def SogouNews(root, split): path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]), _EXTRACTED_FILES_MD5[split], hash_type="md5") logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], _create_data_from_csv(path))
def Multi30k(root, split, language_pair=('de', 'en')): """Multi30k dataset Reference: http://www.statmt.org/wmt16/multimodal-task.html#task1 Args: root: Directory where the datasets are saved. Default: ".data" split: split or splits to be returned. Can be a string or tuple of strings. Default: ('train', 'valid', 'test') language_pair: tuple or list containing src and tgt language. Available options are ('de','en') and ('en', 'de') """ assert ( len(language_pair) == 2 ), 'language_pair must contain only 2 elements: src and tgt language respectively' assert (tuple(sorted(language_pair)) == ( 'de', 'en')), "language_pair must be either ('de','en') or ('en', 'de')" downloaded_file = os.path.basename(URL[split]) src_path = _download_extract_validate( root, URL[split], MD5[split], os.path.join(root, downloaded_file), os.path.join( root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' + language_pair[0]), _EXTRACTED_FILES_INFO[split]['md5'][language_pair[0]]) trg_path = _download_extract_validate( root, URL[split], MD5[split], os.path.join(root, downloaded_file), os.path.join( root, _EXTRACTED_FILES_INFO[split]['file_prefix'] + '.' + language_pair[1]), _EXTRACTED_FILES_INFO[split]['md5'][language_pair[1]]) src_data_iter = _read_text_iterator(src_path) trg_data_iter = _read_text_iterator(trg_path) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], zip(src_data_iter, trg_data_iter))
def CoNLL2000Chunking(root, split): # Create a dataset specific subfolder to deal with generic download filenames root = os.path.join(root, 'conll2000chunking') path = os.path.join(root, split + ".txt.gz") data_filename = _download_extract_validate(root, URL[split], MD5[split], path, os.path.join( root, _EXTRACTED_FILES[split]), _EXTRACTED_FILES_MD5[split], hash_type="md5") logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset("CoNLL2000Chunking", NUM_LINES[split], _create_data_from_iob(data_filename, " "))
def AmazonReviewFull(root, split): def _create_data_from_csv(data_path): with io.open(data_path, encoding="utf8") as f: reader = unicode_csv_reader(f) for row in reader: yield int(row[0]), ' '.join(row[1:]) path = _download_extract_validate(root, URL, MD5, os.path.join(root, _PATH), os.path.join(root, _EXTRACTED_FILES[split]), _EXTRACTED_FILES_MD5[split], hash_type="md5") logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset("AmazonReviewFull", NUM_LINES[split], _create_data_from_csv(path))
def WMTNewsCrawl(root, split, year=2010, language='en'): if year not in _AVAILABLE_YEARS: raise ValueError( "{} not available. Please choose from years {}".format( year, _AVAILABLE_YEARS)) if language not in _AVAILABLE_LANGUAGES: raise ValueError( "{} not available. Please choose from languages {}".format( language, _AVAILABLE_LANGUAGES)) path = _download_extract_validate(root, URL, MD5, _PATH, _EXTRACTED_FILES[language], _EXTRACTED_FILES_MD5[language], hash_type="md5") logging.info('Creating {} data'.format(split)) return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split], _read_text_iterator(path))