def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.tasks_cache_dir:
        t5.data.set_global_cache_dirs(FLAGS.tasks_cache_dir)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))

    tf.io.gfile.makedirs(FLAGS.model_dir)
    suffix = 0
    command_filename = os.path.join(FLAGS.model_dir, "command")
    while tf.io.gfile.exists(command_filename):
        suffix += 1
        command_filename = os.path.join(FLAGS.model_dir,
                                        "command.{}".format(suffix))
    with tf.io.gfile.GFile(command_filename, "w") as f:
        f.write(" ".join(sys.argv))

    utils.parse_gin_defaults_and_flags()
    utils.run(tpu_job_name=FLAGS.tpu_job_name,
              tpu=FLAGS.tpu,
              gcp_project=FLAGS.gcp_project,
              tpu_zone=FLAGS.tpu_zone,
              model_dir=FLAGS.model_dir)
Пример #2
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    tf.io.gfile.makedirs(FLAGS.model_dir)
    with tf.io.gfile.GFile(os.path.join(FLAGS.model_dir, "command"), "w") as f:
        f.write(" ".join(sys.argv))

    utils.parse_gin_defaults_and_flags()
    utils.run(tpu_job_name=FLAGS.tpu_job_name,
              tpu=FLAGS.tpu,
              gcp_project=FLAGS.gcp_project,
              tpu_zone=FLAGS.tpu_zone,
              model_dir=FLAGS.model_dir)
Пример #3
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    tf.io.gfile.makedirs(FLAGS.model_dir)
    suffix = 0
    command_filename = os.path.join(FLAGS.model_dir, "command")
    while tf.io.gfile.exists(command_filename):
        suffix += 1
        command_filename = os.path.join(FLAGS.model_dir,
                                        "command.{}".format(suffix))
    with tf.io.gfile.GFile(command_filename, "w") as f:
        f.write(" ".join(sys.argv))

    utils.parse_gin_defaults_and_flags()
    utils.run(tpu_job_name=FLAGS.tpu_job_name,
              tpu=FLAGS.tpu,
              gcp_project=FLAGS.gcp_project,
              tpu_zone=FLAGS.tpu_zone,
              model_dir=FLAGS.model_dir)
Пример #4
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.t5_tfds_data_dir:
        t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)
    t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))
    try:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        suffix = 0
        command_filename = os.path.join(FLAGS.model_dir, "command")
        while tf.io.gfile.exists(command_filename):
            suffix += 1
            command_filename = os.path.join(FLAGS.model_dir,
                                            "command.{}".format(suffix))
        with tf.io.gfile.GFile(command_filename, "w") as f:
            f.write(" ".join(sys.argv))
    except tf.errors.PermissionDeniedError:
        logging.info(
            "No write access to model directory. Skipping command logging.")

    utils.parse_gin_defaults_and_flags()

    if FLAGS.use_model_api:
        model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name,
                                   tpu=FLAGS.tpu,
                                   gcp_project=FLAGS.gcp_project,
                                   tpu_zone=FLAGS.tpu_zone,
                                   model_dir=FLAGS.model_dir)

        if FLAGS.checkpoint_mode == "latest":
            checkpoint_steps = -1
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_steps = "all"
        else:
            checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "train":
            model.train(mixture_or_task_name=FLAGS.mixture_or_task,
                        steps=FLAGS.train_steps)
        elif FLAGS.mode == "eval":
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_steps=checkpoint_steps,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)
        elif FLAGS.mode == "finetune":

            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for finetuning a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.finetune(mixture_or_task_name=FLAGS.mixture_or_task,
                           steps=FLAGS.train_steps,
                           pretrained_model_dir=FLAGS.pretrained_model_dir,
                           checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode == "predict":
            model.predict(checkpoint_steps=checkpoint_steps,
                          input_file=FLAGS.input_file,
                          output_file=FLAGS.output_file)
        elif FLAGS.mode == "export":
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for exporting a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.export(export_dir=FLAGS.export_dir,
                         checkpoint_step=checkpoint_steps)
    else:
        utils.run(tpu_job_name=FLAGS.tpu_job_name,
                  tpu=FLAGS.tpu,
                  gcp_project=FLAGS.gcp_project,
                  tpu_zone=FLAGS.tpu_zone,
                  model_dir=FLAGS.model_dir)
Пример #5
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.t5_tfds_data_dir:
        t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))
    try:
        suffix = 0
        command_dir = os.path.join(FLAGS.model_dir, "commands")
        tf.io.gfile.makedirs(command_dir)
        command_filename = os.path.join(command_dir, "command")
        while tf.io.gfile.exists(command_filename):
            suffix += 1
            command_filename = os.path.join(command_dir,
                                            "command.{}".format(suffix))
        with tf.io.gfile.GFile(command_filename, "w") as f:
            f.write(" ".join(sys.argv))
    except (tf.errors.PermissionDeniedError, tf.errors.InvalidArgumentError):
        logging.info(
            "No write access to model directory. Skipping command logging.")

    utils.parse_gin_defaults_and_flags(
        skip_unknown=(FLAGS.skip_all_gin_unknowns
                      or (mesh_transformer.DEPRECATED_GIN_REFERENCES +
                          tuple(FLAGS.additional_deprecated_gin_references))),
        finalize_config=False)
    # We must overide this binding explicitly since it is set to a deprecated
    # function or class in many existing configs.
    gin.bind_parameter("run.vocabulary", mesh_transformer.get_vocabulary())
    gin.finalize()

    # Set cache dir after loading gin to avoid unintentionally overriding it.
    t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

    if FLAGS.use_model_api:
        model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name,
                                   tpu=FLAGS.tpu,
                                   gcp_project=FLAGS.gcp_project,
                                   tpu_zone=FLAGS.tpu_zone,
                                   tpu_topology=FLAGS.tpu_topology,
                                   model_parallelism=FLAGS.model_parallelism,
                                   model_dir=FLAGS.model_dir,
                                   batch_size=FLAGS.batch_size,
                                   sequence_length={
                                       "inputs": FLAGS.input_sequence_length,
                                       "targets": FLAGS.target_sequence_length
                                   })

        if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps:
            raise ValueError(
                "checkpoint_mode is set to %s and checkpoint_steps is "
                "also set. To use a particular checkpoint, please set "
                "checkpoint_mode to 'specific'. For other modes, please "
                "ensure that checkpoint_steps is not set." %
                FLAGS.checkpoint_mode)

        if FLAGS.checkpoint_mode == "latest":
            checkpoint_steps = -1
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_steps = "all"
        else:
            checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "train":
            model.train(mixture_or_task_name=FLAGS.mixture_or_task,
                        steps=FLAGS.train_steps)
        elif FLAGS.mode == "eval":
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_steps=checkpoint_steps,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)
        elif FLAGS.mode == "finetune":
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for finetuning a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.finetune(mixture_or_task_name=FLAGS.mixture_or_task,
                           steps=FLAGS.train_steps,
                           pretrained_model_dir=FLAGS.pretrained_model_dir,
                           checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode == "predict":
            model.predict(
                checkpoint_steps=checkpoint_steps,
                input_file=FLAGS.input_file,
                output_file=FLAGS.output_file,
                beam_size=FLAGS.beam_size,
                temperature=FLAGS.temperature,
                keep_top_k=FLAGS.keep_top_k,
            )
        elif FLAGS.mode == "score":
            model.score(FLAGS.input_file,
                        FLAGS.target_file,
                        scores_file=FLAGS.output_file,
                        checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode in ("export_predict", "export_score"):
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for exporting a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.export(export_dir=FLAGS.export_dir,
                         checkpoint_step=checkpoint_steps,
                         beam_size=FLAGS.beam_size,
                         temperature=FLAGS.temperature,
                         keep_top_k=FLAGS.keep_top_k,
                         eval_with_score=(FLAGS.mode == "export_score"))
        else:
            raise ValueError("--mode flag must be set when using Model API.")
    else:
        if FLAGS.mode:
            raise ValueError(
                "--mode flag should only be set when using Model API.")
        if not FLAGS.tpu:
            with gin.unlock_config():
                gin.bind_parameter("utils.get_variable_dtype.slice_dtype",
                                   "float32")
                gin.bind_parameter("utils.get_variable_dtype.activation_dtype",
                                   "float32")
        utils.run(tpu_job_name=FLAGS.tpu_job_name,
                  tpu=FLAGS.tpu,
                  gcp_project=FLAGS.gcp_project,
                  tpu_zone=FLAGS.tpu_zone,
                  model_dir=FLAGS.model_dir)
Пример #6
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))

    models_dir_name = FLAGS.model_dir_name
    if FLAGS.model_dir_counter >= 0:
        models_dir_name += "_%s" % str(FLAGS.model_dir_counter)
    models_dir = os.path.join(FLAGS.base_dir, models_dir_name)

    model_dir = os.path.join(models_dir, FLAGS.model_size)
    try:
        tf.io.gfile.makedirs(model_dir)
        suffix = 0
        command_filename = os.path.join(model_dir, "command")
        while tf.io.gfile.exists(command_filename):
            suffix += 1
            command_filename = os.path.join(model_dir,
                                            "command.{}".format(suffix))
        with tf.io.gfile.GFile(command_filename, "w") as f:
            f.write(" ".join(sys.argv))
    except tf.errors.PermissionDeniedError:
        logging.info(
            "No write access to model directory. Skipping command logging.")

    utils.parse_gin_defaults_and_flags()

    # Load and print a few examples.
    st_task = TaskRegistry_ll.get("processed_cctk")
    sequence_length = {"inputs": 64, "targets": 64}
    sequence_length[
        "attribute"] = 64  # Or "attribute": 1 but packing not efficient...
    sequence_length["codeprefixedtargets"] = 64
    sequence_length["controlcode"] = 64

    with gin.config_scope('caet5'):
        ds = st_task.get_dataset(split="validation",
                                 sequence_length=sequence_length)

    print("A few preprocessed validation examples...")
    for ex in tfds.as_numpy(ds.take(5)):
        print(ex)
    """
    print("unitests")

    mixture_or_task_name = "processed_cctk"
    from caet5.models.mesh_transformer import mesh_train_dataset_fn_ll
    from caet5.data.utils import get_mixture_or_task_ll, MixtureRegistry_ll

    from mesh_tensorflow_caet5.dataset import pack_or_pad_ll

    mixture_or_task = get_mixture_or_task_ll("mixture_processed_cctk")

    with gin.config_scope('caet5'):
        dsbis = mixture_or_task.get_dataset(split="train", sequence_length=sequence_length)

    
    #ds2 = pack_or_pad_ll(dsbis, sequence_length, pack=False,
    #                     feature_keys=tuple(mixture_or_task.output_features), ensure_eos=True)
    

    def filter_attribute_1_fn(x):
        return tf.equal(x["attribute"][0], 1)

    def filter_attribute_2_fn(x):
        return tf.equal(x["attribute"][0], 2)

    ds_attribute_1 = dsbis.filter(filter_attribute_1_fn)
    ds_attribute_2 = dsbis.filter(filter_attribute_2_fn)

    ds2_attribute_1 = pack_or_pad_ll(
        ds_attribute_1, sequence_length, pack=False,
        feature_keys=tuple(mixture_or_task.output_features),
        ensure_eos=True)  # (not straightforward) Adapt packing so that pack=True
    ds2_attribute_2 = pack_or_pad_ll(
        ds_attribute_2, sequence_length, pack=False,
        feature_keys=tuple(mixture_or_task.output_features),
        ensure_eos=True)  # (not straightforward) Adapt packing so that pack=True

    ds3_attribute_1 = ds2_attribute_1
    ds3_attribute_2 = ds2_attribute_2

    def f1():
        return ds3_attribute_1

    def f2():
        return ds3_attribute_2

    def interleave_map_fn(x):
        return tf.cond(tf.equal(x, 0), f1, f2)

    ds3 = tf.data.Dataset.range(2).interleave(
        interleave_map_fn, cycle_length=2,
        block_length=4,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    print("A few preprocessed validation examples...")
    for ex in tfds.as_numpy(ds3.take(80)):
        print(ex)
    """

    if FLAGS.use_model_api:
        # Modifying original T5 in CAE-T5
        transformer.make_bitransformer = make_bitransformer_ll
        utils.tpu_estimator_model_fn = tpu_estimator_model_fn_ll

        model_parallelism, train_batch_size, keep_checkpoint_max = {
            "small": (1, 256, 16),
            "base": (2, 128, 8),
            "large": (8, 64, 4),
            "3B": (8, 16, 1),
            "11B": (8, 16, 1)
        }[FLAGS.model_size]

        model = MtfModel_ll(
            tpu_job_name=FLAGS.tpu_job_name,
            tpu=FLAGS.tpu,
            gcp_project=FLAGS.gcp_project,
            tpu_zone=FLAGS.tpu_zone,
            model_dir=model_dir,
            model_parallelism=model_parallelism,
            batch_size=train_batch_size,
            learning_rate_schedule=0.003,
            save_checkpoints_steps=2000,
            keep_checkpoint_max=keep_checkpoint_max,  # if ON_CLOUD else None,
            iterations_per_loop=100,
            model_type="bitransformer",
            unsupervised_attribute_transfer_metrics=True)

        if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps:
            raise ValueError(
                "checkpoint_mode is set to %s and checkpoint_steps is "
                "also set. To use a particular checkpoint, please set "
                "checkpoint_mode to 'specific'. For other modes, please "
                "ensure that checkpoint_steps is not set." %
                FLAGS.checkpoint_mode)

        if FLAGS.checkpoint_mode == "latest":
            checkpoint_steps = -1
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_steps = "all"
        else:
            checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "finetune":
            pretrained_dir = os.path.join(FLAGS.base_pretrained_model_dir,
                                          FLAGS.model_size)

            model.finetune(mixture_or_task_name=FLAGS.mixture_or_task,
                           pretrained_model_dir=pretrained_dir,
                           finetune_steps=FLAGS.train_steps)

        elif FLAGS.mode == "eval":
            model.batch_size = train_batch_size * 4
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_steps=checkpoint_steps,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)

            # print_random_predictions("yelp", sequence_length, model_dir, n=10)

        elif FLAGS.mode == "predict":
            if FLAGS.predict_batch_size > 0:
                model.batch_size = FLAGS.predict_batch_size
            model.predict(checkpoint_steps=checkpoint_steps,
                          input_file=FLAGS.input_file,
                          output_file=FLAGS.output_file,
                          temperature=0)
        else:
            raise ValueError("--mode flag must be set when using Model API.")

    else:
        raise NotImplementedError()
Пример #7
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.t5_tfds_data_dir:
        t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)
    t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))
    try:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        suffix = 0
        command_filename = os.path.join(FLAGS.model_dir, "command")
        while tf.io.gfile.exists(command_filename):
            suffix += 1
            command_filename = os.path.join(FLAGS.model_dir,
                                            "command.{}".format(suffix))
        with tf.io.gfile.GFile(command_filename, "w") as f:
            f.write(" ".join(sys.argv))
    except tf.errors.PermissionDeniedError:
        logging.info(
            "No write access to model directory. Skipping command logging.")

    utils.parse_gin_defaults_and_flags()

    if FLAGS.use_model_api:
        model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name,
                                   tpu=FLAGS.tpu,
                                   gcp_project=FLAGS.gcp_project,
                                   tpu_zone=FLAGS.tpu_zone,
                                   model_dir=FLAGS.model_dir)

        if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps:
            raise ValueError(
                "checkpoint_mode is set to %s and checkpoint_steps is "
                "also set. To use a particular checkpoint, please set "
                "checkpoint_mode to 'specific'. For other modes, please "
                "ensure that checkpoint_steps is not set." %
                FLAGS.checkpoint_mode)

        if FLAGS.checkpoint_mode == "latest":
            checkpoint_steps = -1
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_steps = "all"
        else:
            checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "train":
            model.train(mixture_or_task_name=FLAGS.mixture_or_task,
                        steps=FLAGS.train_steps)
        elif FLAGS.mode == "eval":
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_steps=checkpoint_steps,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)
        elif FLAGS.mode == "finetune":
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for finetuning a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.finetune(mixture_or_task_name=FLAGS.mixture_or_task,
                           steps=FLAGS.train_steps,
                           pretrained_model_dir=FLAGS.pretrained_model_dir,
                           checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode == "predict":
            model.predict(checkpoint_steps=checkpoint_steps,
                          input_file=FLAGS.input_file,
                          output_file=FLAGS.output_file)
        elif FLAGS.mode == "score":
            model.score(FLAGS.input_file,
                        FLAGS.target_file,
                        scores_file=FLAGS.output_file,
                        checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode == "export":
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for exporting a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.batch_size = FLAGS.export_batch_size
            model.export(export_dir=FLAGS.export_dir,
                         checkpoint_step=checkpoint_steps)
        else:
            raise ValueError("--mode flag must be set when using Model API.")
    else:
        if FLAGS.mode:
            raise ValueError(
                "--mode flag should only be set when using Model API.")
        if not FLAGS.tpu:
            with gin.unlock_config():
                gin.bind_parameter("utils.get_variable_dtype.slice_dtype",
                                   "float32")
                gin.bind_parameter("utils.get_variable_dtype.activation_dtype",
                                   "float32")
        utils.run(tpu_job_name=FLAGS.tpu_job_name,
                  tpu=FLAGS.tpu,
                  gcp_project=FLAGS.gcp_project,
                  tpu_zone=FLAGS.tpu_zone,
                  model_dir=FLAGS.model_dir)
Пример #8
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.t5_tfds_data_dir:
        t5.data.set_tfds_data_dir_override(FLAGS.tfds_data_dir)
    t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))

    tf.io.gfile.makedirs(FLAGS.model_dir)
    suffix = 0
    command_filename = os.path.join(FLAGS.model_dir, "command")
    while tf.io.gfile.exists(command_filename):
        suffix += 1
        command_filename = os.path.join(FLAGS.model_dir,
                                        "command.{}".format(suffix))
    with tf.io.gfile.GFile(command_filename, "w") as f:
        f.write(" ".join(sys.argv))

    utils.parse_gin_defaults_and_flags()

    if FLAGS.use_model_api:
        model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name,
                                   tpu=FLAGS.tpu,
                                   gcp_project=FLAGS.gcp_project,
                                   tpu_zone=FLAGS.tpu_zone,
                                   model_dir=FLAGS.model_dir)

        if FLAGS.checkpoint_mode == "latest":
            ckpts = tf.io.gfile.glob(FLAGS.model_dir + "model.*index")
            ckpts = [re.sub(".*ckpt-", "", c) for c in ckpts]
            ckpts = sorted([int(c.replace(".index", "")) for c in ckpts])
            checkpoint_step = ckpts[-1]
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_step = "all"
        else:
            checkpoint_step = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "train":
            model.train(mixture_or_task_name=FLAGS.mixture_or_task,
                        steps=FLAGS.train_steps)
        elif FLAGS.mode == "eval":
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_step=checkpoint_step,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)
        else:
            model.predict(checkpoint_step=checkpoint_step,
                          input_file=FLAGS.input_file,
                          output_file=FLAGS.output_file)

    else:
        utils.run(tpu_job_name=FLAGS.tpu_job_name,
                  tpu=FLAGS.tpu,
                  gcp_project=FLAGS.gcp_project,
                  tpu_zone=FLAGS.tpu_zone,
                  model_dir=FLAGS.model_dir)