Пример #1
0
def SogouNews(root, split):
    path = _download_extract_validate(root,
                                      URL,
                                      MD5,
                                      os.path.join(root, _PATH),
                                      os.path.join(root,
                                                   _EXTRACTED_FILES[split]),
                                      _EXTRACTED_FILES_MD5[split],
                                      hash_type="md5")
    logging.info('Creating {} data'.format(split))
    return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
                                   _create_data_from_csv(path))
Пример #2
0
def Multi30k(root, split, language_pair=('de', 'en')):
    """Multi30k dataset

    Reference: http://www.statmt.org/wmt16/multimodal-task.html#task1

    Args:
        root: Directory where the datasets are saved. Default: ".data"
        split: split or splits to be returned. Can be a string or tuple of strings. Default: ('train', 'valid', 'test')
        language_pair: tuple or list containing src and tgt language. Available options are ('de','en') and ('en', 'de')
    """

    assert (
        len(language_pair) == 2
    ), 'language_pair must contain only 2 elements: src and tgt language respectively'
    assert (tuple(sorted(language_pair)) == (
        'de',
        'en')), "language_pair must be either ('de','en') or ('en', 'de')"

    downloaded_file = os.path.basename(URL[split])

    src_path = _download_extract_validate(
        root, URL[split], MD5[split], os.path.join(root, downloaded_file),
        os.path.join(
            root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' +
            language_pair[0]),
        _EXTRACTED_FILES_INFO[split]['md5'][language_pair[0]])
    trg_path = _download_extract_validate(
        root, URL[split], MD5[split], os.path.join(root, downloaded_file),
        os.path.join(
            root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' +
            language_pair[1]),
        _EXTRACTED_FILES_INFO[split]['md5'][language_pair[1]])

    src_data_iter = _read_text_iterator(src_path)
    trg_data_iter = _read_text_iterator(trg_path)

    return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
                                   zip(src_data_iter, trg_data_iter))
Пример #3
0
def CoNLL2000Chunking(root, split):
    # Create a dataset specific subfolder to deal with generic download filenames
    root = os.path.join(root, 'conll2000chunking')
    path = os.path.join(root, split + ".txt.gz")
    data_filename = _download_extract_validate(root,
                                               URL[split],
                                               MD5[split],
                                               path,
                                               os.path.join(
                                                   root,
                                                   _EXTRACTED_FILES[split]),
                                               _EXTRACTED_FILES_MD5[split],
                                               hash_type="md5")
    logging.info('Creating {} data'.format(split))
    return _RawTextIterableDataset("CoNLL2000Chunking", NUM_LINES[split],
                                   _create_data_from_iob(data_filename, " "))
Пример #4
0
def AmazonReviewFull(root, split):
    def _create_data_from_csv(data_path):
        with io.open(data_path, encoding="utf8") as f:
            reader = unicode_csv_reader(f)
            for row in reader:
                yield int(row[0]), ' '.join(row[1:])

    path = _download_extract_validate(root,
                                      URL,
                                      MD5,
                                      os.path.join(root, _PATH),
                                      os.path.join(root,
                                                   _EXTRACTED_FILES[split]),
                                      _EXTRACTED_FILES_MD5[split],
                                      hash_type="md5")
    logging.info('Creating {} data'.format(split))
    return _RawTextIterableDataset("AmazonReviewFull", NUM_LINES[split],
                                   _create_data_from_csv(path))
Пример #5
0
def WMTNewsCrawl(root, split, year=2010, language='en'):
    if year not in _AVAILABLE_YEARS:
        raise ValueError(
            "{} not available. Please choose from years {}".format(
                year, _AVAILABLE_YEARS))
    if language not in _AVAILABLE_LANGUAGES:
        raise ValueError(
            "{} not available. Please choose from languages {}".format(
                language, _AVAILABLE_LANGUAGES))
    path = _download_extract_validate(root,
                                      URL,
                                      MD5,
                                      _PATH,
                                      _EXTRACTED_FILES[language],
                                      _EXTRACTED_FILES_MD5[language],
                                      hash_type="md5")
    logging.info('Creating {} data'.format(split))
    return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
                                   _read_text_iterator(path))