示例#1
0
文件: eval.py 项目: delldu/BigGraph
    def eval(
        self,
        scores: Scores,
        batch_edges: EdgeList,
    ) -> Stats:
        batch_size = len(batch_edges)

        ranks = []
        aucs = []
        if scores.lhs_neg.nelement() > 0:
            lhs_rank = (scores.lhs_neg >=
                        scores.lhs_pos.unsqueeze(1)).sum(1) + 1
            lhs_auc = compute_randomized_auc(scores.lhs_pos, scores.lhs_neg,
                                             batch_size)
            ranks.append(lhs_rank)
            aucs.append(lhs_auc)

        if scores.rhs_neg.nelement() > 0:
            rhs_rank = (scores.rhs_neg >=
                        scores.rhs_pos.unsqueeze(1)).sum(1) + 1
            rhs_auc = compute_randomized_auc(scores.rhs_pos, scores.rhs_neg,
                                             batch_size)
            ranks.append(rhs_rank)
            aucs.append(rhs_auc)

        return Stats(
            pos_rank=average_of_sums(*ranks),
            mrr=average_of_sums(*(rank.float().reciprocal()
                                  for rank in ranks)),
            r1=average_of_sums(*(rank.le(1) for rank in ranks)),
            r10=average_of_sums(*(rank.le(10) for rank in ranks)),
            r50=average_of_sums(*(rank.le(50) for rank in ranks)),
            # At the end the AUC will be averaged over count.
            auc=batch_size * sum(aucs) / len(aucs),
            count=batch_size)
示例#2
0
    def eval(
        self,
        scores: Scores,
        batch_edges: EdgeList,
    ) -> Stats:
        batch_size = len(batch_edges)

        lhs_rank = (scores.lhs_neg >= scores.lhs_pos.unsqueeze(1)).sum(1) + 1
        rhs_rank = (scores.rhs_neg >= scores.rhs_pos.unsqueeze(1)).sum(1) + 1

        lhs_auc = compute_randomized_auc(scores.lhs_pos, scores.lhs_neg,
                                         batch_size)
        rhs_auc = compute_randomized_auc(scores.rhs_pos, scores.rhs_neg,
                                         batch_size)

        return Stats(
            pos_rank=average_of_sums(lhs_rank, rhs_rank),
            mrr=average_of_sums(lhs_rank.float().reciprocal(),
                                rhs_rank.float().reciprocal()),
            r1=average_of_sums(lhs_rank.le(1), rhs_rank.le(1)),
            r10=average_of_sums(lhs_rank.le(10), rhs_rank.le(10)),
            r50=average_of_sums(lhs_rank.le(50), rhs_rank.le(50)),
            # At the end the AUC will be averaged over count.
            auc=batch_size * (lhs_auc + rhs_auc) / 2,
            count=batch_size)
示例#3
0
    def _process_one_batch(self, model: MultiRelationEmbedder,
                           batch_edges: EdgeList) -> Stats:

        with torch.no_grad():
            scores = model(batch_edges)

        self._adjust_scores(scores, batch_edges)

        batch_size = len(batch_edges)

        loss = self.calc_loss(scores, batch_edges)

        ranks = []
        aucs = []
        if scores.lhs_neg.nelement() > 0:
            lhs_rank = (scores.lhs_neg >=
                        scores.lhs_pos.unsqueeze(1)).sum(1) + 1
            lhs_auc = compute_randomized_auc(scores.lhs_pos, scores.lhs_neg,
                                             batch_size)
            ranks.append(lhs_rank)
            aucs.append(lhs_auc)

        if scores.rhs_neg.nelement() > 0:
            rhs_rank = (scores.rhs_neg >=
                        scores.rhs_pos.unsqueeze(1)).sum(1) + 1
            rhs_auc = compute_randomized_auc(scores.rhs_pos, scores.rhs_neg,
                                             batch_size)
            ranks.append(rhs_rank)
            aucs.append(rhs_auc)

        return Stats(
            loss=float(loss),
            pos_rank=average_of_sums(*ranks),
            mrr=average_of_sums(*(rank.float().reciprocal()
                                  for rank in ranks)),
            r1=average_of_sums(*(rank.le(1) for rank in ranks)),
            r10=average_of_sums(*(rank.le(10) for rank in ranks)),
            r50=average_of_sums(*(rank.le(50) for rank in ranks)),
            # At the end the AUC will be averaged over count.
            auc=batch_size * sum(aucs) / len(aucs),
            count=batch_size,
        )