def validate(summarizer, validate_dataset, language='en'):
    """ validation function to be used optionally in fine tuning.

    Args:
        summarizer(BertSumAbs): The summarizer under fine tuning.
        validate_dataset (SummarizationDataset): dataset for validation.

    Returns:
        string: A string which contains the rouge score on a subset of
            the validation dataset.

    """
    TOP_N = 32
    shortened_dataset = validate_dataset.shorten(TOP_N)
    reference_summaries = [
        " ".join(t).rstrip("\n") for t in shortened_dataset.get_target()
    ]
    generated_summaries = summarizer.predict(shortened_dataset,
                                             num_gpus=4,
                                             batch_size=4)
    assert len(generated_summaries) == len(reference_summaries)
    print("###################")
    print("prediction is {}".format(generated_summaries[0]))
    print("reference is {}".format(reference_summaries[0]))

    rouge_score = compute_rouge_python(cand=generated_summaries,
                                       ref=reference_summaries,
                                       language=language)
    return rouge_score
예제 #2
0
def test_compute_rouge_python(rouge_test_data):
    rouge_python = compute_rouge_python(cand=rouge_test_data["candidates"],
                                        ref=rouge_test_data["references"])

    pytest.approx(rouge_python["rouge-1"]["r"], R1R, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-1"]["p"], R1P, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-1"]["f"], R1F, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-2"]["r"], R2R, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-2"]["p"], R2P, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-2"]["f"], R2F, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-l"]["r"], RLR, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-l"]["p"], RLP, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-l"]["f"], RLF, abs=ABS_TOL)
예제 #3
0
def test_compute_rouge_python_hi(rouge_test_data):
    rouge_python = compute_rouge_python(cand=rouge_test_data["candidates_hi"],
                                        ref=rouge_test_data["references_hi"],
                                        language="hi")

    pytest.approx(rouge_python["rouge-1"]["r"], R1R_hi, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-1"]["p"], R1P_hi, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-1"]["f"], R1F_hi, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-2"]["r"], R2R_hi, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-2"]["p"], R2P_hi, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-2"]["f"], R2F_hi, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-l"]["r"], RLR_hi, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-l"]["p"], RLP_hi, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-l"]["f"], RLF_hi, abs=ABS_TOL)
예제 #4
0
def test_compute_rouge_python_file(rouge_test_data, tmp):
    tmp_cand_file = os.path.join(tmp, "cand.txt")
    tmp_ref_file = os.path.join(tmp, "ref.txt")

    with open(tmp_cand_file, "w") as f:
        for s in rouge_test_data["candidates"]:
            f.write(s + "\n")
    with open(tmp_ref_file, "w") as f:
        for s in rouge_test_data["references"]:
            f.write(s + "\n")

    rouge_python = compute_rouge_python(cand=tmp_cand_file,
                                        ref=tmp_ref_file,
                                        is_input_files=True)

    pytest.approx(rouge_python["rouge-1"]["r"], R1R, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-1"]["p"], R1P, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-1"]["f"], R1F, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-2"]["r"], R2R, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-2"]["p"], R2P, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-2"]["f"], R2F, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-l"]["r"], RLR, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-l"]["p"], RLP, abs=ABS_TOL)
    pytest.approx(rouge_python["rouge-l"]["f"], RLF, abs=ABS_TOL)
예제 #5
0
batch_size = 250
rouge_scores = {}
predictions = {}

TEST = False


for dataset in ['swiss', 'bundes']:
    predictions[dataset] = {}
    rouge_scores[dataset] = {}
    print("Dataset: ", dataset)
    if TEST:
        n = 5
    else:
        n = len(torch_tests[dataset])
    print("Sample size:", n)
    
    for train_name, summarizer in summarizers.items():
        print("model name: ", train_name)
        if "lead" in train_name:
            print("IN HERE")
            predictions[dataset][train_name] = leads[train_name][:n]
        else:
            predictions[dataset][train_name] = summarizer.predict(torch_tests[dataset][:n], num_gpus=0, batch_size=batch_size, sentence_separator=sentence_separator)
        
        rouge_scores[dataset][train_name] = compute_rouge_python(cand=predictions[dataset][train_name], ref=target[dataset][:n])

        
# print out the calculated rouge scores
pprint.pprint(rouge_scores)
예제 #6
0
def main():
    torch.distributed.init_process_group(
        timeout=datetime.timedelta(0, 5400),
        backend="nccl",
    )

    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()

    processor = S2SAbsSumProcessor(model_name=MODEL_NAME)

    abs_summarizer = S2SAbstractiveSummarizer(
        model_name=MODEL_NAME,
        max_seq_length=MAX_SEQ_LENGTH,
        max_source_seq_length=MAX_SOURCE_SEQ_LENGTH,
        max_target_seq_length=MAX_TARGET_SEQ_LENGTH,
    )

    if args.local_rank == 0:
        torch.distributed.barrier()

    train_dataset = processor.s2s_dataset_from_json_or_file(
        train_ds, train_mode=True, local_rank=args.local_rank)

    test_dataset = processor.s2s_dataset_from_json_or_file(
        test_ds, train_mode=False, local_rank=args.local_rank)

    abs_summarizer.fit(
        train_dataset=train_dataset,
        per_gpu_batch_size=TRAIN_PER_GPU_BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        learning_rate=LEARNING_RATE,
        warmup_steps=WARMUP_STEPS,
        max_steps=MAX_STEPS,
        fp16=args.fp16,
        fp16_opt_level=args.fp16_opt_level,
        local_rank=args.local_rank,
        save_model_to_dir=".",
    )

    torch.distributed.barrier()

    if args.local_rank in [-1, 0]:
        res = abs_summarizer.predict(
            test_dataset=test_dataset,
            per_gpu_batch_size=TEST_PER_GPU_BATCH_SIZE,
            beam_size=BEAM_SIZE,
            forbid_ignore_word=FORBID_IGNORE_WORD,
            fp16=args.fp16,
        )

        for r in res[:5]:
            print(r)

        with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
            for line in res:
                f.write(line + "\n")

        tgt = []
        with jsonlines.open(test_ds) as reader:
            for item in reader:
                tgt.append(item["tgt"])

        for t in tgt[:5]:
            print(t)

        print(compute_rouge_python(cand=res, ref=tgt))