def _deserialize_federated_value( value_proto: executor_pb2.Value) -> _DeserializeReturnType: """Deserializes a value of federated type.""" all_equal = value_proto.federated.type.all_equal placement_uri = value_proto.federated.type.placement.value.uri if not value_proto.federated.value: raise ValueError('Attempting to deserialize federated value with no data.') value = [] # item_type will represent a supertype of all deserialized member types in the # federated value. item_type = None for item in value_proto.federated.value: item_value, next_item_type = deserialize_value(item) item_type = _ensure_deserialized_types_compatible(item_type, next_item_type) value.append(item_value) if all_equal: if len(value) == 1: value = value[0] else: raise ValueError( 'Encountered an all_equal value with {} member constituents. ' 'Expected exactly 1.'.format(len(value))) type_spec = computation_types.FederatedType( item_type, placement=placements.uri_to_placement_literal(placement_uri), all_equal=all_equal) return value, type_spec
def deserialize_cardinalities( serialized_cardinalities: Collection[serialization_bindings.Cardinality] ) -> CardinalitiesType: cardinalities_dict = {} for cardinality_spec in serialized_cardinalities: literal = placements.uri_to_placement_literal( cardinality_spec.placement.uri) cardinalities_dict[literal] = cardinality_spec.cardinality return cardinalities_dict
def deserialize_cardinalities( serialized_cardinalities: Collection[ executor_pb2.SetCardinalitiesRequest.Cardinality] ) -> CardinalitiesType: cardinalities_dict = {} for cardinality_spec in serialized_cardinalities: literal = placements.uri_to_placement_literal( cardinality_spec.placement.uri) cardinalities_dict[literal] = cardinality_spec.cardinality return cardinalities_dict
def _deserialize_federated_value( value_proto: executor_pb2.Value, type_hint: Optional[computation_types.Type] = None ) -> _DeserializeReturnType: """Deserializes a value of federated type.""" if not value_proto.federated.value: raise ValueError( 'Attempting to deserialize federated value with no data.') # The C++ runtime doesn't use the `all_equal` boolean (and doesn't report it # in returned values), however the type_hint on the computation may contain # it. if type_hint is not None: all_equal = type_hint.all_equal else: all_equal = value_proto.federated.type.all_equal placement_uri = value_proto.federated.type.placement.value.uri # item_type will represent a supertype of all deserialized member types in the # federated value. This will be the hint used for deserialize member values. if type_hint is not None: item_type_hint = type_hint.member else: item_type_hint = None item_type = None if all_equal: # As an optimization, we only deserialize the first value of an # `all_equal=True` federated value. items = [value_proto.federated.value[0]] else: items = value_proto.federated.value value = [] for item in items: item_value, next_item_type = deserialize_value(item, item_type_hint) item_type = _ensure_deserialized_types_compatible( item_type, next_item_type) value.append(item_value) type_spec = computation_types.FederatedType( item_type, placement=placements.uri_to_placement_literal(placement_uri), all_equal=all_equal) if all_equal: value = value[0] return value, type_spec
def deserialize_type( type_proto: Optional[pb.Type]) -> Optional[computation_types.Type]: """Deserializes 'type_proto' as a computation_types.Type. Note: Currently only deserialization for tensor, named tuple, sequence, and function types is implemented. Args: type_proto: An object that supports same interface as `pb.Type` (e.g. pybind backend C++ `Type` protocol buffer messages), or `None`. Returns: The corresponding instance of computation_types.Type (or None if the argument was None). Raises: TypeError: if the argument is of the wrong type. NotImplementedError: for type variants for which deserialization is not implemented. """ if type_proto is None: return None type_variant = type_proto.WhichOneof('type') if type_variant is None: return None elif type_variant == 'tensor': tensor_proto = type_proto.tensor return computation_types.TensorType( dtype=tf.dtypes.as_dtype(tensor_proto.dtype), shape=_to_tensor_shape(tensor_proto)) elif type_variant == 'sequence': return computation_types.SequenceType( deserialize_type(type_proto.sequence.element)) elif type_variant == 'struct': def empty_str_to_none(s): if s == '': # pylint: disable=g-explicit-bool-comparison return None return s return computation_types.StructType( [(empty_str_to_none(e.name), deserialize_type(e.value)) for e in type_proto.struct.element], convert=False) elif type_variant == 'function': return computation_types.FunctionType( parameter=deserialize_type(type_proto.function.parameter), result=deserialize_type(type_proto.function.result)) elif type_variant == 'placement': return computation_types.PlacementType() elif type_variant == 'federated': placement_oneof = type_proto.federated.placement.WhichOneof( 'placement') if placement_oneof == 'value': return computation_types.FederatedType( member=deserialize_type(type_proto.federated.member), placement=placements.uri_to_placement_literal( type_proto.federated.placement.value.uri), all_equal=type_proto.federated.all_equal) else: raise NotImplementedError( 'Deserialization of federated types with placement spec as {} ' 'is not currently implemented yet.'.format(placement_oneof)) else: raise NotImplementedError( 'Unknown type variant {}.'.format(type_variant))
def test_something(self): self.assertNotEqual(str(placements.CLIENTS), str(placements.SERVER)) for literal in [placements.CLIENTS, placements.SERVER]: self.assertIs(placements.uri_to_placement_literal(literal.uri), literal)