1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
| class Node:
def __init__(self, id: int, name): self.id = id self.name = name
def __str__(self): return f"Node({self.id},{self.name})"
def __repr__(self): return f"Node({self.id},{self.name})"
class Relation:
def __init__(self, id: int, from_node: Node, to_node: Node): self.from_node = from_node self.to_node = to_node self.id = id
def __str__(self): return f"Relation({self.id},{self.from_node},{self.to_node})"
def __repr__(self): return f"Relation({self.id},{self.from_node},{self.to_node})"
class Graph:
def __init__(self, node_list: list[Node], relation_list: list[Relation]): self.require_map = {} self.node_list = node_list self.relation_list = relation_list for node in self.node_list: self.require_map[node] = set() for relation in self.relation_list: values = self.require_map.get(relation.to_node) values.add(relation.from_node)
def topological_sort(self): l = [] s = set(self.get_no_incoming_node()) edges_copy = self.relation_list.copy() cache = []
while s: n = s.pop() if n not in l: l.append(n) cache.append(n) for m in self.get_each_node_m_with_an_edge_e_from_n_to_m(n): self.remove_edge(edges_copy, n, m) if self.check_no_incoming_edges(m, l): if m not in l: l.append(m) qualify_nodes = set(self.get_qualified_node(l)) s = qualify_nodes.difference(cache) if edges_copy: raise RuntimeError("图中至少含有一个环") else: return l
def get_no_incoming_node(self) -> set[Node]: for key in self.require_map.keys(): if len(self.require_map.get(key)) == 0: yield key
def get_qualified_node(self, node_list: list[Node]): for key in self.require_map.keys(): if (len(self.require_map.get(key)) == 0) or (set(self.require_map.get(key)).issubset(set(node_list))): yield key
def get_each_node_m_with_an_edge_e_from_n_to_m(self, n: Node) -> list[Node]: for relation in self.relation_list: if relation.from_node == n: yield relation.to_node
@staticmethod def remove_edge(edge_list: list[Relation], from_node: Node, to_node: Node): for edge in edge_list: if (edge.from_node == from_node) and (edge.to_node == to_node): edge_list.remove(edge) break
def check_no_incoming_edges(self, m: Node, node_list: list[Node]) -> bool: return set(self.require_map.get(m)).issubset(set(node_list))
if __name__ == "__main__": y = Node(1, "y") x = Node(2, "x") b = Node(3, "b") a = Node(4, "a") r1 = Relation(1, y, x) r2 = Relation(2, y, b) r3 = Relation(3, x, a) r4 = Relation(4, b, a) g = Graph([x, y, b, a], [r1, r2, r3, r4]) print(g.topological_sort())
|