def test_remove_unused_flags_with_empty_optimizer(self): hparam_dict = collections.OrderedDict([('optimizer', '')]) with self.assertRaisesRegex( ValueError, 'The flag optimizer was not set. ' 'Unable to determine the relevant optimizer.'): _ = optimizer_utils.remove_unused_flags(prefix=None, hparam_dict=hparam_dict)
def test_remove_flags_with_optimizers_sharing_a_prefix(self): hparam_dict = collections.OrderedDict([('client_optimizer', 'adamW'), ('client_adam_momentum', 0.3), ('client_adamW_momentum', 0.5)]) relevant_hparam_dict = optimizer_utils.remove_unused_flags( 'client', hparam_dict) expected_flag_names = ['client_optimizer', 'client_adamW_momentum'] self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names) self.assertEqual(relevant_hparam_dict['client_optimizer'], 'adamW') self.assertEqual(relevant_hparam_dict['client_adamW_momentum'], 0.5)
def test_remove_unused_flags_without_prefix(self): hparam_dict = collections.OrderedDict([('optimizer', 'sgd'), ('value', 0.1), ('sgd_momentum', 0.3), ('adam_momentum', 0.5)]) relevant_hparam_dict = optimizer_utils.remove_unused_flags( prefix=None, hparam_dict=hparam_dict) expected_flag_names = ['optimizer', 'value', 'sgd_momentum'] self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names) self.assertEqual(relevant_hparam_dict['optimizer'], 'sgd') self.assertEqual(relevant_hparam_dict['value'], 0.1) self.assertEqual(relevant_hparam_dict['sgd_momentum'], 0.3)
def _write_hparam_flags(): """Creates an ordered dictionary of hyperparameter flags and writes to CSV.""" hparam_dict = utils_impl.lookup_flag_values(shared_flags) # Update with optimizer flags corresponding to the chosen optimizers. opt_flag_dict = utils_impl.lookup_flag_values(optimizer_flags) opt_flag_dict = optimizer_utils.remove_unused_flags( 'client', opt_flag_dict) opt_flag_dict = optimizer_utils.remove_unused_flags( 'server', opt_flag_dict) hparam_dict.update(opt_flag_dict) # Update with task-specific flags. task_name = FLAGS.task if task_name in TASK_FLAGS: task_hparam_dict = utils_impl.lookup_flag_values(TASK_FLAGS[task_name]) hparam_dict.update(task_hparam_dict) results_dir = os.path.join(FLAGS.root_output_dir, 'results', FLAGS.experiment_name) utils_impl.create_directory_if_not_exists(results_dir) hparam_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)
def test_removal_with_standard_default_values(self): hparam_dict = collections.OrderedDict([('client_optimizer', 'adam'), ('non_client_value', 0), ('client_sgd_momentum', 0), ('client_adam_param1', None), ('client_adam_param2', False)]) relevant_hparam_dict = optimizer_utils.remove_unused_flags( 'client', hparam_dict) expected_flag_names = [ 'client_optimizer', 'non_client_value', 'client_adam_param1', 'client_adam_param2' ] self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names) self.assertEqual(relevant_hparam_dict['client_optimizer'], 'adam') self.assertEqual(relevant_hparam_dict['non_client_value'], 0) self.assertIsNone(relevant_hparam_dict['client_adam_param1']) self.assertEqual(relevant_hparam_dict['client_adam_param2'], False)
def test_remove_unused_flags_without_optimizer_flag(self): hparam_dict = collections.OrderedDict([('client_opt_fn', 'sgd'), ('client_sgd_momentum', 0.3)]) with self.assertRaisesRegex( ValueError, 'The flag client_optimizer was not defined.'): _ = optimizer_utils.remove_unused_flags('client', hparam_dict)