def define_translate_flags():
    """Define flags used for translation script."""
    # Model flags
    flags.DEFINE_string(
        name="model_dir",
        short_name="md",
        default="/tmp/transformer_model",
        help=flags_core.help_wrap(
            "Directory containing Transformer model checkpoints."))
    flags.DEFINE_enum(
        name="param_set",
        short_name="mp",
        default="big",
        enum_values=["base", "big", "hkh"],
        help=flags_core.help_wrap(
            "Parameter set to use when creating and training the model. The "
            "parameters define the input shape (batch size and max length), "
            "model configuration (size of embedding, # of hidden layers, etc.), "
            "and various other settings. The big parameter set increases the "
            "default batch size, embedding/hidden size, and filter size. For a "
            "complete list of parameters, please see model/model_params.py."))
    flags.DEFINE_string(
        name="vocab_file",
        short_name="vf",
        default=None,
        help=flags_core.help_wrap(
            "Path to vocabulary file. If data_download.py was used to "
            "download and encode the training data, look in the data_dir to find "
            "the vocab file."))
    flags.mark_flag_as_required("vocab_file")
    flags.DEFINE_string(
        name="subword_option",
        short_name="so",
        default="bpe",
        help=flags_core.help_wrap("Possible values: ['', 'bpe', 'spm']"))

    flags.DEFINE_string(
        name="text",
        default=None,
        help=flags_core.help_wrap(
            "Text to translate. Output will be printed to console."))
    flags.DEFINE_string(
        name="file",
        default=None,
        help=flags_core.help_wrap(
            "File containing text to translate. Translation will be printed to "
            "console and, if --file_out is provided, saved to an output file.")
    )
    flags.DEFINE_string(
        name="file_out",
        default=None,
        help=flags_core.help_wrap(
            "If --file flag is specified, save translation to this file."))
Esempio n. 2
0
def define_compute_bleu_flags():
    """Add flags for computing BLEU score."""
    flags.DEFINE_string(
        name="translation",
        default=None,
        help=flags_core.help_wrap("File containing translated text."))
    flags.mark_flag_as_required("translation")

    flags.DEFINE_string(
        name="reference",
        default=None,
        help=flags_core.help_wrap("File containing reference translation."))
    flags.mark_flag_as_required("reference")

    flags.DEFINE_enum(
        name="bleu_variant",
        short_name="bv",
        default="both",
        enum_values=["both", "uncased", "cased"],
        case_sensitive=False,
        help=flags_core.help_wrap(
            "Specify one or more BLEU variants to calculate. Variants: \"cased\""
            ", \"uncased\", or \"both\"."))
Esempio n. 3
0
def define_data_download_flags():
    flags.DEFINE_string(
        name="data_dir",
        short_name="dd",
        default="/hdd/data/iwslt18/open-subtitles/tf_data_subtoken",
        help=flags_core.help_wrap(""))

    flags.DEFINE_string(name="vocab_prefix",
                        short_name="vp",
                        default="vocab.subtoken",
                        help=flags_core.help_wrap(""))

    flags.DEFINE_integer(name="vocab_size",
                         short_name="vs",
                         default=16000,
                         help=flags_core.help_wrap(""))

    flags.DEFINE_string(
        name="train_prefix",
        short_name="tp",
        default="/hdd/data/iwslt18/open-subtitles/data.tok/train.tok.clean",
        help=flags_core.help_wrap(""))

    flags.DEFINE_string(
        name="dev_prefix",
        short_name="dp",
        default="/hdd/data/iwslt18/open-subtitles/data.tok/dev.tok.clean",
        help=flags_core.help_wrap(""))

    flags.DEFINE_string(name="src",
                        short_name="src",
                        default="eu",
                        help=flags_core.help_wrap(""))

    flags.DEFINE_string(name="tgt",
                        short_name="tgt",
                        default="en",
                        help=flags_core.help_wrap(""))

    flags.DEFINE_bool(
        name="search",
        default=True,
        help=flags_core.help_wrap(
            "If set, use binary search to find the vocabulary set with size"
            "closest to the target size."))
def define_data_download_flags():
    flags.DEFINE_string(name="data_dir",
                        short_name="dd",
                        default="/hdd/data/iwslt18/open-subtitles/tf_data_dir",
                        help=flags_core.help_wrap(""))

    flags.DEFINE_string(
        name="vocab_prefix",
        short_name="vp",
        default="/hdd/data/iwslt18/open-subtitles/data.tok/vocab.bpe.16000",
        help=flags_core.help_wrap(""))

    flags.DEFINE_bool(name="share_vocab",
                      short_name="sv",
                      default=True,
                      help=flags_core.help_wrap(""))

    flags.DEFINE_string(
        name="train_prefix",
        short_name="tp",
        default=
        "/hdd/data/iwslt18/open-subtitles/data.tok/train.tok.clean.bpe.16000",
        help=flags_core.help_wrap(""))

    flags.DEFINE_string(
        name="dev_prefix",
        short_name="dp",
        default=
        "/hdd/data/iwslt18/open-subtitles/data.tok/dev.tok.clean.bpe.16000",
        help=flags_core.help_wrap(""))

    flags.DEFINE_string(name="src",
                        short_name="src",
                        default="eu",
                        help=flags_core.help_wrap(""))

    flags.DEFINE_string(name="tgt",
                        short_name="tgt",
                        default="en",
                        help=flags_core.help_wrap(""))
def define_transformer_flags():
    """Add flags and flag validators for running transformer_main."""
    # Add common flags (data_dir, model_dir, train_epochs, etc.).
    flags_core.define_base()
    flags_core.define_performance(num_parallel_calls=True,
                                  inter_op=False,
                                  intra_op=False,
                                  synthetic_data=True,
                                  max_train_steps=False,
                                  dtype=False,
                                  all_reduce_alg=True)
    flags_core.define_benchmark()
    flags_core.define_device(tpu=True)

    # Set flags from the flags_core module as "key flags" so they're listed when
    # the '-h' flag is used. Without this line, the flags defined above are
    # only shown in the full `--helpful` help text.
    flags.adopt_module_key_flags(flags_core)

    # Add transformer-specific flags
    flags.DEFINE_enum(
        name="param_set",
        short_name="ps",
        default="big",
        enum_values=PARAMS_MAP.keys(),
        help=flags_core.help_wrap(
            "Parameter set to use when creating and training the model. The "
            "parameters define the input shape (batch size and max length), "
            "model configuration (size of embedding, # of hidden layers, etc.), "
            "and various other settings. The big parameter set increases the "
            "default batch size, embedding/hidden size, and filter size. For a "
            "complete list of parameters, please see model/model_params.py."))

    flags.DEFINE_bool(
        name="static_batch",
        default=False,
        help=flags_core.help_wrap(
            "Whether the batches in the dataset should have static shapes. In "
            "general, this setting should be False. Dynamic shapes allow the "
            "inputs to be grouped so that the number of padding tokens is "
            "minimized, and helps model training. In cases where the input shape "
            "must be static (e.g. running on TPU), this setting will be ignored "
            "and static batching will always be used."))

    # Flags for training with steps (may be used for debugging)
    flags.DEFINE_integer(
        name="train_steps",
        short_name="ts",
        default=None,
        help=flags_core.help_wrap("The number of steps used to train."))
    flags.DEFINE_integer(
        name="steps_between_evals",
        short_name="sbe",
        default=1000,
        help=flags_core.help_wrap(
            "The Number of training steps to run between evaluations. This is "
            "used if --train_steps is defined."))

    # BLEU score computation
    flags.DEFINE_string(
        name="bleu_source",
        short_name="bls",
        default=None,
        help=flags_core.help_wrap(
            "Path to source file containing text translate when calculating the "
            "official BLEU score. Both --bleu_source and --bleu_ref must be set. "
            "Use the flag --stop_threshold to stop the script based on the "
            "uncased BLEU score."))
    flags.DEFINE_string(
        name="bleu_ref",
        short_name="blr",
        default=None,
        help=flags_core.help_wrap(
            "Path to source file containing text translate when calculating the "
            "official BLEU score. Both --bleu_source and --bleu_ref must be set. "
            "Use the flag --stop_threshold to stop the script based on the "
            "uncased BLEU score."))
    flags.DEFINE_string(
        name="vocab_file",
        short_name="vf",
        default=None,
        help=flags_core.help_wrap(
            "Path to subtoken vocabulary file. If data_download.py was used to "
            "download and encode the training data, look in the data_dir to find "
            "the vocab file."))
    flags.DEFINE_string(
        name="subword_option",
        short_name="so",
        default="bpe",
        help=flags_core.help_wrap("Possible values: ['', 'bpe', 'spm']"))

    flags.DEFINE_bool(
        name="gpu_allow_growth",
        short_name="gag",
        default=True,
        help=flags_core.help_wrap("Allow gpu memory dynamic growth."))
    flags.DEFINE_float(
        name="gpu_memory_fraction",
        short_name="gmf",
        default=0.5,
        help=flags_core.help_wrap("Fraction of total gpu memory"))

    flags_core.set_defaults(data_dir="/tmp/translate_ende",
                            model_dir="/tmp/transformer_model",
                            batch_size=None,
                            train_epochs=None)

    @flags.multi_flags_validator(
        ["train_epochs", "train_steps"],
        message=
        "Both --train_steps and --train_epochs were set. Only one may be "
        "defined.")
    def _check_train_limits(flag_dict):
        return flag_dict["train_epochs"] is None or flag_dict[
            "train_steps"] is None

    @flags.multi_flags_validator(
        ["bleu_source", "bleu_ref"],
        message="Both or neither --bleu_source and --bleu_ref must be defined."
    )
    def _check_bleu_files(flags_dict):
        return (flags_dict["bleu_source"] is None) == (flags_dict["bleu_ref"]
                                                       is None)

    @flags.multi_flags_validator(
        ["bleu_source", "bleu_ref", "vocab_file"],
        message="--vocab_file must be defined if --bleu_source and --bleu_ref "
        "are defined.")
    def _check_bleu_vocab_file(flags_dict):
        if flags_dict["bleu_source"] and flags_dict["bleu_ref"]:
            return flags_dict["vocab_file"] is not None
        return True

    @flags.multi_flags_validator(
        ["export_dir", "vocab_file"],
        message="--vocab_file must be defined if --export_dir is set.")
    def _check_export_vocab_file(flags_dict):
        if flags_dict["export_dir"]:
            return flags_dict["vocab_file"] is not None
        return True

    flags_core.require_cloud_storage(["data_dir", "model_dir", "export_dir"])