def test_hetero_cross_entropy_ce_only(pred): """Should behave as normal cross entropy when no superclass index is specified. """ # (1,2,3) target = torch.LongTensor([[ [-2, 1, -2], [-2, -2, 1], ]]) # Height available = torch.BoolTensor([ [True, True, True, True], ]) actual_loss = hetero_cross_entropy(pred, target, available, ignore_index=-2) actual_loss.backward() # This should always work actual_loss = actual_loss.detach().numpy() # Compute what the expected value should be pred_valid = torch.FloatTensor([[ [0.5, 0.4, 0.25, 2], [0.6, -2, 0.8, 1], ]]).permute(0, 2, 1) # (1, 4, 2) target_valid = torch.LongTensor([[1, 1]]) expected = F.cross_entropy(pred_valid, target_valid) # This should be 3.0005 assert np.isclose(actual_loss, expected.detach().numpy())
def test_hetero_cross_entropy_weight_tensor_no_class(pred): """Test both parts (ce_loss + super_loss) combined plus weight""" target = torch.LongTensor([[ [-1, 1, -2], [-2, -2, 1], ]]) available = torch.BoolTensor([ [True, True, False, False], ]) weight = torch.Tensor([ [1, 2, 3], [4, 5, 6], ]).reshape(1, 1, 2, 3) actual_loss = hetero_cross_entropy(pred, target, available, ignore_index=-2, super_index=-1, weight=weight) actual_loss.backward() # This should always work actual_loss = actual_loss.detach().numpy() assert np.isclose(actual_loss, 14.342173)
def test_hetero_cross_entropy_smoothing_complete_alpha_near_zero(pred): """Test both parts (ce_loss + super_loss) combined + label smoothing""" target = torch.LongTensor([[ [-1, 1, -2], [-2, -2, 1], ]]) # Height available = torch.BoolTensor([ [True, True, False, False], ]) pred_softmax = F.softmax(pred, dim=1) # For inspecting/debugging purposes """ Predicted classes (where X means doesn't matter AT ALL): tensor([[[2, 3, x], [x, x, 3]]]) [0,0] -> 2 is the interesting one here. If the loss is working, this prediction should increase 2/3 class prob # pred - learning_rate * grad should get us closer. """ actual_loss = hetero_cross_entropy( pred, target, available, ignore_index=-2, super_index=-1, alpha=0.00001, ) actual_loss.backward() # This should always work pred_grad = pred.grad.detach().numpy() assert not pred_grad[:, :, 0, 2].any() assert not pred_grad[:, :, 1, 0].any() assert not pred_grad[:, :, 1, 1].any() actual_loss = actual_loss.detach().numpy() assert np.isclose(actual_loss, 3.4782794, rtol=0.01) updated_pred = pred - (0.1 * pred.grad) updated_pred_softmax = F.softmax(updated_pred, dim=1) assert (updated_pred_softmax[0, 1, 0, 1] > pred_softmax[0, 1, 0, 1] ) # prediction should be more sure of 1 target assert (updated_pred_softmax[0, 1, 1, 2] > pred_softmax[0, 1, 1, 2] ) # prediction should be more sure of 1 target # The classes that are in the dataset should now have lower probabilities # for the pixel that's marked as unlabeled. assert updated_pred_softmax[0, 0, 0, 0] < pred_softmax[0, 0, 0, 0] assert updated_pred_softmax[0, 1, 0, 0] < pred_softmax[0, 1, 0, 0]
def test_hetero_cross_entropy_super_only_simple(): pred = torch.Tensor([-1, 0, 1, 2]).reshape(1, 4, 1) # logits target = torch.full((1, 1), -1, dtype=torch.long) available = torch.BoolTensor([True, True, False, False]).reshape(1, -1) actual_loss = hetero_cross_entropy(pred, target, available, ignore_index=-2, super_index=-1) actual_loss.backward() actual_loss = actual_loss.detach().numpy() assert np.isclose(0.12692809, actual_loss)
def test_hetero_cross_entropy_super_only(pred): target = torch.LongTensor([[ [-1, -1, -1], [-1, -1, -1], ]]) # Height available = torch.BoolTensor([ [True, False, False, False], ] # Only class 0 is available. ) actual_loss = hetero_cross_entropy(pred, target, available, ignore_index=-2, super_index=-1) actual_loss.backward() # This should always work actual_loss = actual_loss.detach().numpy() assert np.isclose(actual_loss, 0.14451277)
def test_hetero_cross_entropy_weight_int(pred): """Test both parts (ce_loss + super_loss) combined plus weight""" target = torch.LongTensor([[ [-1, 1, -2], [-2, -2, 1], ]]) available = torch.BoolTensor([ [True, True, False, False], ]) actual_loss = hetero_cross_entropy(pred, target, available, ignore_index=-2, super_index=-1, weight=2) actual_loss.backward() # This should always work actual_loss = actual_loss.detach().numpy() assert np.isclose(actual_loss, 2 * 3.4782794)
def test_hetero_cross_entropy_all_invalid(pred): """This primarily tests that when all invalid data is provided (i.e. the returned loss is 0) that the ``backward()`` method still works. """ target = torch.LongTensor([[ [-2, -2, -2], [-2, -2, -2], ]]) # Height available = torch.BoolTensor([ [True, True, True, True], ]) actual_loss = hetero_cross_entropy(pred, target, available, ignore_index=-2, super_index=-1) actual_loss.backward() # This should always work actual_loss = actual_loss.detach().numpy() assert actual_loss == 0