Ejemplo n.º 1
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    t2t_trainer.maybe_log_registry_and_exit()

    if FLAGS.cloud_mlengine:
        cloud_mlengine.launch()
        return

    if FLAGS.generate_data:
        t2t_trainer.generate_data()

    if cloud_mlengine.job_dir():
        FLAGS.output_dir = cloud_mlengine.job_dir()

    if argv:
        t2t_trainer.set_hparams_from_args(argv[1:])

    root_output_dir = FLAGS.output_dir

    if FLAGS.teacher_dir:
        teacher_dir = FLAGS.teacher_dir
    else:
        teacher_dir = os.path.join(root_output_dir, "teacher")

    # Train Teacher ============
    if FLAGS.skip_teacher_training:
        tf.logging.info("training teacher skipped")
    else:
        hparams = t2t_trainer.create_hparams()
        hparams.distill_phase = "train"
        FLAGS.output_dir = teacher_dir

        exp_fn = t2t_trainer.create_experiment_fn()
        run_config = t2t_trainer.create_run_config(hparams)
        exp = exp_fn(run_config, hparams)
        if t2t_trainer.is_chief():
            t2t_trainer.save_metadata(hparams)
        t2t_trainer.execute_schedule(exp)

    # ==========================
    # Train Student ============
    hparams = t2t_trainer.create_hparams()
    hparams.add_hparam("teacher_dir", teacher_dir)
    hparams.distill_phase = "distill"
    if FLAGS.student_dir:
        student_dir = FLAGS.student_dir
    else:
        student_dir = os.path.join(root_output_dir, "student")
    FLAGS.output_dir = student_dir
    hparams.add_hparam("student_dir", student_dir)

    exp_fn = t2t_trainer.create_experiment_fn()
    run_config = t2t_trainer.create_run_config(hparams)
    exp = exp_fn(run_config, hparams)

    if t2t_trainer.is_chief():
        t2t_trainer.save_metadata(hparams)
    t2t_trainer.execute_schedule(exp)
Ejemplo n.º 2
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()

  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  with t2t_trainer.maybe_cloud_tpu():
    root_output_dir = FLAGS.output_dir

    # Train Teacher ============
    hparams = t2t_trainer.create_hparams()
    hparams.distill_phase = "train"
    teacher_dir = os.path.join(root_output_dir, "teacher")
    FLAGS.output_dir = teacher_dir

    exp_fn = t2t_trainer.create_experiment_fn()
    run_config = t2t_trainer.create_run_config(hparams)
    exp = exp_fn(run_config, hparams)
    if t2t_trainer.is_chief():
      t2t_trainer.save_metadata(hparams)
    t2t_trainer.execute_schedule(exp)
    # ==========================
    # Train Student ============
    hparams = t2t_trainer.create_hparams()
    hparams.add_hparam("teacher_dir", teacher_dir)
    hparams.distill_phase = "distill"
    student_dir = os.path.join(root_output_dir, "student")
    FLAGS.output_dir = student_dir

    exp_fn = t2t_trainer.create_experiment_fn()
    run_config = t2t_trainer.create_run_config(hparams)
    exp = exp_fn(run_config, hparams)

    if t2t_trainer.is_chief():
      t2t_trainer.save_metadata(hparams)
    t2t_trainer.execute_schedule(exp)