def test_run(self): num_classes = 10 data, model = core.test_util.create_toy_example( num_clients=10, num_clusters=4, num_classes=num_classes, num_examples=5, seed=0) rng_seq = core.PRNGSequence(0) algorithm = mime_lite.MimeLite( federated_data=data, model=model, base_optimizer=core.get_optimizer( core.OptimizerName.MOMENTUM, learning_rate=0.2, momentum=0.9), hparams=mime_lite.MimeLiteHParams( train_data_hparams=core.ClientDataHParams(batch_size=100), combined_data_hparams=core.ClientDataHParams(batch_size=100), server_learning_rate=1.0), rng_seq=rng_seq, ) dataset = core.create_tf_dataset_for_clients(data).batch(50) state = algorithm.init_state() init_metrics = core.evaluate_single_client( dataset=dataset, model=model, params=state.params) for _ in range(10): state = algorithm.run_round(state, data.client_ids) metrics = core.evaluate_single_client( dataset=dataset, model=model, params=state.params) self.assertLess(metrics['loss'], init_metrics['loss']) self.assertGreater(metrics['accuracy'], init_metrics['accuracy'])
def test_run(self): data, model = core.test_util.create_toy_example(num_clients=10, num_clusters=4, num_classes=10, num_examples=5, seed=0) dataset = core.create_tf_dataset_for_clients(data).batch(50) algorithm = fed_avg.FedAvg( federated_data=data, model=model, client_optimizer=core.get_optimizer(core.OptimizerName.SGD, learning_rate=0.1), server_optimizer=core.get_optimizer(core.OptimizerName.MOMENTUM, learning_rate=2.0, momentum=0.9), hparams=fed_avg.FedAvgHParams( train_data_hparams=core.ClientDataHParams(batch_size=100, num_epochs=1)), rng_seq=core.PRNGSequence(0)) state = algorithm.init_state() init_metrics = core.evaluate_single_client(dataset=dataset, model=model, params=state.params) for _ in range(10): state = algorithm.run_round(state, data.client_ids) metrics = core.evaluate_single_client(dataset=dataset, model=model, params=state.params) self.assertLess(metrics['loss'], init_metrics['loss']) self.assertGreater(metrics['accuracy'], init_metrics['accuracy'])
def __call__(self, state: Any, round_num: int) -> core.MetricResults: random_state = get_pseudo_random_state(round_num, self._sample_client_random_seed) client_ids = list( random_state.choice(self._federated_data.client_ids, size=self._num_clients_per_round, replace=False)) combined_dataset = core.create_tf_dataset_for_clients( self._federated_data, client_ids=client_ids) return core.evaluate_single_client(combined_dataset, self._model, state.params)
def test_run(self): num_classes = 10 num_clusters = 3 federated_data, model = core.test_util.create_toy_example( num_clients=10, num_clusters=num_clusters, num_classes=num_classes, num_examples=5, seed=0) rng_seq = core.PRNGSequence(0) algorithm = hyp_cluster.HypCluster( federated_data=federated_data, model=model, client_optimizer=core.get_optimizer(core.OptimizerName.SGD, learning_rate=0.1), server_optimizer=core.get_optimizer(core.OptimizerName.SGD, learning_rate=1.0), hparams=hyp_cluster.HypClusterHParams( train_data_hparams=core.ClientDataHParams(batch_size=5), num_clusters=num_clusters), rng_seq=rng_seq, ) state = algorithm.init_state() for _ in range(10): state = algorithm.run_round(state, federated_data.client_ids) with self.subTest('num_clusters'): self.assertLen(state.cluster_params, num_clusters) with self.subTest('maximization'): data_hparams = core.ClientDataHParams(batch_size=5) cluster_client_ids = hyp_cluster.maximization( federated_data, federated_data.client_ids, model, state.cluster_params, data_hparams) for cluster_id, client_ids in enumerate(cluster_client_ids): for client_id in client_ids: dataset = federated_data.create_tf_dataset_for_client( client_id) dataset = core.preprocess_tf_dataset(dataset, data_hparams) cluster_loss = [] for params in state.cluster_params: cluster_loss.append( core.evaluate_single_client( dataset, model, params)['loss']) # Cluster should be the best params for a client because clients are # clustered based on empirical loss. self.assertEqual(cluster_id, cluster_loss.index(min(cluster_loss)))
def __call__(self, state: Any, round_num: int) -> core.MetricResults: del round_num return core.evaluate_single_client(self._dataset, self._model, state.params)
def __call__(self, state: Any, round_num: int) -> core.MetricResults: client_ids = self._sample_clients(round_num) combined_dataset = core.create_tf_dataset_for_clients( self._federated_data, client_ids=client_ids) return core.evaluate_single_client(combined_dataset, self._model, state.params)