コード例 #1
0
    def _decode_sent(self, args):
        """ return: is_parsed, is_result_correct """
        sentid = args
        string = Perceptron.KB.questions[sentid]
        goldexpr, _ = Perceptron.KB.logicalexprs[sentid]

        parse_result = Perceptron.parser.parse(string,
                                               filter_func=None,
                                               verbose=Perceptron.verbose)
        parse_result = simplify_expr(
            parse_result.get_expr()) if parse_result else None

        ret = {
            "sentid": sentid,
            "parsed": parse_result is not None,
            "succ": parse_result and parse_result.semantic_eq(goldexpr),
            "parse": str(parse_result),
            "match_prec": 0.0,
            "match_recall": 0.0
        }

        if parse_result:
            parse_unigrams = set(
                x.name if isinstance(x.type, ComplexType) else str(x)
                for x in collect_constants(parse_result)[0])
            intersection = parse_unigrams & Perceptron.KB.expr_unigrams[sentid]
            ret["match_prec"] = 1.0 * len(intersection) / len(
                parse_unigrams) if len(parse_unigrams) > 0 else 0
            ret["match_recall"] = 1.0 * len(intersection) / len(
                Perceptron.KB.expr_unigrams[sentid])
        return ret
コード例 #2
0
def decode_sentence(kb, sentid, weightfile):
    indepkb = IndepKnowledgeBase()
    model = Model()

    parser = Parser(indepkb, kb, model, State)

    State.model = model
    State.model.weights = pickle.load(open(weightfile))
    State.ExtraInfoGen = ExprGenerator
    ExprGenerator.setup()

    ret = parser.parse(kb.questions[sentid])
    print >> LOGS, "============================="
    print >> LOGS, simplify_expr(ret.get_expr())
    print >> LOGS, "TRACING"
    for s in ret.trace_states():
        print >> LOGS, s, s.extrainfo
コード例 #3
0
    def load_hgs(self, hgs):
        """load decoding hypergraph to beam"""
        stepsize = len(hgs) / 10
        start_t = time.time()
        for j, (sentid, hg) in enumerate(hgs.iteritems()):
            if stepsize != 0 and j % stepsize == 0:
                print >> LOGS, ".",

            beam = self.load_ref_hg(sentid, hg)

            self.ref_beams[sentid] = beam

            if _sanity_check:
                for s in beam[-1]:
                    goldexpr = Perceptron.KB.logicalexprs[sentid][0]
                    refexpr = simplify_expr(s.extrainfo)
                    if not goldexpr.semantic_eq(refexpr):
                        print >> LOGS, "inconsistant at sent", sentid
                        print >> LOGS, "gold", goldexpr
                        print >> LOGS, "ref", refexpr
        end_t = time.time()
        print >> LOGS, "taking %fs" % (end_t - start_t)
コード例 #4
0
    def decode(self, sentid):
        string = ForcedDecoder.questions[sentid]
        goldexpr = ForcedDecoder.logicalexprs[sentid]

        if ForcedDecoder.verbose:
            print >> logs, string
            print >> logs, goldexpr

        _TS.TE = goldexpr[1]

        constant_unigrams, constant_bigrams = collect_constants(goldexpr[0])

        constant_bigrams.add(('and', 'and'))
        constant_bigrams.add(('or', 'or'))

        if ForcedDecoder.verbose:
            print >> logs, constant_unigrams
            print >> logs, constant_bigrams

        predicate_candidates = sorted(
            [(expr, TypeEnv()) for expr in constant_unigrams
             if expr.name not in ConstraintKB.__builtin_constant__],
            key=lambda t: t[0].name)
        constant_unigram_names = set(c.name for c in constant_unigrams)

        cKB = ConstraintKB(ForcedDecoder.KB, predicate_candidates)
        cKB.adjs = ForcedDecoder.KB.adjs
        cKB.pmi = ForcedDecoder.KB.pmi

        ForcedDecoder.parser.KB = cKB

        def filter_expr(expr):
            if expr:
                u, b = collect_constants(expr)
                if any(c.name not in constant_unigram_names for c in u):
                    return False
                return b.issubset(constant_bigrams)
            return True

        ForcedDecoder.parser.parse(string,
                                   filter_func=filter_expr,
                                   verbose=ForcedDecoder.verbose)

        matches = []

        if ForcedDecoder.verbose:
            print >> logs, "========== checking last beam step =========="

        for candidate in ForcedDecoder.parser.beam[-1]:
            if candidate.extrainfo:
                r = simplify_expr(candidate.extrainfo)
                if r.semantic_eq(goldexpr[0]):
                    matches.append(candidate)
                else:
                    pass

        if ForcedDecoder.verbose:
            print >> logs, "========== %d matches ==========" % len(matches)

            for candidate in matches:
                print >> logs, candidate.trace_states()

        hg = {
        }  # hypergraph, each item is state_id -> (action, match, incomings)
        for candidate in matches:
            for s in candidate.trace_states():
                if s.state_id not in hg:
                    incomings = []
                    for (n1, n2) in s.incomings:
                        n1id = n1.state_id if n1 else None
                        n2id = n2.state_id if n2 else None
                        if ((n1id and n1id in hg) or not n1id) and \
                                ((n2id and n2id in hg) or not n2id):
                            incomings.append((n1id, n2id))
                        else:
                            print >> logs, s.state_id
                            print >> logs, n1id, n2id, hg.keys()
                            print >> logs, s.trace_states()
                            assert False
                    hg[s.state_id] = (s.action, s.match, s.ruleid, incomings)

        return len(matches), hg
コード例 #5
0
    def learn_one(self, args):
        sentid, update_c = args

        start_t = time.time()

        if Perceptron.ncpus > 0:
            Perceptron.shared_memo.seek(0)
            update_weights = pickle.load(Perceptron.shared_memo)
            Perceptron.model.weights.iadd_wstep(update_weights, update_c)
            Perceptron.c += 1

        ret = {
            "c": Perceptron.c,
            "start_t": start_t,
            "match_prec": 0,
            "match_recall": 0
        }

        if Perceptron.verbose:
            print >> LOGS, "sent", sentid
        if sentid == -1:
            ret["end_t"] = time.time()
            return None, ret

        string = Perceptron.KB.questions[sentid]
        goldexpr, goldte = Perceptron.KB.logicalexprs[sentid]
        if Perceptron.verbose:
            print >> LOGS, string
            print >> LOGS, goldexpr

        goldbeam = None
        if sentid in Perceptron.ref_beams:
            goldbeam = Perceptron.ref_beams[sentid]
        elif Perceptron.ontheflyfd:
            Perceptron.ForcedDecoder.parser.dp = False
            Perceptron.ForcedDecoder.parser.beamwidth = Perceptron.fdbeamsize
            (fdmatches, fdhg) = Perceptron.ForcedDecoder.decode(sentid)
            if fdmatches > 0:
                goldbeam = self.load_ref_hg(sentid, fdhg)

        Perceptron.parser.beamwidth = Perceptron.beamsize

        parse_result = Perceptron.parser.parse(string,
                                               filter_func=None,
                                               verbose=False)
        parse_result = simplify_expr(
            parse_result.get_expr()) if parse_result else None

        ret["parse"] = True if parse_result else False

        if Perceptron.verbose:
            print >> LOGS, "==>", parse_result

        if parse_result and parse_result.semantic_eq(goldexpr):
            ret["end_t"] = time.time()
            ret["match_recall"] = 1.0
            ret["match_prec"] = 1.0
            return True, ret
        elif goldbeam:
            # calculate partial predicate match
            viterbi_unigrams = collect_constants(
                parse_result)[0] if parse_result else set()
            viterbi_unigrams = set(
                x.name if isinstance(x.type, ComplexType) else str(x)
                for x in viterbi_unigrams)
            unigrams_intersection = viterbi_unigrams & Perceptron.KB.expr_unigrams[
                sentid]

            ret["match_prec"] = 1.0 * len(unigrams_intersection) / len(viterbi_unigrams) if len(viterbi_unigrams) > 0 \
                else 0.0
            ret["match_recall"] = 1.0 * len(unigrams_intersection) / len(
                Perceptron.KB.expr_unigrams[sentid])

            gb_ = self.rerank_beam(goldbeam)
            gb = None

            if Perceptron.single_gold:
                gb = [[] for i in xrange(len(gb_))]
                for item in gb_[-1][0].trace_states():
                    gb[item.step].append(item)
            else:
                gb = gb_

            maxstep = -1
            maxdiff = -float("inf")

            for i in xrange(len(gb)):
                if len(gb[i]) > 0 and len(Perceptron.parser.beam[i]) > 0:
                    scorediff = Perceptron.parser.beam[i][0].score - gb[i][
                        0].score

                    if Perceptron.verbose:
                        golditem = gb[i][0]
                        viterbiitem = Perceptron.parser.beam[i][0]
                        print >> LOGS, "at %d: %f"%(i, scorediff), \
                            "\t\tgoldbeam len", len(gb[i]), \
                            "viterbi beam len", len(Perceptron.parser.beam[i])
                        print >> LOGS, "\tgold score", golditem.score, \
                            "actioncost", golditem.actioncost, \
                            "inside", golditem.inside, "shift", golditem.shiftcost, \
                            golditem, golditem.action, golditem.match
                        print >> LOGS, "\t\tgold.incomings", golditem.incomings, \
                            golditem.incomings[0][0].score if golditem.incomings else None
                        print >> LOGS, "\t\tgold.left", golditem.leftptrs
                        print >> LOGS, "\t\t", golditem.get_expr()
                        print >> LOGS, "\tviterbi score", viterbiitem.score, \
                            "actioncost", viterbiitem.actioncost, \
                            "inside", viterbiitem.inside, "shift", \
                            viterbiitem.shiftcost, viterbiitem, viterbiitem.action, viterbiitem.match
                        print >> LOGS, "\t\tviterbi.incomings", viterbiitem.incomings, \
                            viterbiitem.incomings[0][0].score if viterbiitem.incomings else None
                        print >> LOGS, "\t\tviterbi.left", viterbiitem.leftptrs
                        print >> LOGS, "\t\t", viterbiitem.get_expr()

                    if scorediff >= 0 and scorediff >= maxdiff:
                        maxdiff = scorediff
                        maxstep = i

            assert maxstep != -1, "max violation not found"

            viterbistate = Perceptron.parser.beam[maxstep][0]
            goldstate = gb[maxstep][0]

            viterbifeats = viterbistate.recover_feats()
            goldfeats = goldstate.recover_feats()

            deltafeats = goldfeats.iaddc(viterbifeats, -1)

            if _sanity_check:
                viterbiscore = Perceptron.model.weights.dot(viterbifeats)
                if abs(viterbiscore - viterbistate.score) > State.epsilon:
                    print >> LOGS, "wrong viterbi score", viterbiscore, viterbistate.score, str(
                        viterbistate)
                    print >> LOGS, viterbifeats
                    print >> LOGS, State.model.eval_module.static_eval(
                        *(viterbistate.get_atomic_feats()))
                    assert False
                scorediff_ = Perceptron.model.weights.dot(deltafeats)
                if abs(scorediff_ + maxdiff) > State.epsilon:
                    print >> LOGS, "wrong max violation", scorediff_, maxdiff
                    assert False

            ret["deltafeats"] = deltafeats
            ret["sentid"] = sentid

            ret["end_t"] = time.time()

            return False, ret
        else:
            ret["end_t"] = time.time()
            return False, ret