def SQuAD1(root, split): extracted_files = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5') return RawTextIterableDataset('SQuAD1', NUM_LINES[split], _create_data_from_json(extracted_files))
def WikiText103(root, split): dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5') extracted_files = extract_archive(dataset_tar) path = find_match(split, extracted_files) logging.info('Creating {} data'.format(split)) return RawTextIterableDataset('WikiText103', NUM_LINES[split], iter(io.open(path, encoding="utf8")))
def PennTreebank(root, split): path = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5') logging.info('Creating {} data'.format(split)) return RawTextIterableDataset('PennTreebank', NUM_LINES[split], iter(io.open(path, encoding="utf8")))
def CoNLL2000Chunking(root, split): # Create a dataset specific subfolder to deal with generic download filenames root = os.path.join(root, 'conll2000chunking') path = os.path.join(root, split + ".txt.gz") data_filename = download_extract_validate(root, URL[split], MD5[split], path, os.path.join(root, _EXTRACTED_FILES[split]), _EXTRACTED_FILES_MD5[split], hash_type="md5") logging.info('Creating {} data'.format(split)) return RawTextIterableDataset("CoNLL2000Chunking", NUM_LINES[split], _create_data_from_iob(data_filename, " "))
def UDPOS(root, split): dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5') extracted_files = extract_archive(dataset_tar) if split == 'valid': path = find_match("dev.txt", extracted_files) else: path = find_match(split + ".txt", extracted_files) return RawTextIterableDataset("UDPOS", NUM_LINES[split], _create_data_from_iob(path))
def AG_NEWS(root, split): def _create_data_from_csv(data_path): with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: yield int(row[0]), ' '.join(row[1:]) path = download_from_url(URL[split], root=root, path=os.path.join(root, split + ".csv"), hash_value=MD5[split], hash_type='md5') return RawTextIterableDataset("AG_NEWS", NUM_LINES[split], _create_data_from_csv(path))
def YelpReviewFull(root, split): def _create_data_from_csv(data_path): with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: yield int(row[0]), ' '.join(row[1:]) dataset_tar = download_from_url(URL, root=root, path=os.path.join(root, _PATH), hash_value=MD5, hash_type='md5') extracted_files = extract_archive(dataset_tar) path = find_match(split + '.csv', extracted_files) return RawTextIterableDataset("YelpReviewFull", NUM_LINES[split], _create_data_from_csv(path))
def IMDB(root, split): def generate_imdb_data(key, extracted_files): for fname in extracted_files: if 'urls' in fname: continue elif key in fname and ('pos' in fname or 'neg' in fname): with io.open(fname, encoding="utf8") as f: label = 'pos' if 'pos' in fname else 'neg' yield label, f.read() dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5') extracted_files = extract_archive(dataset_tar) iterator = generate_imdb_data(split, extracted_files) return RawTextIterableDataset("IMDB", NUM_LINES[split], iterator)
def AmazonReviewFull(root, split): def _create_data_from_csv(data_path): with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: yield int(row[0]), ' '.join(row[1:]) path = download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]), _EXTRACTED_FILES_MD5[split], hash_type="md5") logging.info('Creating {} data'.format(split)) return RawTextIterableDataset("AmazonReviewFull", NUM_LINES[split], _create_data_from_csv(path))
def WMTNewsCrawl(root, split, year=2010, language='en'): if year not in _AVAILABLE_YEARS: raise ValueError( "{} not available. Please choose from years {}".format( year, _AVAILABLE_YEARS)) if language not in _AVAILABLE_LANGUAGES: raise ValueError( "{} not available. Please choose from languages {}".format( language, _AVAILABLE_LANGUAGES)) path = download_extract_validate(root, URL, MD5, _PATH, _EXTRACTED_FILES[language], _EXTRACTED_FILES_MD5[language], hash_type="md5") logging.info('Creating {} data'.format(split)) return RawTextIterableDataset("WMTNewsCrawl", NUM_LINES[split], iter(io.open(path, encoding="utf8")))
def IWSLT2017(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en')): """IWSLT2017 dataset The available datasets include following: **Language pairs**: +-----+-----+-----+-----+-----+-----+ | |'en' |'nl' |'de' |'it' |'ro' | +-----+-----+-----+-----+-----+-----+ |'en' | | x | x | x | x | +-----+-----+-----+-----+-----+-----+ |'nl' | x | | x | x | x | +-----+-----+-----+-----+-----+-----+ |'de' | x | x | | x | x | +-----+-----+-----+-----+-----+-----+ |'it' | x | x | x | | x | +-----+-----+-----+-----+-----+-----+ |'ro' | x | x | x | x | | +-----+-----+-----+-----+-----+-----+ For additional details refer to source website: https://wit3.fbk.eu/2017-01 Args: root: Directory where the datasets are saved. Default: ".data" split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’) language_pair: tuple or list containing src and tgt language Examples: >>> from torchtext.datasets import IWSLT2017 >>> train_iter, valid_iter, test_iter = IWSLT2017() >>> src_sentence, tgt_sentence = next(train_iter) """ valid_set = 'dev2010' test_set = 'tst2010' num_lines_set_identifier = { 'train': 'train', 'valid': valid_set, 'test': test_set } if not isinstance(language_pair, list) and not isinstance(language_pair, tuple): raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair))) assert (len(language_pair) == 2), 'language_pair must contain only 2 elements: src and tgt language respectively' src_language, tgt_language = language_pair[0], language_pair[1] if src_language not in SUPPORTED_DATASETS['language_pair']: raise ValueError("src_language '{}' is not valid. Supported source languages are {}". format(src_language, list(SUPPORTED_DATASETS['language_pair']))) if tgt_language not in SUPPORTED_DATASETS['language_pair'][src_language]: raise ValueError("tgt_language '{}' is not valid for give src_language '{}'. Supported target language are {}". format(tgt_language, src_language, SUPPORTED_DATASETS['language_pair'][src_language])) train_filenames = ('train.{}-{}.{}'.format(src_language, tgt_language, src_language), 'train.{}-{}.{}'.format(src_language, tgt_language, tgt_language)) valid_filenames = ('IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language), 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language)) test_filenames = ('IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language), 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language)) src_train, tgt_train = train_filenames src_eval, tgt_eval = valid_filenames src_test, tgt_test = test_filenames extracted_files = [] # list of paths to the extracted files dataset_tar = download_from_url(SUPPORTED_DATASETS['URL'], root=root, hash_value=SUPPORTED_DATASETS['MD5'], path=os.path.join(root, SUPPORTED_DATASETS['_PATH']), hash_type='md5') extracted_dataset_tar = extract_archive(dataset_tar) # IWSLT dataset's url downloads a multilingual tgz. # We need to take an extra step to pick out the specific language pair from it. src_language = train_filenames[0].split(".")[-1] tgt_language = train_filenames[1].split(".")[-1] iwslt_tar = os.path.join(root, SUPPORTED_DATASETS['_PATH'].split(".")[0], 'texts/DeEnItNlRo/DeEnItNlRo', 'DeEnItNlRo-DeEnItNlRo.tgz') extracted_dataset_tar = extract_archive(iwslt_tar) extracted_files.extend(extracted_dataset_tar) # Clean the xml and tag file in the archives file_archives = [] for fname in extracted_files: if 'xml' in fname: _clean_xml_file(fname) file_archives.append(os.path.splitext(fname)[0]) elif "tags" in fname: _clean_tags_file(fname) file_archives.append(fname.replace('.tags', '')) else: file_archives.append(fname) data_filenames = { "train": _construct_filepaths(file_archives, src_train, tgt_train), "valid": _construct_filepaths(file_archives, src_eval, tgt_eval), "test": _construct_filepaths(file_archives, src_test, tgt_test) } for key in data_filenames: if len(data_filenames[key]) == 0 or data_filenames[key] is None: raise FileNotFoundError( "Files are not found for data type {}".format(key)) src_data_iter = _read_text_iterator(data_filenames[split][0]) tgt_data_iter = _read_text_iterator(data_filenames[split][1]) def _iter(src_data_iter, tgt_data_iter): for item in zip(src_data_iter, tgt_data_iter): yield item return RawTextIterableDataset("IWSLT2017", NUM_LINES[split][num_lines_set_identifier[split]][tuple(sorted(language_pair))], _iter(src_data_iter, tgt_data_iter))
def Multi30k(root, split, task='task1', language_pair=('de', 'en'), train_set="train", valid_set="val", test_set="test_2016_flickr"): """Multi30k Dataset The available datasets include following: **Language pairs (task1)**: +-----+-----+-----+-----+-----+ | |'en' |'cs' |'de' |'fr' | +-----+-----+-----+-----+-----+ |'en' | | x | x | x | +-----+-----+-----+-----+-----+ |'cs' | x | | x | x | +-----+-----+-----+-----+-----+ |'de' | x | x | | x | +-----+-----+-----+-----+-----+ |'fr' | x | x | x | | +-----+-----+-----+-----+-----+ **Language pairs (task2)**: +-----+-----+-----+ | |'en' |'de' | +-----+-----+-----+ |'en' | | x | +-----+-----+-----+ |'de' | x | | +-----+-----+-----+ For additional details refer to source: https://github.com/multi30k/dataset Args: root: Directory where the datasets are saved. Default: ".data" split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’) task: Indicate the task language_pair: tuple or list containing src and tgt language train_set: A string to identify train set. valid_set: A string to identify validation set. test_set: A string to identify test set. Examples: >>> from torchtext.experimental.datasets.raw import Multi30k >>> train_iter, valid_iter, test_iter = Multi30k() >>> src_sentence, tgt_sentence = next(train_iter) """ if task not in SUPPORTED_DATASETS.keys(): raise ValueError( 'task {} is not supported. Valid options are {}'.format( task, SUPPORTED_DATASETS.keys())) assert ( len(language_pair) == 2 ), 'language_pair must contain only 2 elements: src and tgt language respectively' if language_pair[0] not in SUPPORTED_DATASETS[task].keys(): raise ValueError( "Source language '{}' is not supported. Valid options for task '{}' are {}" .format(language_pair[0], task, list(SUPPORTED_DATASETS[task].keys()))) if language_pair[1] not in SUPPORTED_DATASETS[task].keys(): raise ValueError( "Target language '{}' is not supported. Valid options for task '{}' are {}" .format(language_pair[1], task, list(SUPPORTED_DATASETS[task].keys()))) if train_set not in SUPPORTED_DATASETS[task][ language_pair[0]].keys() or 'train' not in train_set: raise ValueError( "'{}' is not a valid train set identifier. valid options for task '{}' and language pair {} are {}" .format(train_set, task, language_pair, [ k for k in SUPPORTED_DATASETS[task][language_pair[0]].keys() if 'train' in k ])) if valid_set not in SUPPORTED_DATASETS[task][ language_pair[0]].keys() or 'val' not in valid_set: raise ValueError( "'{}' is not a valid valid set identifier. valid options for task '{}' and language pair {} are {}" .format(valid_set, task, language_pair, [ k for k in SUPPORTED_DATASETS[task][language_pair[0]].keys() if 'val' in k ])) if test_set not in SUPPORTED_DATASETS[task][ language_pair[0]].keys() or 'test' not in test_set: raise ValueError( "'{}' is not a valid test set identifier. valid options for task '{}' and language pair {} are {}" .format(test_set, task, language_pair, [ k for k in SUPPORTED_DATASETS[task][language_pair[0]].keys() if 'test' in k ])) train_filenames = [ "{}.{}".format(train_set, language_pair[0]), "{}.{}".format(train_set, language_pair[1]) ] valid_filenames = [ "{}.{}".format(valid_set, language_pair[0]), "{}.{}".format(valid_set, language_pair[1]) ] test_filenames = [ "{}.{}".format(test_set, language_pair[0]), "{}.{}".format(test_set, language_pair[1]) ] if split == 'train': src_file, tgt_file = train_filenames elif split == 'valid': src_file, tgt_file = valid_filenames else: src_file, tgt_file = test_filenames extracted_files = [] # list of paths to the extracted files current_url = [] current_md5 = [] current_filenames = [src_file, tgt_file] for url, md5 in zip(URL[split], MD5[split]): if any(f in url for f in current_filenames): current_url.append(url) current_md5.append(md5) for url, md5 in zip(current_url, current_md5): dataset_tar = download_from_url(url, path=os.path.join( root, os.path.basename(url)), root=root, hash_value=md5, hash_type='md5') extracted_files.extend(extract_archive(dataset_tar)) file_archives = extracted_files data_filenames = { split: _construct_filepaths(file_archives, src_file, tgt_file), } for key in data_filenames: if len(data_filenames[key]) == 0 or data_filenames[key] is None: raise FileNotFoundError( "Files are not found for data type {}".format(key)) assert data_filenames[split][ 0] is not None, "Internal Error: File not found for reading" assert data_filenames[split][ 1] is not None, "Internal Error: File not found for reading" src_data_iter = _read_text_iterator(data_filenames[split][0]) tgt_data_iter = _read_text_iterator(data_filenames[split][1]) def _iter(src_data_iter, tgt_data_iter): for item in zip(src_data_iter, tgt_data_iter): yield item set_identifier = { 'train': train_set, 'valid': valid_set, 'test': test_set, } return RawTextIterableDataset( "Multi30k", SUPPORTED_DATASETS[task][language_pair[0]][ set_identifier[split]]['NUM_LINES'], _iter(src_data_iter, tgt_data_iter))
def WMT14(root, split, language_pair=('de', 'en'), train_set='train.tok.clean.bpe.32000', valid_set='newstest2013.tok.bpe.32000', test_set='newstest2014.tok.bpe.32000'): """WMT14 Dataset The available datasets include following: **Language pairs**: +-----+-----+-----+ | |'en' |'de' | +-----+-----+-----+ |'en' | | x | +-----+-----+-----+ |'de' | x | | +-----+-----+-----+ Args: root: Directory where the datasets are saved. Default: ".data" split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’) language_pair: tuple or list containing src and tgt language train_set: A string to identify train set. valid_set: A string to identify validation set. test_set: A string to identify test set. Examples: >>> from torchtext.datasets import WMT14 >>> train_iter, valid_iter, test_iter = WMT14() >>> src_sentence, tgt_sentence = next(train_iter) """ supported_language = ['en', 'de'] supported_train_set = [s for s in NUM_LINES if 'train' in s] supported_valid_set = [s for s in NUM_LINES if 'test' in s] supported_test_set = [s for s in NUM_LINES if 'test' in s] assert ( len(language_pair) == 2 ), 'language_pair must contain only 2 elements: src and tgt language respectively' if language_pair[0] not in supported_language: raise ValueError( "Source language '{}' is not supported. Valid options are {}". format(language_pair[0], supported_language)) if language_pair[1] not in supported_language: raise ValueError( "Target language '{}' is not supported. Valid options are {}". format(language_pair[1], supported_language)) if train_set not in supported_train_set: raise ValueError( "'{}' is not a valid train set identifier. valid options are {}". format(train_set, supported_train_set)) if valid_set not in supported_valid_set: raise ValueError( "'{}' is not a valid valid set identifier. valid options are {}". format(valid_set, supported_valid_set)) if test_set not in supported_test_set: raise ValueError( "'{}' is not a valid valid set identifier. valid options are {}". format(test_set, supported_test_set)) train_filenames = '{}.{}'.format(train_set, language_pair[0]), '{}.{}'.format( train_set, language_pair[1]) valid_filenames = '{}.{}'.format(valid_set, language_pair[0]), '{}.{}'.format( valid_set, language_pair[1]) test_filenames = '{}.{}'.format(test_set, language_pair[0]), '{}.{}'.format( test_set, language_pair[1]) if split == 'train': src_file, tgt_file = train_filenames elif split == 'valid': src_file, tgt_file = valid_filenames else: src_file, tgt_file = test_filenames root = os.path.join(root, 'wmt14') dataset_tar = download_from_url(URL, root=root, hash_value=MD5, path=os.path.join(root, _PATH), hash_type='md5') extracted_files = extract_archive(dataset_tar) # Clean the xml and tag file in the archives file_archives = [] for fname in extracted_files: if 'xml' in fname: _clean_xml_file(fname) file_archives.append(os.path.splitext(fname)[0]) elif "tags" in fname: _clean_tags_file(fname) file_archives.append(fname.replace('.tags', '')) else: file_archives.append(fname) data_filenames = { split: _construct_filepaths(file_archives, src_file, tgt_file), } for key in data_filenames: if len(data_filenames[key]) == 0 or data_filenames[key] is None: raise FileNotFoundError( "Files are not found for data type {}".format(key)) assert data_filenames[split][ 0] is not None, "Internal Error: File not found for reading" assert data_filenames[split][ 1] is not None, "Internal Error: File not found for reading" src_data_iter = _read_text_iterator(data_filenames[split][0]) tgt_data_iter = _read_text_iterator(data_filenames[split][1]) def _iter(src_data_iter, tgt_data_iter): for item in zip(src_data_iter, tgt_data_iter): yield item return RawTextIterableDataset("WMT14", NUM_LINES[os.path.splitext(src_file)[0]], _iter(src_data_iter, tgt_data_iter))