def test_type_properties(self, modulus, value_type, symmetric_range): factory_ = secure.SecureModularSumFactory( modulus=modulus, symmetric_range=symmetric_range) self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) expected_state_type = computation_types.at_server( computation_types.to_type(())) expected_measurements_type = expected_state_type expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type)) try: static_assert.assert_not_contains_unsecure_aggregation( process.next) except: # pylint: disable=bare-except self.fail('Factory returned an AggregationProcess containing ' 'non-secure aggregation.')
def test_weighted_secure_aggregator_only_contains_secure_aggregation(self): aggregator = model_update_aggregator.secure_aggregator( weighted=True).create(_float_matrix_type, _float_type) try: static_assert.assert_not_contains_unsecure_aggregation( aggregator.next) except: # pylint: disable=bare-except self.fail('Secure aggregator contains non-secure aggregation.')
def test_no_unsecure_aggregation_with_secure_aggregator(self): model_fn = model_examples.LinearRegression learning_process = fed_sgd.build_fed_sgd( model_fn, model_aggregator=model_update_aggregator.secure_aggregator(), metrics_aggregator=aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next)
def test_ddp_secure_aggregator_only_contains_secure_aggregation(self): aggregator = model_update_aggregator.ddp_secure_aggregator( noise_multiplier=1e-2, expected_clients_per_round=10).create(_float_matrix_type) try: static_assert.assert_not_contains_unsecure_aggregation( aggregator.next) except: # pylint: disable=bare-except self.fail('Secure aggregator contains non-secure aggregation.')
def test_secure_sum(self, dp_mechanism): hihi_computation = hihi.build_hierarchical_histogram_computation( lower_bound=0, upper_bound=10, num_bins=5, dp_mechanism=dp_mechanism, enable_secure_sum=True) static_assert.assert_not_contains_unsecure_aggregation( hihi_computation)
def test_unweighted_mime_lite_with_only_secure_aggregation(self): aggregator = model_update_aggregator.secure_aggregator(weighted=False) learning_process = mime.build_unweighted_mime_lite( model_examples.LinearRegression, base_optimizer=sgdm.build_sgdm(learning_rate=0.01, momentum=0.9), model_aggregator=aggregator, full_gradient_aggregator=aggregator, metrics_aggregator=metrics_aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next)
def test_unweighted_fed_avg_with_only_secure_aggregation(self): model_fn = model_examples.LinearRegression learning_process = fed_avg.build_unweighted_fed_avg( model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(1.0), model_aggregator=model_update_aggregator.secure_aggregator( weighted=False), metrics_aggregator=aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next)
def test_secure_estimation_true_only_contains_secure_aggregation(self): secure_process = QEProcess.no_noise(initial_estimate=1.0, target_quantile=0.5, learning_rate=1.0, secure_estimation=True) try: static_assert.assert_not_contains_unsecure_aggregation( secure_process.next) except: # pylint: disable=bare-except self.fail('Computation contains non-secure aggregation.')
def test_construction_with_only_secure_aggregation(self): model_fn = model_examples.LinearRegression learning_process = fed_avg_with_optimizer_schedule.build_weighted_fed_avg_with_optimizer_schedule( model_fn, client_learning_rate_fn=lambda x: 0.5, client_optimizer_fn=tf.keras.optimizers.SGD, model_aggregator=model_update_aggregator.secure_aggregator( weighted=True), metrics_aggregator=aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next)
def test_weighted_fed_avg_with_only_secure_aggregation(self): model = test_models.build_functional_linear_regression() learning_process = fed_avg.build_weighted_fed_avg( model_fn=None, model=model, client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1), model_aggregator=model_update_aggregator.secure_aggregator( weighted=True), metrics_aggregator=aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation( learning_process.next)
def test_default_value_ranges_returns_correct_results( self, metric_finalizers, local_unfinalized_metrics_at_clients, expected_aggregated_metrics): aggregator_computation = aggregator.secure_sum_then_finalize( metric_finalizers=metric_finalizers, local_unfinalized_metrics_type=type_conversions.type_from_tensors( local_unfinalized_metrics_at_clients[0])) try: static_assert.assert_not_contains_unsecure_aggregation( aggregator_computation) except: # pylint: disable=bare-except self.fail( 'Metric aggregation contains non-secure summation aggregation') aggregated_metrics = aggregator_computation( local_unfinalized_metrics_at_clients) no_clipped_values = collections.OrderedDict( secure_upper_clipped_count=0, secure_lower_clipped_count=0, secure_upper_threshold=aggregator.DEFAULT_SECURE_UPPER_BOUND, secure_lower_threshold=aggregator.DEFAULT_SECURE_LOWER_BOUND) factory_keys = collections.OrderedDict() for value in tf.nest.flatten(local_unfinalized_metrics_at_clients[0]): tensor = tf.constant(value) if tensor.dtype.is_floating: lower = float(aggregator.DEFAULT_SECURE_LOWER_BOUND) upper = float(aggregator.DEFAULT_SECURE_UPPER_BOUND) elif tensor.dtype.is_integer: lower = int(aggregator.DEFAULT_SECURE_LOWER_BOUND) upper = int(aggregator.DEFAULT_SECURE_UPPER_BOUND) else: raise TypeError( f'Expected float or int, found tensors of dtype {tensor.dtype}.' ) factory_key = aggregator._create_factory_key( lower, upper, tensor.dtype) factory_keys[factory_key] = 1 expected_measurements = collections.OrderedDict( (factory_key, no_clipped_values) for factory_key in factory_keys) secure_sum_measurements = aggregated_metrics.pop( 'secure_sum_measurements') self.assertAllClose(secure_sum_measurements, expected_measurements) self.assertAllClose(aggregated_metrics, expected_aggregated_metrics, rtol=1e-5, atol=1e-5)
def test_type_properties_adaptive_bounds(self, value_type, dtype): upper_bound_process = _test_estimation_process(1) lower_bound_process = _test_estimation_process(-1) secure_sum_f = secure.SecureSumFactory( upper_bound_threshold=upper_bound_process, lower_bound_threshold=lower_bound_process) self.assertIsInstance(secure_sum_f, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = secure_sum_f.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) threshold_type = upper_bound_process.report.type_signature.result.member expected_state_type = computation_types.at_server( computation_types.to_type((threshold_type, threshold_type))) expected_measurements_type = _measurements_type(dtype) expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type)) try: static_assert.assert_not_contains_unsecure_aggregation( process.next) except: # pylint: disable=bare-except self.fail('Factory returned an AggregationProcess containing ' 'non-secure aggregation.')
def test_fails_on_bothagg(self): with self.assertRaises(AssertionError): static_assert.assert_not_contains_unsecure_aggregation( secure_and_unsecure_aggregation)
def test_passes_on_secagg(self): static_assert.assert_not_contains_unsecure_aggregation( secure_aggregation)
def test_type_properties(self, value_type, mechanism): ddp_factory = _make_test_factory(mechanism=mechanism) self.assertIsInstance(ddp_factory, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = ddp_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # The state is a nested object with component factory states. Construct # test factories directly and compare the signatures. modsum_f = secure.SecureModularSumFactory(2**15, True) if mechanism == 'distributed_dgauss': dp_query = tfp.DistributedDiscreteGaussianSumQuery( l2_norm_bound=10.0, local_stddev=10.0) else: dp_query = tfp.DistributedSkellamSumQuery(l1_norm_bound=10.0, l2_norm_bound=10.0, local_stddev=10.0) dp_f = differential_privacy.DifferentiallyPrivateFactory( dp_query, modsum_f) discrete_f = discretization.DiscretizationFactory(dp_f) l2clip_f = robust.clipping_factory(clipping_norm=10.0, inner_agg_factory=discrete_f) rot_f = rotation.HadamardTransformFactory(inner_agg_factory=l2clip_f) expected_process = concat.concat_factory(rot_f).create(value_type) # Check init_fn/state. expected_init_type = expected_process.initialize.type_signature expected_state_type = expected_init_type.result actual_init_type = process.initialize.type_signature self.assertTrue(actual_init_type.is_equivalent_to(expected_init_type)) # Check next_fn/measurements. tensor2type = type_conversions.type_from_tensors discrete_state = discrete_f.create( computation_types.to_type(tf.float32)).initialize() dp_query_state = dp_query.initial_global_state() dp_query_metrics_type = tensor2type( dp_query.derive_metrics(dp_query_state)) expected_measurements_type = collections.OrderedDict( l2_clip=robust.NORM_TF_TYPE, scale_factor=tensor2type(discrete_state['scale_factor']), scaled_inflated_l2=tensor2type(dp_query_state.l2_norm_bound), scaled_local_stddev=tensor2type(dp_query_state.local_stddev), actual_num_clients=tf.int32, padded_dim=tf.int32, dp_query_metrics=dp_query_metrics_type) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=computation_types.at_server( expected_measurements_type))) actual_next_type = process.next.type_signature self.assertTrue(actual_next_type.is_equivalent_to(expected_next_type)) try: static_assert.assert_not_contains_unsecure_aggregation( process.next) except: # pylint: disable=bare-except self.fail('Factory returned an AggregationProcess containing ' 'non-secure aggregation.')
def test_no_unsecure_aggregation_with_secure_metrics_finalizer(self): evaluate_comp = federated_evaluation.build_federated_evaluation( _model_fn_from_keras, metrics_aggregator=aggregator.secure_sum_then_finalize) static_assert.assert_not_contains_unsecure_aggregation(evaluate_comp)