예제 #1
0
  def from_proto(cls, computation_proto):
    """Returns an instance of a derived class based on 'computation_proto'.

    Args:
      computation_proto: An instance of pb.Computation.

    Returns:
      An instance of a class that implements 'ComputationBuildingBlock' and
      that contains the deserialized logic from in 'computation_proto'.

    Raises:
      NotImplementedError: if computation_proto contains a kind of computation
        for which deserialization has not been implemented yet.
      ValueError: if deserialization failed due to the argument being invalid.
    """
    py_typecheck.check_type(computation_proto, pb.Computation)
    computation_oneof = computation_proto.WhichOneof('computation')
    deserializer = cls._deserializer_dict.get(computation_oneof)
    if deserializer is not None:
      deserialized = deserializer(computation_proto)
      type_spec = type_serialization.deserialize_type(computation_proto.type)
      if not type_utils.are_equivalent_types(deserialized.type_signature,
                                             type_spec):
        raise ValueError(
            'The type {} derived from the computation structure does not '
            'match the type {} declared in its signature'.format(
                str(deserialized.type_signature), str(type_spec)))
      return deserialized
    else:
      raise NotImplementedError(
          'Deserialization for computations of type {} has not been '
          'implemented yet.'.format(computation_oneof))
 def test_basic_functionality_of_tuple_class(self):
     x = computation_building_blocks.Reference('foo', tf.int32)
     y = computation_building_blocks.Reference('bar', tf.bool)
     z = computation_building_blocks.Tuple([x, ('y', y)])
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Tuple([('', y)])
     self.assertIsInstance(z, anonymous_tuple.AnonymousTuple)
     self.assertEqual(str(z.type_signature), '<int32,y=bool>')
     self.assertEqual(
         repr(z),
         'Tuple([(None, Reference(\'foo\', TensorType(tf.int32))), (\'y\', '
         'Reference(\'bar\', TensorType(tf.bool)))])')
     self.assertEqual(z.tff_repr, '<foo,y=bar>')
     self.assertEqual(dir(z), ['y'])
     self.assertIs(z.y, y)
     self.assertLen(z, 2)
     self.assertIs(z[0], x)
     self.assertIs(z[1], y)
     self.assertEqual(','.join(e.tff_repr for e in iter(z)), 'foo,bar')
     z_proto = z.proto
     self.assertEqual(type_serialization.deserialize_type(z_proto.type),
                      z.type_signature)
     self.assertEqual(z_proto.WhichOneof('computation'), 'tuple')
     self.assertEqual([e.name for e in z_proto.tuple.element], ['', 'y'])
     self._serialize_deserialize_roundtrip_test(z)
 def test_basic_functionality_of_lambda_class(self):
     arg_name = 'arg'
     arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)),
                 ('x', tf.int32)]
     arg = computation_building_blocks.Reference(arg_name, arg_type)
     arg_f = computation_building_blocks.Selection(arg, name='f')
     arg_x = computation_building_blocks.Selection(arg, name='x')
     x = computation_building_blocks.Lambda(
         arg_name, arg_type,
         computation_building_blocks.Call(
             arg_f, computation_building_blocks.Call(arg_f, arg_x)))
     self.assertEqual(str(x.type_signature),
                      '(<f=(int32 -> int32),x=int32> -> int32)')
     self.assertEqual(x.parameter_name, arg_name)
     self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>')
     self.assertEqual(x.result.tff_repr, 'arg.f(arg.f(arg.x))')
     arg_type_repr = (
         'NamedTupleType(['
         '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), '
         '(\'x\', TensorType(tf.int32))])')
     self.assertEqual(
         repr(x), 'Lambda(\'arg\', {0}, '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), '
         'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format(
             arg_type_repr))
     self.assertEqual(x.tff_repr, '(arg -> arg.f(arg.f(arg.x)))')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'lambda')
     self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name)
     self.assertEqual(str(getattr(x_proto, 'lambda').result),
                      str(x.result.proto))
     self._serialize_deserialize_roundtrip_test(x)
예제 #4
0
def serialize_value(value, type_spec=None):
    """Serializes a value into `executor_pb2.Value`.

  Args:
    value: A value to be serialized.
    type_spec: Optional type spec, a `tff.Type` or something convertible to it.

  Returns:
    An instance of `executor_pb2.Value` with the serialized content of `value`.

  Returns:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the value is malformed.
  """
    type_spec = computation_types.to_type(type_spec)
    if isinstance(value, computation_pb2.Computation):
        if type_spec is not None:
            type_utils.reconcile_value_type_with_type_spec(
                type_serialization.deserialize_type(value.type), type_spec)
        return executor_pb2.Value(computation=value)
    elif isinstance(value, computation_impl.ComputationImpl):
        return serialize_value(
            computation_impl.ComputationImpl.get_proto(value),
            type_utils.reconcile_value_with_type_spec(value, type_spec))
    elif isinstance(type_spec, computation_types.TensorType):
        return serialize_tensor_value(value, type_spec)
    else:
        raise ValueError(
            'Unable to serialize value with Python type {} and {} TFF type.'.
            format(str(py_typecheck.type_string(type(value))),
                   str(type_spec) if type_spec is not None else 'unknown'))
예제 #5
0
def deserialize_value(value_proto):
    """Deserializes a value (of any type) from `executor_pb2.Value`.

  Args:
    value_proto: An instance of `executor_pb2.Value`.

  Returns:
    A tuple `(value, type_spec)`, where `value` is a deserialized
    representation of the transmitted value (e.g., Numpy array, or a
    `pb.Computation` instance), and `type_spec` is an instance of
    `tff.TensorType` that represents its type.

  Returns:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the value is malformed.
  """
    py_typecheck.check_type(value_proto, executor_pb2.Value)
    which_value = value_proto.WhichOneof('value')
    if which_value == 'tensor':
        return deserialize_tensor_value(value_proto)
    elif which_value == 'computation':
        return (value_proto.computation,
                type_serialization.deserialize_type(
                    value_proto.computation.type))
    else:
        raise ValueError(
            'Unable to deserialize a value of type {}.'.format(which_value))
예제 #6
0
def deserialize_sequence_value(sequence_value_proto):
    """Deserializes a `tf.data.Dataset`.

  Args:
    sequence_value_proto: `Sequence` protocol buffer message.

  Returns:
    A tuple of `(tf.data.Dataset, tff.Type)`.
  """
    py_typecheck.check_type(sequence_value_proto, executor_pb2.Value.Sequence)

    which_value = sequence_value_proto.WhichOneof('value')
    if which_value == 'zipped_saved_model':
        ds = tensorflow_serialization.deserialize_dataset(
            sequence_value_proto.zipped_saved_model)
    else:
        raise NotImplementedError(
            'Deserializing Sequences enocded as {!s} has not been implemented'.
            format(which_value))

    element_type = type_serialization.deserialize_type(
        sequence_value_proto.element_type)

    # If a serialized dataset had elements of nested structes of tensors (e.g.
    # `dict`, `OrderedDict`), the deserialized dataset will return `dict`,
    # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion).
    #
    # Since the dataset will only be used inside TFF, we wrap the dictionary
    # coming from TF in an `OrderedDict` when necessary (a type that both TF and
    # TFF understand), using the field order stored in the TFF type stored during
    # serialization.
    ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec(
        ds, element_type)

    return ds, computation_types.SequenceType(element=element_type)
 def test_basic_functionality_of_block_class(self):
     x = computation_building_blocks.Block([
         ('x',
          computation_building_blocks.Reference('arg',
                                                (tf.int32, tf.int32))),
         ('y',
          computation_building_blocks.Selection(
              computation_building_blocks.Reference('x',
                                                    (tf.int32, tf.int32)),
              index=0))
     ], computation_building_blocks.Reference('y', tf.int32))
     self.assertEqual(str(x.type_signature), 'int32')
     self.assertEqual([(k, v.tff_repr) for k, v in x.locals],
                      [('x', 'arg'), ('y', 'x[0]')])
     self.assertEqual(x.result.tff_repr, 'y')
     self.assertEqual(
         repr(x), 'Block([(\'x\', Reference(\'arg\', '
         'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)]))), '
         '(\'y\', Selection(Reference(\'x\', '
         'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)])), '
         'index=0))], '
         'Reference(\'y\', TensorType(tf.int32)))')
     self.assertEqual(x.tff_repr, '(let x=arg,y=x[0] in y)')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'block')
     self.assertEqual(str(x_proto.block.result), str(x.result.proto))
     for idx, loc_proto in enumerate(x_proto.block.local):
         loc_name, loc_value = x.locals[idx]
         self.assertEqual(loc_proto.name, loc_name)
         self.assertEqual(str(loc_proto.value), str(loc_value.proto))
         self._serialize_deserialize_roundtrip_test(x)
 def test_serialize_tensorflow_with_structured_type_signature(self):
     batch_type = collections.namedtuple('BatchType', ['x', 'y'])
     output_type = collections.namedtuple('OutputType', ['A', 'B'])
     comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
         lambda z: output_type(2.0 * tf.cast(z.x, tf.float32), 3.0 * z.y),
         batch_type(tf.int32, (tf.float32, [2])),
         context_stack_impl.context_stack)
     self.assertEqual(
         str(type_serialization.deserialize_type(comp.type)),
         '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)')
     self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
     self.assertEqual(
         str(extra_type_spec),
         '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)')
     self.assertIsInstance(
         extra_type_spec.parameter,
         computation_types.NamedTupleTypeWithPyContainerType)
     self.assertIs(
         computation_types.NamedTupleTypeWithPyContainerType.
         get_container_type(extra_type_spec.parameter), batch_type)
     self.assertIsInstance(
         extra_type_spec.result,
         computation_types.NamedTupleTypeWithPyContainerType)
     self.assertIs(
         computation_types.NamedTupleTypeWithPyContainerType.
         get_container_type(extra_type_spec.result), output_type)
 def test_basic_functionality_of_call_class(self):
     x = computation_building_blocks.Reference(
         'foo', computation_types.FunctionType(tf.int32, tf.bool))
     y = computation_building_blocks.Reference('bar', tf.int32)
     z = computation_building_blocks.Call(x, y)
     self.assertEqual(str(z.type_signature), 'bool')
     self.assertIs(z.function, x)
     self.assertIs(z.argument, y)
     self.assertEqual(
         repr(z), 'Call(Reference(\'foo\', '
         'FunctionType(TensorType(tf.int32), TensorType(tf.bool))), '
         'Reference(\'bar\', TensorType(tf.int32)))')
     self.assertEqual(z.tff_repr, 'foo(bar)')
     with self.assertRaises(TypeError):
         computation_building_blocks.Call(x)
     w = computation_building_blocks.Reference('bak', tf.float32)
     with self.assertRaises(TypeError):
         computation_building_blocks.Call(x, w)
     z_proto = z.proto
     self.assertEqual(type_serialization.deserialize_type(z_proto.type),
                      z.type_signature)
     self.assertEqual(z_proto.WhichOneof('computation'), 'call')
     self.assertEqual(str(z_proto.call.function), str(x.proto))
     self.assertEqual(str(z_proto.call.argument), str(y.proto))
     self._serialize_deserialize_roundtrip_test(z)
예제 #10
0
 def test_intrinsic_class_succeeds_simple_federated_map(self):
   simple_function = computation_types.FunctionType(tf.int32, tf.float32)
   federated_arg = computation_types.FederatedType(simple_function.parameter,
                                                   placements.CLIENTS)
   federated_result = computation_types.FederatedType(simple_function.result,
                                                      placements.CLIENTS)
   federated_map_concrete_type = computation_types.FunctionType(
       [simple_function, federated_arg], federated_result)
   concrete_federated_map = building_blocks.Intrinsic(
       intrinsic_defs.FEDERATED_MAP.uri, federated_map_concrete_type)
   self.assertIsInstance(concrete_federated_map, building_blocks.Intrinsic)
   self.assertEqual(
       str(concrete_federated_map.type_signature),
       '(<(int32 -> float32),{int32}@CLIENTS> -> {float32}@CLIENTS)')
   self.assertEqual(concrete_federated_map.uri, 'federated_map')
   self.assertEqual(concrete_federated_map.compact_representation(),
                    'federated_map')
   concrete_federated_map_proto = concrete_federated_map.proto
   self.assertEqual(
       type_serialization.deserialize_type(concrete_federated_map_proto.type),
       concrete_federated_map.type_signature)
   self.assertEqual(
       concrete_federated_map_proto.WhichOneof('computation'), 'intrinsic')
   self.assertEqual(concrete_federated_map_proto.intrinsic.uri,
                    concrete_federated_map.uri)
   self._serialize_deserialize_roundtrip_test(concrete_federated_map)
예제 #11
0
 def from_proto(cls, computation_proto):
     _check_computation_oneof(computation_proto, 'placement')
     py_typecheck.check_type(
         type_serialization.deserialize_type(computation_proto.type),
         computation_types.PlacementType)
     return cls(
         placement_literals.uri_to_placement_literal(
             str(computation_proto.placement.uri)))
예제 #12
0
 def from_proto(cls, computation_proto):
     _check_computation_oneof(computation_proto, 'lambda')
     the_lambda = getattr(computation_proto, 'lambda')
     return cls(
         str(the_lambda.parameter_name),
         type_serialization.deserialize_type(
             computation_proto.type.function.parameter),
         ComputationBuildingBlock.from_proto(the_lambda.result))
예제 #13
0
  def __init__(self, value, scope=None, type_spec=None):
    """Creates an instance of a value embedded in a lambda executor.

    The internal representation of a value can take one of the following
    supported forms:

    * An instance of `executor_value_base.ExecutorValue` that represents a
      value embedded in the target executor (functional or non-functional).

    * An as-yet unprocessed instance of `pb.Computation` that represents a
      function yet to be invoked (always a value of a functional type; any
      non-functional constructs should be processed on the fly).

    * A coroutine callable in Python that accepts a single argument that must
      be an instance of `LambdaExecutorValue` (or `None`), and that returns a
      result that is also an instance of `LambdaExecutorValue`. The associated
      type signature is always functional.

    * A single-level tuple (`anonymous_tuple.AnonymousTuple`) of instances
      of this class (of any of the supported forms listed here).

    Args:
      value: The internal representation of a value, as specified above.
      scope: An optional scope for computations. Only allowed if `value` is an
        unprocessed instance of `pb.Computation`, otherwise it must be `None`
        (the scope is meaningless in other cases).
      type_spec: An optional type signature, only allowed if `value` is a
        callable that represents a function (in which case it must be an
        instance of `computation_types.FunctionType`), otherwise it  must be
        `None` (the type is implied in other cases).
    """
    if isinstance(value, executor_value_base.ExecutorValue):
      py_typecheck.check_none(scope)
      py_typecheck.check_none(type_spec)
      type_spec = value.type_signature
    elif isinstance(value, pb.Computation):
      if scope is not None:
        py_typecheck.check_type(scope, LambdaExecutorScope)
      py_typecheck.check_none(type_spec)
      type_spec = type_utils.get_function_type(
          type_serialization.deserialize_type(value.type))
    elif callable(value):
      py_typecheck.check_none(scope)
      py_typecheck.check_type(type_spec, computation_types.FunctionType)
    else:
      py_typecheck.check_type(value, anonymous_tuple.AnonymousTuple)
      py_typecheck.check_none(scope)
      py_typecheck.check_none(type_spec)
      type_elements = []
      for k, v in anonymous_tuple.to_elements(value):
        py_typecheck.check_type(v, LambdaExecutorValue)
        type_elements.append((k, v.type_signature))
      type_spec = computation_types.NamedTupleType([
          (k, v) if k is not None else v for k, v in type_elements
      ])
    self._value = value
    self._scope = scope
    self._type_signature = type_spec
예제 #14
0
 def test_serialize_tensorflow_with_no_parameter(self):
   comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
       lambda: tf.constant(99), None, context_stack_impl.context_stack)
   self.assertEqual(
       str(type_serialization.deserialize_type(comp.type)), '( -> int32)')
   self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
   results = tf.Session().run(
       tf.import_graph_def(comp.tensorflow.graph_def, None,
                           [comp.tensorflow.result.tensor.tensor_name]))
   self.assertEqual(results, [99])
 def test_basic_functionality_of_placement_class(self):
     x = computation_building_blocks.Placement(placements.CLIENTS)
     self.assertEqual(str(x.type_signature), 'placement')
     self.assertEqual(x.uri, 'clients')
     self.assertEqual(repr(x), 'Placement(\'clients\')')
     self.assertEqual(x.tff_repr, 'CLIENTS')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'placement')
     self.assertEqual(x_proto.placement.uri, x.uri)
     self._serialize_deserialize_roundtrip_test(x)
 def test_basic_functionality_of_reference_class(self):
     x = computation_building_blocks.Reference('foo', tf.int32)
     self.assertEqual(x.name, 'foo')
     self.assertEqual(str(x.type_signature), 'int32')
     self.assertEqual(repr(x), 'Reference(\'foo\', TensorType(tf.int32))')
     self.assertEqual(x.tff_repr, 'foo')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'reference')
     self.assertEqual(x_proto.reference.name, x.name)
     self._serialize_deserialize_roundtrip_test(x)
 def test_serialize_tensorflow_with_simple_add_three_lambda(self):
   comp = tensorflow_serialization.serialize_py_fn_as_tf_computation(
       lambda x: x + 3, tf.int32, context_stack_impl.context_stack)
   self.assertEqual(
       str(type_serialization.deserialize_type(comp.type)), '(int32 -> int32)')
   self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
   parameter = tf.constant(1000)
   results = tf.Session().run(
       tf.import_graph_def(
           serialization_utils.unpack_graph_def(comp.tensorflow.graph_def),
           {comp.tensorflow.parameter.tensor.tensor_name: parameter},
           [comp.tensorflow.result.tensor.tensor_name]))
   self.assertEqual(results, [1003])
예제 #18
0
 def test_basic_intrinsic_functionality_plus_canonical_typecheck(self):
   x = building_blocks.Intrinsic(
       'generic_plus',
       computation_types.FunctionType([tf.int32, tf.int32], tf.int32))
   self.assertEqual(str(x.type_signature), '(<int32,int32> -> int32)')
   self.assertEqual(x.uri, 'generic_plus')
   self.assertEqual(x.compact_representation(), 'generic_plus')
   x_proto = x.proto
   self.assertEqual(
       type_serialization.deserialize_type(x_proto.type), x.type_signature)
   self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic')
   self.assertEqual(x_proto.intrinsic.uri, x.uri)
   self._serialize_deserialize_roundtrip_test(x)
예제 #19
0
 def test_basic_functionality_of_data_class(self):
   x = building_blocks.Data('/tmp/mydata',
                            computation_types.SequenceType(tf.int32))
   self.assertEqual(str(x.type_signature), 'int32*')
   self.assertEqual(x.uri, '/tmp/mydata')
   self.assertEqual(
       repr(x), 'Data(\'/tmp/mydata\', SequenceType(TensorType(tf.int32)))')
   self.assertEqual(x.compact_representation(), '/tmp/mydata')
   x_proto = x.proto
   self.assertEqual(
       type_serialization.deserialize_type(x_proto.type), x.type_signature)
   self.assertEqual(x_proto.WhichOneof('computation'), 'data')
   self.assertEqual(x_proto.data.uri, x.uri)
   self._serialize_deserialize_roundtrip_test(x)
 def test_basic_functionality_of_intrinsic_class(self):
     x = computation_building_blocks.Intrinsic(
         'add_one', computation_types.FunctionType(tf.int32, tf.int32))
     self.assertEqual(str(x.type_signature), '(int32 -> int32)')
     self.assertEqual(x.uri, 'add_one')
     self.assertEqual(
         repr(x), 'Intrinsic(\'add_one\', '
         'FunctionType(TensorType(tf.int32), TensorType(tf.int32)))')
     self.assertEqual(x.tff_repr, 'add_one')
     x_proto = x.proto
     self.assertEqual(type_serialization.deserialize_type(x_proto.type),
                      x.type_signature)
     self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic')
     self.assertEqual(x_proto.intrinsic.uri, x.uri)
     self._serialize_deserialize_roundtrip_test(x)
예제 #21
0
    def _serialize_deserialize_roundtrip_test(self, type_list):
        """Performs roundtrip serialization/deserialization of computation_types.

    Args:
      type_list: A list of instances of computation_types.Type or things
        convertible to it.
    """
        for t in type_list:
            t1 = computation_types.to_type(t)
            p1 = type_serialization.serialize_type(t1)
            t2 = type_serialization.deserialize_type(p1)
            p2 = type_serialization.serialize_type(t2)
            self.assertEqual(repr(t1), repr(t2))
            self.assertEqual(repr(p1), repr(p2))
            self.assertTrue(type_utils.are_equivalent_types(t1, t2))
def wrap_graph_parameter_as_tuple(comp, name=None):
    """Wraps the parameter of `comp` in a tuple binding.

  `wrap_graph_parameter_as_tuple` is intended as a preprocessing step
  to `pad_graph_inputs_to_match_type`, so that `pad_graph_inputs_to_match_type`
  can
  make the assumption that its argument `comp` always has a tuple binding,
  instead of dealing with the possibility of an unwrapped tensor or sequence
  binding.

  Args:
    comp: Instance of `computation_building_blocks.CompiledComputation` whose
      parameter we wish to wrap in a tuple binding.
    name: Optional string argument, the name to assign to the element type in
      the constructed tuple. Defaults to `None`.

  Returns:
    A transformed version of comp representing exactly the same computation,
    but accepting a tuple containing one element--the parameter of `comp`.

  Raises:
    TypeError: If `comp` is not a
      `computation_building_blocks.CompiledComputation`.
  """
    py_typecheck.check_type(comp,
                            computation_building_blocks.CompiledComputation)
    if name is not None:
        py_typecheck.check_type(name, six.string_types)
    proto = comp.proto
    proto_type = type_serialization.deserialize_type(proto.type)

    parameter_binding = [proto.tensorflow.parameter]
    parameter_type_list = [(name, proto_type.parameter)]
    new_parameter_binding = pb.TensorFlow.Binding(
        tuple=pb.TensorFlow.NamedTupleBinding(element=parameter_binding))

    new_function_type = computation_types.FunctionType(parameter_type_list,
                                                       proto_type.result)
    serialized_type = type_serialization.serialize_type(new_function_type)

    input_padded_proto = pb.Computation(
        type=serialized_type,
        tensorflow=pb.TensorFlow(graph_def=proto.tensorflow.graph_def,
                                 initialize_op=proto.tensorflow.initialize_op,
                                 parameter=new_parameter_binding,
                                 result=proto.tensorflow.result))

    return computation_building_blocks.CompiledComputation(input_padded_proto)
 def test_basic_functionality_of_selection_class(self):
     x = computation_building_blocks.Reference('foo', [('bar', tf.int32),
                                                       ('baz', tf.bool)])
     y = computation_building_blocks.Selection(x, name='bar')
     self.assertEqual(y.name, 'bar')
     self.assertEqual(y.index, None)
     self.assertEqual(str(y.type_signature), 'int32')
     self.assertEqual(
         repr(y), 'Selection(Reference(\'foo\', NamedTupleType(['
         '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))'
         ', name=\'bar\')')
     self.assertEqual(computation_building_blocks.compact_representation(y),
                      'foo.bar')
     z = computation_building_blocks.Selection(x, name='baz')
     self.assertEqual(str(z.type_signature), 'bool')
     self.assertEqual(computation_building_blocks.compact_representation(z),
                      'foo.baz')
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Selection(x, name='bak')
     x0 = computation_building_blocks.Selection(x, index=0)
     self.assertEqual(x0.name, None)
     self.assertEqual(x0.index, 0)
     self.assertEqual(str(x0.type_signature), 'int32')
     self.assertEqual(
         repr(x0), 'Selection(Reference(\'foo\', NamedTupleType(['
         '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))'
         ', index=0)')
     self.assertEqual(
         computation_building_blocks.compact_representation(x0), 'foo[0]')
     x1 = computation_building_blocks.Selection(x, index=1)
     self.assertEqual(str(x1.type_signature), 'bool')
     self.assertEqual(
         computation_building_blocks.compact_representation(x1), 'foo[1]')
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Selection(x, index=2)
     with self.assertRaises(ValueError):
         _ = computation_building_blocks.Selection(x, index=-1)
     y_proto = y.proto
     self.assertEqual(type_serialization.deserialize_type(y_proto.type),
                      y.type_signature)
     self.assertEqual(y_proto.WhichOneof('computation'), 'selection')
     self.assertEqual(str(y_proto.selection.source), str(x.proto))
     self.assertEqual(y_proto.selection.name, 'bar')
     self._serialize_deserialize_roundtrip_test(y)
     self._serialize_deserialize_roundtrip_test(z)
     self._serialize_deserialize_roundtrip_test(x0)
     self._serialize_deserialize_roundtrip_test(x1)
예제 #24
0
def serialize_value(value, type_spec=None):
    """Serializes a value into `executor_pb2.Value`.

  Args:
    value: A value to be serialized.
    type_spec: Optional type spec, a `tff.Type` or something convertible to it.

  Returns:
    A tuple `(value_proto, ret_type_spec)` where `value_proto` is an instance
    of `executor_pb2.Value` with the serialized content of `value`, and the
    returned `ret_type_spec` is an instance of `tff.Type` that represents the
    TFF type of the serialized value.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the value is malformed.
  """
    type_spec = computation_types.to_type(type_spec)
    if isinstance(value, computation_pb2.Computation):
        type_spec = type_utils.reconcile_value_type_with_type_spec(
            type_serialization.deserialize_type(value.type), type_spec)
        return executor_pb2.Value(computation=value), type_spec
    elif isinstance(value, computation_impl.ComputationImpl):
        return serialize_value(
            computation_impl.ComputationImpl.get_proto(value),
            type_utils.reconcile_value_with_type_spec(value, type_spec))
    elif isinstance(type_spec, computation_types.TensorType):
        return serialize_tensor_value(value, type_spec)
    elif isinstance(type_spec, computation_types.NamedTupleType):
        type_elements = anonymous_tuple.to_elements(type_spec)
        val_elements = anonymous_tuple.to_elements(
            anonymous_tuple.from_container(value))
        tup_elems = []
        for (e_name, e_type), (_, e_val) in zip(type_elements, val_elements):
            e_proto, _ = serialize_value(e_val, e_type)
            tup_elems.append(
                executor_pb2.Value.Tuple.Element(
                    name=e_name if e_name else None, value=e_proto))
        result_proto = (executor_pb2.Value(tuple=executor_pb2.Value.Tuple(
            element=tup_elems)))
        return result_proto, type_spec
    else:
        raise ValueError(
            'Unable to serialize value with Python type {} and {} TFF type.'.
            format(str(py_typecheck.type_string(type(value))),
                   str(type_spec) if type_spec is not None else 'unknown'))
    def test_serialize_tensorflow_with_data_set_sum_lambda(self):
        def _legacy_dataset_reducer_example(ds):
            return ds.reduce(np.int64(0), lambda x, y: x + y)

        comp = tensorflow_serialization.serialize_py_func_as_tf_computation(
            _legacy_dataset_reducer_example,
            computation_types.SequenceType(tf.int64),
            context_stack_impl.context_stack)
        self.assertEqual(str(type_serialization.deserialize_type(comp.type)),
                         '(int64* -> int64)')
        self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
        parameter = tf.data.Dataset.range(5)
        results = tf.Session().run(
            tf.import_graph_def(
                comp.tensorflow.graph_def, {
                    comp.tensorflow.parameter.sequence.iterator_string_handle_name:
                    (parameter.make_one_shot_iterator().string_handle())
                }, [comp.tensorflow.result.tensor.tensor_name]))
        self.assertEqual(results, [10])
예제 #26
0
  def __init__(self, computation_proto, context_stack, annotated_type=None):
    """Constructs a new instance of ComputationImpl from the computation_proto.

    Args:
      computation_proto: The protocol buffer that represents the computation, an
        instance of pb.Computation.
      context_stack: The context stack to use.
      annotated_type: Optional, type information with additional annotations
        that replaces the information in `computation_proto.type`.

    Raises:
      TypeError: if `annotated_type` is not `None` and is not compatible with
      `computation_proto.type`.
    """
    py_typecheck.check_type(computation_proto, pb.Computation)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    type_spec = type_serialization.deserialize_type(computation_proto.type)
    py_typecheck.check_type(type_spec, computation_types.Type)
    if annotated_type is not None:
      py_typecheck.check_type(annotated_type, computation_types.Type)
      # Extra information is encoded in a NamedTupleTypeWithPyContainerType
      # subclass which does not override __eq__. The two type specs should still
      # compare as equal.
      if type_spec != annotated_type:
        raise TypeError(
            'annotated_type not compatible with computation_proto.type\n'
            'computation_proto.type: {!s}\n'
            'annotated_type: {!s}'.format(type_spec, annotated_type)
        )
      type_spec = annotated_type

    type_utils.check_well_formed(type_spec)

    # We may need to modify the type signature to reflect the fact that in the
    # underlying framework for composing computations, there is no concept of
    # no-argument lambdas, but in Python, every computation needs to look like
    # a function that needs to be invoked.
    if not isinstance(type_spec, computation_types.FunctionType):
      type_spec = computation_types.FunctionType(None, type_spec)

    super(ComputationImpl, self).__init__(type_spec, context_stack)
    self._computation_proto = computation_proto
    def test_serialize_tensorflow_with_data_set_sum_lambda(self):
        def _legacy_dataset_reducer_example(ds):
            return ds.reduce(np.int64(0), lambda x, y: x + y)

        comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation(
            _legacy_dataset_reducer_example,
            computation_types.SequenceType(tf.int64),
            context_stack_impl.context_stack)
        self.assertEqual(str(type_serialization.deserialize_type(comp.type)),
                         '(int64* -> int64)')
        self.assertEqual(str(extra_type_spec), '(int64* -> int64)')
        self.assertEqual(comp.WhichOneof('computation'), 'tensorflow')
        parameter = tf.data.Dataset.range(5)
        results = tf.compat.v1.Session().run(
            tf.import_graph_def(
                serialization_utils.unpack_graph_def(
                    comp.tensorflow.graph_def), {
                        comp.tensorflow.parameter.sequence.variant_tensor_name:
                        tf.data.experimental.to_variant(parameter)
                    }, [comp.tensorflow.result.tensor.tensor_name]))
        self.assertEqual(results, [10])
예제 #28
0
def deserialize_value(value_proto):
    """Deserializes a value (of any type) from `executor_pb2.Value`.

  Args:
    value_proto: An instance of `executor_pb2.Value`.

  Returns:
    A tuple `(value, type_spec)`, where `value` is a deserialized
    representation of the transmitted value (e.g., Numpy array, or a
    `pb.Computation` instance), and `type_spec` is an instance of
    `tff.TensorType` that represents its type.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the value is malformed.
  """
    py_typecheck.check_type(value_proto, executor_pb2.Value)
    which_value = value_proto.WhichOneof('value')
    if which_value == 'tensor':
        return deserialize_tensor_value(value_proto)
    elif which_value == 'computation':
        return (value_proto.computation,
                type_serialization.deserialize_type(
                    value_proto.computation.type))
    elif which_value == 'tuple':
        val_elems = []
        type_elems = []
        for e in value_proto.tuple.element:
            name = e.name if e.name else None
            e_val, e_type = deserialize_value(e.value)
            val_elems.append((name, e_val))
            type_elems.append((name, e_type) if name else e_type)
        return (anonymous_tuple.AnonymousTuple(val_elems),
                computation_types.NamedTupleType(type_elems))
    elif which_value == 'sequence':
        return deserialize_sequence_value(value_proto.sequence)
    else:
        raise ValueError(
            'Unable to deserialize a value of type {}.'.format(which_value))
예제 #29
0
    def __init__(self, proto, name=None):
        """Creates a representation of a fully constructed computation.

    Args:
      proto: An instance of pb.Computation with the computation logic.
      name: An optional string name to associate with this computation, used
        only for debugging purposes. If the name is not specified (None), it is
        autogenerated as a hexadecimal string from the hash of the proto.

    Raises:
      TypeError: if the arguments are of the wrong types.
    """
        py_typecheck.check_type(proto, pb.Computation)
        if name is not None:
            py_typecheck.check_type(name, six.string_types)
        super(CompiledComputation,
              self).__init__(type_serialization.deserialize_type(proto.type))
        self._proto = proto
        if name is not None:
            self._name = name
        else:
            self._name = '{:x}'.format(
                zlib.adler32(six.b(repr(self._proto))) & 0xFFFFFFFF)
예제 #30
0
    def __init__(self, computation_proto, context_stack):
        """Constructs a new instance of ComputationImpl from the computation_proto.

    Args:
      computation_proto: The protocol buffer that represents the computation, an
        instance of pb.Computation.
      context_stack: The context stack to use.
    """
        py_typecheck.check_type(computation_proto, pb.Computation)
        py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
        type_spec = type_serialization.deserialize_type(computation_proto.type)
        py_typecheck.check_type(type_spec, computation_types.Type)

        type_utils.check_well_formed(type_spec)

        # We may need to modify the type signature to reflect the fact that in the
        # underlying framework for composing computations, there is no concept of
        # no-argument lambdas, but in Python, every computation needs to look like
        # a function that needs to be invoked.
        if not isinstance(type_spec, computation_types.FunctionType):
            type_spec = computation_types.FunctionType(None, type_spec)

        super(ComputationImpl, self).__init__(type_spec, context_stack)
        self._computation_proto = computation_proto