コード例 #1
0
ファイル: tensorboard.py プロジェクト: vishalbelsare/CrypTen
def graph(model):
    """Converts a crypten.nn graph for consumption by TensorBoard."""

    # convert individual module to graph:
    assert isinstance(model, nn.Module), "model must be crypten.nn.Module"
    if not isinstance(model, nn.Graph):
        graph = nn.Graph("input", "output")
        graph.add_module("output", model, ["input"])
        model = graph

    # create mapping to more interpretable node naming:
    mapping = {input_name: input_name for input_name in model.input_names}
    modules = {name: module for name, module in model.named_modules()}
    for name, module in modules.items():
        op = str(type(module))[26:-2]
        mapping[name] = "%s_%s" % (op, name)

    # create input variables:
    nodes = [
        NodeDef(
            name=mapping[input_name].encode(encoding="utf_8"),
            op="Variable",
            input=[],
        ) for input_name in model.input_names
    ]

    # loop all graph connections:
    for output_name, input_names in model._graph.items():

        # get parameters and type of module:
        module = modules[output_name]
        op = str(type(module))
        input_names = [mapping[name] for name in input_names]
        parameters = [
            "%s: %s" % (name, parameter.size())
            for name, parameter in module.named_parameters()
        ]
        parameter_string = "; ".join(parameters).encode(encoding="utf_8")

        # add to graph:
        nodes.append(
            NodeDef(
                name=mapping[output_name].encode(encoding="utf_8"),
                op=op,
                input=input_names,
                attr={"attr": AttrValue(s=parameter_string)},
            ))

    # return graph definition:
    return GraphDef(node=nodes, versions=VersionDef(producer=22))
コード例 #2
0
    def __init__(self, model, dummy_input, verbose=False):
        super().__init__(model, dummy_input)

        from tensorboard.compat.proto.config_pb2 import RunMetadata
        from tensorboard.compat.proto.graph_pb2 import GraphDef
        from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
        from tensorboard.compat.proto.versions_pb2 import VersionDef

        list_of_nodes = self.parse(self.trace.graph, self.trace, dummy_input)
        if verbose:
            print(self.trace.graph)
        self.stepstats = RunMetadata(step_stats=StepStats(
            dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
        self.graph_def = GraphDef(node=list_of_nodes,
                                  versions=VersionDef(producer=22))
コード例 #3
0
ファイル: _onnx_graph.py プロジェクト: yuguo68/pytorch
def parse(graph):
    nodes_proto = []
    nodes = []
    import itertools

    for node in itertools.chain(graph.input, graph.output):
        nodes_proto.append(node)

    for node in nodes_proto:
        print(node.name)
        shapeproto = TensorShapeProto(dim=[
            TensorShapeProto.Dim(size=d.dim_value)
            for d in node.type.tensor_type.shape.dim
        ])
        nodes.append(
            NodeDef(
                name=node.name.encode(encoding="utf_8"),
                op="Variable",
                input=[],
                attr={
                    "dtype": AttrValue(type=node.type.tensor_type.elem_type),
                    "shape": AttrValue(shape=shapeproto),
                },
            ))

    for node in graph.node:
        _attr = []
        for s in node.attribute:
            _attr.append(" = ".join([str(f[1]) for f in s.ListFields()]))
        attr = ", ".join(_attr).encode(encoding="utf_8")
        print(node.output[0])
        nodes.append(
            NodeDef(
                name=node.output[0].encode(encoding="utf_8"),
                op=node.op_type,
                input=node.input,
                attr={"parameters": AttrValue(s=attr)},
            ))

    # two pass token replacement, appends opname to object id
    mapping = {}
    for node in nodes:
        mapping[node.name] = node.op + "_" + node.name

    return GraphDef(node=nodes, versions=VersionDef(producer=22))
コード例 #4
0
ファイル: _pytorch_graph.py プロジェクト: xsacha/pytorch
def graph(model, args, verbose=False, use_strict_trace=True):
    """
    This method processes a PyTorch model and produces a `GraphDef` proto
    that can be logged to TensorBoard.

    Args:
      model (PyTorch module): The model to be parsed.
      args (tuple): input tensor[s] for the model.
      verbose (bool): Whether to print out verbose information while
        processing.
      use_strict_trace (bool): Whether to pass keyword argument `strict` to
        `torch.jit.trace`. Pass False when you want the tracer to
        record your mutable container types (list, dict)
    """
    with torch.onnx.select_model_mode_for_export(
            model,
            torch.onnx.TrainingMode.EVAL):  # TODO: move outside of torch.onnx?
        try:
            trace = torch.jit.trace(model, args, strict=use_strict_trace)
            graph = trace.graph
            torch._C._jit_pass_inline(graph)
        except RuntimeError as e:
            print(e)
            print('Error occurs, No graph saved')
            raise e

    if verbose:
        print(graph)
    list_of_nodes = parse(graph, trace, args)
    # We are hardcoding that this was run on CPU even though it might have actually
    # run on GPU. Note this is what is shown in TensorBoard and has no bearing
    # on actual execution.
    # TODO: See if we can extract GPU vs CPU information from the PyTorch model
    # and pass it correctly to TensorBoard.
    #
    # Definition of StepStats and DeviceStepStats can be found at
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
    # and
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
    stepstats = RunMetadata(step_stats=StepStats(
        dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
    return GraphDef(node=list_of_nodes,
                    versions=VersionDef(producer=22)), stepstats
コード例 #5
0
ファイル: _pytorch_graph.py プロジェクト: zzprice/pytorch
def graph(model, args, verbose=False, operator_export_type='ONNX', omit_useless_nodes=True):
    """
    This method processes a PyTorch model and produces a `GraphDef` proto
    that can be logged to TensorBoard.

    Args:
      model (PyTorch module): The model to be parsed.
      args (tuple): input tensor[s] for the model.
      verbose (bool): Whether to print out verbose information while
        processing.
      operator_export_type (str): One of 'ONNX', 'ONNX_ATEN', or 'RAW'.
        Defaults to 'ONNX' format  because it outputs the most visually
        understandable format.
      omit_useless_nodes (boolean): Whether to remove nodes from the graph.
    """
    operator_export_type = getattr(OperatorExportTypes, operator_export_type)


    with torch.onnx.set_training(model, False):
        try:
            trace, _ = torch.jit.get_trace_graph(model, args)
        except RuntimeError:
            print('Error occurs, No graph saved')
            _ = model(*args)  # don't catch, just print the error message
            print("Checking if it's onnx problem...")
            try:
                import tempfile
                torch.onnx.export(
                    model, args, tempfile.TemporaryFile(), verbose=True)
            except RuntimeError:
                print("Your model cannot be exported by onnx, please report to onnx team")
            # Create an object matching
            # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/graph.proto
            # The producer version has been reverse engineered from standard
            # TensorBoard logged data.
            return GraphDef(versions=VersionDef(producer=22))

    try:
        # An optimized graph helps debug at a higher level. Users can focus
        # on connections between big modules such as Linear instead of W, x,
        # bias, matmul, etc. Honestly, most users don't care about those
        # detailed nodes information.
        _optimize_trace(trace, operator_export_type)
    except RuntimeError as e:
        # Optimize trace might fail (due to bad scopes in some cases we've seen)
        # and we don't want graph visualization to fail in this case. In this
        # case we'll log the warning and display the non-optimized graph.
        logging.warn(ImportError(e))
    graph = trace.graph()
    if verbose:
        print(graph)
    list_of_nodes, node_stats = parse(graph, args, omit_useless_nodes)
    # We are hardcoding that this was run on CPU even though it might have actually
    # run on GPU. Note this is what is shown in TensorBoard and has no bearing
    # on actual execution.
    # TODO: See if we can extract GPU vs CPU information from the PyTorch model
    # and pass it correctly to TensorBoard.
    #
    # Definition of StepStats and DeviceStepStats can be found at
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
    # and
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
    stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0",
                                                                            node_stats=node_stats)]))
    return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
コード例 #6
0
 def add_graph(self, model, *args, **kargs):
     visitor = GraphVisitor(model, *args, **kargs)
     stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
     graph = GraphDef(node=visitor._graph, versions=VersionDef(producer=22))
     self._get_file_writer().add_graph((graph, stepstats))
コード例 #7
0
def visualize(
    model_path: str,
    log_path: str,
    input: np.ndarray = None,
    inp_dict: dict = None,
    cal_params: bool = True,
    cal_flops: bool = True,
    cal_activations: bool = True,
    logging_to_stdout: bool = True,
    bar_length_max: int = 20,
):
    r"""
    Load megengine dumped model and visualize graph structure with tensorboard log files.
    Can also record and print model's statistics like :func:`~.module_stats`

    :param model_path: dir path for megengine dumped model.
    :param log_path: dir path for tensorboard graph log.
    :param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input.
    :param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used.
    :param cal_params: whether calculate and record params size.
    :param cal_flops: whether calculate and record op flops.
    :param cal_activations: whether calculate and record op activations.
    :param logging_to_stdout: whether print all calculated statistic details.
    :param bar_length_max: size of bar indicating max flops or parameter size in net stats.

    """
    if log_path:
        try:
            from tensorboard.compat.proto.attr_value_pb2 import AttrValue
            from tensorboard.compat.proto.config_pb2 import RunMetadata
            from tensorboard.compat.proto.graph_pb2 import GraphDef
            from tensorboard.compat.proto.node_def_pb2 import NodeDef
            from tensorboard.compat.proto.step_stats_pb2 import (
                AllocatorMemoryUsed,
                DeviceStepStats,
                NodeExecStats,
                StepStats,
            )
            from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
            from tensorboard.compat.proto.versions_pb2 import VersionDef
            from tensorboardX import SummaryWriter
        except ImportError:
            logger.error(
                "TensorBoard and TensorboardX are required for visualize.",
                exc_info=True,
            )
            return

    enable_receptive_field()

    graph = Network.load(model_path)
    graph.reset_batch_size(1)

    has_input = False
    if input is not None or inp_dict is not None:
        has_input = True
        repl_dict = {}
        inp_vars = graph.input_vars
        if inp_dict is not None:
            assert len(inp_dict) == len(
                inp_vars
            ), "Inputs are not sufficient for calculation."
            for v in inp_vars:
                new_input = graph.make_const(inp_dict[v.name], name=v.name)
                repl_dict[v] = new_input
        else:
            assert len(inp_vars) == 1, "The graph needs more than one input."
            inp_var = inp_vars[0]
            repl_dict[inp_var] = graph.make_const(input, name=inp_var.name)
        graph.replace_vars(repl_dict=repl_dict)

    graph._compile()

    def process_name(name):
        # nodes that start with point or contain float const will lead to display bug
        if not re.match(r"^[+-]?\d*\.\d*", name):
            name = name.replace(".", "/")
        return name.encode(encoding="utf-8")

    summary = [["item", "value"]]
    node_list = []
    flops_list = []
    params_list = []
    activations_list = []
    total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"])
    stats_details = namedtuple("module_stats", ["params", "flops", "activations"])

    for node in tqdm(graph.all_oprs):
        if hasattr(node, "output_idx"):
            node_oup = node.outputs[node.output_idx]
        else:
            if len(node.outputs) != 1:
                logger.warning(
                    "OpNode {} has more than one output and not has 'output_idx' attr.".format(
                        node
                    )
                )
            node_oup = node.outputs[0]

        inp_list = [process_name(var.owner.name) for var in node.inputs]
        if log_path:
            # detail format see tensorboard/compat/proto/attr_value.proto
            attr = {
                "_output_shapes": AttrValue(
                    list=AttrValue.ListValue(
                        shape=[
                            TensorShapeProto(
                                dim=[
                                    TensorShapeProto.Dim(size=d) for d in node_oup.shape
                                ]
                            )
                        ]
                    )
                ),
                "params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
                "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
            }

        if cal_flops:
            flops_stats = get_op_stats(node, node.inputs, node.outputs)
            if flops_stats is not None:
                # add op flops attr
                if log_path and hasattr(flops_stats, "flops_num"):
                    attr["flops"] = AttrValue(
                        s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
                    )
                flops_stats["name"] = node.name
                flops_stats["class_name"] = node.type
                flops_list.append(flops_stats)

        if cal_activations:
            acts = get_activation_stats(node_oup.numpy(), has_input=has_input)
            acts["name"] = node.name
            acts["class_name"] = node.type
            activations_list.append(acts)

        if cal_params:
            if node.type == "ImmutableTensor":
                param_stats = get_param_stats(node.numpy())
                # add tensor size attr
                if log_path:
                    attr["size"] = AttrValue(
                        s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
                    )
                param_stats["name"] = node.name
                params_list.append(param_stats)

        if log_path:
            node_list.append(
                NodeDef(
                    name=process_name(node.name),
                    op=node.type,
                    input=inp_list,
                    attr=attr,
                )
            )
    # summary
    extra_info = {
        "#ops": len(graph.all_oprs),
        "#params": len(params_list),
    }

    (
        total_flops,
        total_param_dims,
        total_param_size,
        total_act_dims,
        total_act_size,
    ) = (0, 0, 0, 0, 0)

    if cal_params:
        total_param_dims, total_param_size, params_list = sum_param_stats(
            params_list, bar_length_max
        )
        extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
        extra_info["total_param_size"] = sizeof_fmt(total_param_size)
        if logging_to_stdout:
            print_param_stats(params_list)

    if cal_flops:
        total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
        extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
        if logging_to_stdout:
            print_op_stats(flops_list)

    if cal_activations:
        total_act_dims, total_act_size, activations_list = sum_activations_stats(
            activations_list, bar_length_max
        )
        extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
        extra_info["total_act_size"] = sizeof_fmt(total_act_size)
        if logging_to_stdout:
            print_activations_stats(activations_list, has_input=has_input)

    if cal_flops and cal_params:
        extra_info["flops/param_size"] = "{:3.3f}".format(
            total_flops / total_param_size
        )

    if log_path:
        graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))

        device = "/device:CPU:0"
        stepstats = RunMetadata(
            step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
        )
        writer = SummaryWriter(log_path)
        writer._get_file_writer().add_graph((graph_def, stepstats))

    print_summary(**extra_info)

    return (
        total_stats(
            param_size=total_param_size, flops=total_flops, act_size=total_act_size,
        ),
        stats_details(
            params=params_list, flops=flops_list, activations=activations_list
        ),
    )
コード例 #8
0
ファイル: _pytorch_graph.py プロジェクト: xiang1563/pytorch
def graph(model,
          args,
          verbose=False,
          operator_export_type='ONNX',
          omit_useless_nodes=True):
    """
    This method processes a PyTorch model and produces a `GraphDef` proto
    that can be logged to TensorBoard.

    Args:
      model (PyTorch module): The model to be parsed.
      args (tuple): input tensor[s] for the model.
      verbose (bool): Whether to print out verbose information while
        processing.
      operator_export_type (str): One of 'ONNX', 'ONNX_ATEN', or 'RAW'.
        Defaults to 'ONNX' format  because it outputs the most visually
        understandable format.
      omit_useless_nodes (boolean): Whether to remove nodes from the graph.
    """
    operator_export_type = getattr(OperatorExportTypes, operator_export_type)

    # This code is similar to torch/onnx/utils.py, but adjusted to provide
    # the most visually understandable output.
    #
    # For example, the commented out line
    #
    #    # torch._C._jit_pass_onnx_peephole(graph).
    #
    # This pass removes a lot of scope information. The amount of optimization
    # cannot be too much (lots of information lost) or too little (too much
    # useless information), therefore I copy-pasted the code so that it will
    # not be affected by torch/onnx/utils.py changes.
    def _optimize_trace(trace, operator_export_type):
        trace.set_graph(_optimize_graph(trace.graph(), operator_export_type))

    def _optimize_graph(graph, operator_export_type):
        # torch._C._jit_pass_remove_inplace_ops(graph)
        # we record now record some ops like ones/zeros
        # into a trace where we previously recorded constants
        # use constant prop to maintain our current level of onnx support
        # without implementing symbolics for all of them
        torch._C._jit_pass_constant_propagation(graph)
        torch.onnx.utils._split_tensor_list_constants(graph, graph)
        # run dce to eliminate dead parts of the graph that might have been
        # left behind by things like symbolic_override
        torch._C._jit_pass_dce(graph)
        torch._C._jit_pass_lint(graph)

        # torch._C._jit_pass_canonicalize_ops(graph)
        torch._C._jit_pass_lint(graph)

        torch._C._jit_pass_peephole(graph, True)
        torch._C._jit_pass_lint(graph)

        # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
        torch._C._jit_pass_prepare_division_for_onnx(graph)
        # onnx only supports tensors, so we turn all out number types into tensors
        torch._C._jit_pass_erase_number_types(graph)
        # onnx does not support tuples, so try to remove them
        torch._C._jit_pass_lower_all_tuples(graph)
        torch._C._jit_pass_peephole(graph, True)
        torch._C._jit_pass_lint(graph)

        if operator_export_type != OperatorExportTypes.RAW:
            graph = torch._C._jit_pass_onnx(graph, operator_export_type)
            torch._C._jit_pass_lint(graph)
            # torch._C._jit_pass_onnx_peephole(graph)
            torch._C._jit_pass_lint(graph)
        torch._C._jit_pass_dce(graph)
        torch._C._jit_pass_lint(graph)
        torch._C._jit_pass_fixup_onnx_loops(graph)
        torch._C._jit_pass_lint(graph)
        graph = torch._C._jit_pass_canonicalize(graph)
        torch._C._jit_pass_lint(graph)
        return graph

    with torch.onnx.set_training(model, False):
        try:
            trace, _ = torch.jit.get_trace_graph(model, args)
        except RuntimeError:
            print('Error occurs, No graph saved')
            _ = model(*args)  # don't catch, just print the error message
            print("Checking if it's onnx problem...")
            try:
                import tempfile
                torch.onnx.export(model,
                                  args,
                                  tempfile.TemporaryFile(),
                                  verbose=True)
            except RuntimeError:
                print("Your model fails onnx too, please report to onnx team")
            # Create an object matching
            # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/graph.proto
            # The producer version has been reverse engineered from standard
            # TensorBoard logged data.
            return GraphDef(versions=VersionDef(producer=22))

    try:
        # An optimized graph helps debug at a higher level. Users can focus
        # on connections between big modules such as Linear instead of W, x,
        # bias, matmul, etc. Honestly, most users don't care about those
        # detailed nodes information.
        _optimize_trace(trace, operator_export_type)
    except RuntimeError as e:
        # Optimize trace might fail (due to bad scopes in some cases we've seen)
        # and we don't want graph visualization to fail in this case. In this
        # case we'll log the warning and display the non-optimized graph.
        logging.warn(ImportError(e))
    graph = trace.graph()
    if verbose:
        print(graph)
    list_of_nodes, node_stats = parse(graph, args, omit_useless_nodes)
    # We are hardcoding that this was run on CPU even though it might have actually
    # run on GPU. Note this is what is shown in TensorBoard and has no bearing
    # on actual execution.
    # TODO: See if we can extract GPU vs CPU information from the PyTorch model
    # and pass it correctly to TensorBoard.
    #
    # Definition of StepStats and DeviceStepStats can be found at
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
    # and
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
    stepstats = RunMetadata(step_stats=StepStats(dev_stats=[
        DeviceStepStats(device="/device:CPU:0", node_stats=node_stats)
    ]))
    return GraphDef(node=list_of_nodes,
                    versions=VersionDef(producer=22)), stepstats
コード例 #9
0
def visualize(
    model_path: str,
    log_path: str,
    bar_length_max: int = 20,
    log_params: bool = True,
    log_flops: bool = True,
):
    r"""
    Load megengine dumped model and visualize graph structure with tensorboard log files.
    Can also record and print model's statistics like :func:`~.module_stats`

    :param model_path: dir path for megengine dumped model.
    :param log_path: dir path for tensorboard graph log.
    :param bar_length_max: size of bar indicating max flops or parameter size in net stats.
    :param log_params: whether print and record params size.
    :param log_flops: whether print and record op flops.
    """
    if log_path:
        try:
            from tensorboard.compat.proto.attr_value_pb2 import AttrValue
            from tensorboard.compat.proto.config_pb2 import RunMetadata
            from tensorboard.compat.proto.graph_pb2 import GraphDef
            from tensorboard.compat.proto.node_def_pb2 import NodeDef
            from tensorboard.compat.proto.step_stats_pb2 import (
                AllocatorMemoryUsed,
                DeviceStepStats,
                NodeExecStats,
                StepStats,
            )
            from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
            from tensorboard.compat.proto.versions_pb2 import VersionDef
            from tensorboardX import SummaryWriter
        except ImportError:
            logger.error(
                "TensorBoard and TensorboardX are required for visualize.",
                exc_info=True,
            )
            return
    # FIXME: remove this after resolving "span dist too large" warning
    old_level = set_mgb_log_level(logging.ERROR)

    enable_receptive_field()

    graph = Network.load(model_path)

    def process_name(name):
        # nodes that start with point or contain float const will lead to display bug
        if not re.match(r"^[+-]?\d*\.\d*", name):
            name = name.replace(".", "/")
        return name.encode(encoding="utf-8")

    summary = [["item", "value"]]
    node_list = []
    flops_list = []
    params_list = []
    for node in graph.all_oprs:
        if hasattr(node, "output_idx"):
            node_oup = node.outputs[node.output_idx]
        else:
            if len(node.outputs) != 1:
                logger.warning(
                    "OpNode {} has more than one output and not has 'output_idx' attr."
                    .format(node))
            node_oup = node.outputs[0]

        inp_list = [process_name(var.owner.name) for var in node.inputs]
        if log_path:
            # detail format see tensorboard/compat/proto/attr_value.proto
            attr = {
                "_output_shapes":
                AttrValue(list=AttrValue.ListValue(shape=[
                    TensorShapeProto(dim=[
                        TensorShapeProto.Dim(size=d) for d in node_oup.shape
                    ])
                ])),
                "params":
                AttrValue(s=str(node.params).encode(encoding="utf-8")),
                "dtype":
                AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
            }
        flops_stats = get_op_stats(node, node.inputs, node.outputs)
        if flops_stats is not None:
            # add op flops attr
            if log_path and hasattr(flops_stats, "flops_num"):
                attr["flops"] = AttrValue(
                    s=sizeof_fmt(flops_stats["flops"]).encode(
                        encoding="utf-8"))
            flops_stats["name"] = node.name
            flops_stats["class_name"] = node.type
            flops_list.append(flops_stats)

        if node.type == "ImmutableTensor":
            param_stats = get_param_stats(node.numpy())
            # add tensor size attr
            if log_path:
                attr["size"] = AttrValue(
                    s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8"))
            param_stats["name"] = node.name
            params_list.append(param_stats)

        if log_path:
            node_list.append(
                NodeDef(
                    name=process_name(node.name),
                    op=node.type,
                    input=inp_list,
                    attr=attr,
                ))
    # summary
    extra_info = {
        "#ops": len(graph.all_oprs),
        "#params": len(params_list),
    }

    total_flops, total_param_dims, total_param_size = 0, 0, 0
    if log_params:
        total_param_dims, total_param_size = print_param_stats(
            params_list, bar_length_max)
        extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
        extra_info["total_param_size"] = sizeof_fmt(total_param_size)
    if log_flops:
        total_flops = print_op_stats(flops_list, bar_length_max)
        extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
    if log_params and log_flops:
        extra_info["flops/param_size"] = "{:3.3f}".format(total_flops /
                                                          total_param_size)

    if log_path:
        graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))

        device = "/device:CPU:0"
        stepstats = RunMetadata(step_stats=StepStats(
            dev_stats=[DeviceStepStats(device=device)]))
        writer = SummaryWriter(log_path)
        writer._get_file_writer().add_graph((graph_def, stepstats))

    print_summary(**extra_info)

    # FIXME: remove this after resolving "span dist too large" warning
    _imperative_rt_logger.set_log_level(old_level)

    return total_param_size, total_flops