Esempio n. 1
0
    async def _map(self, arg, all_equal=None):
        self._check_arg_is_structure(arg)
        py_typecheck.check_len(arg.internal_representation, 2)
        fn_type = arg.type_signature[0]
        py_typecheck.check_type(fn_type, computation_types.FunctionType)
        val_type = arg.type_signature[1]
        py_typecheck.check_type(val_type, computation_types.FederatedType)
        if all_equal is None:
            all_equal = val_type.all_equal
        elif all_equal and not val_type.all_equal:
            raise ValueError(
                'Cannot map a non-all_equal argument into an all_equal result.'
            )
        fn = arg.internal_representation[0]
        py_typecheck.check_type(fn, pb.Computation)
        val = arg.internal_representation[1]
        py_typecheck.check_type(val, list)
        for v in val:
            py_typecheck.check_type(v, executor_value_base.ExecutorValue)
        self._check_strategy_compatible_with_placement(val_type.placement)
        children = self._target_executors[val_type.placement]

        async def _map_child(fn, fn_type, value, child):
            fn_at_child = await child.create_value(fn, fn_type)
            return await child.create_call(fn_at_child, value)

        results = await asyncio.gather(*[
            _map_child(fn, fn_type, value, child)
            for (value, child) in zip(val, children)
        ])
        return FederatedResolvingStrategyValue(
            results,
            computation_types.FederatedType(fn_type.result,
                                            val_type.placement,
                                            all_equal=all_equal))
    async def compute_federated_aggregate(
        self, arg: FederatedComposingStrategyValue
    ) -> FederatedComposingStrategyValue:
        value_type, zero_type, accumulate_type, merge_type, report_type = (
            executor_utils.parse_federated_aggregate_argument_types(
                arg.type_signature))
        py_typecheck.check_type(arg.internal_representation, structure.Struct)
        py_typecheck.check_len(arg.internal_representation, 5)
        val = arg.internal_representation[0]
        py_typecheck.check_type(val, list)
        py_typecheck.check_len(val, len(self._target_executors))
        identity_report, identity_report_type = tensorflow_computation_factory.create_identity(
            zero_type)
        aggr_type = computation_types.FunctionType(
            computation_types.StructType([
                value_type, zero_type, accumulate_type, merge_type,
                identity_report_type
            ]), computation_types.at_server(zero_type))
        aggr_comp = executor_utils.create_intrinsic_comp(
            intrinsic_defs.FEDERATED_AGGREGATE, aggr_type)
        zero = await (await
                      self._executor.create_selection(arg, index=1)).compute()
        accumulate = arg.internal_representation[2]
        merge = arg.internal_representation[3]
        report = arg.internal_representation[4]

        async def _child_fn(ex, v):
            py_typecheck.check_type(v, executor_value_base.ExecutorValue)
            arg_values = [
                ex.create_value(zero, zero_type),
                ex.create_value(accumulate, accumulate_type),
                ex.create_value(merge, merge_type),
                ex.create_value(identity_report, identity_report_type)
            ]
            aggr_func, aggr_args = await asyncio.gather(
                ex.create_value(aggr_comp, aggr_type),
                ex.create_struct([v] +
                                 list(await asyncio.gather(*arg_values))))
            child_result = await (await ex.create_call(aggr_func,
                                                       aggr_args)).compute()
            result_at_server = await self._server_executor.create_value(
                child_result, zero_type)
            return result_at_server

        val_futures = asyncio.as_completed(
            [_child_fn(c, v) for c, v in zip(self._target_executors, val)])
        parent_merge, parent_report = await asyncio.gather(
            self._server_executor.create_value(merge, merge_type),
            self._server_executor.create_value(report, report_type))
        merge_result = await next(val_futures)
        for next_val_future in val_futures:
            next_val = await next_val_future
            merge_arg = await self._server_executor.create_struct(
                [merge_result, next_val])
            merge_result = await self._server_executor.create_call(
                parent_merge, merge_arg)
        report_result = await self._server_executor.create_call(
            parent_report, merge_result)
        return FederatedComposingStrategyValue(
            report_result, computation_types.at_server(report_type.result))
Esempio n. 3
0
def _check_iterative_process_compatible_with_canonical_form(
    initialize_tree, next_tree):
  """Tests compatibility with `tff.backends.mapreduce.CanonicalForm`.

  Args:
    initialize_tree: An instance of `building_blocks.ComputationBuildingBlock`
      that maps to the `initalize` property of a `tff.utils.IterativeProcess`.
    next_tree: An instance of `building_blocks.ComputationBuildingBlock` that
      maps to the `next` property of a `tff.utils.IterativeProcess`.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
  py_typecheck.check_type(initialize_tree,
                          building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(initialize_tree.type_signature,
                          computation_types.FederatedType)
  py_typecheck.check_type(next_tree, building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(next_tree.type_signature,
                          computation_types.FunctionType)
  py_typecheck.check_type(next_tree.type_signature.parameter,
                          computation_types.NamedTupleType)
  py_typecheck.check_len(next_tree.type_signature.parameter, 2)
  py_typecheck.check_type(next_tree.type_signature.result,
                          computation_types.NamedTupleType)
  py_typecheck.check_len(next_tree.type_signature.parameter, 2)
  next_result_len = len(next_tree.type_signature.result)
  if next_result_len != 2 and next_result_len != 3:
    raise TypeError(
        'Expected length of 2 or 3, found {}.'.format(next_result_len))
  async def _map(self, arg, all_equal=None):
    py_typecheck.check_type(arg.internal_representation, structure.Struct)
    py_typecheck.check_len(arg.internal_representation, 2)
    fn_type = arg.type_signature[0]
    py_typecheck.check_type(fn_type, computation_types.FunctionType)
    val_type = arg.type_signature[1]
    py_typecheck.check_type(val_type, computation_types.FederatedType)
    if all_equal is None:
      all_equal = val_type.all_equal
    elif all_equal and not val_type.all_equal:
      raise ValueError(
          'Cannot map a non-all_equal argument into an all_equal result.')
    fn = arg.internal_representation[0]
    py_typecheck.check_type(fn, pb.Computation)
    val = arg.internal_representation[1]
    py_typecheck.check_type(val, list)

    map_type = computation_types.FunctionType(
        [fn_type, computation_types.at_clients(fn_type.parameter)],
        computation_types.at_clients(fn_type.result))
    map_comp = executor_utils.create_intrinsic_comp(
        intrinsic_defs.FEDERATED_MAP, map_type)

    async def _child_fn(ex, v):
      py_typecheck.check_type(v, executor_value_base.ExecutorValue)
      fn_val = await ex.create_value(fn, fn_type)
      map_val, map_arg = await asyncio.gather(
          ex.create_value(map_comp, map_type), ex.create_struct([fn_val, v]))
      return await ex.create_call(map_val, map_arg)

    result_vals = await asyncio.gather(
        *[_child_fn(c, v) for c, v in zip(self._target_executors, val)])
    federated_type = computation_types.FederatedType(
        fn_type.result, val_type.placement, all_equal=all_equal)
    return FederatedComposingStrategyValue(result_vals, federated_type)
def _check_iterative_process_compatible_with_canonical_form(
    initialize_tree, next_tree):
  """Tests compatibility with `tff.backends.mapreduce.CanonicalForm`.

  Args:
    initialize_tree: An instance of `building_blocks.ComputationBuildingBlock`
      representing the `initalize` component of an
      `tff.templates.IterativeProcess`.
    next_tree: An instance of `building_blocks.ComputationBuildingBlock` that
      representing `next` component of an `tff.templates.IterativeProcess`.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
  py_typecheck.check_type(initialize_tree,
                          building_blocks.ComputationBuildingBlock)
  init_tree_ty = initialize_tree.type_signature
  _check_type_is_no_arg_fn(init_tree_ty, TypeError)
  _check_type(init_tree_ty.result, computation_types.FederatedType, TypeError)
  _check_placement(init_tree_ty.result, placements.SERVER, TypeError)
  py_typecheck.check_type(next_tree, building_blocks.ComputationBuildingBlock)
  py_typecheck.check_type(next_tree.type_signature,
                          computation_types.FunctionType)
  py_typecheck.check_type(next_tree.type_signature.parameter,
                          computation_types.StructType)
  py_typecheck.check_len(next_tree.type_signature.parameter, 2)
  py_typecheck.check_type(next_tree.type_signature.result,
                          computation_types.StructType)
  py_typecheck.check_len(next_tree.type_signature.parameter, 2)
  next_result_len = len(next_tree.type_signature.result)
  if next_result_len != 2:
    raise TypeError('Expected length of 2, found {}.'.format(next_result_len))
  async def compute_federated_secure_sum(
      self, arg: federated_resolving_strategy.FederatedResolvingStrategyValue
  ) -> federated_resolving_strategy.FederatedResolvingStrategyValue:
    py_typecheck.check_type(arg.internal_representation, structure.Struct)
    py_typecheck.check_len(arg.internal_representation, 2)
    logging.warning(
        'The implementation of the `tff.federated_secure_sum` intrinsic '
        'provided by the `tff.backends.test` runtime uses no cryptography.')

    server_ex = self._target_executors[placement_literals.SERVER][0]
    value = federated_resolving_strategy.FederatedResolvingStrategyValue(
        arg.internal_representation[0], arg.type_signature[0])
    bitwidth = await arg.internal_representation[1].compute()
    bitwidth_type = arg.type_signature[1]
    sum_result, mask, fn = await asyncio.gather(
        self.compute_federated_sum(value),
        executor_utils.embed_tf_constant(self._executor, bitwidth_type,
                                         2**bitwidth - 1),
        executor_utils.embed_tf_binary_operator(self._executor, bitwidth_type,
                                                tf.math.mod))
    fn_arg = await server_ex.create_struct([
        sum_result.internal_representation[0],
        mask.internal_representation,
    ])
    fn_arg_type = computation_types.FederatedType(
        fn_arg.type_signature, placement_literals.SERVER, all_equal=True)
    arg = federated_resolving_strategy.FederatedResolvingStrategyValue(
        structure.Struct([
            (None, fn.internal_representation),
            (None, [fn_arg]),
        ]), computation_types.StructType([fn.type_signature, fn_arg_type]))
    return await self.compute_federated_map_all_equal(arg)
Esempio n. 7
0
    async def compute_federated_aggregate(
        self, arg: FederatedResolvingStrategyValue
    ) -> FederatedResolvingStrategyValue:
        val_type, zero_type, accumulate_type, merge_type, report_type = (
            executor_utils.parse_federated_aggregate_argument_types(
                arg.type_signature))
        del val_type, zero_type, merge_type
        py_typecheck.check_type(arg.internal_representation, structure.Struct)
        py_typecheck.check_len(arg.internal_representation, 5)
        val, zero, accumulate, merge, report = arg.internal_representation

        # Discard `merge`. Since all aggregation happens on a single executor,
        # there's no need for this additional layer.
        del merge
        pre_report = await self.reduce(val, zero, accumulate, accumulate_type)

        py_typecheck.check_type(pre_report.type_signature,
                                computation_types.FederatedType)
        pre_report.type_signature.member.check_equivalent_to(
            report_type.parameter)

        return await self.compute_federated_apply(
            FederatedResolvingStrategyValue(
                structure.Struct([(None, report),
                                  (None, pre_report.internal_representation)]),
                computation_types.StructType(
                    (report_type, pre_report.type_signature))))
Esempio n. 8
0
 async def _map(self, arg, all_equal=None):
   self._check_arg_is_anonymous_tuple(arg)
   py_typecheck.check_len(arg.internal_representation, 2)
   fn_type = arg.type_signature[0]
   py_typecheck.check_type(fn_type, computation_types.FunctionType)
   val_type = arg.type_signature[1]
   py_typecheck.check_type(val_type, computation_types.FederatedType)
   if all_equal is None:
     all_equal = val_type.all_equal
   elif all_equal and not val_type.all_equal:
     raise ValueError(
         'Cannot map a non-all_equal argument into an all_equal result.')
   fn = arg.internal_representation[0]
   py_typecheck.check_type(fn, pb.Computation)
   val = arg.internal_representation[1]
   py_typecheck.check_type(val, list)
   for v in val:
     py_typecheck.check_type(v, executor_value_base.ExecutorValue)
   children = self._target_executors[val_type.placement]
   fns = await asyncio.gather(*[c.create_value(fn, fn_type) for c in children])
   results = await asyncio.gather(*[
       c.create_call(f, v) for c, (f, v) in zip(children, list(zip(fns, val)))
   ])
   return FederatingExecutorValue(
       results,
       computation_types.FederatedType(
           fn_type.result, val_type.placement, all_equal=all_equal))
Esempio n. 9
0
  async def compute_federated_aggregate(
      self,
      arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
    val_type, zero_type, accumulate_type, merge_type, report_type = (
        executor_utils.parse_federated_aggregate_argument_types(
            arg.type_signature))
    del val_type, merge_type
    py_typecheck.check_type(arg.internal_representation, structure.Struct)
    py_typecheck.check_len(arg.internal_representation, 5)
    val, zero, accumulate, merge, report = arg.internal_representation

    # Discard `merge`. Since all aggregation happens on a single executor,
    # there's no need for this additional layer.
    del merge

    # Re-wrap `zero` in a `FederatingResolvingStrategyValue` to ensure that it
    # is an `ExecutorValue` rather than a `Struct` (since the internal
    # representation can include embedded values, lists of embedded values
    # (in the case of federated values), or `Struct`s.
    zero = FederatedResolvingStrategyValue(zero, zero_type)
    pre_report = await self.reduce(val, zero, accumulate, accumulate_type)

    py_typecheck.check_type(pre_report.type_signature,
                            computation_types.FederatedType)
    pre_report.type_signature.member.check_equivalent_to(report_type.parameter)

    return await self.compute_federated_apply(
        FederatedResolvingStrategyValue(
            structure.Struct([(None, report),
                              (None, pre_report.internal_representation)]),
            computation_types.StructType(
                (report_type, pre_report.type_signature))))
Esempio n. 10
0
  async def _compute_intrinsic_federated_map(self, arg):
    py_typecheck.check_type(arg.internal_representation,
                            anonymous_tuple.AnonymousTuple)
    py_typecheck.check_len(arg.internal_representation, 2)
    fn_type = arg.type_signature[0]
    py_typecheck.check_type(fn_type, computation_types.FunctionType)
    val_type = arg.type_signature[1]
    type_utils.check_federated_type(val_type, fn_type.parameter,
                                    placement_literals.CLIENTS)
    fn = arg.internal_representation[0]
    val = arg.internal_representation[1]
    py_typecheck.check_type(fn, pb.Computation)
    py_typecheck.check_type(val, list)

    map_type = computation_types.FunctionType(
        [fn_type, type_factory.at_clients(fn_type.parameter)],
        type_factory.at_clients(fn_type.result))
    map_comp = executor_utils.create_intrinsic_comp(
        intrinsic_defs.FEDERATED_MAP, map_type)

    async def _child_fn(ex, v):
      py_typecheck.check_type(v, executor_value_base.ExecutorValue)
      fn_val = await ex.create_value(fn, fn_type)
      map_val, map_arg = tuple(await asyncio.gather(
          ex.create_value(map_comp, map_type), ex.create_tuple([fn_val, v])))
      return await ex.create_call(map_val, map_arg)

    result_vals = await asyncio.gather(
        *[_child_fn(c, v) for c, v in zip(self._child_executors, val)])
    return CompositeValue(result_vals, type_factory.at_clients(fn_type.result))
Esempio n. 11
0
def parse_federated_aggregate_argument_types(type_spec):
    """Verifies and parses `type_spec` into constituents.

  Args:
    type_spec: An instance of `computation_types.StructType`.

  Returns:
    A tuple of (value_type, zero_type, accumulate_type, merge_type, report_type)
    for the 5 type constituents.
  """
    py_typecheck.check_type(type_spec, computation_types.StructType)
    py_typecheck.check_len(type_spec, 5)
    value_type = type_spec[0]
    py_typecheck.check_type(value_type, computation_types.FederatedType)
    item_type = value_type.member
    zero_type = type_spec[1]
    accumulate_type = type_spec[2]
    accumulate_type.check_equivalent_to(
        type_factory.reduction_op(zero_type, item_type))
    merge_type = type_spec[3]
    merge_type.check_equivalent_to(type_factory.binary_op(zero_type))
    report_type = type_spec[4]
    py_typecheck.check_type(report_type, computation_types.FunctionType)
    report_type.parameter.check_equivalent_to(zero_type)
    return value_type, zero_type, accumulate_type, merge_type, report_type
 async def compute_federated_value(
     self, value: Any, type_signature: computation_types.Type
 ) -> FederatedComposingStrategyValue:
   if type_signature.placement == placement_literals.SERVER:
     if not type_signature.all_equal:
       raise ValueError(
           'Expected an all equal value at the `SERVER` placement, '
           'found {}.'.format(type_signature))
     results = await self._server_executor.create_value(
         value, type_signature.member)
     return FederatedComposingStrategyValue(results, type_signature)
   elif type_signature.placement == placement_literals.CLIENTS:
     if type_signature.all_equal:
       results = await asyncio.gather(*[
           c.create_value(value, type_signature)
           for c in self._target_executors
       ])
       return FederatedComposingStrategyValue(results, type_signature)
     else:
       py_typecheck.check_type(value, list)
       cardinalities = await self._get_cardinalities()
       total_clients = sum(cardinalities)
       py_typecheck.check_len(value, total_clients)
       results = []
       offset = 0
       for child, num_clients in zip(self._target_executors, cardinalities):
         new_offset = offset + num_clients
         result = child.create_value(value[offset:new_offset], type_signature)
         results.append(result)
         offset = new_offset
       return FederatedComposingStrategyValue(await asyncio.gather(*results),
                                              type_signature)
   else:
     raise ValueError('Unexpected placement {}.'.format(
         type_signature.placement))
Esempio n. 13
0
 async def _encrypt_values_on_clients(self, val, sender, receiver):
     ###
     # Case 2: sender=CLIENTS
     #     plaintext: Fed(Tensor, CLIENTS, all_equal=False)
     #     pk_receiver: Fed(Tensor, CLIENTS, all_equal=True)
     #     sk_sender: Fed(Tensor, CLIENTS, all_equal=False)
     #   Returns:
     #     encrypted_values: Fed(Tensor, CLIENTS, all_equal=False)
     ###
     ### Check proper key placement
     sk_sender = self.key_references.get_secret_key(sender)
     pk_receiver = self.key_references.get_public_key(receiver)
     type_analysis.check_federated_type(sk_sender.type_signature,
                                        placement=sender)
     assert sk_sender.type_signature.placement is sender
     assert pk_receiver.type_signature.placement is sender
     ### Check placement cardinalities
     snd_children = self.strategy._get_child_executors(sender)
     rcv_children = self.strategy._get_child_executors(receiver)
     py_typecheck.check_len(rcv_children, 1)
     ### Check value cardinalities
     type_analysis.check_federated_type(val.type_signature,
                                        placement=sender)
     federated_value_internals = [
         val.internal_representation, pk_receiver.internal_representation,
         sk_sender.internal_representation
     ]
     for v in federated_value_internals:
         py_typecheck.check_len(v, len(snd_children))
     ### Materialize encryptor function definition & type spec
     input_type = val.type_signature.member
     self._input_type_cache = input_type
     pk_rcv_type = pk_receiver.type_signature.member
     sk_snd_type = sk_sender.type_signature.member
     pk_element_type = pk_rcv_type
     encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type)
     encryptor_proto, encryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_encryptor, self._encryptor_cache,
         encryptor_arg_spec)
     ### Encrypt values and return them
     encryptor_fns = asyncio.gather(*[
         snd_child.create_value(encryptor_proto, encryptor_type)
         for snd_child in snd_children
     ])
     encryptor_args = asyncio.gather(*[
         snd_child.create_struct([v, pk, sk]) for v, pk, sk, snd_child in
         zip(*federated_value_internals, snd_children)
     ])
     encryptor_fns, encryptor_args = await asyncio.gather(
         encryptor_fns, encryptor_args)
     encrypted_values = [
         snd_child.create_call(encryptor, arg) for encryptor, arg, snd_child
         in zip(encryptor_fns, encryptor_args, snd_children)
     ]
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         await asyncio.gather(*encrypted_values),
         tff.FederatedType(encryptor_type.result,
                           sender,
                           all_equal=val.type_signature.all_equal))
    async def compute_federated_aggregate(
        self, arg: FederatedComposingStrategyValue
    ) -> FederatedComposingStrategyValue:
        value_type, zero_type, accumulate_type, merge_type, report_type = (
            executor_utils.parse_federated_aggregate_argument_types(
                arg.type_signature))
        py_typecheck.check_type(arg.internal_representation,
                                anonymous_tuple.AnonymousTuple)
        py_typecheck.check_len(arg.internal_representation, 5)
        val = arg.internal_representation[0]
        py_typecheck.check_type(val, list)
        py_typecheck.check_len(val, len(self._target_executors))
        identity_report = tensorflow_computation_factory.create_identity(
            zero_type)
        identity_report_type = type_factory.unary_op(zero_type)
        aggr_type = computation_types.FunctionType(
            computation_types.NamedTupleType([
                value_type, zero_type, accumulate_type, merge_type,
                identity_report_type
            ]), type_factory.at_server(zero_type))
        aggr_comp = executor_utils.create_intrinsic_comp(
            intrinsic_defs.FEDERATED_AGGREGATE, aggr_type)
        zero = await (await
                      self._executor.create_selection(arg, index=1)).compute()
        accumulate = arg.internal_representation[2]
        merge = arg.internal_representation[3]
        report = arg.internal_representation[4]

        async def _child_fn(ex, v):
            py_typecheck.check_type(v, executor_value_base.ExecutorValue)
            aggr_func, aggr_args = await asyncio.gather(
                ex.create_value(aggr_comp, aggr_type),
                ex.create_tuple([v] + list(await asyncio.gather(
                    ex.create_value(zero, zero_type),
                    ex.create_value(accumulate, accumulate_type),
                    ex.create_value(merge, merge_type),
                    ex.create_value(identity_report, identity_report_type)))))
            return await (await ex.create_call(aggr_func, aggr_args)).compute()

        vals = await asyncio.gather(
            *[_child_fn(c, v) for c, v in zip(self._target_executors, val)])
        parent_vals = await asyncio.gather(
            *[self._server_executor.create_value(v, zero_type) for v in vals])
        parent_merge, parent_report = await asyncio.gather(
            self._server_executor.create_value(merge, merge_type),
            self._server_executor.create_value(report, report_type))
        merge_result = parent_vals[0]
        for next_val in parent_vals[1:]:
            merge_result = await self._server_executor.create_call(
                parent_merge, await
                self._server_executor.create_tuple([merge_result, next_val]))
        return FederatedComposingStrategyValue(
            await self._server_executor.create_call(parent_report,
                                                    merge_result),
            type_factory.at_server(report_type.result))
Esempio n. 15
0
 async def _decrypt_values_on_clients(self, val, sender, receiver):
     ### Check proper key placement
     pk_sender = self.key_references.get_public_key(sender)
     sk_receiver = self.key_references.get_secret_key(receiver)
     type_analysis.check_federated_type(pk_sender.type_signature,
                                        placement=receiver)
     type_analysis.check_federated_type(sk_receiver.type_signature,
                                        placement=receiver)
     pk_snd_type = pk_sender.type_signature.member
     sk_rcv_type = sk_receiver.type_signature.member
     ### Check value cardinalities
     rcv_children = self.strategy._get_child_executors(receiver)
     federated_value_internals = [
         val.internal_representation, pk_sender.internal_representation,
         sk_receiver.internal_representation
     ]
     for fv in federated_value_internals:
         py_typecheck.check_len(fv, len(rcv_children))
     ### Materialize decryptor type_spec & function definition
     input_type = val.type_signature.member
     #   input_type[0] is a tff.TensorType, thus input_type represents the
     #   tuple needed for a single value to be decrypted.
     py_typecheck.check_type(input_type[0], tff.TensorType)
     py_typecheck.check_type(pk_snd_type, tff.TensorType)
     input_element_type = input_type
     pk_element_type = pk_snd_type
     decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type)
     decryptor_proto, decryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_decryptor,
         self._decryptor_cache,
         decryptor_arg_spec,
         orig_tensor_dtype=self._input_type_cache.dtype)
     ### Decrypt values and return them
     decryptor_fns = asyncio.gather(*[
         rcv_child.create_value(decryptor_proto, decryptor_type)
         for rcv_child in rcv_children
     ])
     decryptor_args = asyncio.gather(*[
         rcv_child.create_struct([v, pk, sk]) for v, pk, sk, rcv_child in
         zip(*federated_value_internals, rcv_children)
     ])
     decryptor_fns, decryptor_args = await asyncio.gather(
         decryptor_fns, decryptor_args)
     decrypted_values = [
         rcv_child.create_call(decryptor, arg) for decryptor, arg, rcv_child
         in zip(decryptor_fns, decryptor_args, rcv_children)
     ]
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         await asyncio.gather(*decrypted_values),
         tff.FederatedType(decryptor_type.result,
                           receiver,
                           all_equal=val.type_signature.all_equal))
Esempio n. 16
0
    async def compute_federated_select(
        self, arg: FederatedComposingStrategyValue
    ) -> FederatedComposingStrategyValue:
        client_keys_type, max_key_type, server_val_type, select_fn_type = (
            arg.type_signature)
        del client_keys_type  # Unused
        py_typecheck.check_type(arg.internal_representation, structure.Struct)
        client_keys, max_key, server_val, select_fn = arg.internal_representation
        py_typecheck.check_type(client_keys, list)
        py_typecheck.check_len(client_keys, len(self._target_executors))
        py_typecheck.check_type(max_key, executor_value_base.ExecutorValue)
        py_typecheck.check_type(server_val, executor_value_base.ExecutorValue)
        py_typecheck.check_type(select_fn, pb.Computation)
        unplaced_server_val, unplaced_max_key = await asyncio.gather(
            server_val.compute(), max_key.compute())
        select_type = computation_types.FunctionType(
            arg.type_signature,
            computation_types.at_clients(
                computation_types.SequenceType(select_fn_type.result)))
        select_pb = executor_utils.create_intrinsic_comp(
            intrinsic_defs.FEDERATED_SELECT, select_type)

        async def child_fn(child, child_client_keys):
            child_max_key_fut = child.create_value(unplaced_max_key,
                                                   max_key_type)
            child_server_val_fut = child.create_value(unplaced_server_val,
                                                      server_val_type)
            child_select_fn_fut = child.create_value(select_fn, select_fn_type)
            child_max_key, child_server_val, child_select_fn = await asyncio.gather(
                child_max_key_fut, child_server_val_fut, child_select_fn_fut)
            child_fn_fut = child.create_value(select_pb, select_type)
            child_arg_fut = child.create_struct(
                structure.Struct([(None, child_client_keys),
                                  (None, child_max_key),
                                  (None, child_server_val),
                                  (None, child_select_fn)]))
            child_fn, child_arg = await asyncio.gather(child_fn_fut,
                                                       child_arg_fut)
            return await child.create_call(child_fn, child_arg)

        return FederatedComposingStrategyValue(
            list(await asyncio.gather(*[
                child_fn(ex, ex_keys)
                for (ex, ex_keys) in zip(self._target_executors, client_keys)
            ])),
            computation_types.at_clients(
                computation_types.SequenceType(select_fn_type.result)))
Esempio n. 17
0
 async def _compute_intrinsic_federated_apply(self, arg):
   py_typecheck.check_type(arg.internal_representation,
                           anonymous_tuple.AnonymousTuple)
   py_typecheck.check_len(arg.internal_representation, 2)
   fn_type = arg.type_signature[0]
   py_typecheck.check_type(fn_type, computation_types.FunctionType)
   val_type = arg.type_signature[1]
   type_utils.check_federated_type(
       val_type, fn_type.parameter, placement_literals.SERVER, all_equal=True)
   fn = arg.internal_representation[0]
   val = arg.internal_representation[1]
   py_typecheck.check_type(fn, pb.Computation)
   py_typecheck.check_type(val, executor_value_base.ExecutorValue)
   return CompositeValue(
       await self._parent_executor.create_call(
           await self._parent_executor.create_value(fn, fn_type), val),
       type_factory.at_server(fn_type.result))
Esempio n. 18
0
 async def _decrypt_values_on_singleton(self, val, sender, receiver):
     ### Check proper key placement
     pk_sender = self.key_references.get_public_key(sender)
     sk_receiver = self.key_references.get_secret_key(receiver)
     type_analysis.check_federated_type(pk_sender.type_signature,
                                        placement=receiver)
     type_analysis.check_federated_type(sk_receiver.type_signature,
                                        placement=receiver)
     pk_snd_type = pk_sender.type_signature.member
     sk_rcv_type = sk_receiver.type_signature.member
     ### Check placement cardinalities
     snd_children = self.strategy._get_child_executors(sender)
     rcv_children = self.strategy._get_child_executors(receiver)
     py_typecheck.check_len(rcv_children, 1)
     rcv_child = rcv_children[0]
     ### Check value cardinalities
     py_typecheck.check_len(pk_sender.internal_representation,
                            len(snd_children))
     py_typecheck.check_len(sk_receiver.internal_representation, 1)
     ### Materialize decryptor type_spec & function definition
     py_typecheck.check_type(val.type_signature, tff.StructType)
     type_analysis.check_federated_type(val.type_signature[0],
                                        placement=receiver,
                                        all_equal=True)
     input_type = val.type_signature[0].member
     #   each input_type is a tuple needed for one value to be decrypted
     py_typecheck.check_type(input_type, tff.StructType)
     py_typecheck.check_type(pk_snd_type, tff.StructType)
     py_typecheck.check_len(val.type_signature, len(pk_snd_type))
     input_element_type = input_type
     pk_element_type = pk_snd_type[0]
     decryptor_arg_spec = (input_element_type, pk_element_type, sk_rcv_type)
     decryptor_proto, decryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_decryptor,
         self._decryptor_cache,
         decryptor_arg_spec,
         orig_tensor_dtype=self._input_type_cache.dtype)
     ### Decrypt values and return them
     vals = val.internal_representation
     sk = sk_receiver.internal_representation[0]
     decryptor_fn = await rcv_child.create_value(decryptor_proto,
                                                 decryptor_type)
     decryptor_args = await asyncio.gather(*[
         rcv_child.create_struct([v, pk, sk])
         for v, pk in zip(vals, pk_sender.internal_representation)
     ])
     decrypted_values = await asyncio.gather(*[
         rcv_child.create_call(decryptor_fn, arg) for arg in decryptor_args
     ])
     decrypted_value_types = [decryptor_type.result] * len(decrypted_values)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         structure.from_container(decrypted_values),
         tff.StructType([
             tff.FederatedType(dvt, receiver, all_equal=True)
             for dvt in decrypted_value_types
         ]))
 async def compute_federated_apply(
     self,
     arg: FederatedComposingStrategyValue) -> FederatedComposingStrategyValue:
   py_typecheck.check_type(arg.internal_representation, structure.Struct)
   py_typecheck.check_len(arg.internal_representation, 2)
   fn_type = arg.type_signature[0]
   py_typecheck.check_type(fn_type, computation_types.FunctionType)
   val_type = arg.type_signature[1]
   type_analysis.check_federated_type(
       val_type, fn_type.parameter, placement_literals.SERVER, all_equal=True)
   fn = arg.internal_representation[0]
   py_typecheck.check_type(fn, pb.Computation)
   val = arg.internal_representation[1]
   py_typecheck.check_type(val, executor_value_base.ExecutorValue)
   return FederatedComposingStrategyValue(
       await self._server_executor.create_call(
           await self._server_executor.create_value(fn, fn_type), val),
       computation_types.at_server(fn_type.result))
 async def compute_federated_zip_at_server(
     self,
     arg: FederatedComposingStrategyValue) -> FederatedComposingStrategyValue:
   py_typecheck.check_type(arg.type_signature, computation_types.StructType)
   py_typecheck.check_len(arg.type_signature, 2)
   py_typecheck.check_type(arg.internal_representation, structure.Struct)
   py_typecheck.check_len(arg.internal_representation, 2)
   for n in [0, 1]:
     type_analysis.check_federated_type(
         arg.type_signature[n],
         placement=placement_literals.SERVER,
         all_equal=True)
   return FederatedComposingStrategyValue(
       await self._server_executor.create_struct(
           [arg.internal_representation[n] for n in [0, 1]]),
       computation_types.at_server(
           computation_types.StructType(
               [arg.type_signature[0].member, arg.type_signature[1].member])))
Esempio n. 21
0
 async def _compute_intrinsic_federated_zip_at_server(self, arg):
   py_typecheck.check_type(arg.type_signature,
                           computation_types.NamedTupleType)
   py_typecheck.check_len(arg.type_signature, 2)
   py_typecheck.check_type(arg.internal_representation,
                           anonymous_tuple.AnonymousTuple)
   py_typecheck.check_len(arg.internal_representation, 2)
   for n in [0, 1]:
     type_utils.check_federated_type(
         arg.type_signature[n],
         placement=placement_literals.SERVER,
         all_equal=True)
   return CompositeValue(
       await self._parent_executor.create_tuple(
           [arg.internal_representation[n] for n in [0, 1]]),
       type_factory.at_server(
           computation_types.NamedTupleType(
               [arg.type_signature[0].member, arg.type_signature[1].member])))
Esempio n. 22
0
def _check_key_inputter(fn_value):
    fn_type = fn_value.type_signature
    py_typecheck.check_type(fn_type, tff.FunctionType)
    try:
        py_typecheck.check_len(fn_type.result, 2)
    except ValueError:
        raise ValueError('Expected 2 elements in the output of key_inputter, '
                         'found {}.'.format(len(fn_type.result)))
    ek_type, dk_type = fn_type.result
    py_typecheck.check_type(ek_type, tff.TensorType)
    py_typecheck.check_type(dk_type, tff.StructType)
    try:
        py_typecheck.check_len(dk_type, 2)
    except ValueError:
        raise ValueError(
            'Expected a two element tuple for the decryption key from '
            'key_inputter, found {} elements.'.format(len(fn_type.result)))
    py_typecheck.check_type(dk_type[0], tff.TensorType)
    py_typecheck.check_type(dk_type[1], tff.TensorType)
    async def compute_federated_aggregate(
        self, arg: FederatedResolvingStrategyValue
    ) -> FederatedResolvingStrategyValue:
        val_type, zero_type, accumulate_type, _, report_type = (
            executor_utils.parse_federated_aggregate_argument_types(
                arg.type_signature))
        py_typecheck.check_type(arg.internal_representation,
                                anonymous_tuple.AnonymousTuple)
        py_typecheck.check_len(arg.internal_representation, 5)

        # Note: This is a simple initial implementation that simply forwards this
        # to `federated_reduce()`. The more complete implementation would be able
        # to take advantage of the parallelism afforded by `merge` to reduce the
        # cost from liner (with respect to the number of clients) to sub-linear.

        # TODO(b/134543154): Expand this implementation to take advantage of the
        # parallelism afforded by `merge`.

        val = arg.internal_representation[0]
        zero = arg.internal_representation[1]
        accumulate = arg.internal_representation[2]
        pre_report = await self.compute_federated_reduce(
            FederatedResolvingStrategyValue(
                anonymous_tuple.AnonymousTuple([(None, val), (None, zero),
                                                (None, accumulate)]),
                computation_types.NamedTupleType(
                    (val_type, zero_type, accumulate_type))))

        py_typecheck.check_type(pre_report.type_signature,
                                computation_types.FederatedType)
        pre_report.type_signature.member.check_equivalent_to(
            report_type.parameter)

        report = arg.internal_representation[4]
        return await self.compute_federated_apply(
            FederatedResolvingStrategyValue(
                anonymous_tuple.AnonymousTuple([
                    (None, report), (None, pre_report.internal_representation)
                ]),
                computation_types.NamedTupleType(
                    (report_type, pre_report.type_signature))))
Esempio n. 24
0
 async def compute_federated_secure_sum(self, arg):
     self._check_arg_is_structure(arg)
     py_typecheck.check_len(arg.internal_representation, 2)
     value_type = arg.type_signature[0]
     type_analysis.check_federated_type(value_type, placement=tff.CLIENTS)
     py_typecheck.check_type(value_type.member, tff.TensorType)
     # Stash input dtype for later
     input_tensor_dtype = value_type.member.dtype
     # Paillier setup phase
     if self._requires_setup:
         await self._paillier_setup()
         self._requires_setup = False
     # Stash input shape, and reshape input tensor to matrix-form
     input_tensor_shape = value_type.member.shape
     if len(input_tensor_shape) != 2:
         clients_value = await self._compute_reshape_on_tensor(
             await self._executor.create_selection(arg, index=0),
             output_shape=[1, input_tensor_shape.num_elements()])
     else:
         clients_value = await self._executor.create_selection(arg, index=0)
     # Encrypt summands on tff.CLIENTS
     encrypted_values = await self._compute_paillier_encryption(
         self.encryption_key_clients, clients_value)
     # Perform Paillier sum on ciphertexts
     encrypted_values = await self._move(encrypted_values, tff.CLIENTS,
                                         paillier_placement.AGGREGATOR)
     encrypted_sum = await self._compute_paillier_sum(
         self.encryption_key_paillier, encrypted_values)
     # Move to server and decrypt the result
     encrypted_sum = await self._move(encrypted_sum,
                                      paillier_placement.AGGREGATOR,
                                      tff.SERVER)
     decrypted_result = await self._compute_paillier_decryption(
         self.decryption_key,
         self.encryption_key_server,
         encrypted_sum,
         export_dtype=input_tensor_dtype)
     return await self._compute_reshape_on_tensor(
         decrypted_result, output_shape=input_tensor_shape.as_list())
Esempio n. 25
0
 async def _compute_reshape_on_tensor(self, tensor, output_shape):
     tensor_type = tensor.type_signature.member
     shape_type = type_conversions.infer_type(output_shape)
     reshaper_proto, reshaper_type = utils.materialize_computation_from_cache(
         paillier_comp.make_reshape_tensor,
         self._reshape_function_cache,
         arg_spec=(tensor_type, ),
         output_shape=output_shape)
     tensor_placement = tensor.type_signature.placement
     children = self._get_child_executors(tensor_placement)
     py_typecheck.check_len(tensor.internal_representation, len(children))
     reshaper_fns = await asyncio.gather(*[
         ex.create_value(reshaper_proto, reshaper_type) for ex in children
     ])
     reshaped_tensors = await asyncio.gather(*[
         ex.create_call(fn, arg) for ex, fn, arg in zip(
             children, reshaper_fns, tensor.internal_representation)
     ])
     output_tensor_spec = tff.FederatedType(
         tff.TensorType(tensor_type.dtype, output_shape), tensor_placement,
         tensor.type_signature.all_equal)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         reshaped_tensors, output_tensor_spec)
Esempio n. 26
0
 async def _compute_paillier_encryption(
     self, client_encryption_keys: federated_resolving_strategy.
     FederatedResolvingStrategyValue,
     clients_value: federated_resolving_strategy.
     FederatedResolvingStrategyValue):
     client_children = self._get_child_executors(tff.CLIENTS)
     num_clients = len(client_children)
     py_typecheck.check_len(client_encryption_keys.internal_representation,
                            num_clients)
     py_typecheck.check_len(clients_value.internal_representation,
                            num_clients)
     encryptor_proto, encryptor_type = utils.lift_to_computation_spec(
         self._paillier_encryptor,
         input_arg_type=tff.StructType(
             (client_encryption_keys.type_signature.member,
              clients_value.type_signature.member)))
     encryptor_fns = asyncio.gather(*[
         c.create_value(encryptor_proto, encryptor_type)
         for c in client_children
     ])
     encryptor_args = asyncio.gather(*[
         c.create_struct((ek, v))
         for c, ek, v in zip(client_children,
                             client_encryption_keys.internal_representation,
                             clients_value.internal_representation)
     ])
     encryptor_fns, encryptor_args = await asyncio.gather(
         encryptor_fns, encryptor_args)
     encrypted_values = await asyncio.gather(*[
         c.create_call(fn, arg) for c, fn, arg in zip(
             client_children, encryptor_fns, encryptor_args)
     ])
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         encrypted_values,
         tff.FederatedType(encryptor_type.result, tff.CLIENTS,
                           clients_value.type_signature.all_equal))
Esempio n. 27
0
 async def _share_public_key(self, key_owner, key_receiver):
     public_key = self.key_references.get_public_key(key_owner)
     children = self.strategy._get_child_executors(key_receiver)
     val = await public_key.compute()
     key_type = public_key.type_signature.member
     # we currently only support sharing n keys with 1 executor,
     # or sharing 1 key with n executors
     if isinstance(val, list):
         # sharing n keys with 1 executor
         py_typecheck.check_len(children, 1)
         executor = children[0]
         vals = [executor.create_value(v, key_type) for v in val]
         vals_type = tff.FederatedType(type_conversions.infer_type(val),
                                       key_receiver)
     else:
         # sharing 1 key with n executors
         # val is a single tensor
         vals = [c.create_value(val, key_type) for c in children]
         vals_type = tff.FederatedType(key_type,
                                       key_receiver,
                                       all_equal=True)
     public_key_rcv = federated_resolving_strategy.FederatedResolvingStrategyValue(
         await asyncio.gather(*vals), vals_type)
     self.key_references.update_keys(key_owner, public_key=public_key_rcv)
    async def compute_federated_zip_at_clients(
        self, arg: FederatedComposingStrategyValue
    ) -> FederatedComposingStrategyValue:
        py_typecheck.check_type(arg.type_signature,
                                computation_types.StructType)
        py_typecheck.check_len(arg.type_signature, 2)
        py_typecheck.check_type(arg.internal_representation,
                                anonymous_tuple.AnonymousTuple)
        py_typecheck.check_len(arg.internal_representation, 2)
        keys = [k for k, _ in anonymous_tuple.to_elements(arg.type_signature)]
        vals = [arg.internal_representation[n] for n in [0, 1]]
        types = [arg.type_signature[n] for n in [0, 1]]
        for n in [0, 1]:
            type_analysis.check_federated_type(
                types[n], placement=placement_literals.CLIENTS)
            types[n] = type_factory.at_clients(types[n].member)
            py_typecheck.check_type(vals[n], list)
            py_typecheck.check_len(vals[n], len(self._target_executors))
        item_type = computation_types.StructType([
            ((keys[n], types[n].member) if keys[n] else types[n].member)
            for n in [0, 1]
        ])
        result_type = type_factory.at_clients(item_type)
        zip_type = computation_types.FunctionType(
            computation_types.StructType([
                ((keys[n], types[n]) if keys[n] else types[n]) for n in [0, 1]
            ]), result_type)
        zip_comp = executor_utils.create_intrinsic_comp(
            intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip_type)

        async def _child_fn(ex, x, y):
            py_typecheck.check_type(x, executor_value_base.ExecutorValue)
            py_typecheck.check_type(y, executor_value_base.ExecutorValue)
            return await ex.create_call(
                await ex.create_value(zip_comp, zip_type), await
                ex.create_struct(
                    anonymous_tuple.AnonymousTuple([(keys[0], x),
                                                    (keys[1], y)])))

        result = await asyncio.gather(*[
            _child_fn(c, x, y)
            for c, x, y in zip(self._target_executors, vals[0], vals[1])
        ])
        return FederatedComposingStrategyValue(result, result_type)
Esempio n. 29
0
 async def create_value(self, value, type_spec=None):
   type_spec = computation_types.to_type(type_spec)
   py_typecheck.check_type(type_spec, computation_types.Type)
   if isinstance(value, intrinsic_defs.IntrinsicDef):
     if not type_utils.is_concrete_instance_of(type_spec,
                                               value.type_signature):  # pytype: disable=attribute-error
       raise TypeError('Incompatible type {} used with intrinsic {}.'.format(
           type_spec, value.uri))  # pytype: disable=attribute-error
     else:
       return CompositeValue(value, type_spec)
   elif isinstance(value, pb.Computation):
     which_computation = value.WhichOneof('computation')
     if which_computation in ['tensorflow', 'lambda']:
       return CompositeValue(value, type_spec)
     elif which_computation == 'intrinsic':
       intr = intrinsic_defs.uri_to_intrinsic_def(value.intrinsic.uri)
       if intr is None:
         raise ValueError('Encountered an unrecognized intrinsic "{}".'.format(
             value.intrinsic.uri))
       py_typecheck.check_type(intr, intrinsic_defs.IntrinsicDef)
       return await self.create_value(intr, type_spec)
     else:
       raise NotImplementedError(
           'Unimplemented computation type {}.'.format(which_computation))
   elif isinstance(type_spec, computation_types.NamedTupleType):
     v_el = anonymous_tuple.to_elements(anonymous_tuple.from_container(value))
     t_el = anonymous_tuple.to_elements(type_spec)
     items = await asyncio.gather(
         *[self.create_value(v, t) for (_, v), (_, t) in zip(v_el, t_el)])
     return self.create_tuple(
         anonymous_tuple.AnonymousTuple([
             (k, i) for (k, _), i in zip(t_el, items)
         ]))
   elif isinstance(type_spec, computation_types.FederatedType):
     if type_spec.placement == placement_literals.SERVER:
       if type_spec.all_equal:
         return CompositeValue(
             await self._parent_executor.create_value(value, type_spec.member),
             type_spec)
       else:
         raise ValueError('A non-all_equal value on the server is unexpected.')
     elif type_spec.placement == placement_literals.CLIENTS:
       if type_spec.all_equal:
         return CompositeValue(
             await asyncio.gather(*[
                 c.create_value(value, type_spec)
                 for c in self._child_executors
             ]), type_spec)
       else:
         py_typecheck.check_type(value, list)
         if self._cardinalities is None:
           self._cardinalities = asyncio.ensure_future(
               self._get_cardinalities())
         cardinalities = await self._cardinalities
         py_typecheck.check_len(cardinalities, len(self._child_executors))
         count = sum(cardinalities)
         py_typecheck.check_len(value, count)
         result = []
         offset = 0
         for c, n in zip(self._child_executors, cardinalities):
           new_offset = offset + n
           # The slice opporator is not supported on all the types `value`
           # supports.
           # pytype: disable=unsupported-operands
           result.append(c.create_value(value[offset:new_offset], type_spec))
           # pytype: enable=unsupported-operands
           offset = new_offset
         return CompositeValue(await asyncio.gather(*result), type_spec)
     else:
       raise ValueError('Unexpected placement {}.'.format(type_spec.placement))
   else:
     return CompositeValue(
         await self._parent_executor.create_value(value, type_spec), type_spec)
Esempio n. 30
0
    def __init__(self,
                 initialize,
                 prepare,
                 work,
                 zero,
                 accumulate,
                 merge,
                 report,
                 bitwidth,
                 update,
                 server_state_label=None,
                 client_data_label=None):
        """Constructs a representation of a MapReduce-like iterative process.

    Note: All the computations supplied here as arguments must be TensorFlow
    computations, i.e., instances of `tff.Computation` constructed by the
    `tff.tf_computation` decorator/wrapper.

    Args:
      initialize: The computation that produces the initial server state.
      prepare: The computation that prepares the input for the clients.
      work: The client-side work computation.
      zero: The computation that produces the initial state for accumulators.
      accumulate: The computation that adds a client update to an accumulator.
      merge: The computation to use for merging pairs of accumulators.
      report: The computation that produces the final server-side aggregate for
        the top level accumulator (the global update).
      bitwidth: The computation that produces the bitwidth for secure sum.
      update: The computation that takes the global update and the server state
        and produces the new server state, as well as server-side output.
      server_state_label: Optional string label for the server state.
      client_data_label: Optional string label for the client data.

    Raises:
      TypeError: If the Python or TFF types of the arguments are invalid or not
        compatible with each other.
      AssertionError: If the manner in which the given TensorFlow computations
        are represented by TFF does not match what this code is expecting (this
        is an internal error that requires code update).
    """
        for label, comp in [
            ('initialize', initialize),
            ('prepare', prepare),
            ('work', work),
            ('zero', zero),
            ('accumulate', accumulate),
            ('merge', merge),
            ('report', report),
            ('bitwidth', bitwidth),
            ('update', update),
        ]:
            py_typecheck.check_type(comp, computation_base.Computation, label)

            # TODO(b/130633916): Remove private access once an appropriate API for it
            # becomes available.
            comp_proto = comp._computation_proto  # pylint: disable=protected-access

            if not isinstance(comp_proto, computation_pb2.Computation):
                # Explicitly raised to force it to be done in non-debug mode as well.
                raise AssertionError(
                    'Cannot find the embedded computation definition.')
            which_comp = comp_proto.WhichOneof('computation')
            if which_comp != 'tensorflow':
                raise TypeError(
                    'Expected all computations supplied as arguments to '
                    'be plain TensorFlow, found {}.'.format(which_comp))

        def is_assignable_from_or_both_none(first, second):
            if first is None:
                return second is None
            return first.is_assignable_from(second)

        prepare_arg_type = prepare.type_signature.parameter
        init_result_type = initialize.type_signature.result
        if not is_assignable_from_or_both_none(prepare_arg_type,
                                               init_result_type):
            raise TypeError(
                'The `prepare` computation expects an argument of type {}, '
                'which does not match the result type {} of `initialize`.'.
                format(prepare_arg_type, init_result_type))

        if (not work.type_signature.parameter.is_struct()
                or len(work.type_signature.parameter) != 2):
            raise TypeError(
                'The `work` computation expects an argument of type {} that is not '
                'a two-tuple.'.format(work.type_signature.parameter))

        work_2nd_arg_type = work.type_signature.parameter[1]
        prepare_result_type = prepare.type_signature.result
        if not is_assignable_from_or_both_none(work_2nd_arg_type,
                                               prepare_result_type):
            raise TypeError(
                'The `work` computation expects an argument tuple with type {} as '
                'the second element (the initial client state from the server), '
                'which does not match the result type {} of `prepare`.'.format(
                    work_2nd_arg_type, prepare_result_type))

        if (not work.type_signature.result.is_struct()
                or len(work.type_signature.result) != 2):
            raise TypeError(
                'The `work` computation returns a result  of type {} that is not a '
                'two-tuple.'.format(work.type_signature.result))

        py_typecheck.check_type(zero.type_signature,
                                computation_types.FunctionType)

        py_typecheck.check_type(accumulate.type_signature,
                                computation_types.FunctionType)
        py_typecheck.check_len(accumulate.type_signature.parameter, 2)
        accumulate.type_signature.parameter[0].check_assignable_from(
            zero.type_signature.result)
        accumulate_2nd_arg_type = accumulate.type_signature.parameter[1]
        work_client_update_type = work.type_signature.result[0]
        if not is_assignable_from_or_both_none(accumulate_2nd_arg_type,
                                               work_client_update_type):

            raise TypeError(
                'The `accumulate` computation expects a second argument of type {}, '
                'which does not match the expected {} as implied by the type '
                'signature of `work`.'.format(accumulate_2nd_arg_type,
                                              work_client_update_type))
        accumulate.type_signature.parameter[0].check_assignable_from(
            accumulate.type_signature.result)

        py_typecheck.check_type(merge.type_signature,
                                computation_types.FunctionType)
        py_typecheck.check_len(merge.type_signature.parameter, 2)
        merge.type_signature.parameter[0].check_assignable_from(
            accumulate.type_signature.result)
        merge.type_signature.parameter[1].check_assignable_from(
            accumulate.type_signature.result)
        merge.type_signature.parameter[0].check_assignable_from(
            merge.type_signature.result)

        py_typecheck.check_type(report.type_signature,
                                computation_types.FunctionType)
        report.type_signature.parameter.check_assignable_from(
            merge.type_signature.result)

        py_typecheck.check_type(bitwidth.type_signature,
                                computation_types.FunctionType)

        expected_update_parameter_type = computation_types.to_type([
            initialize.type_signature.result,
            [report.type_signature.result, work.type_signature.result[1]],
        ])
        if not is_assignable_from_or_both_none(update.type_signature.parameter,
                                               expected_update_parameter_type):
            raise TypeError(
                'The `update` computation expects an argument of type {}, '
                'which does not match the expected {} as implied by the type '
                'signatures of `initialize`, `report`, and `work`.'.format(
                    update.type_signature.parameter,
                    expected_update_parameter_type))

        if (not update.type_signature.result.is_struct()
                or len(update.type_signature.result) != 2):
            raise TypeError(
                'The `update` computation returns a result  of type {} that is not '
                'a two-tuple.'.format(update.type_signature.result))

        updated_state_type = update.type_signature.result[0]
        if not prepare_arg_type.is_assignable_from(updated_state_type):
            raise TypeError(
                'The `update` computation returns a result tuple whose first element '
                f'(the updated state type of the server) is type:\n'
                f'{updated_state_type}\n'
                f'which is not assignable to the state parameter type of `prepare`:\n'
                f'{prepare_arg_type}')

        self._initialize = initialize
        self._prepare = prepare
        self._work = work
        self._zero = zero
        self._accumulate = accumulate
        self._merge = merge
        self._report = report
        self._bitwidth = bitwidth
        self._update = update

        if server_state_label is not None:
            py_typecheck.check_type(server_state_label, str)
        self._server_state_label = server_state_label
        if client_data_label is not None:
            py_typecheck.check_type(client_data_label, str)
        self._client_data_label = client_data_label