예제 #1
0
  async def compute_federated_aggregate(
      self,
      arg: FederatedResolvingStrategyValue) -> FederatedResolvingStrategyValue:
    val_type, zero_type, accumulate_type, merge_type, report_type = (
        executor_utils.parse_federated_aggregate_argument_types(
            arg.type_signature))
    del val_type, merge_type
    py_typecheck.check_type(arg.internal_representation, structure.Struct)
    py_typecheck.check_len(arg.internal_representation, 5)
    val, zero, accumulate, merge, report = arg.internal_representation

    # Discard `merge`. Since all aggregation happens on a single executor,
    # there's no need for this additional layer.
    del merge

    # Re-wrap `zero` in a `FederatingResolvingStrategyValue` to ensure that it
    # is an `ExecutorValue` rather than a `Struct` (since the internal
    # representation can include embedded values, lists of embedded values
    # (in the case of federated values), or `Struct`s.
    zero = FederatedResolvingStrategyValue(zero, zero_type)
    pre_report = await self.reduce(val, zero, accumulate, accumulate_type)

    py_typecheck.check_type(pre_report.type_signature,
                            computation_types.FederatedType)
    pre_report.type_signature.member.check_equivalent_to(report_type.parameter)

    return await self.compute_federated_apply(
        FederatedResolvingStrategyValue(
            structure.Struct([(None, report),
                              (None, pre_report.internal_representation)]),
            computation_types.StructType(
                (report_type, pre_report.type_signature))))
예제 #2
0
    async def compute_federated_aggregate(
        self, arg: FederatedResolvingStrategyValue
    ) -> FederatedResolvingStrategyValue:
        val_type, zero_type, accumulate_type, merge_type, report_type = (
            executor_utils.parse_federated_aggregate_argument_types(
                arg.type_signature))
        del val_type, zero_type, merge_type
        py_typecheck.check_type(arg.internal_representation, structure.Struct)
        py_typecheck.check_len(arg.internal_representation, 5)
        val, zero, accumulate, merge, report = arg.internal_representation

        # Discard `merge`. Since all aggregation happens on a single executor,
        # there's no need for this additional layer.
        del merge
        pre_report = await self.reduce(val, zero, accumulate, accumulate_type)

        py_typecheck.check_type(pre_report.type_signature,
                                computation_types.FederatedType)
        pre_report.type_signature.member.check_equivalent_to(
            report_type.parameter)

        return await self.compute_federated_apply(
            FederatedResolvingStrategyValue(
                structure.Struct([(None, report),
                                  (None, pre_report.internal_representation)]),
                computation_types.StructType(
                    (report_type, pre_report.type_signature))))
    async def compute_federated_aggregate(
        self, arg: FederatedComposingStrategyValue
    ) -> FederatedComposingStrategyValue:
        value_type, zero_type, accumulate_type, merge_type, report_type = (
            executor_utils.parse_federated_aggregate_argument_types(
                arg.type_signature))
        py_typecheck.check_type(arg.internal_representation, structure.Struct)
        py_typecheck.check_len(arg.internal_representation, 5)
        val = arg.internal_representation[0]
        py_typecheck.check_type(val, list)
        py_typecheck.check_len(val, len(self._target_executors))
        identity_report, identity_report_type = tensorflow_computation_factory.create_identity(
            zero_type)
        aggr_type = computation_types.FunctionType(
            computation_types.StructType([
                value_type, zero_type, accumulate_type, merge_type,
                identity_report_type
            ]), computation_types.at_server(zero_type))
        aggr_comp = executor_utils.create_intrinsic_comp(
            intrinsic_defs.FEDERATED_AGGREGATE, aggr_type)
        zero = await (await
                      self._executor.create_selection(arg, index=1)).compute()
        accumulate = arg.internal_representation[2]
        merge = arg.internal_representation[3]
        report = arg.internal_representation[4]

        async def _child_fn(ex, v):
            py_typecheck.check_type(v, executor_value_base.ExecutorValue)
            arg_values = [
                ex.create_value(zero, zero_type),
                ex.create_value(accumulate, accumulate_type),
                ex.create_value(merge, merge_type),
                ex.create_value(identity_report, identity_report_type)
            ]
            aggr_func, aggr_args = await asyncio.gather(
                ex.create_value(aggr_comp, aggr_type),
                ex.create_struct([v] +
                                 list(await asyncio.gather(*arg_values))))
            child_result = await (await ex.create_call(aggr_func,
                                                       aggr_args)).compute()
            result_at_server = await self._server_executor.create_value(
                child_result, zero_type)
            return result_at_server

        val_futures = asyncio.as_completed(
            [_child_fn(c, v) for c, v in zip(self._target_executors, val)])
        parent_merge, parent_report = await asyncio.gather(
            self._server_executor.create_value(merge, merge_type),
            self._server_executor.create_value(report, report_type))
        merge_result = await next(val_futures)
        for next_val_future in val_futures:
            next_val = await next_val_future
            merge_arg = await self._server_executor.create_struct(
                [merge_result, next_val])
            merge_result = await self._server_executor.create_call(
                parent_merge, merge_arg)
        report_result = await self._server_executor.create_call(
            parent_report, merge_result)
        return FederatedComposingStrategyValue(
            report_result, computation_types.at_server(report_type.result))
    async def compute_federated_aggregate(
        self, arg: FederatedComposingStrategyValue
    ) -> FederatedComposingStrategyValue:
        value_type, zero_type, accumulate_type, merge_type, report_type = (
            executor_utils.parse_federated_aggregate_argument_types(
                arg.type_signature))
        py_typecheck.check_type(arg.internal_representation,
                                anonymous_tuple.AnonymousTuple)
        py_typecheck.check_len(arg.internal_representation, 5)
        val = arg.internal_representation[0]
        py_typecheck.check_type(val, list)
        py_typecheck.check_len(val, len(self._target_executors))
        identity_report = tensorflow_computation_factory.create_identity(
            zero_type)
        identity_report_type = type_factory.unary_op(zero_type)
        aggr_type = computation_types.FunctionType(
            computation_types.NamedTupleType([
                value_type, zero_type, accumulate_type, merge_type,
                identity_report_type
            ]), type_factory.at_server(zero_type))
        aggr_comp = executor_utils.create_intrinsic_comp(
            intrinsic_defs.FEDERATED_AGGREGATE, aggr_type)
        zero = await (await
                      self._executor.create_selection(arg, index=1)).compute()
        accumulate = arg.internal_representation[2]
        merge = arg.internal_representation[3]
        report = arg.internal_representation[4]

        async def _child_fn(ex, v):
            py_typecheck.check_type(v, executor_value_base.ExecutorValue)
            aggr_func, aggr_args = await asyncio.gather(
                ex.create_value(aggr_comp, aggr_type),
                ex.create_tuple([v] + list(await asyncio.gather(
                    ex.create_value(zero, zero_type),
                    ex.create_value(accumulate, accumulate_type),
                    ex.create_value(merge, merge_type),
                    ex.create_value(identity_report, identity_report_type)))))
            return await (await ex.create_call(aggr_func, aggr_args)).compute()

        vals = await asyncio.gather(
            *[_child_fn(c, v) for c, v in zip(self._target_executors, val)])
        parent_vals = await asyncio.gather(
            *[self._server_executor.create_value(v, zero_type) for v in vals])
        parent_merge, parent_report = await asyncio.gather(
            self._server_executor.create_value(merge, merge_type),
            self._server_executor.create_value(report, report_type))
        merge_result = parent_vals[0]
        for next_val in parent_vals[1:]:
            merge_result = await self._server_executor.create_call(
                parent_merge, await
                self._server_executor.create_tuple([merge_result, next_val]))
        return FederatedComposingStrategyValue(
            await self._server_executor.create_call(parent_report,
                                                    merge_result),
            type_factory.at_server(report_type.result))
    async def compute_federated_aggregate(
        self, arg: FederatedResolvingStrategyValue
    ) -> FederatedResolvingStrategyValue:
        val_type, zero_type, accumulate_type, _, report_type = (
            executor_utils.parse_federated_aggregate_argument_types(
                arg.type_signature))
        py_typecheck.check_type(arg.internal_representation,
                                anonymous_tuple.AnonymousTuple)
        py_typecheck.check_len(arg.internal_representation, 5)

        # Note: This is a simple initial implementation that simply forwards this
        # to `federated_reduce()`. The more complete implementation would be able
        # to take advantage of the parallelism afforded by `merge` to reduce the
        # cost from liner (with respect to the number of clients) to sub-linear.

        # TODO(b/134543154): Expand this implementation to take advantage of the
        # parallelism afforded by `merge`.

        val = arg.internal_representation[0]
        zero = arg.internal_representation[1]
        accumulate = arg.internal_representation[2]
        pre_report = await self.compute_federated_reduce(
            FederatedResolvingStrategyValue(
                anonymous_tuple.AnonymousTuple([(None, val), (None, zero),
                                                (None, accumulate)]),
                computation_types.NamedTupleType(
                    (val_type, zero_type, accumulate_type))))

        py_typecheck.check_type(pre_report.type_signature,
                                computation_types.FederatedType)
        pre_report.type_signature.member.check_equivalent_to(
            report_type.parameter)

        report = arg.internal_representation[4]
        return await self.compute_federated_apply(
            FederatedResolvingStrategyValue(
                anonymous_tuple.AnonymousTuple([
                    (None, report), (None, pre_report.internal_representation)
                ]),
                computation_types.NamedTupleType(
                    (report_type, pre_report.type_signature))))