def SQuAD2(root, split): extracted_files = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5') return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], _create_data_from_json(extracted_files))
def PennTreebank(root, split): path = download_from_url(URL[split], root=root, hash_value=MD5[split], hash_type='md5') logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], _read_text_iterator(path))
def WikiText103(root, split): dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5') extracted_files = extract_archive(dataset_tar) path = _find_match(split, extracted_files) logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset('WikiText103', NUM_LINES[split], iter(io.open(path, encoding="utf8")))
def AG_NEWS(root, split): path = download_from_url(URL[split], root=root, path=os.path.join(root, split + ".csv"), hash_value=MD5[split], hash_type='md5') return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], _create_data_from_csv(path))
def _create_raw_text_iterable_dataset(description=None, path=None): iterator = _create_data_from_csv(path) length = sum(1 for _ in iterator) #为了得到下面的full_num_lines参数的值 iterator = _create_data_from_csv(path) # print("lenght: ",length) return _RawTextIterableDataset(description=description, full_num_lines=length, iterator=iterator)
def UDPOS(root, split): dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5') extracted_files = extract_archive(dataset_tar) if split == 'valid': path = _find_match("dev.txt", extracted_files) else: path = _find_match(split + ".txt", extracted_files) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], _create_data_from_iob(path))
def EnWik9(root, split): dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5') extracted_files = extract_archive(dataset_tar) path = extracted_files[0] logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], _read_text_iterator(path))
def DBpedia(root, split): dataset_tar = download_from_url(URL, root=root, path=os.path.join(root, _PATH), hash_value=MD5, hash_type='md5') extracted_files = extract_archive(dataset_tar) path = _find_match(split + '.csv', extracted_files) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], _create_data_from_csv(path))
def SogouNews(root, split): path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]), _EXTRACTED_FILES_MD5[split], hash_type="md5") logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], _create_data_from_csv(path))
def AG_NEWS(root, split): def _create_data_from_csv(data_path): with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: yield int(row[0]), ' '.join(row[1:]) path = download_from_url(URL[split], root=root, path=os.path.join(root, split + ".csv"), hash_value=MD5[split], hash_type='md5') return _RawTextIterableDataset("AG_NEWS", NUM_LINES[split], _create_data_from_csv(path))
def IMDB(root, split): def generate_imdb_data(key, extracted_files): for fname in extracted_files: if 'urls' in fname: continue elif key in fname and ('pos' in fname or 'neg' in fname): with io.open(fname, encoding="utf8") as f: label = 'pos' if 'pos' in fname else 'neg' yield label, f.read() dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5') extracted_files = extract_archive(dataset_tar) iterator = generate_imdb_data(split, extracted_files) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], iterator)
def IMDB(root, split): def generate_imdb_data(key, extracted_files): for fname in extracted_files: *_, split, label, file = Path(fname).parts if key == split and (label in ['pos', 'neg']): with io.open(fname, encoding="utf8") as f: yield label, f.read() dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5') extracted_files = extract_archive(dataset_tar) iterator = generate_imdb_data(split, extracted_files) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], iterator)
def CoNLL2000Chunking(root, split): # Create a dataset specific subfolder to deal with generic download filenames root = os.path.join(root, 'conll2000chunking') path = os.path.join(root, split + ".txt.gz") data_filename = _download_extract_validate(root, URL[split], MD5[split], path, os.path.join( root, _EXTRACTED_FILES[split]), _EXTRACTED_FILES_MD5[split], hash_type="md5") logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset("CoNLL2000Chunking", NUM_LINES[split], _create_data_from_iob(data_filename, " "))
def DBpedia(root, split): def _create_data_from_csv(data_path): with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: yield int(row[0]), ' '.join(row[1:]) dataset_tar = download_from_url(URL, root=root, path=os.path.join(root, _PATH), hash_value=MD5, hash_type='md5') extracted_files = extract_archive(dataset_tar) path = _find_match(split + '.csv', extracted_files) return _RawTextIterableDataset("DBpedia", NUM_LINES[split], _create_data_from_csv(path))
def AmazonReviewFull(root, split): def _create_data_from_csv(data_path): with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: yield int(row[0]), ' '.join(row[1:]) path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]), _EXTRACTED_FILES_MD5[split], hash_type="md5") logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset("AmazonReviewFull", NUM_LINES[split], _create_data_from_csv(path))
def WMTNewsCrawl(root, split, year=2010, language='en'): if year not in _AVAILABLE_YEARS: raise ValueError( "{} not available. Please choose from years {}".format( year, _AVAILABLE_YEARS)) if language not in _AVAILABLE_LANGUAGES: raise ValueError( "{} not available. Please choose from languages {}".format( language, _AVAILABLE_LANGUAGES)) path = _download_extract_validate(root, URL, MD5, _PATH, _EXTRACTED_FILES[language], _EXTRACTED_FILES_MD5[language], hash_type="md5") logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], _read_text_iterator(path))
def Multi30k(root, split, language_pair=('de', 'en')): """Multi30k dataset Reference: http://www.statmt.org/wmt16/multimodal-task.html#task1 Args: root: Directory where the datasets are saved. Default: ".data" split: split or splits to be returned. Can be a string or tuple of strings. Default: ('train', 'valid', 'test') language_pair: tuple or list containing src and tgt language. Available options are ('de','en') and ('en', 'de') """ assert ( len(language_pair) == 2 ), 'language_pair must contain only 2 elements: src and tgt language respectively' assert (tuple(sorted(language_pair)) == ( 'de', 'en')), "language_pair must be either ('de','en') or ('en', 'de')" downloaded_file = os.path.basename(URL[split]) src_path = _download_extract_validate( root, URL[split], MD5[split], os.path.join(root, downloaded_file), os.path.join( root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' + language_pair[0]), _EXTRACTED_FILES_INFO[split]['md5'][language_pair[0]]) trg_path = _download_extract_validate( root, URL[split], MD5[split], os.path.join(root, downloaded_file), os.path.join( root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' + language_pair[1]), _EXTRACTED_FILES_INFO[split]['md5'][language_pair[1]]) src_data_iter = _read_text_iterator(src_path) trg_data_iter = _read_text_iterator(trg_path) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], zip(src_data_iter, trg_data_iter))
def IWSLT2016(root='.data', split=('train', 'valid', 'test'), language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014'): """IWSLT2016 dataset The available datasets include following: **Language pairs**: +-----+-----+-----+-----+-----+-----+ | |'en' |'fr' |'de' |'cs' |'ar' | +-----+-----+-----+-----+-----+-----+ |'en' | | x | x | x | x | +-----+-----+-----+-----+-----+-----+ |'fr' | x | | | | | +-----+-----+-----+-----+-----+-----+ |'de' | x | | | | | +-----+-----+-----+-----+-----+-----+ |'cs' | x | | | | | +-----+-----+-----+-----+-----+-----+ |'ar' | x | | | | | +-----+-----+-----+-----+-----+-----+ **valid/test sets**: ['dev2010', 'tst2010', 'tst2011', 'tst2012', 'tst2013', 'tst2014'] For additional details refer to source website: https://wit3.fbk.eu/2016-01 Args: root: Directory where the datasets are saved. Default: ".data" split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’) language_pair: tuple or list containing src and tgt language valid_set: a string to identify validation set. test_set: a string to identify test set. Examples: >>> from torchtext.datasets import IWSLT2016 >>> train_iter, valid_iter, test_iter = IWSLT2016() >>> src_sentence, tgt_sentence = next(train_iter) """ num_lines_set_identifier = { 'train': 'train', 'valid': valid_set, 'test': test_set } if not isinstance(language_pair, list) and not isinstance(language_pair, tuple): raise ValueError("language_pair must be list or tuple but got {} instead".format(type(language_pair))) assert (len(language_pair) == 2), 'language_pair must contain only 2 elements: src and tgt language respectively' src_language, tgt_language = language_pair[0], language_pair[1] if src_language not in SUPPORTED_DATASETS['language_pair']: raise ValueError("src_language '{}' is not valid. Supported source languages are {}". format(src_language, list(SUPPORTED_DATASETS['language_pair']))) if tgt_language not in SUPPORTED_DATASETS['language_pair'][src_language]: raise ValueError("tgt_language '{}' is not valid for give src_language '{}'. Supported target language are {}". format(tgt_language, src_language, SUPPORTED_DATASETS['language_pair'][src_language])) if valid_set not in SUPPORTED_DATASETS['valid_test'] or valid_set in SET_NOT_EXISTS[language_pair]: raise ValueError("valid_set '{}' is not valid for given language pair {}. Supported validation sets are {}". format(valid_set, language_pair, [s for s in SUPPORTED_DATASETS['valid_test'] if s not in SET_NOT_EXISTS[language_pair]])) if test_set not in SUPPORTED_DATASETS['valid_test'] or test_set in SET_NOT_EXISTS[language_pair]: raise ValueError("test_set '{}' is not valid for give language pair {}. Supported test sets are {}". format(valid_set, language_pair, [s for s in SUPPORTED_DATASETS['valid_test'] if s not in SET_NOT_EXISTS[language_pair]])) train_filenames = ('train.{}-{}.{}'.format(src_language, tgt_language, src_language), 'train.{}-{}.{}'.format(src_language, tgt_language, tgt_language)) valid_filenames = ('IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, src_language), 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], valid_set, src_language, tgt_language, tgt_language)) test_filenames = ('IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, src_language), 'IWSLT{}.TED.{}.{}-{}.{}'.format(SUPPORTED_DATASETS['year'], test_set, src_language, tgt_language, tgt_language)) src_train, tgt_train = train_filenames src_eval, tgt_eval = valid_filenames src_test, tgt_test = test_filenames extracted_files = [] # list of paths to the extracted files dataset_tar = download_from_url(SUPPORTED_DATASETS['URL'], root=root, hash_value=SUPPORTED_DATASETS['MD5'], path=os.path.join(root, SUPPORTED_DATASETS['_PATH']), hash_type='md5') extracted_dataset_tar = extract_archive(dataset_tar) # IWSLT dataset's url downloads a multilingual tgz. # We need to take an extra step to pick out the specific language pair from it. src_language = train_filenames[0].split(".")[-1] tgt_language = train_filenames[1].split(".")[-1] languages = "-".join([src_language, tgt_language]) iwslt_tar = '{}/{}/texts/{}/{}/{}.tgz' iwslt_tar = iwslt_tar.format( root, SUPPORTED_DATASETS['_PATH'].split(".")[0], src_language, tgt_language, languages) extracted_dataset_tar = extract_archive(iwslt_tar) extracted_files.extend(extracted_dataset_tar) # Clean the xml and tag file in the archives file_archives = [] for fname in extracted_files: if 'xml' in fname: _clean_xml_file(fname) file_archives.append(os.path.splitext(fname)[0]) elif "tags" in fname: _clean_tags_file(fname) file_archives.append(fname.replace('.tags', '')) else: file_archives.append(fname) data_filenames = { "train": _construct_filepaths(file_archives, src_train, tgt_train), "valid": _construct_filepaths(file_archives, src_eval, tgt_eval), "test": _construct_filepaths(file_archives, src_test, tgt_test) } for key in data_filenames.keys(): if len(data_filenames[key]) == 0 or data_filenames[key] is None: raise FileNotFoundError( "Files are not found for data type {}".format(key)) src_data_iter = _read_text_iterator(data_filenames[split][0]) tgt_data_iter = _read_text_iterator(data_filenames[split][1]) def _iter(src_data_iter, tgt_data_iter): for item in zip(src_data_iter, tgt_data_iter): yield item return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split][num_lines_set_identifier[split]][tuple(sorted(language_pair))], _iter(src_data_iter, tgt_data_iter))
def Multi30k(root, split, task='task1', language_pair=('de', 'en'), train_set="train", valid_set="val", test_set="test_2016_flickr"): """Multi30k Dataset The available datasets include following: **Language pairs (task1)**: +-----+-----+-----+-----+-----+ | |'en' |'cs' |'de' |'fr' | +-----+-----+-----+-----+-----+ |'en' | | x | x | x | +-----+-----+-----+-----+-----+ |'cs' | x | | x | x | +-----+-----+-----+-----+-----+ |'de' | x | x | | x | +-----+-----+-----+-----+-----+ |'fr' | x | x | x | | +-----+-----+-----+-----+-----+ **Language pairs (task2)**: +-----+-----+-----+ | |'en' |'de' | +-----+-----+-----+ |'en' | | x | +-----+-----+-----+ |'de' | x | | +-----+-----+-----+ For additional details refer to source: https://github.com/multi30k/dataset Args: root: Directory where the datasets are saved. Default: ".data" split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’) task: Indicate the task language_pair: tuple or list containing src and tgt language train_set: A string to identify train set. valid_set: A string to identify validation set. test_set: A string to identify test set. Examples: >>> from torchtext.experimental.datasets.raw import Multi30k >>> train_iter, valid_iter, test_iter = Multi30k() >>> src_sentence, tgt_sentence = next(train_iter) """ if task not in SUPPORTED_DATASETS.keys(): raise ValueError( 'task {} is not supported. Valid options are {}'.format( task, SUPPORTED_DATASETS.keys())) assert ( len(language_pair) == 2 ), 'language_pair must contain only 2 elements: src and tgt language respectively' if language_pair[0] not in SUPPORTED_DATASETS[task].keys(): raise ValueError( "Source language '{}' is not supported. Valid options for task '{}' are {}" .format(language_pair[0], task, list(SUPPORTED_DATASETS[task].keys()))) if language_pair[1] not in SUPPORTED_DATASETS[task].keys(): raise ValueError( "Target language '{}' is not supported. Valid options for task '{}' are {}" .format(language_pair[1], task, list(SUPPORTED_DATASETS[task].keys()))) if train_set not in SUPPORTED_DATASETS[task][ language_pair[0]].keys() or 'train' not in train_set: raise ValueError( "'{}' is not a valid train set identifier. valid options for task '{}' and language pair {} are {}" .format(train_set, task, language_pair, [ k for k in SUPPORTED_DATASETS[task][language_pair[0]].keys() if 'train' in k ])) if valid_set not in SUPPORTED_DATASETS[task][ language_pair[0]].keys() or 'val' not in valid_set: raise ValueError( "'{}' is not a valid valid set identifier. valid options for task '{}' and language pair {} are {}" .format(valid_set, task, language_pair, [ k for k in SUPPORTED_DATASETS[task][language_pair[0]].keys() if 'val' in k ])) if test_set not in SUPPORTED_DATASETS[task][ language_pair[0]].keys() or 'test' not in test_set: raise ValueError( "'{}' is not a valid test set identifier. valid options for task '{}' and language pair {} are {}" .format(test_set, task, language_pair, [ k for k in SUPPORTED_DATASETS[task][language_pair[0]].keys() if 'test' in k ])) train_filenames = [ "{}.{}".format(train_set, language_pair[0]), "{}.{}".format(train_set, language_pair[1]) ] valid_filenames = [ "{}.{}".format(valid_set, language_pair[0]), "{}.{}".format(valid_set, language_pair[1]) ] test_filenames = [ "{}.{}".format(test_set, language_pair[0]), "{}.{}".format(test_set, language_pair[1]) ] if split == 'train': src_file, tgt_file = train_filenames elif split == 'valid': src_file, tgt_file = valid_filenames else: src_file, tgt_file = test_filenames extracted_files = [] # list of paths to the extracted files current_url = [] current_md5 = [] current_filenames = [src_file, tgt_file] for url, md5 in zip(URL[split], MD5[split]): if any(f in url for f in current_filenames): current_url.append(url) current_md5.append(md5) for url, md5 in zip(current_url, current_md5): dataset_tar = download_from_url(url, path=os.path.join( root, os.path.basename(url)), root=root, hash_value=md5, hash_type='md5') extracted_files.extend(extract_archive(dataset_tar)) file_archives = extracted_files data_filenames = { split: _construct_filepaths(file_archives, src_file, tgt_file), } for key in data_filenames: if len(data_filenames[key]) == 0 or data_filenames[key] is None: raise FileNotFoundError( "Files are not found for data type {}".format(key)) assert data_filenames[split][ 0] is not None, "Internal Error: File not found for reading" assert data_filenames[split][ 1] is not None, "Internal Error: File not found for reading" src_data_iter = _read_text_iterator(data_filenames[split][0]) tgt_data_iter = _read_text_iterator(data_filenames[split][1]) def _iter(src_data_iter, tgt_data_iter): for item in zip(src_data_iter, tgt_data_iter): yield item set_identifier = { 'train': train_set, 'valid': valid_set, 'test': test_set, } return _RawTextIterableDataset( "Multi30k", SUPPORTED_DATASETS[task][language_pair[0]][ set_identifier[split]]['NUM_LINES'], _iter(src_data_iter, tgt_data_iter))
def WMT14(root, split, language_pair=('de', 'en'), train_set='train.tok.clean.bpe.32000', valid_set='newstest2013.tok.bpe.32000', test_set='newstest2014.tok.bpe.32000'): """WMT14 Dataset The available datasets include following: **Language pairs**: +-----+-----+-----+ | |'en' |'de' | +-----+-----+-----+ |'en' | | x | +-----+-----+-----+ |'de' | x | | +-----+-----+-----+ Args: root: Directory where the datasets are saved. Default: ".data" split: split or splits to be returned. Can be a string or tuple of strings. Default: (‘train’, ‘valid’, ‘test’) language_pair: tuple or list containing src and tgt language train_set: A string to identify train set. valid_set: A string to identify validation set. test_set: A string to identify test set. Examples: >>> from torchtext.datasets import WMT14 >>> train_iter, valid_iter, test_iter = WMT14() >>> src_sentence, tgt_sentence = next(train_iter) """ supported_language = ['en', 'de'] supported_train_set = [s for s in NUM_LINES if 'train' in s] supported_valid_set = [s for s in NUM_LINES if 'test' in s] supported_test_set = [s for s in NUM_LINES if 'test' in s] assert ( len(language_pair) == 2 ), 'language_pair must contain only 2 elements: src and tgt language respectively' if language_pair[0] not in supported_language: raise ValueError( "Source language '{}' is not supported. Valid options are {}". format(language_pair[0], supported_language)) if language_pair[1] not in supported_language: raise ValueError( "Target language '{}' is not supported. Valid options are {}". format(language_pair[1], supported_language)) if train_set not in supported_train_set: raise ValueError( "'{}' is not a valid train set identifier. valid options are {}". format(train_set, supported_train_set)) if valid_set not in supported_valid_set: raise ValueError( "'{}' is not a valid valid set identifier. valid options are {}". format(valid_set, supported_valid_set)) if test_set not in supported_test_set: raise ValueError( "'{}' is not a valid valid set identifier. valid options are {}". format(test_set, supported_test_set)) train_filenames = '{}.{}'.format(train_set, language_pair[0]), '{}.{}'.format( train_set, language_pair[1]) valid_filenames = '{}.{}'.format(valid_set, language_pair[0]), '{}.{}'.format( valid_set, language_pair[1]) test_filenames = '{}.{}'.format(test_set, language_pair[0]), '{}.{}'.format( test_set, language_pair[1]) if split == 'train': src_file, tgt_file = train_filenames elif split == 'valid': src_file, tgt_file = valid_filenames else: src_file, tgt_file = test_filenames dataset_tar = download_from_url(URL, root=root, hash_value=MD5, path=os.path.join(root, _PATH), hash_type='md5') extracted_files = extract_archive(dataset_tar) data_filenames = { split: _construct_filepaths(extracted_files, src_file, tgt_file), } for key in data_filenames: if len(data_filenames[key]) == 0 or data_filenames[key] is None: raise FileNotFoundError( "Files are not found for data type {}".format(key)) assert data_filenames[split][ 0] is not None, "Internal Error: File not found for reading" assert data_filenames[split][ 1] is not None, "Internal Error: File not found for reading" src_data_iter = _read_text_iterator(data_filenames[split][0]) tgt_data_iter = _read_text_iterator(data_filenames[split][1]) def _iter(src_data_iter, tgt_data_iter): for item in zip(src_data_iter, tgt_data_iter): yield item return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[os.path.splitext(src_file)[0]], _iter(src_data_iter, tgt_data_iter))