def test_iterative_process_with_encoding(self): model_fn = model_examples.LinearRegression gather_fn = encoding_utils.build_encoded_mean_from_model( model_fn, _test_encoder_fn('gather')) broadcast_fn = encoding_utils.build_encoded_broadcast_from_model( model_fn, _test_encoder_fn('simple')) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_fn, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate= 1.0), stateful_delta_aggregate_fn=gather_fn, stateful_model_broadcast_fn=broadcast_fn) ds = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([ ('x', [[1.0, 2.0], [3.0, 4.0]]), ('y', [[5.0], [6.0]]), ])).batch(2) federated_ds = [ds] * 3 state = iterative_process.initialize() self.assertEqual(state.model_broadcast_state.trainable[0][0], 1) state, _ = iterative_process.next(state, federated_ds) self.assertEqual(state.model_broadcast_state.trainable[0][0], 2)
def test_broadcast_from_model(self): model_fn = model_examples.LinearRegression with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') broadcast_fn = encoding_utils.build_encoded_broadcast_from_model( model_fn, _test_encoder_fn('simple')) self.assertLen(w, 2) self.assertIsInstance(broadcast_fn, computation_utils.StatefulBroadcastFn)
def test_iterative_process_with_encoding(self): model_fn = model_examples.LinearRegression gather_fn = encoding_utils.build_encoded_mean_from_model( model_fn, _test_encoder_fn('gather')) broadcast_fn = encoding_utils.build_encoded_broadcast_from_model( model_fn, _test_encoder_fn('simple')) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_fn, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), stateful_delta_aggregate_fn=gather_fn, stateful_model_broadcast_fn=broadcast_fn) self._verify_iterative_process(iterative_process)
def test_iterative_process_with_encoding(self): model_fn = model_examples.TrainableLinearRegression broadcast_fn = encoding_utils.build_encoded_broadcast_from_model( model_fn, _test_encoder_fn()) iterative_process = optimizer_utils.build_model_delta_optimizer_process( model_fn=model_fn, model_to_client_delta_fn=DummyClientDeltaFn, server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0), stateful_model_broadcast_fn=broadcast_fn) ds = tf.data.Dataset.from_tensor_slices({ 'x': [[1., 2.], [3., 4.]], 'y': [[5.], [6.]] }).batch(2) federated_ds = [ds] * 3 state = iterative_process.initialize() self.assertEqual(state.model_broadcast_state.trainable.a[0], 1) state, _ = iterative_process.next(state, federated_ds) self.assertEqual(state.model_broadcast_state.trainable.a[0], 2)
def test_broadcast_from_model(self): model_fn = model_examples.LinearRegression broadcast_fn = encoding_utils.build_encoded_broadcast_from_model( model_fn, _test_encoder_fn('simple')) self.assertIsInstance(broadcast_fn, tff.utils.StatefulBroadcastFn)