Example #1
0
 def test_read_ref_trans_counts(self):
     refset = read_ref_trans_counts(
         config["test_data_dir"] + "/lemma_sample_out_de-en.ref") 
     for lemma, count in [("the", 6),
                          ("union", 5),
                          ("side", 1)]:
         assert refset["test"]['1'][lemma] == count
Example #2
0
def accuracy_score(graphs, ref_fname, score_attr):
    """
    Compute approximate accuracy score
    
    Parameters
    ----------
    graphs: list of TransGraph instances
    ref_fname: str
        name of file containing lemmatized reference translation 
        in mteval format
    score_attr: str
        scoring attribute on edge (normally a model score)
        
    Returns
    -------
    Accuracy(correct, incorrect, ignored, score): named tuple
        
    Notes
    -----
    The score is approximate, because there is no alignment between source
    and target lemmas, so we cannot be sure what the correct translation of a
    lemma is. However, if the predicted translation occurs anywhere in the
    reference translations of the sentence, we can guess that it is a correct
    translations. In practice, this works reasonably well for content words,
    which tend to occur just once in a sentence. Not do for function words
    like articles or pronouns, which are likely to occur multiple times in
    the same sentence. False positives are thus to be expected.
    
    If none of the translation edges for a source lemma contains the score
    attribute, it is assumed that there is no model/prediction for it,
    and it is ignored for the purpose of calculating accuracy.
    """
    ref_trans_counts = read_ref_trans_counts(ref_fname, flatten=True)
    correct, incorrect, ignored = 0, 0, 0
    
    for graph, lemma_counts in zip(graphs, ref_trans_counts):
        log.debug(graph)
        log.debug("source lemmas: {}".format(graph.source_lemmas()))
        log.debug("reference lemma counts: {}".format(lemma_counts))
        
        for u in graph.source_nodes_iter():
            log.debug(u"checking source node {!r} with lemma {!r}".format(
                u, graph.lemma(u)))
            score, v = graph.max_score(u, score_attr)
            if v:
                target_lemma = graph.lemma(v)
                log.debug("  best translation is node {!r} with lemma {!r} "
                          "({}={:.3f})".format(
                              v, target_lemma, score_attr, score))
                if target_lemma.lower() in lemma_counts:
                    correct += 1
                    log.debug("    which is correct :-)")
                else:
                    incorrect += 1
                    log.debug("    which is NOT correct :-(")
            else:
                ignored += 1
                log.debug("  none of its translation edges have score "
                          "attribute {!r}".format(score_attr))
                
    try:
        score = correct / float(correct + incorrect)
    except ZeroDivisionError:
        log.warn("zero correct and zero incorrect; assuming zero acuracy")
        score = 0.0
        
    result = Accuracy(correct, incorrect, ignored, score)
    log.info(result)
    return result
                



    
Example #3
0
 def __call__(self, obj, *args, **kwargs):
     # It is assumed that the order of documents (by docid) in the source
     # and reference is the same
     self.counts = iter(read_ref_trans_counts(self.ref_fname, flatten=True))        
     Scorer.__call__(self, obj, *args, **kwargs)
Example #4
0
def trans_diff(inf, score_attrs, ref_fname=None, colwidth=32,
               outf=codecs.getwriter('utf8')(sys.stdout)):
    """
    Report translation differences
    
    Outputs all cases where translations differ when selected on score_attr.
    If reference translations are provided, it also shows the reference
    translation sentences as well as a guess of the reference lemma(s) per
    source lemma.
    
    Parameters
    ----------
    inf: list or str
        list of TransGraph instances or filename of pickled graphs
    score_attrs: list of strings
        list of scoring attributes
    ref_fname: str
        filename of reference translations in mteval xml format
    col_width: int
        column width
    outf: file or str
        file or filename for output
        
    Notes
    -----
    Does not support multi-word expressions
    """
    assert len(score_attrs) > 1
    
    if isinstance(inf, basestring):
        inf = cPickle.load(open(inf))
        
    if isinstance(outf, basestring):
        outf = codecs.open(outf, "w", encoding="utf-8")
        
    no_cols = 1 + len(score_attrs)
    
    if ref_fname:
        ref_trans = read_ref_trans(ref_fname, flatten=True)
        ref_counts = read_ref_trans_counts(ref_fname, flatten=True)
        no_cols += 1
    else:
        ref_lemmas = set()
    
    bar = no_cols * colwidth * u"=" + u"\n"
    subbar = no_cols * colwidth * u"-" + u"\n"        
    
    for i, graph in enumerate(inf):
        diffs = graph_trans_diff(graph, score_attrs)
        if not diffs:
            continue
        
        outf.write(bar)
        outf.write( u"SEGMENT {} (id={})\n".format(graph.graph.get("n"),
                                                   graph.graph.get("id")))
        outf.write(bar + u"\n")
        outf.write(u"SRC:   {}\n".format(
            graph.source_string()))
        if ref_fname: 
            for ref_lemmas in ref_trans[i]:
                outf.write(u"REF:   {}\n".format(ref_lemmas))
                
        outf.write(u"\n")
        outf.write(u"SRC LEMPOS:".ljust(colwidth))
        for attr in score_attrs:
            outf.write((attr.upper() + ":").ljust(colwidth))
        if ref_fname:
            outf.write(u"REF TRANS:".ljust(colwidth))
        outf.write(u"\n" + subbar)         
        
        for source_node, max_scores in diffs.iteritems():
            if ref_fname:
                ref_lemmas = get_ref_lemmas(graph, source_node, ref_counts[i])
                        
            outf.write(graph.lempos(source_node).ljust(colwidth))
                       
            for score, target_node in max_scores:
                if score is not None:
                    target_lemma = graph.lemma(target_node)
                else:
                    target_lemma = u"__NONE__"
                    
                pair = u"{}: {:.4f}: {}".format(
                    "+" if target_lemma in ref_lemmas else "-",
                    score,
                    target_lemma)
                outf.write(pair.ljust(colwidth))
                
            if ref_fname:
                outf.write(", ".join(ref_lemmas or ["---"]))
                
            outf.write(u"\n")
        outf.write(u"\n")