def j2j_train(model_name, dataset_name, data_dir=None, output_dir=None, config_file=None, config=None): """Main function to train the given model on the given dataset. Args: model_name: The name of the model to train. dataset_name: The name of the dataset to train on. data_dir: Directory where the data is located. output_dir: Directory where to put the logs and checkpoints. config_file: the gin configuration file to use. config: string (in gin format) to override gin parameters. """ gin.bind_parameter("train_fn.dataset", dataset_name) if FLAGS.model: config = [] if config is None else config config += ["train_fn.model=@models." + model_name] gin.parse_config_files_and_bindings(config_file, config) if output_dir: if not tf.gfile.Exists(output_dir): tf.gfile.MkDir(output_dir) config_path = os.path.join(output_dir, "gin.config") # TODO(lukaszkaiser): why is the file empty if there's no provided config? with tf.gfile.Open(config_path, "w") as f: f.write(gin.operative_config_str()) j2j.train_fn(data_dir, output_dir=output_dir)
def main(argv): tf.logging.set_verbosity(tf.logging.INFO) if FLAGS.jax: # Hacking main v1 flags to work with jax. config_strs = [] config_strs.append("train_fn.train_steps=" + str(FLAGS.train_steps)) config_strs.append("train_fn.eval_steps=" + str(FLAGS.eval_steps)) config_strs.append("train_fn.eval_frequency=" + str(FLAGS.local_eval_frequency)) if FLAGS.hparams: config_strs.extend(str(FLAGS.hparams).split(",")) data_dir = os.path.expanduser(FLAGS.data_dir) output_dir = os.path.expanduser(FLAGS.output_dir) gin.bind_parameter("train_fn.dataset", FLAGS.problem) config_strs += ["train_fn.model=@" + FLAGS.model] config_files = [] if FLAGS.hparams_set: config_files = [os.path.expanduser(FLAGS.hparams_set)] gin.parse_config_files_and_bindings(config_files, config_strs) j2j.train_fn(data_dir=data_dir, output_dir=output_dir) return usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) # If we just have to print the registry, do that and exit early. maybe_log_registry_and_exit() # Create HParams. if argv: set_hparams_from_args(argv[1:]) hparams = create_hparams() if FLAGS.schedule == "train" or FLAGS.schedule == "train_eval_and_decode": mlperf_log.transformer_print(key=mlperf_log.RUN_START, hparams=hparams) if FLAGS.schedule == "run_std_server": run_std_server() mlperf_log.transformer_print(key=mlperf_log.RUN_SET_RANDOM_SEED, value=FLAGS.random_seed, hparams=hparams) trainer_lib.set_random_seed(FLAGS.random_seed) if FLAGS.cloud_mlengine: cloud_mlengine.launch() return if FLAGS.generate_data: generate_data() if cloud_mlengine.job_dir(): FLAGS.output_dir = cloud_mlengine.job_dir() exp_fn = create_experiment_fn() exp = exp_fn(create_run_config(hparams), hparams) if is_chief(): save_metadata(hparams) execute_schedule(exp) if FLAGS.schedule != "train": mlperf_log.transformer_print(key=mlperf_log.RUN_FINAL, hparams=hparams)
def main(_): _setup_gin() # Setup directories data_dir, output_dir = FLAGS.data_dir, FLAGS.output_dir data_dir = data_dir and os.path.expanduser(data_dir) output_dir = output_dir and os.path.expanduser(output_dir) j2j.train_fn(data_dir, output_dir=output_dir)
def main(_): _setup_gin() # Setup directories data_dir = FLAGS.data_dir output_dir = FLAGS.output_dir or _default_output_dir() assert data_dir, "Must specify a data directory" assert output_dir, "Must specify an output directory" j2j.log("Using output_dir %s" % output_dir) data_dir = os.path.expanduser(data_dir) output_dir = os.path.expanduser(output_dir) j2j.train_fn(data_dir=data_dir, output_dir=output_dir)