Пример #1
0
def define_common_squad_flags():
    """Defines common flags used by SQuAD tasks."""
    flags.DEFINE_enum(
        'mode', 'train_and_eval', [
            'train_and_eval', 'train_and_predict', 'train', 'eval', 'predict',
            'export_only'
        ], 'One of {"train_and_eval", "train_and_predict", '
        '"train", "eval", "predict", "export_only"}. '
        '`train_and_eval`: train & predict to json files & compute eval metrics. '
        '`train_and_predict`: train & predict to json files. '
        '`train`: only trains the model. '
        '`eval`: predict answers from squad json file & compute eval metrics. '
        '`predict`: predict answers from the squad json file. '
        '`export_only`: will take the latest checkpoint inside '
        'model_dir and export a `SavedModel`.')
    flags.DEFINE_string('train_data_path', '',
                        'Training data path with train tfrecords.')
    flags.DEFINE_string(
        'input_meta_data_path', None,
        'Path to file that contains meta data about input '
        'to be used for training and evaluation.')
    # Model training specific flags.
    flags.DEFINE_integer('train_batch_size', 32,
                         'Total batch size for training.')
    # Predict processing related.
    flags.DEFINE_string(
        'predict_file', None, 'SQuAD prediction json file path. '
        '`predict` mode supports multiple files: one can use '
        'wildcard to specify multiple files and it can also be '
        'multiple file patterns separated by comma. Note that '
        '`eval` mode only supports a single predict file.')
    flags.DEFINE_bool(
        'do_lower_case', True,
        'Whether to lower case the input text. Should be True for uncased '
        'models and False for cased models.')
    flags.DEFINE_float(
        'null_score_diff_threshold', 0.0,
        'If null_score - best_non_null is greater than the threshold, '
        'predict null. This is only used for SQuAD v2.')
    flags.DEFINE_bool(
        'verbose_logging', False,
        'If true, all of the warnings related to data processing will be '
        'printed. A number of warnings are expected for a normal SQuAD '
        'evaluation.')
    flags.DEFINE_integer('predict_batch_size', 8,
                         'Total batch size for prediction.')
    flags.DEFINE_integer(
        'n_best_size', 20,
        'The total number of n-best predictions to generate in the '
        'nbest_predictions.json output file.')
    flags.DEFINE_integer(
        'max_answer_length', 30,
        'The maximum length of an answer that can be generated. This is needed '
        'because the start and end predictions are not conditioned on one '
        'another.')

    common_flags.define_common_bert_flags()
Пример #2
0
    'than this will be padded.')
flags.DEFINE_integer('max_predictions_per_seq', 20,
                     'Maximum predictions per sequence_output.')
flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.')
flags.DEFINE_integer('num_steps_per_epoch', 1000,
                     'Total number of training steps to run per epoch.')
flags.DEFINE_float('warmup_steps', 10000,
                   'Warmup steps for Adam weight decay optimizer.')
flags.DEFINE_bool('use_next_sentence_label', True,
                  'Whether to use next sentence label to compute final loss.')
flags.DEFINE_bool(
    'train_summary_interval', 0, 'Step interval for training '
    'summaries. If the value is a negative number, '
    'then training summaries are not enabled.')

common_flags.define_common_bert_flags()

FLAGS = flags.FLAGS


def get_pretrain_dataset_fn(input_file_pattern,
                            seq_length,
                            max_predictions_per_seq,
                            global_batch_size,
                            use_next_sentence_label=True):
    """Returns input dataset from input file string."""
    def _dataset_fn(ctx=None):
        """Returns tf.data.Dataset for distributed BERT pretraining."""
        input_patterns = input_file_pattern.split(',')
        batch_size = ctx.get_per_replica_batch_size(global_batch_size)
        train_dataset = input_pipeline.create_pretrain_dataset(