def test_raises_not_implemented_error_with_intrinsic_def_federated_secure_sum(
      self):
    executor = create_test_executor()
    comp, comp_type = executor_test_utils.create_dummy_intrinsic_def_federated_secure_sum(
    )
    arg_1 = [10, 11, 12]
    arg_1_type = computation_types.at_clients(tf.int32, all_equal=False)
    arg_2 = 10
    arg_2_type = computation_types.TensorType(tf.int32)

    comp = self.run_sync(executor.create_value(comp, comp_type))
    arg_1 = self.run_sync(executor.create_value(arg_1, arg_1_type))
    arg_2 = self.run_sync(executor.create_value(arg_2, arg_2_type))
    args = self.run_sync(executor.create_struct([arg_1, arg_2]))
    with self.assertRaises(NotImplementedError):
      self.run_sync(executor.create_call(comp, args))
    def test_returns_value_with_unplaced_type_and_clients(self, executor):
        value, type_signature = executor_test_utils.create_dummy_value_unplaced(
        )

        value = self.run_sync(executor.create_value(value, type_signature))
        result = self.run_sync(
            executor_utils.compute_intrinsic_federated_value(
                executor, value, placements.CLIENTS))

        self.assertIsInstance(result, executor_value_base.ExecutorValue)
        expected_type = computation_types.at_clients(type_signature,
                                                     all_equal=True)
        self.assertEqual(result.type_signature.compact_representation(),
                         expected_type.compact_representation())
        actual_result = self.run_sync(result.compute())
        self.assertEqual(actual_result, 10.0)
Exemple #3
0
def _build_expected_broadcaster_next_signature():
    """Returns signature of the broadcaster used in multiple tests below."""
    state_type = computation_types.at_server(
        computation_types.StructType([('trainable', [
            (),
        ]), ('non_trainable', [])]))
    value_type = computation_types.at_server(
        model_utils.weights_type_from_model(TestModelQuant))
    result_type = computation_types.at_clients(
        model_utils.weights_type_from_model(TestModelQuant))
    measurements_type = computation_types.at_server(())
    return computation_types.FunctionType(
        parameter=collections.OrderedDict(state=state_type, value=value_type),
        result=collections.OrderedDict(state=state_type,
                                       result=result_type,
                                       measurements=measurements_type))
    def test_changing_cardinalities_across_calls(self):
        @computations.federated_computation(
            computation_types.at_clients(tf.int32))
        def comp(x):
            return x

        five_ints = list(range(5))
        ten_ints = list(range(10))

        executor = executor_stacks.local_executor_factory()
        with executor_test_utils.install_executor(executor):
            five = comp(five_ints)
            ten = comp(ten_ints)

        self.assertEqual(five, five_ints)
        self.assertEqual(ten, ten_ints)
    def test_returns_value_with_federated_type_at_server(
            self, executor, num_clients):
        del num_clients  # Unused.
        value, type_signature = executor_test_utils.create_dummy_value_at_server(
        )

        value = self.run_sync(executor.create_value(value, type_signature))
        result = self.run_sync(
            executor_utils.compute_intrinsic_federated_broadcast(
                executor, value))

        self.assertIsInstance(result, executor_value_base.ExecutorValue)
        expected_type = computation_types.at_clients(type_signature.member,
                                                     all_equal=True)
        self.assertEqual(result.type_signature.compact_representation(),
                         expected_type.compact_representation())
        actual_result = self.run_sync(result.compute())
        self.assertEqual(actual_result, 10.0)
  def test_returns_value_with_intrinsic_def_federated_secure_sum(
      self, value, bitwidth, expected_result):
    executor = create_test_executor()
    comp, comp_type = executor_test_utils.create_dummy_intrinsic_def_federated_secure_sum(
    )
    value_type = computation_types.at_clients(tf.int32, all_equal=False)
    bitwidth_type = computation_types.TensorType(tf.int32)

    comp = self.run_sync(executor.create_value(comp, comp_type))
    arg_1 = self.run_sync(executor.create_value(value, value_type))
    arg_2 = self.run_sync(executor.create_value(bitwidth, bitwidth_type))
    args = self.run_sync(executor.create_struct([arg_1, arg_2]))
    result = self.run_sync(executor.create_call(comp, args))

    self.assertIsInstance(result, executor_value_base.ExecutorValue)
    self.assert_types_identical(result.type_signature, comp_type.result)
    actual_result = self.run_sync(result.compute())
    self.assertEqual(actual_result, expected_result)
Exemple #7
0
            def create(
                self, value_type: factory.ValueType
            ) -> aggregation_process.AggregationProcess:
                _check_value_type(value_type)

                inner_agg_process = inner_agg_factory.create(value_type)
                clip_fn = make_clip_fn(value_type)

                @computations.federated_computation()
                def init_fn():
                    return init_fn_impl(inner_agg_process)

                @computations.federated_computation(
                    init_fn.type_signature.result,
                    computation_types.at_clients(value_type))
                def next_fn(state, value):
                    return next_fn_impl(state, value, clip_fn,
                                        inner_agg_process)

                return aggregation_process.AggregationProcess(init_fn, next_fn)
Exemple #8
0
  def test_federated_sum_reduces_to_aggregate(self):
    uri = intrinsic_defs.FEDERATED_SUM.uri

    comp = building_blocks.Intrinsic(
        uri,
        computation_types.FunctionType(
            computation_types.at_clients(tf.float32),
            computation_types.at_server(tf.float32)))

    count_sum_before_reduction = _count_intrinsics(comp, uri)
    reduced, modified = intrinsic_reductions.replace_intrinsics_with_bodies(
        comp)
    count_sum_after_reduction = _count_intrinsics(reduced, uri)
    count_aggregations = _count_intrinsics(
        reduced, intrinsic_defs.FEDERATED_AGGREGATE.uri)
    self.assertTrue(modified)
    self.assert_types_identical(comp.type_signature, reduced.type_signature)
    self.assertGreater(count_sum_before_reduction, 0)
    self.assertEqual(count_sum_after_reduction, 0)
    self.assertGreater(count_aggregations, 0)
Exemple #9
0
def _create_tff_parallel_clients_with_dataset_reduce():
    @tf.function
    def reduce_fn(x, y):
        return x + y

    @tf.function
    def dataset_reduce_fn(ds, initial_val):
        return ds.reduce(initial_val, reduce_fn)

    @computations.tf_computation(computation_types.SequenceType(tf.int64))
    def dataset_reduce_fn_wrapper(ds):
        initial_val = tf.Variable(np.int64(1.0))
        return dataset_reduce_fn(ds, initial_val)

    @computations.federated_computation(
        computation_types.at_clients(computation_types.SequenceType(tf.int64)))
    def parallel_client_run(client_datasets):
        return intrinsics.federated_map(dataset_reduce_fn_wrapper,
                                        client_datasets)

    return parallel_client_run
Exemple #10
0
    def test_type_properties(self, value_type, inner_agg_factory):
        agg_factory = dp_factory.DifferentiallyPrivateFactory(
            _test_dp_query, inner_agg_factory)
        self.assertIsInstance(agg_factory,
                              factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = agg_factory.create_unweighted(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        query_state = _test_dp_query.initial_global_state()
        query_state_type = type_conversions.type_from_tensors(query_state)
        query_metrics_type = type_conversions.type_from_tensors(
            _test_dp_query.derive_metrics(query_state))

        inner_state_type = tf.int32 if inner_agg_factory else ()

        server_state_type = computation_types.at_server(
            (query_state_type, inner_state_type))
        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=server_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        inner_measurements_type = tf.int32 if inner_agg_factory else ()
        expected_measurements_type = computation_types.at_server(
            collections.OrderedDict(
                query_metrics=query_metrics_type,
                record_agg_process=inner_measurements_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Exemple #11
0
def _check_bound_process(bound_process: estimation_process.EstimationProcess,
                         name: str):
    """Checks type properties for estimation process for bounds.

  The process must be an `EstimationProcess` with `next` function of type
  signature (<state@SERVER, NORM_TF_TYPE@CLIENTS> -> state@SERVER), and `report`
  with type signature (state@SERVER -> NORM_TF_TYPE@SERVER).

  Args:
    bound_process: A process to check.
    name: A string name for formatting error messages.
  """
    py_typecheck.check_type(bound_process,
                            estimation_process.EstimationProcess)

    next_parameter_type = bound_process.next.type_signature.parameter
    if not next_parameter_type.is_struct() or len(next_parameter_type) != 2:
        raise TypeError(f'`{name}.next` must take two arguments but found:\n'
                        f'{next_parameter_type}')

    float_type_at_clients = computation_types.at_clients(NORM_TF_TYPE)
    if not next_parameter_type[1].is_assignable_from(float_type_at_clients):
        raise TypeError(
            f'Second argument of `{name}.next` must be assignable from '
            f'{float_type_at_clients} but found {next_parameter_type[1]}')

    next_result_type = bound_process.next.type_signature.result
    if not bound_process.state_type.is_assignable_from(next_result_type):
        raise TypeError(
            f'Result type of `{name}.next` must consist of state only '
            f'but found result type:\n{next_result_type}\n'
            f'while the state type is:\n{bound_process.state_type}')

    report_type = bound_process.report.type_signature.result
    estimated_value_type_at_server = computation_types.at_server(
        next_parameter_type[1].member)
    if not report_type.is_assignable_from(estimated_value_type_at_server):
        raise TypeError(
            f'Report type of `{name}.report` must be assignable from '
            f'{estimated_value_type_at_server} but found {report_type}.')
    def test_type_properties_with_inner_factory(self, value_type, weight_type):
        sum_factory = aggregators_test_utils.SumPlusOneFactory()
        mean_f = mean_factory.MeanFactory(value_sum_factory=sum_factory,
                                          weight_sum_factory=sum_factory)
        self.assertIsInstance(mean_f, factory.WeightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        weight_type = computation_types.to_type(weight_type)
        process = mean_f.create_weighted(value_type, weight_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        param_value_type = computation_types.FederatedType(
            value_type, placements.CLIENTS)
        result_value_type = computation_types.FederatedType(
            value_type, placements.SERVER)
        expected_state_type = computation_types.FederatedType(
            collections.OrderedDict(value_sum_process=tf.int32,
                                    weight_sum_process=tf.int32),
            placements.SERVER)
        expected_measurements_type = computation_types.FederatedType(
            collections.OrderedDict(value_sum_process=tf.int32,
                                    weight_sum_process=tf.int32),
            placements.SERVER)

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=param_value_type,
                weight=computation_types.at_clients(weight_type)),
            result=measured_process.MeasuredProcessOutput(
                expected_state_type, result_value_type,
                expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Exemple #13
0
    def test_federated_aggregate_with_federated_zero_fails(self):
        zero = intrinsics.federated_value(0, placements.SERVER)

        @computations.tf_computation([tf.int32, tf.int32])
        def accumulate(accu, elem):
            return accu + elem

        # The operator to use during the second stage simply adds total and count.
        @computations.tf_computation([tf.int32, tf.int32])
        def merge(x, y):
            return x + y

        # The operator to use during the final stage simply computes the ratio.
        @computations.tf_computation(tf.int32)
        def report(accu):
            return accu

        x = _mock_data_of_type(computation_types.at_clients(tf.int32))
        with self.assertRaisesRegex(
                TypeError, 'Expected `zero` to be assignable to type int32, '
                'but was of incompatible type int32@SERVER'):
            intrinsics.federated_aggregate(x, zero, accumulate, merge, report)
    def test_returns_value_with_intrinsic_def_federated_secure_sum(
            self, client_values, bitwidth, expected_result):
        executor = create_test_executor()
        value_type = computation_types.at_clients(
            type_conversions.infer_type(client_values[0]))
        bitwidth_type = type_conversions.infer_type(bitwidth)
        comp, comp_type = create_intrinsic_def_federated_secure_sum(
            value_type.member, bitwidth_type)

        comp = self.run_sync(executor.create_value(comp, comp_type))
        arg_1 = self.run_sync(executor.create_value(client_values, value_type))
        arg_2 = self.run_sync(executor.create_value(bitwidth, bitwidth_type))
        args = self.run_sync(executor.create_struct([arg_1, arg_2]))
        result = self.run_sync(executor.create_call(comp, args))

        self.assertIsInstance(result, executor_value_base.ExecutorValue)
        self.assert_types_identical(result.type_signature, comp_type.result)
        actual_result = self.run_sync(result.compute())
        if isinstance(expected_result, structure.Struct):
            structure.map_structure(self.assertAllEqual, actual_result,
                                    expected_result)
        else:
            self.assertEqual(actual_result, expected_result)
Exemple #15
0
    def test_type_properties(self, encoder_fn):
        encoded_f = encoded_factory.EncodedSumFactory(encoder_fn)
        self.assertIsInstance(encoded_f, factory.UnweightedAggregationFactory)

        process = encoded_f.create_unweighted(_test_struct_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        self.assertIsNone(process.initialize.type_signature.parameter)
        server_state_type = process.initialize.type_signature.result
        # State structure should have one element per tensor aggregated,
        self.assertLen(server_state_type.member, 2)
        self.assertEqual(placements.SERVER, server_state_type.placement)

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=server_state_type,
                value=computation_types.at_clients(_test_struct_type)),
            result=measured_process.MeasuredProcessOutput(
                state=server_state_type,
                result=computation_types.at_server(_test_struct_type),
                measurements=computation_types.at_server(())))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
Exemple #16
0
    def test_with_temperature_sensor_example(self, executor):
        @computations.tf_computation(computation_types.SequenceType(
            tf.float32), tf.float32)
        def count_over(ds, t):
            return ds.reduce(
                np.float32(0),
                lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32))

        @computations.tf_computation(computation_types.SequenceType(tf.float32)
                                     )
        def count_total(ds):
            return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0)

        @computations.federated_computation(
            computation_types.at_clients(
                computation_types.SequenceType(tf.float32)),
            computation_types.at_server(tf.float32))
        def comp(temperatures, threshold):
            return intrinsics.federated_mean(
                intrinsics.federated_map(
                    count_over,
                    intrinsics.federated_zip([
                        temperatures,
                        intrinsics.federated_broadcast(threshold)
                    ])), intrinsics.federated_map(count_total, temperatures))

        with executor_test_utils.install_executor(executor):
            to_float = lambda x: tf.cast(x, tf.float32)
            temperatures = [
                tf.data.Dataset.range(10).map(to_float),
                tf.data.Dataset.range(20).map(to_float),
                tf.data.Dataset.range(30).map(to_float),
            ]
            threshold = 15.0
            result = comp(temperatures, threshold)
            self.assertAlmostEqual(result, 8.333, places=3)
 def test_at_clients(self):
   type_spec = computation_types.TensorType(tf.bool)
   actual_type = computation_types.at_clients(type_spec)
   expected_type = computation_types.FederatedType(type_spec,
                                                   placements.CLIENTS)
   self.assertEqual(actual_type, expected_type)
Exemple #18
0
def create_whimsy_called_federated_collect(value_type=tf.int32):
    federated_type = computation_types.at_clients(value_type)
    value = building_blocks.Data('data', federated_type)
    return building_block_factory.create_federated_collect(value)
async def compute_intrinsic_federated_weighted_mean(
    executor: executor_base.Executor,
    arg: executor_value_base.ExecutorValue,
    local_computation_factory: local_computation_factory_base.
    LocalComputationFactory = tensorflow_computation_factory.
    TensorFlowComputationFactory()
) -> executor_value_base.ExecutorValue:
    """Computes a federated weighted mean on the given `executor`.

  Args:
    executor: The executor to use.
    arg: The argument to embedded in `executor`.
    local_computation_factory: An instance of `LocalComputationFactory` to use
      to construct local computations used as parameters in certain federated
      operators (such as `tff.federated_sum`, etc.). Defaults to a TensorFlow
      computation factory that generates TensorFlow code.

  Returns:
    The result embedded in `executor`.
  """
    type_analysis.check_valid_federated_weighted_mean_argument_tuple_type(
        arg.type_signature)
    zip1_type = computation_types.FunctionType(
        computation_types.StructType([
            computation_types.at_clients(arg.type_signature[0].member),
            computation_types.at_clients(arg.type_signature[1].member)
        ]),
        computation_types.at_clients(
            computation_types.StructType(
                [arg.type_signature[0].member, arg.type_signature[1].member])))

    operand_type = zip1_type.result.member[0]
    scalar_type = zip1_type.result.member[1]
    multiply_comp_pb, multiply_comp_type = local_computation_factory.create_scalar_multiply_operator(
        operand_type, scalar_type)
    multiply_blk = building_blocks.CompiledComputation(
        multiply_comp_pb, type_signature=multiply_comp_type)
    map_type = computation_types.FunctionType(
        computation_types.StructType(
            [multiply_blk.type_signature, zip1_type.result]),
        computation_types.at_clients(multiply_blk.type_signature.result))

    sum1_type = computation_types.FunctionType(
        computation_types.at_clients(map_type.result.member),
        computation_types.at_server(map_type.result.member))

    sum2_type = computation_types.FunctionType(
        computation_types.at_clients(arg.type_signature[1].member),
        computation_types.at_server(arg.type_signature[1].member))

    zip2_type = computation_types.FunctionType(
        computation_types.StructType([sum1_type.result, sum2_type.result]),
        computation_types.at_server(
            computation_types.StructType(
                [sum1_type.result.member, sum2_type.result.member])))

    divide_blk = building_block_factory.create_tensorflow_binary_operator_with_upcast(
        zip2_type.result.member, tf.divide)

    async def _compute_multiply_fn():
        return await executor.create_value(multiply_blk.proto,
                                           multiply_blk.type_signature)

    async def _compute_multiply_arg():
        zip1_comp = create_intrinsic_comp(
            intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip1_type)
        zip_fn = await executor.create_value(zip1_comp, zip1_type)
        return await executor.create_call(zip_fn, arg)

    async def _compute_product_fn():
        map_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_MAP,
                                         map_type)
        return await executor.create_value(map_comp, map_type)

    async def _compute_product_arg():
        multiply_fn, multiply_arg = await asyncio.gather(
            _compute_multiply_fn(), _compute_multiply_arg())
        return await executor.create_struct((multiply_fn, multiply_arg))

    async def _compute_products():
        product_fn, product_arg = await asyncio.gather(_compute_product_fn(),
                                                       _compute_product_arg())
        return await executor.create_call(product_fn, product_arg)

    async def _compute_total_weight():
        sum2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM,
                                          sum2_type)
        sum2_fn, sum2_arg = await asyncio.gather(
            executor.create_value(sum2_comp, sum2_type),
            executor.create_selection(arg, 1))
        return await executor.create_call(sum2_fn, sum2_arg)

    async def _compute_sum_of_products():
        sum1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM,
                                          sum1_type)
        sum1_fn, products = await asyncio.gather(
            executor.create_value(sum1_comp, sum1_type), _compute_products())
        return await executor.create_call(sum1_fn, products)

    async def _compute_zip2_fn():
        zip2_comp = create_intrinsic_comp(
            intrinsic_defs.FEDERATED_ZIP_AT_SERVER, zip2_type)
        return await executor.create_value(zip2_comp, zip2_type)

    async def _compute_zip2_arg():
        sum_of_products, total_weight = await asyncio.gather(
            _compute_sum_of_products(), _compute_total_weight())
        return await executor.create_struct([sum_of_products, total_weight])

    async def _compute_divide_fn():
        return await executor.create_value(divide_blk.proto,
                                           divide_blk.type_signature)

    async def _compute_divide_arg():
        zip_fn, zip_arg = await asyncio.gather(_compute_zip2_fn(),
                                               _compute_zip2_arg())
        return await executor.create_call(zip_fn, zip_arg)

    async def _compute_apply_fn():
        apply_type = computation_types.FunctionType(
            computation_types.StructType(
                [divide_blk.type_signature, zip2_type.result]),
            computation_types.at_server(divide_blk.type_signature.result))
        apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY,
                                           apply_type)
        return await executor.create_value(apply_comp, apply_type)

    async def _compute_apply_arg():
        divide_fn, divide_arg = await asyncio.gather(_compute_divide_fn(),
                                                     _compute_divide_arg())
        return await executor.create_struct([divide_fn, divide_arg])

    async def _compute_divided():
        apply_fn, apply_arg = await asyncio.gather(_compute_apply_fn(),
                                                   _compute_apply_arg())
        return await executor.create_call(apply_fn, apply_arg)

    return await _compute_divided()
Exemple #20
0
def create_dummy_value_at_clients_all_equal():
    """Returns a Python value and federated type at clients and all equal."""
    value = 10.0
    type_signature = computation_types.at_clients(tf.float32, all_equal=True)
    return value, type_signature
Exemple #21
0
def create_dummy_value_at_clients(number_of_clients: int = 3):
    """Returns a Python value and federated type at clients."""
    value = [float(x) for x in range(10, number_of_clients + 10)]
    type_signature = computation_types.at_clients(tf.float32)
    return value, type_signature
Exemple #22
0
def create_dummy_intrinsic_def_federated_value_at_clients():
    value = intrinsic_defs.FEDERATED_VALUE_AT_CLIENTS
    type_signature = computation_types.FunctionType(
        tf.float32, computation_types.at_clients(tf.float32, all_equal=True))
    return value, type_signature
Exemple #23
0
def create_dummy_intrinsic_def_federated_sum():
    value = intrinsic_defs.FEDERATED_SUM
    type_signature = computation_types.FunctionType(
        computation_types.at_clients(tf.float32),
        computation_types.at_server(tf.float32))
    return value, type_signature
Exemple #24
0
def create_dummy_intrinsic_def_federated_eval_at_clients():
    value = intrinsic_defs.FEDERATED_EVAL_AT_CLIENTS
    type_signature = computation_types.FunctionType(
        computation_types.FunctionType(None, tf.float32),
        computation_types.at_clients(tf.float32))
    return value, type_signature
Exemple #25
0
def create_dummy_intrinsic_def_federated_broadcast():
    value = intrinsic_defs.FEDERATED_BROADCAST
    type_signature = computation_types.FunctionType(
        computation_types.at_server(tf.float32),
        computation_types.at_clients(tf.float32, all_equal=True))
    return value, type_signature
    async def compute_federated_secure_sum(
        self, arg: federated_resolving_strategy.FederatedResolvingStrategyValue
    ) -> federated_resolving_strategy.FederatedResolvingStrategyValue:
        logging.warning(
            'The implementation of the `tff.federated_secure_sum` intrinsic '
            'provided by the `tff.backends.test` runtime uses no cryptography.'
        )
        py_typecheck.check_type(arg.internal_representation, structure.Struct)
        py_typecheck.check_len(arg.internal_representation, 2)
        summands, bitwidth = await asyncio.gather(
            self.ingest_value(arg.internal_representation[0],
                              arg.type_signature[0]).compute(),
            self.ingest_value(arg.internal_representation[1],
                              arg.type_signature[1]).compute())
        summands_type = arg.type_signature[0].member
        if not type_analysis.is_structure_of_integers(summands_type):
            raise TypeError(
                'Cannot compute `federated_secure_sum` on summands that are not '
                'TensorType or StructType of TensorType. Got {t}'.format(
                    t=repr(summands_type)))
        if (summands_type.is_struct()
                and not structure.is_same_structure(summands_type, bitwidth)):
            raise TypeError(
                'Cannot compute `federated_secure_sum` if summands and bitwidth are '
                'not the same structure. Got summands={s}, bitwidth={b}'.
                format(s=repr(summands_type), b=repr(bitwidth.type_signature)))

        num_additional_bits = await self._compute_extra_bits_for_secagg()
        # Clamp to 64 bits, otherwise we can't represent the mask in TensorFlow.
        extended_bitwidth = _map_numpy_or_structure(
            bitwidth, fn=lambda b: min(b.numpy() + num_additional_bits, 64))
        logging.debug('Emulated secure sum effective bitwidth: %s',
                      extended_bitwidth)
        # Now we need to cast the summands into the integral type that is large
        # enough to represent the sum and the mask.
        summation_type_spec = _compute_summation_type_for_bitwidth(
            extended_bitwidth, summands_type)
        # `summands` is a list of all clients' summands. We map
        # `_map_numpy_or_structure` to the list, applying it pointwise to clients.
        summand_tensors = tf.nest.map_structure(_extract_numpy_arrays,
                                                summands)
        # Dtype conversion trick: pull the summand values out, and push them back
        # into the executor using the new dtypes decided based on bitwidth.
        casted_summands = await self._executor.create_value(
            summand_tensors, computation_types.at_clients(summation_type_spec))
        # To emulate SecAgg without the random masks, we must mask the summands to
        # the effective bitwidth. This isn't strictly necessary because we also
        # mask the sum result and modulus operator is distributive, but this more
        # accurately reflects the system.
        mask = await self._embed_tf_secure_sum_mask_value(
            summation_type_spec, extended_bitwidth)
        masked_summands = await self._compute_modulus(casted_summands, mask)
        logging.debug('Computed masked modular summands as: %s', await
                      masked_summands.compute())
        # Then perform the sum and modolulo operation (using powers of 2 bitmasking)
        # on the sum, using the computed effective bitwidth.
        sum_result = await self.compute_federated_sum(masked_summands)
        modular_sums = await self._compute_modulus(sum_result, mask)
        # Dtype conversion trick again, pull the modular sum values out, and push
        # them back into the executor using the dypte from the summands.
        modular_sum_values = _extract_numpy_arrays(await
                                                   modular_sums.compute())
        logging.debug('Computed modular sums as: %s', modular_sum_values)
        return await self._executor.create_value(
            modular_sum_values, computation_types.at_server(summands_type))
Exemple #27
0
def _encoded_next_fn(server_state_type, value_type, encoders):
    """Creates `next_fn` for the process returned by `EncodedSumFactory`.

  The structure of the implementation is roughly as follows:
  * Extract params for encoding/decoding from state (`get_params_fn`).
  * Encode values to be aggregated, placed at clients (`encode_fn`).
  * Call `federated_aggregate` operator, with decoding of the part which does
    not commute with sum, placed in its `accumulate_fn` arg.
  * Finish decoding the summed value placed at server (`decode_after_sum_fn`).
  * Update the state placed at server (`update_state_fn`).

  Args:
    server_state_type: A `tff.Type` of the expected state placed at server.
    value_type: An unplaced `tff.Type` of the value to be aggregated.
    encoders: A collection of `GatherEncoder` objects.

  Returns:
    A `tff.Computation` for `EncodedSumFactory`, with the type signature of
    `(server_state_type, value_type@CLIENTS) ->
    MeasuredProcessOutput(server_state_type, value_type@SERVER, ()@SERVER)`
  """
    @computations.tf_computation(server_state_type.member)
    def get_params_fn(state):
        params = tree.map_structure_up_to(encoders,
                                          lambda e, s: e.get_params(s),
                                          encoders, state)
        encode_params = _slice(encoders, params, 0)
        decode_before_sum_params = _slice(encoders, params, 1)
        decode_after_sum_params = _slice(encoders, params, 2)
        return encode_params, decode_before_sum_params, decode_after_sum_params

    encode_params_type = get_params_fn.type_signature.result[0]
    decode_before_sum_params_type = get_params_fn.type_signature.result[1]
    decode_after_sum_params_type = get_params_fn.type_signature.result[2]

    # TODO(b/139844355): Get rid of decode_before_sum_params.
    # We pass decode_before_sum_params to the encode method, because TFF currently
    # does not have a mechanism to make a tff.SERVER placed value available inside
    # of intrinsics.federated_aggregate - in production, this could mean an
    # intermediary aggregator node. So currently, we send the params to clients,
    # and ask them to send them back as part of the encoded structure.
    @computations.tf_computation(value_type, encode_params_type,
                                 decode_before_sum_params_type)
    def encode_fn(x, encode_params, decode_before_sum_params):
        encoded_structure = tree.map_structure_up_to(
            encoders, lambda e, *args: e.encode(*args), encoders, x,
            encode_params)
        encoded_x = _slice(encoders, encoded_structure, 0)
        state_update_tensors = _slice(encoders, encoded_structure, 1)
        return encoded_x, decode_before_sum_params, state_update_tensors

    state_update_tensors_type = encode_fn.type_signature.result[2]

    # This is not a @computations.tf_computation because it will be used below
    # when bulding the computations.tf_computations that will compose a
    # intrinsics.federated_aggregate...
    def decode_before_sum_tf_function(encoded_x, decode_before_sum_params):
        part_decoded_x = tree.map_structure_up_to(
            encoders, lambda e, *args: e.decode_before_sum(*args), encoders,
            encoded_x, decode_before_sum_params)
        one = tf.constant((1, ), tf.int32)
        return part_decoded_x, one

    # ...however, result type is needed to build the subsequent tf_compuations.
    @computations.tf_computation(encode_fn.type_signature.result[0:2])
    def tmp_decode_before_sum_fn(encoded_x, decode_before_sum_params):
        return decode_before_sum_tf_function(encoded_x,
                                             decode_before_sum_params)

    part_decoded_x_type = tmp_decode_before_sum_fn.type_signature.result
    del tmp_decode_before_sum_fn  # Only needed for result type.

    @computations.tf_computation(part_decoded_x_type,
                                 decode_after_sum_params_type)
    def decode_after_sum_fn(summed_values, decode_after_sum_params):
        part_decoded_aggregated_x, num_summands = summed_values
        return tree.map_structure_up_to(
            encoders,
            lambda e, x, params: e.decode_after_sum(x, params, num_summands),
            encoders, part_decoded_aggregated_x, decode_after_sum_params)

    @computations.tf_computation(server_state_type.member,
                                 state_update_tensors_type)
    def update_state_fn(state, state_update_tensors):
        return tree.map_structure_up_to(encoders,
                                        lambda e, *args: e.update_state(*args),
                                        encoders, state, state_update_tensors)

    # Computations for intrinsics.federated_aggregate.
    def _accumulator_value(values, state_update_tensors):
        return collections.OrderedDict(
            values=values, state_update_tensors=state_update_tensors)

    @computations.tf_computation
    def zero_fn():
        values = tf.nest.map_structure(
            lambda s: tf.zeros(s.shape, s.dtype),
            type_conversions.type_to_tf_tensor_specs(part_decoded_x_type))
        state_update_tensors = tf.nest.map_structure(
            lambda s: tf.zeros(s.shape, s.dtype),
            type_conversions.type_to_tf_tensor_specs(
                state_update_tensors_type))
        return _accumulator_value(values, state_update_tensors)

    accumulator_type = zero_fn.type_signature.result
    state_update_aggregation_modes = tf.nest.map_structure(
        lambda e: tuple(e.state_update_aggregation_modes), encoders)

    @computations.tf_computation(accumulator_type,
                                 encode_fn.type_signature.result)
    def accumulate_fn(acc, encoded_x):
        value, params, state_update_tensors = encoded_x
        part_decoded_value = decode_before_sum_tf_function(value, params)
        new_values = tf.nest.map_structure(tf.add, acc['values'],
                                           part_decoded_value)
        new_state_update_tensors = tf.nest.map_structure(
            _accmulate_state_update_tensor, acc['state_update_tensors'],
            state_update_tensors, state_update_aggregation_modes)
        return _accumulator_value(new_values, new_state_update_tensors)

    @computations.tf_computation(accumulator_type, accumulator_type)
    def merge_fn(acc1, acc2):
        new_values = tf.nest.map_structure(tf.add, acc1['values'],
                                           acc2['values'])
        new_state_update_tensors = tf.nest.map_structure(
            _accmulate_state_update_tensor, acc1['state_update_tensors'],
            acc2['state_update_tensors'], state_update_aggregation_modes)
        return _accumulator_value(new_values, new_state_update_tensors)

    @computations.tf_computation(accumulator_type)
    def report_fn(acc):
        return acc

    @computations.federated_computation(
        server_state_type, computation_types.at_clients(value_type))
    def next_fn(state, value):
        encode_params, decode_before_sum_params, decode_after_sum_params = (
            intrinsics.federated_map(get_params_fn, state))
        encode_params = intrinsics.federated_broadcast(encode_params)
        decode_before_sum_params = intrinsics.federated_broadcast(
            decode_before_sum_params)

        encoded_values = intrinsics.federated_map(
            encode_fn, [value, encode_params, decode_before_sum_params])

        aggregated_values = intrinsics.federated_aggregate(
            encoded_values, zero_fn(), accumulate_fn, merge_fn, report_fn)

        decoded_values = intrinsics.federated_map(
            decode_after_sum_fn,
            [aggregated_values.values, decode_after_sum_params])

        updated_state = intrinsics.federated_map(
            update_state_fn, [state, aggregated_values.state_update_tensors])

        empty_metrics = intrinsics.federated_value((), placements.SERVER)
        return measured_process.MeasuredProcessOutput(
            state=updated_state,
            result=decoded_values,
            measurements=empty_metrics)

    return next_fn
@computations.federated_computation()
def test_init_fn():
    return intrinsics.federated_value(0, placements.SERVER)


test_state_type = test_init_fn.type_signature.result


@computations.tf_computation
def sum_sequence(s):
    spec = s.element_spec
    return s.reduce(tf.zeros(spec.shape, spec.dtype),
                    lambda s, t: tf.nest.map_structure(tf.add, s, t))


ClientIntSequenceType = computation_types.at_clients(
    computation_types.SequenceType(tf.int32))


def build_next_fn(server_init_fn):
    @computations.federated_computation(server_init_fn.type_signature.result,
                                        ClientIntSequenceType)
    def next_fn(state, client_values):
        metrics = intrinsics.federated_map(sum_sequence, client_values)
        metrics = intrinsics.federated_sum(metrics)
        return LearningProcessOutput(state, metrics)

    return next_fn


def build_report_fn(server_init_fn):
    @computations.tf_computation(server_init_fn.type_signature.result.member)
async def compute_intrinsic_federated_weighted_mean(
    executor: executor_base.Executor, arg: executor_value_base.ExecutorValue
) -> executor_value_base.ExecutorValue:
    """Computes a federated weighted mean on the given `executor`.

  Args:
    executor: The executor to use.
    arg: The argument to embedded in `executor`.

  Returns:
    The result embedded in `executor`.
  """
    type_analysis.check_valid_federated_weighted_mean_argument_tuple_type(
        arg.type_signature)
    zip1_type = computation_types.FunctionType(
        computation_types.StructType([
            computation_types.at_clients(arg.type_signature[0].member),
            computation_types.at_clients(arg.type_signature[1].member)
        ]),
        computation_types.at_clients(
            computation_types.StructType(
                [arg.type_signature[0].member, arg.type_signature[1].member])))

    multiply_blk = building_block_factory.create_tensorflow_binary_operator_with_upcast(
        zip1_type.result.member, tf.multiply)

    map_type = computation_types.FunctionType(
        computation_types.StructType(
            [multiply_blk.type_signature, zip1_type.result]),
        computation_types.at_clients(multiply_blk.type_signature.result))

    sum1_type = computation_types.FunctionType(
        computation_types.at_clients(map_type.result.member),
        computation_types.at_server(map_type.result.member))

    sum2_type = computation_types.FunctionType(
        computation_types.at_clients(arg.type_signature[1].member),
        computation_types.at_server(arg.type_signature[1].member))

    zip2_type = computation_types.FunctionType(
        computation_types.StructType([sum1_type.result, sum2_type.result]),
        computation_types.at_server(
            computation_types.StructType(
                [sum1_type.result.member, sum2_type.result.member])))

    divide_blk = building_block_factory.create_tensorflow_binary_operator_with_upcast(
        zip2_type.result.member, tf.divide)

    async def _compute_multiply_fn():
        return await executor.create_value(multiply_blk.proto,
                                           multiply_blk.type_signature)

    async def _compute_multiply_arg():
        zip1_comp = create_intrinsic_comp(
            intrinsic_defs.FEDERATED_ZIP_AT_CLIENTS, zip1_type)
        zip_fn = await executor.create_value(zip1_comp, zip1_type)
        return await executor.create_call(zip_fn, arg)

    async def _compute_product_fn():
        map_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_MAP,
                                         map_type)
        return await executor.create_value(map_comp, map_type)

    async def _compute_product_arg():
        multiply_fn, multiply_arg = await asyncio.gather(
            _compute_multiply_fn(), _compute_multiply_arg())
        return await executor.create_struct((multiply_fn, multiply_arg))

    async def _compute_products():
        product_fn, product_arg = await asyncio.gather(_compute_product_fn(),
                                                       _compute_product_arg())
        return await executor.create_call(product_fn, product_arg)

    async def _compute_total_weight():
        sum2_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM,
                                          sum2_type)
        sum2_fn, sum2_arg = await asyncio.gather(
            executor.create_value(sum2_comp, sum2_type),
            executor.create_selection(arg, index=1))
        return await executor.create_call(sum2_fn, sum2_arg)

    async def _compute_sum_of_products():
        sum1_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_SUM,
                                          sum1_type)
        sum1_fn, products = await asyncio.gather(
            executor.create_value(sum1_comp, sum1_type), _compute_products())
        return await executor.create_call(sum1_fn, products)

    async def _compute_zip2_fn():
        zip2_comp = create_intrinsic_comp(
            intrinsic_defs.FEDERATED_ZIP_AT_SERVER, zip2_type)
        return await executor.create_value(zip2_comp, zip2_type)

    async def _compute_zip2_arg():
        sum_of_products, total_weight = await asyncio.gather(
            _compute_sum_of_products(), _compute_total_weight())
        return await executor.create_struct([sum_of_products, total_weight])

    async def _compute_divide_fn():
        return await executor.create_value(divide_blk.proto,
                                           divide_blk.type_signature)

    async def _compute_divide_arg():
        zip_fn, zip_arg = await asyncio.gather(_compute_zip2_fn(),
                                               _compute_zip2_arg())
        return await executor.create_call(zip_fn, zip_arg)

    async def _compute_apply_fn():
        apply_type = computation_types.FunctionType(
            computation_types.StructType(
                [divide_blk.type_signature, zip2_type.result]),
            computation_types.at_server(divide_blk.type_signature.result))
        apply_comp = create_intrinsic_comp(intrinsic_defs.FEDERATED_APPLY,
                                           apply_type)
        return await executor.create_value(apply_comp, apply_type)

    async def _compute_apply_arg():
        divide_fn, divide_arg = await asyncio.gather(_compute_divide_fn(),
                                                     _compute_divide_arg())
        return await executor.create_struct([divide_fn, divide_arg])

    async def _compute_divided():
        apply_fn, apply_arg = await asyncio.gather(_compute_apply_fn(),
                                                   _compute_apply_arg())
        return await executor.create_call(apply_fn, apply_arg)

    return await _compute_divided()
def _clipped_sum(clip=2.0):
  return clipping_factory.ClippingFactory(clip, sum_factory.SumFactory())


def _zeroed_mean(clip=2.0, norm_order=2.0):
  return clipping_factory.ZeroingFactory(clip, mean_factory.MeanFactory(),
                                         norm_order)


def _zeroed_sum(clip=2.0, norm_order=2.0):
  return clipping_factory.ZeroingFactory(clip, sum_factory.SumFactory(),
                                         norm_order)


_float_at_server = computation_types.at_server(tf.float32)
_float_at_clients = computation_types.at_clients(tf.float32)


@computations.federated_computation()
def _test_init_fn():
  return intrinsics.federated_value(1., placements.SERVER)


@computations.federated_computation(_float_at_server, _float_at_clients)
def _test_next_fn(state, value):
  del value
  return intrinsics.federated_map(
      computations.tf_computation(lambda x: x + 1., tf.float32), state)


@computations.federated_computation(_float_at_server)