Esempio n. 1
0
    def __call__(self, probs: Tensor, target: Tensor,
                 bounds: Tensor) -> Tensor:
        assert simplex(probs) and simplex(target)
        assert probs.shape == target.shape

        b, c, w, h = probs.shape  # type: Tuple[int, int, int, int]
        k = bounds.shape[2]  # scalar or vector
        value: Tensor = self.__fn__(probs[:, self.idc, ...])
        lower_b = bounds[:, self.idc, :, 0]
        upper_b = bounds[:, self.idc, :, 1]

        assert value.shape == (b, self.C, k), value.shape
        assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape

        too_big: Tensor = (value > upper_b).type(self.dtype)
        too_small: Tensor = (value < lower_b).type(self.dtype)

        big_pen: Tensor = (value - upper_b)**2
        small_pen: Tensor = (value - lower_b)**2

        res = too_big * big_pen + too_small * small_pen

        loss: Tensor = res / (w * h)

        return loss.mean()
Esempio n. 2
0
    def __call__(self, probs: Tensor, target: Tensor) -> Tensor:
        assert simplex(probs)
        assert simplex(target)
        assert probs.shape == target.shape

        B, K, *xyz = probs.shape  # type: ignore

        pc = cast(Tensor, probs[:, self.idc, ...].type(torch.float32))
        tc = cast(Tensor, target[:, self.idc, ...].type(torch.float32))
        assert pc.shape == tc.shape == (B, len(self.idc), *xyz)

        target_dm_npy: np.ndarray = np.stack([one_hot2hd_dist(tc[b].cpu().detach().numpy())
                                              for b in range(B)], axis=0)
        assert target_dm_npy.shape == tc.shape == pc.shape
        tdm: Tensor = torch.tensor(target_dm_npy, device=probs.device, dtype=torch.float32)

        pred_segmentation: Tensor = probs2one_hot(probs).cpu().detach()
        pred_dm_npy: np.nparray = np.stack([one_hot2hd_dist(pred_segmentation[b, self.idc, ...].numpy())
                                            for b in range(B)], axis=0)
        assert pred_dm_npy.shape == tc.shape == pc.shape
        pdm: Tensor = torch.tensor(pred_dm_npy, device=probs.device, dtype=torch.float32)

        delta = (pc - tc)**2
        dtm = tdm**2 + pdm**2

        multipled = einsum("bkwh,bkwh->bkwh", delta, dtm)

        loss = multipled.mean()

        return loss
Esempio n. 3
0
    def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor:
        assert simplex(probs) and simplex(target)
        assert probs.shape == target.shape

        target_sizes: Tensor = bounds[:, self.idc, :, 1]  # Dim of 1, upper and lower are the same
        volume_size: Tensor = einsum("bck->ck", target_sizes)

        lower_b = volume_size * (1 - self.margin)
        upper_b = volume_size * (1 + self.margin)

        _, _2, w, h = probs.shape  # type: Tuple[int, int, int, int]
        k = bounds.shape[2]  # scalar or vector
        value: Tensor = self.__fn__(probs[:, self.idc, ...]).sum(dim=0)
        assert value.shape == (self.C, k), value.shape
        assert lower_b.shape == upper_b.shape == (self.C, k), lower_b.shape

        too_big: Tensor = (value > upper_b).type(torch.float32)
        too_small: Tensor = (value < lower_b).type(torch.float32)

        big_pen: Tensor = (value - upper_b) ** 2
        small_pen: Tensor = (value - lower_b) ** 2

        res = too_big * big_pen + too_small * small_pen

        loss: Tensor = res / (w * h)

        return loss.mean()
Esempio n. 4
0
    def __call__(self, probs: Tensor, target: Tensor, _: Tensor) -> Tensor:
        assert simplex(probs) and simplex(target)

        pc = probs[:, self.idc, ...].type(torch.float32)
        tc = target[:, self.idc, ...].type(torch.float32)

        w: Tensor = 1 / (
            (einsum("bcwh->bc", tc).type(torch.float32) + 1e-10)**2)
        intersection: Tensor = w * einsum("bcwh,bcwh->bc", pc, tc)
        union: Tensor = w * (einsum("bcwh->bc", pc) + einsum("bcwh->bc", tc))

        divided: Tensor = 1 - 2 * (einsum("bc->b", intersection) +
                                   1e-10) / (einsum("bc->b", union) + 1e-10)

        loss_gde = divided.mean()

        log_p: Tensor = (probs[:, self.idc, ...] + 1e-10).log()
        mask_weighted = torch.einsum(
            "bcwh,c->bcwh", [tc, Tensor(self.weights).to(tc.device)])
        loss_ce = -torch.einsum("bcwh,bcwh->", [mask_weighted, log_p])
        loss_ce /= tc.sum() + 1e-10
        loss = loss_ce + self.lamb * loss_gde
        #print(loss_ce.item(),self.lamb*loss_gde.item())

        return loss
Esempio n. 5
0
    def __call__(self, probs: Tensor, target: Tensor) -> Tensor:
        assert simplex(probs) and simplex(target)

        pc = probs[:, self.idc, ...].type(torch.float32)
        tc = target[:, self.idc, ...].type(torch.float32)

        #OPTION 1: instead of dynamically changing weights batch-based, keep them static based on input weights
        if self.opt == 1:
            w: Tensor = 1 / ((self.weights + 1e-10)**2)
            intersection: Tensor = w * einsum("bkwh,bkwh->bk", pc, tc)
            union: Tensor = w * (einsum("bkwh->bk", pc) +
                                 einsum("bkwh->bk", tc))

            divided: Tensor = 1 - (2 * einsum("bk->b", intersection) +
                                   1e-10) / (einsum("bk->b", union) + 1e-10)

        #OPTION 2: imitate the computation that happens if you put in multiple/per-class GDL losses as args
        else:  #if self.opt==2:
            w: Tensor = 1 / (
                (einsum("bkwh->bk", tc).type(torch.float32) + 1e-10)**2)
            intersection: Tensor = w * einsum("bkwh,bkwh->bk", pc, tc)
            union: Tensor = w * (einsum("bkwh->bk", pc) +
                                 einsum("bkwh->bk", tc))

            divided: Tensor = self.weights.sum() - 2 * einsum(
                "bk->b",
                (intersection + 1e-10) / (union + 1e-10) * self.weights)

        loss = divided.mean()

        return loss
Esempio n. 6
0
    def __call__(self, probs: Tensor, target: Tensor, _: Tensor) -> Tensor:
        assert simplex(probs) and simplex(target)

        log_p: Tensor = (probs[:, self.idc, ...] + 1e-10).log()
        mask: Tensor = target[:, self.idc, ...].type(torch.float32)

        loss = -einsum("bcwh,bcwh->", mask, log_p)
        loss /= mask.sum() + 1e-10

        return loss
Esempio n. 7
0
    def __call__(self, probs: Tensor, target: Tensor,
                 bounds: Tensor) -> Tensor:
        assert simplex(probs) and simplex(target)

        log_p: Tensor = (probs[:, self.idc, ...] + 1e-10).log()
        mask: Tensor = target[:, self.idc, ...].type((torch.float32))
        mask_weighted = torch.einsum(
            "bcwh,c->bcwh", [mask, Tensor(self.weights).to(mask.device)])
        loss = -torch.einsum("bcwh,bcwh->", [mask_weighted, log_p])
        loss /= mask.sum() + 1e-10
        return loss
Esempio n. 8
0
    def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor:
        assert simplex(probs) and simplex(target) and sset(target, [0, 1])
        assert probs.shape == target.shape

        with torch.no_grad():
            fake_mask: Tensor = torch.zeros_like(probs)
            for i in range(len(probs)):
                fake_mask[i] = self.pathak_generator(probs[i], target[i], bounds[i])
                self.holder_size = fake_mask[i].sum()

        return super().__call__(probs, fake_mask, bounds)
Esempio n. 9
0
    def __call__(self, probs: Tensor, target: Tensor, _: Tensor,
                 ___) -> Tensor:
        assert simplex(probs) and simplex(target)

        log_p: Tensor = (probs[:, self.idc, ...] + 1e-10).log()
        mask: Tensor = cast(Tensor, target[:, self.idc,
                                           ...].type(torch.float32))

        loss = -einsum(f"bk{self.nd},bk{self.nd}->", mask, log_p)
        loss /= mask.sum() + 1e-10

        return loss
Esempio n. 10
0
    def __call__(self, probs: Tensor, target: Tensor) -> Tensor:
        assert simplex(probs) and simplex(target)

        masked_probs: Tensor = probs[:, self.idc, ...]
        log_p: Tensor = (masked_probs + 1e-10).log()
        mask: Tensor = cast(Tensor, target[:, self.idc, ...].type(torch.float32))

        w: Tensor = (1 - masked_probs)**self.gamma
        loss = - einsum("bkwh,bkwh,bkwh->", w, mask, log_p)
        loss /= mask.sum() + 1e-10

        return loss
Esempio n. 11
0
    def __call__(self, probs: Tensor, target: Tensor, _: Tensor) -> Tensor:
        assert simplex(probs) and simplex(target)

        pc = probs[:, self.idc, ...].type(torch.float32)
        tc = target[:, self.idc, ...].type(torch.float32)

        intersection: Tensor = einsum("bcwh,bcwh->bc", pc, tc)
        union: Tensor = (einsum("bcwh->bc", pc) + einsum("bcwh->bc", tc))

        divided: Tensor = 1 - (2 * intersection + 1e-10) / (union + 1e-10)

        loss = divided.mean()

        return loss
Esempio n. 12
0
    def __call__(self, probs: Tensor, target: Tensor) -> Tensor:
        assert simplex(probs) and simplex(target)

        pc = probs[:, self.idc, ...].type(torch.float32)
        tc = target[:, self.idc, ...].type(torch.float32)

        w: Tensor = 1 / ((einsum("bkwh->bk", tc).type(torch.float32) + 1e-10) ** 2)
        intersection: Tensor = w * einsum("bkwh,bkwh->bk", pc, tc)
        union: Tensor = w * (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc))

        divided: Tensor = 1 - 2 * (einsum("bk->b", intersection) + 1e-10) / (einsum("bk->b", union) + 1e-10)

        loss = divided.mean()

        return loss
Esempio n. 13
0
    def __call__(self, probs: Tensor, target: Tensor, bounds) -> Tensor:
        #print('bounds',torch.round(bounds*10**2)/10**2)
        assert simplex(
            probs
        )  # and simplex(target)  # Actually, does not care about second part

        b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]

        est_prop: Tensor = self.__fn__(probs, self.power)

        if self.curi:
            #print(bounds)
            if self.ivd:
                bounds = bounds[:, :, 0]
                #gt_prop = gt_prop[:,:,0]
            gt_prop = torch.ones_like(est_prop) * bounds / (w * h)
            gt_prop = gt_prop[:, :, 0]

        else:
            gt_prop: Tensor = self.__fn__(
                target, self.power
            )  # the power here is actually useless if we have 0/1 gt labels
        if not self.curi:
            gt_prop = gt_prop.squeeze(2)
        est_prop = est_prop.squeeze(2)
        log_est_prop: Tensor = (est_prop + 1e-10).log()

        loss = -torch.einsum("bc,bc->", [gt_prop, log_est_prop])

        assert loss.requires_grad == probs.requires_grad  # Handle the case for validation
        #print(loss)
        return loss
Esempio n. 14
0
    def __call__(self, probs: Tensor, dist_maps: Tensor, _: Tensor) -> Tensor:
        """
        net_output: (batch_size, 2, x,y,z)
        target: ground truth, shape: (batch_size, 1, x,y,z)
        """
        assert simplex(probs)
        assert not one_hot(dist_maps)

        pc = probs[:, self.idc, ...].type(torch.float32)

        with torch.no_grad():
            pc_dist = torch.from_numpy(
                compute_edts_forhdloss(pc.detach().cpu().numpy() > 0.5))
            gt_dist = dist_maps[:, self.idc, ...].type(torch.float32)

        pos_log = torch.where(gt_dist > 0, torch.log(gt_dist), gt_dist)
        twos_log = torch.where(pos_log < 0, -torch.log(-pos_log), pos_log)

        pc_pos_log = torch.where(pc_dist > 0, torch.log(pc_dist), pc_dist)
        pc_twos_log = torch.where(pc_pos_log < 0, -torch.log(-pc_pos_log),
                                  pc_pos_log)

        if pc_twos_log.device != pc.device:
            pc_twos_log = pc_twos_log.to(pc.device).type(torch.float32)

        multipled = einsum("bcwh,bcwh->bcwh", pc, twos_log)

        return multipled.mean()
Esempio n. 15
0
    def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor:
        def penalty(z: Tensor) -> Tensor:
            assert z.shape == ()

            return torch.max(torch.zeros_like(z), z)**2

        assert simplex(probs)  # and simplex(target)  # Actually, does not care about second part
        assert probs.shape == target.shape

        b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]
        _, _, k, two = bounds.shape  # scalar or vector
        assert two == 2
        # assert k == 1  # Keep it simple for now
        value: Tensor = self.__fn__(probs[:, self.idc, ...])
        lower_b = bounds[:, self.idc, :, 0]
        upper_b = bounds[:, self.idc, :, 1]

        assert value.shape == (b, self.C, k), value.shape
        assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape

        upper_z: Tensor = (value - upper_b).type(torch.float32).flatten()
        lower_z: Tensor = (lower_b - value).type(torch.float32).flatten()

        upper_penalty: Tensor = reduce(add, (penalty(e) for e in upper_z))
        lower_penalty: Tensor = reduce(add, (penalty(e) for e in lower_z))

        res: Tensor = upper_penalty + lower_penalty

        loss: Tensor = res.sum() / (w * h)
        assert loss.requires_grad == probs.requires_grad  # Handle the case for validation

        return loss
Esempio n. 16
0
    def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor:
        def log_barrier(z: Tensor) -> Tensor:
            assert z.shape == ()

            if z <= - 1 / self.t**2:
                return - torch.log(-z) / self.t
            else:
                return self.t * z + -np.log(1 / (self.t**2)) / self.t + 1 / self.t

        assert simplex(probs)  # and simplex(target)  # Actually, does not care about second part
        assert probs.shape == target.shape

        b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]
        _, _, k, two = bounds.shape  # scalar or vector
        assert two == 2
        # assert k == 1  # Keep it simple for now
        value: Tensor = self.__fn__(probs[:, self.idc, ...])
        lower_b = bounds[:, self.idc, :, 0]
        upper_b = bounds[:, self.idc, :, 1]

        assert value.shape == (b, self.C, k), value.shape
        assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape

        upper_z: Tensor = (value - upper_b).type(torch.float32).flatten()
        lower_z: Tensor = (lower_b - value).type(torch.float32).flatten()

        upper_barrier: Tensor = reduce(add, (log_barrier(e) for e in upper_z))
        lower_barrier: Tensor = reduce(add, (log_barrier(e) for e in lower_z))

        res: Tensor = upper_barrier + lower_barrier

        loss: Tensor = res.sum() / (w * h)
        assert loss.requires_grad == probs.requires_grad  # Handle the case for validation

        return loss
Esempio n. 17
0
    def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor, _) -> Tensor:
        assert simplex(probs)  # and simplex(target)  # Actually, does not care about second part
        assert probs.shape == target.shape

        # b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]
        b: int
        b, _, *im_shape = probs.shape
        _, _, k, two = bounds.shape  # scalar or vector
        assert two == 2

        value: Tensor = cast(Tensor, self.__fn__(probs[:, self.idc, ...]))
        lower_b = bounds[:, self.idc, :, 0]
        upper_b = bounds[:, self.idc, :, 1]

        assert value.shape == (b, self.C, k), value.shape
        assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape

        upper_z: Tensor = cast(Tensor, (value - upper_b).type(torch.float32)).flatten()
        lower_z: Tensor = cast(Tensor, (lower_b - value).type(torch.float32)).flatten()

        upper_penalty: Tensor = reduce(add, (self.penalty(e) for e in upper_z))
        lower_penalty: Tensor = reduce(add, (self.penalty(e) for e in lower_z))

        res: Tensor = upper_penalty + lower_penalty

        loss: Tensor = res.sum() / reduce(mul, im_shape)
        assert loss.requires_grad == probs.requires_grad  # Handle the case for validation

        return loss
Esempio n. 18
0
    def __call__(self, probs: Tensor, target: Tensor, bounds) -> Tensor:
        #print('bounds',torch.round(bounds*10**2)/10**2)
        assert simplex(probs)  # and simplex(target)  # Actually, does not care about second part

        b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]
        #if self.fgt:
        #    two = bounds.shape  # scalar or vector
        #else:
        #    _,_,k,two = bounds.shape
        #assert two == 2

        # est_prop is the proportion estimated by the network
        est_prop: Tensor = self.__fn__(probs,self.power)
        #print('est_prop',torch.round(est_prop*10**2)/10**2)
        # gt_prop is the proportion in the ground truth
        if self.curi:
            bounds = bounds[:,:,0,0] 
            #print(bounds.shape)
            gt_prop = torch.ones_like(est_prop)*bounds/(w*h)
            #gt_prop1: Tensor = self.__fn__(target,self.power) # the power here is actually useless if we have 0/1 gt labels
            #print(gt_prop,gt_prop1)
        else:
            gt_prop: Tensor = self.__fn__(target,self.power) # the power here is actually useless if we have 0/1 gt labels
        #gt_prop = (gt_prop/(w*h)).type(torch.float32).flatten()

        #value = (value/(w*h)).type(torch.float32).flatten()
        #print(gt_prop.shape)
        log_est_prop: Tensor = (est_prop + 1e-10).log()
        #print(log_est_prop.shape)
        loss = - torch.einsum("bc,bc->", [gt_prop, log_est_prop])

        assert loss.requires_grad == probs.requires_grad  # Handle the case for validation
        #print(loss)
        return loss
Esempio n. 19
0
    def __call__(self, probs: Tensor, _: Tensor, __: Tensor,
                 box_prior: List[List[Tuple[Tensor, Tensor]]]) -> Tensor:
        assert simplex(probs)

        B: int = probs.shape[0]
        assert len(box_prior) == B

        sublosses = []
        for b in range(B):
            for k in self.idc:
                masks, bounds = box_prior[b][k]

                sizes: Tensor = einsum('wh,nwh->n', probs[b, k], masks)

                assert sizes.shape == bounds.shape == (masks.shape[0],), (sizes.shape, bounds.shape, masks.shape)
                shifted: Tensor = bounds - sizes

                init = torch.zeros((), dtype=torch.float32, requires_grad=probs.requires_grad, device=probs.device)
                sublosses.append(reduce(add, (self.barrier(v) for v in shifted), init))

        loss: Tensor = reduce(add, sublosses)

        assert loss.dtype == torch.float32
        assert loss.shape == (), loss.shape

        return loss
Esempio n. 20
0
    def __call__(self, probs: Tensor, target: Tensor, _: Tensor, __) -> Tensor:
        assert simplex(probs) and simplex(target)

        b: int
        b, _, *im_shape = probs.shape

        probs_m: Tensor = probs[:, self.idc, ...]
        target_m: Tensor = cast(Tensor, target[:, self.idc, ...].type(torch.float32))

        nd: str = self.nd
        # Compute the size for each class, masked by the target pixels (where target ==1)
        masked_sizes: Tensor = einsum(f"bk{nd},bk{nd}->bk", probs_m, target_m).flatten()

        # We want that size to be <= so no shift is needed
        res: Tensor = reduce(add, (self.penalty(e) for e in masked_sizes))  # type: ignore

        loss: Tensor = res / reduce(mul, im_shape)
        assert loss.shape == ()
        assert loss.requires_grad == probs.requires_grad  # Handle the case for validation

        return loss
Esempio n. 21
0
    def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor:
        assert simplex(probs)
        assert probs.shape == target.shape
        assert len(self.mask_idc) == 1, "Cannot handle more at the time, I guess"

        b, c, w, h = probs.shape

        fake_probs: Tensor = torch.zeros_like(probs, dtype=torch.float32)
        for i in range(len(probs)):
            low: Tensor = bounds[i, self.mask_idc][0, 0, 0]
            high: Tensor = bounds[i, self.mask_idc][0, 0, 1]

            res = self.pathak_generator(probs[i].detach(), target[i].detach(), low, high)
            assert simplex(res, axis=0)
            assert res.shape == (c, w, h)

            fake_probs[i] = res
        fake_mask: Tensor = probs2one_hot(fake_probs)
        assert fake_mask.shape == probs.shape == target.shape

        return super().__call__(probs, fake_mask, bounds)
Esempio n. 22
0
    def __call__(self, probs: Tensor, dist_maps: Tensor, _: Tensor) -> Tensor:
        assert simplex(probs)
        assert not one_hot(dist_maps)

        pc = probs[:, self.idc, ...].type(torch.float32)
        dc = dist_maps[:, self.idc, ...].type(torch.float32)

        multipled = einsum("bcwh,bcwh->bcwh", pc, dc)

        loss = multipled.mean()

        return loss
Esempio n. 23
0
    def __call__(self, probs: Tensor, target: Tensor,
                 bounds: Tensor) -> Tensor:
        def penalty(z: Tensor) -> Tensor:
            assert z.shape == ()

            return torch.max(torch.zeros_like(z), z)**2

        assert simplex(
            probs
        )  # and simplex(target)  # Actually, does not care about second part
        #assert probs.shape == target.shape
        #print(probs.shape, target.shape)
        b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]
        #print(bounds.shape)
        if len(bounds.shape) == 3:
            bounds = torch.unsqueeze(bounds, 2)
        #print(bounds.shape)
        _, _, k, two = bounds.shape  # scalar or vector
        #print(bounds.shape,"bounds shape")
        assert two == 2
        # assert k == 1  # Keep it simple for now
        value: Tensor = self.__fn__(probs[:, self.idc, ...], self.power)
        #print(value.shape,"value shape")
        lower_b = bounds[:, self.idc, :, 0]
        upper_b = bounds[:, self.idc, :, 1]
        #print("value",value,"lb",lower_b)

        if len(value.shape) == 2:  #then its norm soft size ... to ameleiorate
            value = value.unsqueeze(2)
            lower_b = lower_b / (w * h)
            upper_b = upper_b / (w * h)
        #    print(np.around(lower_b.cpu().numpy().flatten()), np.around(upper_b.cpu().numpy().flatten()), np.around(value.cpu().detach().numpy().flatten()))

        assert value.shape == (b, self.C, k), value.shape
        assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape

        upper_z: Tensor = (value - upper_b).type(torch.float32).flatten()
        lower_z: Tensor = (lower_b - value).type(torch.float32).flatten()

        upper_penalty: Tensor = reduce(add, (penalty(e) for e in upper_z))
        lower_penalty: Tensor = reduce(add, (penalty(e) for e in lower_z))
        #count_up: Tensor = reduce(add, (penalty(e)>0 for e in lower_z))
        #count_low: Tensor = reduce(add, (penalty(e)>0 for e in upper_z))

        res: Tensor = upper_penalty + lower_penalty
        #count = count_up + count_low
        loss: Tensor = res.sum() / (w * h)**2
        #loss: Tensor = res.sum()
        assert loss.requires_grad == probs.requires_grad  # Handle the case for validation
        #print(round(loss.item(),1))
        return loss
Esempio n. 24
0
    def __call__(self, probs: Tensor, _: Tensor, __: Tensor, ___) -> Tensor:
        assert simplex(probs)

        B, K, *_ = probs.shape  # type: ignore

        lengths: Tensor = soft_length(probs[:, self.class_pair, ...])
        assert lengths.shape == (B, 2, 1), lengths.shape

        loss: Tensor = self.penalty(self.bounds[0] -
                                    lengths[0]) + self.penalty(lengths[1] -
                                                               self.bounds[1])
        assert loss.shape == (2, ), loss.shape

        return loss.mean()
Esempio n. 25
0
    def pathak_generator(self, probs: Tensor, target: Tensor, bounds) -> Tensor:
        _, w, h = probs.shape

        # Replace the probabilities with certainty for the few weak labels that we have
        weak_labels = target[...]
        weak_labels[self.ignore, ...] = 0
        assert not simplex(weak_labels) and simplex(target)
        lower, upper = bounds[-1]

        labeled_pixels = weak_labels.any(axis=0)
        assert w * h == (labeled_pixels.sum() + (~labeled_pixels).sum())  # make sure all pixels are covered
        scribbled_probs = weak_labels + einsum("cwh,wh->cwh", probs, ~labeled_pixels)
        assert simplex(scribbled_probs)

        u: Tensor
        max_iter: int = 100
        lr: float = 0.00005
        b: Tensor = Tensor([-lower, upper])
        beta: Tensor = torch.zeros(2, torch.float32)
        f: Tensor = torch.zeros(2, *probs.shape)
        f[0, ...] = -1
        f[1, ...] = 1

        for i in range(max_iter):
            exped = - einsum("i,icwh->cwh", beta, f).exp()
            u_star = einsum('cwh,cwh->cwh', probs, exped)
            u_star /= u_star.sum(axis=0)
            assert simplex(u_star)

            d_beta = einsum("cwh,icwh->i", u_star, f) - b
            n_beta = torch.max(torch.zeros_like(beta), beta + lr * d_beta)

            u = u_star
            beta = n_beta

        return probs2one_hot(u)
Esempio n. 26
0
    def boundaryLoss(self, logits, target) -> Tensor:
        idc = [1]
        dist_maps = target['dist_map']
        probs = F.softmax(logits, dim=1)
        #probs = logits
        assert simplex(probs)
        assert not one_hot(dist_maps)

        pc = probs[:, idc, ...].type(torch.float32)
        dc = dist_maps[:, idc, ...].type(torch.float32)

        multipled = einsum("bkwh,bkwh->bkwh", pc,
                           dc)  #简记求和,(求和前a下标, 求和b下标 -> 求和后x下标)j即一一求和

        loss = multipled.mean()
        return loss
Esempio n. 27
0
    def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor:
        def penalty(z: Tensor) -> Tensor:
            assert z.shape == ()

            return torch.max(torch.zeros_like(z), z)**2

        assert simplex(probs)  # and simplex(target)  # Actually, does not care about second part
        #print(probs.shape,target.shape)
        #assert probs.shape == target.shape

        b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]
        _, _, k, two = bounds.shape  # scalar or vector
        assert two == 2
        # assert k == 1  # Keep it simple for now
        value: Tensor = self.__fn__(probs[:, self.idc, ...])
        lower_b = bounds[:, self.idc, :, 0]
        upper_b = bounds[:, self.idc, :, 1]
        #if torch.rand(1).item()>0.999:
        #    print(lower_b, upper_b, value)

        gamma: float = 0.01
        assert value.shape == (b, self.C, k), value.shape
        assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape

        upper_z: Tensor = (value - upper_b).type(torch.float32).flatten()
        lower_z: Tensor = (lower_b - value).type(torch.float32).flatten()

        g_beta1 = -upper_z
        g_beta2 = lower_z

        #print(g_beta1)
        n_beta1_1 = max(0, gamma * g_beta1[0])
        n_beta1_2 = max(0, gamma * g_beta1[1])
        n_beta2_1 = max(0, gamma * g_beta2[0])
        n_beta2_2 = max(0, gamma * g_beta2[1])


        res: Tensor = g_beta1[0]*n_beta1_1 + g_beta1[1]*n_beta1_2+ g_beta2[0]*n_beta2_1+ g_beta2[1]*n_beta2_2

        #loss: Tensor = res.sum() / (w * h)
        loss: Tensor = res.sum()

        assert loss.requires_grad == probs.requires_grad  # Handle the case for validation

        return loss
Esempio n. 28
0
    def __call__(self, probs: Tensor, dist_maps: Tensor) -> Tensor:
        assert simplex(probs)
        assert not one_hot(dist_maps)

        pc = probs[:, self.idc, ...].type(torch.float32)
        dc = dist_maps[:, self.idc, ...].type(torch.float32)

        multipled = einsum("bkwh,bkwh->bkwh", pc, dc)

        #OPTION 1: do a soooort-of weighted mean by hand?
        #    weightedall = torch.dot(einsum("bkwh->k", pc), self.weights)
        #    weighted = torch.dot(einsum("bkwh->k", multipled), self.weights)
        #    loss = weighted / (weightedall + 1e-10) #kind of weighted mean?

        #OPTION 2: Simulate  the computation that happens if you put in multiple/per-class BLs in args
        loss = torch.dot(multipled.mean(dim=(0, 2, 3)), self.weights)

        return loss
Esempio n. 29
0
    def __call__(self, probs: Tensor, target: Tensor, bounds) -> Tensor:
        assert simplex(
            probs
        )  # and simplex(target)  # Actually, does not care about second part
        b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]
        predicted_mask = probs2one_hot(probs).detach()
        est_prop_mask = self.__fn__(predicted_mask, self.power).squeeze(2)
        est_prop: Tensor = self.__fn__(probs, self.power)
        if self.curi:
            if self.ivd:
                bounds = bounds[:, :, 0]
                bounds = bounds.unsqueeze(2)
            gt_prop = torch.ones_like(est_prop) * bounds / (w * h)
            gt_prop = gt_prop[:, :, 0]
        else:
            gt_prop: Tensor = self.__fn__(
                target, self.power
            )  # the power here is actually useless if we have 0/1 gt labels
        if not self.curi:
            gt_prop = gt_prop.squeeze(2)
        est_prop = est_prop.squeeze(2)
        log_est_prop: Tensor = (est_prop + 1e-10).log()

        log_gt_prop: Tensor = (gt_prop + 1e-10).log()
        log_est_prop_mask: Tensor = (est_prop_mask + 1e-10).log()

        loss_cons_prior = -torch.einsum(
            "bc,bc->", [est_prop, log_gt_prop]) + torch.einsum(
                "bc,bc->", [est_prop, log_est_prop])
        # Adding division by batch_size to normalise
        loss_cons_prior /= b
        log_p: Tensor = (probs + 1e-10).log()
        mask: Tensor = probs.type((torch.float32))
        mask_weighted = torch.einsum(
            "bcwh,c->bcwh", [mask, Tensor(self.weights).to(mask.device)])
        loss_se = -torch.einsum("bcwh,bcwh->", [mask_weighted, log_p])
        loss_se /= mask.sum() + 1e-10

        assert loss_se.requires_grad == probs.requires_grad  # Handle the case for validation

        return self.lamb_se * loss_se, self.lamb_consprior * loss_cons_prior, est_prop
Esempio n. 30
0
    def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor,
                 filenames: List[str]) -> Tensor:
        assert simplex(
            probs
        )  # and simplex(target)  # Actually, does not care about second part
        assert probs.shape == target.shape

        # b, _, w, h = probs.shape  # type: Tuple[int, int, int, int]
        b: int
        b, _, *im_shape = probs.shape
        _, _, k, two = bounds.shape  # scalar or vector
        assert two == 2

        value: Tensor = cast(Tensor, self.__fn__(probs[:, self.idc, ...]))
        lower_b = bounds[:, self.idc, :, 0]
        upper_b = bounds[:, self.idc, :, 1]

        assert value.shape == (b, self.C, k), value.shape
        assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape

        upper_z: Tensor = cast(Tensor,
                               (value - upper_b).type(torch.float32)).reshape(
                                   b, self.C * k)
        lower_z: Tensor = cast(Tensor,
                               (lower_b - value).type(torch.float32)).reshape(
                                   b, self.C * k)
        assert len(upper_z) == len(lower_b) == len(filenames)

        upper_penalty: Tensor = self.penalty(upper_z)
        lower_penalty: Tensor = self.penalty(lower_z)
        assert upper_penalty.numel() == lower_penalty.numel() == upper_z.numel(
        ) == lower_z.numel()

        # f for flattened axis
        res: Tensor = einsum("f->", upper_penalty) + einsum(
            "f->", lower_penalty)

        loss: Tensor = res.sum() / reduce(mul, im_shape)
        assert loss.requires_grad == probs.requires_grad  # Handle the case for validation

        return loss