예제 #1
0
def get_dataloader(logger, args, input_file, is_training, \
                   batch_size, num_epochs, tokenizer, index=None):
    gc.enable()
    n_paragraphs = args.n_paragraphs

    if (not is_training) and ',' in n_paragraphs:
        n_paragraphs = n_paragraphs.split(',')[-1]

    feature_save_path = input_file.replace(
        '.json', '-{}-{}-{}.pkl'.format(args.max_seq_length, n_paragraphs,
                                        args.max_n_answers))
    gc.collect()

    if os.path.exists(feature_save_path):
        logger.info("Loading saved features from {}".format(feature_save_path))
        with open(feature_save_path, 'rb') as f:
            features = pkl.load(f)
            train_features = features['features']
            examples = features.get('examples', None)
    else:
        examples = read_squad_examples(logger=logger,
                                       args=args,
                                       input_file=input_file,
                                       debug=args.debug)

        train_features = convert_examples_to_features(
            logger=logger,
            args=args,
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            max_n_answers=args.max_n_answers if is_training else 1,
            is_training=is_training)
        gc.collect()
        if not args.debug:
            logger.info("Saving features to: {}".format(feature_save_path))
            save_features = {'features': train_features}
            if not is_training:
                save_features['examples'] = examples
            with open(feature_save_path, 'wb') as f:
                gc.collect()
                pkl.dump(save_features, f)

    n_features = sum([len(f) for f in train_features])
    num_train_steps = int(len(train_features) / batch_size * num_epochs)

    if examples is not None:
        logger.info("  Num orig examples = %d", len(examples))
    logger.info("  Num split examples = %d", n_features)
    logger.info("  Batch size = %d", batch_size)
    if is_training:
        logger.info("  Num steps = %d", num_train_steps)

    dataloader = MyDataLoader(features=train_features,
                              batch_size=batch_size,
                              is_training=is_training)
    flattened_features = [f for _features in train_features for f in _features]
    return dataloader, examples, flattened_features, num_train_steps
예제 #2
0
 def load_dataloader(self, batch_size, is_training=None, do_return=False):
     self.dataloader = MyDataLoader(self.args,
                                    self.dataset,
                                    batch_size=int(batch_size/4),
                                    is_training=self.is_training if is_training is None else is_training)
     if do_return:
         return self.dataloader
예제 #3
0
파일: QGData.py 프로젝트: sumit6597/AmbigQA
 def load_dataloader(self, do_return=False):
     self.dataloader = MyDataLoader(self.args,
                                    self.dataset,
                                    is_training=self.is_training)
     if do_return:
         return self.dataloader