Ejemplo n.º 1
0
def _write_hparam_flags():
  """Returns an ordered dictionary of pertinent hyperparameter flags."""
  hparam_dict = utils_impl.lookup_flag_values(shared_flags)

  # Update with optimizer flags corresponding to the chosen optimizers.
  opt_flag_dict = utils_impl.lookup_flag_values(optimizer_flags)
  opt_flag_dict = optimizer_utils.remove_unused_flags('client', opt_flag_dict)
  opt_flag_dict = optimizer_utils.remove_unused_flags('server', opt_flag_dict)
  hparam_dict.update(opt_flag_dict)

  results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                             FLAGS.experiment_name)
  utils_impl.create_directory_if_not_exists(results_dir)
  hparam_file = os.path.join(results_dir, 'hparams.csv')
  utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)
Ejemplo n.º 2
0
def get_hparam_flags():
    """Returns an ordered dictionary of pertinent hyperparameter flags."""
    hparam_dict = utils_impl.lookup_flag_values(shared_flags)

    # Update with optimizer flags corresponding to the chosen optimizers.
    opt_flag_dict = utils_impl.lookup_flag_values(optimizer_flags)
    opt_flag_dict = optimizer_utils.remove_unused_flags(
        'client', opt_flag_dict)
    opt_flag_dict = optimizer_utils.remove_unused_flags(
        'server', opt_flag_dict)
    hparam_dict.update(opt_flag_dict)

    # Update with task-specific flags.
    task_hparam_dict = utils_impl.lookup_flag_values(gld_flags)
    hparam_dict.update(task_hparam_dict)

    return hparam_dict
Ejemplo n.º 3
0
    def test_remove_unused_flags_with_empty_optimizer(self):
        hparam_dict = collections.OrderedDict([('optimizer', '')])

        with self.assertRaisesRegex(
                ValueError, 'The flag optimizer was not set. '
                'Unable to determine the relevant optimizer.'):
            _ = optimizer_utils.remove_unused_flags(prefix=None,
                                                    hparam_dict=hparam_dict)
Ejemplo n.º 4
0
def _write_hparam_flags():
    """Creates an ordered dictionary of hyperparameter flags and writes to CSV."""
    hparam_dict = utils_impl.lookup_flag_values(shared_flags)

    # Update with optimizer flags corresponding to the chosen optimizers.
    opt_flag_dict = utils_impl.lookup_flag_values(optimizer_flags)
    opt_flag_dict = optimizer_utils.remove_unused_flags(
        'client', opt_flag_dict)
    opt_flag_dict = optimizer_utils.remove_unused_flags(
        'server', opt_flag_dict)
    if FLAGS.task == 'stackoverflow_nwp_finetune':
        opt_flag_dict = optimizer_utils.remove_unused_flags(
            'finetune', opt_flag_dict)
    else:
        opt_flag_dict = optimizer_utils.remove_unused_flags(
            'reconstruction', opt_flag_dict)
    hparam_dict.update(opt_flag_dict)

    # Update with task-specific flags.
    task_name = FLAGS.task
    if task_name in TASK_FLAGS:
        task_hparam_dict = utils_impl.lookup_flag_values(TASK_FLAGS[task_name])
        hparam_dict.update(task_hparam_dict)

    # Update with finetune flags
    if FLAGS.task == 'stackoverflow_nwp_finetune':
        finetune_hparam_dict = utils_impl.lookup_flag_values(finetune_flags)
        hparam_dict.update(finetune_hparam_dict)

    # Update with reconstruction flags.
    recon_hparam_dict = utils_impl.lookup_flag_values(recon_flags)
    hparam_dict.update(recon_hparam_dict)

    # Update with DP flags.
    dp_hparam_dict = utils_impl.lookup_flag_values(dp_flags)
    hparam_dict.update(dp_hparam_dict)

    # Update with run flags.
    run_hparam_dict = utils_impl.lookup_flag_values(run_flags)
    hparam_dict.update(run_hparam_dict)

    results_dir = os.path.join(FLAGS.root_output_dir, 'results',
                               FLAGS.experiment_name)
    utils_impl.create_directory_if_not_exists(results_dir)
    hparam_file = os.path.join(results_dir, 'hparams.csv')
    utils_impl.atomic_write_series_to_csv(hparam_dict, hparam_file)
Ejemplo n.º 5
0
 def test_remove_unused_flags_without_prefix(self):
     hparam_dict = collections.OrderedDict([('optimizer', 'sgd'),
                                            ('value', 0.1),
                                            ('sgd_momentum', 0.3),
                                            ('adam_momentum', 0.5)])
     relevant_hparam_dict = optimizer_utils.remove_unused_flags(
         prefix=None, hparam_dict=hparam_dict)
     expected_flag_names = ['optimizer', 'value', 'sgd_momentum']
     self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names)
     self.assertEqual(relevant_hparam_dict['optimizer'], 'sgd')
     self.assertEqual(relevant_hparam_dict['value'], 0.1)
     self.assertEqual(relevant_hparam_dict['sgd_momentum'], 0.3)
Ejemplo n.º 6
0
    def test_remove_unused_flags_with_prefix(self):
        hparam_dict = collections.OrderedDict([('client_optimizer', 'sgd'),
                                               ('non_client_value', 0.1),
                                               ('client_sgd_momentum', 0.3),
                                               ('client_adam_momentum', 0.5)])

        relevant_hparam_dict = optimizer_utils.remove_unused_flags(
            'client', hparam_dict)
        expected_flag_names = [
            'client_optimizer', 'non_client_value', 'client_sgd_momentum'
        ]
        self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names)
        self.assertEqual(relevant_hparam_dict['client_optimizer'], 'sgd')
        self.assertEqual(relevant_hparam_dict['non_client_value'], 0.1)
        self.assertEqual(relevant_hparam_dict['client_sgd_momentum'], 0.3)
Ejemplo n.º 7
0
    def test_removal_with_standard_default_values(self):
        hparam_dict = collections.OrderedDict([('client_optimizer', 'adam'),
                                               ('non_client_value', 0),
                                               ('client_sgd_momentum', 0),
                                               ('client_adam_param1', None),
                                               ('client_adam_param2', False)])

        relevant_hparam_dict = optimizer_utils.remove_unused_flags(
            'client', hparam_dict)
        expected_flag_names = [
            'client_optimizer', 'non_client_value', 'client_adam_param1',
            'client_adam_param2'
        ]
        self.assertCountEqual(relevant_hparam_dict.keys(), expected_flag_names)
        self.assertEqual(relevant_hparam_dict['client_optimizer'], 'adam')
        self.assertEqual(relevant_hparam_dict['non_client_value'], 0)
        self.assertIsNone(relevant_hparam_dict['client_adam_param1'])
        self.assertEqual(relevant_hparam_dict['client_adam_param2'], False)
Ejemplo n.º 8
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    hparams_dict = utils_impl.lookup_flag_values(hparam_flags)
    hparams_dict = optimizer_utils.remove_unused_flags('centralized',
                                                       hparams_dict)

    centralized_main.run_centralized(
        optimizer_utils.create_optimizer_fn_from_flags('centralized')(),
        FLAGS.num_epochs,
        FLAGS.batch_size,
        vocab_size=FLAGS.vocab_size,
        d_embed=FLAGS.d_embed,
        d_model=FLAGS.d_model,
        d_hidden=FLAGS.d_hidden,
        num_heads=FLAGS.num_heads,
        num_layers=FLAGS.num_layers,
        dropout=FLAGS.dropout,
        experiment_name=FLAGS.experiment_name,
        root_output_dir=FLAGS.root_output_dir,
        max_batches=FLAGS.max_batches,
        hparams_dict=hparams_dict)
Ejemplo n.º 9
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')

    hparams_dict = utils_impl.lookup_flag_values(hparam_flags)
    hparams_dict = optimizer_utils.remove_unused_flags('centralized',
                                                       hparams_dict)

    dataset_type = dataset.DatasetType.GLD23K
    if FLAGS.dataset_type == 'gld160k':
        dataset_type = dataset.DatasetType.GLD160K

    centralized_main.run_centralized(
        optimizer=optimizer_utils.create_optimizer_fn_from_flags(
            'centralized')(),
        image_size=FLAGS.image_size,
        num_epochs=FLAGS.num_epochs,
        batch_size=FLAGS.batch_size,
        num_groups=FLAGS.num_groups,
        dataset_type=dataset_type,
        experiment_name=FLAGS.experiment_name,
        root_output_dir=FLAGS.root_output_dir,
        dropout_prob=FLAGS.dropout_prob,
        hparams_dict=hparams_dict)
Ejemplo n.º 10
0
 def test_remove_unused_flags_without_optimizer_flag(self):
     hparam_dict = collections.OrderedDict([('client_opt_fn', 'sgd'),
                                            ('client_sgd_momentum', 0.3)])
     with self.assertRaisesRegex(
             ValueError, 'The flag client_optimizer was not defined.'):
         _ = optimizer_utils.remove_unused_flags('client', hparam_dict)