Пример #1
0
def test_ece():
    ece_calculator = ECE(n_bins=3)

    # start with multiclass classification
    pred = torch.FloatTensor([[-40, 50, 10], [10, 80, 10]])
    target = torch.LongTensor([[2], [1]])

    for i in range(2):
        ece_calculator.update(output=pred[i, :].unsqueeze(0),
                              target=target[i, :].unsqueeze(0))

    assert np.allclose(ece_calculator.samples, [0, 0, 2])
    assert np.allclose(ece_calculator.tp, [0, 0, 1])
    assert round(ece_calculator.value, 2) == 0.5

    pth = 'tmp'
    Path(pth).mkdir(exist_ok=True)
    ece_calculator.plot(pth=os.path.join(pth, 'figure.png'))
    assert os.path.exists(os.path.join(pth, 'figure.png'))
    shutil.rmtree(pth)

    ece_calculator.reset()

    # start with multiclass classification
    pred = torch.FloatTensor([[0.4, 0.5, 0.1], [0.1, 0.8, 0.1]])
    target = torch.LongTensor([[2], [1]])

    for i in range(2):
        ece_calculator.update(output=pred[i, :].unsqueeze(0),
                              target=target[i, :].unsqueeze(0))

    assert np.allclose(ece_calculator.samples, [0, 1, 1])
    assert np.allclose(ece_calculator.tp, [0, 0, 1])
    assert round(ece_calculator.value, 2) == 0.25
Пример #2
0
def test_ece():
    ece_calculator = ECE(n_bins=3)

    # start with multiclass classification
    pred = torch.FloatTensor([[0.4, 0.5, 0.1], [0.1, 0.8, 0.1]])
    target = torch.LongTensor([[2], [1]])

    for i in range(2):
        ece_calculator.update(output=pred[i, :].unsqueeze(0),
                              target=target[i, :].unsqueeze(0))

    assert np.allclose(ece_calculator.samples, [0, 1, 1])
    assert np.allclose(ece_calculator.tp, [0, 0, 1])
    assert round(ece_calculator.value, 2) == 0.25
Пример #3
0
    def __init__(self,
                 wrapper: ModelWrapper,
                 num_classes: int,
                 lr: float,
                 reg_factor: float,
                 mu: float = None):
        self.num_classes = num_classes
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr
        self.reg_factor = reg_factor
        self.mu = mu or reg_factor
        self.dirichlet_linear = nn.Linear(self.num_classes, self.num_classes)
        self.model = nn.Sequential(wrapper.model, self.dirichlet_linear)
        self.wrapper = ModelWrapper(self.model, self.criterion)

        self.wrapper.add_metric("ece", lambda: ECE())
        self.wrapper.add_metric("ece", lambda: ECE_PerCLs(num_classes))