def test_periodic_eval_fn_map(self): config = federated_experiment.FederatedExperimentConfig( root_dir=self.create_tempdir(), num_rounds=5, eval_frequency=3) client_sampler = FakeClientSampler() algorithm = fake_algorithm() expected_state = {1: (1, 1), 3: (3, 6)} expected_client_ids = {1: [1], 3: [3]} state = federated_experiment.run_federated_experiment( algorithm=algorithm, init_state=algorithm.init(), client_sampler=client_sampler, config=config, periodic_eval_fn_map={ 'test_eval': FakeEvaluationFn(self, expected_state), 'train_eval': FakeTrainClientsEvaluationFn(self, expected_state, expected_client_ids) }) self.assertEqual(state, (5, 15)) self.assertCountEqual( glob.glob(os.path.join(config.root_dir, '*eval*')), [ os.path.join(config.root_dir, 'test_eval'), os.path.join(config.root_dir, 'train_eval') ])
def test_run_federated_experiment_periodic_eval_fn_map(self): config = federated_experiment.FederatedExperimentConfig( root_dir=self.create_tempdir(), num_rounds=5, num_clients_per_round=2, eval_frequency=3) federated_algorithm = test_util.MockFederatedAlgorithm(num_examples=3) federated_data = federated_algorithm.federated_data model = federated_algorithm.model periodic_eval_fn_map = collections.OrderedDict( client_eval_1=federated_experiment.ClientEvaluationFn( federated_data, model, config), client_eval_2=federated_experiment.ClientEvaluationFn( federated_data, model, config), full_eval_1=federated_experiment.FullEvaluationFn( federated_data, model)) federated_experiment.run_federated_experiment( config, federated_algorithm, periodic_eval_fn_map=periodic_eval_fn_map) self.assertTrue( os.path.exists(os.path.join(config.root_dir, 'client_eval_1'))) self.assertTrue( os.path.exists(os.path.join(config.root_dir, 'client_eval_2'))) self.assertTrue( os.path.exists(os.path.join(config.root_dir, 'full_eval_1')))
def test_checkpoint(self): with self.subTest('checkpoint init'): config = federated_experiment.FederatedExperimentConfig( root_dir=self.create_tempdir(), num_rounds=5, checkpoint_frequency=3) client_sampler = FakeClientSampler() algorithm = fake_algorithm() state = federated_experiment.run_federated_experiment( algorithm=algorithm, init_state=(1, -1), client_sampler=client_sampler, config=config) self.assertEqual(state, (6, 14)) self.assertCountEqual( glob.glob(os.path.join(config.root_dir, 'checkpoint_*')), [os.path.join(config.root_dir, 'checkpoint_00000003')]) with self.subTest('checkpoint restore'): # Restored state is (4, 5). FakeSampler produces clients [5, 6]. state = federated_experiment.run_federated_experiment( algorithm=algorithm, init_state=None, client_sampler=FakeClientSampler(1), config=config) self.assertEqual(state, (6, 16)) self.assertCountEqual( glob.glob(os.path.join(config.root_dir, 'checkpoint_*')), [os.path.join(config.root_dir, 'checkpoint_00000004')])
def get(self): return federated_experiment.FederatedExperimentConfig( root_dir=self._get_flag('root_dir'), num_rounds=self._get_flag('num_rounds'), checkpoint_frequency=self._get_flag('checkpoint_frequency'), num_checkpoints_to_keep=self._get_flag('num_checkpoints_to_keep'), eval_frequency=self._get_flag('eval_frequency'))
def test_get(self): with self.subTest('default'): self.assertEqual( self.FEDERATED_EXPERIMENT_CONFIG.get(), federated_experiment.FederatedExperimentConfig( root_dir='foo', num_rounds=1234)) with self.subTest('custom'): with flagsaver.flagsaver(root_dir='bar', num_rounds=567, checkpoint_frequency=2, num_checkpoints_to_keep=3, eval_frequency=4): self.assertEqual( self.FEDERATED_EXPERIMENT_CONFIG.get(), federated_experiment.FederatedExperimentConfig( root_dir='bar', num_rounds=567, checkpoint_frequency=2, num_checkpoints_to_keep=3, eval_frequency=4))
def test_no_eval(self): config = federated_experiment.FederatedExperimentConfig( root_dir=self.create_tempdir(), num_rounds=5) client_sampler = FakeClientSampler() algorithm = fake_algorithm() state = federated_experiment.run_federated_experiment( algorithm=algorithm, init_state=algorithm.init(), client_sampler=client_sampler, config=config) self.assertEqual(state, (5, 15))
def test_run_federated_experiment_final_eval_fn_map(self): config = federated_experiment.FederatedExperimentConfig( root_dir=self.create_tempdir(), num_rounds=5, num_clients_per_round=2) federated_algorithm = test_util.MockFederatedAlgorithm(num_examples=3) federated_data = federated_algorithm.federated_data model = federated_algorithm.model final_eval_fn_map = collections.OrderedDict( full_eval=federated_experiment.FullEvaluationFn( federated_data, model)) federated_experiment.run_federated_experiment( config, federated_algorithm, final_eval_fn_map=final_eval_fn_map) self.assertTrue( os.path.exists(os.path.join(config.root_dir, 'full_eval.tsv')))
def test_final_eval_fn_map(self): config = federated_experiment.FederatedExperimentConfig( root_dir=self.create_tempdir(), num_rounds=5, eval_frequency=3) client_sampler = FakeClientSampler() algorithm = fake_algorithm() expected_state = {5: (5, 15)} state = federated_experiment.run_federated_experiment( algorithm=algorithm, init_state=algorithm.init(), client_sampler=client_sampler, config=config, final_eval_fn_map={ 'final_eval': FakeEvaluationFn(self, expected_state) }) self.assertEqual(state, (5, 15)) self.assertCountEqual( glob.glob(os.path.join(config.root_dir, 'final_eval.tsv')), [os.path.join(config.root_dir, 'final_eval.tsv')])
def __init__(self, name: Optional[str] = None): super().__init__(name) defaults = federated_experiment.FederatedExperimentConfig( root_dir='', num_rounds=-1) self._string('root_dir', None, 'Root directory of experiment outputs') self._integer('num_rounds', None, 'Number of federated training rounds') self._integer( 'checkpoint_frequency', defaults.checkpoint_frequency, 'Checkpoint frequency in rounds' + '. If <= 0, no checkpointing is done.') self._integer('num_checkpoints_to_keep', defaults.num_checkpoints_to_keep, 'Maximum number of checkpoints to keep') self._integer( 'eval_frequency', defaults.eval_frequency, 'Evaluation frequency in rounds' + '. If <= 0, no evaluation is done.')
def test_run_federated_experiment(self): config = federated_experiment.FederatedExperimentConfig( root_dir=self.create_tempdir(), num_rounds=5, num_clients_per_round=2, checkpoint_frequency=3) num_examples = 3 federated_algorithm = test_util.MockFederatedAlgorithm( num_examples=num_examples) state = federated_experiment.run_federated_experiment( config, federated_algorithm) self.assertEqual( state.count, config.num_rounds * config.num_clients_per_round * num_examples) self.assertTrue( os.path.exists(os.path.join(config.root_dir, 'checkpoint_00000003')))