コード例 #1
0
ファイル: train_manager.py プロジェクト: srush/tf-fork
    def decode(self, forest, early_stop=False):
        # decoding

        oracle_bleu, oracle_trans, oracle_fv, _ = forest.compute_oracle(Vector(), model_weight=0.0, bleu_weight=1.0)

        (score, best_trans, best_fv) = self.decoder.beam_search(forest, b=FLAGS.beam)

        # self.write_model("", best_fv)
        best_lm = self._add_back_language_model(best_trans)
        # best_lm1 = self._add_back_uni_language_model(best_trans)
        # assert  best_lm == best_fv["lm"], "%s, %s "% (best_lm, best_fv["lm"])
        # assert  best_lm1 == best_fv["lm1"], "%s, %s "% (best_lm1, best_fv["lm1"])

        oracle_fv["lm"] = self._add_back_language_model(oracle_trans)
        # print "oracle", oracle_fv["lm1"]
        # oracle_fv["lm1"] = self._add_back_uni_language_model(oracle_trans)

        # print "test score " + str(score)
        # print "dot test " + str(self.weights.dot(best_fv))
        test_bleu = forest.bleu.rescore(best_trans)
        oracle_bleu = forest.bleu.rescore(oracle_trans)

        # print "oracle test " + str(self.weights.dot(oracle_fv))
        # print best_trans
        # print oracle_trans

        #     best_bleu = test_bleu
        #     oracle_fv = best_fv
        #     for k, (sc, trans, fv) in enumerate(forest.root.hypvec, 1):
        #       hyp_bleu = forest.bleu.rescore(trans)
        #       if hyp_bleu > best_bleu:
        #         oracle_fv = fv
        #         best_bleu = hyp_bleu

        print >> sys.stderr, "Test: %s \n Oracle: %s" % (best_trans, oracle_trans)
        print >> sys.stderr, "BLEU Test: %s | Oracle: %s" % (test_bleu, oracle_bleu)
        print >> sys.stderr, "MODEL Test: %s | Oracle: %s" % (self.weights.dot(best_fv), self.weights.dot(oracle_fv))

        # delta_feats = self.clip(self.trim(-oracle_fv + best_fv))
        delta_feats = self.trim(-oracle_fv + best_fv)

        # self.write_model("",best_fv)
        #     print " ----------------------"
        #     self.write_model("",oracle_fv)
        #
        # self.write_model("", delta_feats)
        # print " ----------------------"
        # self.write_model("", self.weights)
        # self.write_model("", best_fv)
        # self.write_model("", oracle_fv)

        test_bleu = forest.bleu.rescore((best_trans))
        return (forest.bleu, delta_feats)
コード例 #2
0
ファイル: train_manager.py プロジェクト: srush/tf-fork
    def decode(self, forest, verbose=False):
        # decoding
        # def non_local_scorer(cedge, ders):
        #       hyp = cedge.assemble(ders)
        #       return ((0.0, Vector()),  hyp, hyp)
        #     decoder = CubePruning(MarginalDecoder.FeatureAdder(self.weights), non_local_scorer, 20, 5, find_min=False)
        #     best = decoder.run(forest.root)

        #     for i in range(min(10, len(best))):
        #       print "  Best   Trans: %s"%best[i].full_derivation
        #       forest.bleu.rescore(best[i].full_derivation)
        #       print "  Best BLEU   Score: %s"% (forest.bleu.score_ratio_str())
        #       print "  Best Score: %s"% (best[i].score[0])

        #     test_bleu = forest.bleu.rescore((best[0].full_derivation))
        # oracle_forest, oracle_item = oracle.oracle_extracter(forest, self.weights, None, 5, 2, extract=1)
        (score, subtree, fv) = forest.bestparse(self.weights, use_min=False)
        print "textlenght", fv["Basic/text-length"], self.weights["Basic/text-length"], score
        if verbose:

            print "Ref Tran %s" % forest.refs
            (oracle_score, oracle_subtree, oracle_fv, _) = forest.compute_oracle(self.weights)
            print "Oracle: " + oracle_subtree

            # oracle_forest, oracle_item = oracle.oracle_extracter(forest, self.weights, None, 100, 2, extract=100)
            # (oracle_forest_score , oracle_forest_subtree, oracle_forest_fv)  = oracle_forest.bestparse(self.weights, use_min=False)

            # example_marg, partition  = fast_inside_outside.collect_marginals(forest, self.weights)
            #       oracle_marg, oracle_partition  = fast_inside_outside.collect_marginals(oracle_forest, self.weights)
            #       print "Oracle forest likelihood: ",oracle_partition - partition

            #       print "Oracle Forest Score: ", oracle_forest_score
            #       print "Oracle Forest Results: ", oracle_forest_subtree
            #       print "Oracle Forest Bleu: ", forest.bleu.rescore(oracle_forest_subtree)
            #       parse_diff(self.weights,fv, oracle_forest_fv)

            print "Best: ", subtree
            print "Best Score: ", (score, self.weights.dot(oracle_fv))
            print "Local Score: ", forest.bleu.rescore(subtree)
            # (score, best_trans, best_fv) = self.decoder.beam_search(forest, b=FLAGS.beam)

            print "\n\n\n"
        # print delta_feats
        forest.bleu.rescore(subtree)
        return (forest.bleu, None)