Exemplo n.º 1
0
def deserialize_value(value_proto):
    """Deserializes a value (of any type) from `executor_pb2.Value`.

  Args:
    value_proto: An instance of `executor_pb2.Value`.

  Returns:
    A tuple `(value, type_spec)`, where `value` is a deserialized
    representation of the transmitted value (e.g., Numpy array, or a
    `pb.Computation` instance), and `type_spec` is an instance of
    `tff.TensorType` that represents its type.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the value is malformed.
  """
    py_typecheck.check_type(value_proto, executor_pb2.Value)
    which_value = value_proto.WhichOneof('value')
    if which_value == 'tensor':
        return deserialize_tensor_value(value_proto)
    elif which_value == 'computation':
        return (value_proto.computation,
                type_serialization.deserialize_type(
                    value_proto.computation.type))
    elif which_value == 'tuple':
        val_elems = []
        type_elems = []
        for e in value_proto.tuple.element:
            name = e.name if e.name else None
            e_val, e_type = deserialize_value(e.value)
            val_elems.append((name, e_val))
            type_elems.append((name, e_type) if name else e_type)
        return (anonymous_tuple.AnonymousTuple(val_elems),
                computation_types.NamedTupleType(type_elems))
    elif which_value == 'sequence':
        return deserialize_sequence_value(value_proto.sequence)
    elif which_value == 'federated':
        type_spec = type_serialization.deserialize_type(
            computation_pb2.Type(federated=value_proto.federated.type))
        value = []
        for item in value_proto.federated.value:
            item_value, item_type = deserialize_value(item)
            type_utils.check_assignable_from(type_spec.member, item_type)
            value.append(item_value)
        if type_spec.all_equal:
            if len(value) == 1:
                value = value[0]
            else:
                raise ValueError(
                    'Return an all_equal value with {} member consatituents.'.
                    format(len(value)))
        return value, type_spec
    else:
        raise ValueError(
            'Unable to deserialize a value of type {}.'.format(which_value))
Exemplo n.º 2
0
 def test_basic_functionality_of_call_class(self):
     x = building_blocks.Reference(
         'foo', computation_types.FunctionType(tf.int32, tf.bool))
     y = building_blocks.Reference('bar', tf.int32)
     z = building_blocks.Call(x, y)
     self.assertEqual(str(z.type_signature), 'bool')
     self.assertIs(z.function, x)
     self.assertIs(z.argument, y)
     self.assertEqual(
         repr(z), 'Call(Reference(\'foo\', '
         'FunctionType(TensorType(tf.int32), TensorType(tf.bool))), '
         'Reference(\'bar\', TensorType(tf.int32)))')
     self.assertEqual(z.compact_representation(), 'foo(bar)')
     with self.assertRaises(TypeError):
         building_blocks.Call(x)
     w = building_blocks.Reference('bak', tf.float32)
     with self.assertRaises(TypeError):
         building_blocks.Call(x, w)
     z_proto = z.proto
     self.assertEqual(type_serialization.deserialize_type(z_proto.type),
                      z.type_signature)
     self.assertEqual(z_proto.WhichOneof('computation'), 'call')
     self.assertEqual(str(z_proto.call.function), str(x.proto))
     self.assertEqual(str(z_proto.call.argument), str(y.proto))
     self._serialize_deserialize_roundtrip_test(z)
Exemplo n.º 3
0
 def test_basic_functionality_of_block_class(self):
     x = building_blocks.Block([
         ('x', building_blocks.Reference('arg', (tf.int32, tf.int32))),
         ('y',
          building_blocks.Selection(
              building_blocks.Reference('x',
                                        (tf.int32, tf.int32)), index=0))
     ], building_blocks.Reference('y', tf.int32))
     self.assertEqual(str(x.type_signature), 'int32')
     self.assertEqual([(k, v.compact_representation())
                       for k, v in x.locals], [('x', 'arg'), ('y', 'x[0]')])
     self.assertEqual(x.result.compact_representation(), 'y')
     self.assertEqual(
         repr(x), 'Block([(\'x\', Reference(\'arg\', '
         'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)]))), '
         '(\'y\', Selection(Reference(\'x\', '
         'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)])), '
         'index=0))], '
         'Reference(\'y\', TensorType(tf.int32)))')
     self.assertEqual(x.compact_representation(), '(let x=arg,y=x[0] in y)')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'block')
     self.assertEqual(str(x_proto.block.result), str(x.result.proto))
     for idx, loc_proto in enumerate(x_proto.block.local):
         loc_name, loc_value = x.locals[idx]
         self.assertEqual(loc_proto.name, loc_name)
         self.assertEqual(str(loc_proto.value), str(loc_value.proto))
         self._serialize_deserialize_roundtrip_test(x)
Exemplo n.º 4
0
 def test_intrinsic_class_succeeds_simple_federated_map(self):
     simple_function = computation_types.FunctionType(tf.int32, tf.float32)
     federated_arg = computation_types.FederatedType(
         simple_function.parameter, placements.CLIENTS)
     federated_result = computation_types.FederatedType(
         simple_function.result, placements.CLIENTS)
     federated_map_concrete_type = computation_types.FunctionType(
         [simple_function, federated_arg], federated_result)
     concrete_federated_map = building_blocks.Intrinsic(
         intrinsic_defs.FEDERATED_MAP.uri, federated_map_concrete_type)
     self.assertIsInstance(concrete_federated_map,
                           building_blocks.Intrinsic)
     self.assertEqual(
         str(concrete_federated_map.type_signature),
         '(<(int32 -> float32),{int32}@CLIENTS> -> {float32}@CLIENTS)')
     self.assertEqual(concrete_federated_map.uri, 'federated_map')
     self.assertEqual(concrete_federated_map.compact_representation(),
                      'federated_map')
     concrete_federated_map_proto = concrete_federated_map.proto
     self.assertEqual(
         type_serialization.deserialize_type(
             concrete_federated_map_proto.type),
         concrete_federated_map.type_signature)
     self.assertEqual(
         concrete_federated_map_proto.WhichOneof('computation'),
         'intrinsic')
     self.assertEqual(concrete_federated_map_proto.intrinsic.uri,
                      concrete_federated_map.uri)
     self._serialize_deserialize_roundtrip_test(concrete_federated_map)
Exemplo n.º 5
0
 def test_basic_functionality_of_tuple_class(self):
     x = building_blocks.Reference('foo', tf.int32)
     y = building_blocks.Reference('bar', tf.bool)
     z = building_blocks.Tuple([x, ('y', y)])
     with self.assertRaises(ValueError):
         _ = building_blocks.Tuple([('', y)])
     self.assertIsInstance(z, anonymous_tuple.AnonymousTuple)
     self.assertEqual(str(z.type_signature), '<int32,y=bool>')
     self.assertEqual(
         repr(z),
         'Tuple([(None, Reference(\'foo\', TensorType(tf.int32))), (\'y\', '
         'Reference(\'bar\', TensorType(tf.bool)))])')
     self.assertEqual(z.compact_representation(), '<foo,y=bar>')
     self.assertEqual(dir(z), ['y'])
     self.assertIs(z.y, y)
     self.assertLen(z, 2)
     self.assertIs(z[0], x)
     self.assertIs(z[1], y)
     self.assertEqual(','.join(e.compact_representation() for e in iter(z)),
                      'foo,bar')
     z_proto = z.proto
     self.assertEqual(type_serialization.deserialize_type(z_proto.type),
                      z.type_signature)
     self.assertEqual(z_proto.WhichOneof('computation'), 'tuple')
     self.assertEqual([e.name for e in z_proto.tuple.element], ['', 'y'])
     self._serialize_deserialize_roundtrip_test(z)
Exemplo n.º 6
0
def run_tensorflow(computation_proto, arg=None):
    """Runs a TensorFlow computation with argument `arg`.

  Args:
    computation_proto: An instance of `pb.Computation`.
    arg: The argument to invoke the computation with, or None if the computation
      does not specify a parameter type and does not expects one.

  Returns:
    The result of the computation.
  """
    with tf.Graph().as_default() as graph:
        type_signature = type_serialization.deserialize_type(
            computation_proto.type)
        if type_signature.parameter is not None:
            stamped_arg = _stamp_value_into_graph(arg,
                                                  type_signature.parameter,
                                                  graph)
        else:
            stamped_arg = None
        init_op, result = tensorflow_deserialization.deserialize_and_call_tf_computation(
            computation_proto, stamped_arg, graph)
    with tf.compat.v1.Session(graph=graph) as sess:
        if init_op:
            sess.run(init_op)
        result = tensorflow_utils.fetch_value_in_session(sess, result)
    return result
Exemplo n.º 7
0
    def test_returns_coputation(self):
        proto = computation_factory.create_lambda_empty_tuple()

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(None, [])
        self.assertEqual(actual_type, expected_type)
Exemplo n.º 8
0
 def test_serialize_tensorflow_with_structured_type_signature(self):
     batch_type = collections.namedtuple('BatchType', ['x', 'y'])
     output_type = collections.namedtuple('OutputType', ['A', 'B'])
     comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
         lambda z: output_type(2.0 * tf.cast(z.x, tf.float32), 3.0 * z.y),
         batch_type(tf.int32, (tf.float32, [2])),
         context_stack_impl.context_stack)
     self.assertEqual(
         str(type_serialization.deserialize_type(comp.type)),
         '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)')
     self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
     self.assertEqual(
         str(extra_type_spec),
         '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)')
     self.assertIsInstance(
         extra_type_spec.parameter,
         computation_types.NamedTupleTypeWithPyContainerType)
     self.assertIs(
         computation_types.NamedTupleTypeWithPyContainerType.
         get_container_type(extra_type_spec.parameter), batch_type)
     self.assertIsInstance(
         extra_type_spec.result,
         computation_types.NamedTupleTypeWithPyContainerType)
     self.assertIs(
         computation_types.NamedTupleTypeWithPyContainerType.
         get_container_type(extra_type_spec.result), output_type)
Exemplo n.º 9
0
    def test_serialize_tensorflow_with_dataset_not_optimized(self):
        @tf.function
        def test_foo(ds):
            return ds.reduce(np.int64(0), lambda x, y: x + y)

        def legacy_dataset_reducer_example(ds):
            return test_foo(ds)

        comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            legacy_dataset_reducer_example,
            computation_types.SequenceType(tf.int64),
            context_stack_impl.context_stack)
        self.assertEqual(str(type_serialization.deserialize_type(comp.type)),
                         '(int64* -> int64)')
        self.assertEqual(str(extra_type_spec), '(int64* -> int64)')
        self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
        parameter = tf.data.Dataset.range(5)

        graph_def = serialization_utils.unpack_graph_def(
            comp.tensorflow.graph_def)
        self.assertGraphDoesNotContainOps(graph_def,
                                          ['OptimizeDataset', 'ModelDataste'])
        results = tf.compat.v1.Session().run(
            tf.import_graph_def(
                graph_def, {
                    comp.tensorflow.parameter.sequence.variant_tensor_name:
                    tf.data.experimental.to_variant(parameter)
                }, [comp.tensorflow.result.tensor.tensor_name]))
        self.assertEqual(results, [10])
Exemplo n.º 10
0
def deserialize_sequence_value(sequence_value_proto):
  """Deserializes a `tf.data.Dataset`.

  Args:
    sequence_value_proto: `Sequence` protocol buffer message.

  Returns:
    A tuple of `(tf.data.Dataset, tff.Type)`.
  """
  py_typecheck.check_type(sequence_value_proto, executor_pb2.Value.Sequence)

  which_value = sequence_value_proto.WhichOneof('value')
  if which_value == 'zipped_saved_model':
    ds = tensorflow_serialization.deserialize_dataset(
        sequence_value_proto.zipped_saved_model)
  else:
    raise NotImplementedError(
        'Deserializing Sequences enocded as {!s} has not been implemented'
        .format(which_value))

  element_type = type_serialization.deserialize_type(
      sequence_value_proto.element_type)

  # If a serialized dataset had elements of nested structes of tensors (e.g.
  # `dict`, `OrderedDict`), the deserialized dataset will return `dict`,
  # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion).
  #
  # Since the dataset will only be used inside TFF, we wrap the dictionary
  # coming from TF in an `OrderedDict` when necessary (a type that both TF and
  # TFF understand), using the field order stored in the TFF type stored during
  # serialization.
  ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
      ds, element_type)

  return ds, computation_types.SequenceType(element=element_type)
Exemplo n.º 11
0
  def __init__(self, value, scope=None, type_spec=None):
    """Creates an instance of a value embedded in a lambda executor.

    The internal representation of a value can take one of the following
    supported forms:

    * An instance of `executor_value_base.ExecutorValue` that represents a
      value embedded in the target executor (functional or non-functional).

    * An as-yet unprocessed instance of `pb.Computation` that represents a
      function yet to be invoked (always a value of a functional type; any
      non-functional constructs should be processed on the fly).

    * A coroutine callable in Python that accepts a single argument that must
      be an instance of `LambdaExecutorValue` (or `None`), and that returns a
      result that is also an instance of `LambdaExecutorValue`. The associated
      type signature is always functional.

    * A single-level tuple (`anonymous_tuple.AnonymousTuple`) of instances
      of this class (of any of the supported forms listed here).

    Args:
      value: The internal representation of a value, as specified above.
      scope: An optional scope for computations. Only allowed if `value` is an
        unprocessed instance of `pb.Computation`, otherwise it must be `None`
        (the scope is meaningless in other cases).
      type_spec: An optional type signature, only allowed if `value` is a
        callable that represents a function (in which case it must be an
        instance of `computation_types.FunctionType`), otherwise it  must be
        `None` (the type is implied in other cases).
    """
    if isinstance(value, executor_value_base.ExecutorValue):
      py_typecheck.check_none(scope)
      py_typecheck.check_none(type_spec)
      type_spec = value.type_signature  # pytype: disable=attribute-error
    elif isinstance(value, pb.Computation):
      if scope is not None:
        py_typecheck.check_type(scope, LambdaExecutorScope)
      py_typecheck.check_none(type_spec)
      type_spec = type_utils.get_function_type(
          type_serialization.deserialize_type(value.type))
    elif callable(value):
      py_typecheck.check_none(scope)
      py_typecheck.check_type(type_spec, computation_types.FunctionType)
    else:
      py_typecheck.check_type(value, anonymous_tuple.AnonymousTuple)
      py_typecheck.check_none(scope)
      py_typecheck.check_none(type_spec)
      type_elements = []
      for k, v in anonymous_tuple.iter_elements(value):
        py_typecheck.check_type(v, LambdaExecutorValue)
        type_elements.append((k, v.type_signature))
      type_spec = computation_types.NamedTupleType([
          (k, v) if k is not None else v for k, v in type_elements
      ])
    self._value = value
    self._scope = scope
    self._type_signature = type_spec
Exemplo n.º 12
0
    def test_returns_computation_sequence(self):
        type_signature = computation_types.SequenceType(tf.int32)

        proto = computation_factory.create_lambda_identity(type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = type_factory.unary_op(type_signature)
        self.assertEqual(actual_type, expected_type)
    def test_returns_computation(self, type_signature, value):
        proto = tensorflow_computation_factory.create_identity(type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = type_factory.unary_op(type_signature)
        self.assertEqual(actual_type, expected_type)
        actual_result = test_utils.run_tensorflow(proto, value)
        self.assertEqual(actual_result, value)
Exemplo n.º 14
0
    def test_returns_computation_tuple_named(self):
        type_signature = computation_types.NamedTupleType([('a', tf.int32),
                                                           ('b', tf.float32)])

        proto = computation_factory.create_lambda_identity(type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = type_factory.unary_op(type_signature)
        self.assertEqual(actual_type, expected_type)
Exemplo n.º 15
0
    def test_returns_coputation(self):
        proto = tensorflow_computation_factory.create_empty_tuple()

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(None, [])
        self.assertEqual(actual_type, expected_type)
        expected_value = anonymous_tuple.AnonymousTuple([])
        actual_value = test_utils.run_tensorflow(proto, expected_value)
        self.assertEqual(actual_value, expected_value)
Exemplo n.º 16
0
 def test_basic_functionality_of_placement_class(self):
     x = building_blocks.Placement(placements.CLIENTS)
     self.assertEqual(str(x.type_signature), 'placement')
     self.assertEqual(x.uri, 'clients')
     self.assertEqual(repr(x), 'Placement(\'clients\')')
     self.assertEqual(x.compact_representation(), 'CLIENTS')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'placement')
     self.assertEqual(x_proto.placement.uri, x.uri)
     self._serialize_deserialize_roundtrip_test(x)
Exemplo n.º 17
0
    def test_returns_computation_sequence(self):
        type_signature = computation_types.SequenceType(tf.int32)

        proto = tensorflow_computation_factory.create_identity(type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = type_factory.unary_op(type_signature)
        self.assertEqual(actual_type, expected_type)
        expected_value = [10] * 3
        actual_value = test_utils.run_tensorflow(proto, expected_value)
        self.assertEqual(actual_value, expected_value)
Exemplo n.º 18
0
 def test_basic_functionality_of_reference_class(self):
     x = building_blocks.Reference('foo', tf.int32)
     self.assertEqual(x.name, 'foo')
     self.assertEqual(str(x.type_signature), 'int32')
     self.assertEqual(repr(x), 'Reference(\'foo\', TensorType(tf.int32))')
     self.assertEqual(x.compact_representation(), 'foo')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'reference')
     self.assertEqual(x_proto.reference.name, x.name)
     self._serialize_deserialize_roundtrip_test(x)
Exemplo n.º 19
0
 def test_serialize_tensorflow_with_no_parameter(self):
   comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
       lambda: tf.constant(99), None, context_stack_impl.context_stack)
   self.assertEqual(
       str(type_serialization.deserialize_type(comp.type)), '( -> int32)')
   self.assertEqual(str(extra_type_spec), '( -> int32)')
   self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
   results = tf.compat.v1.Session().run(
       tf.import_graph_def(
           serialization_utils.unpack_graph_def(comp.tensorflow.graph_def),
           None, [comp.tensorflow.result.tensor.tensor_name]))
   self.assertEqual(results, [99])
Exemplo n.º 20
0
 def test_basic_intrinsic_functionality_plus_canonical_typecheck(self):
     x = building_blocks.Intrinsic(
         'generic_plus',
         computation_types.FunctionType([tf.int32, tf.int32], tf.int32))
     self.assertEqual(str(x.type_signature), '(<int32,int32> -> int32)')
     self.assertEqual(x.uri, 'generic_plus')
     self.assertEqual(x.compact_representation(), 'generic_plus')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic')
     self.assertEqual(x_proto.intrinsic.uri, x.uri)
     self._serialize_deserialize_roundtrip_test(x)
Exemplo n.º 21
0
    def test_returns_computation_with_tensor_float(self):
        value = 10.0
        type_signature = computation_types.TensorType(tf.float32, [3])
        proto = tensorflow_computation_factory.create_constant(
            value, type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(None, type_signature)
        self.assertEqual(actual_type, expected_type)
        expected_value = [value] * 3
        actual_value = test_utils.run_tensorflow(proto, expected_value)
        self.assertCountEqual(actual_value, expected_value)
Exemplo n.º 22
0
    def test_returns_computation_with_tuple_unnamed(self):
        value = 10
        type_signature = computation_types.NamedTupleType([tf.int32] * 3)
        proto = tensorflow_computation_factory.create_constant(
            value, type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(None, type_signature)
        self.assertEqual(actual_type, expected_type)
        expected_value = [value] * 3
        actual_value = test_utils.run_tensorflow(proto, expected_value)
        self.assertCountEqual(actual_value, expected_value)
    def test_returns_computation(self, type_signature, count, value):
        proto = tensorflow_computation_factory.create_replicate_input(
            type_signature, count)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(
            type_signature, [type_signature] * count)
        self.assertEqual(actual_type, expected_type)
        actual_result = test_utils.run_tensorflow(proto, value)
        expected_result = anonymous_tuple.AnonymousTuple([(None, value)] *
                                                         count)
        self.assertEqual(actual_result, expected_result)
Exemplo n.º 24
0
    def test_returns_computation_tuple_named(self):
        type_signature = [('a', tf.int32), ('b', tf.float32)]

        proto = tensorflow_computation_factory.create_identity(type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = type_factory.unary_op(type_signature)
        self.assertEqual(actual_type, expected_type)
        expected_value = anonymous_tuple.AnonymousTuple([('a', 10),
                                                         ('b', 10.0)])
        actual_value = test_utils.run_tensorflow(proto, expected_value)
        self.assertEqual(actual_value, expected_value)
    def test_returns_computation(self, value, type_signature, expected_result):
        proto = tensorflow_computation_factory.create_constant(
            value, type_signature)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        expected_type = computation_types.FunctionType(None, type_signature)
        self.assertEqual(actual_type, expected_type)
        actual_result = test_utils.run_tensorflow(proto)
        if isinstance(expected_result, list):
            self.assertCountEqual(actual_result, expected_result)
        else:
            self.assertEqual(actual_result, expected_result)
Exemplo n.º 26
0
 def test_serialize_tensorflow_with_simple_add_three_lambda(self):
   comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
       lambda x: x + 3, tf.int32, context_stack_impl.context_stack)
   self.assertEqual(
       str(type_serialization.deserialize_type(comp.type)), '(int32 -> int32)')
   self.assertEqual(str(extra_type_spec), '(int32 -> int32)')
   self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
   parameter = tf.constant(1000)
   results = tf.compat.v1.Session().run(
       tf.import_graph_def(
           serialization_utils.unpack_graph_def(comp.tensorflow.graph_def),
           {comp.tensorflow.parameter.tensor.tensor_name: parameter},
           [comp.tensorflow.result.tensor.tensor_name]))
   self.assertEqual(results, [1003])
Exemplo n.º 27
0
    def _serialize_deserialize_roundtrip_test(self, type_list):
        """Performs roundtrip serialization/deserialization of computation_types.

    Args:
      type_list: A list of instances of computation_types.Type or things
        convertible to it.
    """
        for t in type_list:
            t1 = computation_types.to_type(t)
            p1 = type_serialization.serialize_type(t1)
            t2 = type_serialization.deserialize_type(p1)
            p2 = type_serialization.serialize_type(t2)
            self.assertEqual(repr(t1), repr(t2))
            self.assertEqual(repr(p1), repr(p2))
            self.assertTrue(type_utils.are_equivalent_types(t1, t2))
Exemplo n.º 28
0
 def test_basic_functionality_of_intrinsic_class(self):
     x = building_blocks.Intrinsic(
         'add_one', computation_types.FunctionType(tf.int32, tf.int32))
     self.assertEqual(str(x.type_signature), '(int32 -> int32)')
     self.assertEqual(x.uri, 'add_one')
     self.assertEqual(
         repr(x), 'Intrinsic(\'add_one\', '
         'FunctionType(TensorType(tf.int32), TensorType(tf.int32)))')
     self.assertEqual(x.compact_representation(), 'add_one')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic')
     self.assertEqual(x_proto.intrinsic.uri, x.uri)
     self._serialize_deserialize_roundtrip_test(x)
Exemplo n.º 29
0
 def test_basic_functionality_of_data_class(self):
     x = building_blocks.Data('/tmp/mydata',
                              computation_types.SequenceType(tf.int32))
     self.assertEqual(str(x.type_signature), 'int32*')
     self.assertEqual(x.uri, '/tmp/mydata')
     self.assertEqual(
         repr(x),
         'Data(\'/tmp/mydata\', SequenceType(TensorType(tf.int32)))')
     self.assertEqual(x.compact_representation(), '/tmp/mydata')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'data')
     self.assertEqual(x_proto.data.uri, x.uri)
     self._serialize_deserialize_roundtrip_test(x)
Exemplo n.º 30
0
async def embed_tf_scalar_constant(executor, type_spec, value):
  """Embeds a constant `val` of TFF type `type_spec` in `executor`.

  Args:
    executor: An instance of `tff.framework.Executor`.
    type_spec: An instance of `tff.Type`.
    value: A scalar value.

  Returns:
    An instance of `tff.framework.ExecutorValue` containing an embedded value.
  """
  py_typecheck.check_type(executor, executor_base.Executor)
  proto = tensorflow_computation_factory.create_constant(value, type_spec)
  type_signature = type_serialization.deserialize_type(proto.type)
  result = await executor.create_value(proto, type_signature)
  return await executor.create_call(result)