def PennTreebank(root, split): path = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5') logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], _read_text_iterator(path))
def EnWik9(root, split): dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5') extracted_files = extract_archive(dataset_tar) path = extracted_files[0] logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], _read_text_iterator(path))
def Multi30k(root, split, language_pair=('de', 'en')): """Multi30k dataset Reference: http://www.statmt.org/wmt16/multimodal-task.html#task1 Args: root: Directory where the datasets are saved. Default: ".data" split: split or splits to be returned. Can be a string or tuple of strings. Default: ('train', 'valid', 'test') language_pair: tuple or list containing src and tgt language. Available options are ('de','en') and ('en', 'de') """ assert ( len(language_pair) == 2 ), 'language_pair must contain only 2 elements: src and tgt language respectively' assert (tuple(sorted(language_pair)) == ( 'de', 'en')), "language_pair must be either ('de','en') or ('en', 'de')" downloaded_file = os.path.basename(URL[split]) src_path = _download_extract_validate( root, URL[split], MD5[split], os.path.join(root, downloaded_file), os.path.join( root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' + language_pair[0]), _EXTRACTED_FILES_INFO[split]['md5'][language_pair[0]]) trg_path = _download_extract_validate( root, URL[split], MD5[split], os.path.join(root, downloaded_file), os.path.join( root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' + language_pair[1]), _EXTRACTED_FILES_INFO[split]['md5'][language_pair[1]]) src_data_iter = _read_text_iterator(src_path) trg_data_iter = _read_text_iterator(trg_path) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], zip(src_data_iter, trg_data_iter))
def WMTNewsCrawl(root, split, year=2010, language='en'): if year not in _AVAILABLE_YEARS: raise ValueError( "{} not available. Please choose from years {}".format( year, _AVAILABLE_YEARS)) if language not in _AVAILABLE_LANGUAGES: raise ValueError( "{} not available. Please choose from languages {}".format( language, _AVAILABLE_LANGUAGES)) path = _download_extract_validate(root, URL, MD5, _PATH, _EXTRACTED_FILES[language], _EXTRACTED_FILES_MD5[language], hash_type="md5") logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], _read_text_iterator(path))
def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'): """IWSLT2016 dataset The available datasets include following: **Language pairs**: +-----+-----+-----+-----+-----+-----+ | |'en' |'fr' |'de' |'cs' |'ar' | +-----+-----+-----+-----+-----+-----+ |'en' | | x | x | x | x | +-----+-----+-----+-----+-----+-----+ |'fr' | x | | | | | +-----+-----+-----+-----+-----+-----+ |'de' | x | | | | | +-----+-----+-----+-----+-----+-----+ |'cs' | x | | | | | +-----+-----+-----+-----+-----+-----+ |'ar' | x | | | | | +-----+-----+-----+-----+-----+-----+ **valid/test sets**: ['dev2010', 'tst2010', 'tst2011', 'tst2012', 'tst2013', 'tst2014'] For additional details refer to source website: https://wit3.fbk.eu/2016-01 Args: root: Directory where the datasets are saved. Default: ".data" split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’) language_pair: tuple or list containing src and tgt language valid_set: a string to identify validation set. test_set: a string to identify test set. Examples: >>> from torchtext.datasets import IWSLT2016 >>> train_iter, valid_iter, test_iter = IWSLT2016() >>> src_sentence, tgt_sentence = next(train_iter) """ num_lines_set_identifier = { 'train': 'train', 'valid': valid_set, 'test': test_set } if not isinstance(language_pair, list) and not isinstance(language_pair, tuple): raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair))) assert (len(language_pair) == 2), 'language_pair must contain only 2 elements: src and tgt language respectively' src_language, tgt_language = language_pair[0], language_pair[1] if src_language not in SUPPORTED_DATASETS['language_pair']: raise ValueError("src_language '{}' is not valid. Supported source languages are {}". format(src_language, list(SUPPORTED_DATASETS['language_pair']))) if tgt_language not in SUPPORTED_DATASETS['language_pair'][src_language]: raise ValueError("tgt_language '{}' is not valid for give src_language '{}'. Supported target language are {}". format(tgt_language, src_language, SUPPORTED_DATASETS['language_pair'][src_language])) if valid_set not in SUPPORTED_DATASETS['valid_test'] or valid_set in SET_NOT_EXISTS[language_pair]: raise ValueError("valid_set '{}' is not valid for given language pair {}. Supported validation sets are {}". format(valid_set, language_pair, [s for s in SUPPORTED_DATASETS['valid_test'] if s not in SET_NOT_EXISTS[language_pair]])) if test_set not in SUPPORTED_DATASETS['valid_test'] or test_set in SET_NOT_EXISTS[language_pair]: raise ValueError("test_set '{}' is not valid for give language pair {}. Supported test sets are {}". format(valid_set, language_pair, [s for s in SUPPORTED_DATASETS['valid_test'] if s not in SET_NOT_EXISTS[language_pair]])) train_filenames = ('train.{}-{}.{}'.format(src_language, tgt_language, src_language), 'train.{}-{}.{}'.format(src_language, tgt_language, tgt_language)) valid_filenames = ('IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language), 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language)) test_filenames = ('IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language), 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language)) src_train, tgt_train = train_filenames src_eval, tgt_eval = valid_filenames src_test, tgt_test = test_filenames extracted_files = [] # list of paths to the extracted files dataset_tar = download_from_url(SUPPORTED_DATASETS['URL'], root=root, hash_value=SUPPORTED_DATASETS['MD5'], path=os.path.join(root, SUPPORTED_DATASETS['_PATH']), hash_type='md5') extracted_dataset_tar = extract_archive(dataset_tar) # IWSLT dataset's url downloads a multilingual tgz. # We need to take an extra step to pick out the specific language pair from it. src_language = train_filenames[0].split(".")[-1] tgt_language = train_filenames[1].split(".")[-1] languages = "-".join([src_language, tgt_language]) iwslt_tar = '{}/{}/texts/{}/{}/{}.tgz' iwslt_tar = iwslt_tar.format( root, SUPPORTED_DATASETS['_PATH'].split(".")[0], src_language, tgt_language, languages) extracted_dataset_tar = extract_archive(iwslt_tar) extracted_files.extend(extracted_dataset_tar) # Clean the xml and tag file in the archives file_archives = [] for fname in extracted_files: if 'xml' in fname: _clean_xml_file(fname) file_archives.append(os.path.splitext(fname)[0]) elif "tags" in fname: _clean_tags_file(fname) file_archives.append(fname.replace('.tags', '')) else: file_archives.append(fname) data_filenames = { "train": _construct_filepaths(file_archives, src_train, tgt_train), "valid": _construct_filepaths(file_archives, src_eval, tgt_eval), "test": _construct_filepaths(file_archives, src_test, tgt_test) } for key in data_filenames.keys(): if len(data_filenames[key]) == 0 or data_filenames[key] is None: raise FileNotFoundError( "Files are not found for data type {}".format(key)) src_data_iter = _read_text_iterator(data_filenames[split][0]) tgt_data_iter = _read_text_iterator(data_filenames[split][1]) def _iter(src_data_iter, tgt_data_iter): for item in zip(src_data_iter, tgt_data_iter): yield item return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split][num_lines_set_identifier[split]][tuple(sorted(language_pair))], _iter(src_data_iter, tgt_data_iter))
def WMT14(root, split, language_pair=('de', 'en'), train_set='train.tok.clean.bpe.32000', valid_set='newstest2013.tok.bpe.32000', test_set='newstest2014.tok.bpe.32000'): """WMT14 Dataset The available datasets include following: **Language pairs**: +-----+-----+-----+ | |'en' |'de' | +-----+-----+-----+ |'en' | | x | +-----+-----+-----+ |'de' | x | | +-----+-----+-----+ Args: root: Directory where the datasets are saved. Default: ".data" split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’) language_pair: tuple or list containing src and tgt language train_set: A string to identify train set. valid_set: A string to identify validation set. test_set: A string to identify test set. Examples: >>> from torchtext.datasets import WMT14 >>> train_iter, valid_iter, test_iter = WMT14() >>> src_sentence, tgt_sentence = next(train_iter) """ supported_language = ['en', 'de'] supported_train_set = [s for s in NUM_LINES if 'train' in s] supported_valid_set = [s for s in NUM_LINES if 'test' in s] supported_test_set = [s for s in NUM_LINES if 'test' in s] assert ( len(language_pair) == 2 ), 'language_pair must contain only 2 elements: src and tgt language respectively' if language_pair[0] not in supported_language: raise ValueError( "Source language '{}' is not supported. Valid options are {}". format(language_pair[0], supported_language)) if language_pair[1] not in supported_language: raise ValueError( "Target language '{}' is not supported. Valid options are {}". format(language_pair[1], supported_language)) if train_set not in supported_train_set: raise ValueError( "'{}' is not a valid train set identifier. valid options are {}". format(train_set, supported_train_set)) if valid_set not in supported_valid_set: raise ValueError( "'{}' is not a valid valid set identifier. valid options are {}". format(valid_set, supported_valid_set)) if test_set not in supported_test_set: raise ValueError( "'{}' is not a valid valid set identifier. valid options are {}". format(test_set, supported_test_set)) train_filenames = '{}.{}'.format(train_set, language_pair[0]), '{}.{}'.format( train_set, language_pair[1]) valid_filenames = '{}.{}'.format(valid_set, language_pair[0]), '{}.{}'.format( valid_set, language_pair[1]) test_filenames = '{}.{}'.format(test_set, language_pair[0]), '{}.{}'.format( test_set, language_pair[1]) if split == 'train': src_file, tgt_file = train_filenames elif split == 'valid': src_file, tgt_file = valid_filenames else: src_file, tgt_file = test_filenames dataset_tar = download_from_url(URL, root=root, hash_value=MD5, path=os.path.join(root, _PATH), hash_type='md5') extracted_files = extract_archive(dataset_tar) data_filenames = { split: _construct_filepaths(extracted_files, src_file, tgt_file), } for key in data_filenames: if len(data_filenames[key]) == 0 or data_filenames[key] is None: raise FileNotFoundError( "Files are not found for data type {}".format(key)) assert data_filenames[split][ 0] is not None, "Internal Error: File not found for reading" assert data_filenames[split][ 1] is not None, "Internal Error: File not found for reading" src_data_iter = _read_text_iterator(data_filenames[split][0]) tgt_data_iter = _read_text_iterator(data_filenames[split][1]) def _iter(src_data_iter, tgt_data_iter): for item in zip(src_data_iter, tgt_data_iter): yield item return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[os.path.splitext(src_file)[0]], _iter(src_data_iter, tgt_data_iter))
def Multi30k(root, split, task='task1', language_pair=('de', 'en'), train_set="train", valid_set="val", test_set="test_2016_flickr"): """Multi30k Dataset The available datasets include following: **Language pairs (task1)**: +-----+-----+-----+-----+-----+ | |'en' |'cs' |'de' |'fr' | +-----+-----+-----+-----+-----+ |'en' | | x | x | x | +-----+-----+-----+-----+-----+ |'cs' | x | | x | x | +-----+-----+-----+-----+-----+ |'de' | x | x | | x | +-----+-----+-----+-----+-----+ |'fr' | x | x | x | | +-----+-----+-----+-----+-----+ **Language pairs (task2)**: +-----+-----+-----+ | |'en' |'de' | +-----+-----+-----+ |'en' | | x | +-----+-----+-----+ |'de' | x | | +-----+-----+-----+ For additional details refer to source: https://github.com/multi30k/dataset Args: root: Directory where the datasets are saved. Default: ".data" split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’) task: Indicate the task language_pair: tuple or list containing src and tgt language train_set: A string to identify train set. valid_set: A string to identify validation set. test_set: A string to identify test set. Examples: >>> from torchtext.experimental.datasets.raw import Multi30k >>> train_iter, valid_iter, test_iter = Multi30k() >>> src_sentence, tgt_sentence = next(train_iter) """ if task not in SUPPORTED_DATASETS.keys(): raise ValueError( 'task {} is not supported. Valid options are {}'.format( task, SUPPORTED_DATASETS.keys())) assert ( len(language_pair) == 2 ), 'language_pair must contain only 2 elements: src and tgt language respectively' if language_pair[0] not in SUPPORTED_DATASETS[task].keys(): raise ValueError( "Source language '{}' is not supported. Valid options for task '{}' are {}" .format(language_pair[0], task, list(SUPPORTED_DATASETS[task].keys()))) if language_pair[1] not in SUPPORTED_DATASETS[task].keys(): raise ValueError( "Target language '{}' is not supported. Valid options for task '{}' are {}" .format(language_pair[1], task, list(SUPPORTED_DATASETS[task].keys()))) if train_set not in SUPPORTED_DATASETS[task][ language_pair[0]].keys() or 'train' not in train_set: raise ValueError( "'{}' is not a valid train set identifier. valid options for task '{}' and language pair {} are {}" .format(train_set, task, language_pair, [ k for k in SUPPORTED_DATASETS[task][language_pair[0]].keys() if 'train' in k ])) if valid_set not in SUPPORTED_DATASETS[task][ language_pair[0]].keys() or 'val' not in valid_set: raise ValueError( "'{}' is not a valid valid set identifier. valid options for task '{}' and language pair {} are {}" .format(valid_set, task, language_pair, [ k for k in SUPPORTED_DATASETS[task][language_pair[0]].keys() if 'val' in k ])) if test_set not in SUPPORTED_DATASETS[task][ language_pair[0]].keys() or 'test' not in test_set: raise ValueError( "'{}' is not a valid test set identifier. valid options for task '{}' and language pair {} are {}" .format(test_set, task, language_pair, [ k for k in SUPPORTED_DATASETS[task][language_pair[0]].keys() if 'test' in k ])) train_filenames = [ "{}.{}".format(train_set, language_pair[0]), "{}.{}".format(train_set, language_pair[1]) ] valid_filenames = [ "{}.{}".format(valid_set, language_pair[0]), "{}.{}".format(valid_set, language_pair[1]) ] test_filenames = [ "{}.{}".format(test_set, language_pair[0]), "{}.{}".format(test_set, language_pair[1]) ] if split == 'train': src_file, tgt_file = train_filenames elif split == 'valid': src_file, tgt_file = valid_filenames else: src_file, tgt_file = test_filenames extracted_files = [] # list of paths to the extracted files current_url = [] current_md5 = [] current_filenames = [src_file, tgt_file] for url, md5 in zip(URL[split], MD5[split]): if any(f in url for f in current_filenames): current_url.append(url) current_md5.append(md5) for url, md5 in zip(current_url, current_md5): dataset_tar = download_from_url(url, path=os.path.join( root, os.path.basename(url)), root=root, hash_value=md5, hash_type='md5') extracted_files.extend(extract_archive(dataset_tar)) file_archives = extracted_files data_filenames = { split: _construct_filepaths(file_archives, src_file, tgt_file), } for key in data_filenames: if len(data_filenames[key]) == 0 or data_filenames[key] is None: raise FileNotFoundError( "Files are not found for data type {}".format(key)) assert data_filenames[split][ 0] is not None, "Internal Error: File not found for reading" assert data_filenames[split][ 1] is not None, "Internal Error: File not found for reading" src_data_iter = _read_text_iterator(data_filenames[split][0]) tgt_data_iter = _read_text_iterator(data_filenames[split][1]) def _iter(src_data_iter, tgt_data_iter): for item in zip(src_data_iter, tgt_data_iter): yield item set_identifier = { 'train': train_set, 'valid': valid_set, 'test': test_set, } return _RawTextIterableDataset( DATASET_NAME, SUPPORTED_DATASETS[task][language_pair[0]][ set_identifier[split]]['NUM_LINES'], _iter(src_data_iter, tgt_data_iter))