def serialize_type(type_spec): """Serializes 'type_spec' as a pb.Type. Note: Currently only serialization for tensor, named tuple, sequence, and function types is implemented. Args: type_spec: Either an instance of computation_types.Type, or something convertible to it by computation_types.to_type(), or None. Returns: The corresponding instance of `pb.Type`, or `None` if the argument was `None`. Raises: TypeError: if the argument is of the wrong type. NotImplementedError: for type variants for which serialization is not implemented. """ # TODO(b/113112885): Implement serialization of the remaining types. if type_spec is None: return None target = computation_types.to_type(type_spec) py_typecheck.check_type(target, computation_types.Type) if isinstance(target, computation_types.TensorType): return pb.Type(tensor=_to_tensor_type_proto(target)) elif isinstance(target, computation_types.SequenceType): return pb.Type( sequence=pb.SequenceType(element=serialize_type(target.element))) elif isinstance(target, computation_types.NamedTupleType): return pb.Type( tuple=pb.NamedTupleType(element=[ pb.NamedTupleType.Element(name=e[0], value=serialize_type(e[1])) for e in anonymous_tuple.iter_elements(target) ])) elif isinstance(target, computation_types.FunctionType): return pb.Type( function=pb.FunctionType( parameter=serialize_type(target.parameter), result=serialize_type(target.result))) elif isinstance(target, computation_types.PlacementType): return pb.Type(placement=pb.PlacementType()) elif isinstance(target, computation_types.FederatedType): if isinstance(target.placement, placement_literals.PlacementLiteral): return pb.Type( federated=pb.FederatedType( member=serialize_type(target.member), placement=pb.PlacementSpec( value=pb.Placement(uri=target.placement.uri)), all_equal=target.all_equal)) else: raise NotImplementedError( 'Serialization of federated types with placements specifications ' 'of type {} is not currently implemented yet.'.format( type(target.placement))) else: raise NotImplementedError
def test_serialize_type_with_function(self): actual_proto = type_serialization.serialize_type( computation_types.FunctionType((tf.int32, tf.int32), tf.bool)) expected_proto = pb.Type( function=pb.FunctionType(parameter=pb.Type(tuple=pb.NamedTupleType( element=[ pb.NamedTupleType.Element( value=_create_scalar_tensor_type(tf.int32)), pb.NamedTupleType.Element( value=_create_scalar_tensor_type(tf.int32)) ])), result=_create_scalar_tensor_type( tf.bool))) self.assertEqual(actual_proto, expected_proto)
def serialize_type( type_spec: Optional[computation_types.Type]) -> Optional[pb.Type]: """Serializes 'type_spec' as a pb.Type. Note: Currently only serialization for tensor, named tuple, sequence, and function types is implemented. Args: type_spec: A `computation_types.Type`, or `None`. Returns: The corresponding instance of `pb.Type`, or `None` if the argument was `None`. Raises: TypeError: if the argument is of the wrong type. NotImplementedError: for type variants for which serialization is not implemented. """ if type_spec is None: return None if type_spec.is_tensor(): return pb.Type(tensor=_to_tensor_type_proto(type_spec)) elif type_spec.is_sequence(): return pb.Type(sequence=pb.SequenceType( element=serialize_type(type_spec.element))) elif type_spec.is_tuple(): return pb.Type(tuple=pb.NamedTupleType(element=[ pb.NamedTupleType.Element(name=e[0], value=serialize_type(e[1])) for e in anonymous_tuple.iter_elements(type_spec) ])) elif type_spec.is_function(): return pb.Type(function=pb.FunctionType( parameter=serialize_type(type_spec.parameter), result=serialize_type(type_spec.result))) elif type_spec.is_placement(): return pb.Type(placement=pb.PlacementType()) elif type_spec.is_federated(): return pb.Type( federated=pb.FederatedType(member=serialize_type(type_spec.member), placement=pb.PlacementSpec( value=pb.Placement( uri=type_spec.placement.uri)), all_equal=type_spec.all_equal)) else: raise NotImplementedError
def test_serialize_type_with_tensor_tuple(self): actual_proto = type_serialization.serialize_type([ ('x', tf.int32), ('y', tf.string), tf.float32, ('z', tf.bool), ]) expected_proto = pb.Type(tuple=pb.NamedTupleType(element=[ pb.NamedTupleType.Element( name='x', value=_create_scalar_tensor_type(tf.int32)), pb.NamedTupleType.Element( name='y', value=_create_scalar_tensor_type(tf.string)), pb.NamedTupleType.Element( value=_create_scalar_tensor_type(tf.float32)), pb.NamedTupleType.Element( name='z', value=_create_scalar_tensor_type(tf.bool)), ])) self.assertEqual(actual_proto, expected_proto)
def _tuple_type_proto(elements): return pb.Type(tuple=pb.NamedTupleType(element=elements))