コード例 #1
0
ファイル: misc.py プロジェクト: qa276390/tf-models
def define_transformer_flags():
    """Add flags and flag validators for running transformer_main."""
    # Add common flags (data_dir, model_dir, train_epochs, etc.).
    flags_core.define_base()
    flags_core.define_performance(num_parallel_calls=True,
                                  inter_op=False,
                                  intra_op=False,
                                  synthetic_data=True,
                                  max_train_steps=False,
                                  dtype=True,
                                  loss_scale=True,
                                  all_reduce_alg=True,
                                  enable_xla=True)

    # Additional performance flags
    # TODO(b/76028325): Remove when generic layout optimizer is ready.
    flags.DEFINE_boolean(
        name='enable_grappler_layout_optimizer',
        default=True,
        help='Enable Grappler layout optimizer. Currently Grappler can '
        'de-optimize fp16 graphs by forcing NCHW layout for all '
        'convolutions and batch normalizations, and this flag allows to '
        'disable it.')

    flags_core.define_benchmark()
    flags_core.define_device(tpu=True)

    flags.DEFINE_integer(
        name='train_steps',
        short_name='ts',
        default=300000,
        help=flags_core.help_wrap('The number of steps used to train.'))
    flags.DEFINE_integer(
        name='steps_between_evals',
        short_name='sbe',
        default=1000,
        help=flags_core.help_wrap(
            'The Number of training steps to run between evaluations. This is '
            'used if --train_steps is defined.'))
    flags.DEFINE_boolean(name='enable_time_history',
                         default=True,
                         help='Whether to enable TimeHistory callback.')
    flags.DEFINE_boolean(name='enable_tensorboard',
                         default=False,
                         help='Whether to enable Tensorboard callback.')
    flags.DEFINE_boolean(name='enable_metrics_in_training',
                         default=False,
                         help='Whether to enable metrics during training.')
    flags.DEFINE_string(
        name='profile_steps',
        default=None,
        help='Save profiling data to model dir at given range of steps. The '
        'value must be a comma separated pair of positive integers, specifying '
        'the first and last step to profile. For example, "--profile_steps=2,4" '
        'triggers the profiler to process 3 steps, starting from the 2nd step. '
        'Note that profiler has a non-trivial performance overhead, and the '
        'output file can be gigantic if profiling many steps.')
    # Set flags from the flags_core module as 'key flags' so they're listed when
    # the '-h' flag is used. Without this line, the flags defined above are
    # only shown in the full `--helpful` help text.
    flags.adopt_module_key_flags(flags_core)

    # Add transformer-specific flags
    flags.DEFINE_enum(
        name='param_set',
        short_name='mp',
        default='big',
        enum_values=PARAMS_MAP.keys(),
        help=flags_core.help_wrap(
            'Parameter set to use when creating and training the model. The '
            'parameters define the input shape (batch size and max length), '
            'model configuration (size of embedding, # of hidden layers, etc.), '
            'and various other settings. The big parameter set increases the '
            'default batch size, embedding/hidden size, and filter size. For a '
            'complete list of parameters, please see model/model_params.py.'))

    flags.DEFINE_bool(
        name='static_batch',
        short_name='sb',
        default=False,
        help=flags_core.help_wrap(
            'Whether the batches in the dataset should have static shapes. In '
            'general, this setting should be False. Dynamic shapes allow the '
            'inputs to be grouped so that the number of padding tokens is '
            'minimized, and helps model training. In cases where the input shape '
            'must be static (e.g. running on TPU), this setting will be ignored '
            'and static batching will always be used.'))
    flags.DEFINE_integer(
        name='max_length',
        short_name='ml',
        default=256,
        help=flags_core.help_wrap(
            'Max sentence length for Transformer. Default is 256. Note: Usually '
            'it is more effective to use a smaller max length if static_batch is '
            'enabled, e.g. 64.'))

    # Flags for training with steps (may be used for debugging)
    flags.DEFINE_integer(
        name='validation_steps',
        short_name='vs',
        default=64,
        help=flags_core.help_wrap('The number of steps used in validation.'))

    # BLEU score computation
    flags.DEFINE_string(
        name='bleu_source',
        short_name='bls',
        default=None,
        help=flags_core.help_wrap(
            'Path to source file containing text translate when calculating the '
            'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
            'Use the flag --stop_threshold to stop the script based on the '
            'uncased BLEU score.'))
    flags.DEFINE_string(
        name='bleu_ref',
        short_name='blr',
        default=None,
        help=flags_core.help_wrap(
            'Path to source file containing text translate when calculating the '
            'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
            'Use the flag --stop_threshold to stop the script based on the '
            'uncased BLEU score.'))
    flags.DEFINE_string(
        name='vocab_file',
        short_name='vf',
        default=None,
        help=flags_core.help_wrap(
            'Path to subtoken vocabulary file. If data_download.py was used to '
            'download and encode the training data, look in the data_dir to find '
            'the vocab file.'))
    flags.DEFINE_string(
        name='mode',
        default='train',
        help=flags_core.help_wrap('mode: train, eval, or predict'))

    flags_core.set_defaults(data_dir='/tmp/translate_ende',
                            model_dir='/tmp/transformer_model',
                            batch_size=None,
                            train_epochs=10)

    # pylint: disable=unused-variable
    @flags.multi_flags_validator(
        ['mode', 'train_epochs'],
        message='--train_epochs must be defined in train mode')
    def _check_train_limits(flag_dict):
        if flag_dict['mode'] == 'train':
            return flag_dict['train_epochs'] is not None
        return True

    @flags.multi_flags_validator(
        ['bleu_source', 'bleu_ref'],
        message='Both or neither --bleu_source and --bleu_ref must be defined.'
    )
    def _check_bleu_files(flags_dict):
        return (flags_dict['bleu_source'] is None) == (flags_dict['bleu_ref']
                                                       is None)

    @flags.multi_flags_validator(
        ['bleu_source', 'bleu_ref', 'vocab_file'],
        message='--vocab_file must be defined if --bleu_source and --bleu_ref '
        'are defined.')
    def _check_bleu_vocab_file(flags_dict):
        if flags_dict['bleu_source'] and flags_dict['bleu_ref']:
            return flags_dict['vocab_file'] is not None
        return True

    @flags.multi_flags_validator(
        ['export_dir', 'vocab_file'],
        message='--vocab_file must be defined if --export_dir is set.')
    def _check_export_vocab_file(flags_dict):
        if flags_dict['export_dir']:
            return flags_dict['vocab_file'] is not None
        return True

    # pylint: enable=unused-variable

    flags_core.require_cloud_storage(['data_dir', 'model_dir', 'export_dir'])
def define_transformer_flags():
    """Add flags and flag validators for running transformer_main."""
    # Add common flags (data_dir, model_dir, train_epochs, etc.).
    flags.DEFINE_integer(name="max_length",
                         short_name="ml",
                         default=None,
                         help=flags_core.help_wrap("Max length."))

    flags_core.define_base(clean=True,
                           train_epochs=True,
                           epochs_between_evals=True,
                           stop_threshold=True,
                           num_gpu=True,
                           hooks=True,
                           export_dir=True,
                           distribution_strategy=True)
    flags_core.define_performance(num_parallel_calls=True,
                                  inter_op=False,
                                  intra_op=False,
                                  synthetic_data=True,
                                  max_train_steps=False,
                                  dtype=True,
                                  all_reduce_alg=True)
    flags_core.define_benchmark()
    flags_core.define_device(tpu=True)

    # Set flags from the flags_core module as "key flags" so they're listed when
    # the '-h' flag is used. Without this line, the flags defined above are
    # only shown in the full `--helpful` help text.
    flags.adopt_module_key_flags(flags_core)

    # Add transformer-specific flags
    flags.DEFINE_enum(
        name="param_set",
        short_name="mp",
        default="big",
        enum_values=PARAMS_MAP.keys(),
        help=flags_core.help_wrap(
            "Parameter set to use when creating and training the model. The "
            "parameters define the input shape (batch size and max length), "
            "model configuration (size of embedding, # of hidden layers, etc.), "
            "and various other settings. The big parameter set increases the "
            "default batch size, embedding/hidden size, and filter size. For a "
            "complete list of parameters, please see model/model_params.py."))

    flags.DEFINE_bool(
        name="static_batch",
        default=False,
        help=flags_core.help_wrap(
            "Whether the batches in the dataset should have static shapes. In "
            "general, this setting should be False. Dynamic shapes allow the "
            "inputs to be grouped so that the number of padding tokens is "
            "minimized, and helps model training. In cases where the input shape "
            "must be static (e.g. running on TPU), this setting will be ignored "
            "and static batching will always be used."))

    # Flags for training with steps (may be used for debugging)
    flags.DEFINE_integer(
        name="train_steps",
        short_name="ts",
        default=None,
        help=flags_core.help_wrap("The number of steps used to train."))
    flags.DEFINE_integer(
        name="steps_between_evals",
        short_name="sbe",
        default=1000,
        help=flags_core.help_wrap(
            "The Number of training steps to run between evaluations. This is "
            "used if --train_steps is defined."))

    # BLEU score computation
    flags.DEFINE_string(
        name="bleu_source",
        short_name="bls",
        default=None,
        help=flags_core.help_wrap(
            "Path to source file containing text translate when calculating the "
            "official BLEU score. Both --bleu_source and --bleu_ref must be set. "
            "Use the flag --stop_threshold to stop the script based on the "
            "uncased BLEU score."))
    flags.DEFINE_string(
        name="bleu_ref",
        short_name="blr",
        default=None,
        help=flags_core.help_wrap(
            "Path to source file containing text translate when calculating the "
            "official BLEU score. Both --bleu_source and --bleu_ref must be set. "
            "Use the flag --stop_threshold to stop the script based on the "
            "uncased BLEU score."))
    flags.DEFINE_string(
        name="vocab_file",
        short_name="vf",
        default=None,
        help=flags_core.help_wrap(
            "Path to subtoken vocabulary file. If data_download.py was used to "
            "download and encode the training data, look in the data_dir to find "
            "the vocab file."))
    flags.DEFINE_integer(name="save_checkpoints_steps",
                         short_name="scs",
                         default=50000,
                         help=flags_core.help_wrap("the vocab file."))

    flags_core.set_defaults(data_dir="/tmp/translate_ende",
                            model_dir="/tmp/transformer_model",
                            batch_size=None,
                            train_epochs=None)

    @flags.multi_flags_validator(
        ["train_epochs", "train_steps"],
        message=
        "Both --train_steps and --train_epochs were set. Only one may be "
        "defined.")
    def _check_train_limits(flag_dict):
        return flag_dict["train_epochs"] is None or flag_dict[
            "train_steps"] is None

    @flags.multi_flags_validator(
        ["bleu_source", "bleu_ref"],
        message="Both or neither --bleu_source and --bleu_ref must be defined."
    )
    def _check_bleu_files(flags_dict):
        return (flags_dict["bleu_source"] is None) == (flags_dict["bleu_ref"]
                                                       is None)

    @flags.multi_flags_validator(
        ["bleu_source", "bleu_ref", "vocab_file"],
        message="--vocab_file must be defined if --bleu_source and --bleu_ref "
        "are defined.")
    def _check_bleu_vocab_file(flags_dict):
        if flags_dict["bleu_source"] and flags_dict["bleu_ref"]:
            return flags_dict["vocab_file"] is not None
        return True

    @flags.multi_flags_validator(
        ["export_dir", "vocab_file"],
        message="--vocab_file must be defined if --export_dir is set.")
    def _check_export_vocab_file(flags_dict):
        if flags_dict["export_dir"]:
            return flags_dict["vocab_file"] is not None
        return True

    flags_core.require_cloud_storage(["data_dir", "model_dir", "export_dir"])
コード例 #3
0
def define_transformer_flags():
  """Add flags and flag validators for running transformer_main."""
  # Add common flags (data_dir, model_dir, train_epochs, etc.).
  flags_core.define_base()
  flags_core.define_performance(
      num_parallel_calls=True,
      inter_op=True,
      intra_op=True,
      synthetic_data=True,
      max_train_steps=False,
      dtype=False,
      all_reduce_alg=False
  )
  flags_core.define_benchmark()
  flags_core.define_device(tpu=True)

  # Set flags from the flags_core module as "key flags" so they're listed when
  # the '-h' flag is used. Without this line, the flags defined above are
  # only shown in the full `--helpful` help text.
  flags.adopt_module_key_flags(flags_core)

  # Add transformer-specific flags
  flags.DEFINE_enum(
      name="param_set", short_name="mp", default="big",
      enum_values=PARAMS_MAP.keys(),
      help=flags_core.help_wrap(
          "Parameter set to use when creating and training the model. The "
          "parameters define the input shape (batch size and max length), "
          "model configuration (size of embedding, # of hidden layers, etc.), "
          "and various other settings. The big parameter set increases the "
          "default batch size, embedding/hidden size, and filter size. For a "
          "complete list of parameters, please see model/model_params.py."))

  flags.DEFINE_bool(
      name="static_batch", default=False,
      help=flags_core.help_wrap(
          "Whether the batches in the dataset should have static shapes. In "
          "general, this setting should be False. Dynamic shapes allow the "
          "inputs to be grouped so that the number of padding tokens is "
          "minimized, and helps model training. In cases where the input shape "
          "must be static (e.g. running on TPU), this setting will be ignored "
          "and static batching will always be used."))

  # Flags for training with steps (may be used for debugging)
  flags.DEFINE_integer(
      name="train_steps", short_name="ts", default=None,
      help=flags_core.help_wrap("The number of steps used to train."))
  flags.DEFINE_integer(
      name="steps_between_evals", short_name="sbe", default=1000,
      help=flags_core.help_wrap(
          "The Number of training steps to run between evaluations. This is "
          "used if --train_steps is defined."))

  # add intra_op and inter_op flags as arguments

  flags.DEFINE_integer(
     name="intra_op", default=None,
     help=flags_core.help_wrap("The number of intra_op_parallelism threads"))
  flags.DEFINE_integer(
     name="inter_op", default=None,
     help=flags_core.help_wrap("The number of inter_op_parallelism threads"))

  # added flags to override the learning rate, decay, warmup, max_length,   info from params file

  flags.DEFINE_float(
     name="learning_rate",default=2.0,
     help=flags_core.help_wrap("Learning rate"))
  
  # learning_rate_decay_rate is not used anywhere. Added just to be in sync with the params file.
  flags.DEFINE_float(
     name="learning_rate_decay_rate",default=1.0,
     help=flags_core.help_wrap("Learning rate decay rate"))

  flags.DEFINE_integer(
     name="learning_rate_warmup_steps",default=16000,
     help=flags_core.help_wrap("Learning rate warmup steps"))

  flags.DEFINE_integer(
     name="max_length",default=256,
     help=flags_core.help_wrap("Maximum number of tokens per example"))

  flags.DEFINE_integer(
     name="vocab_size",default=33708,
     help=flags_core.help_wrap("Number of tokens defined in the vocabulary file"))

  flags.DEFINE_integer(
     name="save_checkpoints_secs",default=3600,
     help=flags_core.help_wrap("Save checkpoints every mentioned seconds"))
  flags.DEFINE_integer(
     name="log_step_count_steps",default=100,
     help=flags_core.help_wrap("Frequency in steps at which loss and global step/sec is logged"))

  # added for learning rate decay scheme

  flags.DEFINE_integer(
     name="lr_scheme",default=1,
     help=flags_core.help_wrap("Type of learning rate decay scheme."
           "Can be 0,1 or 2."
           " 0 - constant learning rate"
           " 1 - does noam"
           " 2 - does linear lr growth and inverse sqrt decay"))


  flags.DEFINE_float(
     name="warmup_init_lr",default=1e-07,
     help=flags_core.help_wrap("Initial learning rate for the warm up phase"))


  flags.DEFINE_float(
     name="layer_postprocess_dropout",default=0.1,
     help=flags_core.help_wrap("Dropout value"))

  # added for optimizers and it's parameters

  flags.DEFINE_string(
      name="opt_alg", short_name="opt", default="lazyadam",
      help=flags_core.help_wrap("Optimizer algorithm to be used"))
      
  flags.DEFINE_float(
      name="optimizer_sgd_momentum", short_name="sgdm", default=None,
      help=flags_core.help_wrap("Value for SGD's momentum param"))
      
  flags.DEFINE_float(
      name="optimizer_rms_decay", short_name="rmsd", default=0.9,
      help=flags_core.help_wrap("RMSProp Decay value"))
      
  flags.DEFINE_float(
      name="optimizer_rms_momentum", short_name="rmsm", default=0.0,
      help=flags_core.help_wrap("RMSProp momentum value"))
      
  flags.DEFINE_float(
      name="optimizer_rms_epsilon", short_name="rmse", default=1e-10,
      help=flags_core.help_wrap("RMSProp epsilon value"))


  # BLEU score computation
  flags.DEFINE_string(
      name="bleu_source", short_name="bls", default=None,
      help=flags_core.help_wrap(
          "Path to source file containing text translate when calculating the "
          "official BLEU score. Both --bleu_source and --bleu_ref must be set. "
          "Use the flag --stop_threshold to stop the script based on the "
          "uncased BLEU score."))
  flags.DEFINE_string(
      name="bleu_ref", short_name="blr", default=None,
      help=flags_core.help_wrap(
          "Path to source file containing text translate when calculating the "
          "official BLEU score. Both --bleu_source and --bleu_ref must be set. "
          "Use the flag --stop_threshold to stop the script based on the "
          "uncased BLEU score."))
  flags.DEFINE_string(
      name="vocab_file", short_name="vf", default=None,
      help=flags_core.help_wrap(
          "Path to subtoken vocabulary file. If data_download.py was used to "
          "download and encode the training data, look in the data_dir to find "
          "the vocab file."))

  flags_core.set_defaults(data_dir="/tmp/translate_ende",
                          model_dir="/tmp/transformer_model",
                          batch_size=None,
                          train_epochs=None)

  @flags.multi_flags_validator(
      ["train_epochs", "train_steps"],
      message="Both --train_steps and --train_epochs were set. Only one may be "
              "defined.")
  def _check_train_limits(flag_dict):
    return flag_dict["train_epochs"] is None or flag_dict["train_steps"] is None

  @flags.multi_flags_validator(
      ["bleu_source", "bleu_ref"],
      message="Both or neither --bleu_source and --bleu_ref must be defined.")
  def _check_bleu_files(flags_dict):
    return (flags_dict["bleu_source"] is None) == (
        flags_dict["bleu_ref"] is None)

  @flags.multi_flags_validator(
      ["bleu_source", "bleu_ref", "vocab_file"],
      message="--vocab_file must be defined if --bleu_source and --bleu_ref "
              "are defined.")
  def _check_bleu_vocab_file(flags_dict):
    if flags_dict["bleu_source"] and flags_dict["bleu_ref"]:
      return flags_dict["vocab_file"] is not None
    return True

  @flags.multi_flags_validator(
      ["export_dir", "vocab_file"],
      message="--vocab_file must be defined if --export_dir is set.")
  def _check_export_vocab_file(flags_dict):
    if flags_dict["export_dir"]:
      return flags_dict["vocab_file"] is not None
    return True

  flags_core.require_cloud_storage(["data_dir", "model_dir", "export_dir"])
コード例 #4
0
def define_transformer_flags():
  """Add flags and flag validators for running transformer_main."""
  # Add common flags (data_dir, model_dir, train_epochs, etc.).
  flags_core.define_base()
  flags_core.define_performance(
      num_parallel_calls=True,
      inter_op=False,
      intra_op=False,
      synthetic_data=True,
      max_train_steps=False,
      dtype=False,
      all_reduce_alg=True
  )
  flags_core.define_benchmark()
  flags_core.define_device(tpu=True)

  # Set flags from the flags_core module as "key flags" so they're listed when
  # the '-h' flag is used. Without this line, the flags defined above are
  # only shown in the full `--helpful` help text.
  flags.adopt_module_key_flags(flags_core)

  # Add transformer-specific flags
  flags.DEFINE_enum(
      name="param_set", short_name="mp", default="big",
      enum_values=PARAMS_MAP.keys(),
      help=flags_core.help_wrap(
          "Parameter set to use when creating and training the model. The "
          "parameters define the input shape (batch size and max length), "
          "model configuration (size of embedding, # of hidden layers, etc.), "
          "and various other settings. The big parameter set increases the "
          "default batch size, embedding/hidden size, and filter size. For a "
          "complete list of parameters, please see model/model_params.py."))

  flags.DEFINE_bool(
      name="static_batch", default=False,
      help=flags_core.help_wrap(
          "Whether the batches in the dataset should have static shapes. In "
          "general, this setting should be False. Dynamic shapes allow the "
          "inputs to be grouped so that the number of padding tokens is "
          "minimized, and helps model training. In cases where the input shape "
          "must be static (e.g. running on TPU), this setting will be ignored "
          "and static batching will always be used."))

  # Flags for training with steps (may be used for debugging)
  flags.DEFINE_integer(
      name="train_steps", short_name="ts", default=None,
      help=flags_core.help_wrap("The number of steps used to train."))
  flags.DEFINE_integer(
      name="steps_between_evals", short_name="sbe", default=1000,
      help=flags_core.help_wrap(
          "The Number of training steps to run between evaluations. This is "
          "used if --train_steps is defined."))

  # BLEU score computation
  flags.DEFINE_string(
      name="bleu_source", short_name="bls", default=None,
      help=flags_core.help_wrap(
          "Path to source file containing text translate when calculating the "
          "official BLEU score. Both --bleu_source and --bleu_ref must be set. "
          "Use the flag --stop_threshold to stop the script based on the "
          "uncased BLEU score."))
  flags.DEFINE_string(
      name="bleu_ref", short_name="blr", default=None,
      help=flags_core.help_wrap(
          "Path to source file containing text translate when calculating the "
          "official BLEU score. Both --bleu_source and --bleu_ref must be set. "
          "Use the flag --stop_threshold to stop the script based on the "
          "uncased BLEU score."))
  flags.DEFINE_string(
      name="vocab_file", short_name="vf", default=None,
      help=flags_core.help_wrap(
          "Path to subtoken vocabulary file. If data_download.py was used to "
          "download and encode the training data, look in the data_dir to find "
          "the vocab file."))

  flags_core.set_defaults(data_dir="/tmp/translate_ende",
                          model_dir="/tmp/transformer_model",
                          batch_size=None,
                          train_epochs=None)

  @flags.multi_flags_validator(
      ["train_epochs", "train_steps"],
      message="Both --train_steps and --train_epochs were set. Only one may be "
              "defined.")
  def _check_train_limits(flag_dict):
    return flag_dict["train_epochs"] is None or flag_dict["train_steps"] is None

  @flags.multi_flags_validator(
      ["bleu_source", "bleu_ref"],
      message="Both or neither --bleu_source and --bleu_ref must be defined.")
  def _check_bleu_files(flags_dict):
    return (flags_dict["bleu_source"] is None) == (
        flags_dict["bleu_ref"] is None)

  @flags.multi_flags_validator(
      ["bleu_source", "bleu_ref", "vocab_file"],
      message="--vocab_file must be defined if --bleu_source and --bleu_ref "
              "are defined.")
  def _check_bleu_vocab_file(flags_dict):
    if flags_dict["bleu_source"] and flags_dict["bleu_ref"]:
      return flags_dict["vocab_file"] is not None
    return True

  @flags.multi_flags_validator(
      ["export_dir", "vocab_file"],
      message="--vocab_file must be defined if --export_dir is set.")
  def _check_export_vocab_file(flags_dict):
    if flags_dict["export_dir"]:
      return flags_dict["vocab_file"] is not None
    return True

  flags_core.require_cloud_storage(["data_dir", "model_dir", "export_dir"])
コード例 #5
0
def define_transformer_flags():
    """Add flags and flag validators for running transformer_main."""
    # Add common flags (data_dir, model_dir, train_epochs, etc.).
    flags_core.define_base(multi_gpu=False, num_gpu=False, export_dir=False)
    flags_core.define_performance(num_parallel_calls=True,
                                  inter_op=False,
                                  intra_op=False,
                                  synthetic_data=False,
                                  max_train_steps=False,
                                  dtype=False)
    flags_core.define_benchmark()
    flags_core.define_device(tpu=True)

    # Set flags from the flags_core module as "key flags" so they're listed when
    # the '-h' flag is used. Without this line, the flags defined above are
    # only shown in the full `--helpful` help text.
    flags.adopt_module_key_flags(flags_core)

    # Add transformer-specific flags
    flags.DEFINE_enum(
        name="param_set",
        short_name="mp",
        default="big",
        enum_values=["base", "big", "tiny"],
        help=flags_core.help_wrap(
            "Parameter set to use when creating and training the model. The "
            "parameters define the input shape (batch size and max length), "
            "model configuration (size of embedding, # of hidden layers, etc.), "
            "and various other settings. The big parameter set increases the "
            "default batch size, embedding/hidden size, and filter size. For a "
            "complete list of parameters, please see model/model_params.py."))

    flags.DEFINE_bool(
        name="static_batch",
        default=False,
        help=flags_core.help_wrap(
            "Whether the batches in the dataset should have static shapes. In "
            "general, this setting should be False. Dynamic shapes allow the "
            "inputs to be grouped so that the number of padding tokens is "
            "minimized, and helps model training. In cases where the input shape "
            "must be static (e.g. running on TPU), this setting will be ignored "
            "and static batching will always be used."))

    # Flags for training with steps (may be used for debugging)
    flags.DEFINE_integer(
        name="train_steps",
        short_name="ts",
        default=None,
        help=flags_core.help_wrap("The number of steps used to train."))
    flags.DEFINE_integer(
        name="steps_between_evals",
        short_name="sbe",
        default=1000,
        help=flags_core.help_wrap(
            "The Number of training steps to run between evaluations. This is "
            "used if --train_steps is defined."))

    # BLEU score computation
    flags.DEFINE_string(
        name="bleu_source",
        short_name="bls",
        default=None,
        help=flags_core.help_wrap(
            "Path to source file containing text translate when calculating the "
            "official BLEU score. --bleu_source, --bleu_ref, and --vocab_file "
            "must be set. Use the flag --stop_threshold to stop the script based "
            "on the uncased BLEU score."))
    flags.DEFINE_string(
        name="bleu_ref",
        short_name="blr",
        default=None,
        help=flags_core.help_wrap(
            "Path to source file containing text translate when calculating the "
            "official BLEU score. --bleu_source, --bleu_ref, and --vocab_file "
            "must be set. Use the flag --stop_threshold to stop the script based "
            "on the uncased BLEU score."))
    flags.DEFINE_string(
        name="vocab_file",
        short_name="vf",
        default=VOCAB_FILE,
        help=flags_core.help_wrap(
            "Name of vocabulary file containing subtokens for subtokenizing the "
            "bleu_source file. This file is expected to be in the directory "
            "defined by --data_dir."))

    flags_core.set_defaults(data_dir="/tmp/translate_ende",
                            model_dir="/tmp/transformer_model",
                            batch_size=None,
                            train_epochs=None)

    @flags.multi_flags_validator(
        ["train_epochs", "train_steps"],
        message=
        "Both --train_steps and --train_epochs were set. Only one may be "
        "defined.")
    def _check_train_limits(flag_dict):
        return flag_dict["train_epochs"] is None or flag_dict[
            "train_steps"] is None

    @flags.multi_flags_validator(
        ["data_dir", "bleu_source", "bleu_ref", "vocab_file"],
        message="--bleu_source, --bleu_ref, and/or --vocab_file don't exist. "
        "Please ensure that the file paths are correct.")
    def _check_bleu_files(flags_dict):
        """Validate files when bleu_source and bleu_ref are defined."""
        if flags_dict["bleu_source"] is None or flags_dict["bleu_ref"] is None:
            return True
        # Ensure that bleu_source, bleu_ref, and vocab files exist.
        vocab_file_path = os.path.join(flags_dict["data_dir"],
                                       flags_dict["vocab_file"])
        return all([
            tf.gfile.Exists(flags_dict["bleu_source"]),
            tf.gfile.Exists(flags_dict["bleu_ref"]),
            tf.gfile.Exists(vocab_file_path)
        ])

    flags_core.require_cloud_storage(["data_dir", "model_dir"])