def _setup_datasets(root='.data', ngrams=1, vocab=None, include_unk=False): file_list = os.listdir(root) for fname in file_list: if fname.endswith('DSL-TRAIN.txt'): train_csv_path = os.path.join(root, fname) if fname.endswith('DSL-TEST-GOLD.txt'): test_csv_path = os.path.join(root, fname) if vocab is None: logging.info('Building Vocab based on {}'.format(train_csv_path)) vocab = build_vocab_from_iterator(_csv_iterator( train_csv_path, ngrams)) else: if not isinstance(vocab, Vocab): raise TypeError("Passed vocabulary is not of type Vocab") logging.info('Vocab has {} entries'.format(len(vocab))) logging.info('Creating training data') train_data, train_labels = _create_data_from_iterator( vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk) logging.info('Creating testing data') test_data, test_labels = _create_data_from_iterator( vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk) if len(train_labels ^ test_labels) > 0: raise ValueError("Training and test labels don't match") return (TextClassificationDataset(vocab, train_data, train_labels), TextClassificationDataset(vocab, test_data, test_labels))
def make_data(path_root='../data/ag_news_csv.tgz', ngrams=2, vocab=None, include_unk=False): extracted_files = extract_archive(path_root) for fname in extracted_files: if fname.endswith('train.csv'): train_csv_path = fname if fname.endswith('test.csv'): test_csv_path = fname vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams)) logging.info('Vocab has {} entries'.format(len(vocab))) logging.info('Creating training data') train_data, train_labels = _create_data_from_iterator( vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk) logging.info('Creating testing data') test_data, test_labels = _create_data_from_iterator( vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk) if len(train_labels ^ test_labels) > 0: raise ValueError("Training and test labels don't match") return (TextClassificationDataset(vocab, train_data, train_labels), TextClassificationDataset(vocab, test_data, test_labels))
def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None, include_unk=False): #dataset_tar = download_from_url(URLS[dataset_name], root=root) extracted_files = extract_archive('.data/ag_news_csv.tar.gz') for fname in extracted_files: if fname.endswith('train.csv'): train_csv_path = fname if fname.endswith('test.csv'): test_csv_path = fname if vocab is None: print('Building Vocab based on {}'.format(train_csv_path)) vocab = build_vocab_from_iterator(_csv_iterator( train_csv_path, ngrams)) else: if not isinstance(vocab, Vocab): raise TypeError("Passed vocabulary is not of type Vocab") print('Vocab has {} entries'.format(len(vocab))) print('Creating training data') train_data, train_labels = _create_data_from_iterator( vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk) print('Creating testing data') test_data, test_labels = _create_data_from_iterator( vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk) if len(train_labels ^ test_labels) > 0: raise ValueError("Training and test labels don't match") return (TextClassificationDataset(vocab, train_data, train_labels), TextClassificationDataset(vocab, test_data, test_labels), vocab)
def prepairData(path, ngrams=NGRAMS, vocab=None): if not os.path.isdir(path): logging.error('Data path err') return train_csv_path = path + 'train.csv' test_csv_path = path + 'test.csv' if vocab is None: logging.info('Building Vocab based on {}'.format(train_csv_path)) vocab = torch_text.build_vocab_from_iterator( torch_text._csv_iterator(train_csv_path, ngrams)) else: if not isinstance(vocab, Vocab): raise TypeError("Passed vocabulary is not of type Vocab") train_data, train_labels = torch_text._create_data_from_iterator( vocab, torch_text._csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk=False) logging.info('Creating testing data') test_data, test_labels = torch_text._create_data_from_iterator( vocab, torch_text._csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk=False) if len(train_labels ^ test_labels) > 0: raise ValueError("Training and test labels don't match") return (torch_text.TextClassificationDataset(vocab, train_data, train_labels), torch_text.TextClassificationDataset(vocab, test_data, test_labels))
def setup_datasets(train_csv_path, test_csv_path, include_unk=False): iterator=_csv_iterator(train_csv_path, NGRAMS) vocab = build_vocab_from_iterator(iterator) train_data, train_labels = _create_data_from_iterator(vocab, _csv_iterator(train_csv_path, NGRAMS, yield_cls=True, label=0), include_unk) test_data, test_labels = _create_data_from_iterator(vocab, _csv_iterator(test_csv_path, NGRAMS, yield_cls=True, label=0), include_unk) return TextClassificationDataset(vocab, train_data, train_labels), TextClassificationDataset(vocab, test_data, test_labels)
def loadData(train_csv_path, test_csv_path, ngrams): vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams)) train_data, train_labels = _create_data_from_iterator( vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), False) test_data, test_labels = _create_data_from_iterator( vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), False) return (TextClassificationDataset(vocab, train_data, train_labels), TextClassificationDataset(vocab, test_data, test_labels))