Пример #1
0
def j2j_train(model_name,
              dataset_name,
              data_dir=None,
              output_dir=None,
              config_file=None,
              config=None):
    """Main function to train the given model on the given dataset.

  Args:
    model_name: The name of the model to train.
    dataset_name: The name of the dataset to train on.
    data_dir: Directory where the data is located.
    output_dir: Directory where to put the logs and checkpoints.
    config_file: the gin configuration file to use.
    config: string (in gin format) to override gin parameters.
  """
    gin.bind_parameter("train_fn.dataset", dataset_name)
    if FLAGS.model:
        config = [] if config is None else config
        config += ["train_fn.model=@models." + model_name]
    gin.parse_config_files_and_bindings(config_file, config)
    if output_dir:
        if not tf.gfile.Exists(output_dir):
            tf.gfile.MkDir(output_dir)
        config_path = os.path.join(output_dir, "gin.config")
        # TODO(lukaszkaiser): why is the file empty if there's no provided config?
        with tf.gfile.Open(config_path, "w") as f:
            f.write(gin.operative_config_str())
    j2j.train_fn(data_dir, output_dir=output_dir)
Пример #2
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)

    if FLAGS.jax:
        # Hacking main v1 flags to work with jax.
        config_strs = []
        config_strs.append("train_fn.train_steps=" + str(FLAGS.train_steps))
        config_strs.append("train_fn.eval_steps=" + str(FLAGS.eval_steps))
        config_strs.append("train_fn.eval_frequency=" +
                           str(FLAGS.local_eval_frequency))
        if FLAGS.hparams:
            config_strs.extend(str(FLAGS.hparams).split(","))
        data_dir = os.path.expanduser(FLAGS.data_dir)
        output_dir = os.path.expanduser(FLAGS.output_dir)

        gin.bind_parameter("train_fn.dataset", FLAGS.problem)
        config_strs += ["train_fn.model=@" + FLAGS.model]
        config_files = []
        if FLAGS.hparams_set:
            config_files = [os.path.expanduser(FLAGS.hparams_set)]
        gin.parse_config_files_and_bindings(config_files, config_strs)
        j2j.train_fn(data_dir=data_dir, output_dir=output_dir)
        return

    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    # If we just have to print the registry, do that and exit early.
    maybe_log_registry_and_exit()

    # Create HParams.
    if argv:
        set_hparams_from_args(argv[1:])
    hparams = create_hparams()

    if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode":
        mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams)
    if FLAGS.schedule == "run_std_server":
        run_std_server()
    mlperf_log.transformer_print(key=mlperf_log.RUN_SET_RANDOM_SEED,
                                 value=FLAGS.random_seed,
                                 hparams=hparams)
    trainer_lib.set_random_seed(FLAGS.random_seed)

    if FLAGS.cloud_mlengine:
        cloud_mlengine.launch()
        return

    if FLAGS.generate_data:
        generate_data()

    if cloud_mlengine.job_dir():
        FLAGS.output_dir = cloud_mlengine.job_dir()

    exp_fn = create_experiment_fn()
    exp = exp_fn(create_run_config(hparams), hparams)
    if is_chief():
        save_metadata(hparams)
    execute_schedule(exp)
    if FLAGS.schedule != "train":
        mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL, hparams=hparams)
Пример #3
0
def main(_):
    _setup_gin()

    # Setup directories
    data_dir, output_dir = FLAGS.data_dir, FLAGS.output_dir
    data_dir = data_dir and os.path.expanduser(data_dir)
    output_dir = output_dir and os.path.expanduser(output_dir)

    j2j.train_fn(data_dir, output_dir=output_dir)
Пример #4
0
def main(_):
    _setup_gin()

    # Setup directories
    data_dir = FLAGS.data_dir
    output_dir = FLAGS.output_dir or _default_output_dir()
    assert data_dir, "Must specify a data directory"
    assert output_dir, "Must specify an output directory"
    j2j.log("Using output_dir %s" % output_dir)

    data_dir = os.path.expanduser(data_dir)
    output_dir = os.path.expanduser(output_dir)

    j2j.train_fn(data_dir=data_dir, output_dir=output_dir)