def test_serialize_tensor_type(self, dtype, shape): type_signature = computation_types.TensorType(dtype, shape) actual_proto = type_serialization.serialize_type(type_signature) expected_proto = pb.Type( tensor=pb.TensorType( dtype=dtype.as_datatype_enum, dims=_shape_to_dims(shape))) self.assertEqual(actual_proto, expected_proto)
def serialize_type(type_spec): """Serializes 'type_spec' as a pb.Type. NOTE: Currently only serialization for tensor, named tuple, sequence, and function types is implemented. Args: type_spec: Either an instance of computation_types.Type, or something convertible to it by computation_types.to_type(), or None. Returns: The corresponding instance of `pb.Type`, or `None` if the argument was `None`. Raises: TypeError: if the argument is of the wrong type. NotImplementedError: for type variants for which serialization is not implemented. """ # TODO(b/113112885): Implement serialization of the remaining types. if type_spec is None: return None target = computation_types.to_type(type_spec) py_typecheck.check_type(target, computation_types.Type) if isinstance(target, computation_types.TensorType): return pb.Type( tensor=pb.TensorType(dtype=target.dtype.as_datatype_enum, shape=target.shape.as_proto())) elif isinstance(target, computation_types.SequenceType): return pb.Type(sequence=pb.SequenceType( element=serialize_type(target.element))) elif isinstance(target, computation_types.NamedTupleType): return pb.Type(tuple=pb.NamedTupleType(element=[ pb.NamedTupleType.Element(name=e[0], value=serialize_type(e[1])) for e in anonymous_tuple.to_elements(target) ])) elif isinstance(target, computation_types.FunctionType): return pb.Type(function=pb.FunctionType( parameter=serialize_type(target.parameter), result=serialize_type(target.result))) elif isinstance(target, computation_types.PlacementType): return pb.Type(placement=pb.PlacementType()) elif isinstance(target, computation_types.FederatedType): if isinstance(target.placement, placement_literals.PlacementLiteral): return pb.Type(federated=pb.FederatedType( member=serialize_type(target.member), placement=pb.PlacementSpec(value=pb.Placement( uri=target.placement.uri)), all_equal=target.all_equal)) else: raise NotImplementedError( 'Serialization of federated types with placements specifications ' 'of type {} is not currently implemented yet.'.format( type(target.placement))) else: raise NotImplementedError
def _to_tensor_type_proto( tensor_type: computation_types.TensorType) -> pb.TensorType: shape = tensor_type.shape if shape.dims is None: dims = None else: dims = [d.value if d.value is not None else -1 for d in shape.dims] return pb.TensorType(dtype=tensor_type.dtype.base_dtype.as_datatype_enum, dims=dims, unknown_rank=dims is None)
def _create_scalar_tensor_type(dtype): return pb.Type(tensor=pb.TensorType(dtype=dtype.as_datatype_enum))