Esempio n. 1
0
def SQuAD1(root, split):
    extracted_files = download_from_url(URL[split],
                                        root=root,
                                        hash_value=MD5[split],
                                        hash_type='md5')
    return RawTextIterableDataset('SQuAD1', NUM_LINES[split],
                                  _create_data_from_json(extracted_files))
Esempio n. 2
0
def WikiText103(root, split):
    dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5')
    extracted_files = extract_archive(dataset_tar)

    path = find_match(split, extracted_files)
    logging.info('Creating {} data'.format(split))
    return RawTextIterableDataset('WikiText103',
                                  NUM_LINES[split], iter(io.open(path, encoding="utf8")))
Esempio n. 3
0
def PennTreebank(root, split):
    path = download_from_url(URL[split],
                             root=root,
                             hash_value=MD5[split],
                             hash_type='md5')
    logging.info('Creating {} data'.format(split))
    return RawTextIterableDataset('PennTreebank', NUM_LINES[split],
                                  iter(io.open(path, encoding="utf8")))
Esempio n. 4
0
def CoNLL2000Chunking(root, split):
    # Create a dataset specific subfolder to deal with generic download filenames
    root = os.path.join(root, 'conll2000chunking')
    path = os.path.join(root, split + ".txt.gz")
    data_filename = download_extract_validate(root, URL[split], MD5[split], path, os.path.join(root, _EXTRACTED_FILES[split]),
                                              _EXTRACTED_FILES_MD5[split], hash_type="md5")
    logging.info('Creating {} data'.format(split))
    return RawTextIterableDataset("CoNLL2000Chunking", NUM_LINES[split],
                                  _create_data_from_iob(data_filename, " "))
Esempio n. 5
0
def UDPOS(root, split):
    dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5')
    extracted_files = extract_archive(dataset_tar)
    if split == 'valid':
        path = find_match("dev.txt", extracted_files)
    else:
        path = find_match(split + ".txt", extracted_files)
    return RawTextIterableDataset("UDPOS", NUM_LINES[split],
                                  _create_data_from_iob(path))
Esempio n. 6
0
def AG_NEWS(root, split):
    def _create_data_from_csv(data_path):
        with io.open(data_path, encoding="utf8") as f:
            reader = unicode_csv_reader(f)
            for row in reader:
                yield int(row[0]), ' '.join(row[1:])

    path = download_from_url(URL[split],
                             root=root,
                             path=os.path.join(root, split + ".csv"),
                             hash_value=MD5[split],
                             hash_type='md5')
    return RawTextIterableDataset("AG_NEWS", NUM_LINES[split],
                                  _create_data_from_csv(path))
Esempio n. 7
0
def YelpReviewFull(root, split):
    def _create_data_from_csv(data_path):
        with io.open(data_path, encoding="utf8") as f:
            reader = unicode_csv_reader(f)
            for row in reader:
                yield int(row[0]), ' '.join(row[1:])
    dataset_tar = download_from_url(URL, root=root,
                                    path=os.path.join(root, _PATH),
                                    hash_value=MD5, hash_type='md5')
    extracted_files = extract_archive(dataset_tar)

    path = find_match(split + '.csv', extracted_files)
    return RawTextIterableDataset("YelpReviewFull", NUM_LINES[split],
                                  _create_data_from_csv(path))
Esempio n. 8
0
def IMDB(root, split):
    def generate_imdb_data(key, extracted_files):
        for fname in extracted_files:
            if 'urls' in fname:
                continue
            elif key in fname and ('pos' in fname or 'neg' in fname):
                with io.open(fname, encoding="utf8") as f:
                    label = 'pos' if 'pos' in fname else 'neg'
                    yield label, f.read()

    dataset_tar = download_from_url(URL,
                                    root=root,
                                    hash_value=MD5,
                                    hash_type='md5')
    extracted_files = extract_archive(dataset_tar)
    iterator = generate_imdb_data(split, extracted_files)
    return RawTextIterableDataset("IMDB", NUM_LINES[split], iterator)
Esempio n. 9
0
def AmazonReviewFull(root, split):
    def _create_data_from_csv(data_path):
        with io.open(data_path, encoding="utf8") as f:
            reader = unicode_csv_reader(f)
            for row in reader:
                yield int(row[0]), ' '.join(row[1:])

    path = download_extract_validate(root,
                                     URL,
                                     MD5,
                                     os.path.join(root, _PATH),
                                     os.path.join(root,
                                                  _EXTRACTED_FILES[split]),
                                     _EXTRACTED_FILES_MD5[split],
                                     hash_type="md5")
    logging.info('Creating {} data'.format(split))
    return RawTextIterableDataset("AmazonReviewFull", NUM_LINES[split],
                                  _create_data_from_csv(path))
Esempio n. 10
0
def WMTNewsCrawl(root, split, year=2010, language='en'):
    if year not in _AVAILABLE_YEARS:
        raise ValueError(
            "{} not available. Please choose from years {}".format(
                year, _AVAILABLE_YEARS))
    if language not in _AVAILABLE_LANGUAGES:
        raise ValueError(
            "{} not available. Please choose from languages {}".format(
                language, _AVAILABLE_LANGUAGES))
    path = download_extract_validate(root,
                                     URL,
                                     MD5,
                                     _PATH,
                                     _EXTRACTED_FILES[language],
                                     _EXTRACTED_FILES_MD5[language],
                                     hash_type="md5")
    logging.info('Creating {} data'.format(split))
    return RawTextIterableDataset("WMTNewsCrawl", NUM_LINES[split],
                                  iter(io.open(path, encoding="utf8")))
Esempio n. 11
0
def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en')):
    """IWSLT2017 dataset

    The available datasets include following:

    **Language pairs**:

    +-----+-----+-----+-----+-----+-----+
    |     |'en' |'nl' |'de' |'it' |'ro' |
    +-----+-----+-----+-----+-----+-----+
    |'en' |     |   x |  x  |  x  |  x  |
    +-----+-----+-----+-----+-----+-----+
    |'nl' |  x  |     |  x  |  x  |  x  |
    +-----+-----+-----+-----+-----+-----+
    |'de' |  x  |   x |     |  x  |  x  |
    +-----+-----+-----+-----+-----+-----+
    |'it' |  x  |   x |  x  |     |  x  |
    +-----+-----+-----+-----+-----+-----+
    |'ro' |  x  |   x |  x  |  x  |     |
    +-----+-----+-----+-----+-----+-----+


    For additional details refer to source website: https://wit3.fbk.eu/2017-01

    Args:
        root: Directory where the datasets are saved. Default: ".data"
        split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’)
        language_pair: tuple or list containing src and tgt language

    Examples:
        >>> from torchtext.datasets import IWSLT2017
        >>> train_iter, valid_iter, test_iter = IWSLT2017()
        >>> src_sentence, tgt_sentence = next(train_iter)

    """

    valid_set = 'dev2010'
    test_set = 'tst2010'

    num_lines_set_identifier = {
        'train': 'train',
        'valid': valid_set,
        'test': test_set
    }

    if not isinstance(language_pair, list) and not isinstance(language_pair, tuple):
        raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair)))

    assert (len(language_pair) == 2), 'language_pair must contain only 2 elements: src and tgt language respectively'

    src_language, tgt_language = language_pair[0], language_pair[1]

    if src_language not in SUPPORTED_DATASETS['language_pair']:
        raise ValueError("src_language '{}' is not valid. Supported source languages are {}".
                         format(src_language, list(SUPPORTED_DATASETS['language_pair'])))

    if tgt_language not in SUPPORTED_DATASETS['language_pair'][src_language]:
        raise ValueError("tgt_language '{}' is not valid for give src_language '{}'. Supported target language are {}".
                         format(tgt_language, src_language, SUPPORTED_DATASETS['language_pair'][src_language]))

    train_filenames = ('train.{}-{}.{}'.format(src_language, tgt_language, src_language),
                       'train.{}-{}.{}'.format(src_language, tgt_language, tgt_language))
    valid_filenames = ('IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language),
                       'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language))
    test_filenames = ('IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language),
                      'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language))

    src_train, tgt_train = train_filenames
    src_eval, tgt_eval = valid_filenames
    src_test, tgt_test = test_filenames

    extracted_files = []  # list of paths to the extracted files
    dataset_tar = download_from_url(SUPPORTED_DATASETS['URL'], root=root, hash_value=SUPPORTED_DATASETS['MD5'], path=os.path.join(root, SUPPORTED_DATASETS['_PATH']), hash_type='md5')
    extracted_dataset_tar = extract_archive(dataset_tar)
    # IWSLT dataset's url downloads a multilingual tgz.
    # We need to take an extra step to pick out the specific language pair from it.
    src_language = train_filenames[0].split(".")[-1]
    tgt_language = train_filenames[1].split(".")[-1]

    iwslt_tar = os.path.join(root, SUPPORTED_DATASETS['_PATH'].split(".")[0], 'texts/DeEnItNlRo/DeEnItNlRo', 'DeEnItNlRo-DeEnItNlRo.tgz')
    extracted_dataset_tar = extract_archive(iwslt_tar)
    extracted_files.extend(extracted_dataset_tar)

    # Clean the xml and tag file in the archives
    file_archives = []
    for fname in extracted_files:
        if 'xml' in fname:
            _clean_xml_file(fname)
            file_archives.append(os.path.splitext(fname)[0])
        elif "tags" in fname:
            _clean_tags_file(fname)
            file_archives.append(fname.replace('.tags', ''))
        else:
            file_archives.append(fname)

    data_filenames = {
        "train": _construct_filepaths(file_archives, src_train, tgt_train),
        "valid": _construct_filepaths(file_archives, src_eval, tgt_eval),
        "test": _construct_filepaths(file_archives, src_test, tgt_test)
    }
    for key in data_filenames:
        if len(data_filenames[key]) == 0 or data_filenames[key] is None:
            raise FileNotFoundError(
                "Files are not found for data type {}".format(key))

    src_data_iter = _read_text_iterator(data_filenames[split][0])
    tgt_data_iter = _read_text_iterator(data_filenames[split][1])

    def _iter(src_data_iter, tgt_data_iter):
        for item in zip(src_data_iter, tgt_data_iter):
            yield item

    return RawTextIterableDataset("IWSLT2017", NUM_LINES[split][num_lines_set_identifier[split]][tuple(sorted(language_pair))], _iter(src_data_iter, tgt_data_iter))
Esempio n. 12
0
def Multi30k(root,
             split,
             task='task1',
             language_pair=('de', 'en'),
             train_set="train",
             valid_set="val",
             test_set="test_2016_flickr"):
    """Multi30k Dataset

    The available datasets include following:

    **Language pairs (task1)**:

    +-----+-----+-----+-----+-----+
    |     |'en' |'cs' |'de' |'fr' |
    +-----+-----+-----+-----+-----+
    |'en' |     |   x |  x  |  x  |
    +-----+-----+-----+-----+-----+
    |'cs' |  x  |     |  x  |  x  |
    +-----+-----+-----+-----+-----+
    |'de' |  x  |   x |     |  x  |
    +-----+-----+-----+-----+-----+
    |'fr' |  x  |   x |  x  |     |
    +-----+-----+-----+-----+-----+

    **Language pairs (task2)**:

    +-----+-----+-----+
    |     |'en' |'de' |
    +-----+-----+-----+
    |'en' |     |   x |
    +-----+-----+-----+
    |'de' |  x  |     |
    +-----+-----+-----+

    For additional details refer to source: https://github.com/multi30k/dataset

    Args:
        root: Directory where the datasets are saved. Default: ".data"
        split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’)
        task: Indicate the task
        language_pair: tuple or list containing src and tgt language
        train_set: A string to identify train set.
        valid_set: A string to identify validation set.
        test_set: A string to identify test set.

    Examples:
        >>> from torchtext.experimental.datasets.raw import Multi30k
        >>> train_iter, valid_iter, test_iter = Multi30k()
        >>> src_sentence, tgt_sentence = next(train_iter)
    """

    if task not in SUPPORTED_DATASETS.keys():
        raise ValueError(
            'task {} is not supported. Valid options are {}'.format(
                task, SUPPORTED_DATASETS.keys()))

    assert (
        len(language_pair) == 2
    ), 'language_pair must contain only 2 elements: src and tgt language respectively'

    if language_pair[0] not in SUPPORTED_DATASETS[task].keys():
        raise ValueError(
            "Source language '{}' is not supported. Valid options for task '{}' are {}"
            .format(language_pair[0], task,
                    list(SUPPORTED_DATASETS[task].keys())))

    if language_pair[1] not in SUPPORTED_DATASETS[task].keys():
        raise ValueError(
            "Target language '{}' is not supported. Valid options for task '{}' are {}"
            .format(language_pair[1], task,
                    list(SUPPORTED_DATASETS[task].keys())))

    if train_set not in SUPPORTED_DATASETS[task][
            language_pair[0]].keys() or 'train' not in train_set:
        raise ValueError(
            "'{}' is not a valid train set identifier. valid options for task '{}' and language pair {} are {}"
            .format(train_set, task, language_pair, [
                k for k in SUPPORTED_DATASETS[task][language_pair[0]].keys()
                if 'train' in k
            ]))

    if valid_set not in SUPPORTED_DATASETS[task][
            language_pair[0]].keys() or 'val' not in valid_set:
        raise ValueError(
            "'{}' is not a valid valid set identifier. valid options for task '{}' and language pair {} are {}"
            .format(valid_set, task, language_pair, [
                k for k in SUPPORTED_DATASETS[task][language_pair[0]].keys()
                if 'val' in k
            ]))

    if test_set not in SUPPORTED_DATASETS[task][
            language_pair[0]].keys() or 'test' not in test_set:
        raise ValueError(
            "'{}' is not a valid test set identifier. valid options for task '{}' and language pair {} are {}"
            .format(test_set, task, language_pair, [
                k for k in SUPPORTED_DATASETS[task][language_pair[0]].keys()
                if 'test' in k
            ]))

    train_filenames = [
        "{}.{}".format(train_set, language_pair[0]),
        "{}.{}".format(train_set, language_pair[1])
    ]
    valid_filenames = [
        "{}.{}".format(valid_set, language_pair[0]),
        "{}.{}".format(valid_set, language_pair[1])
    ]
    test_filenames = [
        "{}.{}".format(test_set, language_pair[0]),
        "{}.{}".format(test_set, language_pair[1])
    ]

    if split == 'train':
        src_file, tgt_file = train_filenames
    elif split == 'valid':
        src_file, tgt_file = valid_filenames
    else:
        src_file, tgt_file = test_filenames

    extracted_files = []  # list of paths to the extracted files

    current_url = []
    current_md5 = []

    current_filenames = [src_file, tgt_file]
    for url, md5 in zip(URL[split], MD5[split]):
        if any(f in url for f in current_filenames):
            current_url.append(url)
            current_md5.append(md5)

    for url, md5 in zip(current_url, current_md5):
        dataset_tar = download_from_url(url,
                                        path=os.path.join(
                                            root, os.path.basename(url)),
                                        root=root,
                                        hash_value=md5,
                                        hash_type='md5')
        extracted_files.extend(extract_archive(dataset_tar))

    file_archives = extracted_files

    data_filenames = {
        split: _construct_filepaths(file_archives, src_file, tgt_file),
    }

    for key in data_filenames:
        if len(data_filenames[key]) == 0 or data_filenames[key] is None:
            raise FileNotFoundError(
                "Files are not found for data type {}".format(key))

    assert data_filenames[split][
        0] is not None, "Internal Error: File not found for reading"
    assert data_filenames[split][
        1] is not None, "Internal Error: File not found for reading"
    src_data_iter = _read_text_iterator(data_filenames[split][0])
    tgt_data_iter = _read_text_iterator(data_filenames[split][1])

    def _iter(src_data_iter, tgt_data_iter):
        for item in zip(src_data_iter, tgt_data_iter):
            yield item

    set_identifier = {
        'train': train_set,
        'valid': valid_set,
        'test': test_set,
    }

    return RawTextIterableDataset(
        "Multi30k", SUPPORTED_DATASETS[task][language_pair[0]][
            set_identifier[split]]['NUM_LINES'],
        _iter(src_data_iter, tgt_data_iter))
Esempio n. 13
0
def WMT14(root,
          split,
          language_pair=('de', 'en'),
          train_set='train.tok.clean.bpe.32000',
          valid_set='newstest2013.tok.bpe.32000',
          test_set='newstest2014.tok.bpe.32000'):
    """WMT14 Dataset

    The available datasets include following:

    **Language pairs**:

    +-----+-----+-----+
    |     |'en' |'de' |
    +-----+-----+-----+
    |'en' |     |   x |
    +-----+-----+-----+
    |'de' |  x  |     |
    +-----+-----+-----+


    Args:
        root: Directory where the datasets are saved. Default: ".data"
        split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’)
        language_pair: tuple or list containing src and tgt language
        train_set: A string to identify train set.
        valid_set: A string to identify validation set.
        test_set: A string to identify test set.

    Examples:
        >>> from torchtext.datasets import WMT14
        >>> train_iter, valid_iter, test_iter = WMT14()
        >>> src_sentence, tgt_sentence = next(train_iter)
    """

    supported_language = ['en', 'de']
    supported_train_set = [s for s in NUM_LINES if 'train' in s]
    supported_valid_set = [s for s in NUM_LINES if 'test' in s]
    supported_test_set = [s for s in NUM_LINES if 'test' in s]

    assert (
        len(language_pair) == 2
    ), 'language_pair must contain only 2 elements: src and tgt language respectively'

    if language_pair[0] not in supported_language:
        raise ValueError(
            "Source language '{}' is not supported. Valid options are {}".
            format(language_pair[0], supported_language))

    if language_pair[1] not in supported_language:
        raise ValueError(
            "Target language '{}' is not supported. Valid options are {}".
            format(language_pair[1], supported_language))

    if train_set not in supported_train_set:
        raise ValueError(
            "'{}' is not a valid train set identifier. valid options are {}".
            format(train_set, supported_train_set))

    if valid_set not in supported_valid_set:
        raise ValueError(
            "'{}' is not a valid valid set identifier. valid options are {}".
            format(valid_set, supported_valid_set))

    if test_set not in supported_test_set:
        raise ValueError(
            "'{}' is not a valid valid set identifier. valid options are {}".
            format(test_set, supported_test_set))

    train_filenames = '{}.{}'.format(train_set,
                                     language_pair[0]), '{}.{}'.format(
                                         train_set, language_pair[1])
    valid_filenames = '{}.{}'.format(valid_set,
                                     language_pair[0]), '{}.{}'.format(
                                         valid_set, language_pair[1])
    test_filenames = '{}.{}'.format(test_set,
                                    language_pair[0]), '{}.{}'.format(
                                        test_set, language_pair[1])

    if split == 'train':
        src_file, tgt_file = train_filenames
    elif split == 'valid':
        src_file, tgt_file = valid_filenames
    else:
        src_file, tgt_file = test_filenames

    root = os.path.join(root, 'wmt14')
    dataset_tar = download_from_url(URL,
                                    root=root,
                                    hash_value=MD5,
                                    path=os.path.join(root, _PATH),
                                    hash_type='md5')
    extracted_files = extract_archive(dataset_tar)

    # Clean the xml and tag file in the archives
    file_archives = []
    for fname in extracted_files:
        if 'xml' in fname:
            _clean_xml_file(fname)
            file_archives.append(os.path.splitext(fname)[0])
        elif "tags" in fname:
            _clean_tags_file(fname)
            file_archives.append(fname.replace('.tags', ''))
        else:
            file_archives.append(fname)

    data_filenames = {
        split: _construct_filepaths(file_archives, src_file, tgt_file),
    }

    for key in data_filenames:
        if len(data_filenames[key]) == 0 or data_filenames[key] is None:
            raise FileNotFoundError(
                "Files are not found for data type {}".format(key))

    assert data_filenames[split][
        0] is not None, "Internal Error: File not found for reading"
    assert data_filenames[split][
        1] is not None, "Internal Error: File not found for reading"
    src_data_iter = _read_text_iterator(data_filenames[split][0])
    tgt_data_iter = _read_text_iterator(data_filenames[split][1])

    def _iter(src_data_iter, tgt_data_iter):
        for item in zip(src_data_iter, tgt_data_iter):
            yield item

    return RawTextIterableDataset("WMT14",
                                  NUM_LINES[os.path.splitext(src_file)[0]],
                                  _iter(src_data_iter, tgt_data_iter))