コード例 #1
0
    def test_simple_training(self):
        it_process = build_federated_averaging_process(models.model_fn)
        server_state = it_process.initialize()
        Batch = collections.namedtuple('Batch', ['x', 'y'])  # pylint: disable=invalid-name

        # Test out manually setting weights:
        keras_model = models.create_keras_model(compile_model=True)

        def deterministic_batch():
            return Batch(x=np.ones([1, 784], dtype=np.float32),
                         y=np.ones([1, 1], dtype=np.int64))

        batch = tff.tf_computation(deterministic_batch)()
        federated_data = [[batch]]

        def keras_evaluate(state):
            tff.learning.assign_weights_to_keras_model(keras_model,
                                                       state.model)
            # N.B. The loss computed here won't match the
            # loss computed by TFF because of the Dropout layer.
            keras_model.test_on_batch(batch.x, batch.y)

        loss_list = []
        for _ in range(3):
            keras_evaluate(server_state)
            server_state, loss = it_process.next(server_state, federated_data)
            loss_list.append(loss)
        keras_evaluate(server_state)

        self.assertLess(np.mean(loss_list[1:]), loss_list[0])
コード例 #2
0
 def metrics_hook(state, metrics, round_num):
     del round_num
     del metrics
     keras_model = models.create_keras_model(compile_model=True)
     tff.learning.assign_weights_to_keras_model(keras_model,
                                                state.model)
     loss_list.append(keras_model.test_on_batch(batch.x, batch.y))