コード例 #1
0
def serialize_type(
    type_spec: Optional[computation_types.Type]) -> Optional[pb.Type]:
  """Serializes 'type_spec' as a pb.Type.

  Note: Currently only serialization for tensor, named tuple, sequence, and
  function types is implemented.

  Args:
    type_spec: A `computation_types.Type`, or `None`.

  Returns:
    The corresponding instance of `pb.Type`, or `None` if the argument was
      `None`.

  Raises:
    TypeError: if the argument is of the wrong type.
    NotImplementedError: for type variants for which serialization is not
      implemented.
  """
  if type_spec is None:
    return None
  cached_proto = _type_serialization_cache.get(type_spec, None)
  if cached_proto is not None:
    return cached_proto
  if type_spec.is_tensor():
    proto = pb.Type(tensor=_to_tensor_type_proto(type_spec))
  elif type_spec.is_sequence():
    proto = pb.Type(
        sequence=pb.SequenceType(element=serialize_type(type_spec.element)))
  elif type_spec.is_struct():
    proto = pb.Type(
        struct=pb.StructType(element=[
            pb.StructType.Element(name=e[0], value=serialize_type(e[1]))
            for e in structure.iter_elements(type_spec)
        ]))
  elif type_spec.is_function():
    proto = pb.Type(
        function=pb.FunctionType(
            parameter=serialize_type(type_spec.parameter),
            result=serialize_type(type_spec.result)))
  elif type_spec.is_placement():
    proto = pb.Type(placement=pb.PlacementType())
  elif type_spec.is_federated():
    proto = pb.Type(
        federated=pb.FederatedType(
            member=serialize_type(type_spec.member),
            placement=pb.PlacementSpec(
                value=pb.Placement(uri=type_spec.placement.uri)),
            all_equal=type_spec.all_equal))
  else:
    raise NotImplementedError

  _type_serialization_cache[type_spec] = proto
  return proto
コード例 #2
0
 def test_serialize_type_with_function(self):
     actual_proto = type_serialization.serialize_type(
         computation_types.FunctionType((tf.int32, tf.int32), tf.bool))
     expected_proto = pb.Type(
         function=pb.FunctionType(parameter=pb.Type(struct=pb.StructType(
             element=[
                 pb.StructType.Element(
                     value=_create_scalar_tensor_type(tf.int32)),
                 pb.StructType.Element(
                     value=_create_scalar_tensor_type(tf.int32))
             ])),
                                  result=_create_scalar_tensor_type(
                                      tf.bool)))
     self.assertEqual(actual_proto, expected_proto)
コード例 #3
0
 def test_serialize_type_with_tensor_tuple(self):
     type_signature = computation_types.StructType([
         ('x', tf.int32),
         ('y', tf.string),
         tf.float32,
         ('z', tf.bool),
     ])
     actual_proto = type_serialization.serialize_type(type_signature)
     expected_proto = pb.Type(struct=pb.StructType(element=[
         pb.StructType.Element(name='x',
                               value=_create_scalar_tensor_type(tf.int32)),
         pb.StructType.Element(name='y',
                               value=_create_scalar_tensor_type(tf.string)),
         pb.StructType.Element(
             value=_create_scalar_tensor_type(tf.float32)),
         pb.StructType.Element(name='z',
                               value=_create_scalar_tensor_type(tf.bool)),
     ]))
     self.assertEqual(actual_proto, expected_proto)
コード例 #4
0
 def _tuple_type_proto(elements):
     return pb.Type(struct=pb.StructType(element=elements))