Ejemplo n.º 1
0
 def get_dataset(dataset_name, word_vectors_dir, word_vectors_file, batch_size, device, castor_dir="./", utils_trecqa="utils/trec_eval-9.0.5/trec_eval"):
     if dataset_name == 'sick':
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'sick/')
         train_loader, dev_loader, test_loader = SICK.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(SICK.TEXT_FIELD.vocab.vectors)
         return SICK, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'msrvid':
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'msrvid/')
         dev_loader = None
         train_loader, test_loader = MSRVID.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(MSRVID.TEXT_FIELD.vocab.vectors)
         return MSRVID, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'trecqa':
         if not os.path.exists(os.path.join(castor_dir, utils_trecqa)):
             raise FileNotFoundError('TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.')
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'TrecQA/')
         train_loader, dev_loader, test_loader = TRECQA.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(TRECQA.TEXT_FIELD.vocab.vectors)
         return TRECQA, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'wikiqa':
         if not os.path.exists(os.path.join(castor_dir, utils_trecqa)):
             raise FileNotFoundError('TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.')
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'WikiQA/')
         train_loader, dev_loader, test_loader = WikiQA.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(WikiQA.TEXT_FIELD.vocab.vectors)
         return WikiQA, embedding, train_loader, test_loader, dev_loader
     else:
         raise ValueError('{} is not a valid dataset.'.format(dataset_name))
Ejemplo n.º 2
0
 def get_dataset(dataset_name, word_vectors_dir, word_vectors_file, batch_size, device):
     if dataset_name == 'sick':
         dataset_root = os.path.join(os.pardir, 'data', 'sick/')
         train_loader, dev_loader, test_loader = SICK.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
         embedding_dim = SICK.TEXT_FIELD.vocab.vectors.size()
         embedding = nn.Embedding(embedding_dim[0], embedding_dim[1])
         embedding.weight = nn.Parameter(SICK.TEXT_FIELD.vocab.vectors)
         return SICK, embedding, train_loader, test_loader, dev_loader
     if dataset_name == 'sts':
         dataset_root = os.path.join(os.pardir, 'data', 'sts/')
         train_loader, dev_loader, test_loader = STS.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
         embedding_dim = STS.TEXT_FIELD.vocab.vectors.size()
         embedding = nn.Embedding(embedding_dim[0], embedding_dim[1])
         embedding.weight = nn.Parameter(STS.TEXT_FIELD.vocab.vectors)
         return STS, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'msrvid':
         dataset_root = os.path.join(os.pardir, 'data', 'msrvid/')
         dev_loader = None
         train_loader, test_loader = MSRVID.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
         embedding_dim = MSRVID.TEXT_FIELD.vocab.vectors.size()
         embedding = nn.Embedding(embedding_dim[0], embedding_dim[1])
         embedding.weight = nn.Parameter(MSRVID.TEXT_FIELD.vocab.vectors)
         return MSRVID, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'trecqa':
         if not os.path.exists('./utils/trec_eval-9.0.5/trec_eval'):
             raise FileNotFoundError('TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside utils/ (as working directory) before continuing.')
         dataset_root = os.path.join(os.pardir, 'data', 'TrecQA/')
         train_loader, dev_loader, test_loader = TRECQA.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
         embedding_dim = TRECQA.TEXT_FIELD.vocab.vectors.size()
         embedding = nn.Embedding(embedding_dim[0], embedding_dim[1])
         embedding.weight = nn.Parameter(TRECQA.TEXT_FIELD.vocab.vectors)
         return TRECQA, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'wikiqa':
         if not os.path.exists('./utils/trec_eval-9.0.5/trec_eval'):
             raise FileNotFoundError('TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.')
         dataset_root = os.path.join(os.pardir, 'data', 'WikiQA/')
         train_loader, dev_loader, test_loader = WikiQA.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk)
         embedding_dim = WikiQA.TEXT_FIELD.vocab.vectors.size()
         embedding = nn.Embedding(embedding_dim[0], embedding_dim[1])
         embedding.weight = nn.Parameter(WikiQA.TEXT_FIELD.vocab.vectors)
         return WikiQA, embedding, train_loader, test_loader, dev_loader
     else:
         raise ValueError('{} is not a valid dataset.'.format(dataset_name))
Ejemplo n.º 3
0
def get_dataset(args):
    if args.dataset == 'sick':
        train_loader, dev_loader, test_loader = SICK.iters(
            batch_size=args.batch_size, device=args.device, shuffle=True)

        embedding_dim = SICK.TEXT.vocab.vectors.size()
        embedding = nn.Embedding(embedding_dim[0], embedding_dim[1])
        embedding.weight = nn.Parameter(SICK.TEXT.vocab.vectors)
        embedding.weight.requires_grad = False

        return SICK, train_loader, dev_loader, test_loader, embedding
    elif args.dataset == 'wikiqa':
        train_loader, dev_loader, test_loader = WikiQA.iters(
            batch_size=args.batch_size, device=args.device, shuffle=True)

        embedding_dim = WikiQA.TEXT.vocab.vectors.size()
        embedding = nn.Embedding(embedding_dim[0], embedding_dim[1])
        embedding.weight = nn.Parameter(WikiQA.TEXT.vocab.vectors)
        embedding.weight.requires_grad = False

        return WikiQA, train_loader, dev_loader, test_loader, embedding
    else:
        raise ValueError(f'Unrecognized dataset: {args.dataset}')
Ejemplo n.º 4
0
 def get_dataset(dataset_name,
                 word_vectors_dir,
                 word_vectors_file,
                 batch_size,
                 device,
                 castor_dir="./",
                 utils_trecqa="utils/trec_eval-9.0.5/trec_eval"):
     if dataset_name == 'sick':
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data',
                                     'datasets', 'sick/')
         train_loader, dev_loader, test_loader = SICK.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             SICK.TEXT_FIELD.vocab.vectors)
         return SICK, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'msrvid':
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data',
                                     'datasets', 'msrvid/')
         dev_loader = None
         train_loader, test_loader = MSRVID.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             MSRVID.TEXT_FIELD.vocab.vectors)
         return MSRVID, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'trecqa':
         if not os.path.exists(os.path.join(castor_dir, utils_trecqa)):
             raise FileNotFoundError(
                 'TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.'
             )
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data',
                                     'datasets', 'TrecQA/')
         train_loader, dev_loader, test_loader = TRECQA.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             TRECQA.TEXT_FIELD.vocab.vectors)
         return TRECQA, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'wikiqa':
         if not os.path.exists(os.path.join(castor_dir, utils_trecqa)):
             raise FileNotFoundError(
                 'WikiQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.'
             )
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data',
                                     'datasets', 'WikiQA/')
         train_loader, dev_loader, test_loader = WikiQA.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             WikiQA.TEXT_FIELD.vocab.vectors)
         return WikiQA, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'pit2015':
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data',
                                     'datasets', 'SemEval-PIT2015/')
         train_loader, dev_loader, test_loader = PIT2015.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             PIT2015.TEXT_FIELD.vocab.vectors)
         return PIT2015, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'twitterurl':
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data',
                                     'datasets', 'Twitter-URL/')
         train_loader, dev_loader, test_loader = PIT2015.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             PIT2015.TEXT_FIELD.vocab.vectors)
         return PIT2015, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'snli':
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data',
                                     'datasets', 'snli_1.0/')
         train_loader, dev_loader, test_loader = SNLI.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             SNLI.TEXT_FIELD.vocab.vectors)
         return SNLI, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'sts2014':
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data',
                                     'datasets', 'STS-2014')
         train_loader, dev_loader, test_loader = STS2014.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             STS2014.TEXT_FIELD.vocab.vectors)
         return STS2014, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == "quora":
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data',
                                     'datasets', 'quora/')
         train_loader, dev_loader, test_loader = Quora.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             Quora.TEXT_FIELD.vocab.vectors)
         return Quora, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'reuters':
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data',
                                     'datasets', 'Reuters-21578/')
         train_loader, dev_loader, test_loader = Reuters.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             Reuters.TEXT_FIELD.vocab.vectors)
         return Reuters, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'aapd':
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data',
                                     'datasets', 'AAPD/')
         train_loader, dev_loader, test_loader = AAPD.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             AAPD.TEXT_FIELD.vocab.vectors)
         return AAPD, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'imdb':
         dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data',
                                     'datasets', 'IMDB/')
         train_loader, dev_loader, test_loader = AAPD.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             AAPD.TEXT_FIELD.vocab.vectors)
         return IMDB, embedding, train_loader, test_loader, dev_loader
     else:
         raise ValueError('{} is not a valid dataset.'.format(dataset_name))
Ejemplo n.º 5
0
 def get_dataset(dataset_name, word_vectors_dir, word_vectors_file,
                 batch_size, device):
     trec_eval_path = os.path.join(
         os.path.dirname(os.path.realpath(__file__)),
         'utils/trec_eval-9.0.5/trec_eval')
     if dataset_name == 'sick':
         dataset_root = os.path.join(
             os.path.dirname(os.path.realpath(__file__)), os.pardir,
             'Castor-data', 'datasets', 'sick/')
         train_loader, dev_loader, test_loader = SICK.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             SICK.TEXT_FIELD.vocab.vectors)
         return SICK, embedding, train_loader, test_loader, dev_loader
     if dataset_name == 'sts':
         dataset_root = os.path.join(
             os.path.dirname(os.path.realpath(__file__)), os.pardir,
             'Castor-data', 'datasets', 'sts/')
         train_loader, dev_loader, test_loader = STS.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             STS.TEXT_FIELD.vocab.vectors)
         return STS, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'msrp':
         dataset_root = os.path.join(
             os.path.dirname(os.path.realpath(__file__)), os.pardir,
             'Castor-data', 'datasets', 'msrp/')
         train_loader, dev_loader, test_loader = MSRP.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             MSRP.TEXT_FIELD.vocab.vectors)
         return MSRP, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'msrvid':
         dataset_root = os.path.join(
             os.path.dirname(os.path.realpath(__file__)), os.pardir,
             'Castor-data', 'datasets', 'msrvid/')
         dev_loader = None
         train_loader, test_loader = MSRVID.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             MSRVID.TEXT_FIELD.vocab.vectors)
         return MSRVID, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'trecqa':
         if not os.path.exists(trec_eval_path):
             raise FileNotFoundError(
                 'TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside utils/ (as working directory) before continuing.'
             )
         dataset_root = os.path.join(
             os.path.dirname(os.path.realpath(__file__)), os.pardir,
             'Castor-data', 'datasets', 'TrecQA/')
         train_loader, dev_loader, test_loader = TRECQA.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             TRECQA.TEXT_FIELD.vocab.vectors)
         return TRECQA, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'wikiqa':
         if not os.path.exists(trec_eval_path):
             raise FileNotFoundError(
                 'TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.'
             )
         dataset_root = os.path.join(
             os.path.dirname(os.path.realpath(__file__)), os.pardir,
             'Castor-data', 'datasets', 'WikiQA/')
         train_loader, dev_loader, test_loader = WikiQA.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             WikiQA.TEXT_FIELD.vocab.vectors)
         return WikiQA, embedding, train_loader, test_loader, dev_loader
     elif dataset_name == 'semeval':
         if not os.path.exists(trec_eval_path):
             raise FileNotFoundError(
                 'Semeval requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.'
             )
         dataset_root = './semeval/'
         train_loader, dev_loader, test_loader = CQA.iters(
             dataset_root,
             word_vectors_file,
             word_vectors_dir,
             batch_size,
             device=device,
             unk_init=UnknownWordVecCache.unk)
         embedding = nn.Embedding.from_pretrained(
             CQA.TEXT_FIELD.vocab.vectors)
         return CQA, embedding, train_loader, test_loader, dev_loader
     else:
         raise ValueError('{} is not a valid dataset.'.format(dataset_name))