Пример #1
0
 def _federated_zip_at_clients(self, arg):
   py_typecheck.check_type(arg.type_signature,
                           computation_types.NamedTupleType)
   py_typecheck.check_type(arg.value, anonymous_tuple.AnonymousTuple)
   zip_args = []
   zip_arg_types = []
   for idx in range(len(arg.type_signature)):
     val = arg.value[idx]
     py_typecheck.check_type(val, list)
     zip_args.append(val)
     val_type = arg.type_signature[idx]
     type_utils.check_federated_type(val_type, None, placements.CLIENTS, False)
     zip_arg_types.append(val_type.member)
   zipped_val = [anonymous_tuple.from_container(x) for x in zip(*zip_args)]
   return ComputedValue(
       zipped_val,
       type_factory.at_clients(
           computation_types.NamedTupleType(zip_arg_types)))
Пример #2
0
 async def _compute_intrinsic_federated_zip_at_server(self, arg):
   py_typecheck.check_type(arg.type_signature,
                           computation_types.NamedTupleType)
   py_typecheck.check_len(arg.type_signature, 2)
   py_typecheck.check_type(arg.internal_representation,
                           anonymous_tuple.AnonymousTuple)
   py_typecheck.check_len(arg.internal_representation, 2)
   for n in [0, 1]:
     type_utils.check_federated_type(
         arg.type_signature[n],
         placement=placement_literals.SERVER,
         all_equal=True)
   return CompositeValue(
       await self._parent_executor.create_tuple(
           [arg.internal_representation[n] for n in [0, 1]]),
       type_factory.at_server(
           computation_types.NamedTupleType(
               [arg.type_signature[0].member, arg.type_signature[1].member])))
Пример #3
0
 async def _compute_intrinsic_federated_collect(self, arg):
     py_typecheck.check_type(arg.type_signature,
                             computation_types.FederatedType)
     type_utils.check_federated_type(arg.type_signature,
                                     placement=placement_literals.CLIENTS)
     val = arg.internal_representation
     py_typecheck.check_type(val, list)
     member_type = arg.type_signature.member
     child = self._target_executors[placement_literals.SERVER][0]
     collected_items = await child.create_value(
         await asyncio.gather(*[v.compute() for v in val]),
         computation_types.SequenceType(member_type))
     return FederatingExecutorValue(
         [collected_items],
         computation_types.FederatedType(
             computation_types.SequenceType(member_type),
             placement_literals.SERVER,
             all_equal=True))
Пример #4
0
 async def _compute_intrinsic_federated_apply(self, arg):
     py_typecheck.check_type(arg.internal_representation,
                             anonymous_tuple.AnonymousTuple)
     py_typecheck.check_len(arg.internal_representation, 2)
     fn_type = arg.type_signature[0]
     py_typecheck.check_type(fn_type, computation_types.FunctionType)
     val_type = arg.type_signature[1]
     type_utils.check_federated_type(val_type,
                                     fn_type.parameter,
                                     placement_literals.SERVER,
                                     all_equal=True)
     fn = arg.internal_representation[0]
     py_typecheck.check_type(fn, pb.Computation)
     val = arg.internal_representation[1]
     py_typecheck.check_type(val, executor_value_base.ExecutorValue)
     return CompositeValue(
         await self._parent_executor.create_call(
             await self._parent_executor.create_value(fn, fn_type), val),
         type_factory.at_server(fn_type.result))
Пример #5
0
 def _federated_reduce(self, arg):
   py_typecheck.check_type(arg.type_signature,
                           computation_types.NamedTupleType)
   federated_type = arg.type_signature[0]
   type_utils.check_federated_type(federated_type, None, placements.CLIENTS,
                                   False)
   zero_type = arg.type_signature[1]
   op_type = arg.type_signature[2]
   py_typecheck.check_type(op_type, computation_types.FunctionType)
   type_utils.check_assignable_from(op_type.parameter,
                                    [zero_type, federated_type.member])
   total = ComputedValue(arg.value[1], zero_type)
   reduce_fn = arg.value[2]
   for v in arg.value[0]:
     total = reduce_fn(
         ComputedValue(
             anonymous_tuple.AnonymousTuple([(None, total.value), (None, v)]),
             op_type.parameter))
   return self._federated_value_at_server(total)
Пример #6
0
    async def _compute_intrinsic_federated_zip_at_clients(self, arg):
        py_typecheck.check_type(arg.type_signature,
                                computation_types.NamedTupleType)
        py_typecheck.check_len(arg.type_signature, 2)
        py_typecheck.check_type(arg.internal_representation,
                                anonymous_tuple.AnonymousTuple)
        py_typecheck.check_len(arg.internal_representation, 2)
        keys = [k for k, _ in anonymous_tuple.to_elements(arg.type_signature)]
        vals = [arg.internal_representation[n] for n in [0, 1]]
        types = [arg.type_signature[n] for n in [0, 1]]
        for n in [0, 1]:
            type_utils.check_federated_type(
                types[n], placement=placement_literals.CLIENTS)
            types[n] = type_factory.at_clients(types[n].member)
            py_typecheck.check_type(vals[n], list)
            py_typecheck.check_len(vals[n], len(self._child_executors))
        item_type = computation_types.NamedTupleType([
            ((keys[n], types[n].member) if keys[n] else types[n].member)
            for n in [0, 1]
        ])
        result_type = type_factory.at_clients(item_type)
        zip_type = computation_types.FunctionType(
            computation_types.NamedTupleType([
                ((keys[n], types[n]) if keys[n] else types[n]) for n in [0, 1]
            ]), result_type)
        zip_comp = executor_utils.create_intrinsic_comp(
            intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip_type)

        async def _child_fn(ex, x, y):
            py_typecheck.check_type(x, executor_value_base.ExecutorValue)
            py_typecheck.check_type(y, executor_value_base.ExecutorValue)
            return await ex.create_call(
                await ex.create_value(zip_comp, zip_type), await
                ex.create_tuple(
                    anonymous_tuple.AnonymousTuple([(keys[0], x),
                                                    (keys[1], y)])))

        result = await asyncio.gather(*[
            _child_fn(c, x, y)
            for c, x, y in zip(self._child_executors, vals[0], vals[1])
        ])
        return CompositeValue(result, result_type)
Пример #7
0
 def test_check_federated_type(self):
   type_spec = computation_types.FederatedType(tf.int32, placements.CLIENTS,
                                               False)
   type_utils.check_federated_type(type_spec, tf.int32, placements.CLIENTS,
                                   False)
   type_utils.check_federated_type(type_spec, tf.int32, None, None)
   type_utils.check_federated_type(type_spec, None, placements.CLIENTS, None)
   type_utils.check_federated_type(type_spec, None, None, False)
   self.assertRaises(TypeError, type_utils.check_federated_type, type_spec,
                     tf.bool, None, None)
   self.assertRaises(TypeError, type_utils.check_federated_type, type_spec,
                     None, placements.SERVER, None)
   self.assertRaises(TypeError, type_utils.check_federated_type, type_spec,
                     None, None, True)
Пример #8
0
def create_federated_map(fn, arg):
  r"""Creates a called federated map.

            Call
           /    \
  Intrinsic      Tuple
                 |
                 [Comp, Comp]

  Args:
    fn: A functional `computation_building_blocks.ComputationBuildingBlock` to
      use as the function.
    arg: A `computation_building_blocks.ComputationBuildingBlock` to use as the
      argument.

  Returns:
    A `computation_building_blocks.Call`.

  Raises:
    TypeError: If any of the types do not match.
  """
  py_typecheck.check_type(fn,
                          computation_building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
  py_typecheck.check_type(arg,
                          computation_building_blocks.ComputationBuildingBlock)
  type_utils.check_federated_type(arg.type_signature)
  parameter_type = computation_types.FederatedType(fn.type_signature.parameter,
                                                   placement_literals.CLIENTS,
                                                   False)
  result_type = computation_types.FederatedType(fn.type_signature.result,
                                                placement_literals.CLIENTS,
                                                False)
  intrinsic_type = computation_types.FunctionType(
      (fn.type_signature, parameter_type), result_type)
  intrinsic = computation_building_blocks.Intrinsic(
      intrinsic_defs.FEDERATED_MAP.uri, intrinsic_type)
  tup = computation_building_blocks.Tuple((fn, arg))
  return computation_building_blocks.Call(intrinsic, tup)
Пример #9
0
async def compute_intrinsic_federated_broadcast(
    executor: executor_base.Executor, arg: executor_value_base.ExecutorValue
) -> executor_value_base.ExecutorValue:
  """Computes a federated broadcast on the given `executor`.

  Args:
    executor: The executor to use.
    arg: The value to broadcast. Expected to be embedded in the `executor` and
      have federated type placed at `tff.SERVER` with all_equal of `True`.

  Returns:
    The result embedded in `executor`.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
  py_typecheck.check_type(executor, executor_base.Executor)
  py_typecheck.check_type(arg, executor_value_base.ExecutorValue)
  type_utils.check_federated_type(
      arg.type_signature, placement=placement_literals.SERVER, all_equal=True)
  value = await arg.compute()
  type_signature = computation_types.FederatedType(
      arg.type_signature.member, placement_literals.CLIENTS, all_equal=True)
  return await executor.create_value(value, type_signature)
Пример #10
0
def fit_argument(arg, type_spec, context):
  """Fits the given argument `arg` to match the given parameter `type_spec`.

  Args:
    arg: The argument to fit, an instance of `ComputedValue`.
    type_spec: The type of the parameter to fit to, an instance of `tff.Type` or
      something convertible to it.
    context: The context in which to perform the fitting, either an instance of
      `ComputationContext`, or `None` if unspecified.

  Returns:
    An instance of `ComputationValue` with the payload from `arg`, but matching
    the `type_spec` in the given context.

  Raises:
    TypeError: If the types mismatch.
    ValueError: If the value is invalid or does not fit the requested type.
  """
  py_typecheck.check_type(arg, ComputedValue)
  type_spec = computation_types.to_type(type_spec)
  py_typecheck.check_type(type_spec, computation_types.Type)
  if context is not None:
    py_typecheck.check_type(context, ComputationContext)
  type_utils.check_assignable_from(type_spec, arg.type_signature)
  if arg.type_signature == type_spec:
    return arg
  elif isinstance(type_spec, computation_types.NamedTupleType):
    py_typecheck.check_type(arg.value, anonymous_tuple.AnonymousTuple)
    result_elements = []
    for idx, (elem_name,
              elem_type) in enumerate(anonymous_tuple.to_elements(type_spec)):
      elem_val = ComputedValue(arg.value[idx], arg.type_signature[idx])
      if elem_val != elem_type:
        elem_val = fit_argument(elem_val, elem_type, context)
      result_elements.append((elem_name, elem_val.value))
    return ComputedValue(
        anonymous_tuple.AnonymousTuple(result_elements), type_spec)
  elif isinstance(type_spec, computation_types.FederatedType):
    type_utils.check_federated_type(
        arg.type_signature, placement=type_spec.placement)
    if arg.type_signature.all_equal:
      member_val = ComputedValue(arg.value, arg.type_signature.member)
      if type_spec.member != arg.type_signature.member:
        member_val = fit_argument(member_val, type_spec.member, context)
      if type_spec.all_equal:
        return ComputedValue(member_val.value, type_spec)
      else:
        cardinality = context.get_cardinality(type_spec.placement)
        return ComputedValue([member_val.value for _ in range(cardinality)],
                             type_spec)
    elif type_spec.all_equal:
      raise TypeError('Cannot fit a non all-equal {} into all-equal {}.'.format(
          arg.type_signature, type_spec))
    else:
      py_typecheck.check_type(arg.value, list)

      def _fit_member_val(x):
        x_val = ComputedValue(x, arg.type_signature.member)
        return fit_argument(x_val, type_spec.member, context).value

      return ComputedValue([_fit_member_val(x) for x in arg.value], type_spec)
  else:
    # TODO(b/113123634): Possibly add more conversions, e.g., for tensor types.
    return arg