def test_basic_functionality_of_tuple_class(self):
     x = building_blocks.Reference('foo', tf.int32)
     y = building_blocks.Reference('bar', tf.bool)
     z = building_blocks.Tuple([x, ('y', y)])
     with self.assertRaises(ValueError):
         _ = building_blocks.Tuple([('', y)])
     self.assertIsInstance(z, anonymous_tuple.AnonymousTuple)
     self.assertEqual(str(z.type_signature), '<int32,y=bool>')
     self.assertEqual(
         repr(z),
         'Tuple([(None, Reference(\'foo\', TensorType(tf.int32))), (\'y\', '
         'Reference(\'bar\', TensorType(tf.bool)))])')
     self.assertEqual(z.compact_representation(), '<foo,y=bar>')
     self.assertEqual(dir(z), ['y'])
     self.assertIs(z.y, y)
     self.assertLen(z, 2)
     self.assertIs(z[0], x)
     self.assertIs(z[1], y)
     self.assertEqual(','.join(e.compact_representation() for e in iter(z)),
                      'foo,bar')
     z_proto = z.proto
     self.assertEqual(type_serialization.deserialize_type(z_proto.type),
                      z.type_signature)
     self.assertEqual(z_proto.WhichOneof('computation'), 'tuple')
     self.assertEqual([e.name for e in z_proto.tuple.element], ['', 'y'])
     self._serialize_deserialize_roundtrip_test(z)
Exemple #2
0
    def test_returns_coputation(self):
        proto = computation_factory.create_lambda_empty_tuple()

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(None, [])
        self.assertEqual(actual_type, expected_type)
 def test_basic_functionality_of_block_class(self):
     x = building_blocks.Block([
         ('x', building_blocks.Reference('arg', (tf.int32, tf.int32))),
         ('y',
          building_blocks.Selection(
              building_blocks.Reference('x',
                                        (tf.int32, tf.int32)), index=0))
     ], building_blocks.Reference('y', tf.int32))
     self.assertEqual(str(x.type_signature), 'int32')
     self.assertEqual([(k, v.compact_representation())
                       for k, v in x.locals], [('x', 'arg'), ('y', 'x[0]')])
     self.assertEqual(x.result.compact_representation(), 'y')
     self.assertEqual(
         repr(x), 'Block([(\'x\', Reference(\'arg\', '
         'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)]))), '
         '(\'y\', Selection(Reference(\'x\', '
         'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)])), '
         'index=0))], '
         'Reference(\'y\', TensorType(tf.int32)))')
     self.assertEqual(x.compact_representation(), '(let x=arg,y=x[0] in y)')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'block')
     self.assertEqual(str(x_proto.block.result), str(x.result.proto))
     for idx, loc_proto in enumerate(x_proto.block.local):
         loc_name, loc_value = x.locals[idx]
         self.assertEqual(loc_proto.name, loc_name)
         self.assertEqual(str(loc_proto.value), str(loc_value.proto))
         self._serialize_deserialize_roundtrip_test(x)
 def test_intrinsic_class_succeeds_simple_federated_map(self):
     simple_function = computation_types.FunctionType(tf.int32, tf.float32)
     federated_arg = computation_types.FederatedType(
         simple_function.parameter, placement_literals.CLIENTS)
     federated_result = computation_types.FederatedType(
         simple_function.result, placement_literals.CLIENTS)
     federated_map_concrete_type = computation_types.FunctionType(
         [simple_function, federated_arg], federated_result)
     concrete_federated_map = building_blocks.Intrinsic(
         intrinsic_defs.FEDERATED_MAP.uri, federated_map_concrete_type)
     self.assertIsInstance(concrete_federated_map,
                           building_blocks.Intrinsic)
     self.assertEqual(
         str(concrete_federated_map.type_signature),
         '(<(int32 -> float32),{int32}@CLIENTS> -> {float32}@CLIENTS)')
     self.assertEqual(concrete_federated_map.uri, 'federated_map')
     self.assertEqual(concrete_federated_map.compact_representation(),
                      'federated_map')
     concrete_federated_map_proto = concrete_federated_map.proto
     self.assertEqual(
         type_serialization.deserialize_type(
             concrete_federated_map_proto.type),
         concrete_federated_map.type_signature)
     self.assertEqual(
         concrete_federated_map_proto.WhichOneof('computation'),
         'intrinsic')
     self.assertEqual(concrete_federated_map_proto.intrinsic.uri,
                      concrete_federated_map.uri)
     self._serialize_deserialize_roundtrip_test(concrete_federated_map)
  def test_serialize_tensorflow_with_table_no_variables(self):

    def table_lookup(word):
      table = tf.lookup.StaticVocabularyTable(
          tf.lookup.KeyValueTensorInitializer(['a', 'b', 'c'],
                                              np.arange(3, dtype=np.int64)),
          num_oov_buckets=1)
      return table.lookup(word)

    comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        table_lookup,
        computation_types.TensorType(dtype=tf.string, shape=(None,)),
        context_stack_impl.context_stack)
    self.assertEqual(
        str(type_serialization.deserialize_type(comp.type)),
        '(string[?] -> int64[?])')
    self.assertEqual(str(extra_type_spec), '(string[?] -> int64[?])')
    self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')

    with tf.Graph().as_default() as g:
      tf.import_graph_def(
          serialization_utils.unpack_graph_def(comp.tensorflow.graph_def),
          name='')
    with tf.compat.v1.Session(graph=g) as sess:
      sess.run(fetches=comp.tensorflow.initialize_op)
      results = sess.run(
          fetches=comp.tensorflow.result.tensor.tensor_name,
          feed_dict={
              comp.tensorflow.parameter.tensor.tensor_name: ['b', 'c', 'a']
          })
    self.assertAllEqual(results, [1, 2, 0])
 def test_basic_functionality_of_call_class(self):
     x = building_blocks.Reference(
         'foo', computation_types.FunctionType(tf.int32, tf.bool))
     y = building_blocks.Reference('bar', tf.int32)
     z = building_blocks.Call(x, y)
     self.assertEqual(str(z.type_signature), 'bool')
     self.assertIs(z.function, x)
     self.assertIs(z.argument, y)
     self.assertEqual(
         repr(z), 'Call(Reference(\'foo\', '
         'FunctionType(TensorType(tf.int32), TensorType(tf.bool))), '
         'Reference(\'bar\', TensorType(tf.int32)))')
     self.assertEqual(z.compact_representation(), 'foo(bar)')
     with self.assertRaises(TypeError):
         building_blocks.Call(x)
     w = building_blocks.Reference('bak', tf.float32)
     with self.assertRaises(TypeError):
         building_blocks.Call(x, w)
     z_proto = z.proto
     self.assertEqual(type_serialization.deserialize_type(z_proto.type),
                      z.type_signature)
     self.assertEqual(z_proto.WhichOneof('computation'), 'call')
     self.assertEqual(str(z_proto.call.function), str(x.proto))
     self.assertEqual(str(z_proto.call.argument), str(y.proto))
     self._serialize_deserialize_roundtrip_test(z)
Exemple #7
0
def run_tensorflow(computation_proto, arg=None):
    """Runs a TensorFlow computation with argument `arg`.

  Args:
    computation_proto: An instance of `pb.Computation`.
    arg: The argument to invoke the computation with, or None if the computation
      does not specify a parameter type and does not expects one.

  Returns:
    The result of the computation.
  """
    with tf.Graph().as_default() as graph:
        type_signature = type_serialization.deserialize_type(
            computation_proto.type)
        if type_signature.parameter is not None:
            stamped_arg = _stamp_value_into_graph(arg,
                                                  type_signature.parameter,
                                                  graph)
        else:
            stamped_arg = None
        init_op, result = tensorflow_deserialization.deserialize_and_call_tf_computation(
            computation_proto, stamped_arg, graph)
    with tf.compat.v1.Session(graph=graph) as sess:
        if init_op:
            sess.run(init_op)
        result = tensorflow_utils.fetch_value_in_session(sess, result)
    return result
  def test_serialize_tensorflow_with_dataset_not_optimized(self):

    @tf.function
    def test_foo(ds):
      return ds.reduce(np.int64(0), lambda x, y: x + y)

    def legacy_dataset_reducer_example(ds):
      return test_foo(ds)

    comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
        legacy_dataset_reducer_example,
        computation_types.SequenceType(tf.int64),
        context_stack_impl.context_stack)
    self.assertEqual(
        str(type_serialization.deserialize_type(comp.type)),
        '(int64* -> int64)')
    self.assertEqual(str(extra_type_spec), '(int64* -> int64)')
    self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
    parameter = tf.data.Dataset.range(5)

    graph_def = serialization_utils.unpack_graph_def(comp.tensorflow.graph_def)
    self.assertGraphDoesNotContainOps(
        graph_def, ['OptimizeDataset', 'OptimizeDatasetV2', 'ModelDataset'])
    results = tf.compat.v1.Session().run(
        tf.import_graph_def(
            graph_def, {
                comp.tensorflow.parameter.sequence.variant_tensor_name:
                    tf.data.experimental.to_variant(parameter)
            }, [comp.tensorflow.result.tensor.tensor_name]))
    self.assertEqual(results, [10])
Exemple #9
0
 def test_basic_functionality_of_lambda_class(self):
   arg_name = 'arg'
   arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)),
               ('x', tf.int32)]
   arg = building_blocks.Reference(arg_name, arg_type)
   arg_f = building_blocks.Selection(arg, name='f')
   arg_x = building_blocks.Selection(arg, name='x')
   x = building_blocks.Lambda(
       arg_name, arg_type,
       building_blocks.Call(arg_f, building_blocks.Call(arg_f, arg_x)))
   self.assertEqual(
       str(x.type_signature), '(<f=(int32 -> int32),x=int32> -> int32)')
   self.assertEqual(x.parameter_name, arg_name)
   self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>')
   self.assertEqual(x.result.compact_representation(), 'arg.f(arg.f(arg.x))')
   arg_type_repr = (
       'NamedTupleType(['
       '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), '
       '(\'x\', TensorType(tf.int32))])')
   self.assertEqual(
       repr(x), 'Lambda(\'arg\', {0}, '
       'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
       'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
       'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format(
           arg_type_repr))
   self.assertEqual(x.compact_representation(), '(arg -> arg.f(arg.f(arg.x)))')
   x_proto = x.proto
   self.assertEqual(
       type_serialization.deserialize_type(x_proto.type), x.type_signature)
   self.assertEqual(x_proto.WhichOneof('computation'), 'lambda')
   self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name)
   self.assertEqual(
       str(getattr(x_proto, 'lambda').result), str(x.result.proto))
   self._serialize_deserialize_roundtrip_test(x)
    def test_serialize_jax_with_int32_to_int32(self):
        self.skipTest('HLO pattern matching broken by '
                      'https://github.com/google/jax/pull/10232')

        def traced_fn(x):
            return x + 10

        param_type = computation_types.to_type(np.int32)
        arg_fn = function_utils.create_argument_unpacking_fn(
            traced_fn, param_type)
        ctx_stack = context_stack_impl.context_stack
        comp_pb = jax_serialization.serialize_jax_computation(
            traced_fn, arg_fn, param_type, ctx_stack)
        self.assertIsInstance(comp_pb, pb.Computation)
        self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
        type_spec = type_serialization.deserialize_type(comp_pb.type)
        self.assertEqual(str(type_spec), '(int32 -> int32)')
        xla_comp = xla_serialization.unpack_xla_computation(
            comp_pb.xla.hlo_module)
        self.assertIn('ROOT tuple.6 = (s32[]) tuple(add.5)',
                      xla_comp.as_hlo_text())
        self.assertEqual(str(comp_pb.xla.result), str(comp_pb.xla.parameter))
        self.assertEqual(str(comp_pb.xla.result), 'tensor {\n'
                         '  index: 0\n'
                         '}\n')
    def __init__(self, computation_proto, context_stack, annotated_type=None):
        """Constructs a new instance of ComputationImpl from the computation_proto.

    Args:
      computation_proto: The protocol buffer that represents the computation, an
        instance of pb.Computation.
      context_stack: The context stack to use.
      annotated_type: Optional, type information with additional annotations
        that replaces the information in `computation_proto.type`.

    Raises:
      TypeError: if `annotated_type` is not `None` and is not compatible with
      `computation_proto.type`.
    """
        py_typecheck.check_type(computation_proto, pb.Computation)
        py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
        type_spec = type_serialization.deserialize_type(computation_proto.type)
        py_typecheck.check_type(type_spec, computation_types.Type)
        if annotated_type is not None:
            if not type_spec.is_assignable_from(annotated_type):
                raise TypeError(
                    'annotated_type not compatible with computation_proto.type\n'
                    'computation_proto.type: {!s}\n'
                    'annotated_type: {!s}'.format(type_spec, annotated_type))
            type_spec = annotated_type

        type_analysis.check_well_formed(type_spec)

        if not type_spec.is_function():
            raise TypeError(
                '{} is not a functional type, from proto: {}'.format(
                    str(type_spec), str(computation_proto)))

        super().__init__(type_spec, context_stack)
        self._computation_proto = computation_proto
def deserialize_sequence_value(sequence_value_proto):
    """Deserializes a `tf.data.Dataset`.

  Args:
    sequence_value_proto: `Sequence` protocol buffer message.

  Returns:
    A tuple of `(tf.data.Dataset, tff.Type)`.
  """
    py_typecheck.check_type(sequence_value_proto, executor_pb2.Value.Sequence)

    which_value = sequence_value_proto.WhichOneof('value')
    if which_value == 'zipped_saved_model':
        ds = tensorflow_serialization.deserialize_dataset(
            sequence_value_proto.zipped_saved_model)
    else:
        raise NotImplementedError(
            'Deserializing Sequences enocded as {!s} has not been implemented'.
            format(which_value))

    element_type = type_serialization.deserialize_type(
        sequence_value_proto.element_type)

    # If a serialized dataset had elements of nested structes of tensors (e.g.
    # `dict`, `OrderedDict`), the deserialized dataset will return `dict`,
    # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion).
    #
    # Since the dataset will only be used inside TFF, we wrap the dictionary
    # coming from TF in an `OrderedDict` when necessary (a type that both TF and
    # TFF understand), using the field order stored in the TFF type stored during
    # serialization.
    ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
        ds, element_type)

    return ds, computation_types.SequenceType(element=element_type)
Exemple #13
0
def _serialize_computation(
        comp: computation_pb2.Computation,
        type_spec: Optional[computation_types.Type]) -> _SerializeReturnType:
    """Serializes a TFF computation."""
    type_spec = executor_utils.reconcile_value_type_with_type_spec(
        type_serialization.deserialize_type(comp.type), type_spec)
    return serialization_bindings.Value(computation=comp), type_spec
 def test_serialize_tensorflow_with_structured_type_signature(self):
     batch_type = collections.namedtuple('BatchType', ['x', 'y'])
     output_type = collections.namedtuple('OutputType', ['A', 'B'])
     comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
         lambda z: output_type(2.0 * tf.cast(z.x, tf.float32), 3.0 * z.y),
         batch_type(tf.int32, (tf.float32, [2])),
         context_stack_impl.context_stack)
     self.assertEqual(
         str(type_serialization.deserialize_type(comp.type)),
         '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)')
     self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
     self.assertEqual(
         str(extra_type_spec),
         '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)')
     self.assertIsInstance(
         extra_type_spec.parameter,
         computation_types.NamedTupleTypeWithPyContainerType)
     self.assertIs(
         computation_types.NamedTupleTypeWithPyContainerType.
         get_container_type(extra_type_spec.parameter), batch_type)
     self.assertIsInstance(
         extra_type_spec.result,
         computation_types.NamedTupleTypeWithPyContainerType)
     self.assertIs(
         computation_types.NamedTupleTypeWithPyContainerType.
         get_container_type(extra_type_spec.result), output_type)
def _deserialize_sequence_value(
    sequence_value_proto: executor_pb2.Value.Sequence
) -> _DeserializeReturnType:
    """Deserializes a `tf.data.Dataset`.

  Args:
    sequence_value_proto: `Sequence` protocol buffer message.

  Returns:
    A tuple of `(tf.data.Dataset, tff.Type)`.
  """
    element_type = type_serialization.deserialize_type(
        sequence_value_proto.element_type)
    which_value = sequence_value_proto.WhichOneof('value')
    if which_value == 'zipped_saved_model':
        warnings.warn(
            'Deserializng a sequence value that was encoded as a zipped SavedModel.'
            ' This is a deprecated path, please update the binary that is '
            'serializing the sequences.', DeprecationWarning)
        ds = _deserialize_dataset_from_zipped_saved_model(
            sequence_value_proto.zipped_saved_model)
        ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
            ds, element_type)
    elif which_value == 'serialized_graph_def':
        ds = _deserialize_dataset_from_graph_def(
            sequence_value_proto.serialized_graph_def, element_type)
    else:
        raise NotImplementedError(
            'Deserializing Sequences enocded as {!s} has not been implemented'.
            format(which_value))
    return ds, computation_types.SequenceType(element=element_type)
 async def _evaluate_lambda(
     self,
     comp: pb.Computation,
     scope: ReferenceResolvingExecutorScope,
 ) -> ReferenceResolvingExecutorValue:
   type_spec = type_serialization.deserialize_type(comp.type)
   return ReferenceResolvingExecutorValue(
       ScopedLambda(comp, scope), type_spec=type_spec)
 async def _evaluate_to_delegate(
     self,
     comp: pb.Computation,
     scope: ReferenceResolvingExecutorScope,
 ) -> ReferenceResolvingExecutorValue:
   return ReferenceResolvingExecutorValue(
       await self._target_executor.create_value(
           comp, type_serialization.deserialize_type(comp.type)))
    def test_returns_computation_sequence(self):
        type_signature = computation_types.SequenceType(tf.int32)

        proto = computation_factory.create_lambda_identity(type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = type_factory.unary_op(type_signature)
        self.assertEqual(actual_type, expected_type)
    def test_returns_computation_tuple_unnamed(self):
        type_signature = computation_types.StructType([tf.int32, tf.float32])

        proto = computation_factory.create_lambda_identity(type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = type_factory.unary_op(type_signature)
        self.assertEqual(actual_type, expected_type)
Exemple #20
0
def _deserialize_type_spec(serialize_type_variable, python_container=None):
    """Deserialize a `tff.Type` protocol buffer into a python class instance."""
    type_spec = type_serialization.deserialize_type(
        computation_pb2.Type.FromString(
            serialize_type_variable.read_value().numpy()))
    if type_spec.is_struct() and python_container is not None:
        type_spec = computation_types.StructWithPythonType(
            structure.iter_elements(type_spec), python_container)
    return type_conversions.type_to_tf_structure(type_spec)
Exemple #21
0
 def test_create_xla_tff_computation_int32x10_to_int32x10(self):
   xla_comp = _make_test_xla_comp_int32x10_to_int32x10()
   comp_pb = xla_serialization.create_xla_tff_computation(
       xla_comp, [0],
       computation_types.FunctionType((np.int32, (10,)), (np.int32, (10,))))
   self.assertIsInstance(comp_pb, pb.Computation)
   self.assertEqual(comp_pb.WhichOneof('computation'), 'xla')
   type_spec = type_serialization.deserialize_type(comp_pb.type)
   self.assertEqual(str(type_spec), '(int32[10] -> int32[10])')
Exemple #22
0
    def test_returns_computation(self, type_signature, value):
        proto = tensorflow_computation_factory.create_identity(type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = type_factory.unary_op(type_signature)
        self.assertEqual(actual_type, expected_type)
        actual_result = test_utils.run_tensorflow(proto, value)
        self.assertEqual(actual_result, value)
Exemple #23
0
    def test_returns_computation(self):
        proto, _ = tensorflow_computation_factory.create_empty_tuple()

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(None, [])
        expected_type.check_assignable_from(actual_type)
        actual_result = test_utils.run_tensorflow(proto)
        expected_result = structure.Struct([])
        self.assertEqual(actual_result, expected_result)
Exemple #24
0
    def test_returns_coputation(self):
        proto = tensorflow_computation_factory.create_empty_tuple()

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(None, [])
        self.assertEqual(actual_type, expected_type)
        actual_result = test_utils.run_tensorflow(proto)
        expected_result = anonymous_tuple.AnonymousTuple([])
        self.assertEqual(actual_result, expected_result)
Exemple #25
0
 def test_serialize_deserialize_named_tuple_types_py_container(self):
     # The Py container is destroyed during ser/de.
     with_container = computation_types.StructWithPythonType(
         (tf.int32, tf.bool), tuple)
     p1 = type_serialization.serialize_type(with_container)
     without_container = type_serialization.deserialize_type(p1)
     self.assertNotEqual(with_container, without_container)  # Not equal.
     self.assertIsInstance(without_container, computation_types.StructType)
     self.assertNotIsInstance(without_container,
                              computation_types.StructWithPythonType)
     with_container.check_equivalent_to(without_container)
 def test_basic_functionality_of_placement_class(self):
   x = building_blocks.Placement(placement_literals.CLIENTS)
   self.assertEqual(str(x.type_signature), 'placement')
   self.assertEqual(x.uri, 'clients')
   self.assertEqual(repr(x), 'Placement(\'clients\')')
   self.assertEqual(x.compact_representation(), 'CLIENTS')
   x_proto = x.proto
   self.assertEqual(
       type_serialization.deserialize_type(x_proto.type), x.type_signature)
   self.assertEqual(x_proto.WhichOneof('computation'), 'placement')
   self.assertEqual(x_proto.placement.uri, x.uri)
   self._serialize_deserialize_roundtrip_test(x)
 def test_basic_functionality_of_reference_class(self):
   x = building_blocks.Reference('foo', tf.int32)
   self.assertEqual(x.name, 'foo')
   self.assertEqual(str(x.type_signature), 'int32')
   self.assertEqual(repr(x), 'Reference(\'foo\', TensorType(tf.int32))')
   self.assertEqual(x.compact_representation(), 'foo')
   x_proto = x.proto
   self.assertEqual(
       type_serialization.deserialize_type(x_proto.type), x.type_signature)
   self.assertEqual(x_proto.WhichOneof('computation'), 'reference')
   self.assertEqual(x_proto.reference.name, x.name)
   self._serialize_deserialize_roundtrip_test(x)
Exemple #28
0
    def test_returns_computation(self, type_signature, count, value):
        proto, _ = tensorflow_computation_factory.create_replicate_input(
            type_signature, count)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(
            type_signature, [type_signature] * count)
        expected_type.check_assignable_from(actual_type)
        actual_result = test_utils.run_tensorflow(proto, value)
        expected_result = structure.Struct([(None, value)] * count)
        self.assertEqual(actual_result, expected_result)
Exemple #29
0
 def test_serialize_tensorflow_with_no_parameter(self):
   comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
       lambda: tf.constant(99), None, context_stack_impl.context_stack)
   self.assertEqual(
       str(type_serialization.deserialize_type(comp.type)), '( -> int32)')
   self.assertEqual(str(extra_type_spec), '( -> int32)')
   self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
   results = tf.compat.v1.Session().run(
       tf.import_graph_def(
           serialization_utils.unpack_graph_def(comp.tensorflow.graph_def),
           None, [comp.tensorflow.result.tensor.tensor_name]))
   self.assertEqual(results, [99])
Exemple #30
0
def to_representation_for_type(value, type_spec, backend=None):
    """Verifies or converts the `value` to executor payload matching `type_spec`.

  The following kinds of `value` are supported:

  * Computations, either `pb.Computation` or `computation_impl.ComputationImpl`.

  * Numpy arrays and scalars, or Python scalars that are converted to Numpy.

  * Nested structures of the above.

  Args:
    value: The raw representation of a value to compare against `type_spec` and
      potentially to be converted.
    type_spec: An instance of `tff.Type`. Can be `None` for values that derive
      from `typed_object.TypedObject`.
    backend: The backend to use; an instance of `xla_client.Client`. Only used
      for functional types. Can be `None` if unused.

  Returns:
    Either `value` itself, or a modified version of it.

  Raises:
    TypeError: If the `value` is not compatible with `type_spec`.
    ValueError: If the arguments are incorrect.
  """
    if backend is not None:
        py_typecheck.check_type(backend, xla_client.Client)
    if type_spec is not None:
        type_spec = computation_types.to_type(type_spec)
    type_spec = type_utils.reconcile_value_with_type_spec(value, type_spec)
    if isinstance(value, computation_base.Computation):
        return to_representation_for_type(
            computation_impl.ComputationImpl.get_proto(value), type_spec,
            backend)
    if isinstance(value, pb.Computation):
        comp_type = type_serialization.deserialize_type(value.type)
        if type_spec is not None:
            comp_type.check_equivalent_to(type_spec)
        return _ComputationCallable(value, comp_type, backend)
    if isinstance(type_spec, computation_types.StructType):
        return structure.map_structure(
            lambda v, t: to_representation_for_type(v, t, backend),
            structure.from_container(value, recursive=True), type_spec)
    if isinstance(type_spec, computation_types.TensorType):
        type_spec.shape.assert_is_fully_defined()
        type_analysis.check_type(value, type_spec)
        if type_spec.shape.rank == 0:
            return np.dtype(type_spec.dtype.as_numpy_dtype).type(value)
        if type_spec.shape.rank > 0:
            return np.array(value, dtype=type_spec.dtype.as_numpy_dtype)
        raise TypeError('Unsupported tensor shape {}.'.format(type_spec.shape))
    raise TypeError('Unexpected type {}.'.format(type_spec))