コード例 #1
0
def test_basic():

    levels = torch.tensor([[1., 1., 0., 0.], [1., 0., 0., 0.],
                           [1., 1., 1., 1.]])

    logits = torch.tensor([[2.1, 1.8, -2.1, -1.8], [1.9, -1., -1.5, -1.3],
                           [1.9, 1.8, 1.7, 1.6]])

    got_val = coral_loss(logits, levels, reduction=None)
    expect_val = torch.tensor([0.5370, 0.8951, 0.6441])
    assert torch.allclose(got_val, expect_val, rtol=1e-03, atol=1e-05)

    got_val = coral_loss(logits, levels, reduction='sum')
    expect_val = torch.tensor(2.0761)
    assert torch.allclose(got_val, expect_val, rtol=1e-03, atol=1e-05)

    got_val = coral_loss(logits, levels)
    expect_val = torch.tensor(0.6920)
    assert torch.allclose(got_val, expect_val, rtol=1e-03, atol=1e-05)
コード例 #2
0
    def shared_step(self, batch):
        x, y = batch
        levels = levels_from_labelbatch(y, num_classes=self.num_classes)
        logits, probas = self(x)
        loss = coral_loss(logits, levels)

        predicted_labels = proba_to_label(probas).float()

        n_samples = x.shape[0]
        mae = torch.sum(torch.abs(predicted_labels - y))
        mse = torch.sum((predicted_labels - y)**2)
        return loss, mae, mse, n_samples
コード例 #3
0
 def f():
     coral_loss(logits, levels, reduction='something')
コード例 #4
0
 def f():
     coral_loss(logits, levels)