Exemple #1
0
#!/usr/bin/env python
# coding: utf-8

# this script prepares data for pegasus/gigaword eval

# 0. pip install pegasus
# 1. ./process.py

from pegasus.data import all_datasets
from pathlib import Path

input_pattern = "tfds:gigaword"
split = "test"
ds_test = all_datasets.get_dataset(input_pattern + "-" + split,
                                   shuffle_files=False)

save_path = Path("data")
save_path.mkdir(parents=True, exist_ok=True)
src_path = save_path / "test.source"
tgt_path = save_path / "test.target"
with open(src_path, 'wt') as src_file, open(tgt_path, 'wt') as tgt_file:
    for i, d in enumerate(ds_test):
        src = d["inputs"].numpy().decode()
        tgt = d["targets"].numpy().decode()
        src_len, tgt_len = len(src), len(tgt)

        #  remove articles with no summary
        if src_len and tgt_len:
            src = src.replace('\n', '<n>')
            tgt = tgt.replace('\n', '<n>')
            src_file.write(src + '\n')
Exemple #2
0
 def test_multiple_tfds(self, input_pattern):
     data = all_datasets.get_dataset(input_pattern, False)
     self.check_output(data, None, task_id=True)
Exemple #3
0
 def test_corpus_tfds(self, input_pattern):
     for split in ["train", "validation", "test"]:
         data = all_datasets.get_dataset(input_pattern + "-" + split, False)
         self.check_output(data, False)
Exemple #4
0
 def test_tfds_kwargs(self):
     data = all_datasets.get_dataset(
         "tfds_transformed:common_crawl-train-shard_100-take_50", False)
     self.check_output(data, False)
Exemple #5
0
 def test_supervised_files(self, input_pattern):
     data = all_datasets.get_dataset(input_pattern, False)
     self.assertIsInstance(next(iter(data)), dict)
     self.check_output(data, True)