コード例 #1
0
def convert_graph_def(graph_def,
                      pass_pipeline='tf-standard-pipeline',
                      show_debug_info=False):
  """Import a GraphDef and convert it to a textual MLIR module.

  This API is only intended for inspecting the internals of TensorFlow and the
  string returned is at the moment intended for debugging purposes.

  Args:
    graph_def: An object of type graph_pb2.GraphDef or a textual proto
      representation of a valid GraphDef.
    pass_pipeline: A textual description of an MLIR Pass Pipeline to run on the
      module, see MLIR documentation for the
      [textual pass pipeline syntax](https://mlir.llvm.org/docs/PassManagement/#textual-pass-pipeline-specification).
    show_debug_info: Whether to include locations in the emitted textual form.

  Returns:
    A textual representation of the MLIR module corresponding to the graphdef.

  Raises:
    InvalidArgumentError: if graph_def is invalid or cannot be converted to
      MLIR.

  """
  return pywrap_mlir.import_graphdef(graph_def, pass_pipeline, show_debug_info)
コード例 #2
0
def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'):
  """Import a GraphDef and convert it to a textual MLIR module.

  Args:
    graph_def: An object of type graph_pb2.GraphDef or a textual proto
      representation of a valid GraphDef.
    pass_pipeline: A textual description of an MLIR Pass Pipeline to run on the
      module, see MLIR documentation for the
      [textual pass pipeline syntax](https://mlir.llvm.org/docs/PassManagement/#textual-pass-pipeline-specification).

  Returns:
    A textual representation of the MLIR module corresponding to the graphdef.
    Raises a RuntimeError on error.

  """
  return pywrap_mlir.import_graphdef(graph_def, pass_pipeline)
コード例 #3
0
  def testGraphDefToTf(self):
    """Tests the basic flow of `tf.mlir.experimental.convert_graph_def`

        with tf-standard-pipeline converting all the way to the TF dialect.
    """

    tensor_shape = (10, 10)

    @def_function.function(
        input_signature=(
            tensor_spec.TensorSpec(shape=tensor_shape, dtype=dtypes.float32),
            tensor_spec.TensorSpec(shape=tensor_shape, dtype=dtypes.float32),
        ))
    def add_func(lhs, rhs):
      return math_ops.add(lhs, rhs)

    tf_graph_def = add_func.get_concrete_function().graph.as_graph_def()

    mlir_tf = import_graphdef(
        tf_graph_def,
        "tf-standard-pipeline",
        False,
        input_names=["lhs", "rhs"],
        input_data_types=["DT_FLOAT", "DT_FLOAT"],
        input_data_shapes=["10,10", "10,10"],
        output_names=["Add"])
    # Check whether the mlir-function signature has the mentioned
    # inputs and outputs.
    self.assertRegex(
        mlir_tf,
        r"func @main\(%arg0: tensor<10x10xf32>, %arg1: tensor<10x10xf32>")
    self.assertRegex(mlir_tf, r'inputs = "lhs,rhs"')
    self.assertRegex(mlir_tf, r'outputs = "Add"')

    # Same check with scalar input (empty input shape).
    mlir_tf = import_graphdef(
        tf_graph_def,
        "tf-standard-pipeline",
        False,
        input_names=["lhs", "rhs"],
        input_data_types=["DT_FLOAT", "DT_FLOAT"],
        input_data_shapes=["", ""],
        output_names=["Add"])
    self.assertRegex(mlir_tf,
                     r"func @main\(%arg0: tensor<f32>, %arg1: tensor<f32>")

    # Test invalid test cases where no. of input names is invalid/wrong.
    with self.assertRaisesRegex(
        errors.InvalidArgumentError,
        "Length of input node array and data type doesn't match"):

      import_graphdef(
          tf_graph_def,
          "tf-standard-pipeline",
          False,
          input_names=["lhs"],
          input_data_types=["DT_FLOAT", "DT_FLOAT"],
          input_data_shapes=["10,10", "10,10"],
          output_names=["Add"])

    # Test invalid test cases where the input shapes argument is wrong.
    with self.assertRaisesRegex(errors.InvalidArgumentError,
                                "Dimensions must be equal"):

      import_graphdef(
          tf_graph_def,
          "tf-standard-pipeline",
          False,
          input_names=["lhs", "rhs"],
          input_data_types=["DT_FLOAT", "DT_FLOAT"],
          input_data_shapes=["10,11", "10,10"],
          output_names=["Add"])