def deserialize_value(value_proto): """Deserializes a value (of any type) from `executor_pb2.Value`. Args: value_proto: An instance of `executor_pb2.Value`. Returns: A tuple `(value, type_spec)`, where `value` is a deserialized representation of the transmitted value (e.g., Numpy array, or a `pb.Computation` instance), and `type_spec` is an instance of `tff.TensorType` that represents its type. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the value is malformed. """ py_typecheck.check_type(value_proto, executor_pb2.Value) which_value = value_proto.WhichOneof('value') if which_value == 'tensor': return deserialize_tensor_value(value_proto) elif which_value == 'computation': return (value_proto.computation, type_serialization.deserialize_type( value_proto.computation.type)) elif which_value == 'tuple': val_elems = [] type_elems = [] for e in value_proto.tuple.element: name = e.name if e.name else None e_val, e_type = deserialize_value(e.value) val_elems.append((name, e_val)) type_elems.append((name, e_type) if name else e_type) return (anonymous_tuple.AnonymousTuple(val_elems), computation_types.NamedTupleType(type_elems)) elif which_value == 'sequence': return deserialize_sequence_value(value_proto.sequence) elif which_value == 'federated': type_spec = type_serialization.deserialize_type( computation_pb2.Type(federated=value_proto.federated.type)) value = [] for item in value_proto.federated.value: item_value, item_type = deserialize_value(item) type_utils.check_assignable_from(type_spec.member, item_type) value.append(item_value) if type_spec.all_equal: if len(value) == 1: value = value[0] else: raise ValueError( 'Return an all_equal value with {} member consatituents.'. format(len(value))) return value, type_spec else: raise ValueError( 'Unable to deserialize a value of type {}.'.format(which_value))
def test_basic_functionality_of_call_class(self): x = building_blocks.Reference( 'foo', computation_types.FunctionType(tf.int32, tf.bool)) y = building_blocks.Reference('bar', tf.int32) z = building_blocks.Call(x, y) self.assertEqual(str(z.type_signature), 'bool') self.assertIs(z.function, x) self.assertIs(z.argument, y) self.assertEqual( repr(z), 'Call(Reference(\'foo\', ' 'FunctionType(TensorType(tf.int32), TensorType(tf.bool))), ' 'Reference(\'bar\', TensorType(tf.int32)))') self.assertEqual(z.compact_representation(), 'foo(bar)') with self.assertRaises(TypeError): building_blocks.Call(x) w = building_blocks.Reference('bak', tf.float32) with self.assertRaises(TypeError): building_blocks.Call(x, w) z_proto = z.proto self.assertEqual(type_serialization.deserialize_type(z_proto.type), z.type_signature) self.assertEqual(z_proto.WhichOneof('computation'), 'call') self.assertEqual(str(z_proto.call.function), str(x.proto)) self.assertEqual(str(z_proto.call.argument), str(y.proto)) self._serialize_deserialize_roundtrip_test(z)
def test_basic_functionality_of_block_class(self): x = building_blocks.Block([ ('x', building_blocks.Reference('arg', (tf.int32, tf.int32))), ('y', building_blocks.Selection( building_blocks.Reference('x', (tf.int32, tf.int32)), index=0)) ], building_blocks.Reference('y', tf.int32)) self.assertEqual(str(x.type_signature), 'int32') self.assertEqual([(k, v.compact_representation()) for k, v in x.locals], [('x', 'arg'), ('y', 'x[0]')]) self.assertEqual(x.result.compact_representation(), 'y') self.assertEqual( repr(x), 'Block([(\'x\', Reference(\'arg\', ' 'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)]))), ' '(\'y\', Selection(Reference(\'x\', ' 'NamedTupleType([TensorType(tf.int32), TensorType(tf.int32)])), ' 'index=0))], ' 'Reference(\'y\', TensorType(tf.int32)))') self.assertEqual(x.compact_representation(), '(let x=arg,y=x[0] in y)') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'block') self.assertEqual(str(x_proto.block.result), str(x.result.proto)) for idx, loc_proto in enumerate(x_proto.block.local): loc_name, loc_value = x.locals[idx] self.assertEqual(loc_proto.name, loc_name) self.assertEqual(str(loc_proto.value), str(loc_value.proto)) self._serialize_deserialize_roundtrip_test(x)
def test_intrinsic_class_succeeds_simple_federated_map(self): simple_function = computation_types.FunctionType(tf.int32, tf.float32) federated_arg = computation_types.FederatedType( simple_function.parameter, placements.CLIENTS) federated_result = computation_types.FederatedType( simple_function.result, placements.CLIENTS) federated_map_concrete_type = computation_types.FunctionType( [simple_function, federated_arg], federated_result) concrete_federated_map = building_blocks.Intrinsic( intrinsic_defs.FEDERATED_MAP.uri, federated_map_concrete_type) self.assertIsInstance(concrete_federated_map, building_blocks.Intrinsic) self.assertEqual( str(concrete_federated_map.type_signature), '(<(int32 -> float32),{int32}@CLIENTS> -> {float32}@CLIENTS)') self.assertEqual(concrete_federated_map.uri, 'federated_map') self.assertEqual(concrete_federated_map.compact_representation(), 'federated_map') concrete_federated_map_proto = concrete_federated_map.proto self.assertEqual( type_serialization.deserialize_type( concrete_federated_map_proto.type), concrete_federated_map.type_signature) self.assertEqual( concrete_federated_map_proto.WhichOneof('computation'), 'intrinsic') self.assertEqual(concrete_federated_map_proto.intrinsic.uri, concrete_federated_map.uri) self._serialize_deserialize_roundtrip_test(concrete_federated_map)
def test_basic_functionality_of_tuple_class(self): x = building_blocks.Reference('foo', tf.int32) y = building_blocks.Reference('bar', tf.bool) z = building_blocks.Tuple([x, ('y', y)]) with self.assertRaises(ValueError): _ = building_blocks.Tuple([('', y)]) self.assertIsInstance(z, anonymous_tuple.AnonymousTuple) self.assertEqual(str(z.type_signature), '<int32,y=bool>') self.assertEqual( repr(z), 'Tuple([(None, Reference(\'foo\', TensorType(tf.int32))), (\'y\', ' 'Reference(\'bar\', TensorType(tf.bool)))])') self.assertEqual(z.compact_representation(), '<foo,y=bar>') self.assertEqual(dir(z), ['y']) self.assertIs(z.y, y) self.assertLen(z, 2) self.assertIs(z[0], x) self.assertIs(z[1], y) self.assertEqual(','.join(e.compact_representation() for e in iter(z)), 'foo,bar') z_proto = z.proto self.assertEqual(type_serialization.deserialize_type(z_proto.type), z.type_signature) self.assertEqual(z_proto.WhichOneof('computation'), 'tuple') self.assertEqual([e.name for e in z_proto.tuple.element], ['', 'y']) self._serialize_deserialize_roundtrip_test(z)
def run_tensorflow(computation_proto, arg=None): """Runs a TensorFlow computation with argument `arg`. Args: computation_proto: An instance of `pb.Computation`. arg: The argument to invoke the computation with, or None if the computation does not specify a parameter type and does not expects one. Returns: The result of the computation. """ with tf.Graph().as_default() as graph: type_signature = type_serialization.deserialize_type( computation_proto.type) if type_signature.parameter is not None: stamped_arg = _stamp_value_into_graph(arg, type_signature.parameter, graph) else: stamped_arg = None init_op, result = tensorflow_deserialization.deserialize_and_call_tf_computation( computation_proto, stamped_arg, graph) with tf.compat.v1.Session(graph=graph) as sess: if init_op: sess.run(init_op) result = tensorflow_utils.fetch_value_in_session(sess, result) return result
def test_returns_coputation(self): proto = computation_factory.create_lambda_empty_tuple() self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, []) self.assertEqual(actual_type, expected_type)
def test_serialize_tensorflow_with_structured_type_signature(self): batch_type = collections.namedtuple('BatchType', ['x', 'y']) output_type = collections.namedtuple('OutputType', ['A', 'B']) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda z: output_type(2.0 * tf.cast(z.x, tf.float32), 3.0 * z.y), batch_type(tf.int32, (tf.float32, [2])), context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') self.assertEqual( str(extra_type_spec), '(<x=int32,y=float32[2]> -> <A=float32,B=float32[2]>)') self.assertIsInstance( extra_type_spec.parameter, computation_types.NamedTupleTypeWithPyContainerType) self.assertIs( computation_types.NamedTupleTypeWithPyContainerType. get_container_type(extra_type_spec.parameter), batch_type) self.assertIsInstance( extra_type_spec.result, computation_types.NamedTupleTypeWithPyContainerType) self.assertIs( computation_types.NamedTupleTypeWithPyContainerType. get_container_type(extra_type_spec.result), output_type)
def test_serialize_tensorflow_with_dataset_not_optimized(self): @tf.function def test_foo(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) def legacy_dataset_reducer_example(ds): return test_foo(ds) comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( legacy_dataset_reducer_example, computation_types.SequenceType(tf.int64), context_stack_impl.context_stack) self.assertEqual(str(type_serialization.deserialize_type(comp.type)), '(int64* -> int64)') self.assertEqual(str(extra_type_spec), '(int64* -> int64)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.data.Dataset.range(5) graph_def = serialization_utils.unpack_graph_def( comp.tensorflow.graph_def) self.assertGraphDoesNotContainOps(graph_def, ['OptimizeDataset', 'ModelDataste']) results = tf.compat.v1.Session().run( tf.import_graph_def( graph_def, { comp.tensorflow.parameter.sequence.variant_tensor_name: tf.data.experimental.to_variant(parameter) }, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [10])
def deserialize_sequence_value(sequence_value_proto): """Deserializes a `tf.data.Dataset`. Args: sequence_value_proto: `Sequence` protocol buffer message. Returns: A tuple of `(tf.data.Dataset, tff.Type)`. """ py_typecheck.check_type(sequence_value_proto, executor_pb2.Value.Sequence) which_value = sequence_value_proto.WhichOneof('value') if which_value == 'zipped_saved_model': ds = tensorflow_serialization.deserialize_dataset( sequence_value_proto.zipped_saved_model) else: raise NotImplementedError( 'Deserializing Sequences enocded as {!s} has not been implemented' .format(which_value)) element_type = type_serialization.deserialize_type( sequence_value_proto.element_type) # If a serialized dataset had elements of nested structes of tensors (e.g. # `dict`, `OrderedDict`), the deserialized dataset will return `dict`, # `tuple`, or `namedtuple` (loses `collections.OrderedDict` in a conversion). # # Since the dataset will only be used inside TFF, we wrap the dictionary # coming from TF in an `OrderedDict` when necessary (a type that both TF and # TFF understand), using the field order stored in the TFF type stored during # serialization. ds = tensorflow_utils.coerce_dataset_elements_to_tff_type_spec( ds, element_type) return ds, computation_types.SequenceType(element=element_type)
def __init__(self, value, scope=None, type_spec=None): """Creates an instance of a value embedded in a lambda executor. The internal representation of a value can take one of the following supported forms: * An instance of `executor_value_base.ExecutorValue` that represents a value embedded in the target executor (functional or non-functional). * An as-yet unprocessed instance of `pb.Computation` that represents a function yet to be invoked (always a value of a functional type; any non-functional constructs should be processed on the fly). * A coroutine callable in Python that accepts a single argument that must be an instance of `LambdaExecutorValue` (or `None`), and that returns a result that is also an instance of `LambdaExecutorValue`. The associated type signature is always functional. * A single-level tuple (`anonymous_tuple.AnonymousTuple`) of instances of this class (of any of the supported forms listed here). Args: value: The internal representation of a value, as specified above. scope: An optional scope for computations. Only allowed if `value` is an unprocessed instance of `pb.Computation`, otherwise it must be `None` (the scope is meaningless in other cases). type_spec: An optional type signature, only allowed if `value` is a callable that represents a function (in which case it must be an instance of `computation_types.FunctionType`), otherwise it must be `None` (the type is implied in other cases). """ if isinstance(value, executor_value_base.ExecutorValue): py_typecheck.check_none(scope) py_typecheck.check_none(type_spec) type_spec = value.type_signature # pytype: disable=attribute-error elif isinstance(value, pb.Computation): if scope is not None: py_typecheck.check_type(scope, LambdaExecutorScope) py_typecheck.check_none(type_spec) type_spec = type_utils.get_function_type( type_serialization.deserialize_type(value.type)) elif callable(value): py_typecheck.check_none(scope) py_typecheck.check_type(type_spec, computation_types.FunctionType) else: py_typecheck.check_type(value, anonymous_tuple.AnonymousTuple) py_typecheck.check_none(scope) py_typecheck.check_none(type_spec) type_elements = [] for k, v in anonymous_tuple.iter_elements(value): py_typecheck.check_type(v, LambdaExecutorValue) type_elements.append((k, v.type_signature)) type_spec = computation_types.NamedTupleType([ (k, v) if k is not None else v for k, v in type_elements ]) self._value = value self._scope = scope self._type_signature = type_spec
def test_returns_computation_sequence(self): type_signature = computation_types.SequenceType(tf.int32) proto = computation_factory.create_lambda_identity(type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = type_factory.unary_op(type_signature) self.assertEqual(actual_type, expected_type)
def test_returns_computation(self, type_signature, value): proto = tensorflow_computation_factory.create_identity(type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = type_factory.unary_op(type_signature) self.assertEqual(actual_type, expected_type) actual_result = test_utils.run_tensorflow(proto, value) self.assertEqual(actual_result, value)
def test_returns_computation_tuple_named(self): type_signature = computation_types.NamedTupleType([('a', tf.int32), ('b', tf.float32)]) proto = computation_factory.create_lambda_identity(type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = type_factory.unary_op(type_signature) self.assertEqual(actual_type, expected_type)
def test_returns_coputation(self): proto = tensorflow_computation_factory.create_empty_tuple() self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, []) self.assertEqual(actual_type, expected_type) expected_value = anonymous_tuple.AnonymousTuple([]) actual_value = test_utils.run_tensorflow(proto, expected_value) self.assertEqual(actual_value, expected_value)
def test_basic_functionality_of_placement_class(self): x = building_blocks.Placement(placements.CLIENTS) self.assertEqual(str(x.type_signature), 'placement') self.assertEqual(x.uri, 'clients') self.assertEqual(repr(x), 'Placement(\'clients\')') self.assertEqual(x.compact_representation(), 'CLIENTS') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'placement') self.assertEqual(x_proto.placement.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def test_returns_computation_sequence(self): type_signature = computation_types.SequenceType(tf.int32) proto = tensorflow_computation_factory.create_identity(type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = type_factory.unary_op(type_signature) self.assertEqual(actual_type, expected_type) expected_value = [10] * 3 actual_value = test_utils.run_tensorflow(proto, expected_value) self.assertEqual(actual_value, expected_value)
def test_basic_functionality_of_reference_class(self): x = building_blocks.Reference('foo', tf.int32) self.assertEqual(x.name, 'foo') self.assertEqual(str(x.type_signature), 'int32') self.assertEqual(repr(x), 'Reference(\'foo\', TensorType(tf.int32))') self.assertEqual(x.compact_representation(), 'foo') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'reference') self.assertEqual(x_proto.reference.name, x.name) self._serialize_deserialize_roundtrip_test(x)
def test_serialize_tensorflow_with_no_parameter(self): comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(99), None, context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '( -> int32)') self.assertEqual(str(extra_type_spec), '( -> int32)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') results = tf.compat.v1.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def(comp.tensorflow.graph_def), None, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [99])
def test_basic_intrinsic_functionality_plus_canonical_typecheck(self): x = building_blocks.Intrinsic( 'generic_plus', computation_types.FunctionType([tf.int32, tf.int32], tf.int32)) self.assertEqual(str(x.type_signature), '(<int32,int32> -> int32)') self.assertEqual(x.uri, 'generic_plus') self.assertEqual(x.compact_representation(), 'generic_plus') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic') self.assertEqual(x_proto.intrinsic.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def test_returns_computation_with_tensor_float(self): value = 10.0 type_signature = computation_types.TensorType(tf.float32, [3]) proto = tensorflow_computation_factory.create_constant( value, type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, type_signature) self.assertEqual(actual_type, expected_type) expected_value = [value] * 3 actual_value = test_utils.run_tensorflow(proto, expected_value) self.assertCountEqual(actual_value, expected_value)
def test_returns_computation_with_tuple_unnamed(self): value = 10 type_signature = computation_types.NamedTupleType([tf.int32] * 3) proto = tensorflow_computation_factory.create_constant( value, type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, type_signature) self.assertEqual(actual_type, expected_type) expected_value = [value] * 3 actual_value = test_utils.run_tensorflow(proto, expected_value) self.assertCountEqual(actual_value, expected_value)
def test_returns_computation(self, type_signature, count, value): proto = tensorflow_computation_factory.create_replicate_input( type_signature, count) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType( type_signature, [type_signature] * count) self.assertEqual(actual_type, expected_type) actual_result = test_utils.run_tensorflow(proto, value) expected_result = anonymous_tuple.AnonymousTuple([(None, value)] * count) self.assertEqual(actual_result, expected_result)
def test_returns_computation_tuple_named(self): type_signature = [('a', tf.int32), ('b', tf.float32)] proto = tensorflow_computation_factory.create_identity(type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = type_factory.unary_op(type_signature) self.assertEqual(actual_type, expected_type) expected_value = anonymous_tuple.AnonymousTuple([('a', 10), ('b', 10.0)]) actual_value = test_utils.run_tensorflow(proto, expected_value) self.assertEqual(actual_value, expected_value)
def test_returns_computation(self, value, type_signature, expected_result): proto = tensorflow_computation_factory.create_constant( value, type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, type_signature) self.assertEqual(actual_type, expected_type) actual_result = test_utils.run_tensorflow(proto) if isinstance(expected_result, list): self.assertCountEqual(actual_result, expected_result) else: self.assertEqual(actual_result, expected_result)
def test_serialize_tensorflow_with_simple_add_three_lambda(self): comp, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda x: x + 3, tf.int32, context_stack_impl.context_stack) self.assertEqual( str(type_serialization.deserialize_type(comp.type)), '(int32 -> int32)') self.assertEqual(str(extra_type_spec), '(int32 -> int32)') self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') parameter = tf.constant(1000) results = tf.compat.v1.Session().run( tf.import_graph_def( serialization_utils.unpack_graph_def(comp.tensorflow.graph_def), {comp.tensorflow.parameter.tensor.tensor_name: parameter}, [comp.tensorflow.result.tensor.tensor_name])) self.assertEqual(results, [1003])
def _serialize_deserialize_roundtrip_test(self, type_list): """Performs roundtrip serialization/deserialization of computation_types. Args: type_list: A list of instances of computation_types.Type or things convertible to it. """ for t in type_list: t1 = computation_types.to_type(t) p1 = type_serialization.serialize_type(t1) t2 = type_serialization.deserialize_type(p1) p2 = type_serialization.serialize_type(t2) self.assertEqual(repr(t1), repr(t2)) self.assertEqual(repr(p1), repr(p2)) self.assertTrue(type_utils.are_equivalent_types(t1, t2))
def test_basic_functionality_of_intrinsic_class(self): x = building_blocks.Intrinsic( 'add_one', computation_types.FunctionType(tf.int32, tf.int32)) self.assertEqual(str(x.type_signature), '(int32 -> int32)') self.assertEqual(x.uri, 'add_one') self.assertEqual( repr(x), 'Intrinsic(\'add_one\', ' 'FunctionType(TensorType(tf.int32), TensorType(tf.int32)))') self.assertEqual(x.compact_representation(), 'add_one') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'intrinsic') self.assertEqual(x_proto.intrinsic.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def test_basic_functionality_of_data_class(self): x = building_blocks.Data('/tmp/mydata', computation_types.SequenceType(tf.int32)) self.assertEqual(str(x.type_signature), 'int32*') self.assertEqual(x.uri, '/tmp/mydata') self.assertEqual( repr(x), 'Data(\'/tmp/mydata\', SequenceType(TensorType(tf.int32)))') self.assertEqual(x.compact_representation(), '/tmp/mydata') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'data') self.assertEqual(x_proto.data.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
async def embed_tf_scalar_constant(executor, type_spec, value): """Embeds a constant `val` of TFF type `type_spec` in `executor`. Args: executor: An instance of `tff.framework.Executor`. type_spec: An instance of `tff.Type`. value: A scalar value. Returns: An instance of `tff.framework.ExecutorValue` containing an embedded value. """ py_typecheck.check_type(executor, executor_base.Executor) proto = tensorflow_computation_factory.create_constant(value, type_spec) type_signature = type_serialization.deserialize_type(proto.type) result = await executor.create_value(proto, type_signature) return await executor.create_call(result)