Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor graph visualization #165

Merged
merged 16 commits into from
Sep 16, 2021
Prev Previous commit
Next Next commit
Refactor the dot initialization
  • Loading branch information
zhiqwang committed Sep 16, 2021
commit bbaa13c35289dc699a095878e79163d1e3d79269
12 changes: 5 additions & 7 deletions yolort/relaying/ir_visualizer.py
Original file line number Diff line number Diff line change
@@ -36,20 +36,20 @@ def __init__(self, module):
self.absorbing_ops = ('aten::size', 'aten::_shape_as_tensor')

def render(self, classes_to_visit={'YOLO', 'YOLOHead'}):
return self.make_graph(self.module, classes_to_visit=classes_to_visit)
model_input = next(self.module.graph.inputs())
model_type = model_input.type().str().split('.')[-1]
dot = Digraph(format='svg', graph_attr={'label': model_type, 'labelloc': 't'})
self.make_graph(self.module, dot=dot, classes_to_visit=classes_to_visit)
return dot

def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=None,
classes_to_visit=None, classes_found=None):
graph = module.graph
preds = {}

self_input = next(graph.inputs())
self_type = self_input.type().str().split('.')[-1]
preds[self_input] = (set(), set()) # inps, ops

if dot is None:
dot = Digraph(format='svg', graph_attr={'label': self_type, 'labelloc': 't'})

for nr, i in enumerate(list(graph.inputs())[1:]):
name = f'{prefix}input_{i.debugName()}'
preds[i] = {name}, set()
@@ -143,8 +143,6 @@ def make_graph(self, module, dot=None, parent_dot=None, prefix="", input_preds=N
pr, op = preds[o]
self.make_edges(pr, f'input_{name}', name, op, dot)

return dot

def add_edge(self, dot, n1, n2):
if (n1, n2) not in self.seen_edges:
self.seen_edges.add((n1, n2))