def test_iterative_process_fails_with_dp_agg_and_none_client_weighting(self): def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ NumExamplesCounter(), NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] # No values should be changed, but working with inf directly zeroes out all # updates. Preferring very large value, but one that can be handled in # multiplication/division gaussian_sum_query = tfp.GaussianSumQuery(l2_norm_clip=1e10, stddev=0) dp_sum_factory = differential_privacy.DifferentiallyPrivateFactory( query=gaussian_sum_query, record_aggregation_factory=sum_factory.SumFactory()) dp_mean_factory = _DPMean(dp_sum_factory) with self.assertRaisesRegex(ValueError, 'unweighted aggregator'): training_process.build_training_process( MnistModel, loss_fn=loss_fn, metrics_fn=metrics_fn, server_optimizer_fn=_get_keras_optimizer_fn(0.01), client_optimizer_fn=_get_keras_optimizer_fn(0.001), reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.0), aggregation_factory=dp_mean_factory, client_weighting=None, dataset_split_fn=reconstruction_utils.simple_dataset_split_fn)
def test_process_construction_calls_model_fn(self): # Assert that the the process building does not call `model_fn` too many # times. `model_fn` can potentially be expensive (loading weights, # processing, etc). mock_model_fn = mock.Mock(side_effect=local_recon_model_fn) training_process.build_training_process( model_fn=mock_model_fn, loss_fn=tf.keras.losses.SparseCategoricalCrossentropy, client_optimizer_fn=_get_keras_optimizer_fn()) # TODO(b/186451541): Reduce the number of calls to model_fn. self.assertEqual(mock_model_fn.call_count, 4)
def test_keras_local_layer_metrics_empty_list(self): def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [] it_process = training_process.build_training_process( local_recon_model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, client_optimizer_fn=_get_keras_optimizer_fn(0.0001), reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.001)) server_state = it_process.initialize() client_data = create_emnist_client_data() federated_data = [client_data(), client_data()] server_state, output = it_process.next(server_state, federated_data) expected_keys = ['broadcast', 'aggregation', 'train'] self.assertCountEqual(output.keys(), expected_keys) expected_train_keys = ['loss'] self.assertCountEqual(output['train'].keys(), expected_train_keys)
def test_keras_local_layer(self, optimizer_fn): def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ NumExamplesCounter(), NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] it_process = training_process.build_training_process( local_recon_model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, client_optimizer_fn=optimizer_fn(0.001), reconstruction_optimizer_fn=optimizer_fn(0.001)) server_state = it_process.initialize() client_data = create_emnist_client_data() federated_data = [ client_data(batch_size=1, max_examples=2), client_data(batch_size=2) ] server_states = [] outputs = [] loss_list = [] for _ in range(5): server_state, output = it_process.next(server_state, federated_data) server_states.append(server_state) outputs.append(output) loss_list.append(output['train']['loss']) self.assertNotAllClose(server_states[0].model.trainable, server_states[1].model.trainable) self.assertLess(np.mean(loss_list[2:]), np.mean(loss_list[:2])) expected_keys = ['broadcast', 'aggregation', 'train'] self.assertCountEqual(outputs[0].keys(), expected_keys) expected_train_keys = [ 'sparse_categorical_accuracy', 'loss', 'num_examples_total', 'num_batches_total' ] self.assertCountEqual(outputs[0]['train'].keys(), expected_train_keys) # On both rounds, each client has one post-reconstruction batch with 1 # example. self.assertEqual(outputs[0]['train']['num_examples_total'], 2) self.assertEqual(outputs[0]['train']['num_batches_total'], 2) self.assertEqual(outputs[1]['train']['num_examples_total'], 2) self.assertEqual(outputs[1]['train']['num_batches_total'], 2) expected_aggregation_keys = ['mean_weight', 'mean_value'] self.assertCountEqual(output['aggregation'].keys(), expected_aggregation_keys)
def test_custom_model_zeroing_clipping_aggregator_factory(self): client_data = create_emnist_client_data() train_data = [client_data(), client_data()] def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ counters.NumExamplesCounter(), counters.NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] # No values should be clipped and zeroed aggregation_factory = robust.zeroing_factory( zeroing_norm=float('inf'), inner_agg_factory=mean.MeanFactory()) # Disable reconstruction via 0 learning rate to ensure post-recon loss # matches exact expectations round 0 and decreases by the next round. trainer = training_process.build_training_process( MnistModel, loss_fn=loss_fn, metrics_fn=metrics_fn, server_optimizer_fn=_get_keras_optimizer_fn(0.01), client_optimizer_fn=_get_keras_optimizer_fn(0.001), reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.0), aggregation_factory=aggregation_factory, dataset_split_fn=reconstruction_utils.simple_dataset_split_fn) state = trainer.initialize() outputs = [] states = [] for _ in range(2): state, output = trainer.next(state, train_data) outputs.append(output) states.append(state) # All weights and biases are initialized to 0, so initial logits are all 0 # and softmax probabilities are uniform over 10 classes. So negative log # likelihood is -ln(1/10). This is on expectation, so increase tolerance. self.assertAllClose(outputs[0]['train']['loss'], tf.math.log(10.0), rtol=1e-4) self.assertLess(outputs[1]['train']['loss'], outputs[0]['train']['loss']) self.assertNotAllClose(states[0].model.trainable, states[1].model.trainable) # Expect 6 reconstruction examples, 6 training examples. Only training # included in metrics. self.assertEqual(outputs[0]['train']['num_examples'], 6.0) self.assertEqual(outputs[1]['train']['num_examples'], 6.0) # Expect 4 reconstruction batches and 4 training batches. Only training # included in metrics. self.assertEqual(outputs[0]['train']['num_batches'], 4.0) self.assertEqual(outputs[1]['train']['num_batches'], 4.0)
def test_server_update_with_inf_weight_is_noop(self): client_data = create_emnist_client_data() federated_data = [client_data()] client_weight_fn = lambda x: np.inf it_process = training_process.build_training_process( local_recon_model_fn, loss_fn=tf.keras.losses.SparseCategoricalCrossentropy, client_optimizer_fn=_get_keras_optimizer_fn(0.001), reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.001), client_weighting=client_weight_fn) state, _, initial_state = self._run_rounds(it_process, federated_data, 1) self.assertAllClose(state.model.trainable, initial_state.model.trainable, 1e-8) self.assertAllClose(state.model.trainable, initial_state.model.trainable, 1e-8)
def test_fed_recon_with_custom_client_weight_fn(self): client_data = create_emnist_client_data() federated_data = [client_data()] def client_weight_fn(local_outputs): return 1.0 / (1.0 + local_outputs['loss'][-1]) it_process = training_process.build_training_process( local_recon_model_fn, loss_fn=tf.keras.losses.SparseCategoricalCrossentropy, client_optimizer_fn=_get_tff_optimizer(0.0001), reconstruction_optimizer_fn=_get_tff_optimizer(0.001), client_weighting=client_weight_fn) _, train_outputs, _ = self._run_rounds(it_process, federated_data, 5) self.assertLess( np.mean([train_outputs[-1]['loss'], train_outputs[-2]['loss']]), train_outputs[0]['loss'])
def test_keras_local_layer_client_weighting_enum_uniform(self): def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ NumExamplesCounter(), NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] it_process = training_process.build_training_process( local_recon_model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, client_optimizer_fn=_get_keras_optimizer_fn(0.001), reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.001), client_weighting=client_weight_lib.ClientWeighting.UNIFORM, dataset_split_fn=reconstruction_utils.simple_dataset_split_fn) server_state = it_process.initialize() client_data = create_emnist_client_data() federated_data = [client_data(max_examples=2), client_data()] server_state, output = it_process.next(server_state, federated_data) expected_keys = ['broadcast', 'aggregation', 'train'] self.assertCountEqual(output.keys(), expected_keys) expected_train_keys = [ 'sparse_categorical_accuracy', 'loss', 'num_examples_total', 'num_batches_total' ] self.assertCountEqual(output['train'].keys(), expected_train_keys) self.assertEqual(output['train']['num_examples_total'], 5) self.assertEqual(output['train']['num_batches_total'], 3) # Ensure we are using a weighted aggregator. expected_aggregation_keys = ['mean_weight', 'mean_value'] self.assertCountEqual(output['aggregation'].keys(), expected_aggregation_keys)
def test_keras_local_layer_client_weighting_enum_num_examples(self): def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ counters.NumExamplesCounter(), counters.NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] it_process = training_process.build_training_process( local_recon_model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, client_optimizer_fn=_get_keras_optimizer_fn(0.001), reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.001), client_weighting=client_weight_lib.ClientWeighting.NUM_EXAMPLES) server_state = it_process.initialize() client_data = create_emnist_client_data() federated_data = [client_data(max_examples=2), client_data()] server_state, output = it_process.next(server_state, federated_data) expected_keys = ['broadcast', 'aggregation', 'train'] self.assertCountEqual(output.keys(), expected_keys) expected_train_keys = [ 'sparse_categorical_accuracy', 'loss', 'num_examples', 'num_batches' ] self.assertCountEqual(output['train'].keys(), expected_train_keys) # Only one client has a post-reconstruction batch, with one example. self.assertEqual(output['train']['num_examples'], 1) self.assertEqual(output['train']['num_batches'], 1) # Ensure we are using a weighted aggregator. expected_aggregation_keys = ['mean_weight', 'mean_value'] self.assertCountEqual(output['aggregation'].keys(), expected_aggregation_keys)
def test_custom_model_multiple_epochs(self, optimizer_fn): client_data = create_emnist_client_data() train_data = [client_data(), client_data()] def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ counters.NumExamplesCounter(), counters.NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] dataset_split_fn = reconstruction_utils.build_dataset_split_fn( recon_epochs=3, post_recon_epochs=4, post_recon_steps_max=3) trainer = training_process.build_training_process( MnistModel, loss_fn=loss_fn, metrics_fn=metrics_fn, client_optimizer_fn=optimizer_fn(0.001), reconstruction_optimizer_fn=optimizer_fn(0.001), dataset_split_fn=dataset_split_fn) state = trainer.initialize() outputs = [] states = [] for _ in range(2): state, output = trainer.next(state, train_data) outputs.append(output) states.append(state) self.assertLess(outputs[1]['train']['loss'], outputs[0]['train']['loss']) self.assertNotAllClose(states[0].model.trainable, states[1].model.trainable) self.assertEqual(outputs[0]['train']['num_examples'], 10.0) self.assertEqual(outputs[1]['train']['num_examples'], 10.0) self.assertEqual(outputs[0]['train']['num_batches'], 6.0) self.assertEqual(outputs[1]['train']['num_batches'], 6.0)
def test_build_train_iterative_process(self, optimizer_fn): def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ NumExamplesCounter(), NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] it_process = training_process.build_training_process( local_recon_model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, client_optimizer_fn=optimizer_fn()) self.assertIsInstance(it_process, iterative_process_lib.IterativeProcess) federated_data_type = it_process.next.type_signature.parameter[1] self.assertEqual( str(federated_data_type), '{<x=float32[?,784],y=int32[?,1]>*}@CLIENTS')
def test_get_model_weights(self): client_data = create_emnist_client_data() federated_data = [client_data()] it_process = training_process.build_training_process( local_recon_model_fn, loss_fn=tf.keras.losses.SparseCategoricalCrossentropy, client_optimizer_fn=_get_keras_optimizer_fn(0.001), reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.001)) state = it_process.initialize() self.assertIsInstance( it_process.get_model_weights(state), model_utils.ModelWeights) self.assertAllClose(state.model.trainable, it_process.get_model_weights(state).trainable) for _ in range(3): state, _ = it_process.next(state, federated_data) self.assertIsInstance( it_process.get_model_weights(state), model_utils.ModelWeights) self.assertAllClose(state.model.trainable, it_process.get_model_weights(state).trainable)
def test_keras_local_layer_custom_broadcaster(self): def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ NumExamplesCounter(), NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] model_weights_type = type_conversions.type_from_tensors( reconstruction_utils.get_global_variables(local_recon_model_fn())) def build_custom_stateful_broadcaster( model_weights_type) -> measured_process_lib.MeasuredProcess: """Builds a `MeasuredProcess` that wraps `tff.federated_broadcast`.""" @computations.federated_computation() def test_server_initialization(): return intrinsics.federated_value(2.0, placements.SERVER) @computations.federated_computation( computation_types.FederatedType(tf.float32, placements.SERVER), computation_types.FederatedType(model_weights_type, placements.SERVER), ) def stateful_broadcast(state, value): empty_metrics = intrinsics.federated_value(1.0, placements.SERVER) return measured_process_lib.MeasuredProcessOutput( state=state, result=intrinsics.federated_broadcast(value), measurements=empty_metrics) return measured_process_lib.MeasuredProcess( initialize_fn=test_server_initialization, next_fn=stateful_broadcast) it_process = training_process.build_training_process( local_recon_model_fn, loss_fn=loss_fn, metrics_fn=metrics_fn, client_optimizer_fn=_get_keras_optimizer_fn(0.001), reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.001), dataset_split_fn=reconstruction_utils.simple_dataset_split_fn, broadcast_process=build_custom_stateful_broadcaster( model_weights_type=model_weights_type)) server_state = it_process.initialize() # Ensure initialization of broadcaster produces expected metric. self.assertEqual(server_state.model_broadcast_state, 2.0) client_data = create_emnist_client_data() federated_data = [client_data(), client_data()] server_state, output = it_process.next(server_state, federated_data) expected_keys = ['broadcast', 'aggregation', 'train'] self.assertCountEqual(output.keys(), expected_keys) expected_train_keys = [ 'sparse_categorical_accuracy', 'loss', 'num_examples_total', 'num_batches_total' ] self.assertCountEqual(output['train'].keys(), expected_train_keys) self.assertEqual(output['broadcast'], 1.0)
def test_execution_with_custom_dp_query(self): client_data = create_emnist_client_data() train_data = [client_data(), client_data()] def loss_fn(): return tf.keras.losses.SparseCategoricalCrossentropy() def metrics_fn(): return [ NumExamplesCounter(), NumBatchesCounter(), tf.keras.metrics.SparseCategoricalAccuracy() ] # No values should be changed, but working with inf directly zeroes out all # updates. Preferring very large value, but one that can be handled in # multiplication/division gaussian_sum_query = tfp.GaussianSumQuery(l2_norm_clip=1e10, stddev=0) dp_sum_factory = differential_privacy.DifferentiallyPrivateFactory( query=gaussian_sum_query, record_aggregation_factory=sum_factory.SumFactory()) dp_mean_factory = _DPMean(dp_sum_factory) # Disable reconstruction via 0 learning rate to ensure post-recon loss # matches exact expectations round 0 and decreases by the next round. trainer = training_process.build_training_process( MnistModel, loss_fn=loss_fn, metrics_fn=metrics_fn, server_optimizer_fn=_get_keras_optimizer_fn(0.01), client_optimizer_fn=_get_keras_optimizer_fn(0.001), reconstruction_optimizer_fn=_get_keras_optimizer_fn(0.0), aggregation_factory=dp_mean_factory, dataset_split_fn=reconstruction_utils.simple_dataset_split_fn, client_weighting=client_weight_lib.ClientWeighting.UNIFORM, ) state = trainer.initialize() outputs = [] states = [] for _ in range(2): state, output = trainer.next(state, train_data) outputs.append(output) states.append(state) # All weights and biases are initialized to 0, so initial logits are all 0 # and softmax probabilities are uniform over 10 classes. So negative log # likelihood is -ln(1/10). This is on expectation, so increase tolerance. self.assertAllClose( outputs[0]['train']['loss'], tf.math.log(10.0), rtol=1e-4) self.assertLess(outputs[1]['train']['loss'], outputs[0]['train']['loss']) self.assertNotAllClose(states[0].model.trainable, states[1].model.trainable) # Expect 6 reconstruction examples, 6 training examples. Only training # included in metrics. self.assertEqual(outputs[0]['train']['num_examples_total'], 6.0) self.assertEqual(outputs[1]['train']['num_examples_total'], 6.0) # Expect 4 reconstruction batches and 4 training batches. Only training # included in metrics. self.assertEqual(outputs[0]['train']['num_batches_total'], 4.0) self.assertEqual(outputs[1]['train']['num_batches_total'], 4.0)