def test_basic(): levels = torch.tensor([[1., 1., 0., 0.], [1., 0., 0., 0.], [1., 1., 1., 1.]]) logits = torch.tensor([[2.1, 1.8, -2.1, -1.8], [1.9, -1., -1.5, -1.3], [1.9, 1.8, 1.7, 1.6]]) got_val = coral_loss(logits, levels, reduction=None) expect_val = torch.tensor([0.5370, 0.8951, 0.6441]) assert torch.allclose(got_val, expect_val, rtol=1e-03, atol=1e-05) got_val = coral_loss(logits, levels, reduction='sum') expect_val = torch.tensor(2.0761) assert torch.allclose(got_val, expect_val, rtol=1e-03, atol=1e-05) got_val = coral_loss(logits, levels) expect_val = torch.tensor(0.6920) assert torch.allclose(got_val, expect_val, rtol=1e-03, atol=1e-05)
def shared_step(self, batch): x, y = batch levels = levels_from_labelbatch(y, num_classes=self.num_classes) logits, probas = self(x) loss = coral_loss(logits, levels) predicted_labels = proba_to_label(probas).float() n_samples = x.shape[0] mae = torch.sum(torch.abs(predicted_labels - y)) mse = torch.sum((predicted_labels - y)**2) return loss, mae, mse, n_samples
def f(): coral_loss(logits, levels, reduction='something')
def f(): coral_loss(logits, levels)