def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) if FLAGS.tasks_cache_dir: t5.data.set_global_cache_dirs(FLAGS.tasks_cache_dir) # Add search path for gin files stored in package. gin.add_config_file_search_path( pkg_resources.resource_filename(__name__, "gin")) tf.io.gfile.makedirs(FLAGS.model_dir) suffix = 0 command_filename = os.path.join(FLAGS.model_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(FLAGS.model_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) utils.parse_gin_defaults_and_flags() utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
def main(_): # Set up the default values for the configurable parameters. These values will # be overridden by any user provided gin files/parameters. gin.parse_config_file( os.path.join(os.path.dirname(__file__), _DEFAULT_CONFIG_FILE)) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
def run(**kwargs) -> None: """Runs a T5 model for training, finetuning, evaluation etc.""" tf.disable_v2_behavior() if gin.query_parameter("utils.run.mode") == "eval": # Increase the recursion limit, see: https://github.com/pltrdy/rouge/issues/19 length = gin.query_parameter("utils.run.sequence_length").get( "inputs", 512) batch_size = 1024 # TODO: do not hardcode batch_size for recursionlimit calc sys.setrecursionlimit(batch_size * length + 10) utils.run(**kwargs)
def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) tf.io.gfile.makedirs(FLAGS.model_dir) with tf.io.gfile.GFile(os.path.join(FLAGS.model_dir, "command"), "w") as f: f.write(" ".join(sys.argv)) utils.parse_gin_defaults_and_flags() utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
def main(_): # Set up the default values for the configurable parameters. These values will # be overridden by any user provided gin files/parameters. gin.parse_config_file( os.path.join(os.path.dirname(__file__), _DEFAULT_CONFIG_FILE)) gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param) dataset = utils.get_tfds_dataset(data_dir=FLAGS.data_dir) utils.run(dataset=dataset, tpu_job_name=FLAGS.tpu_job_name, master_dtype=FLAGS.master_dtype, slice_dtype=FLAGS.slice_dtype, activation_dtype=FLAGS.activation_dtype, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, autostack=FLAGS.autostack, model_dir=FLAGS.model_dir)
def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) tf.io.gfile.makedirs(FLAGS.model_dir) suffix = 0 command_filename = os.path.join(FLAGS.model_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(FLAGS.model_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) utils.parse_gin_defaults_and_flags() utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) if FLAGS.t5_tfds_data_dir: t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir) t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs) # Add search path for gin files stored in package. gin.add_config_file_search_path( pkg_resources.resource_filename(__name__, "gin")) try: tf.io.gfile.makedirs(FLAGS.model_dir) suffix = 0 command_filename = os.path.join(FLAGS.model_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(FLAGS.model_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) except tf.errors.PermissionDeniedError: logging.info( "No write access to model directory. Skipping command logging.") utils.parse_gin_defaults_and_flags() if FLAGS.use_model_api: model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir) if FLAGS.checkpoint_mode == "latest": checkpoint_steps = -1 elif FLAGS.checkpoint_mode == "all": checkpoint_steps = "all" else: checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps] if FLAGS.mode == "train": model.train(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps) elif FLAGS.mode == "eval": model.eval(mixture_or_task_name=FLAGS.mixture_or_task, checkpoint_steps=checkpoint_steps, summary_dir=FLAGS.eval_summary_dir, split=FLAGS.eval_split) elif FLAGS.mode == "finetune": if not (FLAGS.checkpoint_mode == "latest" or (FLAGS.checkpoint_mode == "specific" and len(FLAGS.checkpoint_steps) == 1)): raise ValueError( "Must specify a single checkpoint for finetuning a model.") if isinstance(checkpoint_steps, list): checkpoint_steps = checkpoint_steps[0] model.finetune(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps, pretrained_model_dir=FLAGS.pretrained_model_dir, checkpoint_steps=checkpoint_steps) elif FLAGS.mode == "predict": model.predict(checkpoint_steps=checkpoint_steps, input_file=FLAGS.input_file, output_file=FLAGS.output_file) elif FLAGS.mode == "export": if not (FLAGS.checkpoint_mode == "latest" or (FLAGS.checkpoint_mode == "specific" and len(FLAGS.checkpoint_steps) == 1)): raise ValueError( "Must specify a single checkpoint for exporting a model.") if isinstance(checkpoint_steps, list): checkpoint_steps = checkpoint_steps[0] model.export(export_dir=FLAGS.export_dir, checkpoint_step=checkpoint_steps) else: utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) if FLAGS.t5_tfds_data_dir: t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir) # Add search path for gin files stored in package. gin.add_config_file_search_path( pkg_resources.resource_filename(__name__, "gin")) try: suffix = 0 command_dir = os.path.join(FLAGS.model_dir, "commands") tf.io.gfile.makedirs(command_dir) command_filename = os.path.join(command_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(command_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) except (tf.errors.PermissionDeniedError, tf.errors.InvalidArgumentError): logging.info( "No write access to model directory. Skipping command logging.") utils.parse_gin_defaults_and_flags( skip_unknown=(FLAGS.skip_all_gin_unknowns or (mesh_transformer.DEPRECATED_GIN_REFERENCES + tuple(FLAGS.additional_deprecated_gin_references))), finalize_config=False) # We must overide this binding explicitly since it is set to a deprecated # function or class in many existing configs. gin.bind_parameter("run.vocabulary", mesh_transformer.get_vocabulary()) gin.finalize() # Set cache dir after loading gin to avoid unintentionally overriding it. t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs) if FLAGS.use_model_api: model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, tpu_topology=FLAGS.tpu_topology, model_parallelism=FLAGS.model_parallelism, model_dir=FLAGS.model_dir, batch_size=FLAGS.batch_size, sequence_length={ "inputs": FLAGS.input_sequence_length, "targets": FLAGS.target_sequence_length }) if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps: raise ValueError( "checkpoint_mode is set to %s and checkpoint_steps is " "also set. To use a particular checkpoint, please set " "checkpoint_mode to 'specific'. For other modes, please " "ensure that checkpoint_steps is not set." % FLAGS.checkpoint_mode) if FLAGS.checkpoint_mode == "latest": checkpoint_steps = -1 elif FLAGS.checkpoint_mode == "all": checkpoint_steps = "all" else: checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps] if FLAGS.mode == "train": model.train(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps) elif FLAGS.mode == "eval": model.eval(mixture_or_task_name=FLAGS.mixture_or_task, checkpoint_steps=checkpoint_steps, summary_dir=FLAGS.eval_summary_dir, split=FLAGS.eval_split) elif FLAGS.mode == "finetune": if not (FLAGS.checkpoint_mode == "latest" or (FLAGS.checkpoint_mode == "specific" and len(FLAGS.checkpoint_steps) == 1)): raise ValueError( "Must specify a single checkpoint for finetuning a model.") if isinstance(checkpoint_steps, list): checkpoint_steps = checkpoint_steps[0] model.finetune(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps, pretrained_model_dir=FLAGS.pretrained_model_dir, checkpoint_steps=checkpoint_steps) elif FLAGS.mode == "predict": model.predict( checkpoint_steps=checkpoint_steps, input_file=FLAGS.input_file, output_file=FLAGS.output_file, beam_size=FLAGS.beam_size, temperature=FLAGS.temperature, keep_top_k=FLAGS.keep_top_k, ) elif FLAGS.mode == "score": model.score(FLAGS.input_file, FLAGS.target_file, scores_file=FLAGS.output_file, checkpoint_steps=checkpoint_steps) elif FLAGS.mode in ("export_predict", "export_score"): if not (FLAGS.checkpoint_mode == "latest" or (FLAGS.checkpoint_mode == "specific" and len(FLAGS.checkpoint_steps) == 1)): raise ValueError( "Must specify a single checkpoint for exporting a model.") if isinstance(checkpoint_steps, list): checkpoint_steps = checkpoint_steps[0] model.export(export_dir=FLAGS.export_dir, checkpoint_step=checkpoint_steps, beam_size=FLAGS.beam_size, temperature=FLAGS.temperature, keep_top_k=FLAGS.keep_top_k, eval_with_score=(FLAGS.mode == "export_score")) else: raise ValueError("--mode flag must be set when using Model API.") else: if FLAGS.mode: raise ValueError( "--mode flag should only be set when using Model API.") if not FLAGS.tpu: with gin.unlock_config(): gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32") gin.bind_parameter("utils.get_variable_dtype.activation_dtype", "float32") utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) if FLAGS.t5_tfds_data_dir: t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir) t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs) # Add search path for gin files stored in package. gin.add_config_file_search_path( pkg_resources.resource_filename(__name__, "gin")) try: tf.io.gfile.makedirs(FLAGS.model_dir) suffix = 0 command_filename = os.path.join(FLAGS.model_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(FLAGS.model_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) except tf.errors.PermissionDeniedError: logging.info( "No write access to model directory. Skipping command logging.") utils.parse_gin_defaults_and_flags() if FLAGS.use_model_api: model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir) if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps: raise ValueError( "checkpoint_mode is set to %s and checkpoint_steps is " "also set. To use a particular checkpoint, please set " "checkpoint_mode to 'specific'. For other modes, please " "ensure that checkpoint_steps is not set." % FLAGS.checkpoint_mode) if FLAGS.checkpoint_mode == "latest": checkpoint_steps = -1 elif FLAGS.checkpoint_mode == "all": checkpoint_steps = "all" else: checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps] if FLAGS.mode == "train": model.train(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps) elif FLAGS.mode == "eval": model.eval(mixture_or_task_name=FLAGS.mixture_or_task, checkpoint_steps=checkpoint_steps, summary_dir=FLAGS.eval_summary_dir, split=FLAGS.eval_split) elif FLAGS.mode == "finetune": if not (FLAGS.checkpoint_mode == "latest" or (FLAGS.checkpoint_mode == "specific" and len(FLAGS.checkpoint_steps) == 1)): raise ValueError( "Must specify a single checkpoint for finetuning a model.") if isinstance(checkpoint_steps, list): checkpoint_steps = checkpoint_steps[0] model.finetune(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps, pretrained_model_dir=FLAGS.pretrained_model_dir, checkpoint_steps=checkpoint_steps) elif FLAGS.mode == "predict": model.predict(checkpoint_steps=checkpoint_steps, input_file=FLAGS.input_file, output_file=FLAGS.output_file) elif FLAGS.mode == "score": model.score(FLAGS.input_file, FLAGS.target_file, scores_file=FLAGS.output_file, checkpoint_steps=checkpoint_steps) elif FLAGS.mode == "export": if not (FLAGS.checkpoint_mode == "latest" or (FLAGS.checkpoint_mode == "specific" and len(FLAGS.checkpoint_steps) == 1)): raise ValueError( "Must specify a single checkpoint for exporting a model.") if isinstance(checkpoint_steps, list): checkpoint_steps = checkpoint_steps[0] model.batch_size = FLAGS.export_batch_size model.export(export_dir=FLAGS.export_dir, checkpoint_step=checkpoint_steps) else: raise ValueError("--mode flag must be set when using Model API.") else: if FLAGS.mode: raise ValueError( "--mode flag should only be set when using Model API.") if not FLAGS.tpu: with gin.unlock_config(): gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32") gin.bind_parameter("utils.get_variable_dtype.activation_dtype", "float32") utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) if FLAGS.t5_tfds_data_dir: t5.data.set_tfds_data_dir_override(FLAGS.tfds_data_dir) t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs) # Add search path for gin files stored in package. gin.add_config_file_search_path( pkg_resources.resource_filename(__name__, "gin")) tf.io.gfile.makedirs(FLAGS.model_dir) suffix = 0 command_filename = os.path.join(FLAGS.model_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(FLAGS.model_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) utils.parse_gin_defaults_and_flags() if FLAGS.use_model_api: model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir) if FLAGS.checkpoint_mode == "latest": ckpts = tf.io.gfile.glob(FLAGS.model_dir + "model.*index") ckpts = [re.sub(".*ckpt-", "", c) for c in ckpts] ckpts = sorted([int(c.replace(".index", "")) for c in ckpts]) checkpoint_step = ckpts[-1] elif FLAGS.checkpoint_mode == "all": checkpoint_step = "all" else: checkpoint_step = [int(c) for c in FLAGS.checkpoint_steps] if FLAGS.mode == "train": model.train(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps) elif FLAGS.mode == "eval": model.eval(mixture_or_task_name=FLAGS.mixture_or_task, checkpoint_step=checkpoint_step, summary_dir=FLAGS.eval_summary_dir, split=FLAGS.eval_split) else: model.predict(checkpoint_step=checkpoint_step, input_file=FLAGS.input_file, output_file=FLAGS.output_file) else: utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)