Beispiel #1
0
def main(_):
    logging.info('Parsing config files...')
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    params = get_exp_config()

    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case of
    # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
    # dtype is float16
    if params.runtime.mixed_precision_dtype:
        performance.set_mixed_precision_policy(
            params.runtime.mixed_precision_dtype)
    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu)

    with distribution_strategy.scope():
        task = distillation.BertDistillationTask(
            strategy=distribution_strategy,
            progressive=params.trainer.progressive,
            optimizer_config=params.trainer.optimizer_config,
            task_config=params.task)

    train_lib.run_experiment(distribution_strategy=distribution_strategy,
                             task=task,
                             mode=FLAGS.mode,
                             params=params,
                             model_dir=FLAGS.model_dir)
Beispiel #2
0
def main(_):
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
    params = train_utils.parse_configuration(FLAGS)
    model_dir = FLAGS.model_dir
    if 'train' in FLAGS.mode:
        # Pure eval modes do not output yaml files. Otherwise continuous eval job
        # may race against the train job for writing the same file.
        train_utils.serialize_config(params, model_dir)

    # Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
    # can have significant impact on model speeds by utilizing float16 in case of
    # GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
    # dtype is float16
    if params.runtime.mixed_precision_dtype:
        performance.set_mixed_precision_policy(
            params.runtime.mixed_precision_dtype)
    distribution_strategy = distribute_utils.get_distribution_strategy(
        distribution_strategy=params.runtime.distribution_strategy,
        all_reduce_alg=params.runtime.all_reduce_alg,
        num_gpus=params.runtime.num_gpus,
        tpu_address=params.runtime.tpu,
        **params.runtime.model_parallelism())
    with distribution_strategy.scope():
        task = task_factory.get_task(params.task, logging_dir=model_dir)

    train_lib.run_experiment(distribution_strategy=distribution_strategy,
                             task=task,
                             mode=FLAGS.mode,
                             params=params,
                             model_dir=model_dir)

    train_utils.save_gin_config(FLAGS.mode, model_dir)
Beispiel #3
0
def main(_):
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
  params = get_exp_config()

  distribution_strategy = distribution_utils.get_distribution_strategy(
      distribution_strategy=params.runtime.distribution_strategy,
      all_reduce_alg=params.runtime.all_reduce_alg,
      num_gpus=params.runtime.num_gpus,
      tpu_address=params.runtime.tpu)

  with distribution_strategy.scope():
    task = masked_lm.ProgressiveMaskedLM(
        strategy=distribution_strategy,
        progressive_config=params.trainer.progressive,
        optimizer_config=params.trainer.optimizer_config,
        train_data_config=params.task.train_data,
        small_train_data_config=params.task.small_train_data,
        task_config=params.task)

  train_lib.run_experiment(
      distribution_strategy=distribution_strategy,
      task=task,
      mode=FLAGS.mode,
      params=params,
      model_dir=FLAGS.model_dir)
Beispiel #4
0
    def test_end_to_end(self, distribution_strategy, flag_mode, run_post_eval):
        model_dir = self.get_temp_dir()
        experiment_config = cfg.ExperimentConfig(
            trainer=prog_trainer_lib.ProgressiveTrainerConfig(),
            task=ProgTaskConfig())
        experiment_config = params_dict.override_params_dict(experiment_config,
                                                             self._test_config,
                                                             is_strict=False)

        with distribution_strategy.scope():
            task = task_factory.get_task(experiment_config.task,
                                         logging_dir=model_dir)

        _, logs = train_lib.run_experiment(
            distribution_strategy=distribution_strategy,
            task=task,
            mode=flag_mode,
            params=experiment_config,
            model_dir=model_dir,
            run_post_eval=run_post_eval)

        if run_post_eval:
            self.assertNotEmpty(logs)
        else:
            self.assertEmpty(logs)

        if flag_mode == 'eval':
            return
        self.assertNotEmpty(
            tf.io.gfile.glob(os.path.join(model_dir, 'checkpoint')))
        # Tests continuous evaluation.
        _, logs = train_lib.run_experiment(
            distribution_strategy=distribution_strategy,
            task=task,
            mode='continuous_eval',
            params=experiment_config,
            model_dir=model_dir,
            run_post_eval=run_post_eval)
        print(logs)