def test_iterative_process_with_encoding(self): model_fn = model_examples.LinearRegression gather_fn = encoding_utils.build_encoded_mean_from_model( model_fn, _test_encoder_fn('gather')) broadcast_fn = encoding_utils.build_encoded_broadcast_from_model( model_fn, _test_encoder_fn('simple')) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_fn, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate= 1.0), stateful_delta_aggregate_fn=gather_fn, stateful_model_broadcast_fn=broadcast_fn) ds = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([ ('x', [[1.0, 2.0], [3.0, 4.0]]), ('y', [[5.0], [6.0]]), ])).batch(2) federated_ds = [ds] * 3 state = iterative_process.initialize() self.assertEqual(state.model_broadcast_state.trainable[0][0], 1) state, _ = iterative_process.next(state, federated_ds) self.assertEqual(state.model_broadcast_state.trainable[0][0], 2)
def test_mean_from_model_raise_warning(self): model_fn = model_examples.LinearRegression with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') gather_fn = encoding_utils.build_encoded_mean_from_model( model_fn, _test_encoder_fn('gather')) self.assertLen(w, 2) self.assertIsInstance(gather_fn, computation_utils.StatefulAggregateFn)
def test_iterative_process_with_encoding(self): model_fn = model_examples.LinearRegression gather_fn = encoding_utils.build_encoded_mean_from_model( model_fn, _test_encoder_fn('gather')) broadcast_fn = encoding_utils.build_encoded_broadcast_from_model( model_fn, _test_encoder_fn('simple')) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_fn, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), stateful_delta_aggregate_fn=gather_fn, stateful_model_broadcast_fn=broadcast_fn) self._verify_iterative_process(iterative_process)
def test_mean_from_model(self): model_fn = model_examples.LinearRegression gather_fn = encoding_utils.build_encoded_mean_from_model( model_fn, _test_encoder_fn('gather')) self.assertIsInstance(gather_fn, tff.utils.StatefulAggregateFn)