def test_focal_class_is_scribtable(): """check that script gives the same results""" input_bad = torch.Tensor([-1, 2, 0]).float() target = torch.Tensor([1, 0, 1]) loss = losses.FocalLoss()(input_bad, target) jit_class = torch.jit.script(losses.FocalLoss()) loss_jit = jit_class(input_bad, target) assert torch.allclose(loss, loss_jit)
def test_binary_focal_3d(): """This test validates that it works for 3d by comparing results to 1d reshaped data""" INP_3d = torch.rand(BS, 1, IM_SIZE, IM_SIZE, IM_SIZE) TARGET_3d = torch.randint(0, 2, (BS, 1, IM_SIZE, IM_SIZE, IM_SIZE)).float() my_focal_3d = losses.FocalLoss(mode="binary")(INP_3d, TARGET_3d) my_focal_1d = losses.FocalLoss(mode="binary")(INP_3d.view(BS, 1, -1), TARGET_3d.view(BS, 1, -1)) assert torch.allclose(my_focal_3d, my_focal_1d)
def test_focal_loss_modes(): # check that multilabel == one-hot-encoded multiclass fl_multiclass = losses.FocalLoss(mode="multiclass", reduction="sum")(INP_IMG, TARGET_IMG_MULTICLASS) fl_multilabel = losses.FocalLoss(mode="multilabel", reduction="sum")(INP_IMG, TARGET_IMG_MULTILABEL) assert fl_multiclass == fl_multilabel # check that ignore index works for multiclass fl = losses.FocalLoss(mode="multiclass", reduction="none")(INP_IMG, TARGET_IMG_MULTICLASS) loss_diff = fl[:, :, :2, :2].sum() y_true = TARGET_IMG_MULTICLASS.clone() y_true[:, :2, :2] = -100 fl_i = losses.FocalLoss(mode="multiclass", reduction="sum", ignore_label=-100)(INP_IMG, y_true) assert torch.allclose(fl.sum() - loss_diff, fl_i) # check that ignore index works for binary fl = losses.FocalLoss(mode="binary", reduction="none")(INP_IMG_BINARY, TARGET_IMG_BINARY) loss_diff = fl[:, :, :2, :2].sum() y_true = TARGET_IMG_BINARY.clone() y_true[:, :, :2, :2] = -100 fl_i = losses.FocalLoss(mode="binary", reduction="sum", ignore_label=-100)(INP_IMG_BINARY, y_true) assert torch.allclose(fl.sum() - loss_diff, fl_i)
def test_focal_incorrect_reduction(): with pytest.raises(ValueError): losses.FocalLoss(reduction="some_reduction")
def test_focal_incorrect_mode(): with pytest.raises(ValueError): losses.FocalLoss(mode="some_mode")