def test(self, sess, test_feed, num_batch=None, repeat=1, dest=sys.stdout):
        local_t = 0
        recall_bleus = []
        prec_bleus = []

        latent_z = []
        output_labels = []

        sim = test_feed.sim
        print('################################\n', sim)

        total = 0
        precision_count = 0
        clf = {v: [] for k, v in self.rev_unseen_intent.items()}

        report_pred_label = []
        report_true_label = []
        while True:
            batch = test_feed.next_batch()

            if batch is None or (num_batch is not None
                                 and local_t > num_batch):
                break
            total += len(batch[1])
            feed_dict = self.batch_2_feed(
                batch,
                None,
                use_prior=False,
                repeat=repeat,
                most_similarity=test_feed.most_similarity)
            word_outs, label_prob, z = sess.run(
                [self.dec_out_words, self.my_label_prob, self.z], feed_dict)
            sample_words = np.split(word_outs, repeat, axis=0)
            sample_label = np.split(label_prob, repeat, axis=0)

            latent_z.extend(z)
            output_labels.extend(feed_dict[self.labels])

            true_outs = feed_dict[self.io_tokens]
            true_labels = feed_dict[self.labels]
            utts_lens = feed_dict[self.io_lens]
            local_t += 1

            if dest != sys.stdout:
                if local_t % (test_feed.num_batch / 10) == 0:
                    print("%.2f >> " %
                          (test_feed.ptr / float(test_feed.num_batch)))
                    dest.write("%.2f >> " %
                               (test_feed.ptr / float(test_feed.num_batch)))

            report_true_label.extend(true_labels[::repeat])

            for b_id in range(test_feed.batch_size):
                dest.write("Batch %d index %d \n" % (local_t, b_id))
                start = np.maximum(0, utts_lens[b_id] - 5)
                # print the true outputs
                true_tokens = [
                    self.vocab[e] for e in true_outs[b_id].tolist()
                    if e not in [0, self.eos_id, self.go_id]
                ]
                true_str = " ".join(true_tokens).replace(" ' ", "'")
                label_str = self.unseen_intent[true_labels[b_id]]
                # print the predicted outputs
                dest.write("Target (%s) >> %s\n" % (label_str, true_str))
                local_tokens = []
                flag = False
                for r_id in range(repeat):
                    pred_outs = sample_words[r_id]
                    # pred_label = np.argmax(sample_label[r_id], axis=1)[b_id]
                    vec = sample_label[r_id][b_id]  # (seen_intent_size,)
                    vec2 = np.matmul(vec, sim)
                    pred_label = np.argmax(
                        vec2)  #====================================#
                    if pred_label == true_labels[b_id]:
                        flag = True
                    clf[true_labels[b_id]].append(vec)
                    pred_tokens = [
                        self.vocab[e] for e in pred_outs[b_id].tolist()
                        if e != self.eos_id and e != 0
                    ]
                    pred_str = " ".join(pred_tokens).replace(" ' ", "'")
                    dest.write(
                        "Sample %d (%s) >> %s\n" %
                        (r_id, self.unseen_intent[pred_label], pred_str))
                    local_tokens.append(pred_tokens)
                if flag:
                    precision_count += 1
                    report_pred_label.append(true_labels[b_id])
                else:
                    report_pred_label.append(pred_label)

                max_bleu, avg_bleu = utils.get_bleu_stats(
                    true_tokens, local_tokens)
                recall_bleus.append(max_bleu)
                prec_bleus.append(avg_bleu)
                # make a new line for better readability
                dest.write("\n")
        # print(report_true_label, report_pred_label)
        # print(len(report_true_label), len(report_pred_label))

        # The most easily misclassified
        count = {k: np.mean(v, axis=0).tolist() for k, v in clf.items()}
        print(count)
        dest.write(str(count) + '\n')
        a = np.array(count[0])
        b = np.array(count[1])
        c = np.array([a / (b + a), b / (a + b)]).transpose()
        c[np.isnan(c)] = 0
        c[np.isinf(c)] = 0
        test_feed.sim = c

        avg_recall_bleu = float(np.mean(recall_bleus))
        avg_prec_bleu = float(np.mean(prec_bleus))
        avg_f1 = 2 * (avg_prec_bleu * avg_recall_bleu) / (
            avg_prec_bleu + avg_recall_bleu + 10e-12)
        report = "Avg recall BLEU %f, avg precision BLEU %f and F1 %f (only 1 reference response. Not final result)" \
                 % (avg_recall_bleu, avg_prec_bleu, avg_f1)
        print(report)
        dest.write(report + "\n")
        dest.write("total sample " + str(total) + ", correct sample " +
                   str(precision_count) + " precision rate is " +
                   str(precision_count / total) + "\n")
        result = classification_report(report_true_label,
                                       report_pred_label,
                                       digits=6)
        dest.write(result + '\n')
        print("Done testing")

        return latent_z, output_labels
Beispiel #2
0
    def test(self,
             sess,
             test_feed,
             num_batch=None,
             repeat=5,
             dest=sys.stdout):  #todo repeat
        local_t = 0
        recall_bleus = []
        prec_bleus = []

        while True:
            batch = test_feed.next_batch()
            if batch is None or (num_batch is not None
                                 and local_t > num_batch):
                break
            feed_dict = self.batch_2_feed(batch, None, repeat=repeat)
            # NOTE when testing, this is where we get the predictions
            word_outs = sess.run(self.dec_out_words, feed_dict)

            # splits into 5 equal pieces
            print np.array(word_outs).shape  # (1, 5, 8)
            sample_words = np.split(np.array(word_outs), repeat, axis=0)

            # lists of true answers
            true_floor = feed_dict[self.floors]
            true_srcs = feed_dict[self.input_contexts]
            true_src_lens = feed_dict[self.context_lens]
            true_outs = feed_dict[self.output_tokens]
            local_t += 1

            if dest != sys.stdout:
                if local_t % (test_feed.num_batch / 10) == 0:
                    print("%.2f >> " %
                          (test_feed.ptr / float(test_feed.num_batch))),

            for b_id in range(test_feed.batch_size):
                # print the real/true dialog context
                dest.write("Batch %d index %d " % (local_t, b_id))
                start = np.maximum(0, true_src_lens[b_id] - 5)
                for t_id in range(start, true_srcs.shape[1], 1):
                    src_str = " ".join([
                        self.vocab[e] for e in true_srcs[b_id, t_id].tolist()
                        if e != 0
                    ])
                    dest.write("Src %d-%d: %s\n" %
                               (t_id, true_floor[b_id, t_id], src_str))

                # print the true outputs
                true_tokens = [
                    self.vocab[e] for e in true_outs[b_id].tolist()
                    if e not in [0, self.eos_id, self.go_id]
                ]
                true_str = " ".join(true_tokens).replace(" ' ", "'")

                # print the predicted outputs
                dest.write("Target >> %s\n" % (true_str))
                local_tokens = []
                for r_id in range(repeat):
                    pred_outs = sample_words[r_id]

                    pred_tokens = [
                        self.vocab[e] for e in pred_outs[b_id].tolist()
                        if e != self.eos_id and e != 0
                    ]
                    pred_str = " ".join(pred_tokens).replace(" ' ", "'")
                    dest.write("Sample %d >> %s\n" % (r_id, pred_str))
                    local_tokens.append(pred_tokens)

                max_bleu, avg_bleu = utils.get_bleu_stats(
                    true_tokens, local_tokens)
                recall_bleus.append(max_bleu)
                prec_bleus.append(avg_bleu)
                # make a new line for better readability
                dest.write("\n")

        avg_recall_bleu = float(np.mean(recall_bleus))
        avg_prec_bleu = float(np.mean(prec_bleus))
        avg_f1 = 2 * (avg_prec_bleu * avg_recall_bleu) / (
            avg_prec_bleu + avg_recall_bleu + 10e-12)
        report = "Avg recall BLEU %f, avg precision BLEU %f and F1 %f (only 1 reference response. Not final result)" \
           % (avg_recall_bleu, avg_prec_bleu, avg_f1)
        print report
        dest.write(report + "\n")
        print("Done testing")
Beispiel #3
0
    def test(self, sess, test_feed, num_batch=None, repeat=5, dest=sys.stdout):
        local_t = 0
        recall_bleus = []
        prec_bleus = []

        while True:
            batch = test_feed.next_batch()
            if batch is None or (num_batch is not None
                                 and local_t > num_batch):
                break
            feed_dict = self.batch_2_feed(batch,
                                          None,
                                          use_prior=True,
                                          repeat=repeat)
            word_outs, topic_logits = sess.run(
                [self.dec_out_words, self.topic_logits], feed_dict)
            sample_words = np.split(word_outs, repeat, axis=0)
            sample_topic = np.split(topic_logits, repeat, axis=0)

            true_srcs = feed_dict[self.input_contexts]
            true_src_lens = feed_dict[self.context_lens]
            true_outs = feed_dict[self.output_tokens]
            true_topics = feed_dict[self.output_topics]
            local_t += 1

            if dest != sys.stdout:
                if local_t % (test_feed.num_batch / 10) == 0:
                    print("%.2f >> " %
                          (test_feed.ptr / float(test_feed.num_batch))),

            for b_id in range(test_feed.batch_size):
                # print the dialog context
                dest.write("Batch %d index %d\n" % (local_t, b_id))
                for t_id in range(0, true_src_lens[b_id], 1):
                    src_str = " ".join([
                        self.vocab[e] for e in true_srcs[b_id, t_id].tolist()
                        if e != 0
                    ])
                    dest.write("Src %d: %s\n" % (t_id, src_str))
                # print the true outputs
                true_tokens = [
                    self.vocab[e] for e in true_outs[b_id].tolist()
                    if e not in [0, self.eos_id, self.go_id]
                ]
                true_str = " ".join(true_tokens).replace(" ' ", "'")
                topic_str = self.topic_vocab[true_topics[b_id]]
                # print the predicted outputs
                dest.write("Target (%s) >> %s\n" % (topic_str, true_str))
                local_tokens = []
                for r_id in range(repeat):
                    pred_outs = sample_words[r_id]
                    pred_topic = np.argmax(sample_topic[r_id], axis=1)[0]
                    pred_tokens = [
                        self.vocab[e] for e in pred_outs[b_id].tolist()
                        if e != self.eos_id and e != 0
                    ]
                    pred_str = " ".join(pred_tokens).replace(" ' ", "'")
                    dest.write("Sample %d (%s) >> %s\n" %
                               (r_id, self.topic_vocab[pred_topic], pred_str))
                    local_tokens.append(pred_tokens)

                max_bleu, avg_bleu = utils.get_bleu_stats(
                    true_tokens, local_tokens)
                recall_bleus.append(max_bleu)
                prec_bleus.append(avg_bleu)
                # make a new line for better readability
                dest.write("\n")

        avg_recall_bleu = float(np.mean(recall_bleus))
        avg_prec_bleu = float(np.mean(prec_bleus))
        avg_f1 = 2 * (avg_prec_bleu * avg_recall_bleu) / (
            avg_prec_bleu + avg_recall_bleu + 10e-12)
        report = "Avg recall BLEU %f, avg precision BLEU %f and F1 %f (only 1 reference response. Not final result)" \
                 % (avg_recall_bleu, avg_prec_bleu, avg_f1)
        print(report)
        dest.write(report + "\n")
        print("Done testing")