def from_proto(cls, computation_proto): """Returns an instance of a derived class based on 'computation_proto'. Args: computation_proto: An instance of pb.Computation. Returns: An instance of a class that implements 'ComputationBuildingBlock' and that contains the deserialized logic from in 'computation_proto'. Raises: NotImplementedError: if computation_proto contains a kind of computation for which deserialization has not been implemented yet. ValueError: if deserialization failed due to the argument being invalid. """ py_typecheck.check_type(computation_proto, pb.Computation) computation_oneof = computation_proto.WhichOneof('computation') deserializer = cls._deserializer_dict.get(computation_oneof) if deserializer is not None: deserialized = deserializer(computation_proto) type_spec = type_serialization.deserialize_type(computation_proto.type) if not type_utils.are_equivalent_types(deserialized.type_signature, type_spec): raise ValueError( 'The type {} derived from the computation structure does not ' 'match the type {} declared in its signature'.format( str(deserialized.type_signature), str(type_spec))) return deserialized else: raise NotImplementedError( 'Deserialization for computations of type {} has not been ' 'implemented yet.'.format(computation_oneof))
def test_basic_functionality_of_tuple_class(self): x = computation_building_blocks.Reference('foo', tf.int32) y = computation_building_blocks.Reference('bar', tf.bool) z = computation_building_blocks.Tuple([x, ('y', y)]) with self.assertRaises(ValueError): _ = computation_building_blocks.Tuple([('', y)]) self.assertIsInstance(z, anonymous_tuple.AnonymousTuple) self.assertEqual(str(z.type_signature), '<int32,y=bool>') self.assertEqual( repr(z), 'Tuple([(None, Reference(\'foo\', TensorType(tf.int32))), (\'y\', ' 'Reference(\'bar\', TensorType(tf.bool)))])') self.assertEqual(z.tff_repr, '<foo,y=bar>') self.assertEqual(dir(z), ['y']) self.assertIs(z.y, y) self.assertLen(z, 2) self.assertIs(z[0], x) self.assertIs(z[1], y) self.assertEqual(','.join(e.tff_repr for e in iter(z)), 'foo,bar') z_proto = z.proto self.assertEqual(type_serialization.deserialize_type(z_proto.type), z.type_signature) self.assertEqual(z_proto.WhichOneof('computation'), 'tuple') self.assertEqual([e.name for e in z_proto.tuple.element], ['', 'y']) self._serialize_deserialize_roundtrip_test(z)
def test_basic_functionality_of_lambda_class(self): arg_name = 'arg' arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)), ('x', tf.int32)] arg = computation_building_blocks.Reference(arg_name, arg_type) arg_f = computation_building_blocks.Selection(arg, name='f') arg_x = computation_building_blocks.Selection(arg, name='x') x = computation_building_blocks.Lambda( arg_name, arg_type, computation_building_blocks.Call( arg_f, computation_building_blocks.Call(arg_f, arg_x))) self.assertEqual(str(x.type_signature), '(<f=(int32 -> int32),x=int32> -> int32)') self.assertEqual(x.parameter_name, arg_name) self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>') self.assertEqual(x.result.tff_repr, 'arg.f(arg.f(arg.x))') arg_type_repr = ( 'NamedTupleType([' '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), ' '(\'x\', TensorType(tf.int32))])') self.assertEqual( repr(x), 'Lambda(\'arg\', {0}, ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format( arg_type_repr)) self.assertEqual(x.tff_repr, '(arg -> arg.f(arg.f(arg.x)))') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'lambda') self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name) self.assertEqual(str(getattr(x_proto, 'lambda').result), str(x.result.proto)) self._serialize_deserialize_roundtrip_test(x)
def serialize_value(value, type_spec=None): """Serializes a value into `executor_pb2.Value`. Args: value: A value to be serialized. type_spec: Optional type spec, a `tff.Type` or something convertible to it. Returns: An instance of `executor_pb2.Value` with the serialized content of `value`. Returns: TypeError: If the arguments are of the wrong types. ValueError: If the value is malformed. """ type_spec = computation_types.to_type(type_spec) if isinstance(value, computation_pb2.Computation): if type_spec is not None: type_utils.reconcile_value_type_with_type_spec( type_serialization.deserialize_type(value.type), type_spec) return executor_pb2.Value(computation=value) elif isinstance(value, computation_impl.ComputationImpl): return serialize_value( computation_impl.ComputationImpl.get_proto(value), type_utils.reconcile_value_with_type_spec(value, type_spec)) elif isinstance(type_spec, computation_types.TensorType): return serialize_tensor_value(value, type_spec) else: raise ValueError( 'Unable to serialize value with Python type {} and {} TFF type.'. format(str(py_typecheck.type_string(type(value))), str(type_spec) if type_spec is not None else 'unknown'))
def deserialize_value(value_proto): """Deserializes a value (of any type) from `executor_pb2.Value`. Args: value_proto: An instance of `executor_pb2.Value`. Returns: A tuple `(value, type_spec)`, where `value` is a deserialized representation of the transmitted value (e.g., Numpy array, or a `pb.Computation` instance), and `type_spec` is an instance of `tff.TensorType` that represents its type. Returns: TypeError: If the arguments are of the wrong types. ValueError: If the value is malformed. """ py_typecheck.check_type(value_proto, executor_pb2.Value) which_value = value_proto.WhichOneof('value') if which_value == 'tensor': return deserialize_tensor_value(value_proto) elif which_value == 'computation': return (value_proto.computation, type_serialization.deserialize_type( value_proto.computation.type)) else: raise ValueError( 'Unable to deserialize a value of type {}.'.format(which_value))
def deserialize_sequence_value(sequence_value_proto): """Deserializes a `tf.data.Dataset`. Args: sequence_value_proto: `Sequence` protocol buffer message. Returns: A tuple of `(tf.data.Dataset, tff.Type)`. """ py_typecheck.check_type(sequence_value_proto, executor_pb2.Value.Sequence) which_value = sequence_value_proto.WhichOneof('value') if which_value == 'zipped_saved_model': ds = tensorflow_serialization.deserialize_dataset( sequence_value_proto.zipped_saved_model) else: raise NotImplementedError( 'Deserializing Sequences enocded as {!s} has not been implemented'. format(which_value)) element_type = type_serialization.deserialize_type( sequence_value_proto.element_type) # If a serialized dataset had elements of nested structes of tensors (e.g. # `dict`, `OrderedDict`), the deserialized dataset will return `dict`, # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion). # # Since the dataset will only be used inside TFF, we wrap the dictionary # coming from TF in an `OrderedDict` when necessary (a type that both TF and # TFF understand), using the field order stored in the TFF type stored during # serialization. ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( ds, element_type) return ds, computation_types.SequenceType(element=element_type)
def test_basic_functionality_of_block_class(self): x = computation_building_blocks.Block([ ('x', computation_building_blocks.Reference('arg', (tf.int32, tf.int32))), ('y', computation_building_blocks.Selection( computation_building_blocks.Reference('x', (tf.int32, tf.int32)), index=0)) ], computation_building_blocks.Reference('y', tf.int32)) self.assertEqual(str(x.type_signature), 'int32') self.assertEqual([(k, v.tff_repr) for k, v in x.locals], [('x', 'arg'), ('y', 'x[0]')]) self.assertEqual(x.result.tff_repr, 'y') self.assertEqual( repr(x), 'Block([(\'x\', Reference(\'arg\', ' 'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)]))), ' '(\'y\', Selection(Reference(\'x\', ' 'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)])), ' 'index=0))], ' 'Reference(\'y\', TensorType(tf.int32)))') self.assertEqual(x.tff_repr, '(let x=arg,y=x[0] in y)') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'block') self.assertEqual(str(x_proto.block.result), str(x.result.proto)) for idx, loc_proto in enumerate(x_proto.block.local): loc_name, loc_value = x.locals[idx] self.assertEqual(loc_proto.name, loc_name) self.assertEqual(str(loc_proto.value), str(loc_value.proto)) self._serialize_deserialize_roundtrip_test(x)
def test_serialize_tensorflow_with_structured_type_signature(self): batch_type = collections.namedtuple('BatchType', ['x', 'y']) output_type = collections.namedtuple('OutputType', ['A', 'B']) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda z: output_type(2.0 * tf.cast(z.x, tf.float32), 3.0 * z.y), batch_type(tf.int32, (tf.float32, [2])), context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') self.assertEqual( str(extra_type_spec), '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)') self.assertIsInstance( extra_type_spec.parameter, computation_types.NamedTupleTypeWithPyContainerType) self.assertIs( computation_types.NamedTupleTypeWithPyContainerType. get_container_type(extra_type_spec.parameter), batch_type) self.assertIsInstance( extra_type_spec.result, computation_types.NamedTupleTypeWithPyContainerType) self.assertIs( computation_types.NamedTupleTypeWithPyContainerType. get_container_type(extra_type_spec.result), output_type)
def test_basic_functionality_of_call_class(self): x = computation_building_blocks.Reference( 'foo', computation_types.FunctionType(tf.int32, tf.bool)) y = computation_building_blocks.Reference('bar', tf.int32) z = computation_building_blocks.Call(x, y) self.assertEqual(str(z.type_signature), 'bool') self.assertIs(z.function, x) self.assertIs(z.argument, y) self.assertEqual( repr(z), 'Call(Reference(\'foo\', ' 'FunctionType(TensorType(tf.int32), TensorType(tf.bool))), ' 'Reference(\'bar\', TensorType(tf.int32)))') self.assertEqual(z.tff_repr, 'foo(bar)') with self.assertRaises(TypeError): computation_building_blocks.Call(x) w = computation_building_blocks.Reference('bak', tf.float32) with self.assertRaises(TypeError): computation_building_blocks.Call(x, w) z_proto = z.proto self.assertEqual(type_serialization.deserialize_type(z_proto.type), z.type_signature) self.assertEqual(z_proto.WhichOneof('computation'), 'call') self.assertEqual(str(z_proto.call.function), str(x.proto)) self.assertEqual(str(z_proto.call.argument), str(y.proto)) self._serialize_deserialize_roundtrip_test(z)
def test_intrinsic_class_succeeds_simple_federated_map(self): simple_function = computation_types.FunctionType(tf.int32, tf.float32) federated_arg = computation_types.FederatedType(simple_function.parameter, placements.CLIENTS) federated_result = computation_types.FederatedType(simple_function.result, placements.CLIENTS) federated_map_concrete_type = computation_types.FunctionType( [simple_function, federated_arg], federated_result) concrete_federated_map = building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, federated_map_concrete_type) self.assertIsInstance(concrete_federated_map, building_blocks.Intrinsic) self.assertEqual( str(concrete_federated_map.type_signature), '(<(int32 -> float32),{int32}@CLIENTS> -> {float32}@CLIENTS)') self.assertEqual(concrete_federated_map.uri, 'federated_map') self.assertEqual(concrete_federated_map.compact_representation(), 'federated_map') concrete_federated_map_proto = concrete_federated_map.proto self.assertEqual( type_serialization.deserialize_type(concrete_federated_map_proto.type), concrete_federated_map.type_signature) self.assertEqual( concrete_federated_map_proto.WhichOneof('computation'), 'intrinsic') self.assertEqual(concrete_federated_map_proto.intrinsic.uri, concrete_federated_map.uri) self._serialize_deserialize_roundtrip_test(concrete_federated_map)
def from_proto(cls, computation_proto): _check_computation_oneof(computation_proto, 'placement') py_typecheck.check_type( type_serialization.deserialize_type(computation_proto.type), computation_types.PlacementType) return cls( placement_literals.uri_to_placement_literal( str(computation_proto.placement.uri)))
def from_proto(cls, computation_proto): _check_computation_oneof(computation_proto, 'lambda') the_lambda = getattr(computation_proto, 'lambda') return cls( str(the_lambda.parameter_name), type_serialization.deserialize_type( computation_proto.type.function.parameter), ComputationBuildingBlock.from_proto(the_lambda.result))
def __init__(self, value, scope=None, type_spec=None): """Creates an instance of a value embedded in a lambda executor. The internal representation of a value can take one of the following supported forms: * An instance of `executor_value_base.ExecutorValue` that represents a value embedded in the target executor (functional or non-functional). * An as-yet unprocessed instance of `pb.Computation` that represents a function yet to be invoked (always a value of a functional type; any non-functional constructs should be processed on the fly). * A coroutine callable in Python that accepts a single argument that must be an instance of `LambdaExecutorValue` (or `None`), and that returns a result that is also an instance of `LambdaExecutorValue`. The associated type signature is always functional. * A single-level tuple (`anonymous_tuple.AnonymousTuple`) of instances of this class (of any of the supported forms listed here). Args: value: The internal representation of a value, as specified above. scope: An optional scope for computations. Only allowed if `value` is an unprocessed instance of `pb.Computation`, otherwise it must be `None` (the scope is meaningless in other cases). type_spec: An optional type signature, only allowed if `value` is a callable that represents a function (in which case it must be an instance of `computation_types.FunctionType`), otherwise it must be `None` (the type is implied in other cases). """ if isinstance(value, executor_value_base.ExecutorValue): py_typecheck.check_none(scope) py_typecheck.check_none(type_spec) type_spec = value.type_signature elif isinstance(value, pb.Computation): if scope is not None: py_typecheck.check_type(scope, LambdaExecutorScope) py_typecheck.check_none(type_spec) type_spec = type_utils.get_function_type( type_serialization.deserialize_type(value.type)) elif callable(value): py_typecheck.check_none(scope) py_typecheck.check_type(type_spec, computation_types.FunctionType) else: py_typecheck.check_type(value, anonymous_tuple.AnonymousTuple) py_typecheck.check_none(scope) py_typecheck.check_none(type_spec) type_elements = [] for k, v in anonymous_tuple.to_elements(value): py_typecheck.check_type(v, LambdaExecutorValue) type_elements.append((k, v.type_signature)) type_spec = computation_types.NamedTupleType([ (k, v) if k is not None else v for k, v in type_elements ]) self._value = value self._scope = scope self._type_signature = type_spec
def test_serialize_tensorflow_with_no_parameter(self): comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(99), None, context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '( -> int32)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') results = tf.Session().run( tf.import_graph_def(comp.tensorflow.graph_def, None, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [99])
def test_basic_functionality_of_placement_class(self): x = computation_building_blocks.Placement(placements.CLIENTS) self.assertEqual(str(x.type_signature), 'placement') self.assertEqual(x.uri, 'clients') self.assertEqual(repr(x), 'Placement(\'clients\')') self.assertEqual(x.tff_repr, 'CLIENTS') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'placement') self.assertEqual(x_proto.placement.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def test_basic_functionality_of_reference_class(self): x = computation_building_blocks.Reference('foo', tf.int32) self.assertEqual(x.name, 'foo') self.assertEqual(str(x.type_signature), 'int32') self.assertEqual(repr(x), 'Reference(\'foo\', TensorType(tf.int32))') self.assertEqual(x.tff_repr, 'foo') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'reference') self.assertEqual(x_proto.reference.name, x.name) self._serialize_deserialize_roundtrip_test(x)
def test_serialize_tensorflow_with_simple_add_three_lambda(self): comp = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda x: x + 3, tf.int32, context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '(int32 -> int32)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.constant(1000) results = tf.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def(comp.tensorflow.graph_def), {comp.tensorflow.parameter.tensor.tensor_name: parameter}, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [1003])
def test_basic_intrinsic_functionality_plus_canonical_typecheck(self): x = building_blocks.Intrinsic( 'generic_plus', computation_types.FunctionType([tf.int32, tf.int32], tf.int32)) self.assertEqual(str(x.type_signature), '(<int32,int32> -> int32)') self.assertEqual(x.uri, 'generic_plus') self.assertEqual(x.compact_representation(), 'generic_plus') x_proto = x.proto self.assertEqual( type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic') self.assertEqual(x_proto.intrinsic.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def test_basic_functionality_of_data_class(self): x = building_blocks.Data('/tmp/mydata', computation_types.SequenceType(tf.int32)) self.assertEqual(str(x.type_signature), 'int32*') self.assertEqual(x.uri, '/tmp/mydata') self.assertEqual( repr(x), 'Data(\'/tmp/mydata\', SequenceType(TensorType(tf.int32)))') self.assertEqual(x.compact_representation(), '/tmp/mydata') x_proto = x.proto self.assertEqual( type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'data') self.assertEqual(x_proto.data.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def test_basic_functionality_of_intrinsic_class(self): x = computation_building_blocks.Intrinsic( 'add_one', computation_types.FunctionType(tf.int32, tf.int32)) self.assertEqual(str(x.type_signature), '(int32 -> int32)') self.assertEqual(x.uri, 'add_one') self.assertEqual( repr(x), 'Intrinsic(\'add_one\', ' 'FunctionType(TensorType(tf.int32), TensorType(tf.int32)))') self.assertEqual(x.tff_repr, 'add_one') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic') self.assertEqual(x_proto.intrinsic.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def _serialize_deserialize_roundtrip_test(self, type_list): """Performs roundtrip serialization/deserialization of computation_types. Args: type_list: A list of instances of computation_types.Type or things convertible to it. """ for t in type_list: t1 = computation_types.to_type(t) p1 = type_serialization.serialize_type(t1) t2 = type_serialization.deserialize_type(p1) p2 = type_serialization.serialize_type(t2) self.assertEqual(repr(t1), repr(t2)) self.assertEqual(repr(p1), repr(p2)) self.assertTrue(type_utils.are_equivalent_types(t1, t2))
def wrap_graph_parameter_as_tuple(comp, name=None): """Wraps the parameter of `comp` in a tuple binding. `wrap_graph_parameter_as_tuple` is intended as a preprocessing step to `pad_graph_inputs_to_match_type`, so that `pad_graph_inputs_to_match_type` can make the assumption that its argument `comp` always has a tuple binding, instead of dealing with the possibility of an unwrapped tensor or sequence binding. Args: comp: Instance of `computation_building_blocks.CompiledComputation` whose parameter we wish to wrap in a tuple binding. name: Optional string argument, the name to assign to the element type in the constructed tuple. Defaults to `None`. Returns: A transformed version of comp representing exactly the same computation, but accepting a tuple containing one element--the parameter of `comp`. Raises: TypeError: If `comp` is not a `computation_building_blocks.CompiledComputation`. """ py_typecheck.check_type(comp, computation_building_blocks.CompiledComputation) if name is not None: py_typecheck.check_type(name, six.string_types) proto = comp.proto proto_type = type_serialization.deserialize_type(proto.type) parameter_binding = [proto.tensorflow.parameter] parameter_type_list = [(name, proto_type.parameter)] new_parameter_binding = pb.TensorFlow.Binding( tuple=pb.TensorFlow.NamedTupleBinding(element=parameter_binding)) new_function_type = computation_types.FunctionType(parameter_type_list, proto_type.result) serialized_type = type_serialization.serialize_type(new_function_type) input_padded_proto = pb.Computation( type=serialized_type, tensorflow=pb.TensorFlow(graph_def=proto.tensorflow.graph_def, initialize_op=proto.tensorflow.initialize_op, parameter=new_parameter_binding, result=proto.tensorflow.result)) return computation_building_blocks.CompiledComputation(input_padded_proto)
def test_basic_functionality_of_selection_class(self): x = computation_building_blocks.Reference('foo', [('bar', tf.int32), ('baz', tf.bool)]) y = computation_building_blocks.Selection(x, name='bar') self.assertEqual(y.name, 'bar') self.assertEqual(y.index, None) self.assertEqual(str(y.type_signature), 'int32') self.assertEqual( repr(y), 'Selection(Reference(\'foo\', NamedTupleType([' '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))' ', name=\'bar\')') self.assertEqual(computation_building_blocks.compact_representation(y), 'foo.bar') z = computation_building_blocks.Selection(x, name='baz') self.assertEqual(str(z.type_signature), 'bool') self.assertEqual(computation_building_blocks.compact_representation(z), 'foo.baz') with self.assertRaises(ValueError): _ = computation_building_blocks.Selection(x, name='bak') x0 = computation_building_blocks.Selection(x, index=0) self.assertEqual(x0.name, None) self.assertEqual(x0.index, 0) self.assertEqual(str(x0.type_signature), 'int32') self.assertEqual( repr(x0), 'Selection(Reference(\'foo\', NamedTupleType([' '(\'bar\', TensorType(tf.int32)), (\'baz\', TensorType(tf.bool))]))' ', index=0)') self.assertEqual( computation_building_blocks.compact_representation(x0), 'foo[0]') x1 = computation_building_blocks.Selection(x, index=1) self.assertEqual(str(x1.type_signature), 'bool') self.assertEqual( computation_building_blocks.compact_representation(x1), 'foo[1]') with self.assertRaises(ValueError): _ = computation_building_blocks.Selection(x, index=2) with self.assertRaises(ValueError): _ = computation_building_blocks.Selection(x, index=-1) y_proto = y.proto self.assertEqual(type_serialization.deserialize_type(y_proto.type), y.type_signature) self.assertEqual(y_proto.WhichOneof('computation'), 'selection') self.assertEqual(str(y_proto.selection.source), str(x.proto)) self.assertEqual(y_proto.selection.name, 'bar') self._serialize_deserialize_roundtrip_test(y) self._serialize_deserialize_roundtrip_test(z) self._serialize_deserialize_roundtrip_test(x0) self._serialize_deserialize_roundtrip_test(x1)
def serialize_value(value, type_spec=None): """Serializes a value into `executor_pb2.Value`. Args: value: A value to be serialized. type_spec: Optional type spec, a `tff.Type` or something convertible to it. Returns: A tuple `(value_proto, ret_type_spec)` where `value_proto` is an instance of `executor_pb2.Value` with the serialized content of `value`, and the returned `ret_type_spec` is an instance of `tff.Type` that represents the TFF type of the serialized value. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the value is malformed. """ type_spec = computation_types.to_type(type_spec) if isinstance(value, computation_pb2.Computation): type_spec = type_utils.reconcile_value_type_with_type_spec( type_serialization.deserialize_type(value.type), type_spec) return executor_pb2.Value(computation=value), type_spec elif isinstance(value, computation_impl.ComputationImpl): return serialize_value( computation_impl.ComputationImpl.get_proto(value), type_utils.reconcile_value_with_type_spec(value, type_spec)) elif isinstance(type_spec, computation_types.TensorType): return serialize_tensor_value(value, type_spec) elif isinstance(type_spec, computation_types.NamedTupleType): type_elements = anonymous_tuple.to_elements(type_spec) val_elements = anonymous_tuple.to_elements( anonymous_tuple.from_container(value)) tup_elems = [] for (e_name, e_type), (_, e_val) in zip(type_elements, val_elements): e_proto, _ = serialize_value(e_val, e_type) tup_elems.append( executor_pb2.Value.Tuple.Element( name=e_name if e_name else None, value=e_proto)) result_proto = (executor_pb2.Value(tuple=executor_pb2.Value.Tuple( element=tup_elems))) return result_proto, type_spec else: raise ValueError( 'Unable to serialize value with Python type {} and {} TFF type.'. format(str(py_typecheck.type_string(type(value))), str(type_spec) if type_spec is not None else 'unknown'))
def test_serialize_tensorflow_with_data_set_sum_lambda(self): def _legacy_dataset_reducer_example(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) comp = tensorflow_serialization.serialize_py_func_as_tf_computation( _legacy_dataset_reducer_example, computation_types.SequenceType(tf.int64), context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(int64* -> int64)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.data.Dataset.range(5) results = tf.Session().run( tf.import_graph_def( comp.tensorflow.graph_def, { comp.tensorflow.parameter.sequence.iterator_string_handle_name: (parameter.make_one_shot_iterator().string_handle()) }, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [10])
def __init__(self, computation_proto, context_stack, annotated_type=None): """Constructs a new instance of ComputationImpl from the computation_proto. Args: computation_proto: The protocol buffer that represents the computation, an instance of pb.Computation. context_stack: The context stack to use. annotated_type: Optional, type information with additional annotations that replaces the information in `computation_proto.type`. Raises: TypeError: if `annotated_type` is not `None` and is not compatible with `computation_proto.type`. """ py_typecheck.check_type(computation_proto, pb.Computation) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) type_spec = type_serialization.deserialize_type(computation_proto.type) py_typecheck.check_type(type_spec, computation_types.Type) if annotated_type is not None: py_typecheck.check_type(annotated_type, computation_types.Type) # Extra information is encoded in a NamedTupleTypeWithPyContainerType # subclass which does not override __eq__. The two type specs should still # compare as equal. if type_spec != annotated_type: raise TypeError( 'annotated_type not compatible with computation_proto.type\n' 'computation_proto.type: {!s}\n' 'annotated_type: {!s}'.format(type_spec, annotated_type) ) type_spec = annotated_type type_utils.check_well_formed(type_spec) # We may need to modify the type signature to reflect the fact that in the # underlying framework for composing computations, there is no concept of # no-argument lambdas, but in Python, every computation needs to look like # a function that needs to be invoked. if not isinstance(type_spec, computation_types.FunctionType): type_spec = computation_types.FunctionType(None, type_spec) super(ComputationImpl, self).__init__(type_spec, context_stack) self._computation_proto = computation_proto
def test_serialize_tensorflow_with_data_set_sum_lambda(self): def _legacy_dataset_reducer_example(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( _legacy_dataset_reducer_example, computation_types.SequenceType(tf.int64), context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(int64* -> int64)') self.assertEqual(str(extra_type_spec), '(int64* -> int64)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.data.Dataset.range(5) results = tf.compat.v1.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def( comp.tensorflow.graph_def), { comp.tensorflow.parameter.sequence.variant_tensor_name: tf.data.experimental.to_variant(parameter) }, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [10])
def deserialize_value(value_proto): """Deserializes a value (of any type) from `executor_pb2.Value`. Args: value_proto: An instance of `executor_pb2.Value`. Returns: A tuple `(value, type_spec)`, where `value` is a deserialized representation of the transmitted value (e.g., Numpy array, or a `pb.Computation` instance), and `type_spec` is an instance of `tff.TensorType` that represents its type. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the value is malformed. """ py_typecheck.check_type(value_proto, executor_pb2.Value) which_value = value_proto.WhichOneof('value') if which_value == 'tensor': return deserialize_tensor_value(value_proto) elif which_value == 'computation': return (value_proto.computation, type_serialization.deserialize_type( value_proto.computation.type)) elif which_value == 'tuple': val_elems = [] type_elems = [] for e in value_proto.tuple.element: name = e.name if e.name else None e_val, e_type = deserialize_value(e.value) val_elems.append((name, e_val)) type_elems.append((name, e_type) if name else e_type) return (anonymous_tuple.AnonymousTuple(val_elems), computation_types.NamedTupleType(type_elems)) elif which_value == 'sequence': return deserialize_sequence_value(value_proto.sequence) else: raise ValueError( 'Unable to deserialize a value of type {}.'.format(which_value))
def __init__(self, proto, name=None): """Creates a representation of a fully constructed computation. Args: proto: An instance of pb.Computation with the computation logic. name: An optional string name to associate with this computation, used only for debugging purposes. If the name is not specified (None), it is autogenerated as a hexadecimal string from the hash of the proto. Raises: TypeError: if the arguments are of the wrong types. """ py_typecheck.check_type(proto, pb.Computation) if name is not None: py_typecheck.check_type(name, six.string_types) super(CompiledComputation, self).__init__(type_serialization.deserialize_type(proto.type)) self._proto = proto if name is not None: self._name = name else: self._name = '{:x}'.format( zlib.adler32(six.b(repr(self._proto))) & 0xFFFFFFFF)
def __init__(self, computation_proto, context_stack): """Constructs a new instance of ComputationImpl from the computation_proto. Args: computation_proto: The protocol buffer that represents the computation, an instance of pb.Computation. context_stack: The context stack to use. """ py_typecheck.check_type(computation_proto, pb.Computation) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) type_spec = type_serialization.deserialize_type(computation_proto.type) py_typecheck.check_type(type_spec, computation_types.Type) type_utils.check_well_formed(type_spec) # We may need to modify the type signature to reflect the fact that in the # underlying framework for composing computations, there is no concept of # no-argument lambdas, but in Python, every computation needs to look like # a function that needs to be invoked. if not isinstance(type_spec, computation_types.FunctionType): type_spec = computation_types.FunctionType(None, type_spec) super(ComputationImpl, self).__init__(type_spec, context_stack) self._computation_proto = computation_proto