예제 #1
0
def multi_class_correct(y_pre: Tensor,
                        y_true: Tensor,
                        threshold=0.5,
                        device='cpu') -> Tensor:
    y_pre, y_true = y_pre.argmax(dim=1), y_true.argmax(dim=1)
    same = torch.as_tensor(y_pre == y_true, dtype=torch.int).to(device)
    return torch.sum(same)
예제 #2
0
def accuracy(outputs:Tensor, actual:Tensor, dim=-1)->Tensor:
    correct = (actual == outputs.argmax(dim))
    return 100 * correct.sum() // len(correct)