def define_common_squad_flags(): """Defines common flags used by SQuAD tasks.""" flags.DEFINE_enum( 'mode', 'train_and_eval', [ 'train_and_eval', 'train_and_predict', 'train', 'eval', 'predict', 'export_only' ], 'One of {"train_and_eval", "train_and_predict", ' '"train", "eval", "predict", "export_only"}. ' '`train_and_eval`: train & predict to json files & compute eval metrics. ' '`train_and_predict`: train & predict to json files. ' '`train`: only trains the model. ' '`eval`: predict answers from squad json file & compute eval metrics. ' '`predict`: predict answers from the squad json file. ' '`export_only`: will take the latest checkpoint inside ' 'model_dir and export a `SavedModel`.') flags.DEFINE_string('train_data_path', '', 'Training data path with train tfrecords.') flags.DEFINE_string( 'input_meta_data_path', None, 'Path to file that contains meta data about input ' 'to be used for training and evaluation.') # Model training specific flags. flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.') # Predict processing related. flags.DEFINE_string( 'predict_file', None, 'SQuAD prediction json file path. ' '`predict` mode supports multiple files: one can use ' 'wildcard to specify multiple files and it can also be ' 'multiple file patterns separated by comma. Note that ' '`eval` mode only supports a single predict file.') flags.DEFINE_bool( 'do_lower_case', True, 'Whether to lower case the input text. Should be True for uncased ' 'models and False for cased models.') flags.DEFINE_float( 'null_score_diff_threshold', 0.0, 'If null_score - best_non_null is greater than the threshold, ' 'predict null. This is only used for SQuAD v2.') flags.DEFINE_bool( 'verbose_logging', False, 'If true, all of the warnings related to data processing will be ' 'printed. A number of warnings are expected for a normal SQuAD ' 'evaluation.') flags.DEFINE_integer('predict_batch_size', 8, 'Total batch size for prediction.') flags.DEFINE_integer( 'n_best_size', 20, 'The total number of n-best predictions to generate in the ' 'nbest_predictions.json output file.') flags.DEFINE_integer( 'max_answer_length', 30, 'The maximum length of an answer that can be generated. This is needed ' 'because the start and end predictions are not conditioned on one ' 'another.') common_flags.define_common_bert_flags()
'than this will be padded.') flags.DEFINE_integer('max_predictions_per_seq', 20, 'Maximum predictions per sequence_output.') flags.DEFINE_integer('train_batch_size', 32, 'Total batch size for training.') flags.DEFINE_integer('num_steps_per_epoch', 1000, 'Total number of training steps to run per epoch.') flags.DEFINE_float('warmup_steps', 10000, 'Warmup steps for Adam weight decay optimizer.') flags.DEFINE_bool('use_next_sentence_label', True, 'Whether to use next sentence label to compute final loss.') flags.DEFINE_bool( 'train_summary_interval', 0, 'Step interval for training ' 'summaries. If the value is a negative number, ' 'then training summaries are not enabled.') common_flags.define_common_bert_flags() FLAGS = flags.FLAGS def get_pretrain_dataset_fn(input_file_pattern, seq_length, max_predictions_per_seq, global_batch_size, use_next_sentence_label=True): """Returns input dataset from input file string.""" def _dataset_fn(ctx=None): """Returns tf.data.Dataset for distributed BERT pretraining.""" input_patterns = input_file_pattern.split(',') batch_size = ctx.get_per_replica_batch_size(global_batch_size) train_dataset = input_pipeline.create_pretrain_dataset(