コード例 #1
0
 def test_periodic_eval_fn_map(self):
   config = federated_experiment.FederatedExperimentConfig(
       root_dir=self.create_tempdir(), num_rounds=5, eval_frequency=3)
   client_sampler = FakeClientSampler()
   algorithm = fake_algorithm()
   expected_state = {1: (1, 1), 3: (3, 6)}
   expected_client_ids = {1: [1], 3: [3]}
   state = federated_experiment.run_federated_experiment(
       algorithm=algorithm,
       init_state=algorithm.init(),
       client_sampler=client_sampler,
       config=config,
       periodic_eval_fn_map={
           'test_eval':
               FakeEvaluationFn(self, expected_state),
           'train_eval':
               FakeTrainClientsEvaluationFn(self, expected_state,
                                            expected_client_ids)
       })
   self.assertEqual(state, (5, 15))
   self.assertCountEqual(
       glob.glob(os.path.join(config.root_dir, '*eval*')), [
           os.path.join(config.root_dir, 'test_eval'),
           os.path.join(config.root_dir, 'train_eval')
       ])
コード例 #2
0
    def test_run_federated_experiment_periodic_eval_fn_map(self):
        config = federated_experiment.FederatedExperimentConfig(
            root_dir=self.create_tempdir(),
            num_rounds=5,
            num_clients_per_round=2,
            eval_frequency=3)
        federated_algorithm = test_util.MockFederatedAlgorithm(num_examples=3)
        federated_data = federated_algorithm.federated_data
        model = federated_algorithm.model
        periodic_eval_fn_map = collections.OrderedDict(
            client_eval_1=federated_experiment.ClientEvaluationFn(
                federated_data, model, config),
            client_eval_2=federated_experiment.ClientEvaluationFn(
                federated_data, model, config),
            full_eval_1=federated_experiment.FullEvaluationFn(
                federated_data, model))

        federated_experiment.run_federated_experiment(
            config,
            federated_algorithm,
            periodic_eval_fn_map=periodic_eval_fn_map)

        self.assertTrue(
            os.path.exists(os.path.join(config.root_dir, 'client_eval_1')))
        self.assertTrue(
            os.path.exists(os.path.join(config.root_dir, 'client_eval_2')))
        self.assertTrue(
            os.path.exists(os.path.join(config.root_dir, 'full_eval_1')))
コード例 #3
0
  def test_checkpoint(self):
    with self.subTest('checkpoint init'):
      config = federated_experiment.FederatedExperimentConfig(
          root_dir=self.create_tempdir(), num_rounds=5, checkpoint_frequency=3)
      client_sampler = FakeClientSampler()
      algorithm = fake_algorithm()
      state = federated_experiment.run_federated_experiment(
          algorithm=algorithm,
          init_state=(1, -1),
          client_sampler=client_sampler,
          config=config)
      self.assertEqual(state, (6, 14))
      self.assertCountEqual(
          glob.glob(os.path.join(config.root_dir, 'checkpoint_*')),
          [os.path.join(config.root_dir, 'checkpoint_00000003')])

    with self.subTest('checkpoint restore'):
      # Restored state is (4, 5). FakeSampler produces clients [5, 6].
      state = federated_experiment.run_federated_experiment(
          algorithm=algorithm,
          init_state=None,
          client_sampler=FakeClientSampler(1),
          config=config)
      self.assertEqual(state, (6, 16))
      self.assertCountEqual(
          glob.glob(os.path.join(config.root_dir, 'checkpoint_*')),
          [os.path.join(config.root_dir, 'checkpoint_00000004')])
コード例 #4
0
 def get(self):
     return federated_experiment.FederatedExperimentConfig(
         root_dir=self._get_flag('root_dir'),
         num_rounds=self._get_flag('num_rounds'),
         checkpoint_frequency=self._get_flag('checkpoint_frequency'),
         num_checkpoints_to_keep=self._get_flag('num_checkpoints_to_keep'),
         eval_frequency=self._get_flag('eval_frequency'))
コード例 #5
0
 def test_get(self):
     with self.subTest('default'):
         self.assertEqual(
             self.FEDERATED_EXPERIMENT_CONFIG.get(),
             federated_experiment.FederatedExperimentConfig(
                 root_dir='foo', num_rounds=1234))
     with self.subTest('custom'):
         with flagsaver.flagsaver(root_dir='bar',
                                  num_rounds=567,
                                  checkpoint_frequency=2,
                                  num_checkpoints_to_keep=3,
                                  eval_frequency=4):
             self.assertEqual(
                 self.FEDERATED_EXPERIMENT_CONFIG.get(),
                 federated_experiment.FederatedExperimentConfig(
                     root_dir='bar',
                     num_rounds=567,
                     checkpoint_frequency=2,
                     num_checkpoints_to_keep=3,
                     eval_frequency=4))
コード例 #6
0
 def test_no_eval(self):
   config = federated_experiment.FederatedExperimentConfig(
       root_dir=self.create_tempdir(), num_rounds=5)
   client_sampler = FakeClientSampler()
   algorithm = fake_algorithm()
   state = federated_experiment.run_federated_experiment(
       algorithm=algorithm,
       init_state=algorithm.init(),
       client_sampler=client_sampler,
       config=config)
   self.assertEqual(state, (5, 15))
コード例 #7
0
    def test_run_federated_experiment_final_eval_fn_map(self):
        config = federated_experiment.FederatedExperimentConfig(
            root_dir=self.create_tempdir(),
            num_rounds=5,
            num_clients_per_round=2)
        federated_algorithm = test_util.MockFederatedAlgorithm(num_examples=3)
        federated_data = federated_algorithm.federated_data
        model = federated_algorithm.model
        final_eval_fn_map = collections.OrderedDict(
            full_eval=federated_experiment.FullEvaluationFn(
                federated_data, model))

        federated_experiment.run_federated_experiment(
            config, federated_algorithm, final_eval_fn_map=final_eval_fn_map)

        self.assertTrue(
            os.path.exists(os.path.join(config.root_dir, 'full_eval.tsv')))
コード例 #8
0
 def test_final_eval_fn_map(self):
   config = federated_experiment.FederatedExperimentConfig(
       root_dir=self.create_tempdir(), num_rounds=5, eval_frequency=3)
   client_sampler = FakeClientSampler()
   algorithm = fake_algorithm()
   expected_state = {5: (5, 15)}
   state = federated_experiment.run_federated_experiment(
       algorithm=algorithm,
       init_state=algorithm.init(),
       client_sampler=client_sampler,
       config=config,
       final_eval_fn_map={
           'final_eval': FakeEvaluationFn(self, expected_state)
       })
   self.assertEqual(state, (5, 15))
   self.assertCountEqual(
       glob.glob(os.path.join(config.root_dir, 'final_eval.tsv')),
       [os.path.join(config.root_dir, 'final_eval.tsv')])
コード例 #9
0
 def __init__(self, name: Optional[str] = None):
     super().__init__(name)
     defaults = federated_experiment.FederatedExperimentConfig(
         root_dir='', num_rounds=-1)
     self._string('root_dir', None, 'Root directory of experiment outputs')
     self._integer('num_rounds', None,
                   'Number of federated training rounds')
     self._integer(
         'checkpoint_frequency', defaults.checkpoint_frequency,
         'Checkpoint frequency in rounds' +
         '. If <= 0, no checkpointing is done.')
     self._integer('num_checkpoints_to_keep',
                   defaults.num_checkpoints_to_keep,
                   'Maximum number of checkpoints to keep')
     self._integer(
         'eval_frequency', defaults.eval_frequency,
         'Evaluation frequency in rounds' +
         '. If <= 0, no evaluation is done.')
コード例 #10
0
    def test_run_federated_experiment(self):
        config = federated_experiment.FederatedExperimentConfig(
            root_dir=self.create_tempdir(),
            num_rounds=5,
            num_clients_per_round=2,
            checkpoint_frequency=3)
        num_examples = 3
        federated_algorithm = test_util.MockFederatedAlgorithm(
            num_examples=num_examples)

        state = federated_experiment.run_federated_experiment(
            config, federated_algorithm)

        self.assertEqual(
            state.count,
            config.num_rounds * config.num_clients_per_round * num_examples)
        self.assertTrue(
            os.path.exists(os.path.join(config.root_dir,
                                        'checkpoint_00000003')))