示例#1
0
    def inference(self, tst, output, quiet=False):

        if self.config.show_svg:
            output.write("<html>\n<body>\n")
        nbatches = (len(tst) + self.config.batch_size -
                    1) // self.config.batch_size
        score = Score()
        n_sents = 0
        for iter, (src_batch, tgt_batch, raw_src_batch, raw_tgt_batch,
                   sign_src_batch, sign_tgt_batch, sign_batch, len_src_batch,
                   len_tgt_batch) in enumerate(
                       minibatches(tst, self.config.batch_size)):
            fd = self.get_feed_dict(src_batch, tgt_batch, sign_src_batch,
                                    sign_tgt_batch, sign_batch, len_src_batch,
                                    len_tgt_batch, 0.0)

            if self.config.mode == "sentence":
                out_batch, last_src_batch, last_tgt_batch = self.sess.run(
                    [self.output, self.last_src, self.last_tgt], feed_dict=fd)
                if tst.annotated:
                    score.add_batch(out_batch, sign_batch)
                for i_sent in range(len(out_batch)):
                    n_sents += 1
                    v = Visualize(output, n_sents, raw_src_batch[i_sent],
                                  raw_tgt_batch[i_sent], out_batch[i_sent])
                    last_src = []
                    last_tgt = []
                    if self.config.show_last:
                        last_src = last_src_batch[i_sent]
                        last_tgt = last_tgt_batch[i_sent]
                    v.print_vectors(last_src,
                                    last_tgt,
                                    aggr_src=[],
                                    aggr_tgt=[],
                                    align=[],
                                    quiet=quiet)
            else:
                align_batch, aggr_src_batch, aggr_tgt_batch, out_src_batch, out_tgt_batch, last_src_batch, \
                    last_tgt_batch, sim_batch = self.sess.run([self.align, self.aggregation_src,
                                                               self.aggregation_tgt, self.out_src, self.out_tgt,
                                                               self.last_src, self.last_tgt, self.cos_similarity],
                                                              feed_dict=fd)
                if tst.annotated:
                    score.add_batch_tokens(aggr_src_batch, sign_src_batch,
                                           len_src_batch)
                    score.add_batch_tokens(aggr_tgt_batch, sign_tgt_batch,
                                           len_tgt_batch)
                for i_sent in range(len(align_batch)):
                    n_sents += 1
                    v = Visualize(output, n_sents, raw_src_batch[i_sent],
                                  raw_tgt_batch[i_sent], sim_batch[i_sent])
                    if self.config.show_svg:
                        v.print_svg(aggr_src_batch[i_sent],
                                    aggr_tgt_batch[i_sent],
                                    align_batch[i_sent])
                    elif self.config.show_matrix:
                        v.print_matrix(aggr_src_batch[i_sent],
                                       aggr_tgt_batch[i_sent],
                                       align_batch[i_sent])
                    else:
                        last_src = []
                        last_tgt = []
                        aggr_src = []
                        aggr_tgt = []
                        align = []
                        if self.config.show_last:
                            last_src = last_src_batch[i_sent]
                            last_tgt = last_tgt_batch[i_sent]
                        if self.config.show_aggr:
                            aggr_src = aggr_src_batch[i_sent]
                            aggr_tgt = aggr_tgt_batch[i_sent]
                        if self.config.show_align:
                            align = align_batch[i_sent]
                        v.print_vectors(last_src,
                                        last_tgt,
                                        aggr_src,
                                        aggr_tgt,
                                        align,
                                        quiet=quiet)

        if tst.annotated:
            score.update()
            unk_s = float(100) * tst.nunk_src / tst.nsrc
            unk_t = float(100) * tst.nunk_tgt / tst.ntgt
            div_s = float(100) * tst.ndiv_src / tst.nsrc
            div_t = float(100) * tst.ndiv_tgt / tst.ntgt
            sys.stderr.write(
                'TEST words={}/{} %div={:.2f}/{:.2f} %unk={:.2f}/{:.2f} (A{:.4f},P{:.4f},R{:.4f},F{:.4f})'
                ' (TP:{},TN:{},FP:{},FN:{})\n'.format(tst.nsrc, tst.ntgt,
                                                      div_s, div_t, unk_s,
                                                      unk_t, score.A, score.P,
                                                      score.R, score.F,
                                                      score.TP, score.TN,
                                                      score.FP, score.FN))

        if self.config.show_svg:
            output.write("</body>\n</html>\n")
示例#2
0
    def inference(self, tst):

        if self.config.show_svg: print "<html>\n<body>"
        nbatches = (len(tst) + self.config.batch_size -
                    1) // self.config.batch_size
        score = Score()
        n_pos = 0
        n_sents = 0

        for iter, (src_batch, tgt_batch, ali_batch, ali_src_batch,
                   ali_tgt_batch, sim_batch, raw_src_batch, raw_tgt_batch,
                   len_src_batch, len_tgt_batch) in enumerate(
                       minibatches(tst, self.config.batch_size)):
            fd = self.get_feed_dict(src_batch, tgt_batch, ali_batch,
                                    ali_src_batch, ali_tgt_batch, sim_batch,
                                    len_src_batch, len_tgt_batch, 0.0)

            align, snt_src, snt_tgt, align_src, align_tgt, sim = self.sess.run(
                [
                    self.align, self.snt_src, self.snt_tgt, self.align_src,
                    self.align_tgt, self.cos_similarity
                ],
                feed_dict=fd)
            n_pos += sum(np.greater(sim, np.zeros_like(sim)))

            if tst.annotated:
                if self.config.error == 'lse':
                    score.add_batch(
                        np.concatenate([align_src, align_tgt], 1),
                        np.concatenate([ali_src_batch, ali_tgt_batch], 1), sim,
                        sim_batch, 0.0, 0.0, 0.0)
                else:
                    score.add_batch(align, ali_batch, sim, sim_batch, 0.0, 0.0,
                                    0.0)

            for i_sent in range(len(align)):
                n_sents += 1
                v = Visualize(n_sents, src_batch[i_sent], tgt_batch[i_sent],
                              raw_src_batch[i_sent], raw_tgt_batch[i_sent],
                              sim[i_sent], align[i_sent], align_src[i_sent],
                              align_tgt[i_sent], snt_src[i_sent],
                              snt_tgt[i_sent], self.config.mark_unks)
                if self.config.show_svg: v.print_svg()
                elif self.config.show_matrix: v.print_matrix()
                else:
                    v.print_vectors(self.config.show_sim,
                                    self.config.show_align)

        if tst.annotated:
            curr_time = time.strftime("[%Y-%m-%d_%X]", time.localtime())
            sys.stderr.write('{} TEST ({})'.format(curr_time,
                                                   score.summarize()))
            unk_s = float(100) * tst.nunk_src / tst.nsrc
            unk_t = float(100) * tst.nunk_tgt / tst.ntgt
            sys.stderr.write(
                ' Test set: words={}/{} %ones={:.2f} pair={} unpair={} delete={} extend={} replace={} %unk={:.2f}/{:.2f}\n'
                .format(tst.nsrc, tst.ntgt, 100.0 * tst.nones / tst.nlnks,
                        tst.npair, tst.nunpair, tst.ndelete, tst.nextend,
                        tst.nreplace, unk_s, unk_t))

        if self.config.show_svg: print "</body>\n</html>"
        sys.stderr.write(
            "Predicted similar {} out of {} examples {:.2f}%\n".format(
                n_pos, n_sents, 100.0 * n_pos / n_sents))