def update(self, preds, target): assert preds.shape == target.shape binary_topk_preds = select_topk(preds, self.top_k) target = target.to(dtype=torch.int) num_relevant = torch.sum(binary_topk_preds & target, dim=-1) top_ks = torch.tensor([self.top_k] * preds.shape[0]).to(preds.device) self.score += torch.nan_to_num(num_relevant / torch.min(top_ks, target.sum(dim=-1)), posinf=0.).sum() self.num_sample += len(preds)
def _top2(x): return select_topk(x, 2)
def _top1(x): return select_topk(x, 1)
case = _check_classification_inputs( preds, target, threshold=threshold, num_classes=num_classes, multiclass=multiclass, top_k=top_k, ) if case in (DataType.BINARY, DataType.MULTILABEL) and not top_k: preds = (preds >= threshold).int() num_classes = num_classes if not multiclass else 2 if case == DataType.MULTILABEL and top_k: preds = select_topk(preds, top_k) if case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) or multiclass: if preds.is_floating_point(): num_classes = preds.shape[1] preds = select_topk(preds, top_k or 1) else: num_classes = num_classes if num_classes else max(preds.max(), target.max()) + 1 preds = to_onehot(preds, max(2, num_classes)) target = to_onehot(target, max(2, num_classes)) if multiclass is False: preds, target = preds[:, 1, ...], target[:, 1, ...] if (case in (DataType.MULTICLASS, DataType.MULTIDIM_MULTICLASS) and multiclass is not False) or multiclass: