예제 #1
0
    def test_fed_avg_with_adaptive_client_and_server(self, optimizer):
        federated_data = [[_batch_fn()]]

        client_optimizer_fn = functools.partial(optimizer, epsilon=0.01)
        server_optimizer_fn = functools.partial(optimizer, epsilon=0.01)

        iterative_process = fed_avg_client_opt.build_iterative_process(
            _uncompiled_model_builder,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn)
        _, train_outputs = self._run_rounds(iterative_process, federated_data,
                                            5)
        self.assertLess(train_outputs[4]['loss'], train_outputs[0]['loss'])
예제 #2
0
    def test_fed_avg_decreases_loss(self):
        federated_data = [[_batch_fn()]]

        client_optimizer_fn = tf.keras.optimizers.SGD
        server_optimizer_fn = tf.keras.optimizers.SGD

        iterative_process = fed_avg_client_opt.build_iterative_process(
            _uncompiled_model_builder,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn)

        _, train_outputs = self._run_rounds(iterative_process, federated_data,
                                            5)
        self.assertLess(train_outputs[4]['loss'], train_outputs[0]['loss'])
예제 #3
0
    def test_state_types(self):
        federated_data = [[_batch_fn()]]

        client_optimizer_fn = functools.partial(tf.keras.optimizers.Adam,
                                                epsilon=0.01)
        server_optimizer_fn = tf.keras.optimizers.SGD

        iterative_process = fed_avg_client_opt.build_iterative_process(
            _uncompiled_model_builder,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn)

        state, _ = self._run_rounds(iterative_process, federated_data, 1)
        self.assertIsInstance(state, fed_avg_client_opt.ServerState)
        self.assertIsInstance(state.model, tff.learning.ModelWeights)
예제 #4
0
    def test_client_state_aggregate_mean(self, n, m):
        dataset1 = [_batch_fn() for _ in range(n)]
        dataset2 = [_batch_fn() for _ in range(m)]
        federated_data = [dataset1, dataset2]

        client_optimizer_fn = functools.partial(tf.keras.optimizers.Adam,
                                                epsilon=0.01)
        server_optimizer_fn = tf.keras.optimizers.SGD

        iterative_process = fed_avg_client_opt.build_iterative_process(
            _uncompiled_model_builder,
            client_optimizer_fn=client_optimizer_fn,
            server_optimizer_fn=server_optimizer_fn,
            client_opt_weight_fn=lambda x: 1.0)

        state = iterative_process.initialize()
        state, _ = iterative_process.next(state, federated_data)

        client_opt_iteration = state.client_optimizer_state.iterations
        self.assertEqual(client_opt_iteration, (n + m) // 2)
예제 #5
0
    def test_get_model_weights(self):
        federated_data = [[_batch_fn()]]

        iterative_process = fed_avg_client_opt.build_iterative_process(
            _uncompiled_model_builder,
            client_optimizer_fn=tf.keras.optimizers.SGD,
            server_optimizer_fn=tf.keras.optimizers.SGD)
        state = iterative_process.initialize()

        self.assertIsInstance(iterative_process.get_model_weights(state),
                              tff.learning.ModelWeights)
        self.assertAllClose(
            state.model.trainable,
            iterative_process.get_model_weights(state).trainable)

        for _ in range(3):
            state, _ = iterative_process.next(state, federated_data)
            self.assertIsInstance(iterative_process.get_model_weights(state),
                                  tff.learning.ModelWeights)
            self.assertAllClose(
                state.model.trainable,
                iterative_process.get_model_weights(state).trainable)