コード例 #1
0
def _operators_to_graph_def(shapes,
                            ops,
                            colon_replacement='$',
                            with_ssa=True,
                            with_gradient_scope=True,
                            blob_name_tracker=None,
                            show_simplified=False,
                            custom_rename=None):
    '''
    Main function to convert set of operators to a graph.

    Args:
        shapes: Dictionary mapping blob names to their shapes/dimensions.
        ops: List of Caffe2 operators, representing some computation graph
        ### **kwargs (model_to_graph_def, nets_to_graph_def, protos_to_graph_def) ###
        colon_replacement: Symbol to replace ':' with. ':i' in TF has a special
            meaning, so we need to replace it with a non-conflicting symbol.
        with_ssa: Boolean
        with_gradient_scope: Boolean
        blob_name_tracker: Dictionary tracking names of blobs (inputs/outputs
            from operators)
        show_simplified: Whether to show a simplified version of the model graph
            Sets all of the following values:
                clear_debug_info: Boolean representing whether to silence debug
                    info (which can be very verbose)
                show_forward_only: Boolean representing whether to only show
                    blobs involved in the forward pass
                show_cpu_only: Boolean representing whether to only show blobs
                    that are not associated with a gpu
                use_tensorflow_naming: Boolean representing whether to convert
                    some common Caffe2 naming conventions to their Tensorflow
                    counterparts
        custom_rename: Function string -> string that defines a custom
            renaming function to use.

    Returns:
        current_graph: GraphDef representing the computation graph formed by the
            set of operators.
    '''
    if blob_name_tracker is not None:
        blob_name_tracker.clear()
    else:
        blob_name_tracker = {}

    blob_name_tracker.update(_get_blob_names(ops))

    _clear_debug_info(ops, show_simplified)  # clear_debug_info
    ops = _filter_ops(ops, _check_if_forward,
                      show_simplified)  # show_forward_only
    ops = _filter_ops(ops, _check_if_cpu, show_simplified)  # show_cpu_only
    if custom_rename:
        _rename_all(shapes, blob_name_tracker, ops, custom_rename)
    if colon_replacement:
        _replace_colons(shapes, blob_name_tracker, ops, colon_replacement)
    if with_ssa:
        _convert_to_ssa(shapes, blob_name_tracker, ops)
    if with_gradient_scope:
        _add_gradient_scope(shapes, blob_name_tracker, ops)
    _fill_missing_operator_names(ops)
    if show_simplified:  # use_tensorflow_naming
        _rename_tensorflow_style(shapes, blob_name_tracker, ops)
    producing_ops = {}
    blobs = set()
    input_blobs, inter_blobs, _ = _compute_in_out(ops)
    current_graph = GraphDef()
    seen = set(input_blobs)
    for op in ops:
        nodes_from_op = _operator_to_node_simp(op, inter_blobs, seen) if \
            show_simplified else \
            [_operator_to_node(shapes, op)]  # .extend() expects an iterable
        current_graph.node.extend(nodes_from_op)
        for input_blob in op.input:
            blobs.add(input_blob)
        for i, output_blob in enumerate(op.output):
            blobs.add(output_blob)
            producing_ops.setdefault(output_blob, []).append((op, i))

    if show_simplified:
        # Show a cleaner, easier-to-interpret version of the model graph
        blobs = input_blobs

    for blob in blobs:
        current_graph.node.extend([_blob_to_node(producing_ops, {}, blob)])

    return current_graph
コード例 #2
0
ファイル: keras_util.py プロジェクト: yatbear/tensorboard
def keras_model_to_graph_def(keras_layer):
    """Returns a GraphDef representation of the Keras model in a dict form.

    Note that it only supports models that implemented to_json().

    Args:
      keras_layer: A dict from Keras model.to_json().

    Returns:
      A GraphDef representation of the layers in the model.
    """
    input_to_layer = {}
    model_name_to_output = {}
    g = GraphDef()

    # Sequential model layers do not have a field "inbound_nodes" but
    # instead are defined implicitly via order of layers.
    prev_node_name = None

    for (name_scope, layer) in _walk_layers(keras_layer):
        if _is_model(layer):
            (
                input_to_layer,
                model_name_to_output,
                prev_node_name,
            ) = _update_dicts(
                name_scope,
                layer,
                input_to_layer,
                model_name_to_output,
                prev_node_name,
            )
            continue

        layer_config = layer.get("config")
        node_name = _scoped_name(name_scope, layer_config.get("name"))

        node_def = g.node.add()
        node_def.name = node_name

        if layer.get("class_name") is not None:
            keras_cls_name = layer.get("class_name").encode("ascii")
            node_def.attr["keras_class"].s = keras_cls_name

        dtype_or_policy = layer_config.get("dtype")
        # Skip dtype processing if this is a dict, since it's presumably a instance of
        # tf/keras/mixed_precision/Policy rather than a single dtype.
        # TODO(#5548): parse the policy dict and populate the dtype attr with the variable dtype.
        if dtype_or_policy is not None and not isinstance(
                dtype_or_policy, dict):
            tf_dtype = dtypes.as_dtype(layer_config.get("dtype"))
            node_def.attr["dtype"].type = tf_dtype.as_datatype_enum
        if layer.get("inbound_nodes") is not None:
            for maybe_inbound_node in layer.get("inbound_nodes"):
                inbound_nodes = _norm_to_list_of_layers(maybe_inbound_node)
                for [name, size, index, _] in inbound_nodes:
                    inbound_name = _scoped_name(name_scope, name)
                    # An input to a layer can be output from a model. In that case, the name
                    # of inbound_nodes to a layer is a name of a model. Remap the name of the
                    # model to output layer of the model. Also, since there can be multiple
                    # outputs in a model, make sure we pick the right output_layer from the model.
                    inbound_node_names = model_name_to_output.get(
                        inbound_name, [inbound_name])
                    node_def.input.append(inbound_node_names[index])
        elif prev_node_name is not None:
            node_def.input.append(prev_node_name)

        if node_name in input_to_layer:
            node_def.input.append(input_to_layer.get(node_name))

        prev_node_name = node_def.name

    return g
コード例 #3
0
    def test_combine_graph_defs_function_collison(self):
        graph_def_a = GraphDef()
        text_format.Merge(
            '''
      library {
        function {
          signature {
            name: "foo"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
      }
    ''', graph_def_a)

        graph_def_b = GraphDef()
        text_format.Merge(
            '''
      library {
        function {
          signature {
            name: "foo"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "div"
            op: "Div"
            input: "x"
            input: "y"
          }
        }
        function {
          signature {
            name: "foo_1"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
      }
    ''', graph_def_b)

        with six.assertRaisesRegex(
                self, ValueError,
            ('Cannot combine GraphDefs because functions share a name but '
             'are different: foo')):
            graph_util.combine_graph_defs(graph_def_a, graph_def_b)
コード例 #4
0
    def test_combine_graph_defs_src_function_duplicate_keys(self):
        graph_def_a = GraphDef()
        text_format.Merge(
            '''
      library {
        function {
          signature {
            name: "foo"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
      }
    ''', graph_def_a)

        graph_def_b = GraphDef()
        text_format.Merge(
            '''
      library {
        function {
          signature {
            name: "bar"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
        }
        function {
          signature {
            name: "bar"
            input_arg {
              name: "y"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
        }
      }
    ''', graph_def_b)

        with six.assertRaisesRegex(
                self, ValueError,
                'A GraphDef contains non-unique function names: bar'):
            graph_util.combine_graph_defs(graph_def_a, graph_def_b)
コード例 #5
0
    def test_combine_graph_defs(self):
        expected_proto = '''
      node {
        name: "X"
        op: "Input"
      }
      node {
        name: "W"
        op: "Input"
      }
      node {
        name: "Y"
        op: "MatMul"
        input: "X"
        input: "W"
      }
      node {
        name: "A"
        op: "Input"
      }
      node {
        name: "B"
        op: "Input"
      }
      node {
        name: "C"
        op: "MatMul"
        input: "A"
        input: "B"
      }
      versions {
        producer: 21
      }
    '''

        graph_def_a = GraphDef()
        text_format.Merge(
            '''
      node {
        name: "X"
        op: "Input"
      }
      node {
        name: "W"
        op: "Input"
      }
      node {
        name: "Y"
        op: "MatMul"
        input: "X"
        input: "W"
      }
      versions {
        producer: 21
      }
    ''', graph_def_a)

        graph_def_b = GraphDef()
        text_format.Merge(
            '''
      node {
        name: "A"
        op: "Input"
      }
      node {
        name: "B"
        op: "Input"
      }
      node {
        name: "C"
        op: "MatMul"
        input: "A"
        input: "B"
      }
      versions {
        producer: 21
      }
    ''', graph_def_b)

        self.assertProtoEquals(
            expected_proto,
            graph_util.combine_graph_defs(graph_def_a, graph_def_b))
コード例 #6
0
    def test_combine_graph_defs_function(self):
        expected_proto = '''
      library {
        function {
          signature {
            name: "foo"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
        function {
          signature {
            name: "foo_1"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
      }
    '''

        graph_def_a = GraphDef()
        text_format.Merge(
            '''
      library {
        function {
          signature {
            name: "foo"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
      }
    ''', graph_def_a)

        graph_def_b = GraphDef()
        text_format.Merge(
            '''
      library {
        function {
          signature {
            name: "foo"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
        function {
          signature {
            name: "foo_1"
            input_arg {
              name: "x"
              type: DT_HALF
            }
            output_arg {
              name: "identity"
              type: DT_HALF
            }
          }
          node_def {
            name: "add"
            op: "Add"
            input: "x"
            input: "y"
          }
        }
      }
    ''', graph_def_b)

        self.assertProtoEquals(
            expected_proto,
            graph_util.combine_graph_defs(graph_def_a, graph_def_b))
コード例 #7
0
ファイル: _pytorch_graph.py プロジェクト: xiang1563/pytorch
def graph(model,
          args,
          verbose=False,
          operator_export_type='ONNX',
          omit_useless_nodes=True):
    """
    This method processes a PyTorch model and produces a `GraphDef` proto
    that can be logged to TensorBoard.

    Args:
      model (PyTorch module): The model to be parsed.
      args (tuple): input tensor[s] for the model.
      verbose (bool): Whether to print out verbose information while
        processing.
      operator_export_type (str): One of 'ONNX', 'ONNX_ATEN', or 'RAW'.
        Defaults to 'ONNX' format  because it outputs the most visually
        understandable format.
      omit_useless_nodes (boolean): Whether to remove nodes from the graph.
    """
    operator_export_type = getattr(OperatorExportTypes, operator_export_type)

    # This code is similar to torch/onnx/utils.py, but adjusted to provide
    # the most visually understandable output.
    #
    # For example, the commented out line
    #
    #    # torch._C._jit_pass_onnx_peephole(graph).
    #
    # This pass removes a lot of scope information. The amount of optimization
    # cannot be too much (lots of information lost) or too little (too much
    # useless information), therefore I copy-pasted the code so that it will
    # not be affected by torch/onnx/utils.py changes.
    def _optimize_trace(trace, operator_export_type):
        trace.set_graph(_optimize_graph(trace.graph(), operator_export_type))

    def _optimize_graph(graph, operator_export_type):
        # torch._C._jit_pass_remove_inplace_ops(graph)
        # we record now record some ops like ones/zeros
        # into a trace where we previously recorded constants
        # use constant prop to maintain our current level of onnx support
        # without implementing symbolics for all of them
        torch._C._jit_pass_constant_propagation(graph)
        torch.onnx.utils._split_tensor_list_constants(graph, graph)
        # run dce to eliminate dead parts of the graph that might have been
        # left behind by things like symbolic_override
        torch._C._jit_pass_dce(graph)
        torch._C._jit_pass_lint(graph)

        # torch._C._jit_pass_canonicalize_ops(graph)
        torch._C._jit_pass_lint(graph)

        torch._C._jit_pass_peephole(graph, True)
        torch._C._jit_pass_lint(graph)

        # onnx only supports tensors, but 1 / 2 = 0.5 and tensor(1) / tensor(2) = 0
        torch._C._jit_pass_prepare_division_for_onnx(graph)
        # onnx only supports tensors, so we turn all out number types into tensors
        torch._C._jit_pass_erase_number_types(graph)
        # onnx does not support tuples, so try to remove them
        torch._C._jit_pass_lower_all_tuples(graph)
        torch._C._jit_pass_peephole(graph, True)
        torch._C._jit_pass_lint(graph)

        if operator_export_type != OperatorExportTypes.RAW:
            graph = torch._C._jit_pass_onnx(graph, operator_export_type)
            torch._C._jit_pass_lint(graph)
            # torch._C._jit_pass_onnx_peephole(graph)
            torch._C._jit_pass_lint(graph)
        torch._C._jit_pass_dce(graph)
        torch._C._jit_pass_lint(graph)
        torch._C._jit_pass_fixup_onnx_loops(graph)
        torch._C._jit_pass_lint(graph)
        graph = torch._C._jit_pass_canonicalize(graph)
        torch._C._jit_pass_lint(graph)
        return graph

    with torch.onnx.set_training(model, False):
        try:
            trace, _ = torch.jit.get_trace_graph(model, args)
        except RuntimeError:
            print('Error occurs, No graph saved')
            _ = model(*args)  # don't catch, just print the error message
            print("Checking if it's onnx problem...")
            try:
                import tempfile
                torch.onnx.export(model,
                                  args,
                                  tempfile.TemporaryFile(),
                                  verbose=True)
            except RuntimeError:
                print("Your model fails onnx too, please report to onnx team")
            # Create an object matching
            # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/graph.proto
            # The producer version has been reverse engineered from standard
            # TensorBoard logged data.
            return GraphDef(versions=VersionDef(producer=22))

    try:
        # An optimized graph helps debug at a higher level. Users can focus
        # on connections between big modules such as Linear instead of W, x,
        # bias, matmul, etc. Honestly, most users don't care about those
        # detailed nodes information.
        _optimize_trace(trace, operator_export_type)
    except RuntimeError as e:
        # Optimize trace might fail (due to bad scopes in some cases we've seen)
        # and we don't want graph visualization to fail in this case. In this
        # case we'll log the warning and display the non-optimized graph.
        logging.warn(ImportError(e))
    graph = trace.graph()
    if verbose:
        print(graph)
    list_of_nodes, node_stats = parse(graph, args, omit_useless_nodes)
    # We are hardcoding that this was run on CPU even though it might have actually
    # run on GPU. Note this is what is shown in TensorBoard and has no bearing
    # on actual execution.
    # TODO: See if we can extract GPU vs CPU information from the PyTorch model
    # and pass it correctly to TensorBoard.
    #
    # Definition of StepStats and DeviceStepStats can be found at
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
    # and
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
    stepstats = RunMetadata(step_stats=StepStats(dev_stats=[
        DeviceStepStats(device="/device:CPU:0", node_stats=node_stats)
    ]))
    return GraphDef(node=list_of_nodes,
                    versions=VersionDef(producer=22)), stepstats
コード例 #8
0
    def test_combine_graph_defs_name_collided_but_same_content(self):
        expected_proto = """
      node {
        name: "X"
        op: "Input"
      }
      node {
        name: "W"
        op: "Input"
      }
      node {
        name: "Y"
        op: "MatMul"
        input: "X"
        input: "W"
      }
      node {
        name: "A"
        op: "Input"
      }
      versions {
        producer: 21
      }
    """

        graph_def_a = GraphDef()
        text_format.Merge(
            """
      node {
        name: "X"
        op: "Input"
      }
      node {
        name: "W"
        op: "Input"
      }
      node {
        name: "Y"
        op: "MatMul"
        input: "X"
        input: "W"
      }
      versions {
        producer: 21
      }
    """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Merge(
            """
      node {
        name: "X"
        op: "Input"
      }
      node {
        name: "A"
        op: "Input"
      }
      versions {
        producer: 21
      }
    """,
            graph_def_b,
        )

        self.assertProtoEquals(
            expected_proto,
            graph_util.combine_graph_defs(graph_def_a, graph_def_b),
        )
コード例 #9
0
    def test_merge_graph_defs_function(self):
        expected_proto = """
            library {
              function {
                signature {
                  name: "graph_1_foo"
                  input_arg {
                    name: "x"
                    type: DT_HALF
                  }
                  output_arg {
                    name: "identity"
                    type: DT_HALF
                  }
                }
                node_def {
                  name: "add"
                  op: "Add"
                  input: "x"
                  input: "y"
                }
              }
              function {
                signature {
                  name: "graph_2_foo"
                  input_arg {
                    name: "x"
                    type: DT_INT32
                  }
                  output_arg {
                    name: "identity"
                    type: DT_INT32
                  }
                }
                node_def {
                  name: "add"
                  op: "Add"
                  input: "x"
                  input: "y"
                }
              }
              function {
                signature {
                  name: "graph_2_foo_1"
                  input_arg {
                    name: "x"
                    type: DT_HALF
                  }
                  output_arg {
                    name: "identity"
                    type: DT_HALF
                  }
                }
                node_def {
                  name: "add"
                  op: "Add"
                  input: "x"
                  input: "y"
                }
              }
            }
        """

        graph_def_a = GraphDef()
        text_format.Parse(
            """
                library {
                  function {
                    signature {
                      name: "foo"
                      input_arg {
                        name: "x"
                        type: DT_HALF
                      }
                      output_arg {
                        name: "identity"
                        type: DT_HALF
                      }
                    }
                    node_def {
                      name: "add"
                      op: "Add"
                      input: "x"
                      input: "y"
                    }
                  }
                }
            """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Parse(
            """
                library {
                  function {
                    signature {
                      name: "foo"
                      input_arg {
                        name: "x"
                        type: DT_INT32
                      }
                      output_arg {
                        name: "identity"
                        type: DT_INT32
                      }
                    }
                    node_def {
                      name: "add"
                      op: "Add"
                      input: "x"
                      input: "y"
                    }
                  }
                  function {
                    signature {
                      name: "foo_1"
                      input_arg {
                        name: "x"
                        type: DT_HALF
                      }
                      output_arg {
                        name: "identity"
                        type: DT_HALF
                      }
                    }
                    node_def {
                      name: "add"
                      op: "Add"
                      input: "x"
                      input: "y"
                    }
                  }
                }
            """,
            graph_def_b,
        )

        self.assertProtoEquals(
            expected_proto,
            graph_util.merge_graph_defs([graph_def_a, graph_def_b]),
        )
コード例 #10
0
    def test_merge_graph_defs_partitioned_call_remap(self):
        expected_proto = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "graph_1/X"
                  op: "PartitionedCall"
                  attr {
                    key: "f"
                    value {
                      func {
                        name: "graph_1_foo"
                      }
                    }
                  }
                }
                library {
                  function {
                    signature {
                      name: "graph_1_foo"
                      input_arg {
                        name: "x"
                        type: DT_HALF
                      }
                      output_arg {
                        name: "identity"
                        type: DT_HALF
                      }
                    }
                  }
                }
            """,
            expected_proto,
        )

        graph_def_a = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "X"
                  op: "PartitionedCall"
                  attr {
                    key: "f"
                    value {
                      func {
                        name: "foo"
                      }
                    }
                  }
                }
                library {
                  function {
                    signature {
                      name: "foo"
                      input_arg {
                        name: "x"
                        type: DT_HALF
                      }
                      output_arg {
                        name: "identity"
                        type: DT_HALF
                      }
                    }
                  }
                }
            """,
            graph_def_a,
        )
        graph_def_b = GraphDef()

        self.assertProtoEquals(
            expected_proto,
            graph_util.merge_graph_defs([graph_def_a, graph_def_b]),
        )
コード例 #11
0
    def test_merge_graph_defs(self):
        expected_proto = """
            node {
              name: "graph_1/X"
              op: "Input"
            }
            node {
              name: "graph_1/W"
              op: "Input"
            }
            node {
              name: "graph_1/Y"
              op: "MatMul"
              input: "graph_1/X"
              input: "graph_1/W"
            }
            node {
              name: "graph_2/A"
              op: "Input"
            }
            node {
              name: "graph_2/B"
              op: "Input"
            }
            node {
              name: "graph_2/C"
              op: "MatMul"
              input: "graph_2/A"
              input: "graph_2/B"
            }
            node {
              name: "graph_3/A"
              op: "Input"
            }
            node {
              name: "graph_3/B"
              op: "Input"
            }
            versions {
              producer: 21
            }
        """

        graph_def_a = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "X"
                  op: "Input"
                }
                node {
                  name: "W"
                  op: "Input"
                }
                node {
                  name: "Y"
                  op: "MatMul"
                  input: "X"
                  input: "W"
                }
                versions {
                  producer: 21
                }
            """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "A"
                  op: "Input"
                }
                node {
                  name: "B"
                  op: "Input"
                }
                node {
                  name: "C"
                  op: "MatMul"
                  input: "A"
                  input: "B"
                }
                versions {
                  producer: 21
                }
            """,
            graph_def_b,
        )

        graph_def_c = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "A"
                  op: "Input"
                }
                node {
                  name: "B"
                  op: "Input"
                }
                versions {
                  producer: 21
                }
            """,
            graph_def_c,
        )

        self.assertProtoEquals(
            expected_proto,
            graph_util.merge_graph_defs(
                [graph_def_a, graph_def_b, graph_def_c]
            ),
        )
コード例 #12
0
    def test_merge_graph_defs_name_collided_with_same_content(self):
        expected_proto = """
            node {
              name: "graph_1/X"
              op: "Input"
            }
            node {
              name: "graph_1/W"
              op: "Input"
            }
            node {
              name: "graph_1/Y"
              op: "MatMul"
              input: "graph_1/X"
              input: "graph_1/W"
            }
            node {
              name: "graph_2/X"
              op: "Input"
            }
            node {
              name: "graph_2/A"
              op: "Input"
            }
            node {
              name: "graph_2/Y"
              op: "MatMul"
              input: "graph_2/X"
              input: "graph_2/A"
            }
            versions {
              producer: 21
            }
        """

        graph_def_a = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "X"
                  op: "Input"
                }
                node {
                  name: "W"
                  op: "Input"
                }
                node {
                  name: "Y"
                  op: "MatMul"
                  input: "X"
                  input: "W"
                }
                versions {
                  producer: 21
                }
            """,
            graph_def_a,
        )

        graph_def_b = GraphDef()
        text_format.Parse(
            """
                node {
                  name: "X"
                  op: "Input"
                }
                node {
                  name: "A"
                  op: "Input"
                }
                node {
                  name: "Y"
                  op: "MatMul"
                  input: "X"
                  input: "A"
                }
                versions {
                  producer: 21
                }
            """,
            graph_def_b,
        )

        self.assertProtoEquals(
            expected_proto,
            graph_util.merge_graph_defs([graph_def_a, graph_def_b]),
        )
コード例 #13
0
ファイル: keras_util.py プロジェクト: ecrawford-0/tensorflow
def keras_model_to_graph_def(keras_layer):
    """Returns a GraphDef representation of the Keras model in a dict form.

  Note that it only supports models that implemented to_json().

  Args:
    keras_layer: A dict from Keras model.to_json().

  Returns:
    A GraphDef representation of the layers in the model.
  """
    input_to_layer = {}
    model_name_to_output = {}
    g = GraphDef()

    # Sequential model layers do not have a field "inbound_nodes" but
    # instead are defined implicitly via order of layers.
    prev_node_name = None

    for (name_scope, layer) in _walk_layers(keras_layer):
        if _is_model(layer):
            (input_to_layer, model_name_to_output,
             prev_node_name) = _update_dicts(name_scope, layer, input_to_layer,
                                             model_name_to_output,
                                             prev_node_name)
            continue

        layer_config = layer.get('config')
        node_name = _scoped_name(name_scope, layer_config.get('name'))

        node_def = g.node.add()
        node_def.name = node_name

        if layer.get('class_name') is not None:
            keras_cls_name = layer.get('class_name').encode('ascii')
            node_def.attr['keras_class'].s = keras_cls_name

        if layer_config.get('dtype') is not None:
            tf_dtype = dtypes.as_dtype(layer_config.get('dtype'))
            node_def.attr['dtype'].type = tf_dtype.as_datatype_enum

        if layer.get('inbound_nodes') is not None:
            for maybe_inbound_node in layer.get('inbound_nodes'):
                inbound_nodes = _norm_to_list_of_layers(maybe_inbound_node)
                for [name, size, index, _] in inbound_nodes:
                    inbound_name = _scoped_name(name_scope, name)
                    # An input to a layer can be output from a model. In that case, the name
                    # of inbound_nodes to a layer is a name of a model. Remap the name of the
                    # model to output layer of the model. Also, since there can be multiple
                    # outputs in a model, make sure we pick the right output_layer from the model.
                    inbound_node_names = model_name_to_output.get(
                        inbound_name, [inbound_name])
                    node_def.input.append(inbound_node_names[index])
        elif prev_node_name is not None:
            node_def.input.append(prev_node_name)

        if node_name in input_to_layer:
            node_def.input.append(input_to_layer.get(node_name))

        prev_node_name = node_def.name

    return g
コード例 #14
0
 def add_graph(self, model, *args, **kargs):
     visitor = GraphVisitor(model, *args, **kargs)
     stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
     graph = GraphDef(node=visitor._graph, versions=VersionDef(producer=22))
     self._get_file_writer().add_graph((graph, stepstats))
コード例 #15
0
def visualize(
    model_path: str,
    log_path: str,
    input: np.ndarray = None,
    inp_dict: dict = None,
    cal_params: bool = True,
    cal_flops: bool = True,
    cal_activations: bool = True,
    logging_to_stdout: bool = True,
    bar_length_max: int = 20,
):
    r"""
    Load megengine dumped model and visualize graph structure with tensorboard log files.
    Can also record and print model's statistics like :func:`~.module_stats`

    :param model_path: dir path for megengine dumped model.
    :param log_path: dir path for tensorboard graph log.
    :param input: user defined input data for running model and calculating stats, alternative with inp_dict, used when the model has only one input.
    :param inp_dict: input dict for running model and calculating stats, alternative with input, used when the model has more than one input. When both input and inp_dict are None, a random input will be used.
    :param cal_params: whether calculate and record params size.
    :param cal_flops: whether calculate and record op flops.
    :param cal_activations: whether calculate and record op activations.
    :param logging_to_stdout: whether print all calculated statistic details.
    :param bar_length_max: size of bar indicating max flops or parameter size in net stats.

    """
    if log_path:
        try:
            from tensorboard.compat.proto.attr_value_pb2 import AttrValue
            from tensorboard.compat.proto.config_pb2 import RunMetadata
            from tensorboard.compat.proto.graph_pb2 import GraphDef
            from tensorboard.compat.proto.node_def_pb2 import NodeDef
            from tensorboard.compat.proto.step_stats_pb2 import (
                AllocatorMemoryUsed,
                DeviceStepStats,
                NodeExecStats,
                StepStats,
            )
            from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
            from tensorboard.compat.proto.versions_pb2 import VersionDef
            from tensorboardX import SummaryWriter
        except ImportError:
            logger.error(
                "TensorBoard and TensorboardX are required for visualize.",
                exc_info=True,
            )
            return

    enable_receptive_field()

    graph = Network.load(model_path)
    graph.reset_batch_size(1)

    has_input = False
    if input is not None or inp_dict is not None:
        has_input = True
        repl_dict = {}
        inp_vars = graph.input_vars
        if inp_dict is not None:
            assert len(inp_dict) == len(
                inp_vars
            ), "Inputs are not sufficient for calculation."
            for v in inp_vars:
                new_input = graph.make_const(inp_dict[v.name], name=v.name)
                repl_dict[v] = new_input
        else:
            assert len(inp_vars) == 1, "The graph needs more than one input."
            inp_var = inp_vars[0]
            repl_dict[inp_var] = graph.make_const(input, name=inp_var.name)
        graph.replace_vars(repl_dict=repl_dict)

    graph._compile()

    def process_name(name):
        # nodes that start with point or contain float const will lead to display bug
        if not re.match(r"^[+-]?\d*\.\d*", name):
            name = name.replace(".", "/")
        return name.encode(encoding="utf-8")

    summary = [["item", "value"]]
    node_list = []
    flops_list = []
    params_list = []
    activations_list = []
    total_stats = namedtuple("total_stats", ["param_size", "flops", "act_size"])
    stats_details = namedtuple("module_stats", ["params", "flops", "activations"])

    for node in tqdm(graph.all_oprs):
        if hasattr(node, "output_idx"):
            node_oup = node.outputs[node.output_idx]
        else:
            if len(node.outputs) != 1:
                logger.warning(
                    "OpNode {} has more than one output and not has 'output_idx' attr.".format(
                        node
                    )
                )
            node_oup = node.outputs[0]

        inp_list = [process_name(var.owner.name) for var in node.inputs]
        if log_path:
            # detail format see tensorboard/compat/proto/attr_value.proto
            attr = {
                "_output_shapes": AttrValue(
                    list=AttrValue.ListValue(
                        shape=[
                            TensorShapeProto(
                                dim=[
                                    TensorShapeProto.Dim(size=d) for d in node_oup.shape
                                ]
                            )
                        ]
                    )
                ),
                "params": AttrValue(s=str(node.params).encode(encoding="utf-8")),
                "dtype": AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
            }

        if cal_flops:
            flops_stats = get_op_stats(node, node.inputs, node.outputs)
            if flops_stats is not None:
                # add op flops attr
                if log_path and hasattr(flops_stats, "flops_num"):
                    attr["flops"] = AttrValue(
                        s=sizeof_fmt(flops_stats["flops"]).encode(encoding="utf-8")
                    )
                flops_stats["name"] = node.name
                flops_stats["class_name"] = node.type
                flops_list.append(flops_stats)

        if cal_activations:
            acts = get_activation_stats(node_oup.numpy(), has_input=has_input)
            acts["name"] = node.name
            acts["class_name"] = node.type
            activations_list.append(acts)

        if cal_params:
            if node.type == "ImmutableTensor":
                param_stats = get_param_stats(node.numpy())
                # add tensor size attr
                if log_path:
                    attr["size"] = AttrValue(
                        s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8")
                    )
                param_stats["name"] = node.name
                params_list.append(param_stats)

        if log_path:
            node_list.append(
                NodeDef(
                    name=process_name(node.name),
                    op=node.type,
                    input=inp_list,
                    attr=attr,
                )
            )
    # summary
    extra_info = {
        "#ops": len(graph.all_oprs),
        "#params": len(params_list),
    }

    (
        total_flops,
        total_param_dims,
        total_param_size,
        total_act_dims,
        total_act_size,
    ) = (0, 0, 0, 0, 0)

    if cal_params:
        total_param_dims, total_param_size, params_list = sum_param_stats(
            params_list, bar_length_max
        )
        extra_info["total_param_dims"] = sizeof_fmt(total_param_dims, suffix="")
        extra_info["total_param_size"] = sizeof_fmt(total_param_size)
        if logging_to_stdout:
            print_param_stats(params_list)

    if cal_flops:
        total_flops, flops_list = sum_op_stats(flops_list, bar_length_max)
        extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
        if logging_to_stdout:
            print_op_stats(flops_list)

    if cal_activations:
        total_act_dims, total_act_size, activations_list = sum_activations_stats(
            activations_list, bar_length_max
        )
        extra_info["total_act_dims"] = sizeof_fmt(total_act_dims, suffix="")
        extra_info["total_act_size"] = sizeof_fmt(total_act_size)
        if logging_to_stdout:
            print_activations_stats(activations_list, has_input=has_input)

    if cal_flops and cal_params:
        extra_info["flops/param_size"] = "{:3.3f}".format(
            total_flops / total_param_size
        )

    if log_path:
        graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))

        device = "/device:CPU:0"
        stepstats = RunMetadata(
            step_stats=StepStats(dev_stats=[DeviceStepStats(device=device)])
        )
        writer = SummaryWriter(log_path)
        writer._get_file_writer().add_graph((graph_def, stepstats))

    print_summary(**extra_info)

    return (
        total_stats(
            param_size=total_param_size, flops=total_flops, act_size=total_act_size,
        ),
        stats_details(
            params=params_list, flops=flops_list, activations=activations_list
        ),
    )
コード例 #16
0
ファイル: _pytorch_graph.py プロジェクト: zzprice/pytorch
def graph(model, args, verbose=False, operator_export_type='ONNX', omit_useless_nodes=True):
    """
    This method processes a PyTorch model and produces a `GraphDef` proto
    that can be logged to TensorBoard.

    Args:
      model (PyTorch module): The model to be parsed.
      args (tuple): input tensor[s] for the model.
      verbose (bool): Whether to print out verbose information while
        processing.
      operator_export_type (str): One of 'ONNX', 'ONNX_ATEN', or 'RAW'.
        Defaults to 'ONNX' format  because it outputs the most visually
        understandable format.
      omit_useless_nodes (boolean): Whether to remove nodes from the graph.
    """
    operator_export_type = getattr(OperatorExportTypes, operator_export_type)


    with torch.onnx.set_training(model, False):
        try:
            trace, _ = torch.jit.get_trace_graph(model, args)
        except RuntimeError:
            print('Error occurs, No graph saved')
            _ = model(*args)  # don't catch, just print the error message
            print("Checking if it's onnx problem...")
            try:
                import tempfile
                torch.onnx.export(
                    model, args, tempfile.TemporaryFile(), verbose=True)
            except RuntimeError:
                print("Your model cannot be exported by onnx, please report to onnx team")
            # Create an object matching
            # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/graph.proto
            # The producer version has been reverse engineered from standard
            # TensorBoard logged data.
            return GraphDef(versions=VersionDef(producer=22))

    try:
        # An optimized graph helps debug at a higher level. Users can focus
        # on connections between big modules such as Linear instead of W, x,
        # bias, matmul, etc. Honestly, most users don't care about those
        # detailed nodes information.
        _optimize_trace(trace, operator_export_type)
    except RuntimeError as e:
        # Optimize trace might fail (due to bad scopes in some cases we've seen)
        # and we don't want graph visualization to fail in this case. In this
        # case we'll log the warning and display the non-optimized graph.
        logging.warn(ImportError(e))
    graph = trace.graph()
    if verbose:
        print(graph)
    list_of_nodes, node_stats = parse(graph, args, omit_useless_nodes)
    # We are hardcoding that this was run on CPU even though it might have actually
    # run on GPU. Note this is what is shown in TensorBoard and has no bearing
    # on actual execution.
    # TODO: See if we can extract GPU vs CPU information from the PyTorch model
    # and pass it correctly to TensorBoard.
    #
    # Definition of StepStats and DeviceStepStats can be found at
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
    # and
    # https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
    stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0",
                                                                            node_stats=node_stats)]))
    return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats
コード例 #17
0
def visualize(
    model_path: str,
    log_path: str,
    bar_length_max: int = 20,
    log_params: bool = True,
    log_flops: bool = True,
):
    r"""
    Load megengine dumped model and visualize graph structure with tensorboard log files.
    Can also record and print model's statistics like :func:`~.module_stats`

    :param model_path: dir path for megengine dumped model.
    :param log_path: dir path for tensorboard graph log.
    :param bar_length_max: size of bar indicating max flops or parameter size in net stats.
    :param log_params: whether print and record params size.
    :param log_flops: whether print and record op flops.
    """
    if log_path:
        try:
            from tensorboard.compat.proto.attr_value_pb2 import AttrValue
            from tensorboard.compat.proto.config_pb2 import RunMetadata
            from tensorboard.compat.proto.graph_pb2 import GraphDef
            from tensorboard.compat.proto.node_def_pb2 import NodeDef
            from tensorboard.compat.proto.step_stats_pb2 import (
                AllocatorMemoryUsed,
                DeviceStepStats,
                NodeExecStats,
                StepStats,
            )
            from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
            from tensorboard.compat.proto.versions_pb2 import VersionDef
            from tensorboardX import SummaryWriter
        except ImportError:
            logger.error(
                "TensorBoard and TensorboardX are required for visualize.",
                exc_info=True,
            )
            return
    # FIXME: remove this after resolving "span dist too large" warning
    old_level = set_mgb_log_level(logging.ERROR)

    enable_receptive_field()

    graph = Network.load(model_path)

    def process_name(name):
        # nodes that start with point or contain float const will lead to display bug
        if not re.match(r"^[+-]?\d*\.\d*", name):
            name = name.replace(".", "/")
        return name.encode(encoding="utf-8")

    summary = [["item", "value"]]
    node_list = []
    flops_list = []
    params_list = []
    for node in graph.all_oprs:
        if hasattr(node, "output_idx"):
            node_oup = node.outputs[node.output_idx]
        else:
            if len(node.outputs) != 1:
                logger.warning(
                    "OpNode {} has more than one output and not has 'output_idx' attr."
                    .format(node))
            node_oup = node.outputs[0]

        inp_list = [process_name(var.owner.name) for var in node.inputs]
        if log_path:
            # detail format see tensorboard/compat/proto/attr_value.proto
            attr = {
                "_output_shapes":
                AttrValue(list=AttrValue.ListValue(shape=[
                    TensorShapeProto(dim=[
                        TensorShapeProto.Dim(size=d) for d in node_oup.shape
                    ])
                ])),
                "params":
                AttrValue(s=str(node.params).encode(encoding="utf-8")),
                "dtype":
                AttrValue(s=str(node_oup.dtype).encode(encoding="utf-8")),
            }
        flops_stats = get_op_stats(node, node.inputs, node.outputs)
        if flops_stats is not None:
            # add op flops attr
            if log_path and hasattr(flops_stats, "flops_num"):
                attr["flops"] = AttrValue(
                    s=sizeof_fmt(flops_stats["flops"]).encode(
                        encoding="utf-8"))
            flops_stats["name"] = node.name
            flops_stats["class_name"] = node.type
            flops_list.append(flops_stats)

        if node.type == "ImmutableTensor":
            param_stats = get_param_stats(node.numpy())
            # add tensor size attr
            if log_path:
                attr["size"] = AttrValue(
                    s=sizeof_fmt(param_stats["size"]).encode(encoding="utf-8"))
            param_stats["name"] = node.name
            params_list.append(param_stats)

        if log_path:
            node_list.append(
                NodeDef(
                    name=process_name(node.name),
                    op=node.type,
                    input=inp_list,
                    attr=attr,
                ))
    # summary
    extra_info = {
        "#ops": len(graph.all_oprs),
        "#params": len(params_list),
    }

    total_flops, total_param_dims, total_param_size = 0, 0, 0
    if log_params:
        total_param_dims, total_param_size = print_param_stats(
            params_list, bar_length_max)
        extra_info["total_param_dims"] = sizeof_fmt(total_param_dims)
        extra_info["total_param_size"] = sizeof_fmt(total_param_size)
    if log_flops:
        total_flops = print_op_stats(flops_list, bar_length_max)
        extra_info["total_flops"] = sizeof_fmt(total_flops, suffix="OPs")
    if log_params and log_flops:
        extra_info["flops/param_size"] = "{:3.3f}".format(total_flops /
                                                          total_param_size)

    if log_path:
        graph_def = GraphDef(node=node_list, versions=VersionDef(producer=22))

        device = "/device:CPU:0"
        stepstats = RunMetadata(step_stats=StepStats(
            dev_stats=[DeviceStepStats(device=device)]))
        writer = SummaryWriter(log_path)
        writer._get_file_writer().add_graph((graph_def, stepstats))

    print_summary(**extra_info)

    # FIXME: remove this after resolving "span dist too large" warning
    _imperative_rt_logger.set_log_level(old_level)

    return total_param_size, total_flops