Example #1
0
def main(args):
    text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
    with jsonl.open(args.original_file, gzip=True) as test_file:
        data = test_file.read()

    with jsonl.open(args.out_file, gzip=True) as out_file:
        out_file.write(data[-args.n:])
def main(args):
    text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
    train_split, val_split, test_split = load_splits(args.splits_file)
    summaries = os.listdir(args.summary_dir)

    num_summaries = 0
    train_data, val_data, test_data = [], [], []
    for file_name in tqdm(summaries):
        summary_data = load_summary(os.path.join(args.summary_dir, file_name))
        if len(summary_data["summary"]) == 0 or len(summary_data["text"]) == 0:
            continue
        summary_data["summary"] = encode_line(summary_data["summary"],
                                              text_encoder)
        summary_data["text"] = encode_line(summary_data["text"], text_encoder)
        file_id = file_name.split(".")[0]
        if file_id in train_split:
            train_data.append(summary_data)
            num_summaries += 1
        elif file_id in val_split:
            val_data.append(summary_data)
            num_summaries += 1
        elif file_id in test_split:
            test_data.append(summary_data)
            num_summaries += 1

    with jsonl.open(args.train_file, gzip=True) as train_file:
        train_file.write(train_data)
    with jsonl.open(args.val_file, gzip=True) as val_file:
        val_file.write(val_data)
    with jsonl.open(args.test_file, gzip=True) as test_file:
        test_file.write(test_data)
    print("Number of successful conversions: {}".format(num_summaries))
Example #3
0
def run(dataset: str, summaries_file: str, model: typing.Callable, budget: int):
    model = SafeSummerizer(model, budget)
    with jsonl.open(dataset, gzip=True) as dataset_fh:
        with jsonl.open(summaries_file, gzip=True) as summaries_fh:
            summaries_fh.delete()  # clean out output file
            progressbar = tqdm()
            for document in dataset_fh:
                summary = model(document)
                summaries_fh.appendline({"system": summary})
                progressbar.update()
Example #4
0
def filter_urls(urls_file: str, out: str, prop_name: str = "original"):
    """Simple filter heuristic for noisy articles based on a hURL heuristic."""
    with jsonl.open(urls_file, gzip=True) as urls_file_fh:
        with jsonl.open(out, gzip=True) as output_fh:
            for snapshot_entry in tqdm(urls_file_fh):
                url = snapshot_entry[prop_name]
                is_asset = _is_asset(url)
                is_hurl = _is_hurl(url)

                if not is_asset and is_hurl:
                    output_fh.appendline(snapshot_entry)
Example #5
0
def main(system, dataset, summaries, keys):

    print("Starting", system, "Docker image.")

    process = Popen(
        [
            "docker", "run", "--rm", "--name", system, "-a", "stdin", "-a",
            "stdout", "-i", system
        ],
        stdin=PIPE,
        stdout=PIPE,
    )

    dataset_file = jsonl.open(dataset, gzip=True)

    # Check the size of the dataset.
    # As a sanity check and for the progress bar.

    print("Loading articles... ", end="", flush=True)
    dataset_length = len(dataset_file)
    print("found", dataset_length, "articles.\n")

    # Start new thread to feed summaries into container.

    Thread(target=_writer,
           args=(process, dataset_file, keys.split(","))).start()

    # Start progress bar.

    progress = tqdm(
        readiter(process.stdout),
        total=dataset_length,
        desc="Running " + system,
    )

    # Prepare to decode summaries.

    is_json = True

    with jsonl.open(summaries, gzip=True) as summaries_file:

        summaries_file.delete()

        with progress as output:
            for line in output:

                summaries_file.appendline({"system": line})

    print("\nRun complete. Next, evaluate with newsroom-score.")
def main(args):
    text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
    num_summaries = 0
    out_data = []
    with jsonl.open(args.in_file, gzip=True) as in_file:
        data = in_file.read()
        for entry in tqdm(data):
            if entry["summary"] is None or entry["text"] is None:
                continue
            entry["summary"] = encode_line(entry["summary"], text_encoder)
            entry["text"] = encode_line(entry["text"], text_encoder)
            num_summaries += 1
            out_data.append(entry)
    with jsonl.open(args.out_file, gzip=True) as out_file:
        out_file.write(out_data)
    print("Number of successful conversions: {}".format(num_summaries))
def main(args):
    args.data_path = '/home/yilin10945/summary/data/newsroom/train.data'
    args.target_path = './data/train_features.json'

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    bert_model = BertModel.from_pretrained('bert-base-uncased')
    bert_model.eval()
    bert_model = bert_model.to(device)
    bert_model_cpu = BertModel.from_pretrained('bert-base-uncased')
    bert_model_cpu.eval()

    train_data = jsonl.open(args.data_path, gzip=True).read()
    with open(args.target_path, 'w') as fp:
        for idx, doc in enumerate(tqdm.tqdm(train_data)):
            if doc['text']:
                doc_features = list()
                for sen in doc['text'].split('\n\n'):
                    if len(sen) > 1:
                        input_tok = tokenizer.tokenize(sen[:512])
                        input_list = tokenizer.convert_tokens_to_ids(input_tok)
                        try:
                            input_tensor = torch.tensor(input_list).view(1, -1).to(device)
                            encoded_layers, pooled_output = bert_model(input_tensor)
                        except:
                            input_tensor = input_tensor.to('cpu')
                            encoded_layers, pooled_output = bert_model_cpu(input_tensor)
                        features = encoded_layers[-1].view(-1, 768).tolist()
                        doc_features.extend(features)
                doc['id'] = idx
                doc['bert_features'] = doc_features
                print(json.dumps(doc), file=fp)
Example #8
0
 def __init__(self, data_file, encoder, max_size=None, subset=None):
     with jsonl.open(data_file, gzip=True) as f:
         self.data = f.read()
     if subset is not None:
         self.data = [x for x in self.data if x["density_bin"] == subset]
     random.shuffle(self.data)
     if max_size is not None:
         self.data = self.data[:max_size]
     self.encoder = encoder
Example #9
0
def run_parallel(dataset: str, summaries: str, model: typing.Callable, budget: int):
    model = SafeSummerizer(model, budget)
    with jsonl.open(dataset, gzip=True) as dataset_fh:
        with jsonl.open(summaries, gzip=True) as summaries_fh:
            summaries_fh.delete()  # clean out output file
            progressbar = tqdm()

            chunk = []
            for document in dataset_fh:
                chunk.append(document)

                if len(chunk) >= CHUNK_SIZE:
                    with ProcessPoolExecutor(WORKERS) as pool:
                        results = pool.map(model, chunk)
                        results = list(results)
                        for summary in results:
                            summaries_fh.appendline({"system": summary})
                        chunk = []

                    progressbar.update(len(results))
Example #10
0
def split_dataset(dataset: str, thin_dev: str, thin_test: str):
    # employing convention to reduce number of arguments
    assert dataset.endswith(
        ".dataset"), "dataset name doesn't follow convention"
    assert thin_dev.endswith(".thin"), "..."
    assert thin_test.endswith(".thin"), "..."

    train_out = f"{dataset[:-8]}.train.dataset"
    dev_out = f"{dataset[:-8]}.dev.dataset"
    test_out = f"{dataset[:-8]}.test.dataset"

    for filename in [train_out, dev_out, test_out]:
        assert not exists(
            filename), "output file (%s) already exist." % filename

    with jsonl.open(thin_dev, gzip=True) as dev_fh:
        dev_ids = set(doc["archive"] for doc in tqdm(dev_fh))

    with jsonl.open(thin_test, gzip=True) as test_fh:
        test_ids = set(doc["archive"] for doc in tqdm(test_fh))

    # Disable formatter as it messes up formatting below. It's take on it ain't pretty.
    # fmt: off
    with jsonl.open(dataset, gzip=True) as dataset_fh,\
         jsonl.open(train_out, gzip=True) as train_out_fh,\
         jsonl.open(dev_out, gzip=True) as dev_out_fh,\
         jsonl.open(test_out, gzip=True) as test_out_fh:
        for doc in tqdm(dataset_fh):
            if doc["archive"] in dev_ids:
                dev_out_fh.appendline(doc)
            elif doc["archive"] in test_ids:
                test_out_fh.appendline(doc)
            else:
                train_out_fh.appendline(doc)
Example #11
0
def build_thin_from_dataset(dataset: str, thin: str):
    """
    Builds thin-file from dataset-file for distribution. This is useful when a dataset
    already exists and one wishes distribute it in a appropriately legal manor.

    Recall that a thin-file contains metrics and information to download the dataset.
    """
    desired_cols = set([
        "archive",
        "date",
        "density",
        "coverage",
        "compression",
        "compression_bin",
        "coverage_bin",
        "density_bin",
    ])
    with jsonl.open(dataset, gzip=True) as dataset_fh:
        with jsonl.open(thin, gzip=True) as thin_fh:
            for doc in tqdm(dataset_fh):
                thin_entry = {col: doc[col] for col in desired_cols}
                thin_fh.appendline(thin_entry)
Example #12
0
def trim_and_transform(example_generator, new_filename, transformation,
                       constraint):
    oldcount, newcount = 0, 0
    if os.path.isfile(new_filename):
        os.remove(new_filename)
    with jsonl.open(new_filename, gzip=True) as newfile:
        for line in example_generator:
            oldcount += 1
            line = transformation(line)
            if constraint(line):
                newcount += 1
                newfile.appendline(line)
            if oldcount % 1000 == 0:
                print(oldcount)
    print('# of old lines: %i, # of new lines: %i' % (oldcount, newcount))
def main(args):
    text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
    num_summaries = 0
    out_data = []
    with open(args.src_file) as src_file, open(args.tgt_file) as tgt_file:
        src_lines = src_file.readlines()
        tgt_lines = tgt_file.readlines()
        for i in tqdm(range(len(src_lines))):
            num_summaries += 1
            out_data.append({
                "summary":
                encode_line(tgt_lines[i].strip(), text_encoder),
                "text":
                encode_line(src_lines[i].strip(), text_encoder)
            })
    with jsonl.open(args.out_file, gzip=True) as out_file:
        out_file.write(out_data)
    print("Number of successful conversions: {}".format(num_summaries))
Example #14
0
def evaluate(args):
    summarizer = summarizer_factory(args.model_class)

    with jsonl.open(args.data, gzip=True) as dataset:
        for i, entry in enumerate(dataset):
            if args.num_texts is not None and i >= args.num_texts:
                break

            try:
                summary = summarizer.summarize(entry["text"])
                ref_summery = entry["summary"]

            except ValueError:
                # Handles "input must have more than one sentence"
                summary = entry["text"]
            print(json.dumps({
                "sum": summary,
                "ref": ref_summery
            }),
                  file=args.output,
                  flush=True)
Example #15
0
from newsroom import jsonl

# Read entire file:

with jsonl.open("dev.data", gzip=True) as train_file:
    train = train_file.read()

# Read file entry by entry:

with jsonl.open("dev.data", gzip=True) as train_file:
    for entry in train_file:
        print(entry["summary"], entry["text"])
Example #16
0
    entry['processed']['ner_summary'] = ner_summary
    entry['text'] = text.lower()
    entry['summary'] = summary.lower()
    return entry


root_dir = sys.argv[1]

filetypes = ['train', 'dev', 'test']

vocab = Counter()
pvocab = Counter()

for filetype in filetypes:
    filename = "{}/{}.data".format(root_dir, filetype)
    with jsonl.open(filename, gzip=True) as _file:
        entries = [_ for _ in _file]

    print("processing {} files...".format(len(entries)))

    pbar = tqdm(total=len(entries))

    processed_entries = []
    pool = Pool(cpu_count())
    for _ in pool.imap_unordered(run, entries):
        processed_entries.append(_)
        pbar.update(1)

    pool.close()

    assert len(processed_entries) == len(entries)
Example #17
0
def main(archive, urldiff, dataset, workers, chunksize, lang):

    if archive is None and urldiff is None:

        print("One of --archive or --urldiff required.")

        return

    elif urldiff:

        # Check to see if the dataset contains all URLs.

        required = set()

        print("Comparing URL file to dataset...")

        with open(urldiff, "rt") as urls_file:

            for line in urls_file:
                required.add(line.strip())

        with jsonl.open(dataset, gzip = True) as dataset_file:

            for article in dataset_file.readlines(ignore_errors = True):

                url = article.get("archive", article.get("url"))
                required.discard(url)

        if len(required) > 0:

            print(len(required), "URLs missing:\n")

            for url in required:

                print(url)

        else:

            print("Dataset complete.")

        return

    previously = set()
    todo = set()

    if os.path.isfile(dataset):

        print("Comparing archive and dataset files: ", end = "")

        with jsonl.open(dataset, gzip = True) as dataset_file:

            for article in dataset_file.readlines(ignore_errors = True):

                url = article.get("archive", article.get("url"))
                previously.add(url)

        print("found", len(previously), "finished summaries... ", end = "")

    else:

        print("Loading downloaded summaries: ", end = "")

    with jsonl.open(archive, gzip = True) as archive_file:

        for article in archive_file.readlines(ignore_errors = True):

            url = article.get("archive", article.get("url"))
            todo.add(url)

    todo -= previously

    print("found", len(todo), "new summaries.\n")

    with tqdm(total = len(todo), desc = "Extracting Summaries") as progress:
        with jsonl.open(archive, gzip = True) as archive_file:
            with jsonl.open(dataset, gzip = True) as dataset_file:

                chunk = []

                def process_chunk():

                    with ProcessPoolExecutor(workers) as ex:
                        results = list(ex.map(Article.process, chunk))
                        results = [r for r in results if r is not None]

                        # Compute statistics.

                        for result in results:
                            if (result["text"] is None) or (result["summary"] is None):
                                continue
                            # this compared to the original impl.
                            # should skip empty summaries or bodies
                            # if not result["text"] or not result["summary"]:
                            #     continue

                            fragments = Fragments(result["summary"], result["text"], lang=lang)

                            result["density"] = fragments.density()
                            result["coverage"] = fragments.coverage()
                            result["compression"] = fragments.compression()

                            result["compression"] = fragments.compression()
                            result["coverage"] = fragments.coverage()
                            result["density"] = fragments.density()

                            for measure in ("compression", "coverage", "density"):

                                result[measure + "_bin"] = binner(
                                    result[measure],
                                    cutoffs[measure],
                                    levels[measure])

                        dataset_file.append(results)
                        progress.update(len(results))

                for article in archive_file.readlines(ignore_errors = True):

                    url = article.get("archive", article.get("url"))
                    if url not in todo: continue

                    chunk.append(article)

                    if len(chunk) >= chunksize:

                        process_chunk()
                        chunk = []

                process_chunk()

    print("\nExtraction complete.")
Example #18
0
from newsroom import jsonl
from gensim.summarization.summarizer import summarize

WORD_COUNT = 50

with jsonl.open("input.dataset", gzip=True) as inputs:
    with jsonl.open("textrank.summaries", gzip=True) as outputs:
        outputs.delete()
        for article in data:
            try:
                summary = summarize(article["text"], word_count=WORD_COUNT)
            except ValueError:
                # Handles "input must have more than one sentence"
                summary = article["text"]
            outputs.appendline({"system": summary.replace("\n", " ")})
Example #19
0
#! -*- coding: utf-8 -*-

from __future__ import print_function
import random

from newsroom import jsonl
from newsroom.analyze import Fragments


# Extraction Analysis
# Read file entry by entry:
# with open("../data/Newsroom/train.label.info.jsonl", "r+", encoding="utf8") as f:
#     for item in jsonlines.Reader(f):

with jsonl.open("../../data/Newsroom/train.label.info.jsonl") as train_file:
    for entry in train_file:
        print(entry["summary"], entry["text"])

        # Compute stats on random training example:
        summary, text = entry["summary"], entry["text"]
        fragments = Fragments("".join(summary), "\n".join(text))

        # Print paper metrics:
        print("Coverage:",    fragments.coverage())
        print("Density:",     fragments.density())
        print("Compression:", fragments.compression())

        # Extractive fragments oracle:
        print("List of extractive fragments:")
        print(fragments.strings())
from newsroom import jsonl

dir = 'D:\\Documents\\Classes\\CS224n\\project'
import os
# Read entire file:

with jsonl.open(os.path.join(dir, 'thin\\train.data'),
                gzip=True) as train_file:
    train = train_file.read()

# Read file entry by entry:

with jsonl.open("./thin/train.data", gzip=True) as train_file:
    for entry in train_file:
        print(entry["summary"], entry["text"])
Example #21
0
    entry['processed']['pos_summary'] = pos_summary
    entry['processed']['ner_summary'] = ner_summary
    entry['text'] = text.lower()
    entry['summary'] = summary.lower()
    return entry

root_dir = sys.argv[1]

filetypes = ['train', 'dev', 'test']

vocab = Counter()
pvocab = Counter()

for filetype in filetypes:
    filename = "{}/{}.data".format(root_dir, filetype)
    with jsonl.open(filename, gzip = True) as _file:
        entries = [_ for _ in _file]

    print("processing {} files...".format(len(entries)))

    pbar = tqdm(total=len(entries))

    processed_entries = []
    pool = Pool(cpu_count())
    for _ in pool.imap_unordered(run, entries):
        processed_entries.append(_)
        pbar.update(1)

    pool.close()

    assert len(processed_entries) == len(entries)
Example #22
0

import os
from newsroom import jsonl

index = 1
cc_num = 0
out_dir = "../data/Newsroom/train"
os.makedirs(out_dir, exist_ok=True)
with jsonl.open("../data/Newsroom/train.jsonl.gz", gzip=True) as train_file:
    for entry in train_file:
        summary_str = entry["summary"].strip()
        content_str = entry["text"].strip()
        # print("************************************")
        # print(summary_str)
        # print("-------------------------------")
        # print(content_str)
        pos = summary_str.find("<?xml:")
        if pos >= 0:
            continue
        if content_str.find(summary_str) >= 0:
            cc_num += 1
        text_file_path = os.path.join(out_dir, "%06d.txt" % index)
        index += 1
        with open(text_file_path, "w") as fw:
            fw.write(summary_str.replace("\n", ".") + "\n" + content_str)


print("summary in content num = %d " % cc_num)          #
print("total txt num = %d " % index)
Example #23
0
def preprocess_newsroom_datafile(filename, new_filename):
    with jsonl.open(filename, gzip=True) as oldfile:
        trim_and_transform(oldfile, new_filename, newsroom_preprocess,
                           newsroom_constraint)
Example #24
0
def main(urls, thin, archive, exactness, diff, **downloader_args):

    if not urls and not thin:

        print("Either --urls or --thin must be defined.")
        return

    # If the archive file exists, only download what we need.
    # Open the file and read all previously downloaded URLs.

    if not os.path.isfile(archive):

        done = {}

    else:

        print("Loading previously downloaded summaries:", end = " ")

        with jsonl.open(archive, gzip = True) as f:

            done = {ln["archive"] for ln in f}
            print(len(done), "downloaded summaries...", end = " ")

    # Read the URL file or thin.

    if urls:

        with open(urls, "r") as f:

            urls = [ln.strip() for ln in f]

    elif thin:

        with jsonl.open(thin, gzip = True) as f:

            urls = [entry["archive"] for entry in f]

    # Which URLs are remaining?

    todo = [url for url in urls if url not in done]
    size = round(0.00002 * len(todo), 1)

    # If --diff argument is enabled, just print undownloaded article URLs.

    if diff:

        print("there are", len(todo), "URLs not downloaded:\n")

        for url in todo:
            print(url)

        return

    else:

        print(len(todo), "new summaries (about", size, "GB).")

    # Truncate url dates if they can't be downloaded.

    if exactness is not None:

        exactness_map = {_exactness(url, exactness): url for url in todo}
        todo = list(exactness_map.keys())

    # Randomize todo to prevent "hard" pages from collecting at start.

    random.shuffle(todo)

    # Initialize the Archive scraper to start downloading.

    print("If pages fail to download now, re-run script when finished.\n")

    scraper = Downloader(**downloader_args)
    downloads = scraper.download(todo)

    # Progress bar arguments.

    progress = tqdm(
        total = len(todo),
        desc = "Downloading Summaries"
    )

    # Append the downloaded files to the compressed file.
    # (We checked earlier, so we won't overwrite older downloads.)

    errors = 0
    progress.update(0)

    try:

        with jsonl.open(archive, gzip = True) as f:

            for article in downloads:

                if article is None:

                    errors += 1

                    if errors % 10 == 0:

                        print()
                        print(errors, "pages need re-downloading later.")

                else:

                    # Rename url -> archive for consistency.

                    article["archive"] = article["url"]
                    del article["url"]

                    # Keep track of how much this article was truncated.

                    if exactness is not None:

                        exactness_archive = article["archive"]
                        real_archive = exactness_map[exactness_archive]

                        article["exactness_factor"] = exactness
                        article["exactness_archive"] = exactness_archive
                        article["archive"] = real_archive

                    # Write updated dictionary to JSON file.

                    f.appendline(article)

                progress.update(1)

        if errors > 0:

            print("\n\nRerun the script:", errors, "pages failed to download.")
            print("- Try running with a lower --workers count (default = 16).")
            print("- Check which URLs are left with the --diff flag.")
            print("- Last resort: --exactness X to truncate dates to X digits.")
            print("  (e.g., --exactness 4 will download the closest year.)")

        else:

            print("\n\nDownload complete. Next, run newsroom-extract.")

    except KeyboardInterrupt:

        print("\n\nDownload aborted with progress preserved.")
        print("Run script again to resume from this point.")
Example #25
0
def main(dataset, summaries, scores, rouge, stemmed, workers, chunksize):

    rouges = rouge.upper().split(",")

    with jsonl.open(dataset, gzip=True) as a:
        with jsonl.open(summaries, gzip=True) as s:
            with jsonl.open(scores, gzip=True) as f:

                # If scores file exists, delete it.
                # (So we write, rather than appending.)

                f.delete()
                size = len(s)
                chunk = []

                with tqdm(total=size, desc="Evaluating") as progress:

                    def process_chunk():

                        with ProcessPoolExecutor(workers) as ex:

                            results = list(ex.map(compute_rouge, chunk))

                            f.append(results)
                            progress.update(len(results))

                    for aline, sline in zip(a, s):

                        chunk.append([aline, sline, rouges, stemmed])

                        if len(chunk) >= chunksize:

                            process_chunk()
                            chunk = []

                    process_chunk()

    aggregate = {}
    with jsonl.open(scores, gzip=True) as f:

        for entry in f:
            for k, v in entry.items():

                if not k.startswith("rouge_"):
                    continue

                if k not in aggregate:
                    aggregate[k] = []
                else:
                    aggregate[k].append(v)

    print("\nScoring complete. Overall statistics:\n")
    for name, scores in aggregate.items():

        pretty = name \
            .title() \
            .replace("_", " ") \
            .replace("Rouge ", "ROUGE-") \
            .replace("Fscore", "F-Score:  ") \
            .replace("Precision", "Precision:") \
            .replace("Recall", "Recall:   ")

        print(pretty, round(sum(scores) / len(scores) * 100, 4))

    print()
    print("Next, run newsroom-tables for detailed statistics.")
Example #26
0
        sen_arr.append(sen)
    summary = ' '.join(sen_arr)
    sen_arr = []
    for sen in nltk.sent_tokenize(title):
        sen = nltk.word_tokenize(sen)
        sen = ['<s>']+sen+['</s>']
        sen = ' '.join(sen)
        sen_arr.append(sen)
    title = ' '.join(sen_arr)

    sen_arr = [title, summary, article]
    
    return '<sec>'.join(sen_arr)

fout = open('plain_data/test.txt', 'w')
fp = jsonl.open('extract_data/test.jsonl', gzip=True)
cnt = 0
batcher = []
for line in fp:
    cnt += 1
    print(cnt)
    batcher.append(line)
    if len(batcher) == 64:
        pool = Pool(processes=16)
        result = pool.map(process_data, batcher)
        pool.terminate()
        for itm in result:
            if len(itm) > 1:
                fout.write(itm+'\n')
        batcher = []
def preprocess_newsroom(filter_level=0,
                        raw_dir='~/newsroom-data-dir',
                        output_dir='~/newsroom-data-dir/processed-data',
                        write_inverse=False):
    '''
    preprocess newsroom dataset according to different filter levels.
    - filter_level=0: no special processing
    - filter_level=1: remove corruption text in source articles and summaries. (e.g. Some source articles contain only
        captions for photos; some bad summaries such as "Collection of all USATODAY.com coverage of People,
        including articles, videos, photos, and quotes.")
    - filter_level=2: entity hallucination filtering in addition to corruption text removal.
        A summary sentence is removed if it contains a named entity not in the source document.
    '''

    from newsroom import jsonl

    def source_bad(source, is_print=False):
        if len(source.split()) < 50:
            if is_print:
                print(source)
            return True
        # if source.startswith("'Image ") or \
        # source.startswith('"Photo: ') or \
        if (source.startswith('Image ') and source[6] in "0123456789") or \
            source.startswith("Photo: ") or \
            '"params":' in source :
            if is_print:
                print(source)
            return True
        return False

    def summary_bad(summary, is_print=False):
        if len(summary.split()) < 8:
            if is_print:
                print(summary)
            return True
        if re.search(re.escape('on FoxNews.com'), summary, re.IGNORECASE) or \
            re.search(re.escape('from FoxNews.com'), summary, re.IGNORECASE) or \
            re.search(re.escape('Collection of all USATODAY.com'), summary, re.IGNORECASE) or \
            re.search(re.escape('washingtonpost.com'), summary, re.IGNORECASE):
            if is_print:
                print(summary)
            return True
        return False

    nlp = spacy.load("en_core_web_lg")
    split_types = ['val', 'train', 'test']
    for split_type in split_types:
        if split_type == 'train':
            in_file = os.path.join(raw_dir, 'train.dataset')
        elif split_type == 'val':
            in_file = os.path.join(raw_dir, 'dev.dataset')
        elif split_type == 'test':
            in_file = os.path.join(raw_dir, 'test.dataset')
        else:
            print(
                "ERROR! split_type must be one of the following: train, val and test!"
            )
        count = 0
        output_source = split_type + '.source'
        output_target = split_type + '.target'
        num_lines = sum(1 for _ in jsonl.open(in_file, gzip=True))
        with open(os.path.join(output_dir, output_source), 'w') as source_f, \
            open(os.path.join(output_dir, output_target), 'w') as target_f:
            with jsonl.open(in_file, gzip=True) as f:
                for entry in tqdm(f, total=num_lines):
                    if entry['summary'] and entry['text']:
                        summary = " ".join(entry['summary'].split('\n'))
                        if filter_level > 0:
                            if summary_bad(summary) or source_bad(
                                    entry['text']):
                                if write_inverse and filter_level <= 1:
                                    source_f.write(entry['text'].strip(
                                    ).encode('unicode-escape').decode(
                                    ).replace('\\\\', '\\') + '\n')
                                    target_f.write(summary.strip() + '\n')
                                    # target_f.write(summary.strip().encode('unicode-escape').decode().replace('\\\\', '\\') + '\n')
                                    # source_f.write(repr(entry['text'].strip()) + '\n')
                                    # target_f.write(repr(entry['summary'].strip()) + '\n')
                                    count += 1
                                continue
                            filtered_summary = select_summary_sentences(
                                nlp, entry['text'], summary, filter_level)
                            if not filtered_summary:
                                if write_inverse:
                                    source_f.write(entry['text'].strip(
                                    ).encode('unicode-escape').decode(
                                    ).replace('\\\\', '\\') + '\n')
                                    target_f.write(summary.strip() + '\n')
                                    # target_f.write(summary.strip().encode('unicode-escape').decode().replace('\\\\', '\\') + '\n')
                                    count += 1
                                continue
                        if not write_inverse:
                            source_f.write(
                                entry['text'].strip().encode('unicode-escape').
                                decode().replace('\\\\', '\\') + '\n')
                            target_f.write(summary.strip() + '\n')
                            # target_f.write(
                            #     summary.strip().encode('unicode-escape').decode().replace('\\\\', '\\') + '\n')
                            count += 1
        print("Wrote {} lines in {}".format(
            count, os.path.join(output_dir, output_source)))