コード例 #1
0
    def test_type_properties(self, modulus, value_type, symmetric_range):
        factory_ = secure.SecureModularSumFactory(
            modulus=modulus, symmetric_range=symmetric_range)
        self.assertIsInstance(factory_, factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = factory_.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        expected_state_type = computation_types.at_server(
            computation_types.to_type(()))
        expected_measurements_type = expected_state_type

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
        try:
            static_assert.assert_not_contains_unsecure_aggregation(
                process.next)
        except:  # pylint: disable=bare-except
            self.fail('Factory returned an AggregationProcess containing '
                      'non-secure aggregation.')
コード例 #2
0
 def test_weighted_secure_aggregator_only_contains_secure_aggregation(self):
     aggregator = model_update_aggregator.secure_aggregator(
         weighted=True).create(_float_matrix_type, _float_type)
     try:
         static_assert.assert_not_contains_unsecure_aggregation(
             aggregator.next)
     except:  # pylint: disable=bare-except
         self.fail('Secure aggregator contains non-secure aggregation.')
コード例 #3
0
 def test_no_unsecure_aggregation_with_secure_aggregator(self):
     model_fn = model_examples.LinearRegression
     learning_process = fed_sgd.build_fed_sgd(
         model_fn,
         model_aggregator=model_update_aggregator.secure_aggregator(),
         metrics_aggregator=aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(
         learning_process.next)
コード例 #4
0
 def test_ddp_secure_aggregator_only_contains_secure_aggregation(self):
     aggregator = model_update_aggregator.ddp_secure_aggregator(
         noise_multiplier=1e-2,
         expected_clients_per_round=10).create(_float_matrix_type)
     try:
         static_assert.assert_not_contains_unsecure_aggregation(
             aggregator.next)
     except:  # pylint: disable=bare-except
         self.fail('Secure aggregator contains non-secure aggregation.')
コード例 #5
0
 def test_secure_sum(self, dp_mechanism):
     hihi_computation = hihi.build_hierarchical_histogram_computation(
         lower_bound=0,
         upper_bound=10,
         num_bins=5,
         dp_mechanism=dp_mechanism,
         enable_secure_sum=True)
     static_assert.assert_not_contains_unsecure_aggregation(
         hihi_computation)
コード例 #6
0
ファイル: mime_test.py プロジェクト: tensorflow/federated
 def test_unweighted_mime_lite_with_only_secure_aggregation(self):
     aggregator = model_update_aggregator.secure_aggregator(weighted=False)
     learning_process = mime.build_unweighted_mime_lite(
         model_examples.LinearRegression,
         base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9),
         model_aggregator=aggregator,
         full_gradient_aggregator=aggregator,
         metrics_aggregator=metrics_aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(
         learning_process.next)
コード例 #7
0
ファイル: fed_avg_test.py プロジェクト: tensorflow/federated
 def test_unweighted_fed_avg_with_only_secure_aggregation(self):
     model_fn = model_examples.LinearRegression
     learning_process = fed_avg.build_unweighted_fed_avg(
         model_fn,
         client_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0),
         model_aggregator=model_update_aggregator.secure_aggregator(
             weighted=False),
         metrics_aggregator=aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(
         learning_process.next)
コード例 #8
0
 def test_secure_estimation_true_only_contains_secure_aggregation(self):
     secure_process = QEProcess.no_noise(initial_estimate=1.0,
                                         target_quantile=0.5,
                                         learning_rate=1.0,
                                         secure_estimation=True)
     try:
         static_assert.assert_not_contains_unsecure_aggregation(
             secure_process.next)
     except:  # pylint: disable=bare-except
         self.fail('Computation contains non-secure aggregation.')
コード例 #9
0
 def test_construction_with_only_secure_aggregation(self):
     model_fn = model_examples.LinearRegression
     learning_process = fed_avg_with_optimizer_schedule.build_weighted_fed_avg_with_optimizer_schedule(
         model_fn,
         client_learning_rate_fn=lambda x: 0.5,
         client_optimizer_fn=tf.keras.optimizers.SGD,
         model_aggregator=model_update_aggregator.secure_aggregator(
             weighted=True),
         metrics_aggregator=aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(
         learning_process.next)
コード例 #10
0
ファイル: fed_avg_test.py プロジェクト: tensorflow/federated
 def test_weighted_fed_avg_with_only_secure_aggregation(self):
     model = test_models.build_functional_linear_regression()
     learning_process = fed_avg.build_weighted_fed_avg(
         model_fn=None,
         model=model,
         client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1),
         model_aggregator=model_update_aggregator.secure_aggregator(
             weighted=True),
         metrics_aggregator=aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(
         learning_process.next)
コード例 #11
0
    def test_default_value_ranges_returns_correct_results(
            self, metric_finalizers, local_unfinalized_metrics_at_clients,
            expected_aggregated_metrics):
        aggregator_computation = aggregator.secure_sum_then_finalize(
            metric_finalizers=metric_finalizers,
            local_unfinalized_metrics_type=type_conversions.type_from_tensors(
                local_unfinalized_metrics_at_clients[0]))
        try:
            static_assert.assert_not_contains_unsecure_aggregation(
                aggregator_computation)
        except:  # pylint: disable=bare-except
            self.fail(
                'Metric aggregation contains non-secure summation aggregation')

        aggregated_metrics = aggregator_computation(
            local_unfinalized_metrics_at_clients)

        no_clipped_values = collections.OrderedDict(
            secure_upper_clipped_count=0,
            secure_lower_clipped_count=0,
            secure_upper_threshold=aggregator.DEFAULT_SECURE_UPPER_BOUND,
            secure_lower_threshold=aggregator.DEFAULT_SECURE_LOWER_BOUND)

        factory_keys = collections.OrderedDict()
        for value in tf.nest.flatten(local_unfinalized_metrics_at_clients[0]):
            tensor = tf.constant(value)
            if tensor.dtype.is_floating:
                lower = float(aggregator.DEFAULT_SECURE_LOWER_BOUND)
                upper = float(aggregator.DEFAULT_SECURE_UPPER_BOUND)
            elif tensor.dtype.is_integer:
                lower = int(aggregator.DEFAULT_SECURE_LOWER_BOUND)
                upper = int(aggregator.DEFAULT_SECURE_UPPER_BOUND)
            else:
                raise TypeError(
                    f'Expected float or int, found tensors of dtype {tensor.dtype}.'
                )
            factory_key = aggregator._create_factory_key(
                lower, upper, tensor.dtype)
            factory_keys[factory_key] = 1

        expected_measurements = collections.OrderedDict(
            (factory_key, no_clipped_values) for factory_key in factory_keys)
        secure_sum_measurements = aggregated_metrics.pop(
            'secure_sum_measurements')
        self.assertAllClose(secure_sum_measurements, expected_measurements)
        self.assertAllClose(aggregated_metrics,
                            expected_aggregated_metrics,
                            rtol=1e-5,
                            atol=1e-5)
コード例 #12
0
    def test_type_properties_adaptive_bounds(self, value_type, dtype):
        upper_bound_process = _test_estimation_process(1)
        lower_bound_process = _test_estimation_process(-1)
        secure_sum_f = secure.SecureSumFactory(
            upper_bound_threshold=upper_bound_process,
            lower_bound_threshold=lower_bound_process)
        self.assertIsInstance(secure_sum_f,
                              factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = secure_sum_f.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        threshold_type = upper_bound_process.report.type_signature.result.member
        expected_state_type = computation_types.at_server(
            computation_types.to_type((threshold_type, threshold_type)))
        expected_measurements_type = _measurements_type(dtype)

        expected_initialize_type = computation_types.FunctionType(
            parameter=None, result=expected_state_type)
        self.assertTrue(
            process.initialize.type_signature.is_equivalent_to(
                expected_initialize_type))

        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=expected_measurements_type))
        self.assertTrue(
            process.next.type_signature.is_equivalent_to(expected_next_type))
        try:
            static_assert.assert_not_contains_unsecure_aggregation(
                process.next)
        except:  # pylint: disable=bare-except
            self.fail('Factory returned an AggregationProcess containing '
                      'non-secure aggregation.')
コード例 #13
0
 def test_fails_on_bothagg(self):
     with self.assertRaises(AssertionError):
         static_assert.assert_not_contains_unsecure_aggregation(
             secure_and_unsecure_aggregation)
コード例 #14
0
 def test_passes_on_secagg(self):
     static_assert.assert_not_contains_unsecure_aggregation(
         secure_aggregation)
コード例 #15
0
    def test_type_properties(self, value_type, mechanism):
        ddp_factory = _make_test_factory(mechanism=mechanism)
        self.assertIsInstance(ddp_factory,
                              factory.UnweightedAggregationFactory)
        value_type = computation_types.to_type(value_type)
        process = ddp_factory.create(value_type)
        self.assertIsInstance(process, aggregation_process.AggregationProcess)

        # The state is a nested object with component factory states. Construct
        # test factories directly and compare the signatures.
        modsum_f = secure.SecureModularSumFactory(2**15, True)

        if mechanism == 'distributed_dgauss':
            dp_query = tfp.DistributedDiscreteGaussianSumQuery(
                l2_norm_bound=10.0, local_stddev=10.0)
        else:
            dp_query = tfp.DistributedSkellamSumQuery(l1_norm_bound=10.0,
                                                      l2_norm_bound=10.0,
                                                      local_stddev=10.0)

        dp_f = differential_privacy.DifferentiallyPrivateFactory(
            dp_query, modsum_f)
        discrete_f = discretization.DiscretizationFactory(dp_f)
        l2clip_f = robust.clipping_factory(clipping_norm=10.0,
                                           inner_agg_factory=discrete_f)
        rot_f = rotation.HadamardTransformFactory(inner_agg_factory=l2clip_f)
        expected_process = concat.concat_factory(rot_f).create(value_type)

        # Check init_fn/state.
        expected_init_type = expected_process.initialize.type_signature
        expected_state_type = expected_init_type.result
        actual_init_type = process.initialize.type_signature
        self.assertTrue(actual_init_type.is_equivalent_to(expected_init_type))

        # Check next_fn/measurements.
        tensor2type = type_conversions.type_from_tensors
        discrete_state = discrete_f.create(
            computation_types.to_type(tf.float32)).initialize()
        dp_query_state = dp_query.initial_global_state()
        dp_query_metrics_type = tensor2type(
            dp_query.derive_metrics(dp_query_state))
        expected_measurements_type = collections.OrderedDict(
            l2_clip=robust.NORM_TF_TYPE,
            scale_factor=tensor2type(discrete_state['scale_factor']),
            scaled_inflated_l2=tensor2type(dp_query_state.l2_norm_bound),
            scaled_local_stddev=tensor2type(dp_query_state.local_stddev),
            actual_num_clients=tf.int32,
            padded_dim=tf.int32,
            dp_query_metrics=dp_query_metrics_type)
        expected_next_type = computation_types.FunctionType(
            parameter=collections.OrderedDict(
                state=expected_state_type,
                value=computation_types.at_clients(value_type)),
            result=measured_process.MeasuredProcessOutput(
                state=expected_state_type,
                result=computation_types.at_server(value_type),
                measurements=computation_types.at_server(
                    expected_measurements_type)))
        actual_next_type = process.next.type_signature
        self.assertTrue(actual_next_type.is_equivalent_to(expected_next_type))
        try:
            static_assert.assert_not_contains_unsecure_aggregation(
                process.next)
        except:  # pylint: disable=bare-except
            self.fail('Factory returned an AggregationProcess containing '
                      'non-secure aggregation.')
コード例 #16
0
 def test_no_unsecure_aggregation_with_secure_metrics_finalizer(self):
     evaluate_comp = federated_evaluation.build_federated_evaluation(
         _model_fn_from_keras,
         metrics_aggregator=aggregator.secure_sum_then_finalize)
     static_assert.assert_not_contains_unsecure_aggregation(evaluate_comp)