示例#1
0
def _setup_datasets(root='.data', ngrams=1, vocab=None, include_unk=False):
    file_list = os.listdir(root)

    for fname in file_list:
        if fname.endswith('DSL-TRAIN.txt'):
            train_csv_path = os.path.join(root, fname)
        if fname.endswith('DSL-TEST-GOLD.txt'):
            test_csv_path = os.path.join(root, fname)

    if vocab is None:
        logging.info('Building Vocab based on {}'.format(train_csv_path))
        vocab = build_vocab_from_iterator(_csv_iterator(
            train_csv_path, ngrams))
    else:
        if not isinstance(vocab, Vocab):
            raise TypeError("Passed vocabulary is not of type Vocab")
    logging.info('Vocab has {} entries'.format(len(vocab)))
    logging.info('Creating training data')
    train_data, train_labels = _create_data_from_iterator(
        vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True),
        include_unk)
    logging.info('Creating testing data')
    test_data, test_labels = _create_data_from_iterator(
        vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True),
        include_unk)
    if len(train_labels ^ test_labels) > 0:
        raise ValueError("Training and test labels don't match")
    return (TextClassificationDataset(vocab, train_data, train_labels),
            TextClassificationDataset(vocab, test_data, test_labels))
示例#2
0
def make_data(path_root='../data/ag_news_csv.tgz',
              ngrams=2,
              vocab=None,
              include_unk=False):
    extracted_files = extract_archive(path_root)

    for fname in extracted_files:
        if fname.endswith('train.csv'):
            train_csv_path = fname
        if fname.endswith('test.csv'):
            test_csv_path = fname

    vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))

    logging.info('Vocab has {} entries'.format(len(vocab)))
    logging.info('Creating training data')

    train_data, train_labels = _create_data_from_iterator(
        vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True),
        include_unk)
    logging.info('Creating testing data')
    test_data, test_labels = _create_data_from_iterator(
        vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True),
        include_unk)
    if len(train_labels ^ test_labels) > 0:
        raise ValueError("Training and test labels don't match")
    return (TextClassificationDataset(vocab, train_data, train_labels),
            TextClassificationDataset(vocab, test_data, test_labels))
示例#3
0
def _setup_datasets(dataset_name,
                    root='.data',
                    ngrams=1,
                    vocab=None,
                    include_unk=False):
    #dataset_tar = download_from_url(URLS[dataset_name], root=root)
    extracted_files = extract_archive('.data/ag_news_csv.tar.gz')

    for fname in extracted_files:
        if fname.endswith('train.csv'):
            train_csv_path = fname
        if fname.endswith('test.csv'):
            test_csv_path = fname

    if vocab is None:
        print('Building Vocab based on {}'.format(train_csv_path))
        vocab = build_vocab_from_iterator(_csv_iterator(
            train_csv_path, ngrams))
    else:
        if not isinstance(vocab, Vocab):
            raise TypeError("Passed vocabulary is not of type Vocab")
    print('Vocab has {} entries'.format(len(vocab)))
    print('Creating training data')
    train_data, train_labels = _create_data_from_iterator(
        vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True),
        include_unk)
    print('Creating testing data')
    test_data, test_labels = _create_data_from_iterator(
        vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True),
        include_unk)
    if len(train_labels ^ test_labels) > 0:
        raise ValueError("Training and test labels don't match")
    return (TextClassificationDataset(vocab, train_data, train_labels),
            TextClassificationDataset(vocab, test_data, test_labels), vocab)
示例#4
0
def prepairData(path, ngrams=NGRAMS, vocab=None):
    if not os.path.isdir(path):
        logging.error('Data path err')
        return

    train_csv_path = path + 'train.csv'
    test_csv_path = path + 'test.csv'

    if vocab is None:
        logging.info('Building Vocab based on {}'.format(train_csv_path))
        vocab = torch_text.build_vocab_from_iterator(
            torch_text._csv_iterator(train_csv_path, ngrams))
    else:
        if not isinstance(vocab, Vocab):
            raise TypeError("Passed vocabulary is not of type Vocab")

    train_data, train_labels = torch_text._create_data_from_iterator(
        vocab,
        torch_text._csv_iterator(train_csv_path, ngrams, yield_cls=True),
        include_unk=False)
    logging.info('Creating testing data')
    test_data, test_labels = torch_text._create_data_from_iterator(
        vocab,
        torch_text._csv_iterator(test_csv_path, ngrams, yield_cls=True),
        include_unk=False)
    if len(train_labels ^ test_labels) > 0:
        raise ValueError("Training and test labels don't match")
    return (torch_text.TextClassificationDataset(vocab, train_data,
                                                 train_labels),
            torch_text.TextClassificationDataset(vocab, test_data,
                                                 test_labels))
示例#5
0
def setup_datasets(train_csv_path, test_csv_path, include_unk=False):
    iterator=_csv_iterator(train_csv_path, NGRAMS)
    vocab = build_vocab_from_iterator(iterator)
    train_data, train_labels = _create_data_from_iterator(vocab, _csv_iterator(train_csv_path, NGRAMS, yield_cls=True, label=0), include_unk)
    test_data, test_labels = _create_data_from_iterator(vocab, _csv_iterator(test_csv_path, NGRAMS, yield_cls=True, label=0), include_unk)


    return TextClassificationDataset(vocab, train_data, train_labels), TextClassificationDataset(vocab, test_data, test_labels)
示例#6
0
def loadData(train_csv_path, test_csv_path, ngrams):
    vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))

    train_data, train_labels = _create_data_from_iterator(
        vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), False)
    test_data, test_labels = _create_data_from_iterator(
        vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), False)

    return (TextClassificationDataset(vocab, train_data, train_labels),
            TextClassificationDataset(vocab, test_data, test_labels))