def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor: assert simplex(probs) and simplex(target) assert probs.shape == target.shape b, c, w, h = probs.shape # type: Tuple[int, int, int, int] k = bounds.shape[2] # scalar or vector value: Tensor = self.__fn__(probs[:, self.idc, ...]) lower_b = bounds[:, self.idc, :, 0] upper_b = bounds[:, self.idc, :, 1] assert value.shape == (b, self.C, k), value.shape assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape too_big: Tensor = (value > upper_b).type(self.dtype) too_small: Tensor = (value < lower_b).type(self.dtype) big_pen: Tensor = (value - upper_b)**2 small_pen: Tensor = (value - lower_b)**2 res = too_big * big_pen + too_small * small_pen loss: Tensor = res / (w * h) return loss.mean()
def __call__(self, probs: Tensor, target: Tensor) -> Tensor: assert simplex(probs) assert simplex(target) assert probs.shape == target.shape B, K, *xyz = probs.shape # type: ignore pc = cast(Tensor, probs[:, self.idc, ...].type(torch.float32)) tc = cast(Tensor, target[:, self.idc, ...].type(torch.float32)) assert pc.shape == tc.shape == (B, len(self.idc), *xyz) target_dm_npy: np.ndarray = np.stack([one_hot2hd_dist(tc[b].cpu().detach().numpy()) for b in range(B)], axis=0) assert target_dm_npy.shape == tc.shape == pc.shape tdm: Tensor = torch.tensor(target_dm_npy, device=probs.device, dtype=torch.float32) pred_segmentation: Tensor = probs2one_hot(probs).cpu().detach() pred_dm_npy: np.nparray = np.stack([one_hot2hd_dist(pred_segmentation[b, self.idc, ...].numpy()) for b in range(B)], axis=0) assert pred_dm_npy.shape == tc.shape == pc.shape pdm: Tensor = torch.tensor(pred_dm_npy, device=probs.device, dtype=torch.float32) delta = (pc - tc)**2 dtm = tdm**2 + pdm**2 multipled = einsum("bkwh,bkwh->bkwh", delta, dtm) loss = multipled.mean() return loss
def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor: assert simplex(probs) and simplex(target) assert probs.shape == target.shape target_sizes: Tensor = bounds[:, self.idc, :, 1] # Dim of 1, upper and lower are the same volume_size: Tensor = einsum("bck->ck", target_sizes) lower_b = volume_size * (1 - self.margin) upper_b = volume_size * (1 + self.margin) _, _2, w, h = probs.shape # type: Tuple[int, int, int, int] k = bounds.shape[2] # scalar or vector value: Tensor = self.__fn__(probs[:, self.idc, ...]).sum(dim=0) assert value.shape == (self.C, k), value.shape assert lower_b.shape == upper_b.shape == (self.C, k), lower_b.shape too_big: Tensor = (value > upper_b).type(torch.float32) too_small: Tensor = (value < lower_b).type(torch.float32) big_pen: Tensor = (value - upper_b) ** 2 small_pen: Tensor = (value - lower_b) ** 2 res = too_big * big_pen + too_small * small_pen loss: Tensor = res / (w * h) return loss.mean()
def __call__(self, probs: Tensor, target: Tensor, _: Tensor) -> Tensor: assert simplex(probs) and simplex(target) pc = probs[:, self.idc, ...].type(torch.float32) tc = target[:, self.idc, ...].type(torch.float32) w: Tensor = 1 / ( (einsum("bcwh->bc", tc).type(torch.float32) + 1e-10)**2) intersection: Tensor = w * einsum("bcwh,bcwh->bc", pc, tc) union: Tensor = w * (einsum("bcwh->bc", pc) + einsum("bcwh->bc", tc)) divided: Tensor = 1 - 2 * (einsum("bc->b", intersection) + 1e-10) / (einsum("bc->b", union) + 1e-10) loss_gde = divided.mean() log_p: Tensor = (probs[:, self.idc, ...] + 1e-10).log() mask_weighted = torch.einsum( "bcwh,c->bcwh", [tc, Tensor(self.weights).to(tc.device)]) loss_ce = -torch.einsum("bcwh,bcwh->", [mask_weighted, log_p]) loss_ce /= tc.sum() + 1e-10 loss = loss_ce + self.lamb * loss_gde #print(loss_ce.item(),self.lamb*loss_gde.item()) return loss
def __call__(self, probs: Tensor, target: Tensor) -> Tensor: assert simplex(probs) and simplex(target) pc = probs[:, self.idc, ...].type(torch.float32) tc = target[:, self.idc, ...].type(torch.float32) #OPTION 1: instead of dynamically changing weights batch-based, keep them static based on input weights if self.opt == 1: w: Tensor = 1 / ((self.weights + 1e-10)**2) intersection: Tensor = w * einsum("bkwh,bkwh->bk", pc, tc) union: Tensor = w * (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc)) divided: Tensor = 1 - (2 * einsum("bk->b", intersection) + 1e-10) / (einsum("bk->b", union) + 1e-10) #OPTION 2: imitate the computation that happens if you put in multiple/per-class GDL losses as args else: #if self.opt==2: w: Tensor = 1 / ( (einsum("bkwh->bk", tc).type(torch.float32) + 1e-10)**2) intersection: Tensor = w * einsum("bkwh,bkwh->bk", pc, tc) union: Tensor = w * (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc)) divided: Tensor = self.weights.sum() - 2 * einsum( "bk->b", (intersection + 1e-10) / (union + 1e-10) * self.weights) loss = divided.mean() return loss
def __call__(self, probs: Tensor, target: Tensor, _: Tensor) -> Tensor: assert simplex(probs) and simplex(target) log_p: Tensor = (probs[:, self.idc, ...] + 1e-10).log() mask: Tensor = target[:, self.idc, ...].type(torch.float32) loss = -einsum("bcwh,bcwh->", mask, log_p) loss /= mask.sum() + 1e-10 return loss
def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor: assert simplex(probs) and simplex(target) log_p: Tensor = (probs[:, self.idc, ...] + 1e-10).log() mask: Tensor = target[:, self.idc, ...].type((torch.float32)) mask_weighted = torch.einsum( "bcwh,c->bcwh", [mask, Tensor(self.weights).to(mask.device)]) loss = -torch.einsum("bcwh,bcwh->", [mask_weighted, log_p]) loss /= mask.sum() + 1e-10 return loss
def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor: assert simplex(probs) and simplex(target) and sset(target, [0, 1]) assert probs.shape == target.shape with torch.no_grad(): fake_mask: Tensor = torch.zeros_like(probs) for i in range(len(probs)): fake_mask[i] = self.pathak_generator(probs[i], target[i], bounds[i]) self.holder_size = fake_mask[i].sum() return super().__call__(probs, fake_mask, bounds)
def __call__(self, probs: Tensor, target: Tensor, _: Tensor, ___) -> Tensor: assert simplex(probs) and simplex(target) log_p: Tensor = (probs[:, self.idc, ...] + 1e-10).log() mask: Tensor = cast(Tensor, target[:, self.idc, ...].type(torch.float32)) loss = -einsum(f"bk{self.nd},bk{self.nd}->", mask, log_p) loss /= mask.sum() + 1e-10 return loss
def __call__(self, probs: Tensor, target: Tensor) -> Tensor: assert simplex(probs) and simplex(target) masked_probs: Tensor = probs[:, self.idc, ...] log_p: Tensor = (masked_probs + 1e-10).log() mask: Tensor = cast(Tensor, target[:, self.idc, ...].type(torch.float32)) w: Tensor = (1 - masked_probs)**self.gamma loss = - einsum("bkwh,bkwh,bkwh->", w, mask, log_p) loss /= mask.sum() + 1e-10 return loss
def __call__(self, probs: Tensor, target: Tensor, _: Tensor) -> Tensor: assert simplex(probs) and simplex(target) pc = probs[:, self.idc, ...].type(torch.float32) tc = target[:, self.idc, ...].type(torch.float32) intersection: Tensor = einsum("bcwh,bcwh->bc", pc, tc) union: Tensor = (einsum("bcwh->bc", pc) + einsum("bcwh->bc", tc)) divided: Tensor = 1 - (2 * intersection + 1e-10) / (union + 1e-10) loss = divided.mean() return loss
def __call__(self, probs: Tensor, target: Tensor) -> Tensor: assert simplex(probs) and simplex(target) pc = probs[:, self.idc, ...].type(torch.float32) tc = target[:, self.idc, ...].type(torch.float32) w: Tensor = 1 / ((einsum("bkwh->bk", tc).type(torch.float32) + 1e-10) ** 2) intersection: Tensor = w * einsum("bkwh,bkwh->bk", pc, tc) union: Tensor = w * (einsum("bkwh->bk", pc) + einsum("bkwh->bk", tc)) divided: Tensor = 1 - 2 * (einsum("bk->b", intersection) + 1e-10) / (einsum("bk->b", union) + 1e-10) loss = divided.mean() return loss
def __call__(self, probs: Tensor, target: Tensor, bounds) -> Tensor: #print('bounds',torch.round(bounds*10**2)/10**2) assert simplex( probs ) # and simplex(target) # Actually, does not care about second part b, _, w, h = probs.shape # type: Tuple[int, int, int, int] est_prop: Tensor = self.__fn__(probs, self.power) if self.curi: #print(bounds) if self.ivd: bounds = bounds[:, :, 0] #gt_prop = gt_prop[:,:,0] gt_prop = torch.ones_like(est_prop) * bounds / (w * h) gt_prop = gt_prop[:, :, 0] else: gt_prop: Tensor = self.__fn__( target, self.power ) # the power here is actually useless if we have 0/1 gt labels if not self.curi: gt_prop = gt_prop.squeeze(2) est_prop = est_prop.squeeze(2) log_est_prop: Tensor = (est_prop + 1e-10).log() loss = -torch.einsum("bc,bc->", [gt_prop, log_est_prop]) assert loss.requires_grad == probs.requires_grad # Handle the case for validation #print(loss) return loss
def __call__(self, probs: Tensor, dist_maps: Tensor, _: Tensor) -> Tensor: """ net_output: (batch_size, 2, x,y,z) target: ground truth, shape: (batch_size, 1, x,y,z) """ assert simplex(probs) assert not one_hot(dist_maps) pc = probs[:, self.idc, ...].type(torch.float32) with torch.no_grad(): pc_dist = torch.from_numpy( compute_edts_forhdloss(pc.detach().cpu().numpy() > 0.5)) gt_dist = dist_maps[:, self.idc, ...].type(torch.float32) pos_log = torch.where(gt_dist > 0, torch.log(gt_dist), gt_dist) twos_log = torch.where(pos_log < 0, -torch.log(-pos_log), pos_log) pc_pos_log = torch.where(pc_dist > 0, torch.log(pc_dist), pc_dist) pc_twos_log = torch.where(pc_pos_log < 0, -torch.log(-pc_pos_log), pc_pos_log) if pc_twos_log.device != pc.device: pc_twos_log = pc_twos_log.to(pc.device).type(torch.float32) multipled = einsum("bcwh,bcwh->bcwh", pc, twos_log) return multipled.mean()
def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor: def penalty(z: Tensor) -> Tensor: assert z.shape == () return torch.max(torch.zeros_like(z), z)**2 assert simplex(probs) # and simplex(target) # Actually, does not care about second part assert probs.shape == target.shape b, _, w, h = probs.shape # type: Tuple[int, int, int, int] _, _, k, two = bounds.shape # scalar or vector assert two == 2 # assert k == 1 # Keep it simple for now value: Tensor = self.__fn__(probs[:, self.idc, ...]) lower_b = bounds[:, self.idc, :, 0] upper_b = bounds[:, self.idc, :, 1] assert value.shape == (b, self.C, k), value.shape assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape upper_z: Tensor = (value - upper_b).type(torch.float32).flatten() lower_z: Tensor = (lower_b - value).type(torch.float32).flatten() upper_penalty: Tensor = reduce(add, (penalty(e) for e in upper_z)) lower_penalty: Tensor = reduce(add, (penalty(e) for e in lower_z)) res: Tensor = upper_penalty + lower_penalty loss: Tensor = res.sum() / (w * h) assert loss.requires_grad == probs.requires_grad # Handle the case for validation return loss
def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor: def log_barrier(z: Tensor) -> Tensor: assert z.shape == () if z <= - 1 / self.t**2: return - torch.log(-z) / self.t else: return self.t * z + -np.log(1 / (self.t**2)) / self.t + 1 / self.t assert simplex(probs) # and simplex(target) # Actually, does not care about second part assert probs.shape == target.shape b, _, w, h = probs.shape # type: Tuple[int, int, int, int] _, _, k, two = bounds.shape # scalar or vector assert two == 2 # assert k == 1 # Keep it simple for now value: Tensor = self.__fn__(probs[:, self.idc, ...]) lower_b = bounds[:, self.idc, :, 0] upper_b = bounds[:, self.idc, :, 1] assert value.shape == (b, self.C, k), value.shape assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape upper_z: Tensor = (value - upper_b).type(torch.float32).flatten() lower_z: Tensor = (lower_b - value).type(torch.float32).flatten() upper_barrier: Tensor = reduce(add, (log_barrier(e) for e in upper_z)) lower_barrier: Tensor = reduce(add, (log_barrier(e) for e in lower_z)) res: Tensor = upper_barrier + lower_barrier loss: Tensor = res.sum() / (w * h) assert loss.requires_grad == probs.requires_grad # Handle the case for validation return loss
def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor, _) -> Tensor: assert simplex(probs) # and simplex(target) # Actually, does not care about second part assert probs.shape == target.shape # b, _, w, h = probs.shape # type: Tuple[int, int, int, int] b: int b, _, *im_shape = probs.shape _, _, k, two = bounds.shape # scalar or vector assert two == 2 value: Tensor = cast(Tensor, self.__fn__(probs[:, self.idc, ...])) lower_b = bounds[:, self.idc, :, 0] upper_b = bounds[:, self.idc, :, 1] assert value.shape == (b, self.C, k), value.shape assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape upper_z: Tensor = cast(Tensor, (value - upper_b).type(torch.float32)).flatten() lower_z: Tensor = cast(Tensor, (lower_b - value).type(torch.float32)).flatten() upper_penalty: Tensor = reduce(add, (self.penalty(e) for e in upper_z)) lower_penalty: Tensor = reduce(add, (self.penalty(e) for e in lower_z)) res: Tensor = upper_penalty + lower_penalty loss: Tensor = res.sum() / reduce(mul, im_shape) assert loss.requires_grad == probs.requires_grad # Handle the case for validation return loss
def __call__(self, probs: Tensor, target: Tensor, bounds) -> Tensor: #print('bounds',torch.round(bounds*10**2)/10**2) assert simplex(probs) # and simplex(target) # Actually, does not care about second part b, _, w, h = probs.shape # type: Tuple[int, int, int, int] #if self.fgt: # two = bounds.shape # scalar or vector #else: # _,_,k,two = bounds.shape #assert two == 2 # est_prop is the proportion estimated by the network est_prop: Tensor = self.__fn__(probs,self.power) #print('est_prop',torch.round(est_prop*10**2)/10**2) # gt_prop is the proportion in the ground truth if self.curi: bounds = bounds[:,:,0,0] #print(bounds.shape) gt_prop = torch.ones_like(est_prop)*bounds/(w*h) #gt_prop1: Tensor = self.__fn__(target,self.power) # the power here is actually useless if we have 0/1 gt labels #print(gt_prop,gt_prop1) else: gt_prop: Tensor = self.__fn__(target,self.power) # the power here is actually useless if we have 0/1 gt labels #gt_prop = (gt_prop/(w*h)).type(torch.float32).flatten() #value = (value/(w*h)).type(torch.float32).flatten() #print(gt_prop.shape) log_est_prop: Tensor = (est_prop + 1e-10).log() #print(log_est_prop.shape) loss = - torch.einsum("bc,bc->", [gt_prop, log_est_prop]) assert loss.requires_grad == probs.requires_grad # Handle the case for validation #print(loss) return loss
def __call__(self, probs: Tensor, _: Tensor, __: Tensor, box_prior: List[List[Tuple[Tensor, Tensor]]]) -> Tensor: assert simplex(probs) B: int = probs.shape[0] assert len(box_prior) == B sublosses = [] for b in range(B): for k in self.idc: masks, bounds = box_prior[b][k] sizes: Tensor = einsum('wh,nwh->n', probs[b, k], masks) assert sizes.shape == bounds.shape == (masks.shape[0],), (sizes.shape, bounds.shape, masks.shape) shifted: Tensor = bounds - sizes init = torch.zeros((), dtype=torch.float32, requires_grad=probs.requires_grad, device=probs.device) sublosses.append(reduce(add, (self.barrier(v) for v in shifted), init)) loss: Tensor = reduce(add, sublosses) assert loss.dtype == torch.float32 assert loss.shape == (), loss.shape return loss
def __call__(self, probs: Tensor, target: Tensor, _: Tensor, __) -> Tensor: assert simplex(probs) and simplex(target) b: int b, _, *im_shape = probs.shape probs_m: Tensor = probs[:, self.idc, ...] target_m: Tensor = cast(Tensor, target[:, self.idc, ...].type(torch.float32)) nd: str = self.nd # Compute the size for each class, masked by the target pixels (where target ==1) masked_sizes: Tensor = einsum(f"bk{nd},bk{nd}->bk", probs_m, target_m).flatten() # We want that size to be <= so no shift is needed res: Tensor = reduce(add, (self.penalty(e) for e in masked_sizes)) # type: ignore loss: Tensor = res / reduce(mul, im_shape) assert loss.shape == () assert loss.requires_grad == probs.requires_grad # Handle the case for validation return loss
def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor: assert simplex(probs) assert probs.shape == target.shape assert len(self.mask_idc) == 1, "Cannot handle more at the time, I guess" b, c, w, h = probs.shape fake_probs: Tensor = torch.zeros_like(probs, dtype=torch.float32) for i in range(len(probs)): low: Tensor = bounds[i, self.mask_idc][0, 0, 0] high: Tensor = bounds[i, self.mask_idc][0, 0, 1] res = self.pathak_generator(probs[i].detach(), target[i].detach(), low, high) assert simplex(res, axis=0) assert res.shape == (c, w, h) fake_probs[i] = res fake_mask: Tensor = probs2one_hot(fake_probs) assert fake_mask.shape == probs.shape == target.shape return super().__call__(probs, fake_mask, bounds)
def __call__(self, probs: Tensor, dist_maps: Tensor, _: Tensor) -> Tensor: assert simplex(probs) assert not one_hot(dist_maps) pc = probs[:, self.idc, ...].type(torch.float32) dc = dist_maps[:, self.idc, ...].type(torch.float32) multipled = einsum("bcwh,bcwh->bcwh", pc, dc) loss = multipled.mean() return loss
def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor: def penalty(z: Tensor) -> Tensor: assert z.shape == () return torch.max(torch.zeros_like(z), z)**2 assert simplex( probs ) # and simplex(target) # Actually, does not care about second part #assert probs.shape == target.shape #print(probs.shape, target.shape) b, _, w, h = probs.shape # type: Tuple[int, int, int, int] #print(bounds.shape) if len(bounds.shape) == 3: bounds = torch.unsqueeze(bounds, 2) #print(bounds.shape) _, _, k, two = bounds.shape # scalar or vector #print(bounds.shape,"bounds shape") assert two == 2 # assert k == 1 # Keep it simple for now value: Tensor = self.__fn__(probs[:, self.idc, ...], self.power) #print(value.shape,"value shape") lower_b = bounds[:, self.idc, :, 0] upper_b = bounds[:, self.idc, :, 1] #print("value",value,"lb",lower_b) if len(value.shape) == 2: #then its norm soft size ... to ameleiorate value = value.unsqueeze(2) lower_b = lower_b / (w * h) upper_b = upper_b / (w * h) # print(np.around(lower_b.cpu().numpy().flatten()), np.around(upper_b.cpu().numpy().flatten()), np.around(value.cpu().detach().numpy().flatten())) assert value.shape == (b, self.C, k), value.shape assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape upper_z: Tensor = (value - upper_b).type(torch.float32).flatten() lower_z: Tensor = (lower_b - value).type(torch.float32).flatten() upper_penalty: Tensor = reduce(add, (penalty(e) for e in upper_z)) lower_penalty: Tensor = reduce(add, (penalty(e) for e in lower_z)) #count_up: Tensor = reduce(add, (penalty(e)>0 for e in lower_z)) #count_low: Tensor = reduce(add, (penalty(e)>0 for e in upper_z)) res: Tensor = upper_penalty + lower_penalty #count = count_up + count_low loss: Tensor = res.sum() / (w * h)**2 #loss: Tensor = res.sum() assert loss.requires_grad == probs.requires_grad # Handle the case for validation #print(round(loss.item(),1)) return loss
def __call__(self, probs: Tensor, _: Tensor, __: Tensor, ___) -> Tensor: assert simplex(probs) B, K, *_ = probs.shape # type: ignore lengths: Tensor = soft_length(probs[:, self.class_pair, ...]) assert lengths.shape == (B, 2, 1), lengths.shape loss: Tensor = self.penalty(self.bounds[0] - lengths[0]) + self.penalty(lengths[1] - self.bounds[1]) assert loss.shape == (2, ), loss.shape return loss.mean()
def pathak_generator(self, probs: Tensor, target: Tensor, bounds) -> Tensor: _, w, h = probs.shape # Replace the probabilities with certainty for the few weak labels that we have weak_labels = target[...] weak_labels[self.ignore, ...] = 0 assert not simplex(weak_labels) and simplex(target) lower, upper = bounds[-1] labeled_pixels = weak_labels.any(axis=0) assert w * h == (labeled_pixels.sum() + (~labeled_pixels).sum()) # make sure all pixels are covered scribbled_probs = weak_labels + einsum("cwh,wh->cwh", probs, ~labeled_pixels) assert simplex(scribbled_probs) u: Tensor max_iter: int = 100 lr: float = 0.00005 b: Tensor = Tensor([-lower, upper]) beta: Tensor = torch.zeros(2, torch.float32) f: Tensor = torch.zeros(2, *probs.shape) f[0, ...] = -1 f[1, ...] = 1 for i in range(max_iter): exped = - einsum("i,icwh->cwh", beta, f).exp() u_star = einsum('cwh,cwh->cwh', probs, exped) u_star /= u_star.sum(axis=0) assert simplex(u_star) d_beta = einsum("cwh,icwh->i", u_star, f) - b n_beta = torch.max(torch.zeros_like(beta), beta + lr * d_beta) u = u_star beta = n_beta return probs2one_hot(u)
def boundaryLoss(self, logits, target) -> Tensor: idc = [1] dist_maps = target['dist_map'] probs = F.softmax(logits, dim=1) #probs = logits assert simplex(probs) assert not one_hot(dist_maps) pc = probs[:, idc, ...].type(torch.float32) dc = dist_maps[:, idc, ...].type(torch.float32) multipled = einsum("bkwh,bkwh->bkwh", pc, dc) #简记求和,(求和前a下标, 求和b下标 -> 求和后x下标)j即一一求和 loss = multipled.mean() return loss
def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor) -> Tensor: def penalty(z: Tensor) -> Tensor: assert z.shape == () return torch.max(torch.zeros_like(z), z)**2 assert simplex(probs) # and simplex(target) # Actually, does not care about second part #print(probs.shape,target.shape) #assert probs.shape == target.shape b, _, w, h = probs.shape # type: Tuple[int, int, int, int] _, _, k, two = bounds.shape # scalar or vector assert two == 2 # assert k == 1 # Keep it simple for now value: Tensor = self.__fn__(probs[:, self.idc, ...]) lower_b = bounds[:, self.idc, :, 0] upper_b = bounds[:, self.idc, :, 1] #if torch.rand(1).item()>0.999: # print(lower_b, upper_b, value) gamma: float = 0.01 assert value.shape == (b, self.C, k), value.shape assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape upper_z: Tensor = (value - upper_b).type(torch.float32).flatten() lower_z: Tensor = (lower_b - value).type(torch.float32).flatten() g_beta1 = -upper_z g_beta2 = lower_z #print(g_beta1) n_beta1_1 = max(0, gamma * g_beta1[0]) n_beta1_2 = max(0, gamma * g_beta1[1]) n_beta2_1 = max(0, gamma * g_beta2[0]) n_beta2_2 = max(0, gamma * g_beta2[1]) res: Tensor = g_beta1[0]*n_beta1_1 + g_beta1[1]*n_beta1_2+ g_beta2[0]*n_beta2_1+ g_beta2[1]*n_beta2_2 #loss: Tensor = res.sum() / (w * h) loss: Tensor = res.sum() assert loss.requires_grad == probs.requires_grad # Handle the case for validation return loss
def __call__(self, probs: Tensor, dist_maps: Tensor) -> Tensor: assert simplex(probs) assert not one_hot(dist_maps) pc = probs[:, self.idc, ...].type(torch.float32) dc = dist_maps[:, self.idc, ...].type(torch.float32) multipled = einsum("bkwh,bkwh->bkwh", pc, dc) #OPTION 1: do a soooort-of weighted mean by hand? # weightedall = torch.dot(einsum("bkwh->k", pc), self.weights) # weighted = torch.dot(einsum("bkwh->k", multipled), self.weights) # loss = weighted / (weightedall + 1e-10) #kind of weighted mean? #OPTION 2: Simulate the computation that happens if you put in multiple/per-class BLs in args loss = torch.dot(multipled.mean(dim=(0, 2, 3)), self.weights) return loss
def __call__(self, probs: Tensor, target: Tensor, bounds) -> Tensor: assert simplex( probs ) # and simplex(target) # Actually, does not care about second part b, _, w, h = probs.shape # type: Tuple[int, int, int, int] predicted_mask = probs2one_hot(probs).detach() est_prop_mask = self.__fn__(predicted_mask, self.power).squeeze(2) est_prop: Tensor = self.__fn__(probs, self.power) if self.curi: if self.ivd: bounds = bounds[:, :, 0] bounds = bounds.unsqueeze(2) gt_prop = torch.ones_like(est_prop) * bounds / (w * h) gt_prop = gt_prop[:, :, 0] else: gt_prop: Tensor = self.__fn__( target, self.power ) # the power here is actually useless if we have 0/1 gt labels if not self.curi: gt_prop = gt_prop.squeeze(2) est_prop = est_prop.squeeze(2) log_est_prop: Tensor = (est_prop + 1e-10).log() log_gt_prop: Tensor = (gt_prop + 1e-10).log() log_est_prop_mask: Tensor = (est_prop_mask + 1e-10).log() loss_cons_prior = -torch.einsum( "bc,bc->", [est_prop, log_gt_prop]) + torch.einsum( "bc,bc->", [est_prop, log_est_prop]) # Adding division by batch_size to normalise loss_cons_prior /= b log_p: Tensor = (probs + 1e-10).log() mask: Tensor = probs.type((torch.float32)) mask_weighted = torch.einsum( "bcwh,c->bcwh", [mask, Tensor(self.weights).to(mask.device)]) loss_se = -torch.einsum("bcwh,bcwh->", [mask_weighted, log_p]) loss_se /= mask.sum() + 1e-10 assert loss_se.requires_grad == probs.requires_grad # Handle the case for validation return self.lamb_se * loss_se, self.lamb_consprior * loss_cons_prior, est_prop
def __call__(self, probs: Tensor, target: Tensor, bounds: Tensor, filenames: List[str]) -> Tensor: assert simplex( probs ) # and simplex(target) # Actually, does not care about second part assert probs.shape == target.shape # b, _, w, h = probs.shape # type: Tuple[int, int, int, int] b: int b, _, *im_shape = probs.shape _, _, k, two = bounds.shape # scalar or vector assert two == 2 value: Tensor = cast(Tensor, self.__fn__(probs[:, self.idc, ...])) lower_b = bounds[:, self.idc, :, 0] upper_b = bounds[:, self.idc, :, 1] assert value.shape == (b, self.C, k), value.shape assert lower_b.shape == upper_b.shape == (b, self.C, k), lower_b.shape upper_z: Tensor = cast(Tensor, (value - upper_b).type(torch.float32)).reshape( b, self.C * k) lower_z: Tensor = cast(Tensor, (lower_b - value).type(torch.float32)).reshape( b, self.C * k) assert len(upper_z) == len(lower_b) == len(filenames) upper_penalty: Tensor = self.penalty(upper_z) lower_penalty: Tensor = self.penalty(lower_z) assert upper_penalty.numel() == lower_penalty.numel() == upper_z.numel( ) == lower_z.numel() # f for flattened axis res: Tensor = einsum("f->", upper_penalty) + einsum( "f->", lower_penalty) loss: Tensor = res.sum() / reduce(mul, im_shape) assert loss.requires_grad == probs.requires_grad # Handle the case for validation return loss