Esempio n. 1
0
  def test_run(self):
    num_classes = 10
    data, model = core.test_util.create_toy_example(
        num_clients=10,
        num_clusters=4,
        num_classes=num_classes,
        num_examples=5,
        seed=0)
    rng_seq = core.PRNGSequence(0)
    algorithm = mime_lite.MimeLite(
        federated_data=data,
        model=model,
        base_optimizer=core.get_optimizer(
            core.OptimizerName.MOMENTUM, learning_rate=0.2, momentum=0.9),
        hparams=mime_lite.MimeLiteHParams(
            train_data_hparams=core.ClientDataHParams(batch_size=100),
            combined_data_hparams=core.ClientDataHParams(batch_size=100),
            server_learning_rate=1.0),
        rng_seq=rng_seq,
    )
    dataset = core.create_tf_dataset_for_clients(data).batch(50)

    state = algorithm.init_state()
    init_metrics = core.evaluate_single_client(
        dataset=dataset, model=model, params=state.params)
    for _ in range(10):
      state = algorithm.run_round(state, data.client_ids)
    metrics = core.evaluate_single_client(
        dataset=dataset, model=model, params=state.params)

    self.assertLess(metrics['loss'], init_metrics['loss'])
    self.assertGreater(metrics['accuracy'], init_metrics['accuracy'])
Esempio n. 2
0
    def test_run(self):
        data, model = core.test_util.create_toy_example(num_clients=10,
                                                        num_clusters=4,
                                                        num_classes=10,
                                                        num_examples=5,
                                                        seed=0)
        dataset = core.create_tf_dataset_for_clients(data).batch(50)
        algorithm = fed_avg.FedAvg(
            federated_data=data,
            model=model,
            client_optimizer=core.get_optimizer(core.OptimizerName.SGD,
                                                learning_rate=0.1),
            server_optimizer=core.get_optimizer(core.OptimizerName.MOMENTUM,
                                                learning_rate=2.0,
                                                momentum=0.9),
            hparams=fed_avg.FedAvgHParams(
                train_data_hparams=core.ClientDataHParams(batch_size=100,
                                                          num_epochs=1)),
            rng_seq=core.PRNGSequence(0))

        state = algorithm.init_state()
        init_metrics = core.evaluate_single_client(dataset=dataset,
                                                   model=model,
                                                   params=state.params)
        for _ in range(10):
            state = algorithm.run_round(state, data.client_ids)

        metrics = core.evaluate_single_client(dataset=dataset,
                                              model=model,
                                              params=state.params)
        self.assertLess(metrics['loss'], init_metrics['loss'])
        self.assertGreater(metrics['accuracy'], init_metrics['accuracy'])
Esempio n. 3
0
 def __call__(self, state: Any, round_num: int) -> core.MetricResults:
     random_state = get_pseudo_random_state(round_num,
                                            self._sample_client_random_seed)
     client_ids = list(
         random_state.choice(self._federated_data.client_ids,
                             size=self._num_clients_per_round,
                             replace=False))
     combined_dataset = core.create_tf_dataset_for_clients(
         self._federated_data, client_ids=client_ids)
     return core.evaluate_single_client(combined_dataset, self._model,
                                        state.params)
Esempio n. 4
0
    def test_run(self):
        num_classes = 10
        num_clusters = 3
        federated_data, model = core.test_util.create_toy_example(
            num_clients=10,
            num_clusters=num_clusters,
            num_classes=num_classes,
            num_examples=5,
            seed=0)
        rng_seq = core.PRNGSequence(0)
        algorithm = hyp_cluster.HypCluster(
            federated_data=federated_data,
            model=model,
            client_optimizer=core.get_optimizer(core.OptimizerName.SGD,
                                                learning_rate=0.1),
            server_optimizer=core.get_optimizer(core.OptimizerName.SGD,
                                                learning_rate=1.0),
            hparams=hyp_cluster.HypClusterHParams(
                train_data_hparams=core.ClientDataHParams(batch_size=5),
                num_clusters=num_clusters),
            rng_seq=rng_seq,
        )

        state = algorithm.init_state()
        for _ in range(10):
            state = algorithm.run_round(state, federated_data.client_ids)

        with self.subTest('num_clusters'):
            self.assertLen(state.cluster_params, num_clusters)

        with self.subTest('maximization'):
            data_hparams = core.ClientDataHParams(batch_size=5)
            cluster_client_ids = hyp_cluster.maximization(
                federated_data, federated_data.client_ids, model,
                state.cluster_params, data_hparams)
            for cluster_id, client_ids in enumerate(cluster_client_ids):
                for client_id in client_ids:
                    dataset = federated_data.create_tf_dataset_for_client(
                        client_id)
                    dataset = core.preprocess_tf_dataset(dataset, data_hparams)
                    cluster_loss = []
                    for params in state.cluster_params:
                        cluster_loss.append(
                            core.evaluate_single_client(
                                dataset, model, params)['loss'])
                    # Cluster should be the best params for a client because clients are
                    # clustered based on empirical loss.
                    self.assertEqual(cluster_id,
                                     cluster_loss.index(min(cluster_loss)))
Esempio n. 5
0
 def __call__(self, state: Any, round_num: int) -> core.MetricResults:
     del round_num
     return core.evaluate_single_client(self._dataset, self._model,
                                        state.params)
Esempio n. 6
0
 def __call__(self, state: Any, round_num: int) -> core.MetricResults:
   client_ids = self._sample_clients(round_num)
   combined_dataset = core.create_tf_dataset_for_clients(
       self._federated_data, client_ids=client_ids)
   return core.evaluate_single_client(combined_dataset, self._model,
                                      state.params)