def define_common_squad_flags(): """Defines common flags used by SQuAD tasks.""" flags.DEFINE_enum( 'mode', 'train_and_eval', [ 'train_and_eval', 'train_and_predict', 'train', 'eval', 'predict', 'export_only' ], 'One of {"train_and_eval", "train_and_predict", ' '"train", "eval", "predict", "export_only"}. ' '`train_and_eval`: train & predict to json files & compute eval metrics. ' '`train_and_predict`: train & predict to json files. ' '`train`: only trains the model. ' '`eval`: predict answers from squad json file & compute eval metrics. ' '`predict`: predict answers from the squad json file. ' '`export_only`: will take the latest checkpoint inside ' 'model_dir and export a `SavedModel`.') flags.DEFINE_string('train_data_path', '', 'Training data path with train tfrecords.') flags.DEFINE_string( 'input_meta_data_path', None, 'Path to file that contains meta data about input ' 'to be used for training and evaluation.') # Model training specific flags. flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.') # Predict processing related. flags.DEFINE_string('predict_file', None, 'Prediction data path with train tfrecords.') flags.DEFINE_bool( 'do_lower_case', True, 'Whether to lower case the input text. Should be True for uncased ' 'models and False for cased models.') flags.DEFINE_float( 'null_score_diff_threshold', 0.0, 'If null_score - best_non_null is greater than the threshold, ' 'predict null. This is only used for SQuAD v2.') flags.DEFINE_bool( 'verbose_logging', False, 'If true, all of the warnings related to data processing will be ' 'printed. A number of warnings are expected for a normal SQuAD ' 'evaluation.') flags.DEFINE_integer('predict_batch_size', 8, 'Total batch size for prediction.') flags.DEFINE_integer( 'n_best_size', 20, 'The total number of n-best predictions to generate in the ' 'nbest_predictions.json output file.') flags.DEFINE_integer( 'max_answer_length', 30, 'The maximum length of an answer that can be generated. This is needed ' 'because the start and end predictions are not conditioned on one ' 'another.') common_flags.define_common_bert_flags() common_flags.define_gin_flags()
# Model training specific flags. flags.DEFINE_integer( 'max_seq_length', 128, 'The maximum total input sequence length after WordPiece tokenization. ' 'Sequences longer than this will be truncated, and sequences shorter ' 'than this will be padded.') flags.DEFINE_integer('max_predictions_per_seq', 20, 'Maximum predictions per sequence_output.') flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.') flags.DEFINE_integer('num_steps_per_epoch', 1000, 'Total number of training steps to run per epoch.') flags.DEFINE_float('warmup_steps', 10000, 'Warmup steps for Adam weight decay optimizer.') common_flags.define_common_bert_flags() common_flags.define_gin_flags() FLAGS = flags.FLAGS def get_pretrain_dataset_fn(input_file_pattern, seq_length, max_predictions_per_seq, global_batch_size): """Returns input dataset from input file string.""" def _dataset_fn(ctx=None): """Returns tf.data.Dataset for distributed BERT pretraining.""" input_patterns = input_file_pattern.split(',') batch_size = ctx.get_per_replica_batch_size(global_batch_size) train_dataset = input_pipeline.create_pretrain_dataset( input_patterns, seq_length, max_predictions_per_seq,