示例#1
0
 async def create_selection(self, source, index=None, name=None):
     py_typecheck.check_type(source, CachedValue)
     py_typecheck.check_type(source.type_signature,
                             computation_types.NamedTupleType)
     source_val = await source.target_future
     if index is not None:
         py_typecheck.check_none(name)
         identifier_str = '{}[{}]'.format(source.identifier, index)
         type_spec = source.type_signature[index]
     else:
         py_typecheck.check_not_none(name)
         identifier_str = '{}.{}'.format(source.identifier, name)
         type_spec = getattr(source.type_signature, name)
     identifier = CachedValueIdentifier(identifier_str)
     try:
         cached_value = self._cache[identifier]
     except KeyError:
         target_future = asyncio.ensure_future(
             self._target_executor.create_selection(source_val,
                                                    index=index,
                                                    name=name))
         cached_value = CachedValue(identifier, None, type_spec,
                                    target_future)
         self._cache[identifier] = cached_value
     target_value = await cached_value.target_future
     type_utils.check_assignable_from(type_spec,
                                      target_value.type_signature)
     return cached_value
示例#2
0
 async def create_selection(self, source, index=None, name=None):
   py_typecheck.check_type(source, CachedValue)
   py_typecheck.check_type(source.type_signature, computation_types.StructType)
   source_val = await source.target_future
   if index is not None:
     py_typecheck.check_none(name)
     identifier_str = '{}[{}]'.format(source.identifier, index)
     type_spec = source.type_signature[index]
   else:
     py_typecheck.check_not_none(name)
     identifier_str = '{}.{}'.format(source.identifier, name)
     type_spec = getattr(source.type_signature, name)
   identifier = CachedValueIdentifier(identifier_str)
   try:
     cached_value = self._cache[identifier]
   except KeyError:
     target_future = asyncio.ensure_future(
         self._target_executor.create_selection(
             source_val, index=index, name=name))
     cached_value = CachedValue(identifier, None, type_spec, target_future)
     self._cache[identifier] = cached_value
   try:
     target_value = await cached_value.target_future
   except Exception:
     # TODO(b/145514490): This is a bit heavy handed, there maybe caches where
     # only the current cache item needs to be invalidated; however this
     # currently only occurs when an inner RemoteExecutor has the backend go
     # down.
     self._cache = {}
     raise
   type_spec.check_assignable_from(target_value.type_signature)
   return cached_value
示例#3
0
  def __init__(self,
               comp: Optional[computation_base.Computation],
               arg: Any,
               arg_type: computation_types.Type,
               cardinality: Optional[int] = None):
    """Constructs this data descriptor from the given computation and argument.

    Args:
      comp: The computation that materializes the data, of some type `(T -> U)`
        where `T` is the type of the argument `arg` and `U` is the type of the
        materialized data that's being produced. This can be `None`, in which
        case it's assumed to be an identity function (and `T` in that case must
        be identical to `U`).
      arg: The argument to be passed as input to `comp` if `comp` is not `None`,
        or to be treated as the computed result. Must be recognized by the TFF
        runtime as a payload of type `T`.
      arg_type: The type of the argument (`T` references above). An instance of
        `tff.Type`.
      cardinality: If of federated type, placed at clients, this int specifies
        the number of clients represented by this DataDescriptor.

    Raises:
      ValueError: if the arguments don't satisfy the constraints listed above.
    """
    super().__init__(comp, arg, arg_type)
    self._cardinality = {}
    if self._type_signature.is_federated():
      if self._type_signature.placement is placements.CLIENTS:
        py_typecheck.check_not_none(cardinality)
        self._cardinality[placements.CLIENTS] = cardinality
      else:
        py_typecheck.check_none(cardinality)
def infer_cardinalities(value, type_spec):
  """Infers cardinalities from Python `value`.

  Allows for any Python object to represent a federated value; enforcing
  particular representations is not the job of this inference function, but
  rather ingestion functions lower in the stack.

  Args:
    value: Python object from which to infer TFF placement cardinalities.
    type_spec: The TFF type spec for `value`, determining the semantics for
      inferring cardinalities. That is, we only pull the cardinality off of
      federated types.

  Returns:
    Dict of cardinalities.

  Raises:
    ValueError: If conflicting cardinalities are inferred from `value`.
    TypeError: If the arguments are of the wrong types, or if `type_spec` is
      a federated type which is not `all_equal` but the yet-to-be-embedded
      `value` is not represented as a Python `list`.
  """
  py_typecheck.check_not_none(value)
  py_typecheck.check_type(type_spec, computation_types.Type)
  if type_spec.is_federated():
    if type_spec.all_equal:
      return {}
    py_typecheck.check_type(value, collections.Sized)
    return {type_spec.placement: len(value)}
  elif type_spec.is_tuple():
    anonymous_tuple_value = anonymous_tuple.from_container(
        value, recursive=False)
    cardinality_dict = {}
    for idx, (_,
              elem_type) in enumerate(anonymous_tuple.to_elements(type_spec)):
      cardinality_dict = merge_cardinalities(
          cardinality_dict,
          infer_cardinalities(anonymous_tuple_value[idx], elem_type))
    return cardinality_dict
  else:
    return {}
示例#5
0
def infer_cardinalities(value, type_spec):
    """Infers cardinalities from Python `value`.

  Codifies the TFF convention that federated types which are not declared to be
  all-equal must be represented before ingestion at the Python level as a list.

  Args:
    value: Python object from which to infer TFF placement cardinalities.
    type_spec: The TFF type spec for `value`, determining the semantics for
      inferring cardinalities. That is, we only pull the cardinality off of
      federated types.

  Returns:
    Dict of cardinalities.

  Raises:
    ValueError: If conflicting cardinalities are inferred from `value`.
    TypeError: If the arguments are of the wrong types, or if `type_spec` is
      a federated type which is not `all_equal` but the yet-to-be-embedded
      `value` is not represented as a Python `list`.
  """
    py_typecheck.check_not_none(value)
    py_typecheck.check_type(type_spec, computation_types.Type)
    if isinstance(type_spec, computation_types.FederatedType):
        if type_spec.all_equal:
            return {}
        py_typecheck.check_type(value, list)
        return {type_spec.placement: len(value)}
    elif isinstance(type_spec, computation_types.NamedTupleType):
        anonymous_tuple_value = anonymous_tuple.from_container(value,
                                                               recursive=False)
        cardinality_dict = {}
        for idx, (_, elem_type) in enumerate(
                anonymous_tuple.to_elements(type_spec)):
            cardinality_dict = merge_cardinalities(
                cardinality_dict,
                infer_cardinalities(anonymous_tuple_value[idx], elem_type))
        return cardinality_dict
    else:
        return {}
示例#6
0
def check_valid_federated_weighted_mean_argument_tuple_type(type_spec):
    """Checks that `type_spec` is a valid type of a federated weighted mean arg.

  Args:
    type_spec: An instance of `tff.Type` or something convertible to it.

  Raises:
    TypeError: If the check fails.
  """
    type_spec = computation_types.to_type(type_spec)
    py_typecheck.check_not_none(type_spec)
    py_typecheck.check_type(type_spec, computation_types.NamedTupleType)
    if len(type_spec) != 2:
        raise TypeError('Expected a 2-tuple, found {}.'.format(type_spec))
    for _, v in anonymous_tuple.iter_elements(type_spec):
        check_federated_type(v, None, placement_literals.CLIENTS, False)
        if not is_average_compatible(v.member):
            raise TypeError(
                'Expected average-compatible args, got {} from argument of type {}.'
                .format(v.member, type_spec))
    w_type = type_spec[1].member
    py_typecheck.check_type(w_type, computation_types.TensorType)
    if w_type.shape.ndims != 0:
        raise TypeError('Expected scalar weight, got {}.'.format(w_type))
示例#7
0
 def test_check_not_none(self):
   py_typecheck.check_not_none(10)
   with self.assertRaises(TypeError):
     py_typecheck.check_not_none(None)
   with self.assertRaisesRegex(TypeError, 'foo'):
     py_typecheck.check_not_none(None, 'foo')