コード例 #1
0
def _setup_datasets(dataset_name, tokenizer, root, vocab, split_, year, language):
    if tokenizer is None:
        tokenizer = get_tokenizer('basic_english')

    split = _check_default_set(split_, ('train', 'test', 'valid'), dataset_name)

    if vocab is None:
        if 'train' not in split:
            raise TypeError("Must pass a vocab if train is not selected.")
        if dataset_name == 'WMTNewsCrawl':
            raw_train = experimental_raw.DATASETS[dataset_name](root=root, split='train',
                                                                year=year, language=language)
        else:
            raw_train, = raw.DATASETS[dataset_name](root=root, split=('train',))
        logger_.info('Building Vocab based on train data')
        vocab = build_vocab(raw_train, tokenizer)
    logger_.info('Vocab has %d entries', len(vocab))

    def text_transform(line):
        return torch.tensor([vocab[token] for token in tokenizer(line)], dtype=torch.long)

    if dataset_name == 'WMTNewsCrawl':
        raw_datasets = experimental_raw.DATASETS[dataset_name](root=root, split=split,
                                                               year=year, language=language)
    else:
        raw_datasets = raw.DATASETS[dataset_name](root=root, split=split)
    raw_data = {name: list(map(text_transform, raw_dataset)) for name, raw_dataset in zip(split, raw_datasets)}
    logger_.info('Building datasets for {}'.format(split))
    return _wrap_datasets(tuple(LanguageModelingDataset(raw_data[item], vocab, text_transform)
                                for item in split), split_)
コード例 #2
0
def _setup_datasets(dataset_name, root, ngrams, vocab, tokenizer, split_):
    text_transform = []
    if tokenizer is None:
        tokenizer = get_tokenizer("basic_english")
    text_transform = sequential_transforms(tokenizer, ngrams_func(ngrams))
    split = _check_default_set(split_, ('train', 'test'), dataset_name)
    raw_datasets = raw.DATASETS[dataset_name](root=root, split=split)
    # Materialize raw text iterable dataset
    raw_data = {
        name: list(raw_dataset)
        for name, raw_dataset in zip(split, raw_datasets)
    }

    if vocab is None:
        if "train" not in split:
            raise TypeError("Must pass a vocab if train is not selected.")
        logger_.info('Building Vocab based on train data')
        vocab = build_vocab(raw_data["train"], text_transform)
    logger_.info('Vocab has %d entries', len(vocab))
    text_transform = sequential_transforms(text_transform, vocab_func(vocab),
                                           totensor(dtype=torch.long))
    if dataset_name == 'IMDB':
        label_transform = sequential_transforms(
            lambda x: 1 if x == 'pos' else 0, totensor(dtype=torch.long))
    else:
        label_transform = sequential_transforms(totensor(dtype=torch.long))
    logger_.info('Building datasets for {}'.format(split))
    return _wrap_datasets(
        tuple(
            TextClassificationDataset(raw_data[item], vocab, (label_transform,
                                                              text_transform))
            for item in split), split_)
コード例 #3
0
ファイル: question_answer.py プロジェクト: LongerVision/text
def _setup_datasets(dataset_name, root, vocab, tokenizer, split_):
    text_transform = []
    if tokenizer is None:
        tokenizer = get_tokenizer('basic_english')
    text_transform = sequential_transforms(tokenizer)
    split = _check_default_set(split_, ('train', 'dev'), dataset_name)
    raw_datasets = raw.DATASETS[dataset_name](root=root, split=split)
    raw_data = {
        name: list(raw_dataset)
        for name, raw_dataset in zip(split, raw_datasets)
    }
    if vocab is None:
        if 'train' not in split:
            raise TypeError("Must pass a vocab if train is not selected.")

        def apply_transform(data):
            for (_context, _question, _answers, _ans_pos) in data:
                tok_ans = []
                for item in _answers:
                    tok_ans += text_transform(item)
                yield text_transform(_context) + text_transform(
                    _question) + tok_ans

        logger_.info('Building Vocab based on train data')
        vocab = build_vocab_from_iterator(apply_transform(raw_data['train']),
                                          specials=['<unk>', '<pad>'])
        vocab.set_default_index(vocab['<unk>'])
    logger_.info('Vocab has %d entries', len(vocab))
    text_transform = sequential_transforms(text_transform, vocab_func(vocab),
                                           totensor(dtype=torch.long))
    transforms = {
        'context': text_transform,
        'question': text_transform,
        'answers': text_transform,
        'ans_pos': totensor(dtype=torch.long)
    }
    logger_.info('Building datasets for {}'.format(split))
    return _wrap_datasets(
        tuple(
            QuestionAnswerDataset(raw_data[item], vocab, transforms)
            for item in split), split_)
コード例 #4
0
def _setup_datasets(dataset_name, root, vocabs, split_):
    split = _check_default_set(split_, ('train', 'valid', 'test'),
                               dataset_name)
    raw_iter_tuple = raw.DATASETS[dataset_name](root=root, split=split)
    raw_data = {}
    for name, raw_iter in zip(split, raw_iter_tuple):
        raw_data[name] = list(raw_iter)

    if vocabs is None:
        if "train" not in split:
            raise TypeError("Must pass a vocab if train is not selected.")
        logger_.info('Building Vocab based on train data')
        vocabs = build_vocab(raw_data["train"])
    else:
        if not isinstance(vocabs, list):
            raise TypeError("vocabs must be an instance of list")

        # Find data that's not None
        notnone_data = None
        for key in raw_data.keys():
            if raw_data[key] is not None:
                notnone_data = raw_data[key]
                break
        if len(vocabs) != len(notnone_data[0]):
            raise ValueError(
                "Number of vocabs must match the number of columns "
                "in the data")

    transformers = [
        sequential_transforms(vocab_func(vocabs[idx]),
                              totensor(dtype=torch.long))
        for idx in range(len(vocabs))
    ]
    logger_.info('Building datasets for {}'.format(split))
    return _wrap_datasets(
        tuple(
            SequenceTaggingDataset(raw_data[item], vocabs, transformers)
            for item in split), split_)
コード例 #5
0
ファイル: translation.py プロジェクト: LongerVision/text
def _setup_datasets(dataset_name, split_, root, vocab, tokenizer, **kwargs):
    split = _check_default_set(split_, ('train', 'valid', 'test'),
                               dataset_name)
    src_vocab, tgt_vocab = vocab
    if tokenizer is None:
        src_tokenizer = get_tokenizer("spacy", language='de_core_news_sm')
        tgt_tokenizer = get_tokenizer("spacy", language='en_core_web_sm')
    elif isinstance(tokenizer, tuple):
        if len(tokenizer) == 2:
            src_tokenizer, tgt_tokenizer = tokenizer
        else:
            raise ValueError("tokenizer must have length of two for"
                             "source and target")
    else:
        raise ValueError(
            "tokenizer must be an instance of tuple with length two"
            "or None")

    if dataset_name == 'WMT14':
        raw_datasets = experimental_raw.DATASETS[dataset_name](split=split,
                                                               root=root,
                                                               **kwargs)
    else:
        raw_datasets = raw.DATASETS[dataset_name](split=split,
                                                  root=root,
                                                  **kwargs)
    raw_data = {
        name: list(raw_dataset)
        for name, raw_dataset in zip(split, raw_datasets)
    }
    src_text_vocab_transform = sequential_transforms(src_tokenizer)
    tgt_text_vocab_transform = sequential_transforms(tgt_tokenizer)

    if src_vocab is None:
        if 'train' not in split:
            raise TypeError("Must pass a vocab if train is not selected.")
        logger_.info('Building src Vocab based on train data')
        src_vocab = build_vocab(raw_data["train"],
                                src_text_vocab_transform,
                                index=0)
    else:
        if not isinstance(src_vocab, Vocab):
            raise TypeError("Passed src vocabulary is not of type Vocab")
    logger_.info('src Vocab has %d entries', len(src_vocab))

    if tgt_vocab is None:
        if 'train' not in split:
            raise TypeError("Must pass a vocab if train is not selected.")
        logger_.info('Building tgt Vocab based on train data')
        tgt_vocab = build_vocab(raw_data["train"],
                                tgt_text_vocab_transform,
                                index=1)
    else:
        if not isinstance(tgt_vocab, Vocab):
            raise TypeError("Passed tgt vocabulary is not of type Vocab")
    logger_.info('tgt Vocab has %d entries', len(tgt_vocab))

    logger_.info('Building datasets for {}'.format(split))
    datasets = []
    for key in split:
        src_text_transform = sequential_transforms(src_text_vocab_transform,
                                                   vocab_func(src_vocab),
                                                   totensor(dtype=torch.long))
        tgt_text_transform = sequential_transforms(tgt_text_vocab_transform,
                                                   vocab_func(tgt_vocab),
                                                   totensor(dtype=torch.long))
        datasets.append(
            TranslationDataset(raw_data[key], (src_vocab, tgt_vocab),
                               (src_text_transform, tgt_text_transform)))

    return _wrap_datasets(tuple(datasets), split_)