コード例 #1
0
def multiprocess_map(dset, num_processes, function, **kwargs):
    with multiprocessing.Pool(processes=num_processes) as pool:
        shards = pool.map(
            partial(
                shard_and_map,
                filename=dset._data_files[0]["filename"],
                num_shards=num_processes,
                function=function,
                **kwargs,
            ),
            range(num_processes),
        )
    return nlp.concatenate_datasets(shards)
コード例 #2
0
ファイル: test_arrow_dataset.py プロジェクト: wjj962464/nlp-1
    def test_concatenate(self):
        data1, data2, data3 = {"id": [0, 1, 2]}, {"id": [3, 4, 5]}, {"id": [6, 7]}
        info1 = DatasetInfo(description="Dataset1")
        info2 = DatasetInfo(description="Dataset2")
        dset1, dset2, dset3 = (
            Dataset.from_dict(data1, info=info1),
            Dataset.from_dict(data2, info=info2),
            Dataset.from_dict(data3),
        )

        dset_concat = concatenate_datasets([dset1, dset2, dset3])
        self.assertEquals(len(dset_concat), len(dset1) + len(dset2) + len(dset3))
        self.assertEquals(dset_concat.info.description, "Dataset1\n\nDataset2\n\n")
コード例 #3
0
    def test_concatenate(self):
        data1, data2, data3 = {
            "id": [0, 1, 2]
        }, {
            "id": [3, 4, 5]
        }, {
            "id": [6, 7]
        }
        dset1, dset2, dset3 = Dataset.from_dict(data1), Dataset.from_dict(
            data2), Dataset.from_dict(data3)
        dset1._info = DatasetInfo(description="Dataset1")
        dset2._info = DatasetInfo(description="Dataset2")
        dset3._info = None

        dset_concat = concatenate_datasets([dset1, dset2, dset3])
        self.assertEquals(len(dset_concat),
                          len(dset1) + len(dset2) + len(dset3))
        self.assertEquals(dset_concat.info.description, "Dataset1\n\nDataset2")
                                            return_token_type_ids=True)
    return encodings

"""### MNLI"""

mnli_encoded_dataset = mnli.map(convert_to_features, batched=True, remove_columns=['idx', 'premise', 'hypothesis'])
mnli_encoded_dataset.set_format("torch", columns=['attention_mask', 'input_ids', 'token_type_ids', 'label'])

print(mnli_encoded_dataset.num_rows)
print(mnli_encoded_dataset.num_columns)
print(mnli_encoded_dataset.column_names)

"""### XNLI"""

xnli = nlp.load_dataset(path='xnli')
xnli = nlp.concatenate_datasets([xnli['test'], xnli['validation']])

def preprocess_xnli(example):
    premise_output = []
    hypothesis_output = []
    label_output = []
    for prem, hyp, lab in zip(example['premise'],  example['hypothesis'], example["label"]):
        label = lab
        langs = hyp['language']
        translations = hyp['translation']
        hypothesis = {k: v for k, v in zip(langs, translations)}
        for lang in prem:
            if lang in hypothesis:
                premise_output += [prem[lang]]
                hypothesis_output += [hypothesis[lang]]
                label_output += [label]
コード例 #5
0
    dset = nlp.load_dataset("bookcorpus",
                            split="train",
                            cache_dir=args.cache_dir)
elif args.dataset == "wikibooks":
    dset_wikipedia = nlp.load_dataset("wikipedia",
                                      "20200501.en",
                                      split="train",
                                      cache_dir=args.cache_dir)
    dset_wikipedia.drop(columns=["title"])
    dset_wikipedia.features.pop("title")
    dset_books = nlp.load_dataset("bookcorpus",
                                  split="train",
                                  cache_dir=args.cache_dir)
    # Cast schemas, since one is nullable and one is not
    dset_wikipedia._data = dset_wikipedia.data.cast(dset_books._data.schema)
    dset = nlp.concatenate_datasets([dset_wikipedia, dset_books])
elif args.dataset == "c4":
    dset = nlp.load_dataset("c4", "en", cache_dir=args.cache_dir)
    assert False, "This dataset must be preprocessed beforehand"
else:
    assert False
print("Loaded dataset:", dset, dset[0])
assert dset.column_names == ["text"
                             ], "Dataset should have exactly one 'text' column"

print("Filtering empty lines")
dset = dset.filter(
    lambda ex: len(ex["text"]) > 0,
    cache_file_name=os.path.join(args.cache_dir, FILTER_CACHE),
    load_from_cache_file=load_from_cache_file,
)