def test_deserialize_federated_value_promotes_types(self): x = [10] smaller_type = computation_types.StructType([ (None, computation_types.to_type(tf.int32)) ]) smaller_type_member_proto, _ = value_serialization.serialize_value( x, smaller_type) larger_type = computation_types.StructType([ ('a', computation_types.to_type(tf.int32)) ]) larger_type_member_proto, _ = value_serialization.serialize_value( x, larger_type) type_at_clients = type_serialization.serialize_type( computation_types.at_clients(tf.int32)) unspecified_member_federated_type = computation_pb2.FederatedType( placement=type_at_clients.federated.placement, all_equal=False) federated_proto = executor_pb2.Value.Federated( type=unspecified_member_federated_type, value=[larger_type_member_proto, smaller_type_member_proto]) federated_value_proto = executor_pb2.Value(federated=federated_proto) _, deserialized_type_spec = value_serialization.deserialize_value( federated_value_proto) type_test_utils.assert_types_identical( deserialized_type_spec, computation_types.at_clients(larger_type))
def test_deserialize_federated_value_with_incompatible_member_types_raises( self): x = 10 x_type = computation_types.to_type(tf.int32) int_member_proto, _ = executor_serialization.serialize_value(x, x_type) y = 10. y_type = computation_types.to_type(tf.float32) float_member_proto, _ = executor_serialization.serialize_value(y, y_type) fully_specified_type_at_clients = type_serialization.serialize_type( computation_types.at_clients(tf.int32)) unspecified_member_federated_type = computation_pb2.FederatedType( placement=fully_specified_type_at_clients.federated.placement, all_equal=False) federated_proto = executor_pb2.Value.Federated( type=unspecified_member_federated_type, value=[int_member_proto, float_member_proto]) federated_value_proto = executor_pb2.Value(federated=federated_proto) self.assertIsInstance(int_member_proto, executor_pb2.Value) self.assertIsInstance(float_member_proto, executor_pb2.Value) self.assertIsInstance(federated_value_proto, executor_pb2.Value) with self.assertRaises(TypeError): executor_serialization.deserialize_value(federated_value_proto)
def test_serialize_type_with_federated_bool(self): actual_proto = type_serialization.serialize_type( computation_types.FederatedType(tf.bool, placements.CLIENTS, True)) expected_proto = pb.Type(federated=pb.FederatedType( placement=pb.PlacementSpec(value=pb.Placement( uri=placements.CLIENTS.uri)), all_equal=True, member=_create_scalar_tensor_type(tf.bool))) self.assertEqual(actual_proto, expected_proto)
def serialize_type(type_spec): """Serializes 'type_spec' as a pb.Type. Note: Currently only serialization for tensor, named tuple, sequence, and function types is implemented. Args: type_spec: Either an instance of computation_types.Type, or something convertible to it by computation_types.to_type(), or None. Returns: The corresponding instance of `pb.Type`, or `None` if the argument was `None`. Raises: TypeError: if the argument is of the wrong type. NotImplementedError: for type variants for which serialization is not implemented. """ # TODO(b/113112885): Implement serialization of the remaining types. if type_spec is None: return None target = computation_types.to_type(type_spec) py_typecheck.check_type(target, computation_types.Type) if isinstance(target, computation_types.TensorType): return pb.Type(tensor=_to_tensor_type_proto(target)) elif isinstance(target, computation_types.SequenceType): return pb.Type( sequence=pb.SequenceType(element=serialize_type(target.element))) elif isinstance(target, computation_types.NamedTupleType): return pb.Type( tuple=pb.NamedTupleType(element=[ pb.NamedTupleType.Element(name=e[0], value=serialize_type(e[1])) for e in anonymous_tuple.iter_elements(target) ])) elif isinstance(target, computation_types.FunctionType): return pb.Type( function=pb.FunctionType( parameter=serialize_type(target.parameter), result=serialize_type(target.result))) elif isinstance(target, computation_types.PlacementType): return pb.Type(placement=pb.PlacementType()) elif isinstance(target, computation_types.FederatedType): if isinstance(target.placement, placement_literals.PlacementLiteral): return pb.Type( federated=pb.FederatedType( member=serialize_type(target.member), placement=pb.PlacementSpec( value=pb.Placement(uri=target.placement.uri)), all_equal=target.all_equal)) else: raise NotImplementedError( 'Serialization of federated types with placements specifications ' 'of type {} is not currently implemented yet.'.format( type(target.placement))) else: raise NotImplementedError
def serialize_type( type_spec: Optional[computation_types.Type]) -> Optional[pb.Type]: """Serializes 'type_spec' as a pb.Type. Note: Currently only serialization for tensor, named tuple, sequence, and function types is implemented. Args: type_spec: A `computation_types.Type`, or `None`. Returns: The corresponding instance of `pb.Type`, or `None` if the argument was `None`. Raises: TypeError: if the argument is of the wrong type. NotImplementedError: for type variants for which serialization is not implemented. """ if type_spec is None: return None cached_proto = _type_serialization_cache.get(type_spec, None) if cached_proto is not None: return cached_proto if type_spec.is_tensor(): proto = pb.Type(tensor=_to_tensor_type_proto(type_spec)) elif type_spec.is_sequence(): proto = pb.Type( sequence=pb.SequenceType(element=serialize_type(type_spec.element))) elif type_spec.is_struct(): proto = pb.Type( struct=pb.StructType(element=[ pb.StructType.Element(name=e[0], value=serialize_type(e[1])) for e in structure.iter_elements(type_spec) ])) elif type_spec.is_function(): proto = pb.Type( function=pb.FunctionType( parameter=serialize_type(type_spec.parameter), result=serialize_type(type_spec.result))) elif type_spec.is_placement(): proto = pb.Type(placement=pb.PlacementType()) elif type_spec.is_federated(): proto = pb.Type( federated=pb.FederatedType( member=serialize_type(type_spec.member), placement=pb.PlacementSpec( value=pb.Placement(uri=type_spec.placement.uri)), all_equal=type_spec.all_equal)) else: raise NotImplementedError _type_serialization_cache[type_spec] = proto return proto
def test_deserialize_federated_all_equal_value_takes_first_element(self): tensor_value_pb, _ = value_serialization.serialize_value( 10, TensorType(tf.int32)) num_clients = 5 value_pb = executor_pb2.Value( federated=executor_pb2.Value.Federated( value=[tensor_value_pb] * num_clients, type=computation_pb2.FederatedType( placement=computation_pb2.PlacementSpec( value=computation_pb2.Placement( uri=placements.CLIENTS.uri))))) all_equal_clients_type_hint = computation_types.FederatedType( tf.int32, placements.CLIENTS, all_equal=True) deserialized_value, deserialized_type = value_serialization.deserialize_value( value_pb, all_equal_clients_type_hint) type_test_utils.assert_types_identical(deserialized_type, all_equal_clients_type_hint) self.assertAllEqual(deserialized_value, 10)
def test_deserialize_federated_value_with_unset_member_type(self): x = 10 x_type = computation_types.to_type(tf.int32) member_proto, _ = value_serialization.serialize_value(x, x_type) fully_specified_type_at_clients = type_serialization.serialize_type( computation_types.at_clients(tf.int32)) unspecified_member_federated_type = computation_pb2.FederatedType( placement=fully_specified_type_at_clients.federated.placement, all_equal=fully_specified_type_at_clients.federated.all_equal) federated_proto = executor_pb2.Value.Federated( type=unspecified_member_federated_type, value=[member_proto]) federated_value_proto = executor_pb2.Value(federated=federated_proto) deserialized_federated_value, deserialized_type_spec = value_serialization.deserialize_value( federated_value_proto) type_test_utils.assert_types_identical( deserialized_type_spec, computation_types.at_clients(tf.int32)) self.assertEqual(deserialized_federated_value, [10])