Ejemplo n.º 1
0
 def test_add_control_deps_for_init_op(self):
   # Creates a graph (double edges are regular dependencies, single edges are
   # control dependencies) like this:
   #
   #  ghi
   #   |
   #  def
   #   ||
   #  def:0         foo
   #   ||        //     ||
   #  abc      bar      ||
   #     \   //   \\    ||
   #      bak        baz
   #
   graph_def = tf.compat.v1.GraphDef(node=[
       tf.compat.v1.NodeDef(name='foo', input=[]),
       tf.compat.v1.NodeDef(name='bar', input=['foo']),
       tf.compat.v1.NodeDef(name='baz', input=['foo', 'bar']),
       tf.compat.v1.NodeDef(name='bak', input=['bar', '^abc']),
       tf.compat.v1.NodeDef(name='abc', input=['def:0']),
       tf.compat.v1.NodeDef(name='def', input=['^ghi']),
       tf.compat.v1.NodeDef(name='ghi', input=[]),
   ])
   new_graph_def = tensorflow_utils.add_control_deps_for_init_op(
       graph_def, 'abc')
   self.assertEqual(
       ','.join('{}({})'.format(node.name, ','.join(node.input))
                for node in new_graph_def.node),
       'foo(^abc),bar(foo,^abc),baz(foo,bar,^abc),'
       'bak(bar,^abc),abc(def:0),def(^ghi),ghi()')
Ejemplo n.º 2
0
    def function_to_wrap():
        """No-arg function to import graph def.

    We pass a no-arg function to `tf.compat.v1.wrap_function` to avoid
    the leftover placeholders that can result from binding arguments to the
    imported graphdef via `input_map`. The correct signature will be added to
    this function later, via the `prune` call below.

    Returns:
      Result of importing graphdef backing `comp`.
    """
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        # TODO(b/159180073): clean raise after fixing dataset reduce.
        _check_dataset_reduce_in_multi_gpu(graph_def)

        init_op = comp.tensorflow.initialize_op
        if init_op:
            graph_def = tensorflow_utils.add_control_deps_for_init_op(
                graph_def, init_op)

        def _import_fn():
            return tf.import_graph_def(
                graph_merge.uniquify_shared_names(graph_def), name='')

        if must_pin_function_to_cpu:
            with tf.device('cpu'):
                return _import_fn()
        elif device is not None:
            with tf.device(device.name):
                return _import_fn()
        else:
            return _import_fn()
Ejemplo n.º 3
0
    def function_to_wrap(*args):  # pylint: disable=missing-docstring
        if len(args) != len(input_tensor_names):
            raise RuntimeError('Expected {} arguments, found {}.'.format(
                len(input_tensor_names), len(args)))
        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        init_op = comp.tensorflow.initialize_op
        if init_op:
            graph_def = tensorflow_utils.add_control_deps_for_init_op(
                graph_def, init_op)

        def _import_fn():
            return tf.import_graph_def(
                graph_merge.uniquify_shared_names(graph_def),
                input_map=dict(list(zip(input_tensor_names, args))),
                return_elements=output_tensor_names)

        if must_pin_function_to_cpu:
            with tf.device('cpu'):
                return _import_fn()
        elif device is not None:
            with tf.device(device):
                return _import_fn()
        else:
            return _import_fn()
Ejemplo n.º 4
0
 def test_add_control_deps_for_init_op(self):
     graph_def = tf.compat.v1.GraphDef(node=[
         tf.compat.v1.NodeDef(name='foo', input=[]),
         tf.compat.v1.NodeDef(name='bar', input=['foo']),
         tf.compat.v1.NodeDef(name='baz', input=['foo', 'bar']),
         tf.compat.v1.NodeDef(name='bak', input=['bar', '^abc']),
         tf.compat.v1.NodeDef(name='abc', input=['def:0']),
         tf.compat.v1.NodeDef(name='def', input=['^ghi']),
         tf.compat.v1.NodeDef(name='ghi', input=[]),
     ])
     new_graph_def = tensorflow_utils.add_control_deps_for_init_op(
         graph_def, 'abc')
     self.assertEqual(
         ','.join('{}({})'.format(node.name, ','.join(node.input))
                  for node in new_graph_def.node),
         'foo(^abc),bar(foo,^abc),baz(foo,bar,^abc),'
         'bak(bar,^abc),abc(def:0),def(^ghi),ghi()')
Ejemplo n.º 5
0
def import_tensorflow_computation(comp, name='fn'):
    """Creates a `computation_module.ComputationModule` from a TF computation.

  WARNING: This helper function is under construction, and most capabilities are
  not implemented at this stage:

  * The parameter and result of `comp` can only be a single tensor. Named
    tuples, sequences, or functional types are not currently supported.

  * Only tensorflow code can be imported.

  TODO(b/153499219): Add support for named tuples, sequences, and functions.

  Args:
    comp: An instance of a `pb.Computation` with TensorFlow code to import.
    name: An optional `str` name of the (single) function in the IREE module.

  Returns:
    An instance of `Module` with the imported function present.

  Raises:
    TypeError: If arguments are of the wrong types, e.g., in `comp` is not a
      TensorFlow computation.
  """
    py_typecheck.check_type(comp, pb.Computation)
    type_spec = type_serialization.deserialize_type(comp.type)
    if not type_spec.is_function():
        type_spec = computation_types.FunctionType(None, type_spec)

    # TODO(b/153499219): Replace this with a recursive check of the signature
    # after relaxing the type restrictions and introducing nested structures.
    py_typecheck.check_type(type_spec.result, computation_types.TensorType)
    if type_spec.parameter is not None:
        py_typecheck.check_type(type_spec.parameter,
                                computation_types.TensorType)

    which_computation = comp.WhichOneof('computation')
    if which_computation != 'tensorflow':
        raise TypeError('Expected a TensorFlow computation, found {}.'.format(
            which_computation))

    output_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
        comp.tensorflow.result)
    if type_spec.parameter is not None:
        input_tensor_names = tensorflow_utils.extract_tensor_names_from_binding(
            comp.tensorflow.parameter)
    else:
        input_tensor_names = []

    graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
    init_op = comp.tensorflow.initialize_op
    return_elements = input_tensor_names + output_tensor_names
    if init_op:
        graph_def = tensorflow_utils.add_control_deps_for_init_op(
            graph_def, init_op)
        return_elements.append(init_op)

    with tf.Graph().as_default() as graph:
        # TODO(b/153499219): See if we can reintroduce uniquify_shared_names().
        # Right now, it causes loader breakage, and unclear if still necessary.
        import_results = tf.graph_util.import_graph_def(
            graph_def, input_map={}, return_elements=return_elements, name='')

    if init_op:
        initializer = import_results[-1]
        import_results.pop()
    else:
        initializer = None

    inputs = import_results[0:len(input_tensor_names)]
    outputs = import_results[len(input_tensor_names):]

    with graph.as_default():
        # TODO(b/153499219): Find a way to reflect the nested parameter and result
        # structure here after relaxing the restrictions.
        if inputs:
            assert len(inputs) < 2
            input_dict = {
                'parameter':
                tf.compat.v1.saved_model.utils.build_tensor_info(inputs[0])
            }
        else:
            input_dict = {}
        assert len(outputs) == 1
        output_dict = {
            'result':
            tf.compat.v1.saved_model.utils.build_tensor_info(outputs[0])
        }
        sig_def = tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
            inputs=input_dict, outputs=output_dict, method_name=name)
        with tempfile.TemporaryDirectory() as model_dir:
            builder = tf.compat.v1.saved_model.Builder(model_dir)
            with tf.compat.v1.Session(graph=graph) as sess:
                builder.add_meta_graph_and_variables(
                    sess, ['unused'],
                    signature_def_map={name: sig_def},
                    legacy_init_op=initializer,
                    strip_default_attrs=True)
                builder.save()
            iree_module = iree.compiler.tf.compile_saved_model(
                model_dir,
                import_type='SIGNATURE_DEF',
                import_only=True,
                saved_model_tags=set(['unused']),
                exported_names=[name])
            return computation_module.ComputationModule(
                iree_module, name, type_spec)