def test_override_params_dict_using_csv_string(self):
     params = params_dict.ParamsDict({
         'a': 1,
         'b': {
             'b1': 2,
             'b2': [2, 3],
         },
         'd': {
             'd1': {
                 'd2': 'hello'
             }
         },
         'e': False
     })
     override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
     params = params_dict.override_params_dict(params,
                                               override_csv_string,
                                               is_strict=True)
     self.assertEqual(1, params.a)
     self.assertEqual(2, params.b.b1)
     self.assertEqual([3, 4], params.b.b2)
     self.assertEqual('hi, world', params.d.d1.d2)
     self.assertEqual('gs://test', params.e)
     # Test different float formats
     override_csv_string = 'b.b2=-1.e-3, d.d1.d2=+0.001, e=1e+3, a=-1.5E-3'
     params = params_dict.override_params_dict(params,
                                               override_csv_string,
                                               is_strict=True)
     self.assertEqual(-1e-3, params.b.b2)
     self.assertEqual(0.001, params.d.d1.d2)
     self.assertEqual(1e3, params.e)
     self.assertEqual(-1.5e-3, params.a)
Ejemplo n.º 2
0
def _override_exp_config_by_file(exp_config, exp_config_files):
  """Overrides an `ExperimentConfig` object by files."""
  for exp_config_file in exp_config_files:
    if not tf.io.gfile.exists(exp_config_file):
      raise ValueError('%s does not exist.' % exp_config_file)
    params_dict.override_params_dict(
        exp_config, exp_config_file, is_strict=True)

  return exp_config
Ejemplo n.º 3
0
def run(callbacks=None):
    keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)

    params = config_factory.config_generator(FLAGS.model)

    params = params_dict.override_params_dict(params,
                                              FLAGS.config_file,
                                              is_strict=True)

    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params.override(
        {
            'strategy_type': FLAGS.strategy_type,
            'model_dir': FLAGS.model_dir,
            'strategy_config': executor.strategy_flags_dict(),
        },
        is_strict=False)
    params.validate()
    params.lock()
    pp = pprint.PrettyPrinter()
    params_str = pp.pformat(params.as_dict())
    logging.info('Model Parameters: {}'.format(params_str))

    train_input_fn = None
    eval_input_fn = None
    training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern
    eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern
    if not training_file_pattern and not eval_file_pattern:
        raise ValueError(
            'Must provide at least one of training_file_pattern and '
            'eval_file_pattern.')

    if training_file_pattern:
        # Use global batch size for single host.
        train_input_fn = input_reader.InputFn(
            file_pattern=training_file_pattern,
            params=params,
            mode=input_reader.ModeKeys.TRAIN,
            batch_size=params.train.batch_size)

    if eval_file_pattern:
        eval_input_fn = input_reader.InputFn(
            file_pattern=eval_file_pattern,
            params=params,
            mode=input_reader.ModeKeys.PREDICT_WITH_GT,
            batch_size=params.eval.batch_size,
            num_examples=params.eval.eval_samples)
    return run_executor(params,
                        train_input_fn=train_input_fn,
                        eval_input_fn=eval_input_fn,
                        callbacks=callbacks)
Ejemplo n.º 4
0
 def test_end_to_end_multi_eval(self, distribution_strategy, flag_mode):
   model_dir = self.get_temp_dir()
   experiment_config = configs.MultiEvalExperimentConfig(
       task=test_utils.FooConfig(),
       eval_tasks=(configs.TaskRoutine(
           task_name='foo', task_config=test_utils.FooConfig(), eval_steps=2),
                   configs.TaskRoutine(
                       task_name='bar',
                       task_config=test_utils.BarConfig(),
                       eval_steps=3)))
   experiment_config = params_dict.override_params_dict(
       experiment_config, self._test_config, is_strict=False)
   with distribution_strategy.scope():
     train_task = task_factory.get_task(experiment_config.task)
     eval_tasks = [
         task_factory.get_task(config.task_config, name=config.task_name)
         for config in experiment_config.eval_tasks
     ]
   train_lib.run_experiment_with_multitask_eval(
       distribution_strategy=distribution_strategy,
       train_task=train_task,
       eval_tasks=eval_tasks,
       mode=flag_mode,
       params=experiment_config,
       model_dir=model_dir)
Ejemplo n.º 5
0
 def test_override_params_dict_using_yaml_string(self):
   params = params_dict.ParamsDict({
       'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
   override_yaml_string = "'b': 5.2\n'c': [30, 40]"
   params = params_dict.override_params_dict(
       params, override_yaml_string, is_strict=True)
   self.assertEqual(1, params.a)
   self.assertEqual(5.2, params.b)
   self.assertEqual([30, 40], params.c)
   self.assertEqual('hello', params.d)
   self.assertEqual(False, params.e)
Ejemplo n.º 6
0
 def test_override_params_dict_using_csv_string(self):
   params = params_dict.ParamsDict({
       'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
       'd': {'d1': {'d2': 'hello'}}, 'e': False})
   override_csv_string = "b.b2=[3,4], d.d1.d2='hi, world', e=gs://test"
   params = params_dict.override_params_dict(
       params, override_csv_string, is_strict=True)
   self.assertEqual(1, params.a)
   self.assertEqual(2, params.b.b1)
   self.assertEqual([3, 4], params.b.b2)
   self.assertEqual('hi, world', params.d.d1.d2)
   self.assertEqual('gs://test', params.e)
Ejemplo n.º 7
0
 def test_override_params_dict_using_json_string(self):
   params = params_dict.ParamsDict({
       'a': 1, 'b': {'b1': 2, 'b2': [2, 3],},
       'd': {'d1': {'d2': 'hello'}}, 'e': False})
   override_json_string = "{ b: { b2: [3, 4] }, d: { d1: { d2: 'hi' } } }"
   params = params_dict.override_params_dict(
       params, override_json_string, is_strict=True)
   self.assertEqual(1, params.a)
   self.assertEqual(2, params.b.b1)
   self.assertEqual([3, 4], params.b.b2)
   self.assertEqual('hi', params.d.d1.d2)
   self.assertEqual(False, params.e)
Ejemplo n.º 8
0
def run():
    """Runs NHNet using Keras APIs."""
    if FLAGS.enable_mlir_bridge:
        tf.config.experimental.enable_mlir_bridge()

    strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=FLAGS.distribution_strategy,
        tpu_address=FLAGS.tpu)
    if strategy:
        logging.info("***** Number of cores used : %d",
                     strategy.num_replicas_in_sync)

    params = models.get_model_params(FLAGS.model_type)
    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params.override(
        {
            "len_title":
            FLAGS.len_title,
            "len_passage":
            FLAGS.len_passage,
            "num_hidden_layers":
            FLAGS.num_encoder_layers,
            "num_decoder_layers":
            FLAGS.num_decoder_layers,
            "passage_list":
            [chr(ord("b") + i) for i in range(FLAGS.num_nhnet_articles)],
        },
        is_strict=False)
    stats = {}
    if "train" in FLAGS.mode:
        stats = train(params, strategy)
    if "eval" in FLAGS.mode:
        timeout = 0 if FLAGS.mode == "train_and_eval" else FLAGS.eval_timeout
        # Uses padded decoding for TPU. Always uses cache.
        padded_decode = isinstance(strategy,
                                   tf.distribute.experimental.TPUStrategy)
        params.override({
            "padded_decode": padded_decode,
        }, is_strict=False)
        stats = evaluation.continuous_eval(
            strategy,
            params,
            model_type=FLAGS.model_type,
            eval_file_pattern=FLAGS.eval_file_pattern,
            batch_size=FLAGS.eval_batch_size,
            eval_steps=FLAGS.eval_steps,
            model_dir=FLAGS.model_dir,
            timeout=timeout)
    return stats
Ejemplo n.º 9
0
def _get_params_from_flags(flags_obj: flags.FlagValues):
    """Get ParamsDict from flags."""
    model = flags_obj.model_type.lower()
    dataset = flags_obj.dataset.lower()
    # TODO(zhz): Why is this a ParamsDict type
    params = configs.get_config(model=model, dataset=dataset)

    flags_overrides = {
        'model_dir': flags_obj.model_dir,
        'mode': flags_obj.mode,
        'model': {
            'name': model,
        },
        'runtime': {
            'run_eagerly': flags_obj.run_eagerly,
            'tpu': flags_obj.tpu,
        },
        'train_dataset': {
            'data_dir': flags_obj.data_dir,
        },
        'validation_dataset': {
            'data_dir': flags_obj.data_dir,
        },
        'train': {
            'time_history': {
                'log_steps': flags_obj.log_steps,
            },
        },
    }

    overriding_configs = (flags_obj.config_file, flags_obj.params_override,
                          flags_overrides)

    pp = pprint.PrettyPrinter()

    logging.info('Base params: %s', pp.pformat(params.as_dict()))

    for param in overriding_configs:
        logging.info('Overriding params: %s', param)
        # Set is_strict to false because we can have dynamic dict parameters.
        params = params_dict.override_params_dict(params,
                                                  param,
                                                  is_strict=False)

    params.validate()
    params.lock()

    logging.info('Final model parameters: %s', pp.pformat(params.as_dict()))
    return params
Ejemplo n.º 10
0
 def test_override_params_dict_using_yaml_file(self):
   params = params_dict.ParamsDict({
       'a': 1, 'b': 2.5, 'c': [3, 4], 'd': 'hello', 'e': False})
   override_yaml_file = self.write_temp_file(
       'params.yaml', r"""
       b: 5.2
       c: [30, 40]
       """)
   params = params_dict.override_params_dict(
       params, override_yaml_file, is_strict=True)
   self.assertEqual(1, params.a)
   self.assertEqual(5.2, params.b)
   self.assertEqual([30, 40], params.c)
   self.assertEqual('hello', params.d)
   self.assertEqual(False, params.e)
Ejemplo n.º 11
0
 def test_end_to_end(self, distribution_strategy, flag_mode):
   model_dir = self.get_temp_dir()
   experiment_config = configs.MultiTaskExperimentConfig(
       task=configs.MultiTaskConfig(
           task_routines=(
               configs.TaskRoutine(
                   task_name='foo', task_config=test_utils.FooConfig()),
               configs.TaskRoutine(
                   task_name='bar', task_config=test_utils.BarConfig()))))
   experiment_config = params_dict.override_params_dict(
       experiment_config, self._test_config, is_strict=False)
   with distribution_strategy.scope():
     test_multitask = multitask.MultiTask.from_config(experiment_config.task)
     model = test_utils.MockMultiTaskModel()
   train_lib.run_experiment(
       distribution_strategy=distribution_strategy,
       task=test_multitask,
       model=model,
       mode=flag_mode,
       params=experiment_config,
       model_dir=model_dir)
    def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
        model_dir = self.get_temp_dir()
        experiment_config = cfg.ExperimentConfig(
            trainer=prog_trainer_lib.ProgressiveTrainerConfig(),
            task=ProgTaskConfig())
        experiment_config = params_dict.override_params_dict(experiment_config,
                                                             self._test_config,
                                                             is_strict=False)

        with distribution_strategy.scope():
            task = task_factory.get_task(experiment_config.task,
                                         logging_dir=model_dir)

        _, logs = train_lib.run_experiment(
            distribution_strategy=distribution_strategy,
            task=task,
            mode=flag_mode,
            params=experiment_config,
            model_dir=model_dir,
            run_post_eval=run_post_eval)

        if run_post_eval:
            self.assertNotEmpty(logs)
        else:
            self.assertEmpty(logs)

        if flag_mode == 'eval':
            return
        self.assertNotEmpty(
            tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
        # Tests continuous evaluation.
        _, logs = train_lib.run_experiment(
            distribution_strategy=distribution_strategy,
            task=task,
            mode='continuous_eval',
            params=experiment_config,
            model_dir=model_dir,
            run_post_eval=run_post_eval)
        print(logs)
Ejemplo n.º 13
0
def run(callbacks=None):
    keras_utils.set_session_config(enable_xla=FLAGS.enable_xla)

    params = config_factory.config_generator(FLAGS.model)

    params = params_dict.override_params_dict(params,
                                              FLAGS.config_file,
                                              is_strict=True)

    params = params_dict.override_params_dict(params,
                                              FLAGS.params_override,
                                              is_strict=True)
    params.override(
        {
            'strategy_type': FLAGS.strategy_type,
            'model_dir': FLAGS.model_dir,
            'strategy_config': executor.strategy_flags_dict(),
        },
        is_strict=False)

    # Make sure use_tpu and strategy_type are in sync.
    params.use_tpu = (params.strategy_type == 'tpu')

    if not params.use_tpu:
        params.override(
            {
                'architecture': {
                    'use_bfloat16': False,
                },
                'norm_activation': {
                    'use_sync_bn': False,
                },
            },
            is_strict=True)

    params.validate()
    params.lock()
    pp = pprint.PrettyPrinter()
    params_str = pp.pformat(params.as_dict())
    logging.info('Model Parameters: %s', params_str)

    train_input_fn = None
    eval_input_fn = None
    training_file_pattern = FLAGS.training_file_pattern or params.train.train_file_pattern
    eval_file_pattern = FLAGS.eval_file_pattern or params.eval.eval_file_pattern
    if not training_file_pattern and not eval_file_pattern:
        raise ValueError(
            'Must provide at least one of training_file_pattern and '
            'eval_file_pattern.')

    if training_file_pattern:
        # Use global batch size for single host.
        train_input_fn = input_reader.InputFn(
            file_pattern=training_file_pattern,
            params=params,
            mode=input_reader.ModeKeys.TRAIN,
            batch_size=params.train.batch_size)

    if eval_file_pattern:
        eval_input_fn = input_reader.InputFn(
            file_pattern=eval_file_pattern,
            params=params,
            mode=input_reader.ModeKeys.PREDICT_WITH_GT,
            batch_size=params.eval.batch_size,
            num_examples=params.eval.eval_samples)

    if callbacks is None:
        callbacks = []

    if FLAGS.log_steps:
        callbacks.append(
            keras_utils.TimeHistory(
                batch_size=params.train.batch_size,
                log_steps=FLAGS.log_steps,
            ))

    return run_executor(params,
                        FLAGS.mode,
                        checkpoint_path=FLAGS.checkpoint_path,
                        train_input_fn=train_input_fn,
                        eval_input_fn=eval_input_fn,
                        callbacks=callbacks)