예제 #1
0
def test_real_article():
    """
    Test that the summary of a real article is the same as generated during a local test.
    Note: this test is slow (20 seconds not including imports / loading model).
    """
    # load data
    with open('test_article.json') as f:
        data = json.load(f)

    # compute summary
    doc = SingleDocument(document_id=0, raw={'body': data['article']})
    summary, score = generate_summary(doc.spacy_text())

    # check result
    assert isinstance(summary, unicode)
    assert summary == data['expected_summary']
    assert abs(score - data['expected_score']) < .001
예제 #2
0
def write_results(out_file):
    out = open(out_file, 'w')
    out.write('\t'.join(['Reference', 'Lexrank', 'Seq-to-seq', 'Score']) +
              '\n')

    for filename in sorted(os.listdir(RESULTS_ARTICLE_DIR)):
        article_id = int(filename.split('.')[0].split('_')[1])

        # Read article
        with open(os.path.join(RESULTS_ARTICLE_DIR, filename)) as f:
            article_text = unicode(f.read(), 'utf-8')
            article_text = article_text.replace(u'\xa0', ' ').replace(
                '\t', ' ').replace('\n', ' ')

        # Read reference summary
        with open(
                os.path.join(RESULTS_ABSTRACT_DIR,
                             'abstract_%d.txt' % article_id)) as f:
            reference_summary = f.read()

        doc = SingleDocument(0, raw={'body': article_text})

        # Generate lexrank summary
        lexrank_summary = get_lexrank_summary(doc).encode('utf-8')

        # Generate seq-to-seq summary
        t0 = time.time()
        spacy_article = doc.spacy_text()
        seq_to_seq_summary, score = generate_summary(spacy_article)
        seq_to_seq_summary = seq_to_seq_summary.encode('utf-8')

        print '####################'
        print seq_to_seq_summary
        print 'Time:', time.time() - t0, '| Score:', score

        # Write all results together
        out.write('\t'.join([
            reference_summary, lexrank_summary, seq_to_seq_summary,
            str(score)
        ]) + '\n')
        out.flush()

    out.close()
예제 #3
0
def get_cable_results(data_file, out_file):
    out = open(out_file, 'w')
    out.write('\t'.join(['Cable', 'Lexrank', 'Seq-to-seq']) + '\n')

    with open(data_file) as f:
        cables = json.load(f)

    for cable in cables[:100]:
        doc = SingleDocument(0, raw={'body': cable})
        if len(doc.text()) < 500:
            continue

        lexrank = get_lexrank_summary(doc)
        seq2seq = generate_summary(doc.spacy_text())[0]

        out.write('\t'.join([
            string.encode('utf-8').replace('\t', ' ').replace('\n', ' ')
            for string in [cable, lexrank, seq2seq]
        ]) + '\n')
        out.flush()

    out.close()
예제 #4
0
def test_incorrect_inputs():
    with raises(AssertionError):
        generate_summary(None)
    with raises(AssertionError):
        generate_summary(u'Random string')
예제 #5
0
def test_short_input():
    text_input = u'Short phrase.'
    summary, score = generate_summary(get_spacy()(text_input))
    assert summary == text_input
    assert score == 0.