Example #1
0
 def test_computation_callable(self):
     tf_module = tf.Module()
     fn = lambda x: x + 1.0
     sig = [tf.TensorSpec([], tf.float32)]
     tf_module.foo = tf.function(fn, input_signature=sig)
     with tempfile.TemporaryDirectory() as model_dir:
         save_options = tf.saved_model.SaveOptions(save_debug_info=True)
         tf.saved_model.save(tf_module, model_dir, options=save_options)
         iree_compiler_module = iree.compiler.tf.compile_saved_model(
             model_dir, import_only=True)
     my_computation_module = computation_module.ComputationModule(
         iree_compiler_module, 'foo',
         computation_types.FunctionType(tf.float32, tf.float32))
     computation_callable = runtime.ComputationCallable(
         my_computation_module, backend_info.VULKAN_SPIRV)
     self.assertTrue(callable(computation_callable))
     result = computation_callable(np.float32(5.0))
     self.assertEqual(result, 6.0)
 def test_module_class_with_add_one(self):
     tf_module = tf.Module()
     tf_module.foo = tf.function(
         lambda x: x + 1.0,
         input_signature=[tf.TensorSpec(shape=(), dtype=tf.float32)])
     model_dir = '/tmp/foo'
     save_options = tf.saved_model.SaveOptions(save_debug_info=True)
     tf.saved_model.save(tf_module, model_dir, options=save_options)
     iree_compiler_module = iree_compiler.tf_saved_model_to_compiler_module(
         model_dir, exported_names=['foo'])
     my_computation_module = computation_module.ComputationModule(
         iree_compiler_module, 'foo',
         computation_types.FunctionType(tf.float32, tf.float32))
     self.assertIs(my_computation_module.compiler_module,
                   iree_compiler_module)
     self.assertEqual(my_computation_module.function_name, 'foo')
     self.assertEqual(str(my_computation_module.type_signature),
                      '(float32 -> float32)')
Example #3
0
def import_tensorflow_computation(comp, name='fn'):
    """Creates a `computation_module.ComputationModule` from a TF computation.

  WARNING: This helper function is under construction, and most capabilities are
  not implemented at this stage:

  * The parameter and result of `comp` can only be a single tensor. Named
    tuples, sequences, or functional types are not currently supported.

  * Only tensorflow code can be imported.

  TODO(b/153499219): Add support for named tuples, sequences, and functions.

  Args:
    comp: An instance of a `pb.Computation` with TensorFlow code to import.
    name: An optional `str` name of the (single) function in the IREE module.

  Returns:
    An instance of `Module` with the imported function present.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    py_typecheck.check_type(comp, pb.Computation)
    type_spec = type_serialization.deserialize_type(comp.type)
    if not type_spec.is_function():
        type_spec = computation_types.FunctionType(None, type_spec)

    # TODO(b/153499219): Replace this with a recursive check of the signature
    # after relaxing the type restrictions and introducing nested structures.
    py_typecheck.check_type(type_spec.result, computation_types.TensorType)
    if type_spec.parameter is not None:
        py_typecheck.check_type(type_spec.parameter,
                                computation_types.TensorType)

    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)
    if type_spec.parameter is not None:
        input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
    init_op = comp.tensorflow.initialize_op
    return_elements = input_tensor_names + output_tensor_names
    if init_op:
        graph_def = tensorflow_utils.add_control_deps_for_init_op(
            graph_def, init_op)
        return_elements.append(init_op)

    with tf.Graph().as_default() as graph:
        # TODO(b/153499219): See if we can reintroduce uniquify_shared_names().
        # Right now, it causes loader breakage, and unclear if still necessary.
        import_results = tf.graph_util.import_graph_def(
            graph_def, input_map={}, return_elements=return_elements, name='')

    if init_op:
        initializer = import_results[-1]
        import_results.pop()
    else:
        initializer = None

    inputs = import_results[0:len(input_tensor_names)]
    outputs = import_results[len(input_tensor_names):]

    with graph.as_default():
        # TODO(b/153499219): Find a way to reflect the nested parameter and result
        # structure here after relaxing the restrictions.
        if inputs:
            assert len(inputs) < 2
            input_dict = {
                'parameter':
                tf.compat.v1.saved_model.utils.build_tensor_info(inputs[0])
            }
        else:
            input_dict = {}
        assert len(outputs) == 1
        output_dict = {
            'result':
            tf.compat.v1.saved_model.utils.build_tensor_info(outputs[0])
        }
        sig_def = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
            inputs=input_dict, outputs=output_dict, method_name=name)
        with tempfile.TemporaryDirectory() as model_dir:
            builder = tf.compat.v1.saved_model.Builder(model_dir)
            with tf.compat.v1.Session(graph=graph) as sess:
                builder.add_meta_graph_and_variables(
                    sess, ['unused'],
                    signature_def_map={name: sig_def},
                    legacy_init_op=initializer,
                    strip_default_attrs=True)
                builder.save()
            iree_module = iree.compiler.tf.compile_saved_model(
                model_dir,
                import_type='SIGNATURE_DEF',
                import_only=True,
                saved_model_tags=set(['unused']),
                exported_names=[name])
            return computation_module.ComputationModule(
                iree_module, name, type_spec)