def test_simple_training(self): it_process = build_federated_averaging_process(models.model_fn) server_state = it_process.initialize() Batch = collections.namedtuple('Batch', ['x', 'y']) # pylint: disable=invalid-name # Test out manually setting weights: keras_model = models.create_keras_model(compile_model=True) def deterministic_batch(): return Batch(x=np.ones([1, 784], dtype=np.float32), y=np.ones([1, 1], dtype=np.int64)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] def keras_evaluate(state): tff.learning.assign_weights_to_keras_model(keras_model, state.model) # N.B. The loss computed here won't match the # loss computed by TFF because of the Dropout layer. keras_model.test_on_batch(batch.x, batch.y) loss_list = [] for _ in range(3): keras_evaluate(server_state) server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) keras_evaluate(server_state) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def metrics_hook(state, metrics, round_num): del round_num del metrics keras_model = models.create_keras_model(compile_model=True) tff.learning.assign_weights_to_keras_model(keras_model, state.model) loss_list.append(keras_model.test_on_batch(batch.x, batch.y))