Beispiel #1
0
def main(_):
    # TODO(b/177863554): consolidate to nlp/train.py
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    params = train_utils.parse_configuration(FLAGS)
    model_dir = FLAGS.model_dir
    train_utils.serialize_config(params, model_dir)
    continuous_finetune_lib.run_continuous_finetune(
        FLAGS.mode, params, model_dir, pretrain_steps=FLAGS.pretrain_steps)
    train_utils.save_gin_config(FLAGS.mode, model_dir)
Beispiel #2
0
    def testContinuousFinetune(self):
        pretrain_steps = 1
        src_model_dir = self.get_temp_dir()
        flags_dict = dict(experiment='mock',
                          mode='continuous_train_and_eval',
                          model_dir=self._model_dir,
                          params_override={
                              'task': {
                                  'init_checkpoint': src_model_dir,
                              },
                              'trainer': {
                                  'continuous_eval_timeout': 1,
                                  'steps_per_loop': 1,
                                  'train_steps': 1,
                                  'validation_steps': 1,
                                  'best_checkpoint_export_subdir': 'best_ckpt',
                                  'best_checkpoint_eval_metric': 'acc',
                                  'optimizer_config': {
                                      'optimizer': {
                                          'type': 'sgd'
                                      },
                                      'learning_rate': {
                                          'type': 'constant'
                                      }
                                  }
                              }
                          })

        with flagsaver.flagsaver(**flags_dict):
            # Train and save some checkpoints.
            params = train_utils.parse_configuration(flags.FLAGS)
            distribution_strategy = tf.distribute.get_strategy()
            with distribution_strategy.scope():
                task = task_factory.get_task(params.task,
                                             logging_dir=src_model_dir)
            _ = train_lib.run_experiment(
                distribution_strategy=distribution_strategy,
                task=task,
                mode='train',
                params=params,
                model_dir=src_model_dir)

            params = train_utils.parse_configuration(FLAGS)
            eval_metrics = continuous_finetune_lib.run_continuous_finetune(
                FLAGS.mode,
                params,
                FLAGS.model_dir,
                run_post_eval=True,
                pretrain_steps=pretrain_steps)
            self.assertIn('best_acc', eval_metrics)

            self.assertFalse(
                tf.io.gfile.exists(os.path.join(FLAGS.model_dir,
                                                'checkpoint')))
Beispiel #3
0
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
  params = train_utils.parse_configuration(FLAGS)
  model_dir = FLAGS.model_dir
  if 'train' in FLAGS.mode:
    # Pure eval modes do not output yaml files. Otherwise continuous eval job
    # may race against the train job for writing the same file.
    train_utils.serialize_config(params, model_dir)

  if FLAGS.mode == 'continuous_train_and_eval':
    continuous_finetune_lib.run_continuous_finetune(
        FLAGS.mode, params, model_dir, pretrain_steps=FLAGS.pretrain_steps)

  else:
    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case
    # of GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only
    # when dtype is float16
    if params.runtime.mixed_precision_dtype:
      performance.set_mixed_precision_policy(
          params.runtime.mixed_precision_dtype)
    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu,
        **params.runtime.model_parallelism())
    with distribution_strategy.scope():
      task = task_factory.get_task(params.task, logging_dir=model_dir)

    train_lib.run_experiment(
        distribution_strategy=distribution_strategy,
        task=task,
        mode=FLAGS.mode,
        params=params,
        model_dir=model_dir)

  train_utils.save_gin_config(FLAGS.mode, model_dir)