示例#1
0
    def run_eval(self, filename, num_ngs):
        """Evaluate the given file and returns some evaluation metrics.

        Args:
            filename (str): A file name that will be evaluated.
            num_ngs (int): The number of negative sampling for a positive instance.

        Returns:
            dict: A dictionary that contains evaluation metrics.
        """

        load_sess = self.sess
        preds = []
        labels = []
        group_preds = []
        group_labels = []
        group = num_ngs + 1

        for batch_data_input in self.iterator.load_data_from_file(
                filename, min_seq_length=self.min_seq_length, batch_num_ngs=0):
            if batch_data_input:
                step_pred, step_labels = self.eval(load_sess, batch_data_input)
                preds.extend(np.reshape(step_pred, -1))
                labels.extend(np.reshape(step_labels, -1))
                group_preds.extend(np.reshape(step_pred, (-1, group)))
                group_labels.extend(np.reshape(step_labels, (-1, group)))

        res = cal_metric(labels, preds, self.hparams.metrics)
        res_pairwise = cal_metric(group_labels, group_preds,
                                  self.hparams.pairwise_metrics)
        res.update(res_pairwise)
        return res
示例#2
0
    def run_eval(self, filename):
        """Evaluate the given file and returns some evaluation metrics.

        Args:
            filename (str): A file name that will be evaluated.

        Returns:
            dict: A dictionary that contains evaluation metrics.
        """
        load_sess = self.sess
        preds = []
        labels = []
        imp_indexs = []
        for batch_data_input, imp_index, data_size in self.iterator.load_data_from_file(
                filename):
            step_pred, step_labels = self.eval(load_sess, batch_data_input)
            preds.extend(np.reshape(step_pred, -1))
            labels.extend(np.reshape(step_labels, -1))
            imp_indexs.extend(np.reshape(imp_index, -1))
        res = cal_metric(labels, preds, self.hparams.metrics)
        if "pairwise_metrics" in self.hparams.values():
            group_labels, group_preds = self.group_labels(
                labels, preds, imp_indexs)
            res_pairwise = cal_metric(group_labels, group_preds,
                                      self.hparams.pairwise_metrics)
            res.update(res_pairwise)
        return res
示例#3
0
    def run_eval(self, filename):
        """Evaluate the given file and returns some evaluation metrics.

        Args:
            filename (str): A file name that will be evaluated.

        Returns:
            dict: A dictionary containing evaluation metrics.
        """
        load_sess = self.sess
        group_preds = []
        group_labels = []

        for (
                batch_data_input,
                newsid_list,
                data_size,
        ) in self.iterator.load_data_from_file(filename):
            if batch_data_input:
                step_pred, step_labels = self.eval(load_sess, batch_data_input)
                group_preds.extend(step_pred)
                group_labels.extend(step_labels)

        res = cal_metric(group_labels, group_preds,
                         self.hparams.pairwise_metrics)
        return res
示例#4
0
    def run_eval(self, news_filename, behaviors_file):
        """Evaluate the given file and returns some evaluation metrics.

        Args:
            filename (str): A file name that will be evaluated.

        Returns:
            dict: A dictionary that contains evaluation metrics.
        """

        if self.support_quick_scoring:
            _, group_labels, group_preds = self.run_fast_eval(
                news_filename, behaviors_file)
        else:
            _, group_labels, group_preds = self.run_slow_eval(
                news_filename, behaviors_file)
        res = cal_metric(group_labels, group_preds, self.hparams.metrics)
        return res