def forward(self, x, y): shp_x = x.shape if self.batch_dice: axes = [0] + list(range(2, len(shp_x))) else: axes = list(range(2, len(shp_x))) if self.apply_nonlin is not None: net_output = self.apply_nonlin(x) # (b,c,x,y,z) gt_onehot = gt2onehot(net_output, y, axes) # (b,c,x,y,z) intersection = sum_tensor(net_output * gt_onehot, axes, keepdim=False) ground_o = sum_tensor(gt_onehot**2, axes, keepdim=False) pred_o = sum_tensor(net_output**2, axes, keepdim=False) dc = 2.0 * (intersection + self.smooth) / (ground_o + pred_o + self.smooth) if not self.do_bg: if self.batch_dice: dc = dc[1:] else: dc = dc[:, 1:] dc = dc.mean() return -dc
def run_online_evaluation(self, output, target): with torch.no_grad(): num_classes = output.shape[1] output_softmax = softmax_helper(output) output_seg = output_softmax.argmax(1) target = target[:, 0] axes = tuple(range(1, len(target.shape))) tp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fn_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) for c in range(1, num_classes): tp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, c - 1] = sum_tensor( (output_seg != c).float() * (target == c).float(), axes=axes) tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy() fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy() fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy() self.online_eval_foreground_dc.append( list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) self.online_eval_tp.append(list(tp_hard)) self.online_eval_fp.append(list(fp_hard)) self.online_eval_fn.append(list(fn_hard))
def run_online_evaluation(self, output, target): with torch.no_grad(): num_classes = output[0].shape[1] output_seg = output[0].argmax(1) target = target[0][:, 0] axes = tuple(range(1, len(target.shape))) tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index) for c in range(1, num_classes): tp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, c - 1] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes) # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target, # axes, None) # print_if_rank0("before allgather", tp_hard.shape) tp_hard = tp_hard.sum(0, keepdim=False)[None] fp_hard = fp_hard.sum(0, keepdim=False)[None] fn_hard = fn_hard.sum(0, keepdim=False)[None] tp_hard = awesome_allgather_function.apply(tp_hard) fp_hard = awesome_allgather_function.apply(fp_hard) fn_hard = awesome_allgather_function.apply(fn_hard) tp_hard = tp_hard.detach().cpu().numpy().sum(0) fp_hard = fp_hard.detach().cpu().numpy().sum(0) fn_hard = fn_hard.detach().cpu().numpy().sum(0) self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) self.online_eval_tp.append(list(tp_hard)) self.online_eval_fp.append(list(fp_hard)) self.online_eval_fn.append(list(fn_hard))
def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None, square_nominator=False, square_denom=False): if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]: rebalance_weights = rebalance_weights[1:] # this is the case when use_bg=False axes = tuple([0] + list(range(2, len(net_output.size())))) tp = sum_tensor(net_output * gt, axes, keepdim=False) fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False) fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False) weights = torch.ones(tp.shape) weights[0] = background_weight if net_output.device.type == "cuda": weights = weights.cuda(net_output.device.index) if rebalance_weights is not None: rebalance_weights = torch.from_numpy(rebalance_weights).float() if net_output.device.type == "cuda": rebalance_weights = rebalance_weights.cuda(net_output.device.index) tp = tp * rebalance_weights fn = fn * rebalance_weights nominator = tp if square_nominator: nominator = nominator ** 2 if square_denom: denom = 2 * tp ** 2 + fp ** 2 + fn ** 2 else: denom = 2 * tp + fp + fn result = (- ((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights).mean() return result
def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False): """ net_output must be (b, c, x, y(, z))) gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) if mask is provided it must have shape (b, 1, x, y(, z))) :param net_output: :param gt: :param axes: :param mask: mask must be 1 for valid pixels and 0 for invalid pixels :param square: if True then fp, tp and fn will be squared before summation :return: """ if axes is None: axes = tuple(range(2, len(net_output.size()))) shp_x = net_output.shape shp_y = gt.shape with torch.no_grad(): if len(shp_x) != len(shp_y): gt = gt.view((shp_y[0], 1, *shp_y[1:])) if all([i == j for i, j in zip(net_output.shape, gt.shape)]): # if this is the case then gt is probably already a one hot encoding y_onehot = gt else: gt = gt.long() y_onehot = torch.zeros(shp_x) if net_output.device.type == "cuda": y_onehot = y_onehot.cuda(net_output.device.index) y_onehot.scatter_(1, gt, 1) tp = net_output * y_onehot fp = net_output * (1 - y_onehot) fn = (1 - net_output) * y_onehot if mask is not None: tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) if square: tp = tp**2 fp = fp**2 fn = fn**2 tp = sum_tensor(tp, axes, keepdim=False) fp = sum_tensor(fp, axes, keepdim=False) fn = sum_tensor(fn, axes, keepdim=False) return tp, fp, fn
def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1., square_nominator=False, square_denom=False): axes = tuple(range(2, len(net_output.size()))) if square_nominator: intersect = sum_tensor(net_output * gt, axes, keepdim=False) else: intersect = sum_tensor((net_output * gt) ** 2, axes, keepdim=False) if square_denom: denom = sum_tensor(net_output ** 2 + gt ** 2, axes, keepdim=False) else: denom = sum_tensor(net_output + gt, axes, keepdim=False) result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth))).mean() return result
def forward(self, x, y, loss_mask=None): shp_x = x.shape shp_y = y.shape if self.batch_dice: axes = [0] + list(range(2, len(shp_x))) else: axes = list(range(2, len(shp_x))) if len(shp_x) != len(shp_y): y = y.view((shp_y[0], 1, *shp_y[1:])) if all([i == j for i, j in zip(x.shape, y.shape)]): # if this is the case then gt is probably already a one hot encoding y_onehot = y else: gt = y.long() y_onehot = torch.zeros(shp_x) if x.device.type == "cuda": y_onehot = y_onehot.cuda(x.device.index) y_onehot.scatter_(1, gt, 1) if self.apply_nonlin is not None: x = self.apply_nonlin(x) if not self.do_bg: x = x[:, 1:] y_onehot = y_onehot[:, 1:] tp, fp, fn, _ = get_tp_fp_fn_tn(x, y_onehot, axes, loss_mask, self.square) # GDL weight computation, we use 1/V volumes = sum_tensor( y_onehot, axes) + 1e-6 # add some eps to prevent div by zero if self.square_volumes: volumes = volumes**2 # apply weights tp = tp / volumes fp = fp / volumes fn = fn / volumes # sum over classes if self.batch_dice: axis = 0 else: axis = 1 tp = tp.sum(axis, keepdim=False) fp = fp.sum(axis, keepdim=False) fn = fn.sum(axis, keepdim=False) # compute dice dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth) dc = dc.mean() return -dc
def forward(self, x, y, loss_mask=None): shp_x = x.shape shp_y = y.shape if self.batch_dice: axes = [0] + list(range(2, len(shp_x))) else: axes = list(range(2, len(shp_x))) if self.apply_nonlin is not None: x = self.apply_nonlin(x) with torch.no_grad(): if len(shp_x) != len(shp_y): y = y.view((shp_y[0], 1, *shp_y[1:])) if all([i == j for i, j in zip(x.shape, y.shape)]): # if this is the case then gt is probably already a one hot encoding y_onehot = y else: y = y.long() y_onehot = torch.zeros(shp_x) if x.device.type == "cuda": y_onehot = y_onehot.cuda(x.device.index) y_onehot.scatter_(1, y, 1).float() intersect = x * y_onehot # values in the denominator get smoothed denominator = x**2 + y_onehot**2 # aggregation was previously done in get_tp_fp_fn, but needs to be done here now (needs to be done after # squaring) intersect = sum_tensor(intersect, axes, False) + self.smooth denominator = sum_tensor(denominator, axes, False) + self.smooth dc = 2 * intersect / denominator if not self.do_bg: if self.batch_dice: dc = dc[1:] else: dc = dc[:, 1:] dc = dc.mean() return -dc
def forward(self, net_output, gt, loss_mask=None): shp_x = net_output.shape shp_y = gt.shape # class_num = shp_x[1] with torch.no_grad(): if len(shp_x) != len(shp_y): gt = gt.view((shp_y[0], 1, *shp_y[1:])) if all([i == j for i, j in zip(net_output.shape, gt.shape)]): # if this is the case then gt is probably already a one hot encoding y_onehot = gt else: gt = gt.long() y_onehot = torch.zeros(shp_x) if net_output.device.type == "cuda": y_onehot = y_onehot.cuda(net_output.device.index) y_onehot.scatter_(1, gt, 1) if self.batch_dice: axes = [0] + list(range(2, len(shp_x))) else: axes = list(range(2, len(shp_x))) if self.apply_nonlin is not None: softmax_output = self.apply_nonlin(net_output) # no object value bg_onehot = 1 - y_onehot squared_error = (y_onehot - softmax_output)**2 specificity_part = sum_tensor(squared_error * y_onehot, axes) / ( sum_tensor(y_onehot, axes) + self.smooth) sensitivity_part = sum_tensor(squared_error * bg_onehot, axes) / ( sum_tensor(bg_onehot, axes) + self.smooth) ss = self.r * specificity_part + (1 - self.r) * sensitivity_part if not self.do_bg: if self.batch_dice: ss = ss[1:] else: ss = ss[:, 1:] ss = ss.mean() return ss
def mget_tp_fp_fn(net_output, gt, focal_conduct, axes=None, mask=None, square=False): """ net_output must be (b, c, x, y(, z))) gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) if mask is provided it must have shape (b, 1, x, y(, z))) :param net_output: :param gt: :param axes: :param mask: mask must be 1 for valid pixels and 0 for invalid pixels :param square: if True then fp, tp and fn will be squared before summation :return: """ if axes is None: axes = tuple(range(2, len(net_output.size()))) shp_x = net_output.shape shp_y = gt.shape with torch.no_grad(): if len(shp_x) != len(shp_y): gt = gt.view((shp_y[0], 1, *shp_y[1:])) if all([i == j for i, j in zip(net_output.shape, gt.shape)]): # if this is the case then gt is probably already a one hot encoding y_onehot = gt else: gt = gt.long() y_onehot = torch.zeros(shp_x) if net_output.device.type == "cuda": y_onehot = y_onehot.cuda(net_output.device.index) y_onehot.scatter_(1, gt, 1) # here is a little dummy...for simplicity, focal_conduct is [N,C] # y_onehot is [B, C, H, W, D] # first reshape [N,C] to [B, H, W, D, C] focal_conduct = torch.reshape( focal_conduct, (net_output.shape[0], net_output.shape[2], net_output.shape[3], net_output.shape[4], net_output.shape[1])) # then use transpose [B, H, W, D, C] to [B, C, H, W, D] focal_conduct = focal_conduct.transpose(4, 3) focal_conduct = focal_conduct.transpose(3, 2) focal_conduct = focal_conduct.transpose(2, 1) # this is like focal Tversky loss, but it would suppress too much # focal_fp = net_output * (1 - y_onehot) * focal_conduct2 focal_fp = net_output * (1 - y_onehot) focal_fn = (1 - net_output) * y_onehot * focal_conduct tp = net_output * y_onehot fp = net_output * (1 - y_onehot) fn = (1 - net_output) * y_onehot if mask is not None: tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) focal_fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(focal_fp, dim=1)), dim=1) focal_fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(focal_fn, dim=1)), dim=1) fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) if square: tp = tp**2 focal_fp = focal_fp**2 focal_fn = focal_fn**2 fp = fp**2 fn = fn**2 # print(tp.size()) tp = sum_tensor(tp, axes, keepdim=False) focal_fp = sum_tensor(focal_fp, axes, keepdim=False) focal_fn = sum_tensor(focal_fn, axes, keepdim=False) # print(tp.size()) fp = sum_tensor(fp, axes, keepdim=False) fn = sum_tensor(fn, axes, keepdim=False) return focal_fp, focal_fn, tp, fp, fn
def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) data = to_cuda(data, gpu_id=None) target = to_cuda(target, gpu_id=None) self.optimizer.zero_grad() output = self.network(data) del data total_loss = None for i in range(len(output)): # Starting here it gets spicy! axes = tuple(range(2, len(output[i].size()))) # network does not do softmax. We need to do softmax for dice output_softmax = softmax_helper(output[i]) # get the tp, fp and fn terms we need tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None) # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables # do_bg=False in nnUNetTrainer -> [:, 1:] nominator = 2 * tp[:, 1:] denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:] if self.batch_dice: # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice nominator = awesome_allgather_function.apply(nominator) denominator = awesome_allgather_function.apply(denominator) nominator = nominator.sum(0) denominator = denominator.sum(0) else: pass ce_loss = self.ce_loss(output[i], target[i]) # we smooth by 1e-5 to penalize false positives if tp is 0 dice_loss = (-(nominator + 1e-5) / (denominator + 1e-5)).mean() if total_loss is None: total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss) else: total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss) if run_online_evaluation: with torch.no_grad(): num_classes = output[0].shape[1] output_seg = output[0].argmax(1) target = target[0][:, 0] axes = tuple(range(1, len(target.shape))) tp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fn_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) for c in range(1, num_classes): tp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, c - 1] = sum_tensor( (output_seg != c).float() * (target == c).float(), axes=axes) # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target, # axes, None) # print_if_rank0("before allgather", tp_hard.shape) tp_hard = tp_hard.sum(0, keepdim=False)[None] fp_hard = fp_hard.sum(0, keepdim=False)[None] fn_hard = fn_hard.sum(0, keepdim=False)[None] tp_hard = awesome_allgather_function.apply(tp_hard) fp_hard = awesome_allgather_function.apply(fp_hard) fn_hard = awesome_allgather_function.apply(fn_hard) # print_if_rank0("after allgather", tp_hard.shape) # print_if_rank0("after sum", tp_hard.shape) self.run_online_evaluation( tp_hard.detach().cpu().numpy().sum(0), fp_hard.detach().cpu().numpy().sum(0), fn_hard.detach().cpu().numpy().sum(0)) del target if do_backprop: if not self.fp16 or amp is None: total_loss.backward() else: with amp.scale_loss(total_loss, self.optimizer) as scaled_loss: scaled_loss.backward() _ = clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() return total_loss.detach().cpu().numpy()
def forward(self, x, y=None, return_hard_tp_fp_fn=False): res = super(Generic_UNet_DP, self).forward(x) # regular Generic_UNet forward pass if y is None: return res else: # compute ce loss if self._deep_supervision and self.do_ds: ce_losses = [self.ce_loss(res[0], y[0]).unsqueeze(0)] tps = [] fps = [] fns = [] res_softmax = softmax_helper(res[0]) tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[0]) tps.append(tp) fps.append(fp) fns.append(fn) for i in range(1, len(y)): ce_losses.append(self.ce_loss(res[i], y[i]).unsqueeze(0)) res_softmax = softmax_helper(res[i]) tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[i]) tps.append(tp) fps.append(fp) fns.append(fn) ret = ce_losses, tps, fps, fns else: ce_loss = self.ce_loss(res, y).unsqueeze(0) # tp fp and fn need the output to be softmax res_softmax = softmax_helper(res) tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y) ret = ce_loss, tp, fp, fn if return_hard_tp_fp_fn: if self._deep_supervision and self.do_ds: output = res[0] target = y[0] else: target = y output = res with torch.no_grad(): num_classes = output.shape[1] output_softmax = softmax_helper(output) output_seg = output_softmax.argmax(1) target = target[:, 0] axes = tuple(range(1, len(target.shape))) tp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fn_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) for c in range(1, num_classes): tp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, c - 1] = sum_tensor( (output_seg != c).float() * (target == c).float(), axes=axes) tp_hard = tp_hard.sum(0, keepdim=False)[None] fp_hard = fp_hard.sum(0, keepdim=False)[None] fn_hard = fn_hard.sum(0, keepdim=False)[None] ret = *ret, tp_hard, fp_hard, fn_hard return ret
def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None, square_nominator=False, square_denom=False): if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]: rebalance_weights = rebalance_weights[ 1:] # this is the case when use_bg=False # print('\nrebalance_weights is:',rebalance_weights) #rebalance_weights is: None axes = tuple([0] + list(range(2, len(net_output.size())))) # print('\naxes is:',axes) #axes is: (0, 2, 3, 4) # print('\nnet_output shape is:',net_output.shape) # print('\ngt shape is:',gt.shape) #net_output shape is: torch.Size([8, 2, 192, 192, 48]) #gt shape is: torch.Size([8, 2, 192, 192, 48]) tp = sum_tensor(net_output * gt, axes, keepdim=False) # print('\ntp is:',tp) #tp shape is: torch.Size([2]) #tp is: tensor([62684.4570, 82510.1562], device='cuda:4', grad_fn=<SumBackward2>) fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False) # print('\nfn is:',fn) #fn shape is: torch.Size([2]) #fn is: tensor([195664.5312, 103144.8438], device='cuda:4', grad_fn=<SumBackward2>) fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False) # print('\nfp is:',fp) #fp shape is: torch.Size([2]) #fp is: tensor([3610596., 6475380.], device='cuda:4', grad_fn=<SumBackward2>) weights = torch.ones(tp.shape) # print('\nweights shape is:',weights.shape) #weights shape is: torch.Size([2]) weights[0] = background_weight # print('\nbackground_weight is:',background_weight) #background_weight is: 1 if net_output.device.type == "cuda": weights = weights.cuda(net_output.device.index) if rebalance_weights is not None: rebalance_weights = torch.from_numpy(rebalance_weights).float() if net_output.device.type == "cuda": rebalance_weights = rebalance_weights.cuda(net_output.device.index) tp = tp * rebalance_weights fn = fn * rebalance_weights nominator = tp if square_nominator: nominator = nominator**2 if square_denom: denom = 2 * tp**2 + fp**2 + fn**2 else: denom = 2 * tp + fp + fn # result_1=(- ((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights) # print('\nresult_1 is:',result_1) #result_1 is: tensor([-0.0616, -0.0038], device='cuda:4', grad_fn=<MulBackward0>) dice_1 = (((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights) # print('\ndice_1 is:',dice_1) result_1 = torch.pow((-torch.log(dice_1[0])), 0.3) * 0.4 + torch.pow( (-torch.log(dice_1[1])), 0.3) * 0.6 # print('\nresult_1 is:',result_1) # result = (- ((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights).mean() # print('\nresult is:',result) # result is: tensor(-0.0327, device='cuda:4', grad_fn= < MeanBackward0 >) # Here we should notice that the soft dice is set as negative. return result_1