可视化模型 graph

首先,我们需要 graphviz 包。也可以参考 graphviz 文档

pip install -U graphviz
[1]:
from yolort.models import yolov5s
from yolort.relay import get_trace_module
[ ]:
model = yolov5s(pretrained=True)
tracing_module = get_trace_module(model)
[3]:
print(tracing_module.code)
def forward(self,
    x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
  model = self.model
  _0, _1, _2, = (model).forward(x, )
  return (_0, _1, _2)

[4]:
from yolort.relay.ir_visualizer import TorchScriptVisualizer
[5]:
visualizer = TorchScriptVisualizer(tracing_module.model)
[6]:
dot1 = visualizer.render(classes_to_visit={'YOLO', 'YOLOHead'})
[7]:
dot1
[7]:
../_images/notebooks_model-graph-visualization_9_0.svg
[8]:
dot2 = visualizer.render(classes_to_visit={'YOLO', 'PostProcess'})
[9]:
dot2
[9]:
../_images/notebooks_model-graph-visualization_11_0.svg
[ ]:

View this document as a notebook: https://github.com/zhiqwang/yolov5-rt-stack/blob/main/notebooks/model-graph-visualization.ipynb