def get_dataset(dataset_name, word_vectors_dir, word_vectors_file, batch_size, device, castor_dir="./", utils_trecqa="utils/trec_eval-9.0.5/trec_eval"): if dataset_name == 'sick': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'sick/') train_loader, dev_loader, test_loader = SICK.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained(SICK.TEXT_FIELD.vocab.vectors) return SICK, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'msrvid': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'msrvid/') dev_loader = None train_loader, test_loader = MSRVID.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained(MSRVID.TEXT_FIELD.vocab.vectors) return MSRVID, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'trecqa': if not os.path.exists(os.path.join(castor_dir, utils_trecqa)): raise FileNotFoundError('TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.') dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'TrecQA/') train_loader, dev_loader, test_loader = TRECQA.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained(TRECQA.TEXT_FIELD.vocab.vectors) return TRECQA, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'wikiqa': if not os.path.exists(os.path.join(castor_dir, utils_trecqa)): raise FileNotFoundError('TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.') dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'WikiQA/') train_loader, dev_loader, test_loader = WikiQA.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained(WikiQA.TEXT_FIELD.vocab.vectors) return WikiQA, embedding, train_loader, test_loader, dev_loader else: raise ValueError('{} is not a valid dataset.'.format(dataset_name))
def get_ds(name, path, word_to_index, index_to_embedding, qmax, amax, char_min, num_neg): if name == 'yahooqa': dataset = YahooQA(path, word_to_index, index_to_embedding, qmax, amax, char_min, num_neg) tf.logging.info('YahooDS loaded') elif name == 'wikiqa': dataset = WikiQA(path, word_to_index, index_to_embedding, qmax, amax, char_min, num_neg) tf.logging.info('WikiQA loaded') return dataset
def get_dataset(dataset_name, word_vectors_dir, word_vectors_file, batch_size, device): if dataset_name == 'sick': dataset_root = os.path.join(os.pardir, 'data', 'sick/') train_loader, dev_loader, test_loader = SICK.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding_dim = SICK.TEXT_FIELD.vocab.vectors.size() embedding = nn.Embedding(embedding_dim[0], embedding_dim[1]) embedding.weight = nn.Parameter(SICK.TEXT_FIELD.vocab.vectors) return SICK, embedding, train_loader, test_loader, dev_loader if dataset_name == 'sts': dataset_root = os.path.join(os.pardir, 'data', 'sts/') train_loader, dev_loader, test_loader = STS.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding_dim = STS.TEXT_FIELD.vocab.vectors.size() embedding = nn.Embedding(embedding_dim[0], embedding_dim[1]) embedding.weight = nn.Parameter(STS.TEXT_FIELD.vocab.vectors) return STS, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'msrvid': dataset_root = os.path.join(os.pardir, 'data', 'msrvid/') dev_loader = None train_loader, test_loader = MSRVID.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding_dim = MSRVID.TEXT_FIELD.vocab.vectors.size() embedding = nn.Embedding(embedding_dim[0], embedding_dim[1]) embedding.weight = nn.Parameter(MSRVID.TEXT_FIELD.vocab.vectors) return MSRVID, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'trecqa': if not os.path.exists('./utils/trec_eval-9.0.5/trec_eval'): raise FileNotFoundError('TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside utils/ (as working directory) before continuing.') dataset_root = os.path.join(os.pardir, 'data', 'TrecQA/') train_loader, dev_loader, test_loader = TRECQA.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding_dim = TRECQA.TEXT_FIELD.vocab.vectors.size() embedding = nn.Embedding(embedding_dim[0], embedding_dim[1]) embedding.weight = nn.Parameter(TRECQA.TEXT_FIELD.vocab.vectors) return TRECQA, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'wikiqa': if not os.path.exists('./utils/trec_eval-9.0.5/trec_eval'): raise FileNotFoundError('TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.') dataset_root = os.path.join(os.pardir, 'data', 'WikiQA/') train_loader, dev_loader, test_loader = WikiQA.iters(dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding_dim = WikiQA.TEXT_FIELD.vocab.vectors.size() embedding = nn.Embedding(embedding_dim[0], embedding_dim[1]) embedding.weight = nn.Parameter(WikiQA.TEXT_FIELD.vocab.vectors) return WikiQA, embedding, train_loader, test_loader, dev_loader else: raise ValueError('{} is not a valid dataset.'.format(dataset_name))
def get_dataset(args): if args.dataset == 'sick': train_loader, dev_loader, test_loader = SICK.iters( batch_size=args.batch_size, device=args.device, shuffle=True) embedding_dim = SICK.TEXT.vocab.vectors.size() embedding = nn.Embedding(embedding_dim[0], embedding_dim[1]) embedding.weight = nn.Parameter(SICK.TEXT.vocab.vectors) embedding.weight.requires_grad = False return SICK, train_loader, dev_loader, test_loader, embedding elif args.dataset == 'wikiqa': train_loader, dev_loader, test_loader = WikiQA.iters( batch_size=args.batch_size, device=args.device, shuffle=True) embedding_dim = WikiQA.TEXT.vocab.vectors.size() embedding = nn.Embedding(embedding_dim[0], embedding_dim[1]) embedding.weight = nn.Parameter(WikiQA.TEXT.vocab.vectors) embedding.weight.requires_grad = False return WikiQA, train_loader, dev_loader, test_loader, embedding else: raise ValueError(f'Unrecognized dataset: {args.dataset}')
if config.dataset == 'TREC': dataset_root = os.path.join(os.pardir, os.pardir, 'data', 'TrecQA/') train_iter, dev_iter, test_iter = TRECQA.iters(dataset_root, word_vectors_name, word_vectors_dir, args.batch_size, device=args.gpu) embedding_dim = TRECQA.TEXT_FIELD.vocab.vectors.size() embedding = nn.Embedding(embedding_dim[0], embedding_dim[1]) embedding.weight = nn.Parameter(TRECQA.TEXT_FIELD.vocab.vectors) else: dataset_root = os.path.join(os.pardir, os.pardir, 'data', 'WikiQA/') train_iter, dev_iter, test_iter = WikiQA.iters(dataset_root, word_vectors_name, word_vectors_dir, args.batch_size, device=args.gpu) embedding_dim = WikiQA.TEXT_FIELD.vocab.vectors.size() embedding = nn.Embedding(embedding_dim[0], embedding_dim[1]) embedding.weight = nn.Parameter(WikiQA.TEXT_FIELD.vocab.vectors) embedding.weight.requires_grad = False snapshot_path = os.path.join(args.save_path, args.dataset, 'static_best_model.pt') if args.gpu != -1: with torch.cuda.device(args.gpu): embedding = embedding.cuda() print("Dataset {}".format(args.dataset))
def get_dataset(dataset_name, word_vectors_dir, word_vectors_file, batch_size, device, castor_dir="./", utils_trecqa="utils/trec_eval-9.0.5/trec_eval"): if dataset_name == 'sick': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'sick/') train_loader, dev_loader, test_loader = SICK.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( SICK.TEXT_FIELD.vocab.vectors) return SICK, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'msrvid': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'msrvid/') dev_loader = None train_loader, test_loader = MSRVID.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( MSRVID.TEXT_FIELD.vocab.vectors) return MSRVID, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'trecqa': if not os.path.exists(os.path.join(castor_dir, utils_trecqa)): raise FileNotFoundError( 'TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.' ) dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'TrecQA/') train_loader, dev_loader, test_loader = TRECQA.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( TRECQA.TEXT_FIELD.vocab.vectors) return TRECQA, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'wikiqa': if not os.path.exists(os.path.join(castor_dir, utils_trecqa)): raise FileNotFoundError( 'WikiQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.' ) dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'WikiQA/') train_loader, dev_loader, test_loader = WikiQA.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( WikiQA.TEXT_FIELD.vocab.vectors) return WikiQA, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'pit2015': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'SemEval-PIT2015/') train_loader, dev_loader, test_loader = PIT2015.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( PIT2015.TEXT_FIELD.vocab.vectors) return PIT2015, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'twitterurl': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'Twitter-URL/') train_loader, dev_loader, test_loader = PIT2015.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( PIT2015.TEXT_FIELD.vocab.vectors) return PIT2015, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'snli': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'snli_1.0/') train_loader, dev_loader, test_loader = SNLI.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( SNLI.TEXT_FIELD.vocab.vectors) return SNLI, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'sts2014': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'STS-2014') train_loader, dev_loader, test_loader = STS2014.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( STS2014.TEXT_FIELD.vocab.vectors) return STS2014, embedding, train_loader, test_loader, dev_loader elif dataset_name == "quora": dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'quora/') train_loader, dev_loader, test_loader = Quora.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( Quora.TEXT_FIELD.vocab.vectors) return Quora, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'reuters': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'Reuters-21578/') train_loader, dev_loader, test_loader = Reuters.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( Reuters.TEXT_FIELD.vocab.vectors) return Reuters, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'aapd': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'AAPD/') train_loader, dev_loader, test_loader = AAPD.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( AAPD.TEXT_FIELD.vocab.vectors) return AAPD, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'imdb': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'IMDB/') train_loader, dev_loader, test_loader = AAPD.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( AAPD.TEXT_FIELD.vocab.vectors) return IMDB, embedding, train_loader, test_loader, dev_loader else: raise ValueError('{} is not a valid dataset.'.format(dataset_name))
def get_dataset(dataset_name, word_vectors_dir, word_vectors_file, batch_size, device): trec_eval_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), 'utils/trec_eval-9.0.5/trec_eval') if dataset_name == 'sick': dataset_root = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.pardir, 'Castor-data', 'datasets', 'sick/') train_loader, dev_loader, test_loader = SICK.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( SICK.TEXT_FIELD.vocab.vectors) return SICK, embedding, train_loader, test_loader, dev_loader if dataset_name == 'sts': dataset_root = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.pardir, 'Castor-data', 'datasets', 'sts/') train_loader, dev_loader, test_loader = STS.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( STS.TEXT_FIELD.vocab.vectors) return STS, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'msrp': dataset_root = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.pardir, 'Castor-data', 'datasets', 'msrp/') train_loader, dev_loader, test_loader = MSRP.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( MSRP.TEXT_FIELD.vocab.vectors) return MSRP, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'msrvid': dataset_root = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.pardir, 'Castor-data', 'datasets', 'msrvid/') dev_loader = None train_loader, test_loader = MSRVID.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( MSRVID.TEXT_FIELD.vocab.vectors) return MSRVID, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'trecqa': if not os.path.exists(trec_eval_path): raise FileNotFoundError( 'TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside utils/ (as working directory) before continuing.' ) dataset_root = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.pardir, 'Castor-data', 'datasets', 'TrecQA/') train_loader, dev_loader, test_loader = TRECQA.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( TRECQA.TEXT_FIELD.vocab.vectors) return TRECQA, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'wikiqa': if not os.path.exists(trec_eval_path): raise FileNotFoundError( 'TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.' ) dataset_root = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.pardir, 'Castor-data', 'datasets', 'WikiQA/') train_loader, dev_loader, test_loader = WikiQA.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( WikiQA.TEXT_FIELD.vocab.vectors) return WikiQA, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'semeval': if not os.path.exists(trec_eval_path): raise FileNotFoundError( 'Semeval requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.' ) dataset_root = './semeval/' train_loader, dev_loader, test_loader = CQA.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( CQA.TEXT_FIELD.vocab.vectors) return CQA, embedding, train_loader, test_loader, dev_loader else: raise ValueError('{} is not a valid dataset.'.format(dataset_name))