Example #1
0
def _deserialize_federated_value(
    value_proto: executor_pb2.Value) -> _DeserializeReturnType:
  """Deserializes a value of federated type."""
  all_equal = value_proto.federated.type.all_equal
  placement_uri = value_proto.federated.type.placement.value.uri
  if not value_proto.federated.value:
    raise ValueError('Attempting to deserialize federated value with no data.')
  value = []
  # item_type will represent a supertype of all deserialized member types in the
  # federated value.
  item_type = None
  for item in value_proto.federated.value:
    item_value, next_item_type = deserialize_value(item)
    item_type = _ensure_deserialized_types_compatible(item_type, next_item_type)
    value.append(item_value)
  if all_equal:
    if len(value) == 1:
      value = value[0]
    else:
      raise ValueError(
          'Encountered an all_equal value with {} member constituents. '
          'Expected exactly 1.'.format(len(value)))
  type_spec = computation_types.FederatedType(
      item_type,
      placement=placements.uri_to_placement_literal(placement_uri),
      all_equal=all_equal)
  return value, type_spec
Example #2
0
def deserialize_cardinalities(
    serialized_cardinalities: Collection[serialization_bindings.Cardinality]
) -> CardinalitiesType:
    cardinalities_dict = {}
    for cardinality_spec in serialized_cardinalities:
        literal = placements.uri_to_placement_literal(
            cardinality_spec.placement.uri)
        cardinalities_dict[literal] = cardinality_spec.cardinality
    return cardinalities_dict
Example #3
0
def deserialize_cardinalities(
    serialized_cardinalities: Collection[
        executor_pb2.SetCardinalitiesRequest.Cardinality]
) -> CardinalitiesType:
  cardinalities_dict = {}
  for cardinality_spec in serialized_cardinalities:
    literal = placements.uri_to_placement_literal(
        cardinality_spec.placement.uri)
    cardinalities_dict[literal] = cardinality_spec.cardinality
  return cardinalities_dict
def _deserialize_federated_value(
    value_proto: executor_pb2.Value,
    type_hint: Optional[computation_types.Type] = None
) -> _DeserializeReturnType:
    """Deserializes a value of federated type."""
    if not value_proto.federated.value:
        raise ValueError(
            'Attempting to deserialize federated value with no data.')
    # The C++ runtime doesn't use the `all_equal` boolean (and doesn't report it
    # in returned values), however the type_hint on the computation may contain
    # it.
    if type_hint is not None:
        all_equal = type_hint.all_equal
    else:
        all_equal = value_proto.federated.type.all_equal
    placement_uri = value_proto.federated.type.placement.value.uri
    # item_type will represent a supertype of all deserialized member types in the
    # federated value. This will be the hint used for deserialize member values.
    if type_hint is not None:
        item_type_hint = type_hint.member
    else:
        item_type_hint = None
    item_type = None
    if all_equal:
        # As an optimization, we only deserialize the first value of an
        # `all_equal=True` federated value.
        items = [value_proto.federated.value[0]]
    else:
        items = value_proto.federated.value
    value = []
    for item in items:
        item_value, next_item_type = deserialize_value(item, item_type_hint)
        item_type = _ensure_deserialized_types_compatible(
            item_type, next_item_type)
        value.append(item_value)
    type_spec = computation_types.FederatedType(
        item_type,
        placement=placements.uri_to_placement_literal(placement_uri),
        all_equal=all_equal)
    if all_equal:
        value = value[0]
    return value, type_spec
def deserialize_type(
        type_proto: Optional[pb.Type]) -> Optional[computation_types.Type]:
    """Deserializes 'type_proto' as a computation_types.Type.

  Note: Currently only deserialization for tensor, named tuple, sequence, and
  function types is implemented.

  Args:
    type_proto: An object that supports same interface as `pb.Type` (e.g.
      pybind backend C++ `Type` protocol buffer messages), or `None`.

  Returns:
    The corresponding instance of computation_types.Type (or None if the
    argument was None).

  Raises:
    TypeError: if the argument is of the wrong type.
    NotImplementedError: for type variants for which deserialization is not
      implemented.
  """
    if type_proto is None:
        return None
    type_variant = type_proto.WhichOneof('type')
    if type_variant is None:
        return None
    elif type_variant == 'tensor':
        tensor_proto = type_proto.tensor
        return computation_types.TensorType(
            dtype=tf.dtypes.as_dtype(tensor_proto.dtype),
            shape=_to_tensor_shape(tensor_proto))
    elif type_variant == 'sequence':
        return computation_types.SequenceType(
            deserialize_type(type_proto.sequence.element))
    elif type_variant == 'struct':

        def empty_str_to_none(s):
            if s == '':  # pylint: disable=g-explicit-bool-comparison
                return None
            return s

        return computation_types.StructType(
            [(empty_str_to_none(e.name), deserialize_type(e.value))
             for e in type_proto.struct.element],
            convert=False)
    elif type_variant == 'function':
        return computation_types.FunctionType(
            parameter=deserialize_type(type_proto.function.parameter),
            result=deserialize_type(type_proto.function.result))
    elif type_variant == 'placement':
        return computation_types.PlacementType()
    elif type_variant == 'federated':
        placement_oneof = type_proto.federated.placement.WhichOneof(
            'placement')
        if placement_oneof == 'value':
            return computation_types.FederatedType(
                member=deserialize_type(type_proto.federated.member),
                placement=placements.uri_to_placement_literal(
                    type_proto.federated.placement.value.uri),
                all_equal=type_proto.federated.all_equal)
        else:
            raise NotImplementedError(
                'Deserialization of federated types with placement spec as {} '
                'is not currently implemented yet.'.format(placement_oneof))
    else:
        raise NotImplementedError(
            'Unknown type variant {}.'.format(type_variant))
Example #6
0
 def test_something(self):
   self.assertNotEqual(str(placements.CLIENTS), str(placements.SERVER))
   for literal in [placements.CLIENTS, placements.SERVER]:
     self.assertIs(placements.uri_to_placement_literal(literal.uri), literal)