예제 #1
0
def _setup_datasets(dataset_name, separator, root, data_select):
    data_select = check_default_set(data_select,
                                    target_select=('train', 'valid', 'test'))
    extracted_files = []
    if isinstance(URLS[dataset_name], list):
        for f in URLS[dataset_name]:
            dataset_tar = download_from_url(f, root=root)
            extracted_files.extend(extract_archive(dataset_tar))
    elif isinstance(URLS[dataset_name], str):
        dataset_tar = download_from_url(URLS[dataset_name], root=root)
        extracted_files.extend(extract_archive(dataset_tar))
    else:
        raise ValueError(
            "URLS for {} has to be in a form or list or string".format(
                dataset_name))

    data_filenames = {
        "train": _construct_filepath(extracted_files, "train.txt"),
        "valid": _construct_filepath(extracted_files, "dev.txt"),
        "test": _construct_filepath(extracted_files, "test.txt")
    }
    return tuple(
        RawTextIterableDataset(
            dataset_name, NUM_LINES[dataset_name],
            _create_data_from_iob(data_filenames[item], separator)
        ) if data_filenames[item] is not None else None
        for item in data_select)
예제 #2
0
def _setup_datasets(dataset_name, separator, root=".data"):

    extracted_files = []
    if isinstance(URLS[dataset_name], list):
        for f in URLS[dataset_name]:
            dataset_tar = download_from_url(f, root=root)
            extracted_files.extend(extract_archive(dataset_tar))
    elif isinstance(URLS[dataset_name], str):
        dataset_tar = download_from_url(URLS[dataset_name], root=root)
        extracted_files.extend(extract_archive(dataset_tar))
    else:
        raise ValueError(
            "URLS for {} has to be in a form or list or string".format(
                dataset_name))

    data_filenames = {
        "train": _construct_filepath(extracted_files, "train.txt"),
        "valid": _construct_filepath(extracted_files, "dev.txt"),
        "test": _construct_filepath(extracted_files, "test.txt")
    }

    datasets = []
    for key in data_filenames.keys():
        if data_filenames[key] is not None:
            datasets.append(
                RawSequenceTaggingIterableDataset(
                    _create_data_from_iob(data_filenames[key], separator)))
        else:
            datasets.append(None)

    return datasets
예제 #3
0
def _setup_datasets(dataset_name, root, data_select):
    data_select = check_default_set(data_select,
                                    target_select=('train', 'test'))
    if dataset_name == 'AG_NEWS':
        extracted_files = [
            download_from_url(URLS[dataset_name][item],
                              root=root,
                              hash_value=MD5['AG_NEWS'][item],
                              hash_type='md5') for item in ('train', 'test')
        ]
    else:
        dataset_tar = download_from_url(URLS[dataset_name],
                                        root=root,
                                        hash_value=MD5[dataset_name],
                                        hash_type='md5')
        extracted_files = extract_archive(dataset_tar)

    cvs_path = {}
    for fname in extracted_files:
        if fname.endswith('train.csv'):
            cvs_path['train'] = fname
        if fname.endswith('test.csv'):
            cvs_path['test'] = fname
    return tuple(
        RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][item],
                               _create_data_from_csv(cvs_path[item]))
        for item in data_select)
예제 #4
0
def _setup_datasets(dataset_name,
                    tokenizer=get_tokenizer("basic_english"),
                    root='.data',
                    vocab=None,
                    removed_tokens=[],
                    data_select=('train', 'test', 'valid'),
                    min_freq=1):

    if isinstance(data_select, str):
        data_select = [data_select]
    if not set(data_select).issubset(set(('train', 'test', 'valid'))):
        raise TypeError('data_select is not supported!')

    if dataset_name == 'PennTreebank':
        extracted_files = []
        select_to_index = {'train': 0, 'test': 1, 'valid': 2}
        extracted_files = [
            download_from_url(URLS['PennTreebank'][select_to_index[key]],
                              root=root) for key in data_select
        ]
    else:
        dataset_tar = download_from_url(URLS[dataset_name], root=root)
        extracted_files = [
            os.path.join(root, d) for d in extract_archive(dataset_tar)
        ]

    _path = {}
    for item in data_select:
        _path[item] = _get_datafile_path(item, extracted_files)

    if vocab is None:
        if 'train' not in _path.keys():
            raise TypeError("Must pass a vocab if train is not selected.")
        logging.info('Building Vocab based on {}'.format(_path['train']))
        txt_iter = iter(
            tokenizer(row) for row in io.open(_path['train'], encoding="utf8"))
        vocab = build_vocab_from_iterator(txt_iter, min_freq=min_freq)
        logging.info('Vocab has {} entries'.format(len(vocab)))
    else:
        if not isinstance(vocab, Vocab):
            raise TypeError("Passed vocabulary is not of type Vocab")

    data = {}
    for item in _path.keys():
        data[item] = []
        logging.info('Creating {} data'.format(item))
        txt_iter = iter(
            tokenizer(row) for row in io.open(_path[item], encoding="utf8"))
        _iter = numericalize_tokens_from_iterator(vocab, txt_iter,
                                                  removed_tokens)
        for tokens in _iter:
            data[item] += [token_id for token_id in tokens]

    for key in data_select:
        if data[key] == []:
            raise TypeError('Dataset {} is empty!'.format(key))

    return tuple(
        LanguageModelingDataset(torch.tensor(data[d]).long(), vocab)
        for d in data_select)
예제 #5
0
def _setup_datasets(dataset_name, root, split, offset):
    if dataset_name == 'AG_NEWS':
        extracted_files = [
            download_from_url(URLS[dataset_name][item],
                              root=root,
                              path=os.path.join(root,
                                                _PATHS[dataset_name][item]),
                              hash_value=MD5['AG_NEWS'][item],
                              hash_type='md5') for item in ('train', 'test')
        ]
    else:
        dataset_tar = download_from_url(URLS[dataset_name],
                                        root=root,
                                        path=os.path.join(
                                            root, _PATHS[dataset_name]),
                                        hash_value=MD5[dataset_name],
                                        hash_type='md5')
        extracted_files = extract_archive(dataset_tar)

    cvs_path = {}
    for fname in extracted_files:
        if fname.endswith('train.csv'):
            cvs_path['train'] = fname
        if fname.endswith('test.csv'):
            cvs_path['test'] = fname
    return [
        RawTextIterableDataset(dataset_name,
                               NUM_LINES[dataset_name][item],
                               _create_data_from_csv(cvs_path[item]),
                               offset=offset) for item in split
    ]
예제 #6
0
def _setup_datasets(dataset_name, separator, root, split, offset):
    extracted_files = []
    if isinstance(URLS[dataset_name], dict):
        for name, item in URLS[dataset_name].items():
            dataset_tar = download_from_url(item,
                                            root=root,
                                            hash_value=MD5[dataset_name][name],
                                            hash_type='md5')
            extracted_files.extend(extract_archive(dataset_tar))
    elif isinstance(URLS[dataset_name], str):
        dataset_tar = download_from_url(URLS[dataset_name],
                                        root=root,
                                        hash_value=MD5[dataset_name],
                                        hash_type='md5')
        extracted_files.extend(extract_archive(dataset_tar))
    else:
        raise ValueError(
            "URLS for {} has to be in a form of dictionary or string".format(
                dataset_name))

    data_filenames = {
        "train": _construct_filepath(extracted_files, "train.txt"),
        "valid": _construct_filepath(extracted_files, "dev.txt"),
        "test": _construct_filepath(extracted_files, "test.txt")
    }
    return [
        RawTextIterableDataset(
            dataset_name,
            NUM_LINES[dataset_name][item],
            _create_data_from_iob(data_filenames[item], separator),
            offset=offset) if data_filenames[item] is not None else None
        for item in split
    ]
예제 #7
0
def _setup_datasets(dataset_name, root, split, year, language, offset):
    if dataset_name == 'PennTreebank':
        extracted_files = [download_from_url(URLS['PennTreebank'][key],
                                             root=root, hash_value=MD5['PennTreebank'][key],
                                             hash_type='md5') for key in split]
    else:
        dataset_tar = download_from_url(URLS[dataset_name], root=root, hash_value=MD5[dataset_name], hash_type='md5')
        extracted_files = extract_archive(dataset_tar)

    if dataset_name == 'WMTNewsCrawl':
        file_name = 'news.{}.{}.shuffled'.format(year, language)
        extracted_files = [f for f in extracted_files if file_name in f]

    path = {}
    for item in split:
        for fname in extracted_files:
            if item in fname:
                path[item] = fname

    datasets = []
    for item in split:
        logging.info('Creating {} data'.format(item))
        datasets.append(RawTextIterableDataset(dataset_name,
                                               NUM_LINES[dataset_name][item], iter(io.open(path[item], encoding="utf8")), offset=offset))

    return datasets
    def _run_pipeline_step(
        self,
        inputs: PytorchTrainInputs,
        outputs: PytorchTrainOutputs,
    ):
        print("Inside run pipeline!!!! for training step")

        if inputs.source_code:
            print("Inside source code block!!!")
            print(inputs.source_code[0])
            download_from_url(inputs.source_code[0],
                              root=inputs.source_code_path[0])
            print("download successfull")

            entry_point = ["ls", "-R", "/pvc/input"]
            run_code = subprocess.run(entry_point, stdout=subprocess.PIPE)
            print("Checking downloaded file!!!")
            print(run_code.stdout)

        if inputs.container_entrypoint:
            print("Inside entry point container block")
            entry_point = inputs.container_entrypoint
            entry_point.append(json.dumps(inputs.input_data))
            entry_point.append(json.dumps(inputs.output_data))
            entry_point.append(json.dumps(inputs.input_parameters))
            run_code = subprocess.run(entry_point, stdout=subprocess.PIPE)
            print(run_code.stdout)
예제 #9
0
def _setup_datasets(dataset_name,
                    train_filenames,
                    valid_filenames,
                    test_filenames,
                    root='.data'):
    if not isinstance(train_filenames, tuple) and not isinstance(valid_filenames, tuple) \
            and not isinstance(test_filenames, tuple):
        raise ValueError("All filenames must be tuples")

    src_train, tgt_train = train_filenames
    src_eval, tgt_eval = valid_filenames
    src_test, tgt_test = test_filenames

    extracted_files = []
    if isinstance(URLS[dataset_name], list):
        for f in URLS[dataset_name]:
            dataset_tar = download_from_url(f, root=root)
            extracted_files.extend(extract_archive(dataset_tar))
    elif isinstance(URLS[dataset_name], str):
        dataset_tar = download_from_url(URLS[dataset_name], root=root)
        extracted_files.extend(extract_archive(dataset_tar))
    else:
        raise ValueError(
            "URLS for {} has to be in a form or list or string".format(
                dataset_name))

    # Clean the xml and tag file in the archives
    file_archives = []
    for fname in extracted_files:
        if 'xml' in fname:
            _clean_xml_file(fname)
            file_archives.append(os.path.splitext(fname)[0])
        elif "tags" in fname:
            _clean_tags_file(fname)
            file_archives.append(fname.replace('.tags', ''))
        else:
            file_archives.append(fname)

    data_filenames = defaultdict(dict)
    data_filenames = {
        "train": _construct_filepaths(file_archives, src_train, tgt_train),
        "valid": _construct_filepaths(file_archives, src_eval, tgt_eval),
        "test": _construct_filepaths(file_archives, src_test, tgt_test)
    }

    for key in data_filenames.keys():
        if len(data_filenames[key]) == 0 or data_filenames[key] is None:
            raise FileNotFoundError(
                "Files are not found for data type {}".format(key))

    datasets = []
    for key in data_filenames.keys():
        src_data_iter = _read_text_iterator(data_filenames[key][0])
        tgt_data_iter = _read_text_iterator(data_filenames[key][1])

        datasets.append(
            RawTranslationIterableDataset(src_data_iter, tgt_data_iter))

    return tuple(datasets)
예제 #10
0
    def download(self):
        import os

        if not os.path.isfile(self.sample_file):
            if not os.path.isfile('genwiki.zip'):
                from torchtext.utils import download_from_url
                print('[Info] No existing data detected. Start downloading...')
                download_from_url(self.url_data, root='.')
            os.system('unzip genwiki.zip')
예제 #11
0
def prepare_data(train_path,
                 val_path,
                 test_path,
                 dh_path,
                 load_from_dump=True,
                 bs=16):
    if load_from_dump == False:
        url_base = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/'
        train_urls = ('train.de.gz', 'train.en.gz')
        val_urls = ('val.de.gz', 'val.en.gz')
        test_urls = ('test_2016_flickr.de.gz', 'test_2016_flickr.en.gz')

        train_filepaths = [
            extract_archive(download_from_url(url_base + url))[0]
            for url in train_urls
        ]
        val_filepaths = [
            extract_archive(download_from_url(url_base + url))[0]
            for url in val_urls
        ]
        test_filepaths = [
            extract_archive(download_from_url(url_base + url))[0]
            for url in test_urls
        ]

        m_dh = DataHandler(train_filepaths)

        train_data = m_dh.data_process(train_filepaths)
        val_data = m_dh.data_process(val_filepaths)
        test_data = m_dh.data_process(test_filepaths)

        dump_data(train_data, train_path)
        dump_data(val_data, val_path)
        dump_data(test_data, test_path)
        dump_data(m_dh, dh_path)
    else:
        train_data = load_dump(train_path)
        val_data = load_dump(val_path)
        test_data = load_dump(test_path)
        m_dh = load_dump(dh_path)

    train_loader = DataLoader(train_data,
                              batch_size=bs,
                              shuffle=True,
                              collate_fn=generate_batch)
    valid_loader = DataLoader(val_data,
                              batch_size=bs,
                              shuffle=False,
                              collate_fn=generate_batch)
    test_loader = DataLoader(test_data,
                             batch_size=bs,
                             shuffle=False,
                             collate_fn=generate_batch)
    return train_loader, valid_loader, test_loader, m_dh
예제 #12
0
def _setup_datasets(dataset_name, root, data_select, year, language):
    data_select = check_default_set(data_select, ('train', 'test', 'valid'))
    if isinstance(data_select, str):
        data_select = [data_select]
    if not set(data_select).issubset(set(('train', 'test', 'valid'))):
        raise TypeError('data_select is not supported!')

    if dataset_name == 'PennTreebank':
        extracted_files = []
        select_to_index = {'train': 0, 'test': 1, 'valid': 2}
        extracted_files = [
            download_from_url(URLS['PennTreebank'][select_to_index[key]],
                              root=root,
                              hash_value=MD5['PennTreebank'][key],
                              hash_type='md5') for key in data_select
        ]
    elif dataset_name == 'WMTNewsCrawl':
        if not (data_select == ['train']
                or set(data_select).issubset(set(('train', )))):
            raise ValueError("WMTNewsCrawl only creates a training dataset. "
                             "data_select should be 'train' "
                             "or ('train',), got {}.".format(data_select))
        dataset_tar = download_from_url(URLS[dataset_name],
                                        root=root,
                                        hash_value=MD5['WMTNewsCrawl'],
                                        hash_type='md5')
        extracted_files = extract_archive(dataset_tar)
        file_name = 'news.{}.{}.shuffled'.format(year, language)
        extracted_files = [f for f in extracted_files if file_name in f]
    else:
        dataset_tar = download_from_url(URLS[dataset_name],
                                        root=root,
                                        hash_value=MD5[dataset_name],
                                        hash_type='md5')
        extracted_files = extract_archive(dataset_tar)

    _path = {}
    for item in data_select:
        for fname in extracted_files:
            if item in fname:
                _path[item] = fname

    data = {}
    for item in _path.keys():
        logging.info('Creating {} data'.format(item))
        data[item] = iter(io.open(_path[item], encoding="utf8"))

    return tuple(
        RawTextIterableDataset(dataset_name, NUM_LINES[dataset_name][item],
                               data[item]) for item in data_select)
예제 #13
0
    def test_builtin_pretrained_sentencepiece_processor(self):
        sp_model_path = download_from_url(PRETRAINED_SP_MODEL['text_unigram_25000'])
        spm_tokenizer = sentencepiece_tokenizer(sp_model_path)
        _path = os.path.join(self.project_root, '.data', 'text_unigram_25000.model')
        os.remove(_path)
        test_sample = 'the pretrained spm model names'
        ref_results = ['\u2581the', '\u2581pre', 'trained', '\u2581sp', 'm', '\u2581model', '\u2581names']
        self.assertEqual(spm_tokenizer(test_sample), ref_results)

        sp_model_path = download_from_url(PRETRAINED_SP_MODEL['text_bpe_25000'])
        spm_transform = sentencepiece_processor(sp_model_path)
        _path = os.path.join(self.project_root, '.data', 'text_bpe_25000.model')
        os.remove(_path)
        test_sample = 'the pretrained spm model names'
        ref_results = [13, 1465, 12824, 304, 24935, 5771, 3776]
        self.assertEqual(spm_transform(test_sample), ref_results)
예제 #14
0
def SQuAD2(root, split):
    extracted_files = download_from_url(URL[split],
                                        root=root,
                                        hash_value=MD5[split],
                                        hash_type='md5')
    return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
                                   _create_data_from_json(extracted_files))
def _setup_datasets(dataset_name, root='.data', ngrams=1, vocab=None, include_unk=False):
    dataset_tar = download_from_url(URLS[dataset_name], root=root)
    extracted_files = extract_archive(dataset_tar)

    for fname in extracted_files:
        if fname.endswith('train.csv'):
            train_csv_path = fname
        if fname.endswith('test.csv'):
            test_csv_path = fname

    if vocab is None:
        logging.info('Building Vocab based on {}'.format(train_csv_path))
        vocab = build_vocab_from_iterator(_csv_iterator(train_csv_path, ngrams))
    else:
        if not isinstance(vocab, Vocab):
            raise TypeError("Passed vocabulary is not of type Vocab")
    logging.info('Vocab has {} entries'.format(len(vocab)))
    logging.info('Creating training data')
    train_data, train_labels = _create_data_from_iterator(
        vocab, _csv_iterator(train_csv_path, ngrams, yield_cls=True), include_unk)
    logging.info('Creating testing data')
    test_data, test_labels = _create_data_from_iterator(
        vocab, _csv_iterator(test_csv_path, ngrams, yield_cls=True), include_unk)
    if len(train_labels ^ test_labels) > 0:
        raise ValueError("Training and test labels don't match")
    return (TextClassificationDataset(vocab, train_data, train_labels),
            TextClassificationDataset(vocab, test_data, test_labels))
예제 #16
0
def IMDB(root='.data', split=('train', 'test'), offset=0):
    """ Defines raw IMDB datasets.

    Create supervised learning dataset: IMDB

    Separately returns the raw training and test dataset

    Args:
        root: Directory where the datasets are saved. Default: ".data"
        split: a string or tuple for the returned datasets. Default: ('train', 'test')
            By default, both datasets (train, test) are generated. Users could also choose any one or two of them,
            for example ('train', 'test') or just a string 'train'.
        offset: the number of the starting line. Default: 0

    Examples:
        >>> train, test = torchtext.experimental.datasets.raw.IMDB()
    """
    split_ = check_default_set(split, ('train', 'test'), 'IMDB')
    dataset_tar = download_from_url(URLS['IMDB'],
                                    root=root,
                                    hash_value=MD5['IMDB'],
                                    hash_type='md5')
    extracted_files = extract_archive(dataset_tar)
    return wrap_datasets(
        tuple(
            RawTextIterableDataset("IMDB",
                                   NUM_LINES["IMDB"][item],
                                   generate_imdb_data(item, extracted_files),
                                   offset=offset) for item in split_), split)
예제 #17
0
def FastText(language="en", unk_tensor=None, root=".data", validate_file=True, num_cpus=32):
    r"""Create a FastText Vectors object.

    Args:
        language (str): the language to use for FastText. The list of supported languages options
                        can be found at https://fasttext.cc/docs/en/language-identification.html
        unk_tensor (Tensor): a 1d tensor representing the vector associated with an unknown token
        root (str): folder used to store downloaded files in. Default: '.data'.
        validate_file (bool): flag to determine whether to validate the downloaded files checksum.
                              Should be `False` when running tests with a local asset.
        num_cpus (int): the number of cpus to use when loading the vectors from file. Default: 10.

    Returns:
        torchtext.experimental.vectors.Vector: a Vectors object.

    Raises:
        ValueError: if duplicate tokens are found in FastText file.

    """
    url = "https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.{}.vec".format(language)

    checksum = None
    if validate_file:
        checksum = CHECKSUMS_FAST_TEXT.get(url, None)

    downloaded_file_path = download_from_url(url, root=root, hash_value=checksum)
    cpp_vectors_obj, dup_tokens = _load_token_and_vectors_from_file(downloaded_file_path, ' ', num_cpus, unk_tensor)

    if dup_tokens:
        raise ValueError("Found duplicate tokens in file: {}".format(str(dup_tokens)))

    vectors_obj = Vectors(cpp_vectors_obj)
    return vectors_obj
예제 #18
0
파일: test_utils.py 프로젝트: zivlir/text
    def test_download_extract_zip(self):
        # create root directory for downloading data
        root = '.data'
        if not os.path.exists(root):
            os.makedirs(root)

        # ensure archive is not already downloaded, if it is then delete
        url = 'https://bitbucket.org/sivareddyg/public/downloads/en-ud-v2.zip'
        target_archive_path = os.path.join(root, 'en-ud-v2.zip')
        conditional_remove(target_archive_path)

        # download archive and ensure is in correct location
        archive_path = utils.download_from_url(url)
        assert target_archive_path == archive_path

        # extract files and ensure they are correct
        files = utils.extract_archive(archive_path)
        assert files == [
            'en-ud-v2/', 'en-ud-v2/en-ud-tag.v2.dev.txt',
            'en-ud-v2/en-ud-tag.v2.test.txt',
            'en-ud-v2/en-ud-tag.v2.train.txt', 'en-ud-v2/LICENSE.txt',
            'en-ud-v2/README.txt'
        ]

        # remove files and archive
        for f in files:
            conditional_remove(os.path.join(root, f))
        os.rmdir(os.path.join(root, 'en-ud-v2'))
        conditional_remove(archive_path)
예제 #19
0
def _setup_qa_datasets(dataset_name,
                       tokenizer=get_tokenizer("basic_english"),
                       root='.data',
                       vocab=None,
                       removed_tokens=[],
                       data_select=('train', 'dev')):

    if isinstance(data_select, str):
        data_select = [data_select]
    if not set(data_select).issubset(set(('train', 'dev'))):
        raise TypeError('data_select is not supported!')

    extracted_files = []
    select_to_index = {'train': 0, 'dev': 1}
    extracted_files = [
        download_from_url(URLS[dataset_name][select_to_index[key]], root=root)
        for key in data_select
    ]

    squad_data = {}
    for item in data_select:
        with open(extracted_files[select_to_index[item]]) as json_file:
            raw_data = json.load(json_file)['data']
            squad_data[item] = process_raw_json_data(raw_data)

    if vocab is None:
        if 'train' not in squad_data.keys():
            raise TypeError("Must pass a vocab if train is not selected.")
        logging.info('Building Vocab based on train data')
        vocab = build_vocab_from_iterator(
            squad_iterator(squad_data['train'], tokenizer))
    else:
        if not isinstance(vocab, Vocab):
            raise TypeError("Passed vocabulary is not of type Vocab")
    logging.info('Vocab has {} entries'.format(len(vocab)))

    data = {}
    for item in data_select:
        data_iter = create_data_from_iterator(vocab, squad_data[item],
                                              tokenizer)
        tensor_data = []
        for context, question, _ans in data_iter:
            iter_data = {
                'context':
                torch.tensor([token_id for token_id in context]).long(),
                'question':
                torch.tensor([token_id for token_id in question]).long(),
                'answers': [],
                'ans_pos': []
            }
            for (_answer, ans_start_id, ans_end_id) in _ans:
                iter_data['answers'].append(
                    torch.tensor([token_id for token_id in _answer]).long())
                iter_data['ans_pos'].append(
                    torch.tensor([ans_start_id, ans_end_id]).long())
            tensor_data.append(iter_data)
        data[item] = tensor_data

    return tuple(
        QuestionAnswerDataset(data[item], vocab) for item in data_select)
예제 #20
0
파일: test_utils.py 프로젝트: zivlir/text
    def test_download_extract_to_path(self):
        # create root directory for downloading data
        root = '.data'
        if not os.path.exists(root):
            os.makedirs(root)

        # create directory to extract archive to
        to_path = '.new_data'
        if not os.path.exists(root):
            os.makedirs(root)

        # ensure archive is not already downloaded, if it is then delete
        url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
        target_archive_path = os.path.join(root, 'validation.tar.gz')
        conditional_remove(target_archive_path)

        # download archive and ensure is in correct location
        archive_path = utils.download_from_url(url)
        assert target_archive_path == archive_path

        # extract files and ensure they are in the to_path directory
        files = utils.extract_archive(archive_path, to_path)
        assert files == [
            os.path.join(to_path, 'val.de'),
            os.path.join(to_path, 'val.en')
        ]

        # remove files and archive
        for f in files:
            conditional_remove(f)
        conditional_remove(archive_path)
예제 #21
0
def prepare_data(device="cpu", train_batch_size=20, eval_batch_size=20, data_dir=None):
    url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip"

    download_path = ".data_wikitext_2_v1"
    extract_path = None
    if data_dir:
        download_path = os.path.join(data_dir, "download")
        os.makedirs(download_path, exist_ok=True)
        download_path = os.path.join(download_path, "wikitext-2-v1.zip")

        extract_path = os.path.join(data_dir, "extracted")
        os.makedirs(extract_path, exist_ok=True)

    test_filepath, valid_filepath, train_filepath = extract_archive(
        download_from_url(url, root=download_path), to_path=extract_path
    )
    tokenizer = get_tokenizer("basic_english")
    vocab = build_vocab_from_iterator(map(tokenizer, iter(io.open(train_filepath, encoding="utf8"))))

    def data_process(raw_text_iter):
        data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter]
        return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

    train_data = data_process(iter(io.open(train_filepath, encoding="utf8")))
    val_data = data_process(iter(io.open(valid_filepath, encoding="utf8")))
    test_data = data_process(iter(io.open(test_filepath, encoding="utf8")))

    device = torch.device(device)

    train_data = batchify(train_data, train_batch_size, device)
    val_data = batchify(val_data, eval_batch_size, device)
    test_data = batchify(test_data, eval_batch_size, device)

    return train_data, val_data, test_data
예제 #22
0
def _generate_imdb_data_iterators(dataset_name, root, ngrams, tokenizer,
                                  data_select):
    if not tokenizer:
        tokenizer = get_tokenizer("basic_english")

    if isinstance(data_select, str):
        data_select = [data_select]
    if not set(data_select).issubset(set(('train', 'test'))):
        raise TypeError(
            'Given data selection {} is not supported!'.format(data_select))

    dataset_tar = download_from_url(URLS[dataset_name], root=root)
    extracted_files = extract_archive(dataset_tar)

    iters_group = {}
    if 'train' in data_select:
        iters_group['vocab'] = _imdb_iterator('train', extracted_files,
                                              tokenizer, ngrams)
    for item in data_select:
        iters_group[item] = _imdb_iterator(item,
                                           extracted_files,
                                           tokenizer,
                                           ngrams,
                                           yield_cls=True)
    return iters_group
예제 #23
0
def _download_extract_validate(root,
                               url,
                               url_md5,
                               downloaded_file,
                               extracted_file,
                               extracted_file_md5,
                               hash_type="sha256"):
    root = os.path.abspath(root)
    downloaded_file = os.path.abspath(downloaded_file)
    extracted_file = os.path.abspath(extracted_file)
    if os.path.exists(extracted_file):
        with open(os.path.join(root, extracted_file), 'rb') as f:
            if validate_file(f, extracted_file_md5, hash_type):
                return extracted_file

    dataset_tar = download_from_url(url,
                                    path=os.path.join(root, downloaded_file),
                                    hash_value=url_md5,
                                    hash_type=hash_type)
    extracted_files = extract_archive(dataset_tar)
    assert os.path.exists(
        extracted_file
    ), "extracted_file [{}] was not found in the archive [{}]".format(
        extracted_file, extracted_files)

    return extracted_file
예제 #24
0
    def test_download_extract_gz(self):
        # create root directory for downloading data
        root = os.path.abspath('.data')
        if not os.path.exists(root):
            os.makedirs(root)

        # ensure archive is not already downloaded, if it is then delete
        url = 'https://raw.githubusercontent.com/multi30k/dataset/master/data/task2/raw/val.5.en.gz'
        target_archive_path = os.path.join(root, 'val.5.en.gz')
        conditional_remove(target_archive_path)

        # download archive and ensure is in correct location
        archive_path = utils.download_from_url(url)
        assert target_archive_path == archive_path

        # extract files and ensure they are correct
        files = utils.extract_archive(archive_path)
        assert files == [os.path.join(root, 'val.5.en')]

        # extract files with overwrite option True
        files = utils.extract_archive(archive_path, overwrite=True)
        assert files == [os.path.join(root, 'val.5.en')]

        # remove files and archive
        for f in files:
            conditional_remove(f)
        conditional_remove(archive_path)
예제 #25
0
def setup_datasets(dataset_name,
                   root='.data',
                   vocab_size=20000,
                   include_unk=False):
    dataset_tar = download_from_url(URLS[dataset_name], root=root)
    extracted_files = extract_archive(dataset_tar)

    for fname in extracted_files:
        if fname.endswith('train.csv'):
            train_csv_path = fname
        if fname.endswith('test.csv'):
            test_csv_path = fname

    # generate sentencepiece  pretrained tokenizer
    if not path.exists('m_user.model'):
        logging.info('Generate SentencePiece pretrained tokenizer...')
        generate_sp_model(train_csv_path, vocab_size)

    sp_model = load_sp_model("m_user.model")
    sp_generator = sentencepiece_numericalizer(sp_model)
    train_data, train_labels = _create_data_with_sp_transform(
        sp_generator, train_csv_path)
    test_data, test_labels = _create_data_with_sp_transform(
        sp_generator, test_csv_path)

    if len(train_labels ^ test_labels) > 0:
        raise ValueError("Training and test labels don't match")
    return (text_classification.TextClassificationDataset(
        None, train_data, train_labels),
            text_classification.TextClassificationDataset(
                None, test_data, test_labels))
예제 #26
0
def download_extract_archive(url, raw_folder, dataset_name):
    """Download the dataset if it doesn't exist in processed_folder already."""

    train_csv_path = os.path.join(raw_folder, dataset_name + '_csv',
                                  'train.csv')
    test_csv_path = os.path.join(raw_folder, dataset_name + '_csv', 'test.csv')
    if os.path.exists(train_csv_path) and os.path.exists(test_csv_path):
        return

    os.makedirs(raw_folder)
    filename = dataset_name + '_csv.tar.gz'
    url = url
    path = os.path.join(raw_folder, filename)
    download_from_url(url, path)
    extract_archive(path, raw_folder, remove_finished=True)

    logging.info('Dataset %s downloaded.' % dataset_name)
예제 #27
0
def AG_NEWS(root, split):
    path = download_from_url(URL[split],
                             root=root,
                             path=os.path.join(root, split + ".csv"),
                             hash_value=MD5[split],
                             hash_type='md5')
    return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
                                   _create_data_from_csv(path))
예제 #28
0
def WikiText103(root, split):
    dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5')
    extracted_files = extract_archive(dataset_tar)

    path = _find_match(split, extracted_files)
    logging.info('Creating {} data'.format(split))
    return _RawTextIterableDataset('WikiText103',
                                   NUM_LINES[split], iter(io.open(path, encoding="utf8")))
예제 #29
0
def PennTreebank(root, split):
    path = download_from_url(URL[split],
                             root=root,
                             hash_value=MD5[split],
                             hash_type='md5')
    logging.info('Creating {} data'.format(split))
    return RawTextIterableDataset('PennTreebank', NUM_LINES[split],
                                  iter(io.open(path, encoding="utf8")))
예제 #30
0
def PennTreebank(root, split):
    path = download_from_url(URL[split],
                             root=root,
                             hash_value=MD5[split],
                             hash_type='md5')
    logging.info('Creating {} data'.format(split))
    return _RawTextIterableDataset(DATASET_NAME, NUM_LINES[split],
                                   _read_text_iterator(path))