def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline', show_debug_info=False): """Import a GraphDef and convert it to a textual MLIR module. This API is only intended for inspecting the internals of TensorFlow and the string returned is at the moment intended for debugging purposes. Args: graph_def: An object of type graph_pb2.GraphDef or a textual proto representation of a valid GraphDef. pass_pipeline: A textual description of an MLIR Pass Pipeline to run on the module, see MLIR documentation for the [textual pass pipeline syntax](https://mlir.llvm.org/docs/PassManagement/#textual-pass-pipeline-specification). show_debug_info: Whether to include locations in the emitted textual form. Returns: A textual representation of the MLIR module corresponding to the graphdef. Raises: InvalidArgumentError: if graph_def is invalid or cannot be converted to MLIR. """ return pywrap_mlir.import_graphdef(graph_def, pass_pipeline, show_debug_info)
def convert_graph_def(graph_def, pass_pipeline='tf-standard-pipeline'): """Import a GraphDef and convert it to a textual MLIR module. Args: graph_def: An object of type graph_pb2.GraphDef or a textual proto representation of a valid GraphDef. pass_pipeline: A textual description of an MLIR Pass Pipeline to run on the module, see MLIR documentation for the [textual pass pipeline syntax](https://mlir.llvm.org/docs/PassManagement/#textual-pass-pipeline-specification). Returns: A textual representation of the MLIR module corresponding to the graphdef. Raises a RuntimeError on error. """ return pywrap_mlir.import_graphdef(graph_def, pass_pipeline)
def testGraphDefToTf(self): """Tests the basic flow of `tf.mlir.experimental.convert_graph_def` with tf-standard-pipeline converting all the way to the TF dialect. """ tensor_shape = (10, 10) @def_function.function( input_signature=( tensor_spec.TensorSpec(shape=tensor_shape, dtype=dtypes.float32), tensor_spec.TensorSpec(shape=tensor_shape, dtype=dtypes.float32), )) def add_func(lhs, rhs): return math_ops.add(lhs, rhs) tf_graph_def = add_func.get_concrete_function().graph.as_graph_def() mlir_tf = import_graphdef( tf_graph_def, "tf-standard-pipeline", False, input_names=["lhs", "rhs"], input_data_types=["DT_FLOAT", "DT_FLOAT"], input_data_shapes=["10,10", "10,10"], output_names=["Add"]) # Check whether the mlir-function signature has the mentioned # inputs and outputs. self.assertRegex( mlir_tf, r"func @main\(%arg0: tensor<10x10xf32>, %arg1: tensor<10x10xf32>") self.assertRegex(mlir_tf, r'inputs = "lhs,rhs"') self.assertRegex(mlir_tf, r'outputs = "Add"') # Same check with scalar input (empty input shape). mlir_tf = import_graphdef( tf_graph_def, "tf-standard-pipeline", False, input_names=["lhs", "rhs"], input_data_types=["DT_FLOAT", "DT_FLOAT"], input_data_shapes=["", ""], output_names=["Add"]) self.assertRegex(mlir_tf, r"func @main\(%arg0: tensor<f32>, %arg1: tensor<f32>") # Test invalid test cases where no. of input names is invalid/wrong. with self.assertRaisesRegex( errors.InvalidArgumentError, "Length of input node array and data type doesn't match"): import_graphdef( tf_graph_def, "tf-standard-pipeline", False, input_names=["lhs"], input_data_types=["DT_FLOAT", "DT_FLOAT"], input_data_shapes=["10,10", "10,10"], output_names=["Add"]) # Test invalid test cases where the input shapes argument is wrong. with self.assertRaisesRegex(errors.InvalidArgumentError, "Dimensions must be equal"): import_graphdef( tf_graph_def, "tf-standard-pipeline", False, input_names=["lhs", "rhs"], input_data_types=["DT_FLOAT", "DT_FLOAT"], input_data_shapes=["10,11", "10,10"], output_names=["Add"])