def run_online_evaluation(self, output, target): with torch.no_grad(): num_classes = output.shape[1] output_softmax = softmax_helper(output) output_seg = output_softmax.argmax(1) target = target[:, 0] axes = tuple(range(1, len(target.shape))) tp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fn_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) for c in range(1, num_classes): tp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, c - 1] = sum_tensor( (output_seg != c).float() * (target == c).float(), axes=axes) tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy() fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy() fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy() self.online_eval_foreground_dc.append( list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) self.online_eval_tp.append(list(tp_hard)) self.online_eval_fp.append(list(fp_hard)) self.online_eval_fn.append(list(fn_hard))
def forward(self, x, y, loss_mask=None): shp_x = x.shape shp_y = y.shape if self.batch_dice: axes = [0] + list(range(2, len(shp_x))) else: axes = list(range(2, len(shp_x))) if len(shp_x) != len(shp_y): y = y.view((shp_y[0], 1, *shp_y[1:])) if all([i == j for i, j in zip(x.shape, y.shape)]): # if this is the case then gt is probably already a one hot encoding y_onehot = y else: gt = y.long() y_onehot = torch.zeros(shp_x) if x.device.type == "cuda": y_onehot = y_onehot.cuda(x.device.index) y_onehot.scatter_(1, gt, 1) if self.apply_nonlin is not None: x = self.apply_nonlin(x) if not self.do_bg: x = x[:, 1:] y_onehot = y_onehot[:, 1:] tp, fp, fn, _ = get_tp_fp_fn_tn(x, y_onehot, axes, loss_mask, self.square) # GDL weight computation, we use 1/V volumes = sum_tensor( y_onehot, axes) + 1e-6 # add some eps to prevent div by zero if self.square_volumes: volumes = volumes**2 # apply weights tp = tp / volumes fp = fp / volumes fn = fn / volumes # sum over classes if self.batch_dice: axis = 0 else: axis = 1 tp = tp.sum(axis, keepdim=False) fp = fp.sum(axis, keepdim=False) fn = fn.sum(axis, keepdim=False) # compute dice dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth) dc = dc.mean() return -dc
def forward(self, x, y, loss_mask=None): shp_x = x.shape shp_y = y.shape if self.batch_dice: axes = [0] + list(range(2, len(shp_x))) else: axes = list(range(2, len(shp_x))) if self.apply_nonlin is not None: x = self.apply_nonlin(x) with torch.no_grad(): if len(shp_x) != len(shp_y): y = y.view((shp_y[0], 1, *shp_y[1:])) if all([i == j for i, j in zip(x.shape, y.shape)]): # if this is the case then gt is probably already a one hot encoding y_onehot = y else: y = y.long() y_onehot = torch.zeros(shp_x) if x.device.type == "cuda": y_onehot = y_onehot.cuda(x.device.index) y_onehot.scatter_(1, y, 1).float() intersect = x * y_onehot # values in the denominator get smoothed denominator = x**2 + y_onehot**2 # aggregation was previously done in get_tp_fp_fn, but needs to be done here now (needs to be done after # squaring) intersect = sum_tensor(intersect, axes, False) + self.smooth denominator = sum_tensor(denominator, axes, False) + self.smooth dc = 2 * intersect / denominator if not self.do_bg: if self.batch_dice: dc = dc[1:] else: dc = dc[:, 1:] dc = dc.mean() return -dc
def run_online_evaluation(self, output, target): with torch.no_grad(): num_classes = output[0].shape[1] output_seg = output[0].argmax(1) target = target[0][:, 0] axes = tuple(range(1, len(target.shape))) tp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fn_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) for c in range(1, num_classes): tp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, c - 1] = sum_tensor( (output_seg != c).float() * (target == c).float(), axes=axes) # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target, # axes, None) # print_if_rank0("before allgather", tp_hard.shape) tp_hard = tp_hard.sum(0, keepdim=False)[None] fp_hard = fp_hard.sum(0, keepdim=False)[None] fn_hard = fn_hard.sum(0, keepdim=False)[None] tp_hard = awesome_allgather_function.apply(tp_hard) fp_hard = awesome_allgather_function.apply(fp_hard) fn_hard = awesome_allgather_function.apply(fn_hard) tp_hard = tp_hard.detach().cpu().numpy().sum(0) fp_hard = fp_hard.detach().cpu().numpy().sum(0) fn_hard = fn_hard.detach().cpu().numpy().sum(0) self.online_eval_foreground_dc.append( list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8))) self.online_eval_tp.append(list(tp_hard)) self.online_eval_fp.append(list(fp_hard)) self.online_eval_fn.append(list(fn_hard))
def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False): """ net_output must be (b, c, x, y(, z))) gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z)) if mask is provided it must have shape (b, 1, x, y(, z))) :param net_output: :param gt: :param axes: can be (, ) = no summation :param mask: mask must be 1 for valid pixels and 0 for invalid pixels :param square: if True then fp, tp and fn will be squared before summation :return: """ if axes is None: axes = tuple(range(2, len(net_output.size()))) shp_x = net_output.shape shp_y = gt.shape with torch.no_grad(): if len(shp_x) != len(shp_y): gt = gt.view((shp_y[0], 1, *shp_y[1:])) if all([i == j for i, j in zip(net_output.shape, gt.shape)]): # if this is the case then gt is probably already a one hot encoding y_onehot = gt else: gt = gt.long() y_onehot = torch.zeros(shp_x) if net_output.device.type == "cuda": y_onehot = y_onehot.cuda(net_output.device.index) y_onehot.scatter_(1, gt, 1) tp = net_output * y_onehot fp = net_output * (1 - y_onehot) fn = (1 - net_output) * y_onehot tn = (1 - net_output) * (1 - y_onehot) if mask is not None: tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1) fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1) fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1) tn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tn, dim=1)), dim=1) if square: tp = tp**2 fp = fp**2 fn = fn**2 tn = tn**2 if len(axes) > 0: tp = sum_tensor(tp, axes, keepdim=False) fp = sum_tensor(fp, axes, keepdim=False) fn = sum_tensor(fn, axes, keepdim=False) tn = sum_tensor(tn, axes, keepdim=False) return tp, fp, fn, tn
def forward(self, x, y=None, return_hard_tp_fp_fn=False): res = super(Generic_UNet_DP, self).forward(x) # regular Generic_UNet forward pass if y is None: return res else: # compute ce loss if self._deep_supervision and self.do_ds: ce_losses = [self.ce_loss(res[0], y[0]).unsqueeze(0)] tps = [] fps = [] fns = [] res_softmax = softmax_helper(res[0]) tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[0]) tps.append(tp) fps.append(fp) fns.append(fn) for i in range(1, len(y)): ce_losses.append(self.ce_loss(res[i], y[i]).unsqueeze(0)) res_softmax = softmax_helper(res[i]) tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[i]) tps.append(tp) fps.append(fp) fns.append(fn) ret = ce_losses, tps, fps, fns else: ce_loss = self.ce_loss(res, y).unsqueeze(0) # tp fp and fn need the output to be softmax res_softmax = softmax_helper(res) tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y) ret = ce_loss, tp, fp, fn if return_hard_tp_fp_fn: if self._deep_supervision and self.do_ds: output = res[0] target = y[0] else: target = y output = res with torch.no_grad(): num_classes = output.shape[1] output_softmax = softmax_helper(output) output_seg = output_softmax.argmax(1) target = target[:, 0] axes = tuple(range(1, len(target.shape))) tp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fp_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) fn_hard = torch.zeros( (target.shape[0], num_classes - 1)).to(output_seg.device.index) for c in range(1, num_classes): tp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target == c).float(), axes=axes) fp_hard[:, c - 1] = sum_tensor( (output_seg == c).float() * (target != c).float(), axes=axes) fn_hard[:, c - 1] = sum_tensor( (output_seg != c).float() * (target == c).float(), axes=axes) tp_hard = tp_hard.sum(0, keepdim=False)[None] fp_hard = fp_hard.sum(0, keepdim=False)[None] fn_hard = fn_hard.sum(0, keepdim=False)[None] ret = *ret, tp_hard, fp_hard, fn_hard return ret