Пример #1
0
  def test_remove_unused_flags_with_empty_optimizer(self):
    hparam_dict = collections.OrderedDict([('optimizer', '')])

    with self.assertRaisesRegex(
        ValueError, 'The flag optimizer was not set. '
        'Unable to determine the relevant optimizer.'):
      _ = optimizer_utils.remove_unused_flags(
          prefix=None, hparam_dict=hparam_dict)
Пример #2
0
def _get_hparam_flags():
  """Returns an ordered dictionary of pertinent hyperparameter flags."""
  hparam_dict = utils_impl.lookup_flag_values(shared_flags)

  # Update with optimizer flags corresponding to the chosen optimizers.
  opt_flag_dict = utils_impl.lookup_flag_values(optimizer_flags)
  opt_flag_dict = optimizer_utils.remove_unused_flags('client', opt_flag_dict)
  opt_flag_dict = optimizer_utils.remove_unused_flags('server', opt_flag_dict)
  hparam_dict.update(opt_flag_dict)

  # Update with task-specific flags.
  task_name = FLAGS.task
  if task_name in TASK_FLAGS:
    task_hparam_dict = utils_impl.lookup_flag_values(TASK_FLAGS[task_name])
    hparam_dict.update(task_hparam_dict)

  return hparam_dict
Пример #3
0
 def test_remove_unused_flags_without_prefix(self):
   hparam_dict = collections.OrderedDict([('optimizer', 'sgd'), ('value', 0.1),
                                          ('sgd_momentum', 0.3),
                                          ('adam_momentum', 0.5)])
   relevant_hparam_dict = optimizer_utils.remove_unused_flags(
       prefix=None, hparam_dict=hparam_dict)
   expected_flag_names = ['optimizer', 'value', 'sgd_momentum']
   self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names)
   self.assertEqual(relevant_hparam_dict['optimizer'], 'sgd')
   self.assertEqual(relevant_hparam_dict['value'], 0.1)
   self.assertEqual(relevant_hparam_dict['sgd_momentum'], 0.3)
Пример #4
0
  def test_remove_unused_flags_with_prefix(self):
    hparam_dict = collections.OrderedDict([('client_optimizer', 'sgd'),
                                           ('non_client_value', 0.1),
                                           ('client_sgd_momentum', 0.3),
                                           ('client_adam_momentum', 0.5)])

    relevant_hparam_dict = optimizer_utils.remove_unused_flags(
        'client', hparam_dict)
    expected_flag_names = [
        'client_optimizer', 'non_client_value', 'client_sgd_momentum'
    ]
    self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names)
    self.assertEqual(relevant_hparam_dict['client_optimizer'], 'sgd')
    self.assertEqual(relevant_hparam_dict['non_client_value'], 0.1)
    self.assertEqual(relevant_hparam_dict['client_sgd_momentum'], 0.3)
Пример #5
0
  def test_removal_with_standard_default_values(self):
    hparam_dict = collections.OrderedDict([('client_optimizer', 'adam'),
                                           ('non_client_value', 0),
                                           ('client_sgd_momentum', 0),
                                           ('client_adam_param1', None),
                                           ('client_adam_param2', False)])

    relevant_hparam_dict = optimizer_utils.remove_unused_flags(
        'client', hparam_dict)
    expected_flag_names = [
        'client_optimizer', 'non_client_value', 'client_adam_param1',
        'client_adam_param2'
    ]
    self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names)
    self.assertEqual(relevant_hparam_dict['client_optimizer'], 'adam')
    self.assertEqual(relevant_hparam_dict['non_client_value'], 0)
    self.assertIsNone(relevant_hparam_dict['client_adam_param1'])
    self.assertEqual(relevant_hparam_dict['client_adam_param2'], False)
Пример #6
0
 def test_remove_unused_flags_without_optimizer_flag(self):
   hparam_dict = collections.OrderedDict([('client_opt_fn', 'sgd'),
                                          ('client_sgd_momentum', 0.3)])
   with self.assertRaisesRegex(ValueError,
                               'The flag client_optimizer was not defined.'):
     _ = optimizer_utils.remove_unused_flags('client', hparam_dict)