async def create_selection(self, source, index=None, name=None): py_typecheck.check_type(source, CachedValue) py_typecheck.check_type(source.type_signature, computation_types.NamedTupleType) source_val = await source.target_future if index is not None: py_typecheck.check_none(name) identifier_str = '{}[{}]'.format(source.identifier, index) type_spec = source.type_signature[index] else: py_typecheck.check_not_none(name) identifier_str = '{}.{}'.format(source.identifier, name) type_spec = getattr(source.type_signature, name) identifier = CachedValueIdentifier(identifier_str) try: cached_value = self._cache[identifier] except KeyError: target_future = asyncio.ensure_future( self._target_executor.create_selection(source_val, index=index, name=name)) cached_value = CachedValue(identifier, None, type_spec, target_future) self._cache[identifier] = cached_value target_value = await cached_value.target_future type_utils.check_assignable_from(type_spec, target_value.type_signature) return cached_value
async def create_selection(self, source, index=None, name=None): py_typecheck.check_type(source, CachedValue) py_typecheck.check_type(source.type_signature, computation_types.StructType) source_val = await source.target_future if index is not None: py_typecheck.check_none(name) identifier_str = '{}[{}]'.format(source.identifier, index) type_spec = source.type_signature[index] else: py_typecheck.check_not_none(name) identifier_str = '{}.{}'.format(source.identifier, name) type_spec = getattr(source.type_signature, name) identifier = CachedValueIdentifier(identifier_str) try: cached_value = self._cache[identifier] except KeyError: target_future = asyncio.ensure_future( self._target_executor.create_selection( source_val, index=index, name=name)) cached_value = CachedValue(identifier, None, type_spec, target_future) self._cache[identifier] = cached_value try: target_value = await cached_value.target_future except Exception: # TODO(b/145514490): This is a bit heavy handed, there maybe caches where # only the current cache item needs to be invalidated; however this # currently only occurs when an inner RemoteExecutor has the backend go # down. self._cache = {} raise type_spec.check_assignable_from(target_value.type_signature) return cached_value
def __init__(self, comp: Optional[computation_base.Computation], arg: Any, arg_type: computation_types.Type, cardinality: Optional[int] = None): """Constructs this data descriptor from the given computation and argument. Args: comp: The computation that materializes the data, of some type `(T -> U)` where `T` is the type of the argument `arg` and `U` is the type of the materialized data that's being produced. This can be `None`, in which case it's assumed to be an identity function (and `T` in that case must be identical to `U`). arg: The argument to be passed as input to `comp` if `comp` is not `None`, or to be treated as the computed result. Must be recognized by the TFF runtime as a payload of type `T`. arg_type: The type of the argument (`T` references above). An instance of `tff.Type`. cardinality: If of federated type, placed at clients, this int specifies the number of clients represented by this DataDescriptor. Raises: ValueError: if the arguments don't satisfy the constraints listed above. """ super().__init__(comp, arg, arg_type) self._cardinality = {} if self._type_signature.is_federated(): if self._type_signature.placement is placements.CLIENTS: py_typecheck.check_not_none(cardinality) self._cardinality[placements.CLIENTS] = cardinality else: py_typecheck.check_none(cardinality)
def infer_cardinalities(value, type_spec): """Infers cardinalities from Python `value`. Allows for any Python object to represent a federated value; enforcing particular representations is not the job of this inference function, but rather ingestion functions lower in the stack. Args: value: Python object from which to infer TFF placement cardinalities. type_spec: The TFF type spec for `value`, determining the semantics for inferring cardinalities. That is, we only pull the cardinality off of federated types. Returns: Dict of cardinalities. Raises: ValueError: If conflicting cardinalities are inferred from `value`. TypeError: If the arguments are of the wrong types, or if `type_spec` is a federated type which is not `all_equal` but the yet-to-be-embedded `value` is not represented as a Python `list`. """ py_typecheck.check_not_none(value) py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_federated(): if type_spec.all_equal: return {} py_typecheck.check_type(value, collections.Sized) return {type_spec.placement: len(value)} elif type_spec.is_tuple(): anonymous_tuple_value = anonymous_tuple.from_container( value, recursive=False) cardinality_dict = {} for idx, (_, elem_type) in enumerate(anonymous_tuple.to_elements(type_spec)): cardinality_dict = merge_cardinalities( cardinality_dict, infer_cardinalities(anonymous_tuple_value[idx], elem_type)) return cardinality_dict else: return {}
def infer_cardinalities(value, type_spec): """Infers cardinalities from Python `value`. Codifies the TFF convention that federated types which are not declared to be all-equal must be represented before ingestion at the Python level as a list. Args: value: Python object from which to infer TFF placement cardinalities. type_spec: The TFF type spec for `value`, determining the semantics for inferring cardinalities. That is, we only pull the cardinality off of federated types. Returns: Dict of cardinalities. Raises: ValueError: If conflicting cardinalities are inferred from `value`. TypeError: If the arguments are of the wrong types, or if `type_spec` is a federated type which is not `all_equal` but the yet-to-be-embedded `value` is not represented as a Python `list`. """ py_typecheck.check_not_none(value) py_typecheck.check_type(type_spec, computation_types.Type) if isinstance(type_spec, computation_types.FederatedType): if type_spec.all_equal: return {} py_typecheck.check_type(value, list) return {type_spec.placement: len(value)} elif isinstance(type_spec, computation_types.NamedTupleType): anonymous_tuple_value = anonymous_tuple.from_container(value, recursive=False) cardinality_dict = {} for idx, (_, elem_type) in enumerate( anonymous_tuple.to_elements(type_spec)): cardinality_dict = merge_cardinalities( cardinality_dict, infer_cardinalities(anonymous_tuple_value[idx], elem_type)) return cardinality_dict else: return {}
def check_valid_federated_weighted_mean_argument_tuple_type(type_spec): """Checks that `type_spec` is a valid type of a federated weighted mean arg. Args: type_spec: An instance of `tff.Type` or something convertible to it. Raises: TypeError: If the check fails. """ type_spec = computation_types.to_type(type_spec) py_typecheck.check_not_none(type_spec) py_typecheck.check_type(type_spec, computation_types.NamedTupleType) if len(type_spec) != 2: raise TypeError('Expected a 2-tuple, found {}.'.format(type_spec)) for _, v in anonymous_tuple.iter_elements(type_spec): check_federated_type(v, None, placement_literals.CLIENTS, False) if not is_average_compatible(v.member): raise TypeError( 'Expected average-compatible args, got {} from argument of type {}.' .format(v.member, type_spec)) w_type = type_spec[1].member py_typecheck.check_type(w_type, computation_types.TensorType) if w_type.shape.ndims != 0: raise TypeError('Expected scalar weight, got {}.'.format(w_type))
def test_check_not_none(self): py_typecheck.check_not_none(10) with self.assertRaises(TypeError): py_typecheck.check_not_none(None) with self.assertRaisesRegex(TypeError, 'foo'): py_typecheck.check_not_none(None, 'foo')