def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): if args.local_rank not in [-1, 0] and not evaluate: # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() # Load data features from cache or dataset file input_dir = args.data_dir if args.data_dir else "." cached_features_file = os.path.join( input_dir, "cached_{}_{}_{}".format( "dev" if evaluate else "train", list(filter(None, args.model_name_or_path.split("/"))).pop(), str(args.max_seq_length), ), ) # Init features and dataset from cache if it exists if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", cached_features_file) features_and_dataset = torch.load(cached_features_file) features, dataset, examples = ( features_and_dataset["features"], features_and_dataset["dataset"], features_and_dataset["examples"], ) else: logger.info("Creating features from dataset file at %s", input_dir) if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)): try: import tensorflow_datasets as tfds except ImportError: raise ImportError( "If not data_dir is specified, tensorflow_datasets needs to be installed." ) if args.version_2_with_negative: logger.warn( "tensorflow_datasets does not handle version 2 of SQuAD.") tfds_examples = tfds.load("squad") examples = SquadV1Processor().get_examples_from_dataset( tfds_examples, evaluate=evaluate) else: processor = SquadV2Processor( ) if args.version_2_with_negative else SquadV1Processor() if evaluate: examples = processor.get_dev_examples( args.data_dir, filename=args.predict_file) else: examples = processor.get_train_examples( args.data_dir, filename=args.train_file) features, dataset = squad_convert_examples_to_features( examples=examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=not evaluate, return_dataset="pt", threads=args.threads, ) if args.local_rank in [-1, 0]: logger.info("Saving features into cached file %s", cached_features_file) torch.save( { "features": features, "dataset": dataset, "examples": examples }, cached_features_file) if args.local_rank == 0 and not evaluate: # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() if output_examples: return dataset, examples, features return dataset
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TFTrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." ) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.info( "n_replicas: %s, distributed training: %s, 16-bits training: %s", training_args.n_replicas, bool(training_args.n_replicas > 1), training_args.fp16, ) logger.info("Training/evaluation parameters %s", training_args) # Prepare Question-Answering task # Load pretrained model and tokenizer # # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast, ) with training_args.strategy.scope(): model = TFAutoModelForQuestionAnswering.from_pretrained( model_args.model_name_or_path, from_pt=bool(".bin" in model_args.model_name_or_path), config=config, cache_dir=model_args.cache_dir, ) # Get datasets if data_args.use_tfds: if data_args.version_2_with_negative: logger.warn( "tensorflow_datasets does not handle version 2 of SQuAD. Switch to version 1 automatically" ) try: import tensorflow_datasets as tfds except ImportError: raise ImportError( "If not data_dir is specified, tensorflow_datasets needs to be installed." ) tfds_examples = tfds.load("squad", data_dir=data_args.data_dir) train_examples = (SquadV1Processor().get_examples_from_dataset( tfds_examples, evaluate=False) if training_args.do_train else None) eval_examples = (SquadV1Processor().get_examples_from_dataset( tfds_examples, evaluate=True) if training_args.do_eval else None) else: processor = SquadV2Processor( ) if data_args.version_2_with_negative else SquadV1Processor() train_examples = processor.get_train_examples( data_args.data_dir) if training_args.do_train else None eval_examples = processor.get_dev_examples( data_args.data_dir) if training_args.do_eval else None train_dataset = (squad_convert_examples_to_features( examples=train_examples, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, doc_stride=data_args.doc_stride, max_query_length=data_args.max_query_length, is_training=True, return_dataset="tf", ) if training_args.do_train else None) train_dataset = train_dataset.apply( tf.data.experimental.assert_cardinality(len(train_examples))) eval_dataset = (squad_convert_examples_to_features( examples=eval_examples, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length, doc_stride=data_args.doc_stride, max_query_length=data_args.max_query_length, is_training=False, return_dataset="tf", ) if training_args.do_eval else None) eval_dataset = eval_dataset.apply( tf.data.experimental.assert_cardinality(len(eval_examples))) # Initialize our Trainer trainer = TFTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset, ) # Training if training_args.do_train: trainer.train() trainer.save_model() tokenizer.save_pretrained(training_args.output_dir)
def load_combined_examples(args, evaluate=False): """ Deprecated sadly """ if args.local_rank not in [-1, 0] and not evaluate: # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)): try: import tensorflow_datasets as tfds except ImportError: raise ImportError( "If not data_dir is specified, tensorflow_datasets needs to be installed." ) if args.version_2_with_negative: logger.warn( "tensorflow_datasets does not handle version 2 of SQuAD.") logger.warn("Something went wrong!") tfds_examples = tfds.load("squad") examples = SquadV1Processor().get_examples_from_dataset( tfds_examples, evaluate=evaluate) else: processor = SquadV2Processor( ) if args.version_2_with_negative else SquadV1Processor() if evaluate: examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file) # Sanity check for loading the correct example assert examples[ 0].question_text == 'In what country is Normandy located?', 'Invalid dev file!' else: # Normal get train examples examples = processor.get_train_examples(args.data_dir, filename=args.train_file) # Sanity check for loading the correct example assert examples[ 0].question_text == 'When did Beyonce start becoming popular?', 'Invalid train file!' assert args.saved_processed_data_dir, 'args.saved_processed_data_dir not defined!' ensemble_dir = args.saved_processed_data_dir if evaluate: with open(os.path.join(ensemble_dir, 'saved_data_dev.pkl'), 'rb') as f: saved_data = pickle.load(f) else: with open(os.path.join(ensemble_dir, 'saved_data_train.pkl'), 'rb') as f: saved_data = pickle.load(f) # saved_data: [features, all_results, tokenizer] features, combined_all_results, tokenizer = saved_data assert np.array_equal( [f.start_position for f in features[0]], [f.start_position for f in features[1]]), print("Same family Same features") # Same family same feature and tokenizer, so we pick the first one features = features[0] tokenizer = tokenizer[0] all_predict_start_logits = [] all_predict_end_logits = [] # Convert to Tensors and build dataset all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) for all_results in combined_all_results: all_predict_start_logits.append( torch.tensor([s.start_logits for s in all_results], dtype=torch.float)) all_predict_end_logits.append( torch.tensor([s.end_logits for s in all_results], dtype=torch.float)) if evaluate: all_example_indices = torch.arange(all_input_ids.size(0), dtype=torch.long) all_predict_start_logits = torch.stack(all_predict_start_logits).permute( 1, 0, 2) all_predict_end_logits = torch.stack(all_predict_end_logits).permute( 1, 0, 2) # print(f'all_input_ids: {all_input_ids.shape}, all_predict_start_logits{all_predict_start_logits.shape}, all_predict_end_logits:{all_predict_end_logits.shape}') if evaluate: dataset = TensorDataset(all_predict_start_logits, all_predict_end_logits, all_example_indices) else: all_start_positions = torch.tensor( [f.start_position for f in features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) # print(all_start_positions.shape, all_end_positions.shape) dataset = TensorDataset(all_predict_start_logits, all_predict_end_logits, all_start_positions, all_end_positions) if evaluate: assert len(examples) == 6078 else: assert len(examples) == 130319 if args.local_rank == 0 and not evaluate: # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() return examples, features, dataset, tokenizer, len(combined_all_results)
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False): if args.local_rank not in [-1, 0] and not evaluate: # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() # Load data features from cache or dataset file input_file = args.predict_file if evaluate else args.train_file cached_features_file = os.path.join( os.path.dirname(input_file), "cached_distillation_{}_{}_{}".format( "dev" if evaluate else "train", list(filter(None, args.model_name_or_path.split("/"))).pop(), str(args.max_seq_length), ), ) if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", cached_features_file) features_and_dataset = torch.load(cached_features_file) try: features, dataset, examples = ( features_and_dataset["features"], features_and_dataset["dataset"], features_and_dataset["examples"], ) except KeyError: raise DeprecationWarning( "You seem to be loading features from an older version of this script please delete the " "file %s in order for it to be created again" % cached_features_file) else: logger.info("Creating features from dataset file at %s", input_file) processor = SquadV2Processor( ) if args.version_2_with_negative else SquadV1Processor() if evaluate: examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file) else: examples = processor.get_train_examples(args.data_dir, filename=args.train_file) features, dataset = squad_convert_examples_to_features( examples=examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=not evaluate, return_dataset="pt", threads=args.threads, ) if args.local_rank in [-1, 0]: logger.info("Saving features into cached file %s", cached_features_file) torch.save( { "features": features, "dataset": dataset, "examples": examples }, cached_features_file) if args.local_rank == 0 and not evaluate: # Make sure only the first process in distributed training process the dataset, and the others will use the cache torch.distributed.barrier() if output_examples: return dataset, examples, features return dataset
def load_and_cache_examples(args, tokenizer, evaluate=False, use_aug_path=False, output_examples=False ) -> torch.utils.data.TensorDataset: """Loads SQuAD-like data features from dataset file (or cache) Parameters ---------- args : kitanaqa.trainer.arguments.ModelArguments A set of arguments related to the model. Specifically, the following arguments are used in this function: - args.train_file_path : str Path to the training data file - args.do_aug : bool Flag to specify whether to use the augmented training set. If True, will be merged with the original training set specified in train_file_path. The default value is False. - args.aug_file_path : str Path for augmented train dataset - args.data_dir : str Path for data files - args.model_name_or_path : str Path to pretrained model or model identifier from huggingface.co/models - args.max_seq_length : Optional[int] Max length for the input tokens, specified to the Transformer model defined in `model_name_or_path` - args.overwrite_cache : Bool Overwrite cached data on load - args.predict_file_path : Dict[str, str] Paths for eval datasets, where the key is the data file tag, and the value is the data file path. Multiple file paths may be given for evaluation, and each will be cached and loaded separately. - args.version_2_with_negative : Bool Flag that specifies to use the SQuAD v2.0 preprocessors. The default value is False. - args.doc_stride : Optional[int] Corresponds to the doc_stride input param for some Huggingface Transformer models. - args.max_query_length : Optional[int] Max length for the query segment in the Transformer model input. tokenizer : The Transformer model tokenizer used to preprocess the data. evaluate : Optional(Bool) A flag to set the trainer task to either train or evaluate. The default value is False. use_aug_path : Optional(Bool) A flag to define whether to use the aug_file_path or the train_file_path. If True, the augmented data path is used when loading and caching the data. output_examples : Optional(Bool) A flag to define whether the examples and features should be returned by the data preprocessor. If False, the preprocessor only returns the dataset. This is necessary if the Trainer is used for evaluation or in a pipeline where training is followed by evaluation. Returns ------- torch.utils.data.TensorDataset The dataset containing the data to be used for training or evaluation. Important Notes: - If the output_examples is True, examples and features also are returned. - If evaluate = True, the output will be a dictionary for which the keys are the name of the datasets used for evaluation and the values are the dataset (and optionally the examples and features). """ if not args.train_file_path and not (args.do_aug and args.aug_file_path): logging.error( 'load_and_cache_examples requires one of either \"train_file_path\", \"aug_file_path\"' ) # Use the augmented data or the original training data train_or_aug_path = args.train_file_path if not use_aug_path else args.aug_file_path input_dir = args.data_dir if args.data_dir else "." cached_features_file = os.path.join( input_dir, "cached_{}_{}_{}".format( "dev" if evaluate else "train", list(filter(None, args.model_name_or_path.split("/"))).pop(), str(args.max_seq_length), ), ) # Init features and dataset from cache if it exists if os.path.exists(cached_features_file) and not args.overwrite_cache: logger.info("Loading features from cached file %s", cached_features_file) features_and_dataset = torch.load(cached_features_file) features, dataset, examples = ( features_and_dataset["features"], features_and_dataset["dataset"], features_and_dataset["examples"], ) else: logger.info("Creating features from dataset file at %s", input_dir) if not args.data_dir and ((evaluate and not args.predict_file_path) or (not evaluate and not train_or_aug_path)): try: import tensorflow_datasets as tfds except ImportError: raise ImportError( "If not data_dir is specified, tensorflow_datasets needs to be installed." ) if args.version_2_with_negative: logger.warn( "tensorflow_datasets does not handle version 2 of SQuAD.") tfds_examples = tfds.load("squad") examples = SquadV1Processor().get_examples_from_dataset( tfds_examples, evaluate=evaluate) else: if evaluate: # when does it concatenate if eval and train are both true? examples = {} processor = AlumSquadV2Processor( ) if args.version_2_with_negative else AlumSquadV1Processor() for predict_sets, predict_paths in args.predict_file_path.items( ): examples[predict_sets] = processor.alum_get_dev_examples( args.data_dir, filename=predict_paths) logger.info("Evaluation Data is fetched for %s.", predict_sets) else: processor = SquadV2Processor( ) if args.version_2_with_negative else SquadV1Processor() examples = processor.get_train_examples( args.data_dir, filename=train_or_aug_path) if not evaluate: features, dataset = squad_convert_examples_to_features( examples=examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=not evaluate, return_dataset="pt", #threads=args.threads, ) logger.info("Saving features into cached file %s", cached_features_file) torch.save( { "features": features, "dataset": dataset, "examples": examples }, cached_features_file) else: #TODO: Incremental Cache - The current version will cache all the eval files together. features, dataset = {}, {} for predict_sets, example in examples.items(): features[predict_sets], dataset[ predict_sets] = alum_squad_convert_examples_to_features( examples=example, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, return_dataset="pt", #threads=args.threads, ) logger.info( "Feature Extraction for Evaluation Data from %s is Finished.", predict_sets) logger.info("Saving features into cached file %s", cached_features_file) torch.save( { "features": features, "dataset": dataset, "examples": examples }, cached_features_file) if output_examples: return dataset, examples, features return dataset
doc_stride = 128 max_query_length = 64 # Load pretrained model and tokenizer config_class, model_class, tokenizer_class = (BertConfig, BertForQuestionAnswering, BertTokenizer) config = config_class.from_pretrained(model_name_or_path, cache_dir=cache_dir) tokenizer = tokenizer_class.from_pretrained(model_name_or_path, do_lower_case=True, cache_dir=cache_dir) model = model_class.from_pretrained(model_name_or_path, from_tf=False, config=config, cache_dir=cache_dir) processor = SquadV1Processor() examples = processor.get_dev_examples(None, filename=predict_file) features, dataset = squad_convert_examples_to_features( examples= examples[:total_samples], # convert enough examples for this notebook tokenizer=tokenizer, max_seq_length=max_seq_length, doc_stride=doc_stride, max_query_length=max_query_length, is_training=False, return_dataset='pt') output_dir = "./models/squad" if not os.path.exists(output_dir): os.makedirs(output_dir)
def run_squad_and_get_results( model: tf.keras.Model, # Must be QuestionAnswering model, not PreTraining tokenizer: PreTrainedTokenizer, run_name: str, fsx_prefix: str, per_gpu_batch_size: int, checkpoint_frequency: Optional[int], validate_frequency: Optional[int], evaluate_frequency: Optional[int], learning_rate: float, warmup_steps: int, total_steps: int, dataset: str, dummy_eval: bool = False, ) -> Dict: checkpoint_frequency = checkpoint_frequency or 1000000 validate_frequency = validate_frequency or 1000000 evaluate_frequency = evaluate_frequency or 1000000 is_sagemaker = fsx_prefix.startswith("/opt/ml") disable_tqdm = is_sagemaker schedule = LinearWarmupPolyDecaySchedule( max_learning_rate=learning_rate, end_learning_rate=0, warmup_steps=warmup_steps, total_steps=total_steps, ) optimizer = tfa.optimizers.AdamW(weight_decay=0.0, learning_rate=schedule) optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite( optimizer, loss_scale="dynamic") # AMP if dataset == "squadv1": train_filename = "train-v1.1.json" val_filename = "dev-v1.1.json" processor = SquadV1Processor() elif dataset == "squadv2": train_filename = "train-v2.0.json" val_filename = "dev-v2.0.json" processor = SquadV2Processor() elif dataset == "debug": train_filename = "dev-v2.0.json" val_filename = "dev-v2.0.json" processor = SquadV2Processor() else: assert False, "--dataset must be one of ['squadv1', 'squadv2', 'debug']" data_dir = f"{fsx_prefix}/squad_data" train_dataset = get_dataset( tokenizer=tokenizer, processor=processor, data_dir=data_dir, filename=train_filename, per_gpu_batch_size=per_gpu_batch_size, shard=True, shuffle=True, repeat=True, drop_remainder=True, ) if hvd.rank() == 0: logger.info(f"Starting finetuning on {dataset}") pbar = tqdm.tqdm(total_steps, disable=disable_tqdm) summary_writer = None # Only create a writer if we make it through a successful step val_dataset = get_dataset( tokenizer=tokenizer, processor=processor, data_dir=data_dir, filename=val_filename, per_gpu_batch_size=per_gpu_batch_size, shard=False, shuffle=True, drop_remainder=False, ) # Need to re-wrap every time this function is called # Wrapping train_step gives an error with optimizer initialization on the second pass # of run_squad_and_get_results(). Bug report at https://github.com/tensorflow/tensorflow/issues/38875 # Discussion at https://github.com/tensorflow/tensorflow/issues/27120 global train_step train_step = rewrap_tf_function(train_step) for step, batch in enumerate(train_dataset): learning_rate = schedule(step=tf.constant(step, dtype=tf.float32)) loss, acc, exact_match, f1, precision, recall = train_step( model=model, optimizer=optimizer, batch=batch) # Broadcast model after the first step so parameters and optimizer are initialized if step == 0: hvd.broadcast_variables(model.variables, root_rank=0) hvd.broadcast_variables(optimizer.variables(), root_rank=0) is_final_step = step >= total_steps - 1 if hvd.rank() == 0: do_checkpoint = (step % checkpoint_frequency == 0) or is_final_step do_validate = (step % validate_frequency == 0) or is_final_step do_evaluate = (step % evaluate_frequency == 0) or is_final_step pbar.update(1) description = f"Loss: {loss:.3f}, Acc: {acc:.3f}, EM: {exact_match:.3f}, F1: {f1:.3f}" pbar.set_description(description) if do_validate: logger.info("Running validation") ( val_loss, val_acc, val_exact_match, val_f1, val_precision, val_recall, ) = run_validation(model=model, val_dataset=val_dataset) description = ( f"Step {step} validation - Loss: {val_loss:.3f}, Acc: {val_acc:.3f}, " f"EM: {val_exact_match:.3f}, F1: {val_f1:.3f}") logger.info(description) if do_evaluate: logger.info("Running evaluation") if dummy_eval: results = { "exact": 0.8169797018445212, "f1": 4.4469722448269335, "total": 11873, "HasAns_exact": 0.15182186234817813, "HasAns_f1": 7.422216845956518, "HasAns_total": 5928, "NoAns_exact": 1.4802354920100924, "NoAns_f1": 1.4802354920100924, "NoAns_total": 5945, "best_exact": 50.07159100480081, "best_exact_thresh": 0.0, "best_f1": 50.0772059855695, "best_f1_thresh": 0.0, } else: results: Dict = get_evaluation_metrics( model=model, tokenizer=tokenizer, data_dir=data_dir, filename=val_filename, per_gpu_batch_size=32, ) print_eval_metrics(results=results, step=step, dataset=dataset) if do_checkpoint: checkpoint_path = ( f"{fsx_prefix}/checkpoints/albert-squad/{run_name}-step{step}.ckpt" ) logger.info(f"Saving checkpoint at {checkpoint_path}") model.save_weights(checkpoint_path) if summary_writer is None: summary_writer = tf.summary.create_file_writer( f"{fsx_prefix}/logs/albert-squad/{run_name}") with summary_writer.as_default(): tf.summary.scalar("learning_rate", learning_rate, step=step) tf.summary.scalar("train_loss", loss, step=step) tf.summary.scalar("train_acc", acc, step=step) tf.summary.scalar("train_exact", exact_match, step=step) tf.summary.scalar("train_f1", f1, step=step) tf.summary.scalar("train_precision", precision, step=step) tf.summary.scalar("train_recall", recall, step=step) if do_validate: tf.summary.scalar("val_loss", val_loss, step=step) tf.summary.scalar("val_acc", val_acc, step=step) tf.summary.scalar("val_exact", val_exact_match, step=step) tf.summary.scalar("val_f1", val_f1, step=step) tf.summary.scalar("val_precision", val_precision, step=step) tf.summary.scalar("val_recall", val_recall, step=step) # And the eval metrics tensorboard_eval_metrics(summary_writer=summary_writer, results=results, step=step, dataset=dataset) if is_final_step: break del train_dataset # Can we return a value only on a single rank? if hvd.rank() == 0: pbar.close() logger.info(f"Finished finetuning, job name {run_name}") return results
def data_save(args, tokenizer, evaluate=False): data_config = DataConfig( endpoint="127.0.0.1:9000", access_key="minio", secret_key="miniosecretkey", dataset_name="SQuAD1.1" if not args.version_2_with_negative else "SQuAD2.0", additional={ "mode": "train" if not evaluate else "test", "framework": "pytorch", "version": 1.1 if not args.version_2_with_negative else 2.0, "model_name": args.tokenizer_name, "doc_stride": args.doc_stride, "max_seq_length": args.max_seq_length, "max_query_length": args.max_query_length, }, attributes=train_attributes if not evaluate else eval_attributes, ) processor = SquadV2Processor( ) if args.version_2_with_negative else SquadV1Processor() if evaluate: examples = processor.get_dev_examples(None, filename=args.predict_file) else: examples = processor.get_train_examples(None, filename=args.train_file) features, dataset = squad_convert_examples_to_features( examples=examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, is_training=not evaluate, return_dataset="pt", threads=args.threads, ) data_saver = DataSaver(config=data_config) dataloader = DataLoader(dataset, batch_size=8, num_workers=args.threads) if not evaluate: for batch in tqdm(dataloader): inputs = { "all_input_ids": batch[0], "all_attention_masks": batch[1], "all_token_type_ids": batch[2], "all_start_positions": batch[3], "all_end_positions": batch[4], "all_cls_index": batch[5], "all_p_mask": batch[6], "all_is_impossible": batch[7] } data_saver(inputs) else: for batch in tqdm(dataloader): inputs = { "all_input_ids": batch[0], "all_attention_masks": batch[1], "all_token_type_ids": batch[2], "all_feature_index": batch[3], "all_cls_index": batch[4], "all_p_mask": batch[5], } data_saver(inputs) _features, _examples = tempfile.mktemp("features"), tempfile.mktemp( "examples") torch.save(features, _features) torch.save(examples, _examples) data_saver({ "features": _features, "examples": _examples }, filetype=True) data_saver.disconnect()