def test_ece(): ece_calculator = ECE(n_bins=3) # start with multiclass classification pred = torch.FloatTensor([[-40, 50, 10], [10, 80, 10]]) target = torch.LongTensor([[2], [1]]) for i in range(2): ece_calculator.update(output=pred[i, :].unsqueeze(0), target=target[i, :].unsqueeze(0)) assert np.allclose(ece_calculator.samples, [0, 0, 2]) assert np.allclose(ece_calculator.tp, [0, 0, 1]) assert round(ece_calculator.value, 2) == 0.5 pth = 'tmp' Path(pth).mkdir(exist_ok=True) ece_calculator.plot(pth=os.path.join(pth, 'figure.png')) assert os.path.exists(os.path.join(pth, 'figure.png')) shutil.rmtree(pth) ece_calculator.reset() # start with multiclass classification pred = torch.FloatTensor([[0.4, 0.5, 0.1], [0.1, 0.8, 0.1]]) target = torch.LongTensor([[2], [1]]) for i in range(2): ece_calculator.update(output=pred[i, :].unsqueeze(0), target=target[i, :].unsqueeze(0)) assert np.allclose(ece_calculator.samples, [0, 1, 1]) assert np.allclose(ece_calculator.tp, [0, 0, 1]) assert round(ece_calculator.value, 2) == 0.25
def test_ece(): ece_calculator = ECE(n_bins=3) # start with multiclass classification pred = torch.FloatTensor([[0.4, 0.5, 0.1], [0.1, 0.8, 0.1]]) target = torch.LongTensor([[2], [1]]) for i in range(2): ece_calculator.update(output=pred[i, :].unsqueeze(0), target=target[i, :].unsqueeze(0)) assert np.allclose(ece_calculator.samples, [0, 1, 1]) assert np.allclose(ece_calculator.tp, [0, 0, 1]) assert round(ece_calculator.value, 2) == 0.25
def __init__(self, wrapper: ModelWrapper, num_classes: int, lr: float, reg_factor: float, mu: float = None): self.num_classes = num_classes self.criterion = nn.CrossEntropyLoss() self.lr = lr self.reg_factor = reg_factor self.mu = mu or reg_factor self.dirichlet_linear = nn.Linear(self.num_classes, self.num_classes) self.model = nn.Sequential(wrapper.model, self.dirichlet_linear) self.wrapper = ModelWrapper(self.model, self.criterion) self.wrapper.add_metric("ece", lambda: ECE()) self.wrapper.add_metric("ece", lambda: ECE_PerCLs(num_classes))