def test_deserialize_federated_value_promotes_types(self):
    x = [10]
    smaller_type = computation_types.StructType([
        (None, computation_types.to_type(tf.int32))
    ])
    smaller_type_member_proto, _ = value_serialization.serialize_value(
        x, smaller_type)
    larger_type = computation_types.StructType([
        ('a', computation_types.to_type(tf.int32))
    ])
    larger_type_member_proto, _ = value_serialization.serialize_value(
        x, larger_type)
    type_at_clients = type_serialization.serialize_type(
        computation_types.at_clients(tf.int32))

    unspecified_member_federated_type = computation_pb2.FederatedType(
        placement=type_at_clients.federated.placement, all_equal=False)

    federated_proto = executor_pb2.Value.Federated(
        type=unspecified_member_federated_type,
        value=[larger_type_member_proto, smaller_type_member_proto])
    federated_value_proto = executor_pb2.Value(federated=federated_proto)

    _, deserialized_type_spec = value_serialization.deserialize_value(
        federated_value_proto)
    type_test_utils.assert_types_identical(
        deserialized_type_spec, computation_types.at_clients(larger_type))
Exemplo n.º 2
0
  def test_deserialize_federated_value_with_incompatible_member_types_raises(
      self):
    x = 10
    x_type = computation_types.to_type(tf.int32)
    int_member_proto, _ = executor_serialization.serialize_value(x, x_type)
    y = 10.
    y_type = computation_types.to_type(tf.float32)
    float_member_proto, _ = executor_serialization.serialize_value(y, y_type)
    fully_specified_type_at_clients = type_serialization.serialize_type(
        computation_types.at_clients(tf.int32))

    unspecified_member_federated_type = computation_pb2.FederatedType(
        placement=fully_specified_type_at_clients.federated.placement,
        all_equal=False)

    federated_proto = executor_pb2.Value.Federated(
        type=unspecified_member_federated_type,
        value=[int_member_proto, float_member_proto])
    federated_value_proto = executor_pb2.Value(federated=federated_proto)

    self.assertIsInstance(int_member_proto, executor_pb2.Value)
    self.assertIsInstance(float_member_proto, executor_pb2.Value)
    self.assertIsInstance(federated_value_proto, executor_pb2.Value)

    with self.assertRaises(TypeError):
      executor_serialization.deserialize_value(federated_value_proto)
Exemplo n.º 3
0
 def test_serialize_type_with_federated_bool(self):
     actual_proto = type_serialization.serialize_type(
         computation_types.FederatedType(tf.bool, placements.CLIENTS, True))
     expected_proto = pb.Type(federated=pb.FederatedType(
         placement=pb.PlacementSpec(value=pb.Placement(
             uri=placements.CLIENTS.uri)),
         all_equal=True,
         member=_create_scalar_tensor_type(tf.bool)))
     self.assertEqual(actual_proto, expected_proto)
Exemplo n.º 4
0
def serialize_type(type_spec):
  """Serializes 'type_spec' as a pb.Type.

  Note: Currently only serialization for tensor, named tuple, sequence, and
  function types is implemented.

  Args:
    type_spec: Either an instance of computation_types.Type, or something
      convertible to it by computation_types.to_type(), or None.

  Returns:
    The corresponding instance of `pb.Type`, or `None` if the argument was
      `None`.

  Raises:
    TypeError: if the argument is of the wrong type.
    NotImplementedError: for type variants for which serialization is not
      implemented.
  """
  # TODO(b/113112885): Implement serialization of the remaining types.
  if type_spec is None:
    return None
  target = computation_types.to_type(type_spec)
  py_typecheck.check_type(target, computation_types.Type)
  if isinstance(target, computation_types.TensorType):
    return pb.Type(tensor=_to_tensor_type_proto(target))
  elif isinstance(target, computation_types.SequenceType):
    return pb.Type(
        sequence=pb.SequenceType(element=serialize_type(target.element)))
  elif isinstance(target, computation_types.NamedTupleType):
    return pb.Type(
        tuple=pb.NamedTupleType(element=[
            pb.NamedTupleType.Element(name=e[0], value=serialize_type(e[1]))
            for e in anonymous_tuple.iter_elements(target)
        ]))
  elif isinstance(target, computation_types.FunctionType):
    return pb.Type(
        function=pb.FunctionType(
            parameter=serialize_type(target.parameter),
            result=serialize_type(target.result)))
  elif isinstance(target, computation_types.PlacementType):
    return pb.Type(placement=pb.PlacementType())
  elif isinstance(target, computation_types.FederatedType):
    if isinstance(target.placement, placement_literals.PlacementLiteral):
      return pb.Type(
          federated=pb.FederatedType(
              member=serialize_type(target.member),
              placement=pb.PlacementSpec(
                  value=pb.Placement(uri=target.placement.uri)),
              all_equal=target.all_equal))
    else:
      raise NotImplementedError(
          'Serialization of federated types with placements specifications '
          'of type {} is not currently implemented yet.'.format(
              type(target.placement)))
  else:
    raise NotImplementedError
Exemplo n.º 5
0
def serialize_type(
    type_spec: Optional[computation_types.Type]) -> Optional[pb.Type]:
  """Serializes 'type_spec' as a pb.Type.

  Note: Currently only serialization for tensor, named tuple, sequence, and
  function types is implemented.

  Args:
    type_spec: A `computation_types.Type`, or `None`.

  Returns:
    The corresponding instance of `pb.Type`, or `None` if the argument was
      `None`.

  Raises:
    TypeError: if the argument is of the wrong type.
    NotImplementedError: for type variants for which serialization is not
      implemented.
  """
  if type_spec is None:
    return None
  cached_proto = _type_serialization_cache.get(type_spec, None)
  if cached_proto is not None:
    return cached_proto
  if type_spec.is_tensor():
    proto = pb.Type(tensor=_to_tensor_type_proto(type_spec))
  elif type_spec.is_sequence():
    proto = pb.Type(
        sequence=pb.SequenceType(element=serialize_type(type_spec.element)))
  elif type_spec.is_struct():
    proto = pb.Type(
        struct=pb.StructType(element=[
            pb.StructType.Element(name=e[0], value=serialize_type(e[1]))
            for e in structure.iter_elements(type_spec)
        ]))
  elif type_spec.is_function():
    proto = pb.Type(
        function=pb.FunctionType(
            parameter=serialize_type(type_spec.parameter),
            result=serialize_type(type_spec.result)))
  elif type_spec.is_placement():
    proto = pb.Type(placement=pb.PlacementType())
  elif type_spec.is_federated():
    proto = pb.Type(
        federated=pb.FederatedType(
            member=serialize_type(type_spec.member),
            placement=pb.PlacementSpec(
                value=pb.Placement(uri=type_spec.placement.uri)),
            all_equal=type_spec.all_equal))
  else:
    raise NotImplementedError

  _type_serialization_cache[type_spec] = proto
  return proto
 def test_deserialize_federated_all_equal_value_takes_first_element(self):
   tensor_value_pb, _ = value_serialization.serialize_value(
       10, TensorType(tf.int32))
   num_clients = 5
   value_pb = executor_pb2.Value(
       federated=executor_pb2.Value.Federated(
           value=[tensor_value_pb] * num_clients,
           type=computation_pb2.FederatedType(
               placement=computation_pb2.PlacementSpec(
                   value=computation_pb2.Placement(
                       uri=placements.CLIENTS.uri)))))
   all_equal_clients_type_hint = computation_types.FederatedType(
       tf.int32, placements.CLIENTS, all_equal=True)
   deserialized_value, deserialized_type = value_serialization.deserialize_value(
       value_pb, all_equal_clients_type_hint)
   type_test_utils.assert_types_identical(deserialized_type,
                                          all_equal_clients_type_hint)
   self.assertAllEqual(deserialized_value, 10)
  def test_deserialize_federated_value_with_unset_member_type(self):
    x = 10
    x_type = computation_types.to_type(tf.int32)
    member_proto, _ = value_serialization.serialize_value(x, x_type)
    fully_specified_type_at_clients = type_serialization.serialize_type(
        computation_types.at_clients(tf.int32))

    unspecified_member_federated_type = computation_pb2.FederatedType(
        placement=fully_specified_type_at_clients.federated.placement,
        all_equal=fully_specified_type_at_clients.federated.all_equal)

    federated_proto = executor_pb2.Value.Federated(
        type=unspecified_member_federated_type, value=[member_proto])
    federated_value_proto = executor_pb2.Value(federated=federated_proto)

    deserialized_federated_value, deserialized_type_spec = value_serialization.deserialize_value(
        federated_value_proto)
    type_test_utils.assert_types_identical(
        deserialized_type_spec, computation_types.at_clients(tf.int32))
    self.assertEqual(deserialized_federated_value, [10])