Esempio n. 1
0
 def test_serialize_tensor_type(self, dtype, shape):
   type_signature = computation_types.TensorType(dtype, shape)
   actual_proto = type_serialization.serialize_type(type_signature)
   expected_proto = pb.Type(
       tensor=pb.TensorType(
           dtype=dtype.as_datatype_enum, dims=_shape_to_dims(shape)))
   self.assertEqual(actual_proto, expected_proto)
def serialize_type(type_spec):
    """Serializes 'type_spec' as a pb.Type.

  NOTE: Currently only serialization for tensor, named tuple, sequence, and
  function types is implemented.

  Args:
    type_spec: Either an instance of computation_types.Type, or something
      convertible to it by computation_types.to_type(), or None.

  Returns:
    The corresponding instance of `pb.Type`, or `None` if the argument was
      `None`.

  Raises:
    TypeError: if the argument is of the wrong type.
    NotImplementedError: for type variants for which serialization is not
      implemented.
  """
    # TODO(b/113112885): Implement serialization of the remaining types.
    if type_spec is None:
        return None
    target = computation_types.to_type(type_spec)
    py_typecheck.check_type(target, computation_types.Type)
    if isinstance(target, computation_types.TensorType):
        return pb.Type(
            tensor=pb.TensorType(dtype=target.dtype.as_datatype_enum,
                                 shape=target.shape.as_proto()))
    elif isinstance(target, computation_types.SequenceType):
        return pb.Type(sequence=pb.SequenceType(
            element=serialize_type(target.element)))
    elif isinstance(target, computation_types.NamedTupleType):
        return pb.Type(tuple=pb.NamedTupleType(element=[
            pb.NamedTupleType.Element(name=e[0], value=serialize_type(e[1]))
            for e in anonymous_tuple.to_elements(target)
        ]))
    elif isinstance(target, computation_types.FunctionType):
        return pb.Type(function=pb.FunctionType(
            parameter=serialize_type(target.parameter),
            result=serialize_type(target.result)))
    elif isinstance(target, computation_types.PlacementType):
        return pb.Type(placement=pb.PlacementType())
    elif isinstance(target, computation_types.FederatedType):
        if isinstance(target.placement, placement_literals.PlacementLiteral):
            return pb.Type(federated=pb.FederatedType(
                member=serialize_type(target.member),
                placement=pb.PlacementSpec(value=pb.Placement(
                    uri=target.placement.uri)),
                all_equal=target.all_equal))
        else:
            raise NotImplementedError(
                'Serialization of federated types with placements specifications '
                'of type {} is not currently implemented yet.'.format(
                    type(target.placement)))
    else:
        raise NotImplementedError
def _to_tensor_type_proto(
        tensor_type: computation_types.TensorType) -> pb.TensorType:
    shape = tensor_type.shape
    if shape.dims is None:
        dims = None
    else:
        dims = [d.value if d.value is not None else -1 for d in shape.dims]
    return pb.TensorType(dtype=tensor_type.dtype.base_dtype.as_datatype_enum,
                         dims=dims,
                         unknown_rank=dims is None)
Esempio n. 4
0
def _create_scalar_tensor_type(dtype):
    return pb.Type(tensor=pb.TensorType(dtype=dtype.as_datatype_enum))