Exemplo n.º 1
0
    def test_simple_training(self):
        it_process = build_federated_averaging_process(models.model_fn)
        server_state = it_process.initialize()
        Batch = collections.namedtuple('Batch', ['x', 'y'])  # pylint: disable=invalid-name

        # Test out manually setting weights:
        keras_model = models.create_keras_model(compile_model=True)

        def deterministic_batch():
            return Batch(x=np.ones([1, 784], dtype=np.float32),
                         y=np.ones([1, 1], dtype=np.int64))

        batch = tff.tf_computation(deterministic_batch)()
        federated_data = [[batch]]

        def keras_evaluate(state):
            tff.learning.assign_weights_to_keras_model(keras_model,
                                                       state.model)
            # N.B. The loss computed here won't match the
            # loss computed by TFF because of the Dropout layer.
            keras_model.test_on_batch(batch.x, batch.y)

        loss_list = []
        for _ in range(3):
            keras_evaluate(server_state)
            server_state, loss = it_process.next(server_state, federated_data)
            loss_list.append(loss)
        keras_evaluate(server_state)

        self.assertLess(np.mean(loss_list[1:]), loss_list[0])
Exemplo n.º 2
0
    def test_self_contained_example_custom_model(self):

        client_data = create_client_data()
        train_data = [client_data()]

        trainer = build_federated_averaging_process(MnistTrainableModel)
        state = trainer.initialize()
        losses = []
        for _ in range(2):
            state, outputs = trainer.next(state, train_data)
            # Track the loss.
            losses.append(outputs.loss)
        self.assertLess(losses[1], losses[0])
Exemplo n.º 3
0
    def test_self_contained_example_keras_model(self):
        def model_fn():
            return tff.learning.from_compiled_keras_model(
                models.create_simple_keras_model(), sample_batch)

        client_data = create_client_data()
        train_data = [client_data()]
        sample_batch = self.evaluate(next(iter(train_data[0])))

        trainer = build_federated_averaging_process(model_fn)
        state = trainer.initialize()
        losses = []
        for _ in range(2):
            state, outputs = trainer.next(state, train_data)
            # Track the loss.
            losses.append(outputs.loss)
        self.assertLess(losses[1], losses[0])
Exemplo n.º 4
0
 def test_something(self):
     it_process = build_federated_averaging_process(models.model_fn)
     self.assertIsInstance(it_process, tff.utils.IterativeProcess)
     federated_data_type = it_process.next.type_signature.parameter[1]
     self.assertEqual(str(federated_data_type),
                      '{<x=float32[?,784],y=int64[?,1]>*}@CLIENTS')