def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.tasks_cache_dir:
        t5.data.set_global_cache_dirs(FLAGS.tasks_cache_dir)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))

    tf.io.gfile.makedirs(FLAGS.model_dir)
    suffix = 0
    command_filename = os.path.join(FLAGS.model_dir, "command")
    while tf.io.gfile.exists(command_filename):
        suffix += 1
        command_filename = os.path.join(FLAGS.model_dir,
                                        "command.{}".format(suffix))
    with tf.io.gfile.GFile(command_filename, "w") as f:
        f.write(" ".join(sys.argv))

    utils.parse_gin_defaults_and_flags()
    utils.run(tpu_job_name=FLAGS.tpu_job_name,
              tpu=FLAGS.tpu,
              gcp_project=FLAGS.gcp_project,
              tpu_zone=FLAGS.tpu_zone,
              model_dir=FLAGS.model_dir)
예제 #2
0
def main(_):
    # Set up the default values for the configurable parameters. These values will
    # be overridden by any user provided gin files/parameters.
    gin.parse_config_file(
        os.path.join(os.path.dirname(__file__), _DEFAULT_CONFIG_FILE))
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
    utils.run(tpu_job_name=FLAGS.tpu_job_name,
              tpu=FLAGS.tpu,
              gcp_project=FLAGS.gcp_project,
              tpu_zone=FLAGS.tpu_zone,
              model_dir=FLAGS.model_dir)
예제 #3
0
def run(**kwargs) -> None:
    """Runs a T5 model for training, finetuning, evaluation etc."""
    tf.disable_v2_behavior()

    if gin.query_parameter("utils.run.mode") == "eval":
        # Increase the recursion limit, see: https://github.com/pltrdy/rouge/issues/19
        length = gin.query_parameter("utils.run.sequence_length").get(
            "inputs", 512)
        batch_size = 1024  # TODO: do not hardcode batch_size for recursionlimit calc
        sys.setrecursionlimit(batch_size * length + 10)

    utils.run(**kwargs)
예제 #4
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    tf.io.gfile.makedirs(FLAGS.model_dir)
    with tf.io.gfile.GFile(os.path.join(FLAGS.model_dir, "command"), "w") as f:
        f.write(" ".join(sys.argv))

    utils.parse_gin_defaults_and_flags()
    utils.run(tpu_job_name=FLAGS.tpu_job_name,
              tpu=FLAGS.tpu,
              gcp_project=FLAGS.gcp_project,
              tpu_zone=FLAGS.tpu_zone,
              model_dir=FLAGS.model_dir)
예제 #5
0
def main(_):
    # Set up the default values for the configurable parameters. These values will
    # be overridden by any user provided gin files/parameters.
    gin.parse_config_file(
        os.path.join(os.path.dirname(__file__), _DEFAULT_CONFIG_FILE))
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
    dataset = utils.get_tfds_dataset(data_dir=FLAGS.data_dir)
    utils.run(dataset=dataset,
              tpu_job_name=FLAGS.tpu_job_name,
              master_dtype=FLAGS.master_dtype,
              slice_dtype=FLAGS.slice_dtype,
              activation_dtype=FLAGS.activation_dtype,
              tpu=FLAGS.tpu,
              gcp_project=FLAGS.gcp_project,
              tpu_zone=FLAGS.tpu_zone,
              autostack=FLAGS.autostack,
              model_dir=FLAGS.model_dir)
예제 #6
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    tf.io.gfile.makedirs(FLAGS.model_dir)
    suffix = 0
    command_filename = os.path.join(FLAGS.model_dir, "command")
    while tf.io.gfile.exists(command_filename):
        suffix += 1
        command_filename = os.path.join(FLAGS.model_dir,
                                        "command.{}".format(suffix))
    with tf.io.gfile.GFile(command_filename, "w") as f:
        f.write(" ".join(sys.argv))

    utils.parse_gin_defaults_and_flags()
    utils.run(tpu_job_name=FLAGS.tpu_job_name,
              tpu=FLAGS.tpu,
              gcp_project=FLAGS.gcp_project,
              tpu_zone=FLAGS.tpu_zone,
              model_dir=FLAGS.model_dir)
예제 #7
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.t5_tfds_data_dir:
        t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)
    t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))
    try:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        suffix = 0
        command_filename = os.path.join(FLAGS.model_dir, "command")
        while tf.io.gfile.exists(command_filename):
            suffix += 1
            command_filename = os.path.join(FLAGS.model_dir,
                                            "command.{}".format(suffix))
        with tf.io.gfile.GFile(command_filename, "w") as f:
            f.write(" ".join(sys.argv))
    except tf.errors.PermissionDeniedError:
        logging.info(
            "No write access to model directory. Skipping command logging.")

    utils.parse_gin_defaults_and_flags()

    if FLAGS.use_model_api:
        model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name,
                                   tpu=FLAGS.tpu,
                                   gcp_project=FLAGS.gcp_project,
                                   tpu_zone=FLAGS.tpu_zone,
                                   model_dir=FLAGS.model_dir)

        if FLAGS.checkpoint_mode == "latest":
            checkpoint_steps = -1
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_steps = "all"
        else:
            checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "train":
            model.train(mixture_or_task_name=FLAGS.mixture_or_task,
                        steps=FLAGS.train_steps)
        elif FLAGS.mode == "eval":
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_steps=checkpoint_steps,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)
        elif FLAGS.mode == "finetune":

            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for finetuning a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.finetune(mixture_or_task_name=FLAGS.mixture_or_task,
                           steps=FLAGS.train_steps,
                           pretrained_model_dir=FLAGS.pretrained_model_dir,
                           checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode == "predict":
            model.predict(checkpoint_steps=checkpoint_steps,
                          input_file=FLAGS.input_file,
                          output_file=FLAGS.output_file)
        elif FLAGS.mode == "export":
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for exporting a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.export(export_dir=FLAGS.export_dir,
                         checkpoint_step=checkpoint_steps)
    else:
        utils.run(tpu_job_name=FLAGS.tpu_job_name,
                  tpu=FLAGS.tpu,
                  gcp_project=FLAGS.gcp_project,
                  tpu_zone=FLAGS.tpu_zone,
                  model_dir=FLAGS.model_dir)
예제 #8
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.t5_tfds_data_dir:
        t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))
    try:
        suffix = 0
        command_dir = os.path.join(FLAGS.model_dir, "commands")
        tf.io.gfile.makedirs(command_dir)
        command_filename = os.path.join(command_dir, "command")
        while tf.io.gfile.exists(command_filename):
            suffix += 1
            command_filename = os.path.join(command_dir,
                                            "command.{}".format(suffix))
        with tf.io.gfile.GFile(command_filename, "w") as f:
            f.write(" ".join(sys.argv))
    except (tf.errors.PermissionDeniedError, tf.errors.InvalidArgumentError):
        logging.info(
            "No write access to model directory. Skipping command logging.")

    utils.parse_gin_defaults_and_flags(
        skip_unknown=(FLAGS.skip_all_gin_unknowns
                      or (mesh_transformer.DEPRECATED_GIN_REFERENCES +
                          tuple(FLAGS.additional_deprecated_gin_references))),
        finalize_config=False)
    # We must overide this binding explicitly since it is set to a deprecated
    # function or class in many existing configs.
    gin.bind_parameter("run.vocabulary", mesh_transformer.get_vocabulary())
    gin.finalize()

    # Set cache dir after loading gin to avoid unintentionally overriding it.
    t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

    if FLAGS.use_model_api:
        model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name,
                                   tpu=FLAGS.tpu,
                                   gcp_project=FLAGS.gcp_project,
                                   tpu_zone=FLAGS.tpu_zone,
                                   tpu_topology=FLAGS.tpu_topology,
                                   model_parallelism=FLAGS.model_parallelism,
                                   model_dir=FLAGS.model_dir,
                                   batch_size=FLAGS.batch_size,
                                   sequence_length={
                                       "inputs": FLAGS.input_sequence_length,
                                       "targets": FLAGS.target_sequence_length
                                   })

        if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps:
            raise ValueError(
                "checkpoint_mode is set to %s and checkpoint_steps is "
                "also set. To use a particular checkpoint, please set "
                "checkpoint_mode to 'specific'. For other modes, please "
                "ensure that checkpoint_steps is not set." %
                FLAGS.checkpoint_mode)

        if FLAGS.checkpoint_mode == "latest":
            checkpoint_steps = -1
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_steps = "all"
        else:
            checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "train":
            model.train(mixture_or_task_name=FLAGS.mixture_or_task,
                        steps=FLAGS.train_steps)
        elif FLAGS.mode == "eval":
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_steps=checkpoint_steps,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)
        elif FLAGS.mode == "finetune":
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for finetuning a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.finetune(mixture_or_task_name=FLAGS.mixture_or_task,
                           steps=FLAGS.train_steps,
                           pretrained_model_dir=FLAGS.pretrained_model_dir,
                           checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode == "predict":
            model.predict(
                checkpoint_steps=checkpoint_steps,
                input_file=FLAGS.input_file,
                output_file=FLAGS.output_file,
                beam_size=FLAGS.beam_size,
                temperature=FLAGS.temperature,
                keep_top_k=FLAGS.keep_top_k,
            )
        elif FLAGS.mode == "score":
            model.score(FLAGS.input_file,
                        FLAGS.target_file,
                        scores_file=FLAGS.output_file,
                        checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode in ("export_predict", "export_score"):
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for exporting a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.export(export_dir=FLAGS.export_dir,
                         checkpoint_step=checkpoint_steps,
                         beam_size=FLAGS.beam_size,
                         temperature=FLAGS.temperature,
                         keep_top_k=FLAGS.keep_top_k,
                         eval_with_score=(FLAGS.mode == "export_score"))
        else:
            raise ValueError("--mode flag must be set when using Model API.")
    else:
        if FLAGS.mode:
            raise ValueError(
                "--mode flag should only be set when using Model API.")
        if not FLAGS.tpu:
            with gin.unlock_config():
                gin.bind_parameter("utils.get_variable_dtype.slice_dtype",
                                   "float32")
                gin.bind_parameter("utils.get_variable_dtype.activation_dtype",
                                   "float32")
        utils.run(tpu_job_name=FLAGS.tpu_job_name,
                  tpu=FLAGS.tpu,
                  gcp_project=FLAGS.gcp_project,
                  tpu_zone=FLAGS.tpu_zone,
                  model_dir=FLAGS.model_dir)
예제 #9
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.t5_tfds_data_dir:
        t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)
    t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))
    try:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        suffix = 0
        command_filename = os.path.join(FLAGS.model_dir, "command")
        while tf.io.gfile.exists(command_filename):
            suffix += 1
            command_filename = os.path.join(FLAGS.model_dir,
                                            "command.{}".format(suffix))
        with tf.io.gfile.GFile(command_filename, "w") as f:
            f.write(" ".join(sys.argv))
    except tf.errors.PermissionDeniedError:
        logging.info(
            "No write access to model directory. Skipping command logging.")

    utils.parse_gin_defaults_and_flags()

    if FLAGS.use_model_api:
        model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name,
                                   tpu=FLAGS.tpu,
                                   gcp_project=FLAGS.gcp_project,
                                   tpu_zone=FLAGS.tpu_zone,
                                   model_dir=FLAGS.model_dir)

        if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps:
            raise ValueError(
                "checkpoint_mode is set to %s and checkpoint_steps is "
                "also set. To use a particular checkpoint, please set "
                "checkpoint_mode to 'specific'. For other modes, please "
                "ensure that checkpoint_steps is not set." %
                FLAGS.checkpoint_mode)

        if FLAGS.checkpoint_mode == "latest":
            checkpoint_steps = -1
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_steps = "all"
        else:
            checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "train":
            model.train(mixture_or_task_name=FLAGS.mixture_or_task,
                        steps=FLAGS.train_steps)
        elif FLAGS.mode == "eval":
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_steps=checkpoint_steps,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)
        elif FLAGS.mode == "finetune":
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for finetuning a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.finetune(mixture_or_task_name=FLAGS.mixture_or_task,
                           steps=FLAGS.train_steps,
                           pretrained_model_dir=FLAGS.pretrained_model_dir,
                           checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode == "predict":
            model.predict(checkpoint_steps=checkpoint_steps,
                          input_file=FLAGS.input_file,
                          output_file=FLAGS.output_file)
        elif FLAGS.mode == "score":
            model.score(FLAGS.input_file,
                        FLAGS.target_file,
                        scores_file=FLAGS.output_file,
                        checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode == "export":
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for exporting a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.batch_size = FLAGS.export_batch_size
            model.export(export_dir=FLAGS.export_dir,
                         checkpoint_step=checkpoint_steps)
        else:
            raise ValueError("--mode flag must be set when using Model API.")
    else:
        if FLAGS.mode:
            raise ValueError(
                "--mode flag should only be set when using Model API.")
        if not FLAGS.tpu:
            with gin.unlock_config():
                gin.bind_parameter("utils.get_variable_dtype.slice_dtype",
                                   "float32")
                gin.bind_parameter("utils.get_variable_dtype.activation_dtype",
                                   "float32")
        utils.run(tpu_job_name=FLAGS.tpu_job_name,
                  tpu=FLAGS.tpu,
                  gcp_project=FLAGS.gcp_project,
                  tpu_zone=FLAGS.tpu_zone,
                  model_dir=FLAGS.model_dir)
예제 #10
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.t5_tfds_data_dir:
        t5.data.set_tfds_data_dir_override(FLAGS.tfds_data_dir)
    t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))

    tf.io.gfile.makedirs(FLAGS.model_dir)
    suffix = 0
    command_filename = os.path.join(FLAGS.model_dir, "command")
    while tf.io.gfile.exists(command_filename):
        suffix += 1
        command_filename = os.path.join(FLAGS.model_dir,
                                        "command.{}".format(suffix))
    with tf.io.gfile.GFile(command_filename, "w") as f:
        f.write(" ".join(sys.argv))

    utils.parse_gin_defaults_and_flags()

    if FLAGS.use_model_api:
        model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name,
                                   tpu=FLAGS.tpu,
                                   gcp_project=FLAGS.gcp_project,
                                   tpu_zone=FLAGS.tpu_zone,
                                   model_dir=FLAGS.model_dir)

        if FLAGS.checkpoint_mode == "latest":
            ckpts = tf.io.gfile.glob(FLAGS.model_dir + "model.*index")
            ckpts = [re.sub(".*ckpt-", "", c) for c in ckpts]
            ckpts = sorted([int(c.replace(".index", "")) for c in ckpts])
            checkpoint_step = ckpts[-1]
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_step = "all"
        else:
            checkpoint_step = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "train":
            model.train(mixture_or_task_name=FLAGS.mixture_or_task,
                        steps=FLAGS.train_steps)
        elif FLAGS.mode == "eval":
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_step=checkpoint_step,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)
        else:
            model.predict(checkpoint_step=checkpoint_step,
                          input_file=FLAGS.input_file,
                          output_file=FLAGS.output_file)

    else:
        utils.run(tpu_job_name=FLAGS.tpu_job_name,
                  tpu=FLAGS.tpu,
                  gcp_project=FLAGS.gcp_project,
                  tpu_zone=FLAGS.tpu_zone,
                  model_dir=FLAGS.model_dir)