def test_simple_training(self): it_process = build_federated_averaging_process(models.model_fn) server_state = it_process.initialize() Batch = collections.namedtuple('Batch', ['x', 'y']) # pylint: disable=invalid-name # Test out manually setting weights: keras_model = models.create_keras_model(compile_model=True) def deterministic_batch(): return Batch(x=np.ones([1, 784], dtype=np.float32), y=np.ones([1, 1], dtype=np.int64)) batch = tff.tf_computation(deterministic_batch)() federated_data = [[batch]] def keras_evaluate(state): tff.learning.assign_weights_to_keras_model(keras_model, state.model) # N.B. The loss computed here won't match the # loss computed by TFF because of the Dropout layer. keras_model.test_on_batch(batch.x, batch.y) loss_list = [] for _ in range(3): keras_evaluate(server_state) server_state, loss = it_process.next(server_state, federated_data) loss_list.append(loss) keras_evaluate(server_state) self.assertLess(np.mean(loss_list[1:]), loss_list[0])
def test_self_contained_example_custom_model(self): client_data = create_client_data() train_data = [client_data()] trainer = build_federated_averaging_process(MnistTrainableModel) state = trainer.initialize() losses = [] for _ in range(2): state, outputs = trainer.next(state, train_data) # Track the loss. losses.append(outputs.loss) self.assertLess(losses[1], losses[0])
def test_self_contained_example_keras_model(self): def model_fn(): return tff.learning.from_compiled_keras_model( models.create_simple_keras_model(), sample_batch) client_data = create_client_data() train_data = [client_data()] sample_batch = self.evaluate(next(iter(train_data[0]))) trainer = build_federated_averaging_process(model_fn) state = trainer.initialize() losses = [] for _ in range(2): state, outputs = trainer.next(state, train_data) # Track the loss. losses.append(outputs.loss) self.assertLess(losses[1], losses[0])
def test_something(self): it_process = build_federated_averaging_process(models.model_fn) self.assertIsInstance(it_process, tff.utils.IterativeProcess) federated_data_type = it_process.next.type_signature.parameter[1] self.assertEqual(str(federated_data_type), '{<x=float32[?,784],y=int64[?,1]>*}@CLIENTS')