示例#1
0
    def test_combine_graph_defs_src_gradient_func_non_unique(self):
        graph_def_a = GraphDef()
        text_format.Merge(
            '''
      library {
        gradient {
          function_name: "foo"
          gradient_func: "foo_grad"
        }
      }
    ''', graph_def_a)

        graph_def_b = GraphDef()
        text_format.Merge(
            '''
      library {
        gradient {
          function_name: "bar"
          gradient_func: "bar_grad"
        }
        gradient {
          function_name: "bar_baz"
          gradient_func: "bar_grad"
        }
      }
    ''', graph_def_b)

        with six.assertRaisesRegex(
                self, ValueError,
                'A GraphDef contains non-unique gradient function names: bar_grad'
        ):
            graph_util.combine_graph_defs(graph_def_a, graph_def_b)
示例#2
0
    def test_combine_graph_defs_gradient_collison(self):
        graph_def_a = GraphDef()
        text_format.Merge(
            '''
      library {
        gradient {
          function_name: "foo"
          gradient_func: "foo_grad"
        }
      }
    ''', graph_def_a)

        graph_def_b = GraphDef()
        text_format.Merge(
            '''
      library {
        gradient {
          function_name: "bar"
          gradient_func: "bar_grad"
        }
        gradient {
          function_name: "foo_1"
          gradient_func: "foo_grad"
        }
      }
    ''', graph_def_b)

        with six.assertRaisesRegex(
                self, ValueError,
            ('share a gradient_func name but map to different functions: '
             'foo_grad')):
            graph_util.combine_graph_defs(graph_def_a, graph_def_b)
示例#3
0
    def test_merge_graph_defs_mismatch_version(self):
        graph_def_a = GraphDef()
        text_format.Parse(
            """
              node {
                name: "A"
                op: "Input"
              }
              versions {
                producer: 21
              }
          """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Parse(
            """
              node {
                name: "A"
                op: "Input"
              }
              versions {
                producer: 100
              }
          """,
            graph_def_b,
        )

        with self.assertRaisesRegex(
            ValueError, "Cannot combine GraphDefs of different versions"
        ):
            graph_util.merge_graph_defs([graph_def_a, graph_def_b])
示例#4
0
    def test_combine_graph_defs_name_collided_different_content(self):
        graph_def_a = GraphDef()
        text_format.Merge(
            """
      node {
        name: "X"
        op: "Input"
      }
      node {
        name: "W"
        op: "Input"
      }
      node {
        name: "Y"
        op: "MatMul"
        input: "X"
        input: "W"
      }
      versions {
        producer: 21
      }
    """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Merge(
            """
      node {
        name: "X"
        op: "Input"
        device: "cpu:0"
      }
      node {
        name: "Z"
        op: "Input"
      }
      node {
        name: "Q"
        op: "MatMul"
        input: "X"
        input: "Z"
      }
      versions {
        producer: 21
      }
    """,
            graph_def_b,
        )

        with six.assertRaisesRegex(
                self,
                ValueError,
            ("Cannot combine GraphDefs because nodes share a name but "
             "contents are different: X"),
        ):
            graph_util.combine_graph_defs(graph_def_a, graph_def_b)
示例#5
0
    def test_merge_graph_defs_gradient(self):
        expected_proto = """
            library {
              gradient {
                function_name: "graph_1_foo"
                gradient_func: "graph_1_foo_grad"
              }
              gradient {
                function_name: "graph_2_foo"
                gradient_func: "graph_2_foo_grad"
              }
              gradient {
                function_name: "graph_2_bar"
                gradient_func: "graph_2_bar_grad"
              }
            }
        """

        graph_def_a = GraphDef()
        text_format.Parse(
            """
                library {
                  gradient {
                    function_name: "foo"
                    gradient_func: "foo_grad"
                  }
                }
            """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Parse(
            """
                library {
                  gradient {
                    function_name: "foo"
                    gradient_func: "foo_grad"
                  }
                  gradient {
                    function_name: "bar"
                    gradient_func: "bar_grad"
                  }
                }
            """,
            graph_def_b,
        )

        self.assertProtoEquals(
            expected_proto,
            graph_util.merge_graph_defs([graph_def_a, graph_def_b]),
        )
示例#6
0
    def test_combine_graph_defs_gradient(self):
        expected_proto = """
      library {
        gradient {
          function_name: "foo"
          gradient_func: "foo_grad"
        }
        gradient {
          function_name: "bar"
          gradient_func: "bar_grad"
        }
      }
    """

        graph_def_a = GraphDef()
        text_format.Merge(
            """
      library {
        gradient {
          function_name: "foo"
          gradient_func: "foo_grad"
        }
      }
    """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Merge(
            """
      library {
        gradient {
          function_name: "foo"
          gradient_func: "foo_grad"
        }
        gradient {
          function_name: "bar"
          gradient_func: "bar_grad"
        }
      }
    """,
            graph_def_b,
        )

        self.assertProtoEquals(
            expected_proto,
            graph_util.combine_graph_defs(graph_def_a, graph_def_b),
        )
示例#7
0
def _operators_to_graph_def(
        shapes,
        ops,
        replace_colons='$',
        with_ssa=True,
        with_gradient_scope=True,
        track_blob_names=None,  # pass an empty array to track blob names
):
    if track_blob_names is not None:
        track_blob_names.clear()
        track_blob_names.update(_get_blob_names(ops))
    if replace_colons:
        _replace_colons(shapes, track_blob_names, ops, replace_colons)
    if with_ssa:
        _convert_to_ssa(shapes, track_blob_names, ops)
    if with_gradient_scope:
        _add_gradient_scope(shapes, track_blob_names, ops)
    _fill_missing_operator_names(ops)
    g = GraphDef()
    producing_ops = {}
    blobs = set()
    for op in ops:
        g.node.extend([_operator_to_node(shapes, op)])
        for input_blob in op.input:
            blobs.add(input_blob)
        for i, output_blob in enumerate(op.output):
            blobs.add(output_blob)
            producing_ops.setdefault(output_blob, []).append((op, i))
    for blob in blobs:
        g.node.extend([_blob_to_node(producing_ops, shapes, blob)])
    return g
示例#8
0
    def test_pytorch_graph(self):
        dummy_input = (torch.zeros(1, 3),)

        class myLinear(torch.nn.Module):
            def __init__(self):
                super(myLinear, self).__init__()
                self.l = torch.nn.Linear(3, 5)

            def forward(self, x):
                return self.l(x)

        with self.createSummaryWriter() as w:
            w.add_graph(myLinear(), dummy_input)

        actual_proto, _ = graph(myLinear(), dummy_input)

        expected_str = read_expected_content(self)
        expected_proto = GraphDef()
        text_format.Parse(expected_str, expected_proto)

        self.assertEquals(len(expected_proto.node), len(actual_proto.node))
        for i in range(len(expected_proto.node)):
            expected_node = expected_proto.node[i]
            actual_node = actual_proto.node[i]
            self.assertEquals(expected_node.name, actual_node.name)
            self.assertEquals(expected_node.op, actual_node.op)
            self.assertEquals(expected_node.input, actual_node.input)
            self.assertEquals(expected_node.device, actual_node.device)
            self.assertEquals(
                sorted(expected_node.attr.keys()), sorted(actual_node.attr.keys()))
示例#9
0
def keras_model_to_graph_def(keras_layer):
  """Returns a GraphDef representation of the Keras model in a dict form.

  Note that it only supports models that implemented to_json().

  Args:
    keras_layer: A dict from Keras model.to_json().

  Returns:
    A GraphDef representation of the layers in the model.
  """
  input_to_layer = {}
  model_name_to_output = {}
  g = GraphDef()

  # Sequential model layers do not have a field "inbound_nodes" but
  # instead are defined implicitly via order of layers.
  prev_node_name = None

  for (name_scope, layer) in _walk_layers(keras_layer):
    if _is_model(layer):
      (input_to_layer, model_name_to_output, prev_node_name) = _update_dicts(
          name_scope, layer, input_to_layer, model_name_to_output, prev_node_name)
      continue

    layer_config = layer.get('config')
    node_name = _scoped_name(name_scope, layer_config.get('name'))

    node_def = g.node.add()
    node_def.name = node_name

    if layer.get('class_name') is not None:
      keras_cls_name = layer.get('class_name').encode('ascii')
      node_def.attr['keras_class'].s = keras_cls_name

    if layer_config.get('dtype') is not None:
      tf_dtype = dtypes.as_dtype(layer_config.get('dtype'))
      node_def.attr['dtype'].type = tf_dtype.as_datatype_enum

    if layer.get('inbound_nodes') is not None:
      for maybe_inbound_node in layer.get('inbound_nodes'):
        inbound_nodes = _norm_to_list_of_layers(maybe_inbound_node)
        for [name, size, index, _] in inbound_nodes:
          inbound_name = _scoped_name(name_scope, name)
          # An input to a layer can be output from a model. In that case, the name
          # of inbound_nodes to a layer is a name of a model. Remap the name of the
          # model to output layer of the model. Also, since there can be multiple
          # outputs in a model, make sure we pick the right output_layer from the model.
          inbound_node_names = model_name_to_output.get(
              inbound_name, [inbound_name])
          node_def.input.append(inbound_node_names[index])
    elif prev_node_name is not None:
      node_def.input.append(prev_node_name)

    if node_name in input_to_layer:
      node_def.input.append(input_to_layer.get(node_name))

    prev_node_name = node_def.name

  return g
示例#10
0
    def test_tensorboard_graphs(self):
        model = model_helper.ModelHelper(name="overfeat")
        data, label = brew.image_input(
            model, ["db"], ["data", "label"], is_test=0
        )
        with core.NameScope("conv1"):
            conv1 = brew.conv(model, data, "conv1", 3, 96, 11, stride=4)
            relu1 = brew.relu(model, conv1, conv1)
            pool1 = brew.max_pool(model, relu1, "pool1", kernel=2, stride=2)
        with core.NameScope("classifier"):
            fc = brew.fc(model, pool1, "fc", 4096, 1000)
            pred = brew.softmax(model, fc, "pred")
            xent = model.LabelCrossEntropy([pred, label], "xent")
            loss = model.AveragedLoss(xent, "loss")
        model.AddGradientOperators([loss], skip=1)

        c2_dir = tempfile.mkdtemp()
        tf_dir = tempfile.mkdtemp()

        with open(os.path.join(c2_dir, "init"), "w") as f:
            f.write(str(model.param_init_net.Proto()))
        with open(os.path.join(c2_dir, "net"), "w") as f:
            f.write(str(model.net.Proto()))
        runner = click.testing.CliRunner()
        result = runner.invoke(
            tb.cli,
            ["tensorboard-graphs",
             "--c2-netdef", os.path.join(c2_dir, "init"),
             "--c2-netdef", os.path.join(c2_dir, "net"),
             "--tf-dir", tf_dir])
        self.assertEqual(result.exit_code, 0)
        entries = list(os.walk(tf_dir))
        self.assertEqual(len(entries), 1)
        ((d, _, (fname,)),) = entries
        self.assertEqual(tf_dir, d)
        events = load_events(os.path.join(tf_dir, fname))
        self.assertEqual(len(events), 3)
        events = events[1:]
        nets = [model.param_init_net, model.net]
        for i, (event, net) in enumerate(zip(events, nets), start=1):
            self.assertEqual(event.step, i)
            self.assertEqual(event.wall_time, i)
            g = GraphDef()
            g.ParseFromString(event.graph_def)
            self.assertMultiLineEqual(
                str(g),
                str(tb_exporter.nets_to_graph_def([net])))
示例#11
0
    def test_combine_graph_defs_src_nodes_duplicate_keys(self):
        graph_def_a = GraphDef()
        text_format.Merge(
            """
      node {
        name: "X"
        op: "Input"
      }
      node {
        name: "Y"
        op: "Input"
      }
      versions {
        producer: 21
      }
    """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Merge(
            """
      node {
        name: "W"
        op: "Input"
        device: "cpu:0"
      }
      node {
        name: "W"
        op: "Input"
      }
      versions {
        producer: 21
      }
    """,
            graph_def_b,
        )

        with six.assertRaisesRegex(
                self, ValueError,
                "A GraphDef contains non-unique node names: W"):
            graph_util.combine_graph_defs(graph_def_a, graph_def_b)
示例#12
0
def graph(model):
    """Converts a crypten.nn graph for consumption by TensorBoard."""

    # convert individual module to graph:
    assert isinstance(model, nn.Module), "model must be crypten.nn.Module"
    if not isinstance(model, nn.Graph):
        graph = nn.Graph("input", "output")
        graph.add_module("output", model, ["input"])
        model = graph

    # create mapping to more interpretable node naming:
    mapping = {input_name: input_name for input_name in model.input_names}
    modules = {name: module for name, module in model.named_modules()}
    for name, module in modules.items():
        op = str(type(module))[26:-2]
        mapping[name] = "%s_%s" % (op, name)

    # create input variables:
    nodes = [
        NodeDef(
            name=mapping[input_name].encode(encoding="utf_8"),
            op="Variable",
            input=[],
        ) for input_name in model.input_names
    ]

    # loop all graph connections:
    for output_name, input_names in model._graph.items():

        # get parameters and type of module:
        module = modules[output_name]
        op = str(type(module))
        input_names = [mapping[name] for name in input_names]
        parameters = [
            "%s: %s" % (name, parameter.size())
            for name, parameter in module.named_parameters()
        ]
        parameter_string = "; ".join(parameters).encode(encoding="utf_8")

        # add to graph:
        nodes.append(
            NodeDef(
                name=mapping[output_name].encode(encoding="utf_8"),
                op=op,
                input=input_names,
                attr={"attr": AttrValue(s=parameter_string)},
            ))

    # return graph definition:
    return GraphDef(node=nodes, versions=VersionDef(producer=22))
示例#13
0
    def test_combine_graph_defs_dst_gradient_func_non_unique(self):
        graph_def_a = GraphDef()
        text_format.Merge(
            """
      library {
        gradient {
          function_name: "foo"
          gradient_func: "foo_grad"
        }
        gradient {
          function_name: "foo_bar"
          gradient_func: "foo_grad"
        }
      }
    """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Merge(
            """
      library {
        gradient {
          function_name: "bar"
          gradient_func: "bar_grad"
        }
      }
    """,
            graph_def_b,
        )

        with six.assertRaisesRegex(
                self,
                ValueError,
                "A GraphDef contains non-unique gradient function names: foo_grad",
        ):
            graph_util.combine_graph_defs(graph_def_a, graph_def_b)
示例#14
0
    def test_combine_graph_defs_dst_nodes_duplicate_keys(self):
        graph_def_a = GraphDef()
        text_format.Merge(
            '''
      node {
        name: "X"
        op: "Input"
      }
      node {
        name: "X"
        op: "Input"
      }
      versions {
        producer: 21
      }
    ''', graph_def_a)

        graph_def_b = GraphDef()
        text_format.Merge(
            '''
      node {
        name: "X"
        op: "Input"
      }
      node {
        name: "Z"
        op: "Input"
      }
      versions {
        producer: 21
      }
    ''', graph_def_b)

        with six.assertRaisesRegex(
                self, ValueError,
                'A GraphDef contains non-unique node names: X'):
            graph_util.combine_graph_defs(graph_def_a, graph_def_b)
示例#15
0
    def __init__(self, model, dummy_input, verbose=False):
        super().__init__(model, dummy_input)

        from tensorboard.compat.proto.config_pb2 import RunMetadata
        from tensorboard.compat.proto.graph_pb2 import GraphDef
        from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
        from tensorboard.compat.proto.versions_pb2 import VersionDef

        list_of_nodes = self.parse(self.trace.graph, self.trace, dummy_input)
        if verbose:
            print(self.trace.graph)
        self.stepstats = RunMetadata(step_stats=StepStats(
            dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
        self.graph_def = GraphDef(node=list_of_nodes,
                                  versions=VersionDef(producer=22))
示例#16
0
    def test_nested_nn_squential(self):

        dummy_input = torch.randn(2, 3)

        class InnerNNSquential(torch.nn.Module):
            def __init__(self, dim1, dim2):
                super().__init__()
                self.inner_nn_squential = torch.nn.Sequential(
                    torch.nn.Linear(dim1, dim2),
                    torch.nn.Linear(dim2, dim1),
                )

            def forward(self, x):
                x = self.inner_nn_squential(x)
                return x

        class OuterNNSquential(torch.nn.Module):
            def __init__(self, dim1=3, dim2=4, depth=2):
                super().__init__()
                layers = []
                for _ in range(depth):
                    layers.append(InnerNNSquential(dim1, dim2))
                self.outer_nn_squential = torch.nn.Sequential(*layers)

            def forward(self, x):
                x = self.outer_nn_squential(x)
                return x

        with self.createSummaryWriter() as w:
            w.add_graph(OuterNNSquential(), dummy_input)

        actual_proto, _ = graph(OuterNNSquential(), dummy_input)

        expected_str = read_expected_content(self)
        expected_proto = GraphDef()
        text_format.Parse(expected_str, expected_proto)

        self.assertEqual(len(expected_proto.node), len(actual_proto.node))
        for i in range(len(expected_proto.node)):
            expected_node = expected_proto.node[i]
            actual_node = actual_proto.node[i]
            self.assertEqual(expected_node.name, actual_node.name)
            self.assertEqual(expected_node.op, actual_node.op)
            self.assertEqual(expected_node.input, actual_node.input)
            self.assertEqual(expected_node.device, actual_node.device)
            self.assertEqual(sorted(expected_node.attr.keys()),
                             sorted(actual_node.attr.keys()))
示例#17
0
def parse(graph):
    nodes_proto = []
    nodes = []
    import itertools

    for node in itertools.chain(graph.input, graph.output):
        nodes_proto.append(node)

    for node in nodes_proto:
        print(node.name)
        shapeproto = TensorShapeProto(dim=[
            TensorShapeProto.Dim(size=d.dim_value)
            for d in node.type.tensor_type.shape.dim
        ])
        nodes.append(
            NodeDef(
                name=node.name.encode(encoding="utf_8"),
                op="Variable",
                input=[],
                attr={
                    "dtype": AttrValue(type=node.type.tensor_type.elem_type),
                    "shape": AttrValue(shape=shapeproto),
                },
            ))

    for node in graph.node:
        _attr = []
        for s in node.attribute:
            _attr.append(" = ".join([str(f[1]) for f in s.ListFields()]))
        attr = ", ".join(_attr).encode(encoding="utf_8")
        print(node.output[0])
        nodes.append(
            NodeDef(
                name=node.output[0].encode(encoding="utf_8"),
                op=node.op_type,
                input=node.input,
                attr={"parameters": AttrValue(s=attr)},
            ))

    # two pass token replacement, appends opname to object id
    mapping = {}
    for node in nodes:
        mapping[node.name] = node.op + "_" + node.name

    return GraphDef(node=nodes, versions=VersionDef(producer=22))
示例#18
0
def graph(model, args, verbose=False, use_strict_trace=True):
    """
    This method processes a PyTorch model and produces a `GraphDef` proto
    that can be logged to TensorBoard.

    Args:
      model (PyTorch module): The model to be parsed.
      args (tuple): input tensor[s] for the model.
      verbose (bool): Whether to print out verbose information while
        processing.
      use_strict_trace (bool): Whether to pass keyword argument `strict` to
        `torch.jit.trace`. Pass False when you want the tracer to
        record your mutable container types (list, dict)
    """
    with torch.onnx.select_model_mode_for_export(
            model,
            torch.onnx.TrainingMode.EVAL):  # TODO: move outside of torch.onnx?
        try:
            trace = torch.jit.trace(model, args, strict=use_strict_trace)
            graph = trace.graph
            torch._C._jit_pass_inline(graph)
        except RuntimeError as e:
            print(e)
            print('Error occurs, No graph saved')
            raise e

    if verbose:
        print(graph)
    list_of_nodes = parse(graph, trace, args)
    # We are hardcoding that this was run on CPU even though it might have actually
    # run on GPU. Note this is what is shown in TensorBoard and has no bearing
    # on actual execution.
    # TODO: See if we can extract GPU vs CPU information from the PyTorch model
    # and pass it correctly to TensorBoard.
    #
    # Definition of StepStats and DeviceStepStats can be found at
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
    # and
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
    stepstats = RunMetadata(step_stats=StepStats(
        dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
    return GraphDef(node=list_of_nodes,
                    versions=VersionDef(producer=22)), stepstats
示例#19
0
    def test_merge_graph_defs_single_graph_def_no_prefix(self):
        graph_def_a = GraphDef()
        text_format.Parse(
            """
              node {
                name: "A"
                op: "Input"
              }
              versions {
                producer: 21
              }
          """,
            graph_def_a,
        )

        self.assertProtoEquals(
            graph_def_a,
            graph_util.merge_graph_defs([graph_def_a]),
        )
示例#20
0
    def _test_graph(self, model, dummy_input, expected_file):
        actual_proto, _ = build_graph(model, dummy_input)

        assert os.path.exists(expected_file), expected_file
        with open(expected_file, "r") as f:
            expected_str = f.read()

        expected_proto = GraphDef()
        text_format.Parse(expected_str, expected_proto)

        self.assertEqual(len(expected_proto.node), len(actual_proto.node))
        for i in range(len(expected_proto.node)):
            expected_node = expected_proto.node[i]
            actual_node = actual_proto.node[i]
            self.assertEqual(expected_node.name, actual_node.name)
            self.assertEqual(expected_node.op, actual_node.op)
            self.assertEqual(expected_node.input, actual_node.input)
            self.assertEqual(expected_node.device, actual_node.device)
            self.assertEqual(sorted(expected_node.attr.keys()),
                             sorted(actual_node.attr.keys()))
示例#21
0
    def test_combine_graph_defs(self):
        expected_proto = '''
      node {
        name: "X"
        op: "Input"
      }
      node {
        name: "W"
        op: "Input"
      }
      node {
        name: "Y"
        op: "MatMul"
        input: "X"
        input: "W"
      }
      node {
        name: "A"
        op: "Input"
      }
      node {
        name: "B"
        op: "Input"
      }
      node {
        name: "C"
        op: "MatMul"
        input: "A"
        input: "B"
      }
      versions {
        producer: 21
      }
    '''

        graph_def_a = GraphDef()
        text_format.Merge(
            '''
      node {
        name: "X"
        op: "Input"
      }
      node {
        name: "W"
        op: "Input"
      }
      node {
        name: "Y"
        op: "MatMul"
        input: "X"
        input: "W"
      }
      versions {
        producer: 21
      }
    ''', graph_def_a)

        graph_def_b = GraphDef()
        text_format.Merge(
            '''
      node {
        name: "A"
        op: "Input"
      }
      node {
        name: "B"
        op: "Input"
      }
      node {
        name: "C"
        op: "MatMul"
        input: "A"
        input: "B"
      }
      versions {
        producer: 21
      }
    ''', graph_def_b)

        self.assertProtoEquals(
            expected_proto,
            graph_util.combine_graph_defs(graph_def_a, graph_def_b))
示例#22
0
def _operators_to_graph_def(shapes,
                            ops,
                            colon_replacement='$',
                            with_ssa=True,
                            with_gradient_scope=True,
                            blob_name_tracker=None,
                            show_simplified=False,
                            custom_rename=None):
    '''
    Main function to convert set of operators to a graph.

    Args:
        shapes: Dictionary mapping blob names to their shapes/dimensions.
        ops: List of Caffe2 operators, representing some computation graph
        ### **kwargs (model_to_graph_def, nets_to_graph_def, protos_to_graph_def) ###
        colon_replacement: Symbol to replace ':' with. ':i' in TF has a special
            meaning, so we need to replace it with a non-conflicting symbol.
        with_ssa: Boolean
        with_gradient_scope: Boolean
        blob_name_tracker: Dictionary tracking names of blobs (inputs/outputs
            from operators)
        show_simplified: Whether to show a simplified version of the model graph
            Sets all of the following values:
                clear_debug_info: Boolean representing whether to silence debug
                    info (which can be very verbose)
                show_forward_only: Boolean representing whether to only show
                    blobs involved in the forward pass
                show_cpu_only: Boolean representing whether to only show blobs
                    that are not associated with a gpu
                use_tensorflow_naming: Boolean representing whether to convert
                    some common Caffe2 naming conventions to their Tensorflow
                    counterparts
        custom_rename: Function string -> string that defines a custom
            renaming function to use.

    Returns:
        current_graph: GraphDef representing the computation graph formed by the
            set of operators.
    '''
    if blob_name_tracker is not None:
        blob_name_tracker.clear()
    else:
        blob_name_tracker = {}

    blob_name_tracker.update(_get_blob_names(ops))

    _clear_debug_info(ops, show_simplified)  # clear_debug_info
    ops = _filter_ops(ops, _check_if_forward,
                      show_simplified)  # show_forward_only
    ops = _filter_ops(ops, _check_if_cpu, show_simplified)  # show_cpu_only
    if custom_rename:
        _rename_all(shapes, blob_name_tracker, ops, custom_rename)
    if colon_replacement:
        _replace_colons(shapes, blob_name_tracker, ops, colon_replacement)
    if with_ssa:
        _convert_to_ssa(shapes, blob_name_tracker, ops)
    if with_gradient_scope:
        _add_gradient_scope(shapes, blob_name_tracker, ops)
    _fill_missing_operator_names(ops)
    if show_simplified:  # use_tensorflow_naming
        _rename_tensorflow_style(shapes, blob_name_tracker, ops)
    producing_ops: Dict[caffe2_pb2.OperatorDef, List] = {}
    blobs = set()
    input_blobs, inter_blobs, _ = _compute_in_out(ops)
    current_graph = GraphDef()
    seen = set(input_blobs)
    for op in ops:
        nodes_from_op = _operator_to_node_simp(op, inter_blobs, seen) if \
            show_simplified else \
            [_operator_to_node(shapes, op)]  # .extend() expects an iterable
        current_graph.node.extend(nodes_from_op)
        for input_blob in op.input:
            blobs.add(input_blob)
        for i, output_blob in enumerate(op.output):
            blobs.add(output_blob)
            producing_ops.setdefault(output_blob, []).append((op, i))

    if show_simplified:
        # Show a cleaner, easier-to-interpret version of the model graph
        blobs = input_blobs

    for blob in sorted(blobs):
        current_graph.node.extend([_blob_to_node(producing_ops, {}, blob)])

    return current_graph
示例#23
0
def graph(model, args, verbose=False, operator_export_type='ONNX', omit_useless_nodes=True):
    """
    This method processes a PyTorch model and produces a `GraphDef` proto
    that can be logged to TensorBoard.

    Args:
      model (PyTorch module): The model to be parsed.
      args (tuple): input tensor[s] for the model.
      verbose (bool): Whether to print out verbose information while
        processing.
      operator_export_type (str): One of 'ONNX', 'ONNX_ATEN', or 'RAW'.
        Defaults to 'ONNX' format  because it outputs the most visually
        understandable format.
      omit_useless_nodes (boolean): Whether to remove nodes from the graph.
    """
    operator_export_type = getattr(OperatorExportTypes, operator_export_type)


    with torch.onnx.set_training(model, False):
        try:
            trace, _ = torch.jit.get_trace_graph(model, args)
        except RuntimeError:
            print('Error occurs, No graph saved')
            _ = model(*args)  # don't catch, just print the error message
            print("Checking if it's onnx problem...")
            try:
                import tempfile
                torch.onnx.export(
                    model, args, tempfile.TemporaryFile(), verbose=True)
            except RuntimeError:
                print("Your model cannot be exported by onnx, please report to onnx team")
            # Create an object matching
            # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/graph.proto
            # The producer version has been reverse engineered from standard
            # TensorBoard logged data.
            return GraphDef(versions=VersionDef(producer=22))

    try:
        # An optimized graph helps debug at a higher level. Users can focus
        # on connections between big modules such as Linear instead of W, x,
        # bias, matmul, etc. Honestly, most users don't care about those
        # detailed nodes information.
        _optimize_trace(trace, operator_export_type)
    except RuntimeError as e:
        # Optimize trace might fail (due to bad scopes in some cases we've seen)
        # and we don't want graph visualization to fail in this case. In this
        # case we'll log the warning and display the non-optimized graph.
        logging.warn(ImportError(e))
    graph = trace.graph()
    if verbose:
        print(graph)
    list_of_nodes, node_stats = parse(graph, args, omit_useless_nodes)
    # We are hardcoding that this was run on CPU even though it might have actually
    # run on GPU. Note this is what is shown in TensorBoard and has no bearing
    # on actual execution.
    # TODO: See if we can extract GPU vs CPU information from the PyTorch model
    # and pass it correctly to TensorBoard.
    #
    # Definition of StepStats and DeviceStepStats can be found at
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
    # and
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
    stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0",
                                                                            node_stats=node_stats)]))
    return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
示例#24
0
 def add_graph(self, model, *args, **kargs):
     visitor = GraphVisitor(model, *args, **kargs)
     stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
     graph = GraphDef(node=visitor._graph, versions=VersionDef(producer=22))
     self._get_file_writer().add_graph((graph, stepstats))
示例#25
0
def keras_model_to_graph_def(keras_layer):
    """Returns a GraphDef representation of the Keras model in a dict form.

    Note that it only supports models that implemented to_json().

    Args:
      keras_layer: A dict from Keras model.to_json().

    Returns:
      A GraphDef representation of the layers in the model.
    """
    input_to_layer = {}
    model_name_to_output = {}
    g = GraphDef()

    # Sequential model layers do not have a field "inbound_nodes" but
    # instead are defined implicitly via order of layers.
    prev_node_name = None

    for (name_scope, layer) in _walk_layers(keras_layer):
        if _is_model(layer):
            (
                input_to_layer,
                model_name_to_output,
                prev_node_name,
            ) = _update_dicts(
                name_scope,
                layer,
                input_to_layer,
                model_name_to_output,
                prev_node_name,
            )
            continue

        layer_config = layer.get("config")
        node_name = _scoped_name(name_scope, layer_config.get("name"))

        node_def = g.node.add()
        node_def.name = node_name

        if layer.get("class_name") is not None:
            keras_cls_name = layer.get("class_name").encode("ascii")
            node_def.attr["keras_class"].s = keras_cls_name

        dtype_or_policy = layer_config.get("dtype")
        # Skip dtype processing if this is a dict, since it's presumably a instance of
        # tf/keras/mixed_precision/Policy rather than a single dtype.
        # TODO(#5548): parse the policy dict and populate the dtype attr with the variable dtype.
        if dtype_or_policy is not None and not isinstance(
                dtype_or_policy, dict):
            tf_dtype = dtypes.as_dtype(layer_config.get("dtype"))
            node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
        if layer.get("inbound_nodes") is not None:
            for maybe_inbound_node in layer.get("inbound_nodes"):
                inbound_nodes = _norm_to_list_of_layers(maybe_inbound_node)
                for [name, size, index, _] in inbound_nodes:
                    inbound_name = _scoped_name(name_scope, name)
                    # An input to a layer can be output from a model. In that case, the name
                    # of inbound_nodes to a layer is a name of a model. Remap the name of the
                    # model to output layer of the model. Also, since there can be multiple
                    # outputs in a model, make sure we pick the right output_layer from the model.
                    inbound_node_names = model_name_to_output.get(
                        inbound_name, [inbound_name])
                    node_def.input.append(inbound_node_names[index])
        elif prev_node_name is not None:
            node_def.input.append(prev_node_name)

        if node_name in input_to_layer:
            node_def.input.append(input_to_layer.get(node_name))

        prev_node_name = node_def.name

    return g
示例#26
0
def graph(model,
          args,
          verbose=False,
          operator_export_type='ONNX',
          omit_useless_nodes=True):
    """
    This method processes a PyTorch model and produces a `GraphDef` proto
    that can be logged to TensorBoard.

    Args:
      model (PyTorch module): The model to be parsed.
      args (tuple): input tensor[s] for the model.
      verbose (bool): Whether to print out verbose information while
        processing.
      operator_export_type (str): One of 'ONNX', 'ONNX_ATEN', or 'RAW'.
        Defaults to 'ONNX' format  because it outputs the most visually
        understandable format.
      omit_useless_nodes (boolean): Whether to remove nodes from the graph.
    """
    operator_export_type = getattr(OperatorExportTypes, operator_export_type)

    # This code is similar to torch/onnx/utils.py, but adjusted to provide
    # the most visually understandable output.
    #
    # For example, the commented out line
    #
    #    # torch._C._jit_pass_onnx_peephole(graph).
    #
    # This pass removes a lot of scope information. The amount of optimization
    # cannot be too much (lots of information lost) or too little (too much
    # useless information), therefore I copy-pasted the code so that it will
    # not be affected by torch/onnx/utils.py changes.
    def _optimize_trace(trace, operator_export_type):
        trace.set_graph(_optimize_graph(trace.graph(), operator_export_type))

    def _optimize_graph(graph, operator_export_type):
        # torch._C._jit_pass_remove_inplace_ops(graph)
        # we record now record some ops like ones/zeros
        # into a trace where we previously recorded constants
        # use constant prop to maintain our current level of onnx support
        # without implementing symbolics for all of them
        torch._C._jit_pass_constant_propagation(graph)
        torch.onnx.utils._split_tensor_list_constants(graph, graph)
        # run dce to eliminate dead parts of the graph that might have been
        # left behind by things like symbolic_override
        torch._C._jit_pass_dce(graph)
        torch._C._jit_pass_lint(graph)

        # torch._C._jit_pass_canonicalize_ops(graph)
        torch._C._jit_pass_lint(graph)

        torch._C._jit_pass_peephole(graph, True)
        torch._C._jit_pass_lint(graph)

        # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
        torch._C._jit_pass_prepare_division_for_onnx(graph)
        # onnx only supports tensors, so we turn all out number types into tensors
        torch._C._jit_pass_erase_number_types(graph)
        # onnx does not support tuples, so try to remove them
        torch._C._jit_pass_lower_all_tuples(graph)
        torch._C._jit_pass_peephole(graph, True)
        torch._C._jit_pass_lint(graph)

        if operator_export_type != OperatorExportTypes.RAW:
            graph = torch._C._jit_pass_onnx(graph, operator_export_type)
            torch._C._jit_pass_lint(graph)
            # torch._C._jit_pass_onnx_peephole(graph)
            torch._C._jit_pass_lint(graph)
        torch._C._jit_pass_dce(graph)
        torch._C._jit_pass_lint(graph)
        torch._C._jit_pass_fixup_onnx_loops(graph)
        torch._C._jit_pass_lint(graph)
        graph = torch._C._jit_pass_canonicalize(graph)
        torch._C._jit_pass_lint(graph)
        return graph

    with torch.onnx.set_training(model, False):
        try:
            trace, _ = torch.jit.get_trace_graph(model, args)
        except RuntimeError:
            print('Error occurs, No graph saved')
            _ = model(*args)  # don't catch, just print the error message
            print("Checking if it's onnx problem...")
            try:
                import tempfile
                torch.onnx.export(model,
                                  args,
                                  tempfile.TemporaryFile(),
                                  verbose=True)
            except RuntimeError:
                print("Your model fails onnx too, please report to onnx team")
            # Create an object matching
            # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/graph.proto
            # The producer version has been reverse engineered from standard
            # TensorBoard logged data.
            return GraphDef(versions=VersionDef(producer=22))

    try:
        # An optimized graph helps debug at a higher level. Users can focus
        # on connections between big modules such as Linear instead of W, x,
        # bias, matmul, etc. Honestly, most users don't care about those
        # detailed nodes information.
        _optimize_trace(trace, operator_export_type)
    except RuntimeError as e:
        # Optimize trace might fail (due to bad scopes in some cases we've seen)
        # and we don't want graph visualization to fail in this case. In this
        # case we'll log the warning and display the non-optimized graph.
        logging.warn(ImportError(e))
    graph = trace.graph()
    if verbose:
        print(graph)
    list_of_nodes, node_stats = parse(graph, args, omit_useless_nodes)
    # We are hardcoding that this was run on CPU even though it might have actually
    # run on GPU. Note this is what is shown in TensorBoard and has no bearing
    # on actual execution.
    # TODO: See if we can extract GPU vs CPU information from the PyTorch model
    # and pass it correctly to TensorBoard.
    #
    # Definition of StepStats and DeviceStepStats can be found at
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
    # and
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
    stepstats = RunMetadata(step_stats=StepStats(dev_stats=[
        DeviceStepStats(device="/device:CPU:0", node_stats=node_stats)
    ]))
    return GraphDef(node=list_of_nodes,
                    versions=VersionDef(producer=22)), stepstats
示例#27
0
    def test_combine_graph_defs_src_function_duplicate_keys(self):
        graph_def_a = GraphDef()
        text_format.Merge(
            '''
      library {
        function {
          signature {
            name: "foo"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
      }
    ''', graph_def_a)

        graph_def_b = GraphDef()
        text_format.Merge(
            '''
      library {
        function {
          signature {
            name: "bar"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
        }
        function {
          signature {
            name: "bar"
            input_arg {
              name: "y"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
        }
      }
    ''', graph_def_b)

        with six.assertRaisesRegex(
                self, ValueError,
                'A GraphDef contains non-unique function names: bar'):
            graph_util.combine_graph_defs(graph_def_a, graph_def_b)
示例#28
0
    def test_combine_graph_defs_function_collison(self):
        graph_def_a = GraphDef()
        text_format.Merge(
            '''
      library {
        function {
          signature {
            name: "foo"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
      }
    ''', graph_def_a)

        graph_def_b = GraphDef()
        text_format.Merge(
            '''
      library {
        function {
          signature {
            name: "foo"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "div"
            op: "Div"
            input: "x"
            input: "y"
          }
        }
        function {
          signature {
            name: "foo_1"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
      }
    ''', graph_def_b)

        with six.assertRaisesRegex(
                self, ValueError,
            ('Cannot combine GraphDefs because functions share a name but '
             'are different: foo')):
            graph_util.combine_graph_defs(graph_def_a, graph_def_b)
示例#29
0
    def test_combine_graph_defs_function(self):
        expected_proto = '''
      library {
        function {
          signature {
            name: "foo"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
        function {
          signature {
            name: "foo_1"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
      }
    '''

        graph_def_a = GraphDef()
        text_format.Merge(
            '''
      library {
        function {
          signature {
            name: "foo"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
      }
    ''', graph_def_a)

        graph_def_b = GraphDef()
        text_format.Merge(
            '''
      library {
        function {
          signature {
            name: "foo"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
        function {
          signature {
            name: "foo_1"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
      }
    ''', graph_def_b)

        self.assertProtoEquals(
            expected_proto,
            graph_util.combine_graph_defs(graph_def_a, graph_def_b))
示例#30
0
def visualize(
    model_path: str,
    log_path: str,
    input: np.ndarray = None,
    inp_dict: dict = None,
    cal_params: bool = True,
    cal_flops: bool = True,
    cal_activations: bool = True,
    logging_to_stdout: bool = True,
    bar_length_max: int = 20,
):
    r"""
    Load megengine dumped model and visualize graph structure with tensorboard log files.
    Can also record and print model's statistics like :func:`~.module_stats`

    :param model_path: dir path for megengine dumped model.
    :param log_path: dir path for tensorboard graph log.
    :param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input.
    :param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used.
    :param cal_params: whether calculate and record params size.
    :param cal_flops: whether calculate and record op flops.
    :param cal_activations: whether calculate and record op activations.
    :param logging_to_stdout: whether print all calculated statistic details.
    :param bar_length_max: size of bar indicating max flops or parameter size in net stats.

    """
    if log_path:
        try:
            from tensorboard.compat.proto.attr_value_pb2 import AttrValue
            from tensorboard.compat.proto.config_pb2 import RunMetadata
            from tensorboard.compat.proto.graph_pb2 import GraphDef
            from tensorboard.compat.proto.node_def_pb2 import NodeDef
            from tensorboard.compat.proto.step_stats_pb2 import (
                AllocatorMemoryUsed,
                DeviceStepStats,
                NodeExecStats,
                StepStats,
            )
            from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
            from tensorboard.compat.proto.versions_pb2 import VersionDef
            from tensorboardX import SummaryWriter
        except ImportError:
            logger.error(
                "TensorBoard and TensorboardX are required for visualize.",
                exc_info=True,
            )
            return

    enable_receptive_field()

    graph = Network.load(model_path)
    graph.reset_batch_size(1)

    has_input = False
    if input is not None or inp_dict is not None:
        has_input = True
        repl_dict = {}
        inp_vars = graph.input_vars
        if inp_dict is not None:
            assert len(inp_dict) == len(
                inp_vars
            ), "Inputs are not sufficient for calculation."
            for v in inp_vars:
                new_input = graph.make_const(inp_dict[v.name], name=v.name)
                repl_dict[v] = new_input
        else:
            assert len(inp_vars) == 1, "The graph needs more than one input."
            inp_var = inp_vars[0]
            repl_dict[inp_var] = graph.make_const(input, name=inp_var.name)
        graph.replace_vars(repl_dict=repl_dict)

    graph._compile()

    def process_name(name):
        # nodes that start with point or contain float const will lead to display bug
        if not re.match(r"^[+-]?\d*\.\d*", name):
            name = name.replace(".", "/")
        return name.encode(encoding="utf-8")

    summary = [["item", "value"]]
    node_list = []
    flops_list = []
    params_list = []
    activations_list = []
    total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"])
    stats_details = namedtuple("module_stats", ["params", "flops", "activations"])

    for node in tqdm(graph.all_oprs):
        if hasattr(node, "output_idx"):
            node_oup = node.outputs[node.output_idx]
        else:
            if len(node.outputs) != 1:
                logger.warning(
                    "OpNode {} has more than one output and not has 'output_idx' attr.".format(
                        node
                    )
                )
            node_oup = node.outputs[0]

        inp_list = [process_name(var.owner.name) for var in node.inputs]
        if log_path:
            # detail format see tensorboard/compat/proto/attr_value.proto
            attr = {
                "_output_shapes": AttrValue(
                    list=AttrValue.ListValue(
                        shape=[
                            TensorShapeProto(
                                dim=[
                                    TensorShapeProto.Dim(size=d) for d in node_oup.shape
                                ]
                            )
                        ]
                    )
                ),
                "params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
                "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
            }

        if cal_flops:
            flops_stats = get_op_stats(node, node.inputs, node.outputs)
            if flops_stats is not None:
                # add op flops attr
                if log_path and hasattr(flops_stats, "flops_num"):
                    attr["flops"] = AttrValue(
                        s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
                    )
                flops_stats["name"] = node.name
                flops_stats["class_name"] = node.type
                flops_list.append(flops_stats)

        if cal_activations:
            acts = get_activation_stats(node_oup.numpy(), has_input=has_input)
            acts["name"] = node.name
            acts["class_name"] = node.type
            activations_list.append(acts)

        if cal_params:
            if node.type == "ImmutableTensor":
                param_stats = get_param_stats(node.numpy())
                # add tensor size attr
                if log_path:
                    attr["size"] = AttrValue(
                        s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
                    )
                param_stats["name"] = node.name
                params_list.append(param_stats)

        if log_path:
            node_list.append(
                NodeDef(
                    name=process_name(node.name),
                    op=node.type,
                    input=inp_list,
                    attr=attr,
                )
            )
    # summary
    extra_info = {
        "#ops": len(graph.all_oprs),
        "#params": len(params_list),
    }

    (
        total_flops,
        total_param_dims,
        total_param_size,
        total_act_dims,
        total_act_size,
    ) = (0, 0, 0, 0, 0)

    if cal_params:
        total_param_dims, total_param_size, params_list = sum_param_stats(
            params_list, bar_length_max
        )
        extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
        extra_info["total_param_size"] = sizeof_fmt(total_param_size)
        if logging_to_stdout:
            print_param_stats(params_list)

    if cal_flops:
        total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
        extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
        if logging_to_stdout:
            print_op_stats(flops_list)

    if cal_activations:
        total_act_dims, total_act_size, activations_list = sum_activations_stats(
            activations_list, bar_length_max
        )
        extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
        extra_info["total_act_size"] = sizeof_fmt(total_act_size)
        if logging_to_stdout:
            print_activations_stats(activations_list, has_input=has_input)

    if cal_flops and cal_params:
        extra_info["flops/param_size"] = "{:3.3f}".format(
            total_flops / total_param_size
        )

    if log_path:
        graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))

        device = "/device:CPU:0"
        stepstats = RunMetadata(
            step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
        )
        writer = SummaryWriter(log_path)
        writer._get_file_writer().add_graph((graph_def, stepstats))

    print_summary(**extra_info)

    return (
        total_stats(
            param_size=total_param_size, flops=total_flops, act_size=total_act_size,
        ),
        stats_details(
            params=params_list, flops=flops_list, activations=activations_list
        ),
    )