def serialize_type( type_spec: Optional[computation_types.Type]) -> Optional[pb.Type]: """Serializes 'type_spec' as a pb.Type. Note: Currently only serialization for tensor, named tuple, sequence, and function types is implemented. Args: type_spec: A `computation_types.Type`, or `None`. Returns: The corresponding instance of `pb.Type`, or `None` if the argument was `None`. Raises: TypeError: if the argument is of the wrong type. NotImplementedError: for type variants for which serialization is not implemented. """ if type_spec is None: return None cached_proto = _type_serialization_cache.get(type_spec, None) if cached_proto is not None: return cached_proto if type_spec.is_tensor(): proto = pb.Type(tensor=_to_tensor_type_proto(type_spec)) elif type_spec.is_sequence(): proto = pb.Type( sequence=pb.SequenceType(element=serialize_type(type_spec.element))) elif type_spec.is_struct(): proto = pb.Type( struct=pb.StructType(element=[ pb.StructType.Element(name=e[0], value=serialize_type(e[1])) for e in structure.iter_elements(type_spec) ])) elif type_spec.is_function(): proto = pb.Type( function=pb.FunctionType( parameter=serialize_type(type_spec.parameter), result=serialize_type(type_spec.result))) elif type_spec.is_placement(): proto = pb.Type(placement=pb.PlacementType()) elif type_spec.is_federated(): proto = pb.Type( federated=pb.FederatedType( member=serialize_type(type_spec.member), placement=pb.PlacementSpec( value=pb.Placement(uri=type_spec.placement.uri)), all_equal=type_spec.all_equal)) else: raise NotImplementedError _type_serialization_cache[type_spec] = proto return proto
def test_serialize_type_with_function(self): actual_proto = type_serialization.serialize_type( computation_types.FunctionType((tf.int32, tf.int32), tf.bool)) expected_proto = pb.Type( function=pb.FunctionType(parameter=pb.Type(struct=pb.StructType( element=[ pb.StructType.Element( value=_create_scalar_tensor_type(tf.int32)), pb.StructType.Element( value=_create_scalar_tensor_type(tf.int32)) ])), result=_create_scalar_tensor_type( tf.bool))) self.assertEqual(actual_proto, expected_proto)
def test_serialize_type_with_tensor_tuple(self): type_signature = computation_types.StructType([ ('x', tf.int32), ('y', tf.string), tf.float32, ('z', tf.bool), ]) actual_proto = type_serialization.serialize_type(type_signature) expected_proto = pb.Type(struct=pb.StructType(element=[ pb.StructType.Element(name='x', value=_create_scalar_tensor_type(tf.int32)), pb.StructType.Element(name='y', value=_create_scalar_tensor_type(tf.string)), pb.StructType.Element( value=_create_scalar_tensor_type(tf.float32)), pb.StructType.Element(name='z', value=_create_scalar_tensor_type(tf.bool)), ])) self.assertEqual(actual_proto, expected_proto)
def _tuple_type_proto(elements): return pb.Type(struct=pb.StructType(element=elements))