示例#1
0
def serialize_type(type_spec):
  """Serializes 'type_spec' as a pb.Type.

  Note: Currently only serialization for tensor, named tuple, sequence, and
  function types is implemented.

  Args:
    type_spec: Either an instance of computation_types.Type, or something
      convertible to it by computation_types.to_type(), or None.

  Returns:
    The corresponding instance of `pb.Type`, or `None` if the argument was
      `None`.

  Raises:
    TypeError: if the argument is of the wrong type.
    NotImplementedError: for type variants for which serialization is not
      implemented.
  """
  # TODO(b/113112885): Implement serialization of the remaining types.
  if type_spec is None:
    return None
  target = computation_types.to_type(type_spec)
  py_typecheck.check_type(target, computation_types.Type)
  if isinstance(target, computation_types.TensorType):
    return pb.Type(tensor=_to_tensor_type_proto(target))
  elif isinstance(target, computation_types.SequenceType):
    return pb.Type(
        sequence=pb.SequenceType(element=serialize_type(target.element)))
  elif isinstance(target, computation_types.NamedTupleType):
    return pb.Type(
        tuple=pb.NamedTupleType(element=[
            pb.NamedTupleType.Element(name=e[0], value=serialize_type(e[1]))
            for e in anonymous_tuple.iter_elements(target)
        ]))
  elif isinstance(target, computation_types.FunctionType):
    return pb.Type(
        function=pb.FunctionType(
            parameter=serialize_type(target.parameter),
            result=serialize_type(target.result)))
  elif isinstance(target, computation_types.PlacementType):
    return pb.Type(placement=pb.PlacementType())
  elif isinstance(target, computation_types.FederatedType):
    if isinstance(target.placement, placement_literals.PlacementLiteral):
      return pb.Type(
          federated=pb.FederatedType(
              member=serialize_type(target.member),
              placement=pb.PlacementSpec(
                  value=pb.Placement(uri=target.placement.uri)),
              all_equal=target.all_equal))
    else:
      raise NotImplementedError(
          'Serialization of federated types with placements specifications '
          'of type {} is not currently implemented yet.'.format(
              type(target.placement)))
  else:
    raise NotImplementedError
 def test_serialize_type_with_function(self):
     actual_proto = type_serialization.serialize_type(
         computation_types.FunctionType((tf.int32, tf.int32), tf.bool))
     expected_proto = pb.Type(
         function=pb.FunctionType(parameter=pb.Type(tuple=pb.NamedTupleType(
             element=[
                 pb.NamedTupleType.Element(
                     value=_create_scalar_tensor_type(tf.int32)),
                 pb.NamedTupleType.Element(
                     value=_create_scalar_tensor_type(tf.int32))
             ])),
                                  result=_create_scalar_tensor_type(
                                      tf.bool)))
     self.assertEqual(actual_proto, expected_proto)
def serialize_type(
        type_spec: Optional[computation_types.Type]) -> Optional[pb.Type]:
    """Serializes 'type_spec' as a pb.Type.

  Note: Currently only serialization for tensor, named tuple, sequence, and
  function types is implemented.

  Args:
    type_spec: A `computation_types.Type`, or `None`.

  Returns:
    The corresponding instance of `pb.Type`, or `None` if the argument was
      `None`.

  Raises:
    TypeError: if the argument is of the wrong type.
    NotImplementedError: for type variants for which serialization is not
      implemented.
  """
    if type_spec is None:
        return None
    if type_spec.is_tensor():
        return pb.Type(tensor=_to_tensor_type_proto(type_spec))
    elif type_spec.is_sequence():
        return pb.Type(sequence=pb.SequenceType(
            element=serialize_type(type_spec.element)))
    elif type_spec.is_tuple():
        return pb.Type(tuple=pb.NamedTupleType(element=[
            pb.NamedTupleType.Element(name=e[0], value=serialize_type(e[1]))
            for e in anonymous_tuple.iter_elements(type_spec)
        ]))
    elif type_spec.is_function():
        return pb.Type(function=pb.FunctionType(
            parameter=serialize_type(type_spec.parameter),
            result=serialize_type(type_spec.result)))
    elif type_spec.is_placement():
        return pb.Type(placement=pb.PlacementType())
    elif type_spec.is_federated():
        return pb.Type(
            federated=pb.FederatedType(member=serialize_type(type_spec.member),
                                       placement=pb.PlacementSpec(
                                           value=pb.Placement(
                                               uri=type_spec.placement.uri)),
                                       all_equal=type_spec.all_equal))
    else:
        raise NotImplementedError
 def test_serialize_type_with_tensor_tuple(self):
     actual_proto = type_serialization.serialize_type([
         ('x', tf.int32),
         ('y', tf.string),
         tf.float32,
         ('z', tf.bool),
     ])
     expected_proto = pb.Type(tuple=pb.NamedTupleType(element=[
         pb.NamedTupleType.Element(
             name='x', value=_create_scalar_tensor_type(tf.int32)),
         pb.NamedTupleType.Element(
             name='y', value=_create_scalar_tensor_type(tf.string)),
         pb.NamedTupleType.Element(
             value=_create_scalar_tensor_type(tf.float32)),
         pb.NamedTupleType.Element(
             name='z', value=_create_scalar_tensor_type(tf.bool)),
     ]))
     self.assertEqual(actual_proto, expected_proto)
 def _tuple_type_proto(elements):
     return pb.Type(tuple=pb.NamedTupleType(element=elements))