torch.cuda.set_device(args.gpu) torch.cuda.manual_seed(args.seed) if torch.cuda.is_available() and not args.cuda: print('Warning: You have Cuda but not use it. You are using CPU for training.') np.random.seed(args.seed) random.seed(args.seed) logger = get_logger() # Set up the data for training SST-1 if args.dataset == 'SST-1': train_iter, dev_iter, test_iter = SST1.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk) # Set up the data for training SST-2 elif args.dataset == 'SST-2': train_iter, dev_iter, test_iter = SST2.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk) elif args.dataset == 'Reuters': train_iter, dev_iter, test_iter = Reuters.iters(args.data_dir, args.word_vectors_file, args.word_vectors_dir, batch_size=args.batch_size, device=args.gpu, unk_init=UnknownWordVecCache.unk) else: raise ValueError('Unrecognized dataset') config = deepcopy(args) config.dataset = train_iter.dataset config.target_class = train_iter.dataset.NUM_CLASSES config.words_num = len(train_iter.dataset.TEXT_FIELD.vocab) print('Dataset {} Mode {}'.format(args.dataset, args.mode)) print('VOCAB num',len(train_iter.dataset.TEXT_FIELD.vocab)) print('LABEL.target_class:', train_iter.dataset.NUM_CLASSES) print('Train instance', len(train_iter.dataset)) print('Dev instance', len(dev_iter.dataset)) print('Test instance', len(test_iter.dataset))
def get_dataset(dataset_name, word_vectors_dir, word_vectors_file, batch_size, device, castor_dir="./", utils_trecqa="utils/trec_eval-9.0.5/trec_eval"): if dataset_name == 'sick': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'sick/') train_loader, dev_loader, test_loader = SICK.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( SICK.TEXT_FIELD.vocab.vectors) return SICK, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'msrvid': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'msrvid/') dev_loader = None train_loader, test_loader = MSRVID.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( MSRVID.TEXT_FIELD.vocab.vectors) return MSRVID, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'trecqa': if not os.path.exists(os.path.join(castor_dir, utils_trecqa)): raise FileNotFoundError( 'TrecQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.' ) dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'TrecQA/') train_loader, dev_loader, test_loader = TRECQA.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( TRECQA.TEXT_FIELD.vocab.vectors) return TRECQA, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'wikiqa': if not os.path.exists(os.path.join(castor_dir, utils_trecqa)): raise FileNotFoundError( 'WikiQA requires the trec_eval tool to run. Please run get_trec_eval.sh inside Castor/utils (as working directory) before continuing.' ) dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'WikiQA/') train_loader, dev_loader, test_loader = WikiQA.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( WikiQA.TEXT_FIELD.vocab.vectors) return WikiQA, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'pit2015': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'SemEval-PIT2015/') train_loader, dev_loader, test_loader = PIT2015.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( PIT2015.TEXT_FIELD.vocab.vectors) return PIT2015, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'twitterurl': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'Twitter-URL/') train_loader, dev_loader, test_loader = PIT2015.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( PIT2015.TEXT_FIELD.vocab.vectors) return PIT2015, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'snli': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'snli_1.0/') train_loader, dev_loader, test_loader = SNLI.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( SNLI.TEXT_FIELD.vocab.vectors) return SNLI, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'sts2014': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'STS-2014') train_loader, dev_loader, test_loader = STS2014.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( STS2014.TEXT_FIELD.vocab.vectors) return STS2014, embedding, train_loader, test_loader, dev_loader elif dataset_name == "quora": dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'quora/') train_loader, dev_loader, test_loader = Quora.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( Quora.TEXT_FIELD.vocab.vectors) return Quora, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'reuters': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'Reuters-21578/') train_loader, dev_loader, test_loader = Reuters.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( Reuters.TEXT_FIELD.vocab.vectors) return Reuters, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'aapd': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'AAPD/') train_loader, dev_loader, test_loader = AAPD.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( AAPD.TEXT_FIELD.vocab.vectors) return AAPD, embedding, train_loader, test_loader, dev_loader elif dataset_name == 'imdb': dataset_root = os.path.join(castor_dir, os.pardir, 'Castor-data', 'datasets', 'IMDB/') train_loader, dev_loader, test_loader = AAPD.iters( dataset_root, word_vectors_file, word_vectors_dir, batch_size, device=device, unk_init=UnknownWordVecCache.unk) embedding = nn.Embedding.from_pretrained( AAPD.TEXT_FIELD.vocab.vectors) return IMDB, embedding, train_loader, test_loader, dev_loader else: raise ValueError('{} is not a valid dataset.'.format(dataset_name))