def test_binary_jaccard_loss(): eps = 1e-5 criterion = JaccardLoss(mode="binary", from_logits=False) # Ideal case y_pred = torch.tensor([1.0]).view(1, 1, 1, 1) y_true = torch.tensor(([1])).view(1, 1, 1, 1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(0.0, abs=eps) y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, 1, -1) y_true = torch.tensor(([1, 0, 1])).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(0.0, abs=eps) y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, 1, -1) y_true = torch.tensor(([0, 0, 0])).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(0.0, abs=eps) # Worst case y_pred = torch.tensor([1.0, 1.0, 1.0]).view(1, 1, -1) y_true = torch.tensor([0, 0, 0]).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(0.0, abs=eps) y_pred = torch.tensor([1.0, 0.0, 1.0]).view(1, 1, -1) y_true = torch.tensor([0, 1, 0]).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(1.0, eps) y_pred = torch.tensor([0.0, 0.0, 0.0]).view(1, 1, -1) y_true = torch.tensor([1, 1, 1]).view(1, 1, 1, -1) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(1.0, eps)
def __init__(self, hparams): super().__init__() self.hparams = hparams self.model = object_from_dict(self.hparams["model"]) if "resume_from_checkpoint" in self.hparams: corrections: Dict[str, str] = {"model.": ""} state_dict = state_dict_from_disk( file_path=self.hparams["resume_from_checkpoint"], rename_in_layers=corrections, ) self.model.load_state_dict(state_dict) self.losses = [ ("jaccard", 0.1, JaccardLoss(mode="binary", from_logits=True)), ("focal", 0.9, BinaryFocalLoss()), ]
def __init__(self, hparams): super().__init__() self.hparams = hparams self.model = object_from_dict(self.hparams["model"]) if "resume_from_checkpoint" in self.hparams: corrections: Dict[str, str] = {"model.": ""} checkpoint = load_checkpoint( file_path=self.hparams["resume_from_checkpoint"], rename_in_layers=corrections, ) self.model.load_state_dict(checkpoint["state_dict"]) if hparams["sync_bn"]: self.model = apex.parallel.convert_syncbn_model(self.model) self.losses = [ ("jaccard", 0.1, JaccardLoss(mode="binary", from_logits=True)), ("focal", 0.9, BinaryFocalLoss()), ]
def test_multilabel_jaccard_loss(): eps = 1e-5 criterion = JaccardLoss(mode="multilabel", from_logits=False) # Ideal case y_pred = torch.tensor([[[1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0]]]) y_true = torch.tensor([[[1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0]]]) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(0.0, abs=eps) # Worst case y_pred = torch.tensor([[[1.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 1.0]]]) y_true = 1 - y_pred loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(1.0, abs=eps) # 1 - 1/3 case y_pred = torch.tensor([[[0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0]]]) y_true = torch.tensor([[[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 0.0]]]) loss = criterion(y_pred, y_true) assert float(loss) == pytest.approx(1.0 - 1.0 / 3.0, abs=eps)