def test_remove_unused_flags_with_empty_optimizer(self): hparam_dict = collections.OrderedDict([('optimizer', '')]) with self.assertRaisesRegex( ValueError, 'The flag optimizer was not set. ' 'Unable to determine the relevant optimizer.'): _ = optimizer_utils.remove_unused_flags( prefix=None, hparam_dict=hparam_dict)
def _get_hparam_flags(): """Returns an ordered dictionary of pertinent hyperparameter flags.""" hparam_dict = utils_impl.lookup_flag_values(shared_flags) # Update with optimizer flags corresponding to the chosen optimizers. opt_flag_dict = utils_impl.lookup_flag_values(optimizer_flags) opt_flag_dict = optimizer_utils.remove_unused_flags('client', opt_flag_dict) opt_flag_dict = optimizer_utils.remove_unused_flags('server', opt_flag_dict) hparam_dict.update(opt_flag_dict) # Update with task-specific flags. task_name = FLAGS.task if task_name in TASK_FLAGS: task_hparam_dict = utils_impl.lookup_flag_values(TASK_FLAGS[task_name]) hparam_dict.update(task_hparam_dict) return hparam_dict
def test_remove_unused_flags_without_prefix(self): hparam_dict = collections.OrderedDict([('optimizer', 'sgd'), ('value', 0.1), ('sgd_momentum', 0.3), ('adam_momentum', 0.5)]) relevant_hparam_dict = optimizer_utils.remove_unused_flags( prefix=None, hparam_dict=hparam_dict) expected_flag_names = ['optimizer', 'value', 'sgd_momentum'] self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names) self.assertEqual(relevant_hparam_dict['optimizer'], 'sgd') self.assertEqual(relevant_hparam_dict['value'], 0.1) self.assertEqual(relevant_hparam_dict['sgd_momentum'], 0.3)
def test_remove_unused_flags_with_prefix(self): hparam_dict = collections.OrderedDict([('client_optimizer', 'sgd'), ('non_client_value', 0.1), ('client_sgd_momentum', 0.3), ('client_adam_momentum', 0.5)]) relevant_hparam_dict = optimizer_utils.remove_unused_flags( 'client', hparam_dict) expected_flag_names = [ 'client_optimizer', 'non_client_value', 'client_sgd_momentum' ] self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names) self.assertEqual(relevant_hparam_dict['client_optimizer'], 'sgd') self.assertEqual(relevant_hparam_dict['non_client_value'], 0.1) self.assertEqual(relevant_hparam_dict['client_sgd_momentum'], 0.3)
def test_removal_with_standard_default_values(self): hparam_dict = collections.OrderedDict([('client_optimizer', 'adam'), ('non_client_value', 0), ('client_sgd_momentum', 0), ('client_adam_param1', None), ('client_adam_param2', False)]) relevant_hparam_dict = optimizer_utils.remove_unused_flags( 'client', hparam_dict) expected_flag_names = [ 'client_optimizer', 'non_client_value', 'client_adam_param1', 'client_adam_param2' ] self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names) self.assertEqual(relevant_hparam_dict['client_optimizer'], 'adam') self.assertEqual(relevant_hparam_dict['non_client_value'], 0) self.assertIsNone(relevant_hparam_dict['client_adam_param1']) self.assertEqual(relevant_hparam_dict['client_adam_param2'], False)
def test_remove_unused_flags_without_optimizer_flag(self): hparam_dict = collections.OrderedDict([('client_opt_fn', 'sgd'), ('client_sgd_momentum', 0.3)]) with self.assertRaisesRegex(ValueError, 'The flag client_optimizer was not defined.'): _ = optimizer_utils.remove_unused_flags('client', hparam_dict)