def compute_loss(self, output, target): total_loss = None for i in range(len(output)): # Starting here it gets spicy! axes = tuple(range(2, len(output[i].size()))) # network does not do softmax. We need to do softmax for dice output_softmax = softmax_helper(output[i]) # get the tp, fp and fn terms we need tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None) # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables # do_bg=False in nnUNetTrainer -> [:, 1:] nominator = 2 * tp[:, 1:] denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:] if self.batch_dice: # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice nominator = awesome_allgather_function.apply(nominator) denominator = awesome_allgather_function.apply(denominator) nominator = nominator.sum(0) denominator = denominator.sum(0) else: pass ce_loss = self.ce_loss(output[i], target[i][:, 0].long()) # we smooth by 1e-5 to penalize false positives if tp is 0 dice_loss = (- (nominator + 1e-5) / (denominator + 1e-5)).mean() if total_loss is None: total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss) else: total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss) return total_loss
def run_online_evaluation(self, output, target): output = output[0] target = target[0] with torch.no_grad(): out_sigmoid = torch.sigmoid(output) out_sigmoid = (out_sigmoid > 0.5).float() if self.threeD: axes = (0, 2, 3, 4) else: axes = (0, 2, 3) tp, fp, fn, _ = get_tp_fp_fn_tn(out_sigmoid, target, axes=axes) tp_hard = tp.detach().cpu().numpy() fp_hard = fp.detach().cpu().numpy() fn_hard = fn.detach().cpu().numpy() self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) self.online_eval_tp.append(list(tp_hard)) self.online_eval_fp.append(list(fp_hard)) self.online_eval_fn.append(list(fn_hard))
def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): raise NotImplementedError("this class has not been changed to work with pytorch amp yet!") data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data, gpu_id=None) target = to_cuda(target, gpu_id=None) self.optimizer.zero_grad() output = self.network(data) del data total_loss = None for i in range(len(output)): # Starting here it gets spicy! axes = tuple(range(2, len(output[i].size()))) # network does not do softmax. We need to do softmax for dice output_softmax = torch.sigmoid(output[i]) # get the tp, fp and fn terms we need tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None) # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables # do_bg=False in nnUNetTrainer -> [:, 1:] nominator = 2 * tp[:, 1:] denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:] if self.batch_dice: # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice nominator = awesome_allgather_function.apply(nominator) denominator = awesome_allgather_function.apply(denominator) nominator = nominator.sum(0) denominator = denominator.sum(0) else: pass ce_loss = self.ce_loss(output[i], target[i]) # we smooth by 1e-5 to penalize false positives if tp is 0 dice_loss = (- (nominator + 1e-5) / (denominator + 1e-5)).mean() if total_loss is None: total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss) else: total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss) if run_online_evaluation: with torch.no_grad(): output = output[0] target = target[0] out_sigmoid = torch.sigmoid(output) out_sigmoid = (out_sigmoid > 0.5).float() if self.threeD: axes = (2, 3, 4) else: axes = (2, 3) tp, fp, fn, _ = get_tp_fp_fn_tn(out_sigmoid, target, axes=axes) tp_hard = awesome_allgather_function.apply(tp) fp_hard = awesome_allgather_function.apply(fp) fn_hard = awesome_allgather_function.apply(fn) # print_if_rank0("after allgather", tp_hard.shape) # print_if_rank0("after sum", tp_hard.shape) self.run_online_evaluation(tp_hard.detach().cpu().numpy().sum(0), fp_hard.detach().cpu().numpy().sum(0), fn_hard.detach().cpu().numpy().sum(0)) del target if do_backprop: if not self.fp16 or amp is None or not torch.cuda.is_available(): total_loss.backward() else: with amp.scale_loss(total_loss, self.optimizer) as scaled_loss: scaled_loss.backward() _ = clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() return total_loss.detach().cpu().numpy()
def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) data = to_cuda(data, gpu_id=None) target = to_cuda(target, gpu_id=None) self.optimizer.zero_grad() output = self.network(data) del data total_loss = None for i in range(len(output)): # Starting here it gets spicy! axes = tuple(range(2, len(output[i].size()))) # network does not do softmax. We need to do softmax for dice output_softmax = softmax_helper(output[i]) # get the tp, fp and fn terms we need tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None) # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables # do_bg=False in nnUNetTrainer -> [:, 1:] nominator = 2 * tp[:, 1:] denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:] if self.batch_dice: # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice nominator = awesome_allgather_function.apply(nominator) denominator = awesome_allgather_function.apply(denominator) nominator = nominator.sum(0) denominator = denominator.sum(0) else: pass ce_loss = self.ce_loss(output[i], target[i]) # we smooth by 1e-5 to penalize false positives if tp is 0 dice_loss = (-(nominator + 1e-5) / (denominator + 1e-5)).mean() if total_loss is None: total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss) else: total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss) if run_online_evaluation: with torch.no_grad(): num_classes = output[0].shape[1] output_seg = output[0].argmax(1) target = target[0][:, 0] axes = tuple(range(1, len(target.shape))) tp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fn_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) for c in range(1, num_classes): tp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, c - 1] = sum_tensor( (output_seg != c).float() * (target == c).float(), axes=axes) # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target, # axes, None) # print_if_rank0("before allgather", tp_hard.shape) tp_hard = tp_hard.sum(0, keepdim=False)[None] fp_hard = fp_hard.sum(0, keepdim=False)[None] fn_hard = fn_hard.sum(0, keepdim=False)[None] tp_hard = awesome_allgather_function.apply(tp_hard) fp_hard = awesome_allgather_function.apply(fp_hard) fn_hard = awesome_allgather_function.apply(fn_hard) # print_if_rank0("after allgather", tp_hard.shape) # print_if_rank0("after sum", tp_hard.shape) self.run_online_evaluation( tp_hard.detach().cpu().numpy().sum(0), fp_hard.detach().cpu().numpy().sum(0), fn_hard.detach().cpu().numpy().sum(0)) del target if do_backprop: if not self.fp16 or amp is None: total_loss.backward() else: with amp.scale_loss(total_loss, self.optimizer) as scaled_loss: scaled_loss.backward() _ = clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() return total_loss.detach().cpu().numpy()
def forward(self, x, y=None, return_hard_tp_fp_fn=False): res = super(Generic_UNet_DP, self).forward(x) # regular Generic_UNet forward pass if y is None: return res else: # compute ce loss if self._deep_supervision and self.do_ds: ce_losses = [self.ce_loss(res[0], y[0]).unsqueeze(0)] tps = [] fps = [] fns = [] res_softmax = softmax_helper(res[0]) tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[0]) tps.append(tp) fps.append(fp) fns.append(fn) for i in range(1, len(y)): ce_losses.append(self.ce_loss(res[i], y[i]).unsqueeze(0)) res_softmax = softmax_helper(res[i]) tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[i]) tps.append(tp) fps.append(fp) fns.append(fn) ret = ce_losses, tps, fps, fns else: ce_loss = self.ce_loss(res, y).unsqueeze(0) # tp fp and fn need the output to be softmax res_softmax = softmax_helper(res) tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y) ret = ce_loss, tp, fp, fn if return_hard_tp_fp_fn: if self._deep_supervision and self.do_ds: output = res[0] target = y[0] else: target = y output = res with torch.no_grad(): num_classes = output.shape[1] output_softmax = softmax_helper(output) output_seg = output_softmax.argmax(1) target = target[:, 0] axes = tuple(range(1, len(target.shape))) tp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fn_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) for c in range(1, num_classes): tp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, c - 1] = sum_tensor( (output_seg != c).float() * (target == c).float(), axes=axes) tp_hard = tp_hard.sum(0, keepdim=False)[None] fp_hard = fp_hard.sum(0, keepdim=False)[None] fn_hard = fn_hard.sum(0, keepdim=False)[None] ret = *ret, tp_hard, fp_hard, fn_hard return ret