def test_zero_type_properties_with_zeroed_count_agg_factory( self, value_type): factory = robust.zeroing_factory( zeroing_norm=1.0, inner_agg_factory=sum_factory.SumFactory(), norm_order=2.0, zeroed_count_sum_factory=aggregator_test_utils.SumPlusOneFactory()) value_type = computation_types.to_type(value_type) process = factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) server_state_type = computation_types.at_server( collections.OrderedDict(zeroing_norm=(), inner_agg=(), zeroed_count_agg=tf.int32)) expected_initialize_type = computation_types.FunctionType( parameter=None, result=server_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_measurements_type = computation_types.at_server( collections.OrderedDict(zeroing=(), zeroing_norm=robust.NORM_TF_TYPE, zeroed_count=robust.COUNT_TF_TYPE)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=server_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=server_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type))
def test_custom_model_zeroing_clipping_aggregator_factory(self): client_data = create_emnist_client_data() train_data = [client_data(), client_data()] def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ counters.NumExamplesCounter(), counters.NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] # No values should be clipped and zeroed aggregation_factory = robust.zeroing_factory( zeroing_norm=float('inf'), inner_agg_factory=mean.MeanFactory()) # Disable reconstruction via 0 learning rate to ensure post-recon loss # matches exact expectations round 0 and decreases by the next round. trainer = training_process.build_training_process( MnistModel, loss_fn=loss_fn, metrics_fn=metrics_fn, server_optimizer_fn=_get_keras_optimizer_fn(0.01), client_optimizer_fn=_get_keras_optimizer_fn(0.001), reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.0), aggregation_factory=aggregation_factory, dataset_split_fn=reconstruction_utils.simple_dataset_split_fn) state = trainer.initialize() outputs = [] states = [] for _ in range(2): state, output = trainer.next(state, train_data) outputs.append(output) states.append(state) # All weights and biases are initialized to 0, so initial logits are all 0 # and softmax probabilities are uniform over 10 classes. So negative log # likelihood is -ln(1/10). This is on expectation, so increase tolerance. self.assertAllClose(outputs[0]['train']['loss'], tf.math.log(10.0), rtol=1e-4) self.assertLess(outputs[1]['train']['loss'], outputs[0]['train']['loss']) self.assertNotAllClose(states[0].model.trainable, states[1].model.trainable) # Expect 6 reconstruction examples, 6 training examples. Only training # included in metrics. self.assertEqual(outputs[0]['train']['num_examples'], 6.0) self.assertEqual(outputs[1]['train']['num_examples'], 6.0) # Expect 4 reconstruction batches and 4 training batches. Only training # included in metrics. self.assertEqual(outputs[0]['train']['num_batches'], 4.0) self.assertEqual(outputs[1]['train']['num_batches'], 4.0)
def _default_zeroing( inner_factory: factory.AggregationFactory, secure_estimation: bool = False) -> factory.AggregationFactory: """The default adaptive zeroing wrapper.""" # Adapts very quickly to a value somewhat higher than the highest values so # far seen. zeroing_norm = quantile_estimation.PrivateQuantileEstimationProcess.no_noise( initial_estimate=10.0, target_quantile=0.98, learning_rate=math.log(10.0), multiplier=2.0, increment=1.0, secure_estimation=secure_estimation) if secure_estimation: secure_count_factory = secure.SecureSumFactory(upper_bound_threshold=1, lower_bound_threshold=0) return robust.zeroing_factory( zeroing_norm, inner_factory, zeroed_count_sum_factory=secure_count_factory) else: return robust.zeroing_factory(zeroing_norm, inner_factory)
def _default_zeroing( inner_factory: factory.AggregationFactory ) -> factory.AggregationFactory: """The default adaptive zeroing wrapper.""" # Adapts very quickly to a value somewhat higher than the highest values so # far seen. zeroing_norm = quantile_estimation.PrivateQuantileEstimationProcess.no_noise( initial_estimate=10.0, target_quantile=0.98, learning_rate=math.log(10.0), multiplier=2.0, increment=1.0) return robust.zeroing_factory(zeroing_norm, inner_factory)
def test_increasing_zero_clip_sum(self): # Tests when zeroing and clipping are performed with non-integer clips. # Zeroing norm grows by 0.75 each time, clipping norm grows by 0.25. @computations.federated_computation(_float_at_server, _float_at_clients) def zeroing_next_fn(state, value): del value return intrinsics.federated_map( computations.tf_computation(lambda x: x + 0.75, tf.float32), state) @computations.federated_computation(_float_at_server, _float_at_clients) def clipping_next_fn(state, value): del value return intrinsics.federated_map( computations.tf_computation(lambda x: x + 0.25, tf.float32), state) zeroing_norm_process = estimation_process.EstimationProcess( _test_init_fn, zeroing_next_fn, _test_report_fn) clipping_norm_process = estimation_process.EstimationProcess( _test_init_fn, clipping_next_fn, _test_report_fn) factory = robust.zeroing_factory(zeroing_norm_process, _clipped_sum(clipping_norm_process)) value_type = computation_types.to_type(tf.float32) process = factory.create(value_type) state = process.initialize() client_data = [1.0, 2.0, 3.0] output = process.next(state, client_data) self.assertAllClose(1.0, output.measurements['zeroing_norm']) self.assertAllClose(1.0, output.measurements['zeroing']['clipping_norm']) self.assertEqual(2, output.measurements['zeroed_count']) self.assertEqual(0, output.measurements['zeroing']['clipped_count']) self.assertAllClose(1.0, output.result) output = process.next(output.state, client_data) self.assertAllClose(1.75, output.measurements['zeroing_norm']) self.assertAllClose(1.25, output.measurements['zeroing']['clipping_norm']) self.assertEqual(2, output.measurements['zeroed_count']) self.assertEqual(0, output.measurements['zeroing']['clipped_count']) self.assertAllClose(1.0, output.result) output = process.next(output.state, client_data) self.assertAllClose(2.5, output.measurements['zeroing_norm']) self.assertAllClose(1.5, output.measurements['zeroing']['clipping_norm']) self.assertEqual(1, output.measurements['zeroed_count']) self.assertEqual(1, output.measurements['zeroing']['clipped_count']) self.assertAllClose(2.5, output.result) output = process.next(output.state, client_data) self.assertAllClose(3.25, output.measurements['zeroing_norm']) self.assertAllClose(1.75, output.measurements['zeroing']['clipping_norm']) self.assertEqual(0, output.measurements['zeroed_count']) self.assertEqual(2, output.measurements['zeroing']['clipped_count']) self.assertAllClose(4.5, output.result) output = process.next(output.state, client_data) self.assertAllClose(4.0, output.measurements['zeroing_norm']) self.assertAllClose(2.0, output.measurements['zeroing']['clipping_norm']) self.assertEqual(0, output.measurements['zeroed_count']) self.assertEqual(1, output.measurements['zeroing']['clipped_count']) self.assertAllClose(5.0, output.result)
def _zeroed_sum(clip=2.0, norm_order=2.0): return robust.zeroing_factory(clip, sum_factory.SumFactory(), norm_order)
def _zeroed_mean(clip=2.0, norm_order=2.0): return robust.zeroing_factory(clip, mean.MeanFactory(), norm_order)