def test_remove_unused_flags_with_empty_optimizer(self):
        hparam_dict = collections.OrderedDict([('optimizer', '')])

        with self.assertRaisesRegex(
                ValueError, 'The flag optimizer was not set. '
                'Unable to determine the relevant optimizer.'):
            _ = optimizer_utils.remove_unused_flags(prefix=None,
                                                    hparam_dict=hparam_dict)
 def test_remove_flags_with_optimizers_sharing_a_prefix(self):
     hparam_dict = collections.OrderedDict([('client_optimizer', 'adamW'),
                                            ('client_adam_momentum', 0.3),
                                            ('client_adamW_momentum', 0.5)])
     relevant_hparam_dict = optimizer_utils.remove_unused_flags(
         'client', hparam_dict)
     expected_flag_names = ['client_optimizer', 'client_adamW_momentum']
     self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names)
     self.assertEqual(relevant_hparam_dict['client_optimizer'], 'adamW')
     self.assertEqual(relevant_hparam_dict['client_adamW_momentum'], 0.5)
 def test_remove_unused_flags_without_prefix(self):
     hparam_dict = collections.OrderedDict([('optimizer', 'sgd'),
                                            ('value', 0.1),
                                            ('sgd_momentum', 0.3),
                                            ('adam_momentum', 0.5)])
     relevant_hparam_dict = optimizer_utils.remove_unused_flags(
         prefix=None, hparam_dict=hparam_dict)
     expected_flag_names = ['optimizer', 'value', 'sgd_momentum']
     self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names)
     self.assertEqual(relevant_hparam_dict['optimizer'], 'sgd')
     self.assertEqual(relevant_hparam_dict['value'], 0.1)
     self.assertEqual(relevant_hparam_dict['sgd_momentum'], 0.3)
示例#4
0
def _write_hparam_flags():
    """Creates an ordered dictionary of hyperparameter flags and writes to CSV."""
    hparam_dict = utils_impl.lookup_flag_values(shared_flags)

    # Update with optimizer flags corresponding to the chosen optimizers.
    opt_flag_dict = utils_impl.lookup_flag_values(optimizer_flags)
    opt_flag_dict = optimizer_utils.remove_unused_flags(
        'client', opt_flag_dict)
    opt_flag_dict = optimizer_utils.remove_unused_flags(
        'server', opt_flag_dict)
    hparam_dict.update(opt_flag_dict)

    # Update with task-specific flags.
    task_name = FLAGS.task
    if task_name in TASK_FLAGS:
        task_hparam_dict = utils_impl.lookup_flag_values(TASK_FLAGS[task_name])
        hparam_dict.update(task_hparam_dict)

    results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                               FLAGS.experiment_name)
    utils_impl.create_directory_if_not_exists(results_dir)
    hparam_file = os.path.join(results_dir, 'hparams.csv')
    utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)
    def test_removal_with_standard_default_values(self):
        hparam_dict = collections.OrderedDict([('client_optimizer', 'adam'),
                                               ('non_client_value', 0),
                                               ('client_sgd_momentum', 0),
                                               ('client_adam_param1', None),
                                               ('client_adam_param2', False)])

        relevant_hparam_dict = optimizer_utils.remove_unused_flags(
            'client', hparam_dict)
        expected_flag_names = [
            'client_optimizer', 'non_client_value', 'client_adam_param1',
            'client_adam_param2'
        ]
        self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names)
        self.assertEqual(relevant_hparam_dict['client_optimizer'], 'adam')
        self.assertEqual(relevant_hparam_dict['non_client_value'], 0)
        self.assertIsNone(relevant_hparam_dict['client_adam_param1'])
        self.assertEqual(relevant_hparam_dict['client_adam_param2'], False)
 def test_remove_unused_flags_without_optimizer_flag(self):
     hparam_dict = collections.OrderedDict([('client_opt_fn', 'sgd'),
                                            ('client_sgd_momentum', 0.3)])
     with self.assertRaisesRegex(
             ValueError, 'The flag client_optimizer was not defined.'):
         _ = optimizer_utils.remove_unused_flags('client', hparam_dict)