예제 #1
0
    def test_serialize_tensorflow_with_dataset_not_optimized(self):
        @tf.function
        def test_foo(ds):
            return ds.reduce(np.int64(0), lambda x, y: x + y)

        def legacy_dataset_reducer_example(ds):
            return test_foo(ds)

        comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            legacy_dataset_reducer_example,
            computation_types.SequenceType(tf.int64),
            context_stack_impl.context_stack)
        self.assertEqual(str(type_serialization.deserialize_type(comp.type)),
                         '(int64* -> int64)')
        self.assertEqual(str(extra_type_spec), '(int64* -> int64)')
        self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
        parameter = tf.data.Dataset.range(5)

        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        self.assertGraphDoesNotContainOps(graph_def,
                                          ['OptimizeDataset', 'ModelDataste'])
        results = tf.compat.v1.Session().run(
            tf.import_graph_def(
                graph_def, {
                    comp.tensorflow.parameter.sequence.variant_tensor_name:
                    tf.data.experimental.to_variant(parameter)
                }, [comp.tensorflow.result.tensor.tensor_name]))
        self.assertEqual(results, [10])
예제 #2
0
    def test_serialize_tensorflow_with_table_no_variables(self):
        def table_lookup(word):
            table = tf.lookup.StaticVocabularyTable(
                tf.lookup.KeyValueTensorInitializer(['a', 'b', 'c'],
                                                    np.arange(3,
                                                              dtype=np.int64)),
                num_oov_buckets=1)
            return table.lookup(word)

        comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            table_lookup,
            computation_types.TensorType(dtype=tf.string, shape=(None, )),
            context_stack_impl.context_stack)
        self.assertEqual(str(type_serialization.deserialize_type(comp.type)),
                         '(string[?] -> int64[?])')
        self.assertEqual(str(extra_type_spec), '(string[?] -> int64[?])')
        self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')

        with tf.Graph().as_default() as g:
            tf.import_graph_def(serialization_utils.unpack_graph_def(
                comp.tensorflow.graph_def),
                                name='')
        with tf.compat.v1.Session(graph=g) as sess:
            sess.run(fetches=comp.tensorflow.initialize_op)
            results = sess.run(
                fetches=comp.tensorflow.result.tensor.tensor_name,
                feed_dict={
                    comp.tensorflow.parameter.tensor.tensor_name:
                    ['b', 'c', 'a']
                })
        self.assertAllEqual(results, [1, 2, 0])
예제 #3
0
    def test_returns_string_for_comp_with_left_overhang(self):
        fn_type = computation_types.FunctionType(tf.int32, tf.int32)
        fn = computation_building_blocks.Reference('a', fn_type)
        proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            lambda: tf.constant(1), None, context_stack_impl.context_stack)
        compiled = computation_building_blocks.CompiledComputation(
            proto, 'bbbbb')
        arg = computation_building_blocks.Call(compiled)

        comp = computation_building_blocks.Call(fn, arg)
        compact_string = computation_building_blocks.compact_representation(
            comp)
        self.assertEqual(compact_string, 'a(comp#bbbbb())')
        formatted_string = computation_building_blocks.formatted_representation(
            comp)
        self.assertEqual(formatted_string, 'a(comp#bbbbb())')
        structural_string = computation_building_blocks.structural_representation(
            comp)
        # pyformat: disable
        self.assertEqual(
            structural_string, '           Call\n'
            '          /    \\\n'
            '    Ref(a)      Call\n'
            '               /\n'
            'Compiled(bbbbb)')
 def test_serialize_tensorflow_with_structured_type_signature(self):
     batch_type = collections.namedtuple('BatchType', ['x', 'y'])
     output_type = collections.namedtuple('OutputType', ['A', 'B'])
     comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
         lambda z: output_type(2.0 * tf.cast(z.x, tf.float32), 3.0 * z.y),
         batch_type(tf.int32, (tf.float32, [2])),
         context_stack_impl.context_stack)
     self.assertEqual(
         str(type_serialization.deserialize_type(comp.type)),
         '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)')
     self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
     self.assertEqual(
         str(extra_type_spec),
         '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)')
     self.assertIsInstance(
         extra_type_spec.parameter,
         computation_types.NamedTupleTypeWithPyContainerType)
     self.assertIs(
         computation_types.NamedTupleTypeWithPyContainerType.
         get_container_type(extra_type_spec.parameter), batch_type)
     self.assertIsInstance(
         extra_type_spec.result,
         computation_types.NamedTupleTypeWithPyContainerType)
     self.assertIs(
         computation_types.NamedTupleTypeWithPyContainerType.
         get_container_type(extra_type_spec.result), output_type)
예제 #5
0
 def test_serialize_tensorflow_with_no_parameter(self):
   comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
       lambda: tf.constant(99), None, context_stack_impl.context_stack)
   self.assertEqual(
       str(type_serialization.deserialize_type(comp.type)), '( -> int32)')
   self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
   results = tf.Session().run(
       tf.import_graph_def(comp.tensorflow.graph_def, None,
                           [comp.tensorflow.result.tensor.tensor_name]))
   self.assertEqual(results, [99])
 def test_returns_string_for_compiled_computation(self):
     proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
         lambda: tf.constant(1), None, context_stack_impl.context_stack)
     comp = computation_building_blocks.CompiledComputation(proto, 'a')
     compact_string = comp.compact_representation()
     self.assertEqual(compact_string, 'comp#a')
     formatted_string = comp.formatted_representation()
     self.assertEqual(formatted_string, 'comp#a')
     structural_string = comp.structural_representation()
     self.assertEqual(structural_string, 'Compiled(a)')
예제 #7
0
    def test_replace_compiled_computations_names_replaces_name(self):
        fn = lambda: tf.constant(1)
        tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            fn, None, context_stack_impl.context_stack)
        compiled_comp = computation_building_blocks.CompiledComputation(
            tf_comp)
        comp = compiled_comp

        transformed_comp = transformations.replace_compiled_computations_names_with_unique_names(
            comp)

        self.assertNotEqual(transformed_comp._name, comp._name)
 def test_serialize_tensorflow_with_simple_add_three_lambda(self):
   comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
       lambda x: x + 3, tf.int32, context_stack_impl.context_stack)
   self.assertEqual(
       str(type_serialization.deserialize_type(comp.type)), '(int32 -> int32)')
   self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
   parameter = tf.constant(1000)
   results = tf.Session().run(
       tf.import_graph_def(
           serialization_utils.unpack_graph_def(comp.tensorflow.graph_def),
           {comp.tensorflow.parameter.tensor.tensor_name: parameter},
           [comp.tensorflow.result.tensor.tensor_name]))
   self.assertEqual(results, [1003])
예제 #9
0
 def test_deserialize_and_call_tf_computation_with_add_one(self):
   ctx_stack = context_stack_impl.context_stack
   add_one, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
       lambda x: tf.add(x, 1, name='the_add'), tf.int32, ctx_stack)
   init_op, result = (
       tensorflow_deserialization.deserialize_and_call_tf_computation(
           add_one, tf.constant(10, name='the_ten'), tf.get_default_graph()))
   self.assertTrue(tf.contrib.framework.is_tensor(result))
   with tf.Session() as sess:
     if init_op:
       sess.run(init_op)
     result_val = sess.run(result)
   self.assertEqual(result_val, 11)
def _tf_wrapper_fn(target_fn, parameter_type, name=None):
    """Wrapper function to plug Tensorflow logic in to TFF framework."""
    del name
    if not type_utils.check_tf_comp_whitelisted(parameter_type):
        raise TypeError(
            '`tf_computation`s can accept only parameter types with '
            'constituents `SequenceType`, `NamedTupleType` '
            'and `TensorType`; you have attempted to create one '
            'with the type {}.'.format(parameter_type))
    ctx_stack = context_stack_impl.context_stack
    comp_pb = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        target_fn, parameter_type, ctx_stack)
    return computation_impl.ComputationImpl(comp_pb, ctx_stack)
예제 #11
0
    def test_fetch_value_with_nested_datasets(self):
        def return_two_datasets():
            return [tf.data.Dataset.range(5), tf.data.Dataset.range(5)]

        executable_return_two_datasets = computation_impl.ComputationImpl(
            tensorflow_serialization.serialize_py_fn_as_tf_computation(
                return_two_datasets, None,
                context_stack_impl.context_stack)[0],
            context_stack_impl.context_stack)

        x = executable_return_two_datasets()
        self.assertEqual([i for i in iter(x[0])], list(range(5)))
        self.assertEqual([i for i in iter(x[1])], list(range(5)))
예제 #12
0
    def test_fetch_value_with_empty_dataset_and_tensors(self):
        def return_dataset():
            ds1 = tf.data.Dataset.from_tensor_slices([[1, 1], [1, 1]])
            return [tf.constant([0., 0.]), ds1.batch(5).take(0)]

        executable_return_dataset = computation_impl.ComputationImpl(
            tensorflow_serialization.serialize_py_fn_as_tf_computation(
                return_dataset, None, context_stack_impl.context_stack)[0],
            context_stack_impl.context_stack)

        x = executable_return_dataset()
        self.assertEqual(x[0][0], 0.)
        self.assertEqual(x[0][1], 0.)
        self.assertEqual(str(x[1][0]), str(np.zeros([0, 2], dtype=np.int32)))
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None):
  """Wrapper function to plug Tensorflow logic in to TFF framework."""
  del name  # Unused.
  target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
      target_fn, parameter_type, unpack)
  if not type_utils.is_tensorflow_compatible_type(parameter_type):
    raise TypeError('`tf_computation`s can accept only parameter types with '
                    'constituents `SequenceType`, `NamedTupleType` '
                    'and `TensorType`; you have attempted to create one '
                    'with the type {}.'.format(parameter_type))
  ctx_stack = context_stack_impl.context_stack
  comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
      target_fn, parameter_type, ctx_stack)
  return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
예제 #14
0
    def test_fetch_value_with_dataset_and_tensor(self):
        def return_dataset_and_tensor():
            return [tf.constant(0), tf.data.Dataset.range(5), tf.constant(5)]

        executable_return_dataset_and_tensor = computation_impl.ComputationImpl(
            tensorflow_serialization.serialize_py_fn_as_tf_computation(
                return_dataset_and_tensor, None,
                context_stack_impl.context_stack)[0],
            context_stack_impl.context_stack)

        x = executable_return_dataset_and_tensor()
        self.assertEqual(x[0], 0)
        self.assertEqual([x for x in iter(x[1])], list(range(5)))
        self.assertEqual(x[2], 5)
 def test_basic_functionality_of_compiled_computation_class(self):
     comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
         lambda x: x + 3, tf.int32, context_stack_impl.context_stack)
     x = computation_building_blocks.CompiledComputation(comp)
     self.assertEqual(str(x.type_signature), '(int32 -> int32)')
     self.assertEqual(str(x.proto), str(comp))
     self.assertTrue(
         re.match(
             r'CompiledComputation\([0-9a-f]+, '
             r'FunctionType\(TensorType\(tf\.int32\), '
             r'TensorType\(tf\.int32\)\)\)', repr(x)))
     self.assertTrue(re.match(r'comp#[0-9a-f]+', x.tff_repr))
     y = computation_building_blocks.CompiledComputation(comp, name='foo')
     self.assertEqual(y.tff_repr, 'comp#foo')
     self._serialize_deserialize_roundtrip_test(x)
예제 #16
0
    def test_fetch_value_with_empty_dataset_and_tensors(self):
        def return_dataset():
            ds1 = tf.data.Dataset.from_tensor_slices([[1, 1], [1, 1]])
            return [tf.constant([0., 0.]), ds1.batch(5).take(0)]

        executable_return_dataset = computation_impl.ComputationImpl(
            tensorflow_serialization.serialize_py_fn_as_tf_computation(
                return_dataset, None, context_stack_impl.context_stack)[0],
            context_stack_impl.context_stack)

        x = executable_return_dataset()
        self.assertAllEqual(x[0], [0., 0.])
        self.assertEqual(x[1].element_spec,
                         tf.TensorSpec(shape=(None, 2), dtype=tf.int32))
        with self.assertRaises(StopIteration):
            _ = next(iter(x[1]))
예제 #17
0
  def test_fetch_value_with_datasets_nested_at_second_level(self):

    def return_two_datasets():
      return [
          tf.constant(0), [tf.data.Dataset.range(5),
                           tf.data.Dataset.range(5)]
      ]

    executable_return_two_datasets = computation_impl.ComputationImpl(
        tensorflow_serialization.serialize_py_fn_as_tf_computation(
            return_two_datasets, None, context_stack_impl.context_stack)[0],
        context_stack_impl.context_stack)

    x = executable_return_two_datasets()
    self.assertEqual(x[0], 0)
    self.assertEqual(x[1][0], list(range(5)))
    self.assertEqual(x[1][1], list(range(5)))
예제 #18
0
def _wrap_constant_as_value(const, context_stack):
    """Wraps the given Python constant as a `tff.Value`.

  Args:
    const: Python constant to be converted to TFF value. Anything convertible to
      Tensor via `tf.constant` can be passed in.
    context_stack: The context stack to use.

  Returns:
    An instance of `value_base.Value`.
  """
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        lambda: tf.constant(const), None, context_stack)
    compiled_comp = computation_building_blocks.CompiledComputation(tf_comp)
    called_comp = computation_building_blocks.Call(compiled_comp)
    return ValueImpl(called_comp, context_stack)
예제 #19
0
 def test_returns_string_for_call_with_no_arg(self):
     proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
         lambda: tf.constant(1), None, context_stack_impl.context_stack)
     compiled = computation_building_blocks.CompiledComputation(proto, 'a')
     comp = computation_building_blocks.Call(compiled)
     compact_string = computation_building_blocks.compact_representation(
         comp)
     self.assertEqual(compact_string, 'comp#a()')
     formatted_string = computation_building_blocks.formatted_representation(
         comp)
     self.assertEqual(formatted_string, 'comp#a()')
     structural_string = computation_building_blocks.structural_representation(
         comp)
     # pyformat: disable
     self.assertEqual(structural_string, '            Call\n'
                      '           /\n'
                      'Compiled(a)')
예제 #20
0
    def test_fetch_value_with_empty_structured_dataset_and_tensors(self):
        def return_dataset():
            ds1 = tf.data.Dataset.from_tensor_slices(
                collections.OrderedDict([('a', [1, 1]), ('b', [1, 1])]))
            return [tf.constant([0., 0.]), ds1.batch(5).take(0)]

        executable_return_dataset = computation_impl.ComputationImpl(
            tensorflow_serialization.serialize_py_fn_as_tf_computation(
                return_dataset, None, context_stack_impl.context_stack)[0],
            context_stack_impl.context_stack)

        x = executable_return_dataset()
        self.assertEqual(x[0][0], 0.)
        self.assertEqual(x[0][1], 0.)
        self.assertTrue(
            np.array_equal(x[1][0].a, np.zeros([0], dtype=np.int32)))
        self.assertTrue(
            np.array_equal(x[1][0].b, np.zeros([0], dtype=np.int32)))
예제 #21
0
def _create_call_to_py_fn(fn):
    r"""Creates a computation to call a Python function.

           Call
          /
  Compiled
  Computation

  Args:
    fn: The Python function to wrap.

  Returns:
    An instance of `computation_building_blocks.Call` wrapping the Python
    function.
  """
    tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        fn, None, context_stack_impl.context_stack)
    compiled_comp = computation_building_blocks.CompiledComputation(tf_comp)
    return computation_building_blocks.Call(compiled_comp)
예제 #22
0
  def test_serialize_tensorflow_with_data_set_sum_lambda(self):

    def _legacy_dataset_reducer_example(ds):
      return ds.reduce(np.int64(0), lambda x, y: x + y)

    comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        _legacy_dataset_reducer_example, computation_types.SequenceType(
            tf.int64), context_stack_impl.context_stack)
    self.assertEqual(
        str(type_serialization.deserialize_type(comp.type)),
        '(int64* -> int64)')
    self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
    parameter = tf.data.Dataset.range(5)
    results = tf.Session().run(
        tf.import_graph_def(
            comp.tensorflow.graph_def, {
                comp.tensorflow.parameter.sequence.iterator_string_handle_name:
                    (parameter.make_one_shot_iterator().string_handle())
            }, [comp.tensorflow.result.tensor.tensor_name]))
    self.assertEqual(results, [10])
예제 #23
0
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None):
    """Wrapper function to plug Tensorflow logic into the TFF framework.

  This function is passed through `computation_wrapper.ComputationWrapper`.
  Documentation its arguments can be found inside the definition of that class.
  """
    del name  # Unused.
    target_fn = function_utils.wrap_as_zero_or_one_arg_callable(
        target_fn, parameter_type, unpack)
    if not type_analysis.is_tensorflow_compatible_type(parameter_type):
        raise TypeError(
            '`tf_computation`s can accept only parameter types with '
            'constituents `SequenceType`, `StructType` '
            'and `TensorType`; you have attempted to create one '
            'with the type {}.'.format(parameter_type))
    ctx_stack = context_stack_impl.context_stack
    comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        target_fn, parameter_type, ctx_stack)
    return computation_impl.ComputationImpl(comp_pb, ctx_stack,
                                            extra_type_spec)
예제 #24
0
    def test_fetch_value_with_empty_structured_dataset_and_tensors(self):
        def return_dataset():
            ds1 = tf.data.Dataset.from_tensor_slices(
                collections.OrderedDict([('a', [1, 1]), ('b', [1, 1])]))
            return [tf.constant([0., 0.]), ds1.batch(5).take(0)]

        executable_return_dataset = computation_impl.ComputationImpl(
            tensorflow_serialization.serialize_py_fn_as_tf_computation(
                return_dataset, None, context_stack_impl.context_stack)[0],
            context_stack_impl.context_stack)

        x = executable_return_dataset()
        self.assertAllEqual(x[0], [0., 0.])
        self.assertEqual(
            tf.data.experimental.get_structure(x[1]),
            collections.OrderedDict([
                ('a', tf.TensorSpec(shape=(None, ), dtype=tf.int32)),
                ('b', tf.TensorSpec(shape=(None, ), dtype=tf.int32)),
            ]))
        with self.assertRaises(StopIteration):
            _ = next(iter(x[1]))
    def test_serialize_tensorflow_with_data_set_sum_lambda(self):
        def _legacy_dataset_reducer_example(ds):
            return ds.reduce(np.int64(0), lambda x, y: x + y)

        comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            _legacy_dataset_reducer_example,
            computation_types.SequenceType(tf.int64),
            context_stack_impl.context_stack)
        self.assertEqual(str(type_serialization.deserialize_type(comp.type)),
                         '(int64* -> int64)')
        self.assertEqual(str(extra_type_spec), '(int64* -> int64)')
        self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
        parameter = tf.data.Dataset.range(5)
        results = tf.compat.v1.Session().run(
            tf.import_graph_def(
                serialization_utils.unpack_graph_def(
                    comp.tensorflow.graph_def), {
                        comp.tensorflow.parameter.sequence.variant_tensor_name:
                        tf.data.experimental.to_variant(parameter)
                    }, [comp.tensorflow.result.tensor.tensor_name]))
        self.assertEqual(results, [10])
예제 #26
0
def _wrap_sequence_as_value(elements, element_type, context_stack):
    """Wraps `elements` as a TFF sequence with elements of type `element_type`.

  Args:
    elements: Python object to the wrapped as a TFF sequence value.
    element_type: An instance of `Type` that determines the type of elements of
      the sequence.
    context_stack: The context stack to use.

  Returns:
    An instance of `tff.Value`.

  Raises:
    TypeError: If `elements` and `element_type` are of incompatible types.
  """
    # TODO(b/113116813): Add support for other representations of sequences.
    py_typecheck.check_type(elements, list)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)

    # Checks that the types of all the individual elements are compatible with the
    # requested type of the sequence as a while.
    for elem in elements:
        elem_type = type_utils.infer_type(elem)
        if not type_utils.is_assignable_from(element_type, elem_type):
            raise TypeError(
                'Expected all sequence elements to be {}, found {}.'.format(
                    str(element_type), str(elem_type)))

    # Defines a no-arg function that builds a `tf.data.Dataset` from the elements.
    def _create_dataset_from_elements():
        return graph_utils.make_data_set_from_elements(tf.get_default_graph(),
                                                       elements, element_type)

    # Wraps the dataset as a value backed by a no-argument TensorFlow computation.
    return ValueImpl(
        computation_building_blocks.Call(
            computation_building_blocks.CompiledComputation(
                tensorflow_serialization.serialize_py_fn_as_tf_computation(
                    _create_dataset_from_elements, None, context_stack))),
        context_stack)
예제 #27
0
def _create_lambda_to_cast(dtype1, dtype2):
    r"""Creates a computation to TensorFlow cast from dtype1 to dtype2.

  Lambda
        \
         Call
        /    \
  Compiled   Reference
  Computation

  Where `CompiledComputation` is a TensorFlow computation casting
  from `dtype1` to `dtype2`.

  The `dtype` arguments can be either instances of `tf.dtypes.DType` or
  `computation_types.TensorType`, but in the latter case the `tf.dtypes.DType`
  of these tensors will be extracted.

  Args:
    dtype1: The type of the argument.
    dtype2: The type to cast the argument to.

  Returns:
    An instance of `computation_building_blocks.Lambda` wrapping a function that
    casts TensorFlow dtype1 to dtype2.
  """
    if isinstance(dtype1, computation_types.TensorType):
        dtype1 = dtype1.dtype
    if isinstance(dtype2, computation_types.TensorType):
        dtype2 = dtype2.dtype
    py_typecheck.check_type(dtype1, tf.dtypes.DType)
    py_typecheck.check_type(dtype2, tf.dtypes.DType)
    arg = computation_building_blocks.Reference('arg', dtype1)
    tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        lambda x: tf.cast(x, dtype2), dtype1, context_stack_impl.context_stack)
    compiled_comp = computation_building_blocks.CompiledComputation(tf_comp)
    call = computation_building_blocks.Call(compiled_comp, arg)
    return computation_building_blocks.Lambda(arg.name, dtype1, call)
예제 #28
0
    def test_replace_compiled_computations_names_replaces_multiple_names(self):
        comps = []
        for _ in range(10):
            fn = lambda: tf.constant(1)
            tf_comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
                fn, None, context_stack_impl.context_stack)
            compiled_comp = computation_building_blocks.CompiledComputation(
                tf_comp)
            comps.append(compiled_comp)
        tup = computation_building_blocks.Tuple(comps)
        comp = tup

        transformed_comp = transformations.replace_compiled_computations_names_with_unique_names(
            comp)

        comp_names = [element._name for element in comp]
        transformed_comp_names = [
            element._name for element in transformed_comp
        ]
        self.assertNotEqual(transformed_comp_names, comp_names)
        self.assertEqual(
            len(transformed_comp_names), len(set(transformed_comp_names)),
            'The transformed computation names are not unique: {}.'.format(
                transformed_comp_names))
예제 #29
0
def _create_compiled_computation(py_fn, arg_type):
    proto, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        py_fn, arg_type, context_stack_impl.context_stack)
    return building_blocks.CompiledComputation(proto)