def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) if FLAGS.tasks_cache_dir: t5.data.set_global_cache_dirs(FLAGS.tasks_cache_dir) # Add search path for gin files stored in package. gin.add_config_file_search_path( pkg_resources.resource_filename(__name__, "gin")) tf.io.gfile.makedirs(FLAGS.model_dir) suffix = 0 command_filename = os.path.join(FLAGS.model_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(FLAGS.model_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) utils.parse_gin_defaults_and_flags() utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) tf.io.gfile.makedirs(FLAGS.model_dir) with tf.io.gfile.GFile(os.path.join(FLAGS.model_dir, "command"), "w") as f: f.write(" ".join(sys.argv)) utils.parse_gin_defaults_and_flags() utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) tf.io.gfile.makedirs(FLAGS.model_dir) suffix = 0 command_filename = os.path.join(FLAGS.model_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(FLAGS.model_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) utils.parse_gin_defaults_and_flags() utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) if FLAGS.t5_tfds_data_dir: t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir) t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs) # Add search path for gin files stored in package. gin.add_config_file_search_path( pkg_resources.resource_filename(__name__, "gin")) try: tf.io.gfile.makedirs(FLAGS.model_dir) suffix = 0 command_filename = os.path.join(FLAGS.model_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(FLAGS.model_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) except tf.errors.PermissionDeniedError: logging.info( "No write access to model directory. Skipping command logging.") utils.parse_gin_defaults_and_flags() if FLAGS.use_model_api: model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir) if FLAGS.checkpoint_mode == "latest": checkpoint_steps = -1 elif FLAGS.checkpoint_mode == "all": checkpoint_steps = "all" else: checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps] if FLAGS.mode == "train": model.train(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps) elif FLAGS.mode == "eval": model.eval(mixture_or_task_name=FLAGS.mixture_or_task, checkpoint_steps=checkpoint_steps, summary_dir=FLAGS.eval_summary_dir, split=FLAGS.eval_split) elif FLAGS.mode == "finetune": if not (FLAGS.checkpoint_mode == "latest" or (FLAGS.checkpoint_mode == "specific" and len(FLAGS.checkpoint_steps) == 1)): raise ValueError( "Must specify a single checkpoint for finetuning a model.") if isinstance(checkpoint_steps, list): checkpoint_steps = checkpoint_steps[0] model.finetune(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps, pretrained_model_dir=FLAGS.pretrained_model_dir, checkpoint_steps=checkpoint_steps) elif FLAGS.mode == "predict": model.predict(checkpoint_steps=checkpoint_steps, input_file=FLAGS.input_file, output_file=FLAGS.output_file) elif FLAGS.mode == "export": if not (FLAGS.checkpoint_mode == "latest" or (FLAGS.checkpoint_mode == "specific" and len(FLAGS.checkpoint_steps) == 1)): raise ValueError( "Must specify a single checkpoint for exporting a model.") if isinstance(checkpoint_steps, list): checkpoint_steps = checkpoint_steps[0] model.export(export_dir=FLAGS.export_dir, checkpoint_step=checkpoint_steps) else: utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) if FLAGS.t5_tfds_data_dir: t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir) # Add search path for gin files stored in package. gin.add_config_file_search_path( pkg_resources.resource_filename(__name__, "gin")) try: suffix = 0 command_dir = os.path.join(FLAGS.model_dir, "commands") tf.io.gfile.makedirs(command_dir) command_filename = os.path.join(command_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(command_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) except (tf.errors.PermissionDeniedError, tf.errors.InvalidArgumentError): logging.info( "No write access to model directory. Skipping command logging.") utils.parse_gin_defaults_and_flags( skip_unknown=(FLAGS.skip_all_gin_unknowns or (mesh_transformer.DEPRECATED_GIN_REFERENCES + tuple(FLAGS.additional_deprecated_gin_references))), finalize_config=False) # We must overide this binding explicitly since it is set to a deprecated # function or class in many existing configs. gin.bind_parameter("run.vocabulary", mesh_transformer.get_vocabulary()) gin.finalize() # Set cache dir after loading gin to avoid unintentionally overriding it. t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs) if FLAGS.use_model_api: model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, tpu_topology=FLAGS.tpu_topology, model_parallelism=FLAGS.model_parallelism, model_dir=FLAGS.model_dir, batch_size=FLAGS.batch_size, sequence_length={ "inputs": FLAGS.input_sequence_length, "targets": FLAGS.target_sequence_length }) if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps: raise ValueError( "checkpoint_mode is set to %s and checkpoint_steps is " "also set. To use a particular checkpoint, please set " "checkpoint_mode to 'specific'. For other modes, please " "ensure that checkpoint_steps is not set." % FLAGS.checkpoint_mode) if FLAGS.checkpoint_mode == "latest": checkpoint_steps = -1 elif FLAGS.checkpoint_mode == "all": checkpoint_steps = "all" else: checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps] if FLAGS.mode == "train": model.train(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps) elif FLAGS.mode == "eval": model.eval(mixture_or_task_name=FLAGS.mixture_or_task, checkpoint_steps=checkpoint_steps, summary_dir=FLAGS.eval_summary_dir, split=FLAGS.eval_split) elif FLAGS.mode == "finetune": if not (FLAGS.checkpoint_mode == "latest" or (FLAGS.checkpoint_mode == "specific" and len(FLAGS.checkpoint_steps) == 1)): raise ValueError( "Must specify a single checkpoint for finetuning a model.") if isinstance(checkpoint_steps, list): checkpoint_steps = checkpoint_steps[0] model.finetune(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps, pretrained_model_dir=FLAGS.pretrained_model_dir, checkpoint_steps=checkpoint_steps) elif FLAGS.mode == "predict": model.predict( checkpoint_steps=checkpoint_steps, input_file=FLAGS.input_file, output_file=FLAGS.output_file, beam_size=FLAGS.beam_size, temperature=FLAGS.temperature, keep_top_k=FLAGS.keep_top_k, ) elif FLAGS.mode == "score": model.score(FLAGS.input_file, FLAGS.target_file, scores_file=FLAGS.output_file, checkpoint_steps=checkpoint_steps) elif FLAGS.mode in ("export_predict", "export_score"): if not (FLAGS.checkpoint_mode == "latest" or (FLAGS.checkpoint_mode == "specific" and len(FLAGS.checkpoint_steps) == 1)): raise ValueError( "Must specify a single checkpoint for exporting a model.") if isinstance(checkpoint_steps, list): checkpoint_steps = checkpoint_steps[0] model.export(export_dir=FLAGS.export_dir, checkpoint_step=checkpoint_steps, beam_size=FLAGS.beam_size, temperature=FLAGS.temperature, keep_top_k=FLAGS.keep_top_k, eval_with_score=(FLAGS.mode == "export_score")) else: raise ValueError("--mode flag must be set when using Model API.") else: if FLAGS.mode: raise ValueError( "--mode flag should only be set when using Model API.") if not FLAGS.tpu: with gin.unlock_config(): gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32") gin.bind_parameter("utils.get_variable_dtype.activation_dtype", "float32") utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) # Add search path for gin files stored in package. gin.add_config_file_search_path( pkg_resources.resource_filename(__name__, "gin")) models_dir_name = FLAGS.model_dir_name if FLAGS.model_dir_counter >= 0: models_dir_name += "_%s" % str(FLAGS.model_dir_counter) models_dir = os.path.join(FLAGS.base_dir, models_dir_name) model_dir = os.path.join(models_dir, FLAGS.model_size) try: tf.io.gfile.makedirs(model_dir) suffix = 0 command_filename = os.path.join(model_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(model_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) except tf.errors.PermissionDeniedError: logging.info( "No write access to model directory. Skipping command logging.") utils.parse_gin_defaults_and_flags() # Load and print a few examples. st_task = TaskRegistry_ll.get("processed_cctk") sequence_length = {"inputs": 64, "targets": 64} sequence_length[ "attribute"] = 64 # Or "attribute": 1 but packing not efficient... sequence_length["codeprefixedtargets"] = 64 sequence_length["controlcode"] = 64 with gin.config_scope('caet5'): ds = st_task.get_dataset(split="validation", sequence_length=sequence_length) print("A few preprocessed validation examples...") for ex in tfds.as_numpy(ds.take(5)): print(ex) """ print("unitests") mixture_or_task_name = "processed_cctk" from caet5.models.mesh_transformer import mesh_train_dataset_fn_ll from caet5.data.utils import get_mixture_or_task_ll, MixtureRegistry_ll from mesh_tensorflow_caet5.dataset import pack_or_pad_ll mixture_or_task = get_mixture_or_task_ll("mixture_processed_cctk") with gin.config_scope('caet5'): dsbis = mixture_or_task.get_dataset(split="train", sequence_length=sequence_length) #ds2 = pack_or_pad_ll(dsbis, sequence_length, pack=False, # feature_keys=tuple(mixture_or_task.output_features), ensure_eos=True) def filter_attribute_1_fn(x): return tf.equal(x["attribute"][0], 1) def filter_attribute_2_fn(x): return tf.equal(x["attribute"][0], 2) ds_attribute_1 = dsbis.filter(filter_attribute_1_fn) ds_attribute_2 = dsbis.filter(filter_attribute_2_fn) ds2_attribute_1 = pack_or_pad_ll( ds_attribute_1, sequence_length, pack=False, feature_keys=tuple(mixture_or_task.output_features), ensure_eos=True) # (not straightforward) Adapt packing so that pack=True ds2_attribute_2 = pack_or_pad_ll( ds_attribute_2, sequence_length, pack=False, feature_keys=tuple(mixture_or_task.output_features), ensure_eos=True) # (not straightforward) Adapt packing so that pack=True ds3_attribute_1 = ds2_attribute_1 ds3_attribute_2 = ds2_attribute_2 def f1(): return ds3_attribute_1 def f2(): return ds3_attribute_2 def interleave_map_fn(x): return tf.cond(tf.equal(x, 0), f1, f2) ds3 = tf.data.Dataset.range(2).interleave( interleave_map_fn, cycle_length=2, block_length=4, num_parallel_calls=tf.data.experimental.AUTOTUNE) print("A few preprocessed validation examples...") for ex in tfds.as_numpy(ds3.take(80)): print(ex) """ if FLAGS.use_model_api: # Modifying original T5 in CAE-T5 transformer.make_bitransformer = make_bitransformer_ll utils.tpu_estimator_model_fn = tpu_estimator_model_fn_ll model_parallelism, train_batch_size, keep_checkpoint_max = { "small": (1, 256, 16), "base": (2, 128, 8), "large": (8, 64, 4), "3B": (8, 16, 1), "11B": (8, 16, 1) }[FLAGS.model_size] model = MtfModel_ll( tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=model_dir, model_parallelism=model_parallelism, batch_size=train_batch_size, learning_rate_schedule=0.003, save_checkpoints_steps=2000, keep_checkpoint_max=keep_checkpoint_max, # if ON_CLOUD else None, iterations_per_loop=100, model_type="bitransformer", unsupervised_attribute_transfer_metrics=True) if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps: raise ValueError( "checkpoint_mode is set to %s and checkpoint_steps is " "also set. To use a particular checkpoint, please set " "checkpoint_mode to 'specific'. For other modes, please " "ensure that checkpoint_steps is not set." % FLAGS.checkpoint_mode) if FLAGS.checkpoint_mode == "latest": checkpoint_steps = -1 elif FLAGS.checkpoint_mode == "all": checkpoint_steps = "all" else: checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps] if FLAGS.mode == "finetune": pretrained_dir = os.path.join(FLAGS.base_pretrained_model_dir, FLAGS.model_size) model.finetune(mixture_or_task_name=FLAGS.mixture_or_task, pretrained_model_dir=pretrained_dir, finetune_steps=FLAGS.train_steps) elif FLAGS.mode == "eval": model.batch_size = train_batch_size * 4 model.eval(mixture_or_task_name=FLAGS.mixture_or_task, checkpoint_steps=checkpoint_steps, summary_dir=FLAGS.eval_summary_dir, split=FLAGS.eval_split) # print_random_predictions("yelp", sequence_length, model_dir, n=10) elif FLAGS.mode == "predict": if FLAGS.predict_batch_size > 0: model.batch_size = FLAGS.predict_batch_size model.predict(checkpoint_steps=checkpoint_steps, input_file=FLAGS.input_file, output_file=FLAGS.output_file, temperature=0) else: raise ValueError("--mode flag must be set when using Model API.") else: raise NotImplementedError()
def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) if FLAGS.t5_tfds_data_dir: t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir) t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs) # Add search path for gin files stored in package. gin.add_config_file_search_path( pkg_resources.resource_filename(__name__, "gin")) try: tf.io.gfile.makedirs(FLAGS.model_dir) suffix = 0 command_filename = os.path.join(FLAGS.model_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(FLAGS.model_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) except tf.errors.PermissionDeniedError: logging.info( "No write access to model directory. Skipping command logging.") utils.parse_gin_defaults_and_flags() if FLAGS.use_model_api: model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir) if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps: raise ValueError( "checkpoint_mode is set to %s and checkpoint_steps is " "also set. To use a particular checkpoint, please set " "checkpoint_mode to 'specific'. For other modes, please " "ensure that checkpoint_steps is not set." % FLAGS.checkpoint_mode) if FLAGS.checkpoint_mode == "latest": checkpoint_steps = -1 elif FLAGS.checkpoint_mode == "all": checkpoint_steps = "all" else: checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps] if FLAGS.mode == "train": model.train(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps) elif FLAGS.mode == "eval": model.eval(mixture_or_task_name=FLAGS.mixture_or_task, checkpoint_steps=checkpoint_steps, summary_dir=FLAGS.eval_summary_dir, split=FLAGS.eval_split) elif FLAGS.mode == "finetune": if not (FLAGS.checkpoint_mode == "latest" or (FLAGS.checkpoint_mode == "specific" and len(FLAGS.checkpoint_steps) == 1)): raise ValueError( "Must specify a single checkpoint for finetuning a model.") if isinstance(checkpoint_steps, list): checkpoint_steps = checkpoint_steps[0] model.finetune(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps, pretrained_model_dir=FLAGS.pretrained_model_dir, checkpoint_steps=checkpoint_steps) elif FLAGS.mode == "predict": model.predict(checkpoint_steps=checkpoint_steps, input_file=FLAGS.input_file, output_file=FLAGS.output_file) elif FLAGS.mode == "score": model.score(FLAGS.input_file, FLAGS.target_file, scores_file=FLAGS.output_file, checkpoint_steps=checkpoint_steps) elif FLAGS.mode == "export": if not (FLAGS.checkpoint_mode == "latest" or (FLAGS.checkpoint_mode == "specific" and len(FLAGS.checkpoint_steps) == 1)): raise ValueError( "Must specify a single checkpoint for exporting a model.") if isinstance(checkpoint_steps, list): checkpoint_steps = checkpoint_steps[0] model.batch_size = FLAGS.export_batch_size model.export(export_dir=FLAGS.export_dir, checkpoint_step=checkpoint_steps) else: raise ValueError("--mode flag must be set when using Model API.") else: if FLAGS.mode: raise ValueError( "--mode flag should only be set when using Model API.") if not FLAGS.tpu: with gin.unlock_config(): gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32") gin.bind_parameter("utils.get_variable_dtype.activation_dtype", "float32") utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)
def main(_): if FLAGS.module_import: for module in FLAGS.module_import: importlib.import_module(module) if FLAGS.t5_tfds_data_dir: t5.data.set_tfds_data_dir_override(FLAGS.tfds_data_dir) t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs) # Add search path for gin files stored in package. gin.add_config_file_search_path( pkg_resources.resource_filename(__name__, "gin")) tf.io.gfile.makedirs(FLAGS.model_dir) suffix = 0 command_filename = os.path.join(FLAGS.model_dir, "command") while tf.io.gfile.exists(command_filename): suffix += 1 command_filename = os.path.join(FLAGS.model_dir, "command.{}".format(suffix)) with tf.io.gfile.GFile(command_filename, "w") as f: f.write(" ".join(sys.argv)) utils.parse_gin_defaults_and_flags() if FLAGS.use_model_api: model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir) if FLAGS.checkpoint_mode == "latest": ckpts = tf.io.gfile.glob(FLAGS.model_dir + "model.*index") ckpts = [re.sub(".*ckpt-", "", c) for c in ckpts] ckpts = sorted([int(c.replace(".index", "")) for c in ckpts]) checkpoint_step = ckpts[-1] elif FLAGS.checkpoint_mode == "all": checkpoint_step = "all" else: checkpoint_step = [int(c) for c in FLAGS.checkpoint_steps] if FLAGS.mode == "train": model.train(mixture_or_task_name=FLAGS.mixture_or_task, steps=FLAGS.train_steps) elif FLAGS.mode == "eval": model.eval(mixture_or_task_name=FLAGS.mixture_or_task, checkpoint_step=checkpoint_step, summary_dir=FLAGS.eval_summary_dir, split=FLAGS.eval_split) else: model.predict(checkpoint_step=checkpoint_step, input_file=FLAGS.input_file, output_file=FLAGS.output_file) else: utils.run(tpu_job_name=FLAGS.tpu_job_name, tpu=FLAGS.tpu, gcp_project=FLAGS.gcp_project, tpu_zone=FLAGS.tpu_zone, model_dir=FLAGS.model_dir)