def _write_hparam_flags(): """Returns an ordered dictionary of pertinent hyperparameter flags.""" hparam_dict = utils_impl.lookup_flag_values(shared_flags) # Update with optimizer flags corresponding to the chosen optimizers. opt_flag_dict = utils_impl.lookup_flag_values(optimizer_flags) opt_flag_dict = optimizer_utils.remove_unused_flags('client', opt_flag_dict) opt_flag_dict = optimizer_utils.remove_unused_flags('server', opt_flag_dict) hparam_dict.update(opt_flag_dict) results_dir = os.path.join(FLAGS.root_output_dir, 'results', FLAGS.experiment_name) utils_impl.create_directory_if_not_exists(results_dir) hparam_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)
def get_hparam_flags(): """Returns an ordered dictionary of pertinent hyperparameter flags.""" hparam_dict = utils_impl.lookup_flag_values(shared_flags) # Update with optimizer flags corresponding to the chosen optimizers. opt_flag_dict = utils_impl.lookup_flag_values(optimizer_flags) opt_flag_dict = optimizer_utils.remove_unused_flags( 'client', opt_flag_dict) opt_flag_dict = optimizer_utils.remove_unused_flags( 'server', opt_flag_dict) hparam_dict.update(opt_flag_dict) # Update with task-specific flags. task_hparam_dict = utils_impl.lookup_flag_values(gld_flags) hparam_dict.update(task_hparam_dict) return hparam_dict
def test_remove_unused_flags_with_empty_optimizer(self): hparam_dict = collections.OrderedDict([('optimizer', '')]) with self.assertRaisesRegex( ValueError, 'The flag optimizer was not set. ' 'Unable to determine the relevant optimizer.'): _ = optimizer_utils.remove_unused_flags(prefix=None, hparam_dict=hparam_dict)
def _write_hparam_flags(): """Creates an ordered dictionary of hyperparameter flags and writes to CSV.""" hparam_dict = utils_impl.lookup_flag_values(shared_flags) # Update with optimizer flags corresponding to the chosen optimizers. opt_flag_dict = utils_impl.lookup_flag_values(optimizer_flags) opt_flag_dict = optimizer_utils.remove_unused_flags( 'client', opt_flag_dict) opt_flag_dict = optimizer_utils.remove_unused_flags( 'server', opt_flag_dict) if FLAGS.task == 'stackoverflow_nwp_finetune': opt_flag_dict = optimizer_utils.remove_unused_flags( 'finetune', opt_flag_dict) else: opt_flag_dict = optimizer_utils.remove_unused_flags( 'reconstruction', opt_flag_dict) hparam_dict.update(opt_flag_dict) # Update with task-specific flags. task_name = FLAGS.task if task_name in TASK_FLAGS: task_hparam_dict = utils_impl.lookup_flag_values(TASK_FLAGS[task_name]) hparam_dict.update(task_hparam_dict) # Update with finetune flags if FLAGS.task == 'stackoverflow_nwp_finetune': finetune_hparam_dict = utils_impl.lookup_flag_values(finetune_flags) hparam_dict.update(finetune_hparam_dict) # Update with reconstruction flags. recon_hparam_dict = utils_impl.lookup_flag_values(recon_flags) hparam_dict.update(recon_hparam_dict) # Update with DP flags. dp_hparam_dict = utils_impl.lookup_flag_values(dp_flags) hparam_dict.update(dp_hparam_dict) # Update with run flags. run_hparam_dict = utils_impl.lookup_flag_values(run_flags) hparam_dict.update(run_hparam_dict) results_dir = os.path.join(FLAGS.root_output_dir, 'results', FLAGS.experiment_name) utils_impl.create_directory_if_not_exists(results_dir) hparam_file = os.path.join(results_dir, 'hparams.csv') utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)
def test_remove_unused_flags_without_prefix(self): hparam_dict = collections.OrderedDict([('optimizer', 'sgd'), ('value', 0.1), ('sgd_momentum', 0.3), ('adam_momentum', 0.5)]) relevant_hparam_dict = optimizer_utils.remove_unused_flags( prefix=None, hparam_dict=hparam_dict) expected_flag_names = ['optimizer', 'value', 'sgd_momentum'] self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names) self.assertEqual(relevant_hparam_dict['optimizer'], 'sgd') self.assertEqual(relevant_hparam_dict['value'], 0.1) self.assertEqual(relevant_hparam_dict['sgd_momentum'], 0.3)
def test_remove_unused_flags_with_prefix(self): hparam_dict = collections.OrderedDict([('client_optimizer', 'sgd'), ('non_client_value', 0.1), ('client_sgd_momentum', 0.3), ('client_adam_momentum', 0.5)]) relevant_hparam_dict = optimizer_utils.remove_unused_flags( 'client', hparam_dict) expected_flag_names = [ 'client_optimizer', 'non_client_value', 'client_sgd_momentum' ] self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names) self.assertEqual(relevant_hparam_dict['client_optimizer'], 'sgd') self.assertEqual(relevant_hparam_dict['non_client_value'], 0.1) self.assertEqual(relevant_hparam_dict['client_sgd_momentum'], 0.3)
def test_removal_with_standard_default_values(self): hparam_dict = collections.OrderedDict([('client_optimizer', 'adam'), ('non_client_value', 0), ('client_sgd_momentum', 0), ('client_adam_param1', None), ('client_adam_param2', False)]) relevant_hparam_dict = optimizer_utils.remove_unused_flags( 'client', hparam_dict) expected_flag_names = [ 'client_optimizer', 'non_client_value', 'client_adam_param1', 'client_adam_param2' ] self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names) self.assertEqual(relevant_hparam_dict['client_optimizer'], 'adam') self.assertEqual(relevant_hparam_dict['non_client_value'], 0) self.assertIsNone(relevant_hparam_dict['client_adam_param1']) self.assertEqual(relevant_hparam_dict['client_adam_param2'], False)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') hparams_dict = utils_impl.lookup_flag_values(hparam_flags) hparams_dict = optimizer_utils.remove_unused_flags('centralized', hparams_dict) centralized_main.run_centralized( optimizer_utils.create_optimizer_fn_from_flags('centralized')(), FLAGS.num_epochs, FLAGS.batch_size, vocab_size=FLAGS.vocab_size, d_embed=FLAGS.d_embed, d_model=FLAGS.d_model, d_hidden=FLAGS.d_hidden, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, dropout=FLAGS.dropout, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, max_batches=FLAGS.max_batches, hparams_dict=hparams_dict)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') hparams_dict = utils_impl.lookup_flag_values(hparam_flags) hparams_dict = optimizer_utils.remove_unused_flags('centralized', hparams_dict) dataset_type = dataset.DatasetType.GLD23K if FLAGS.dataset_type == 'gld160k': dataset_type = dataset.DatasetType.GLD160K centralized_main.run_centralized( optimizer=optimizer_utils.create_optimizer_fn_from_flags( 'centralized')(), image_size=FLAGS.image_size, num_epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size, num_groups=FLAGS.num_groups, dataset_type=dataset_type, experiment_name=FLAGS.experiment_name, root_output_dir=FLAGS.root_output_dir, dropout_prob=FLAGS.dropout_prob, hparams_dict=hparams_dict)
def test_remove_unused_flags_without_optimizer_flag(self): hparam_dict = collections.OrderedDict([('client_opt_fn', 'sgd'), ('client_sgd_momentum', 0.3)]) with self.assertRaisesRegex( ValueError, 'The flag client_optimizer was not defined.'): _ = optimizer_utils.remove_unused_flags('client', hparam_dict)