Пример #1
0
    def run_evaluate(self, sess, type, data_path, test_case=1):
        data_process = DataProcess(self.hparams,
                                   data_path,
                                   type,
                                   word2id=self.word2id,
                                   test_case=test_case)

        k_list = self.hparams.recall_k_list
        total_examples = 0
        total_correct = np.zeros([len(k_list)], dtype=np.int32)
        total_mrr = 0

        index = 0

        while True:
            batch_data = data_process.get_batch_data(
                self.hparams.dev_batch_size, 100)

            if batch_data is None:
                break

            (context, _), (utterances,
                           _), _, _, _, example_id, candidates_id = batch_data

            pred_val, _ = sess.run([self.predictions, self.logits],
                                   feed_dict=self.make_feed_dict(
                                       batch_data, 1.0))

            pred_val = np.asarray(pred_val)
            num_correct, num_examples = evaluate_recall(
                pred_val, batch_data[2], k_list)
            total_mrr += mean_reciprocal_rank(pred_val, batch_data[2])

            total_examples += num_examples
            total_correct = np.add(total_correct, num_correct)

            if num_correct[5] != self.hparams.dev_batch_size:
                print(example_id, ":", index, num_correct[5])

            index += 1
            if index % 500 == 0:
                accumulated_accuracy = (total_correct / total_examples) * 100
                print("index : ", index, " | ", accumulated_accuracy)

        avg_mrr = total_mrr / (self.hparams.dev_batch_size * index)
        recall_result = ""

        for i in range(len(k_list)):
            recall_result += "Recall@%s : " % k_list[i] + "%.2f%% | " % (
                (total_correct[i] / total_examples) * 100)
        self._logger.info(recall_result)
        self._logger.info("MRR: %.4f" % avg_mrr)

        return k_list, (total_correct / total_examples) * 100, avg_mrr
Пример #2
0
    def evaluate(self, saved_file: str):

        context = tf.placeholder(tf.int32, shape=[None, None], name="context")
        context_len = tf.placeholder(tf.int32,
                                     shape=[None],
                                     name="context_len")
        utterances = tf.placeholder(tf.int32,
                                    shape=[None, None, None],
                                    name="utterances")
        utterances_len = tf.placeholder(tf.int32,
                                        shape=[None, None],
                                        name="utterances_len")
        target = tf.placeholder(tf.int32, shape=[None], name="target")

        # logits
        with tf.variable_scope("inference", reuse=False):
            logits = self._inference(context, context_len, utterances,
                                     utterances_len)

        predictions = argsort(logits, axis=1, direction='DESCENDING')

        sess = tf.Session()
        sess.run(tf.global_variables_initializer())

        saver = tf.train.Saver()
        saver.restore(sess, saved_file)

        data = DataProcess(self.hparams.valid_path, "test", self.word2id)

        k_list = [1, 2, 5, 10, 50, 100]
        total_examples = 0
        total_correct = np.zeros([6], dtype=np.int32)

        while True:
            pad_batch_data = data.get_batch_data(self.hparams.batch_size)

            if pad_batch_data is None:
                break
            (pad_context, context_len_batch), (
                pad_utterances,
                utterances_len_batch), target_batch = pad_batch_data

            feed_dict = {
                context: pad_context,
                context_len: context_len_batch,
                utterances: pad_utterances,
                utterances_len: utterances_len_batch,
                target: target_batch
            }
            pred_val = sess.run([predictions], feed_dict=feed_dict)

            pred_val = np.asarray(pred_val).squeeze(0)
            num_correct, num_examples = evaluate_recall(
                pred_val, target_batch, k_list)

            total_examples += num_examples
            total_correct = np.add(total_correct, num_correct)

        recall_result = ""
        for i in range(len(k_list)):
            recall_result += "Recall@%s : " % k_list[i] + "%.2f%% | " % (
                (total_correct[i] / total_examples) * 100)
        self._logger.info(recall_result)