Exemple #1
0
 async def create_tuple(self, elements):
     values = []
     type_specs = []
     for name, value in anonymous_tuple.iter_elements(
             anonymous_tuple.from_container(elements)):
         py_typecheck.check_type(value, CompositeValue)
         values.append((name, value.internal_representation))
         if name is not None:
             type_spec = (name, value.type_signature)
         else:
             type_spec = value.type_signature
         type_specs.append(type_spec)
     return CompositeValue(anonymous_tuple.AnonymousTuple(values),
                           computation_types.NamedTupleType(type_specs))
Exemple #2
0
 def _server_state_from_tff_result(self, result):
     if self._per_vector_clipping:
         per_vector_aggregate_states = [
             anonymous_tuple.to_odict(elt, recursive=True) for _, elt in
             anonymous_tuple.iter_elements(result.delta_aggregate_state)
         ]
     else:
         per_vector_aggregate_states = anonymous_tuple.to_odict(
             result.delta_aggregate_state, recursive=True)
     return tff.learning.framework.ServerState(
         tff.learning.ModelWeights(tuple(result.model.trainable),
                                   tuple(result.model.non_trainable)),
         list(result.optimizer_state), per_vector_aggregate_states,
         tuple(result.model_broadcast_state))
Exemple #3
0
 def _traverse_tuple(comp, transform, context_tree, identifier_seq):
   """Helper function holding traversal logic for tuple nodes."""
   _ = next(identifier_seq)
   elements = []
   elements_modified = False
   for key, value in anonymous_tuple.iter_elements(comp):
     value, value_modified = _transform_postorder_with_symbol_bindings_switch(
         value, transform, context_tree, identifier_seq)
     elements.append((key, value))
     elements_modified = elements_modified or value_modified
   if elements_modified:
     comp = building_blocks.Tuple(elements)
   comp, comp_modified = transform(comp, context_tree)
   return comp, comp_modified or elements_modified
Exemple #4
0
 def _create_result_tensor(type_spec, scalar_value):
     """Packs `scalar_value` into `type_spec` recursively."""
     if type_spec.is_tensor():
         type_spec.shape.assert_is_fully_defined()
         result = tf.constant(scalar_value,
                              dtype=type_spec.dtype,
                              shape=type_spec.shape)
     else:
         elements = []
         for _, type_element in anonymous_tuple.iter_elements(type_spec):
             elements.append(
                 _create_result_tensor(type_element, scalar_value))
         result = elements
     return result
Exemple #5
0
def is_valid_bitwidth_type_for_value_type(bitwidth_type, value_type):
    """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`."""

    # NOTE: this function is primarily a helper for `intrinsic_factory.py`'s
    # `federated_secure_sum` function.

    def _both_are_type(first, second, ty):
        """Whether or not `first` and `second` are both instances of `ty`."""
        return isinstance(first, ty) and isinstance(second, ty)

    bitwidth_type = computation_types.to_type(bitwidth_type)
    value_type = computation_types.to_type(value_type)

    if _both_are_type(value_type, bitwidth_type, computation_types.TensorType):
        # Here, `value_type` refers to a tensor. Rather than check that
        # `bitwidth_type` is exactly the same, we check that it is a single integer,
        # since we want a single bitwidth integer per tensor.
        return bitwidth_type.dtype.is_integer and (
            bitwidth_type.shape.num_elements() == 1)
    elif _both_are_type(value_type, bitwidth_type,
                        computation_types.NamedTupleType):
        bitwidth_name_and_types = list(
            anonymous_tuple.iter_elements(bitwidth_type))
        value_name_and_types = list(anonymous_tuple.iter_elements(value_type))
        if len(bitwidth_name_and_types) != len(value_name_and_types):
            return False
        for (inner_bitwidth_name,
             inner_bitwidth_type), (inner_value_name, inner_value_type) in zip(
                 bitwidth_name_and_types, value_name_and_types):
            if inner_bitwidth_name != inner_value_name:
                return False
            if not is_valid_bitwidth_type_for_value_type(
                    inner_bitwidth_type, inner_value_type):
                return False
        return True
    else:
        return False
Exemple #6
0
def create_variables(name, type_spec, **kwargs):
  """Creates a set of variables that matches the given `type_spec`.

  Unlike `tf.get_variables`, this method will always create new variables,
  and will not retrieve variables previously created with the same name.

  Args:
    name: The common name to use for the scope in which all of the variables are
      to be created.
    type_spec: An instance of `tff.Type` or something convertible to it. The
      type signature may only be composed of tensor types and named tuples,
      possibly nested.
    **kwargs: Additional keyword args to pass to `tf.Variable` construction.

  Returns:
    Either a single variable when invoked with a tensor TFF type, or a nested
    structure of variables created in the appropriately-named variable scopes
    made up of anonymous tuples if invoked with a named tuple TFF type.

  Raises:
    TypeError: if `type_spec` is not a type signature composed of tensor and
      named tuple TFF types.
  """
  py_typecheck.check_type(name, str)
  type_spec = computation_types.to_type(type_spec)
  py_typecheck.check_type(type_spec, computation_types.Type)
  if isinstance(type_spec, computation_types.TensorType):
    return tf.Variable(
        initial_value=tf.zeros(dtype=type_spec.dtype, shape=type_spec.shape),
        name=name,
        **kwargs)
  elif isinstance(type_spec, computation_types.NamedTupleType):

    def _scoped_name(var_name, index):
      if var_name is None:
        var_name = str(index)
      if name:
        return '{}/{}'.format(name, var_name)
      else:
        return var_name

    type_elements_iter = anonymous_tuple.iter_elements(type_spec)
    return anonymous_tuple.AnonymousTuple(
        (k, create_variables(_scoped_name(k, i), v, **kwargs))
        for i, (k, v) in enumerate(type_elements_iter))
  else:
    raise TypeError(
        'Expected a TFF type signature composed of tensors and named tuples, '
        'found {}.'.format(type_spec))
 def _compute_tuple(self, comp, context):
   py_typecheck.check_type(comp, building_blocks.Tuple)
   result_elements = []
   result_type_elements = []
   for k, v in anonymous_tuple.iter_elements(comp):
     computed_v = self._compute(v, context)
     type_utils.check_assignable_from(v.type_signature,
                                      computed_v.type_signature)
     result_elements.append((k, computed_v.value))
     result_type_elements.append((k, computed_v.type_signature))
   return ComputedValue(
       anonymous_tuple.AnonymousTuple(result_elements),
       computation_types.NamedTupleType([
           (k, v) if k else v for k, v in result_type_elements
       ]))
 def _remove_selection_from_block_holding_tuple(comp):
   """Reduces selection from a block holding a tuple."""
   if (comp.is_selection() and comp.source.is_block() and
       comp.source.result.is_struct()):
     if comp.index is None:
       names = [
           x[0]
           for x in anonymous_tuple.iter_elements(comp.source.type_signature)
       ]
       index = names.index(comp.name)
     else:
       index = comp.index
     return building_blocks.Block(comp.source.locals,
                                  comp.source.result[index]), True
   return comp, False
def _default_from_tff_result_fn(record):
  """Converts AnonymousTuple to dict or list if possible."""
  if isinstance(record, anonymous_tuple.AnonymousTuple):
    try:
      record = record._asdict()
    except ValueError:
      # At least some of the fields in `record` were not named. If all of the
      # fields were not named, we can return a `list`. Otherwise `record`
      # is partially named, which is not supported.
      if anonymous_tuple.name_list(record):
        raise ValueError(
            'Cannot construct a default from a TFF result that '
            'has partially named fields. TFF result: {!s}'.format(record))
      record = [elt for _, elt in anonymous_tuple.iter_elements(record)]
  return record
    async def _delegate(self, value):
        """Delegates the entirety of `value` to the target executor.

    Args:
      value: An instance of `LambdaExecutorValue` (of any valid form except for
        a callable, which cannot be delegated, but that should only ever be
        constructed at the evaluation time).

    Returns:
      An instance of `executor_value_base.ExecutorValue` that represents a
      value embedded in the target executor.

    Raises:
      RuntimeError: Upon encountering a request to delegate a computation that
        is in a form that cannot be delegated.
    """
        py_typecheck.check_type(value, LambdaExecutorValue)
        value_repr = value.internal_representation
        if isinstance(value_repr, executor_value_base.ExecutorValue):
            return value_repr
        elif isinstance(value_repr, anonymous_tuple.AnonymousTuple):
            vals = await asyncio.gather(
                *[self._delegate(v) for v in value_repr])
            return await self._target_executor.create_tuple(
                anonymous_tuple.AnonymousTuple(
                    zip((
                        k
                        for k, _ in anonymous_tuple.iter_elements(value_repr)),
                        vals)))
        elif callable(value_repr):
            raise RuntimeError(
                'Cannot delegate a callable to a target executor; it appears that '
                'the internal computation structure has been evaluated too deeply '
                '(this is an internal error that represents a bug in the runtime).'
            )
        else:
            py_typecheck.check_type(value_repr, pb.Computation)

            # TODO(b/134543154): This is the place to check for the computation we
            # are about to push down to the target executor making references to
            # something declared outside of its scope, in which case we'll have to
            # do a little bit more work to plumb things through.

            value_tree = building_blocks.ComputationBuildingBlock.from_proto(
                value_repr)
            tree_analysis.check_contains_no_unbound_references(value_tree)
            return await self._target_executor.create_value(
                value_repr, value.type_signature)
    async def create_selection(self, source, index=None, name=None):
        """A coroutine that creates a selection from `source`.

    Args:
      source: The source to select from.
      index: An optional integer index. Either this, or `name` must be present.
      name: An optional string name. Either this, or `index` must be present.

    Returns:
      An instance of `FederatingExecutorValue` that represents the constructed
      selection.

    Raises:
      TypeError: If `source` is not embedded in the executor, before calling
        `create_call` or if `source` is not a `tff.NamedTupleType`.
      ValueError: If both `index` and `name` are `None` of if `source` is not a
        kind recognized by `FederatingExecutor`.
    """
        py_typecheck.check_type(source, FederatingExecutorValue)
        py_typecheck.check_type(source.type_signature,
                                computation_types.NamedTupleType)
        if index is None and name is None:
            raise ValueError(
                'Expected either `index` or `name` to be specificed, found both are '
                '`None`.')
        if name is not None:
            elements = anonymous_tuple.iter_elements(source.type_signature)
            name_to_index = dict((n, i) for i, (n, _) in enumerate(elements))
            index = name_to_index[name]
        if isinstance(source.internal_representation,
                      anonymous_tuple.AnonymousTuple):
            val = source.internal_representation
            selected = val[index]
            return FederatingExecutorValue(selected,
                                           source.type_signature[index])
        elif isinstance(source.internal_representation,
                        executor_value_base.ExecutorValue):
            val = source.internal_representation
            child = self._target_executors[None][0]
            return FederatingExecutorValue(
                await child.create_selection(val, index=index),
                source.type_signature[index])
        else:
            raise ValueError(
                'Unexpected internal representation while creating selection. '
                'Expected one of `AnonymousTuple` or value embedded in target '
                'executor, received {}'.format(source.internal_representation))
def serialize_type(
        type_spec: Optional[computation_types.Type]) -> Optional[pb.Type]:
    """Serializes 'type_spec' as a pb.Type.

  Note: Currently only serialization for tensor, named tuple, sequence, and
  function types is implemented.

  Args:
    type_spec: A `computation_types.Type`, or `None`.

  Returns:
    The corresponding instance of `pb.Type`, or `None` if the argument was
      `None`.

  Raises:
    TypeError: if the argument is of the wrong type.
    NotImplementedError: for type variants for which serialization is not
      implemented.
  """
    if type_spec is None:
        return None
    if type_spec.is_tensor():
        return pb.Type(tensor=_to_tensor_type_proto(type_spec))
    elif type_spec.is_sequence():
        return pb.Type(sequence=pb.SequenceType(
            element=serialize_type(type_spec.element)))
    elif type_spec.is_tuple():
        return pb.Type(tuple=pb.NamedTupleType(element=[
            pb.NamedTupleType.Element(name=e[0], value=serialize_type(e[1]))
            for e in anonymous_tuple.iter_elements(type_spec)
        ]))
    elif type_spec.is_function():
        return pb.Type(function=pb.FunctionType(
            parameter=serialize_type(type_spec.parameter),
            result=serialize_type(type_spec.result)))
    elif type_spec.is_placement():
        return pb.Type(placement=pb.PlacementType())
    elif type_spec.is_federated():
        return pb.Type(
            federated=pb.FederatedType(member=serialize_type(type_spec.member),
                                       placement=pb.PlacementSpec(
                                           value=pb.Placement(
                                               uri=type_spec.placement.uri)),
                                       all_equal=type_spec.all_equal))
    else:
        raise NotImplementedError
Exemple #13
0
def _get_hashable_key(value, type_spec):
    """Return a hashable key for value `value` of TFF type `type_spec`.

  Args:
    value: An argument to `create_value()`.
    type_spec: An optional type signature.

  Returns:
    A hashable key to use such that the same `value` always maps to the same
    key, and different ones map to different keys.

  Raises:
    TypeError: If there is no hashable key for this type of a value.
  """
    if type_spec.is_tuple():
        if not isinstance(value, anonymous_tuple.AnonymousTuple):
            try:
                value = anonymous_tuple.from_container(value)
            except Exception as e:
                raise TypeError(
                    'Failed to convert value with type_spec {} to `AnonymousTuple`'
                    .format(repr(type_spec))) from e
        type_specs = anonymous_tuple.iter_elements(type_spec)
        r_elem = []
        for v, (field_name, field_type) in zip(value, type_specs):
            r_elem.append((field_name, _get_hashable_key(v, field_type)))
        return anonymous_tuple.AnonymousTuple(r_elem)
    elif type_spec.is_federated():
        if type_spec.all_equal:
            return _get_hashable_key(value, type_spec.member)
        else:
            return tuple(
                [_get_hashable_key(x, type_spec.member) for x in value])
    elif isinstance(value, pb.Computation):
        return value.SerializeToString(deterministic=True)
    elif isinstance(value, np.ndarray):
        return ('<dtype={},shape={}>'.format(value.dtype,
                                             value.shape), value.tobytes())
    elif (isinstance(value, collections.Hashable)
          and not isinstance(value, (tf.Tensor, tf.Variable))):
        # TODO(b/139200385): Currently Tensor and Variable returns True for
        #   `isinstance(value, collections.Hashable)` even when it's not hashable.
        #   Hence this workaround.
        return value
    else:
        return HashableWrapper(value)
Exemple #14
0
def _get_hashable_key(value, type_spec):
    """Return a hashable key for value `value` of TFF type `type_spec`.

  Args:
    value: An argument to `create_value()`.
    type_spec: An optional type signature.

  Returns:
    A hashable key to use such that the same `value` always maps to the same
    key, and different ones map to different keys.

  Raises:
    TypeError: If there is no hashable key for this type of a value.
  """
    if isinstance(type_spec, computation_types.NamedTupleType):
        if isinstance(value, anonymous_tuple.AnonymousTuple):
            type_specs = anonymous_tuple.iter_elements(type_spec)
            r_elem = []
            for v, (field_name, field_type) in zip(value, type_specs):
                r_elem.append((field_name, _get_hashable_key(v, field_type)))
            return anonymous_tuple.AnonymousTuple(r_elem)
        else:
            return _get_hashable_key(anonymous_tuple.from_container(value),
                                     type_spec)
    elif isinstance(type_spec, computation_types.FederatedType):
        if type_spec.all_equal:
            return _get_hashable_key(value, type_spec.member)
        else:
            return tuple(
                [_get_hashable_key(x, type_spec.member) for x in value])
    elif isinstance(value, pb.Computation):
        return str(value)
    elif isinstance(value, np.ndarray):
        return '<dtype={},shape={},items={}>'.format(value.dtype, value.shape,
                                                     value.flatten())
    elif (isinstance(value, collections.Hashable)
          and not isinstance(value, (tf.Tensor, tf.Variable))):
        # TODO(b/139200385): Currently Tensor and Variable returns True for
        #   `isinstance(value, collections.Hashable)` even when it's not hashable.
        #   Hence this workaround.
        return value
    else:
        return HashableWrapper(value)
Exemple #15
0
 async def create_tuple(self, elements):
   constructed_anon_tuple = anonymous_tuple.from_container(elements)
   proto_elem = []
   type_elem = []
   for k, v in anonymous_tuple.iter_elements(constructed_anon_tuple):
     py_typecheck.check_type(v, RemoteValue)
     proto_elem.append(
         executor_pb2.CreateTupleRequest.Element(
             name=(k if k else None), value_ref=v.value_ref))
     type_elem.append((k, v.type_signature) if k else v.type_signature)
   result_type = computation_types.NamedTupleType(type_elem)
   request = executor_pb2.CreateTupleRequest(element=proto_elem)
   if self._bidi_stream is None:
     response = _request(self._stub.CreateTuple, request)
   else:
     response = (await self._bidi_stream.send_request(
         executor_pb2.ExecuteRequest(create_tuple=request))).create_tuple
   py_typecheck.check_type(response, executor_pb2.CreateTupleResponse)
   return RemoteValue(response.value_ref, result_type, self)
 async def compute(self):
     if isinstance(self._value, executor_value_base.ExecutorValue):
         return await self._value.compute()
     elif isinstance(self._type_signature, computation_types.FederatedType):
         py_typecheck.check_type(self._value, list)
         if self._type_signature.all_equal:
             if not self._value:
                 # TODO(b/145936344): this happens when the executor has inferred the
                 # cardinality of clients as 0, which can happen in tff.Computation
                 # that only do a tff.federated_broadcast. This probably should be
                 # handled elsewhere.
                 raise RuntimeError(
                     'Arrived at a computation that inferred there are '
                     '0 clients. Try explicity passing `num_clients` '
                     'parameter when constructor the executor.')
             vals = [self._value[0]]
         else:
             vals = self._value
         results = []
         for v in vals:
             py_typecheck.check_type(v, executor_value_base.ExecutorValue)
             results.append(v.compute())
         results = await asyncio.gather(*results)
         if self._type_signature.all_equal:
             return results[0]
         else:
             return results
     elif isinstance(self._value, anonymous_tuple.AnonymousTuple):
         gathered_values = await asyncio.gather(*[
             FederatingExecutorValue(v, t).compute()
             for v, t in zip(self._value, self._type_signature)
         ])
         type_elements_iter = anonymous_tuple.iter_elements(
             self._type_signature)
         return anonymous_tuple.AnonymousTuple(
             (k, v)
             for (k, _), v in zip(type_elements_iter, gathered_values))
     else:
         raise RuntimeError(
             'Computing values of type {} represented as {} is not supported in '
             'this executor.'.format(
                 self._type_signature,
                 py_typecheck.type_string(type(self._value))))
    async def compute(self):
        """Returns the result of computing the embedded value.

    Raises:
      TypeError: If the embedded value is composed of values that are not
        embedded in the executor.
      RuntimeError: If the embedded value is not a kind supported by the
        `FederatingExecutor`.
    """
        if isinstance(self._value, executor_value_base.ExecutorValue):
            return await self._value.compute()
        elif isinstance(self._value, anonymous_tuple.AnonymousTuple):
            results = await asyncio.gather(*[
                FederatedComposingStrategyValue(v, t).compute()
                for v, t in zip(self._value, self._type_signature)
            ])
            element_types = anonymous_tuple.iter_elements(self._type_signature)
            return anonymous_tuple.AnonymousTuple(
                (n, v) for (n, _), v in zip(element_types, results))
        elif isinstance(self._value, list):
            py_typecheck.check_type(self._type_signature,
                                    computation_types.FederatedType)
            for value in self._value:
                py_typecheck.check_type(value,
                                        executor_value_base.ExecutorValue)
            if self._type_signature.all_equal:
                return await self._value[0].compute()
            else:
                result = []
                values = await asyncio.gather(
                    *[v.compute() for v in self._value])
                for value in values:
                    py_typecheck.check_type(value, list)
                    result.extend(value)
                return result
        else:
            raise RuntimeError(
                'Computing values of type {} represented as {} is not supported in '
                'this executor.'.format(
                    self._type_signature,
                    py_typecheck.type_string(type(self._value))))
def is_structure_of_integers(type_spec: computation_types.Type) -> bool:
    """Determines if `type_spec` is a structure of integers.

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is a structure of integers, otherwise `False`.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        py_typecheck.check_type(type_spec.dtype, tf.DType)
        return type_spec.dtype.is_integer
    elif type_spec.is_tuple():
        return all(
            is_structure_of_integers(v)
            for _, v in anonymous_tuple.iter_elements(type_spec))
    elif type_spec.is_federated():
        return is_structure_of_integers(type_spec.member)
    else:
        return False
Exemple #19
0
 def generic_plus(arg):
   """Adds two arguments when possible."""
   x = arg[0]
   y = arg[1]
   _check_top_level_compatibility_with_generic_operators(x, y, 'Generic plus')
   if _generic_op_can_be_applied(x, y):
     return _apply_generic_op(tf.add, x, y)
   # TODO(b/136587334): Push this logic down a level
   elif x.type_signature.is_struct():
     # This case is needed if federated types are nested deeply.
     names = [t[0] for t in anonymous_tuple.iter_elements(x.type_signature)]
     added = [
         value_impl.ValueImpl.get_comp(generic_plus([x[i], y[i]]))
         for i in range(len(names))
     ]
     named_added = building_block_factory.create_named_tuple(
         building_blocks.Struct(added), names)
     return value_impl.ValueImpl(named_added, context_stack)
   else:
     raise TypeError('Generic plus encountered unexpected type {}, {}'.format(
         x.type_signature, y.type_signature))
    def __init__(self, initialize_fn, next_fn):
        """Creates a `tff.utils.MeasuredProcess`.

    Args:
      initialize_fn: a no-arg `tff.Computation` that creates the initial state
        of the measured process.
      next_fn: a `tff.Computation` that defines an iterated function. If
        `initialize_fn` returns a type `S`, then `next_fn` must return a TFF
        type `<state=S,result=O,measurements=M>`, and accept either a single
        argument of type _T_ or multiple arguments where the first argument must
        be of type _T_.

    Raises:
      TypeError: `initialize_fn` and `next_fn` are not compatible function
        types, or `next_fn` does not have a valid output signature.
    """
        super().__init__(initialize_fn, next_fn)
        # Additional type checks for the specialized MeasuredProcess.
        if isinstance(next_fn.type_signature.result,
                      computation_types.FederatedType):
            next_result_type = next_fn.type_signature.result.member
        elif isinstance(next_fn.type_signature.result,
                        computation_types.StructType):
            next_result_type = next_fn.type_signature.result
        else:
            raise TypeError(
                'MeasuredProcess must return a StructType (or '
                'FederatedType containing a StructType) with the signature '
                '<state=A,result=B,measurements=C>. Received a ({t}): {s}'.
                format(t=type(next_fn.type_signature.result),
                       s=next_fn.type_signature.result))
        result_field_names = [
            name
            for (name, _) in anonymous_tuple.iter_elements(next_result_type)
        ]
        if result_field_names != _RESULT_FIELD_NAMES:
            raise TypeError(
                'The return type of next_fn must match type signature '
                '<state=A,result=B,measurements=C>. Got: {!s}'.format(
                    next_result_type))
Exemple #21
0
def is_structure_of_integers(type_spec):
    """Determines if `type_spec` is a structure of integers.

  Args:
    type_spec: Either an instance of computation_types.Type, or something
      convertible to it.

  Returns:
    `True` iff `type_spec` is a structure of integers, otherwise `False`.
  """
    type_spec = computation_types.to_type(type_spec)
    if isinstance(type_spec, computation_types.TensorType):
        py_typecheck.check_type(type_spec.dtype, tf.DType)
        return type_spec.dtype.is_integer
    elif isinstance(type_spec, computation_types.NamedTupleType):
        return all(
            is_structure_of_integers(v)
            for _, v in anonymous_tuple.iter_elements(type_spec))
    elif isinstance(type_spec, computation_types.FederatedType):
        return is_structure_of_integers(type_spec.member)
    else:
        return False
  def __init__(self, value: LambdaValueInner, type_spec=None):
    """Creates an instance of a value embedded in a lambda executor.

    The internal representation of a value can take one of the following
    supported forms:

    * An instance of `executor_value_base.ExecutorValue` that represents a
      value embedded in the target executor (functional or non-functional).

    * A `ScopedLambda` (a `pb.Computation` lambda with some attached scope).

    * A single-level tuple (`anonymous_tuple.AnonymousTuple`) of instances
      of this class (of any of the supported forms listed here).

    Args:
      value: The internal representation of a value, as specified above.
      type_spec: An optional type signature, only allowed if `value` is a
        callable that represents a function (in which case it must be an
        instance of `computation_types.FunctionType`), otherwise it  must be
        `None` (the type is implied in other cases).
    """
    super().__init__()
    if isinstance(value, executor_value_base.ExecutorValue):
      py_typecheck.check_none(type_spec)
      type_spec = value.type_signature  # pytype: disable=attribute-error
    elif isinstance(value, ScopedLambda):
      py_typecheck.check_type(type_spec, computation_types.FunctionType)
    else:
      py_typecheck.check_type(value, anonymous_tuple.AnonymousTuple)
      py_typecheck.check_none(type_spec)
      type_elements = []
      for k, v in anonymous_tuple.iter_elements(value):
        py_typecheck.check_type(v, ReferenceResolvingExecutorValue)
        type_elements.append((k, v.type_signature))
      type_spec = computation_types.NamedTupleType([
          (k, v) if k is not None else v for k, v in type_elements
      ])
    self._value = value
    self._type_signature = type_spec
Exemple #23
0
async def delegate_entirely_to_executor(arg, arg_type, executor):
    """Delegates `arg` in its entirety to the target executor.

  The supported types of `arg` and the manner in which they are handled:

  * For instances of `pb.Computation`, calls `create_value()`.
  * For instances of `anonymous_tuple.AnonymousTuple`, calls `create_tuple()`.
  * Otherwise, must be `executor_value_base.ExecutorValue`, and assumed to
    already be owned by the target executor.

  Args:
    arg: The object to delegate to the target executor.
    arg_type: The type of this object.
    executor: The target executor to use.

  Returns:
    An instance of `executor_value_base.ExecutorValue` that represents the
    result of delegation.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
    py_typecheck.check_type(executor, executor_base.Executor)
    py_typecheck.check_type(arg_type, computation_types.Type)
    if isinstance(arg, pb.Computation):
        return await executor.create_value(arg, arg_type)
    elif isinstance(arg, anonymous_tuple.AnonymousTuple):
        vals = await asyncio.gather(*[
            delegate_entirely_to_executor(value, type_spec, executor)
            for value, type_spec in zip(arg, arg_type)
        ])
        return await executor.create_tuple(
            anonymous_tuple.AnonymousTuple(
                zip((k for k, _ in anonymous_tuple.iter_elements(arg_type)),
                    vals)))
    else:
        py_typecheck.check_type(arg, executor_value_base.ExecutorValue)
        return arg
def check_valid_federated_weighted_mean_argument_tuple_type(
        type_spec: computation_types.NamedTupleType):
    """Checks that `type_spec` is a valid type of a federated weighted mean arg.

  Args:
    type_spec: A `computation_types.NamedTupleType`.

  Raises:
    TypeError: If the check fails.
  """
    py_typecheck.check_type(type_spec, computation_types.NamedTupleType)
    if len(type_spec) != 2:
        raise TypeError('Expected a 2-tuple, found {}.'.format(type_spec))
    for _, v in anonymous_tuple.iter_elements(type_spec):
        check_federated_type(v, None, placement_literals.CLIENTS, False)
        if not is_average_compatible(v.member):
            raise TypeError(
                'Expected average-compatible args, got {} from argument of type {}.'
                .format(v.member, type_spec))
    w_type = type_spec[1].member
    py_typecheck.check_type(w_type, computation_types.TensorType)
    if w_type.shape.ndims != 0:
        raise TypeError('Expected scalar weight, got {}.'.format(w_type))
Exemple #25
0
def is_average_compatible(type_spec):
    """Determines if `type_spec` can be averaged.

  Types that are average-compatible are composed of numeric tensor types,
  either floating-point or complex, possibly packaged into nested named tuples,
  and possibly federated.

  Args:
    type_spec: An instance of `types.Type`, or something convertible to it.

  Returns:
    `True` iff `type_spec` is average-compatible, `False` otherwise.
  """
    type_spec = computation_types.to_type(type_spec)
    if isinstance(type_spec, computation_types.TensorType):
        return type_spec.dtype.is_floating or type_spec.dtype.is_complex
    elif isinstance(type_spec, computation_types.NamedTupleType):
        return all(
            is_average_compatible(v)
            for _, v in anonymous_tuple.iter_elements(type_spec))
    elif isinstance(type_spec, computation_types.FederatedType):
        return is_average_compatible(type_spec.member)
    else:
        return False
Exemple #26
0
 def generic_multiply(arg):
     """Multiplies two arguments when possible."""
     x = arg[0]
     y = arg[1]
     _check_top_level_compatibility_with_generic_operators(
         x, y, 'Generic multiply')
     if _generic_op_can_be_applied(x, y):
         return _apply_generic_op(tf.multiply, x, y)
     elif isinstance(x.type_signature, computation_types.NamedTupleType):
         # This case is needed if federated types are nested deeply.
         names = [
             t[0] for t in anonymous_tuple.iter_elements(x.type_signature)
         ]
         multiplied = [
             value_impl.ValueImpl.get_comp(generic_multiply([x[i], y[i]]))
             for i in range(len(names))
         ]
         named_multiplied = building_block_factory.create_named_tuple(
             building_blocks.Tuple(multiplied), names)
         return value_impl.ValueImpl(named_multiplied, context_stack)
     else:
         raise TypeError(
             'Generic multiply encountered unexpected type {}, {}'.format(
                 x.type_signature, y.type_signature))
def is_average_compatible(type_spec: computation_types.Type) -> bool:
    """Determines if `type_spec` can be averaged.

  Types that are average-compatible are composed of numeric tensor types,
  either floating-point or complex, possibly packaged into nested named tuples,
  and possibly federated.

  Args:
    type_spec: a `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is average-compatible, `False` otherwise.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return type_spec.dtype.is_floating or type_spec.dtype.is_complex
    elif type_spec.is_tuple():
        return all(
            is_average_compatible(v)
            for _, v in anonymous_tuple.iter_elements(type_spec))
    elif type_spec.is_federated():
        return is_average_compatible(type_spec.member)
    else:
        return False
 def _generic_zero(self, type_spec):
   if isinstance(type_spec, computation_types.TensorType):
     # TODO(b/113116813): Replace this with something more efficient, probably
     # calling some helper method from Numpy.
     with tf.Graph().as_default() as graph:
       zeros = tf.constant(0, type_spec.dtype, type_spec.shape)  # pytype: disable=attribute-error
       with tf.compat.v1.Session(graph=graph) as sess:
         zeros_val = sess.run(zeros)
     return ComputedValue(zeros_val, type_spec)
   elif isinstance(type_spec, computation_types.NamedTupleType):
     type_elements_iter = anonymous_tuple.iter_elements(type_spec)
     return ComputedValue(
         anonymous_tuple.AnonymousTuple(
             (k, self._generic_zero(v).value) for k, v in type_elements_iter),
         type_spec)
   elif isinstance(
       type_spec,
       (computation_types.SequenceType, computation_types.FunctionType,
        computation_types.AbstractType, computation_types.PlacementType)):
     raise TypeError(
         'The generic_zero is not well-defined for TFF type {}.'.format(
             type_spec))
   elif isinstance(type_spec, computation_types.FederatedType):
     if type_spec.all_equal:  # pytype: disable=attribute-error
       return ComputedValue(
           self._generic_zero(type_spec.member).value, type_spec)  # pytype: disable=attribute-error
     else:
       # TODO(b/113116813): Implement this in terms of the generic placement
       # operator once it's been added to the mix.
       raise NotImplementedError(
           'Generic zero support for non-all_equal federated types is not '
           'implemented yet.')
   else:
     raise NotImplementedError(
         'Generic zero support for {} is not implemented yet.'.format(
             type_spec))
Exemple #29
0
def transform_postorder(comp, transform):
    """Traverses `comp` recursively postorder and replaces its constituents.

  For each element of `comp` viewed as an expression tree, the transformation
  `transform` is applied first to building blocks it is parameterized by, then
  the element itself. The transformation `transform` should act as an identity
  function on the kinds of elements (computation building blocks) it does not
  care to transform. This corresponds to a post-order traversal of the
  expression tree, i.e., parameters are always transformed left-to-right (in
  the order in which they are listed in building block constructors), then the
  parent is visited and transformed with the already-visited, and possibly
  transformed arguments in place.

  NOTE: In particular, in `Call(f,x)`, both `f` and `x` are arguments to `Call`.
  Therefore, `f` is transformed into `f'`, next `x` into `x'` and finally,
  `Call(f',x')` is transformed at the end.

  Args:
    comp: A `computation_building_block.ComputationBuildingBlock` to traverse
      and transform bottom-up.
    transform: The transformation to apply locally to each building block in
      `comp`. It is a Python function that accepts a building block at input,
      and should return a (building block, bool) tuple as output, where the
      building block is a `computation_building_block.ComputationBuildingBlock`
      representing either the original building block or a transformed building
      block and the bool is a flag indicating if the building block was modified
      as.

  Returns:
    The result of applying `transform` to parts of `comp` in a bottom-up
    fashion, along with a Boolean with the value `True` if `comp` was
    transformed and `False` if it was not.

  Raises:
    TypeError: If the arguments are of the wrong computation_types.
    NotImplementedError: If the argument is a kind of computation building block
      that is currently not recognized.
  """
    py_typecheck.check_type(comp, building_blocks.ComputationBuildingBlock)
    if isinstance(comp, (
            building_blocks.CompiledComputation,
            building_blocks.Data,
            building_blocks.Intrinsic,
            building_blocks.Placement,
            building_blocks.Reference,
    )):
        return transform(comp)
    elif isinstance(comp, building_blocks.Selection):
        source, source_modified = transform_postorder(comp.source, transform)
        if source_modified:
            comp = building_blocks.Selection(source, comp.name, comp.index)
        comp, comp_modified = transform(comp)
        return comp, comp_modified or source_modified
    elif isinstance(comp, building_blocks.Tuple):
        elements = []
        elements_modified = False
        for key, value in anonymous_tuple.iter_elements(comp):
            value, value_modified = transform_postorder(value, transform)
            elements.append((key, value))
            elements_modified = elements_modified or value_modified
        if elements_modified:
            comp = building_blocks.Tuple(elements)
        comp, comp_modified = transform(comp)
        return comp, comp_modified or elements_modified
    elif isinstance(comp, building_blocks.Call):
        fn, fn_modified = transform_postorder(comp.function, transform)
        if comp.argument is not None:
            arg, arg_modified = transform_postorder(comp.argument, transform)
        else:
            arg, arg_modified = (None, False)
        if fn_modified or arg_modified:
            comp = building_blocks.Call(fn, arg)
        comp, comp_modified = transform(comp)
        return comp, comp_modified or fn_modified or arg_modified
    elif isinstance(comp, building_blocks.Lambda):
        result, result_modified = transform_postorder(comp.result, transform)
        if result_modified:
            comp = building_blocks.Lambda(comp.parameter_name,
                                          comp.parameter_type, result)
        comp, comp_modified = transform(comp)
        return comp, comp_modified or result_modified
    elif isinstance(comp, building_blocks.Block):
        variables = []
        variables_modified = False
        for key, value in comp.locals:
            value, value_modified = transform_postorder(value, transform)
            variables.append((key, value))
            variables_modified = variables_modified or value_modified
        result, result_modified = transform_postorder(comp.result, transform)
        if variables_modified or result_modified:
            comp = building_blocks.Block(variables, result)
        comp, comp_modified = transform(comp)
        return comp, comp_modified or variables_modified or result_modified
    else:
        raise NotImplementedError(
            'Unrecognized computation building block: {}'.format(str(comp)))
Exemple #30
0
def to_value(
    arg: Any,
    type_spec,
    context_stack: context_stack_base.ContextStack,
) -> ValueImpl:
    """Converts the argument into an instance of `tff.Value`.

  The types of non-`tff.Value` arguments that are currently convertible to
  `tff.Value` include the following:

  * Lists, tuples, anonymous tuples, named tuples, and dictionaries, all
    of which are converted into instances of `tff.Tuple`.
  * Placement literals, converted into instances of `tff.Placement`.
  * Computations.
  * Python constants of type `str`, `int`, `float`, `bool`
  * Numpy objects inherting from `np.ndarray` or `np.generic` (the parent
    of numpy scalar types)

  Args:
    arg: Either an instance of `tff.Value`, or an argument convertible to
      `tff.Value`. The argument must not be `None`.
    type_spec: An optional `computation_types.Type` or value convertible to it
      by `computation_types.to_type` which specifies the desired type signature
      of the resulting value. This allows for disambiguating the target type
      (e.g., when two TFF types can be mapped to the same Python
      representations), or `None` if none available, in which case TFF tries to
      determine the type of the TFF value automatically.
    context_stack: The context stack to use.

  Returns:
    An instance of `tff.Value` corresponding to the given `arg`, and of TFF type
    matching the `type_spec` if specified (not `None`).

  Raises:
    TypeError: if `arg` is of an unsupported type, or of a type that does not
      match `type_spec`. Raises explicit error message if TensorFlow constructs
      are encountered, as TensorFlow code should be sealed away from TFF
      federated context.
  """
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if type_spec is not None:
        type_spec = computation_types.to_type(type_spec)
        type_utils.check_well_formed(type_spec)
    if isinstance(arg, ValueImpl):
        result = arg
    elif isinstance(arg, building_blocks.ComputationBuildingBlock):
        result = ValueImpl(arg, context_stack)
    elif isinstance(arg, placement_literals.PlacementLiteral):
        result = ValueImpl(building_blocks.Placement(arg), context_stack)
    elif isinstance(arg, computation_base.Computation):
        result = ValueImpl(
            building_blocks.CompiledComputation(
                computation_impl.ComputationImpl.get_proto(arg)),
            context_stack)
    elif type_spec is not None and isinstance(type_spec,
                                              computation_types.SequenceType):
        result = _wrap_sequence_as_value(arg, type_spec.element, context_stack)
    elif isinstance(arg, anonymous_tuple.AnonymousTuple):
        result = ValueImpl(
            building_blocks.Tuple([
                (k, ValueImpl.get_comp(to_value(v, None, context_stack)))
                for k, v in anonymous_tuple.iter_elements(arg)
            ]), context_stack)
    elif py_typecheck.is_named_tuple(arg):
        result = to_value(arg._asdict(), None, context_stack)  # pytype: disable=attribute-error
    elif py_typecheck.is_attrs(arg):
        result = to_value(
            attr.asdict(arg,
                        dict_factory=collections.OrderedDict,
                        recurse=False), None, context_stack)
    elif isinstance(arg, dict):
        if isinstance(arg, collections.OrderedDict):
            items = arg.items()
        else:
            items = sorted(arg.items())
        value = building_blocks.Tuple([
            (k, ValueImpl.get_comp(to_value(v, None, context_stack)))
            for k, v in items
        ])
        result = ValueImpl(value, context_stack)
    elif isinstance(arg, (tuple, list)):
        result = ValueImpl(
            building_blocks.Tuple([
                ValueImpl.get_comp(to_value(x, None, context_stack))
                for x in arg
            ]), context_stack)
    elif isinstance(arg, tensorflow_utils.TENSOR_REPRESENTATION_TYPES):
        result = _wrap_constant_as_value(arg, context_stack)
    elif isinstance(arg, (tf.Tensor, tf.Variable)):
        raise TypeError(
            'TensorFlow construct {} has been encountered in a federated '
            'context. TFF does not support mixing TF and federated orchestration '
            'code. Please wrap any TensorFlow constructs with '
            '`tff.tf_computation`.'.format(arg))
    elif isinstance(arg, function_utils.PolymorphicFunction):
        # TODO(b/129567727) remove this case when this is no longer an error
        raise TypeError(
            'Polymorphic computations cannot be converted to a TFF value. Consider '
            'explicitly specifying the argument types of a computation before '
            'passing it to a function that requires a TFF value (such as a TFF '
            'intrinsic like federated_map).')
    else:
        raise TypeError(
            'Unable to interpret an argument of type {} as a TFF value.'.
            format(py_typecheck.type_string(type(arg))))
    py_typecheck.check_type(result, ValueImpl)
    if (type_spec is not None and not type_utils.is_assignable_from(
            type_spec, result.type_signature)):
        raise TypeError(
            'The supplied argument maps to TFF type {}, which is incompatible with '
            'the requested type {}.'.format(result.type_signature, type_spec))
    return result