Пример #1
0
    def forward(self, x, y, loss_mask=None):
        print("saw_loss.py:22")

        # create gt density map
        y_cpu = y.cpu().numpy()
        dm = np.empty_like(y_cpu[:, 0:1])
        for i in range(y.shape[0]):
            dm[i, 0] = self.sharpen(y_cpu[i, 0])
            # dm[i, 0] = self.pixel(y_cpu[i, 0])
        # self.save_img(dm, '/cluster/husvogt/debug_imgs/{:04d}_{:03d}.png')
        dm = torch.from_numpy(dm).cuda()
        y_n_ma = torch.sum(dm)
        x_n_ma = torch.sum(x[:, 3])  # -1: = 3:

        print("sum x:", x_n_ma)
        print("sum dm:", y_n_ma)

        l_ = self.dice(x[:, :2], y)  #, loss_mask=loss_mask)
        print("l_:", l_)
        # print("shapes", x.shape, dm.shape)
        l_dm = self.loss_density_map(softmax_helper(x[:, 2:]), dm)
        print("l_dm:", l_dm)
        l_n = self.loss_n_ma(x_n_ma, y_n_ma)
        print("l_n:", l_n)

        return l_ + l_dm + 1e-6 * l_n  # + l_dm + l_n
Пример #2
0
    def run_online_evaluation(self, output, target):
        with torch.no_grad():
            num_classes = output.shape[1]
            output_softmax = softmax_helper(output)
            output_seg = output_softmax.argmax(1)
            target = target[:, 0]
            axes = tuple(range(1, len(target.shape)))
            tp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fp_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            fn_hard = torch.zeros(
                (target.shape[0], num_classes - 1)).to(output_seg.device.index)
            for c in range(1, num_classes):
                tp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target == c).float(),
                    axes=axes)
                fp_hard[:, c - 1] = sum_tensor(
                    (output_seg == c).float() * (target != c).float(),
                    axes=axes)
                fn_hard[:, c - 1] = sum_tensor(
                    (output_seg != c).float() * (target == c).float(),
                    axes=axes)

            tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy()
            fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy()
            fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy()

            self.online_eval_foreground_dc.append(
                list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
            self.online_eval_tp.append(list(tp_hard))
            self.online_eval_fp.append(list(fp_hard))
            self.online_eval_fn.append(list(fn_hard))
Пример #3
0
    def compute_loss(self, output, target):
        total_loss = None
        for i in range(len(output)):
            # Starting here it gets spicy!
            axes = tuple(range(2, len(output[i].size())))

            # network does not do softmax. We need to do softmax for dice
            output_softmax = softmax_helper(output[i])

            # get the tp, fp and fn terms we need
            tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None)
            # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables
            # do_bg=False in nnUNetTrainer -> [:, 1:]
            nominator = 2 * tp[:, 1:]
            denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:]

            if self.batch_dice:
                # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice
                nominator = awesome_allgather_function.apply(nominator)
                denominator = awesome_allgather_function.apply(denominator)
                nominator = nominator.sum(0)
                denominator = denominator.sum(0)
            else:
                pass

            ce_loss = self.ce_loss(output[i], target[i][:, 0].long())

            # we smooth by 1e-5 to penalize false positives if tp is 0
            dice_loss = (- (nominator + 1e-5) / (denominator + 1e-5)).mean()
            if total_loss is None:
                total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss)
            else:
                total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss)
        return total_loss
Пример #4
0
    def forward(self, net_output, gt):
        """
        net_output: (batch_size, class, x,y,z)
        target: ground truth, shape: (batch_size, 1, x,y,z)
        bound: precomputed distance map, shape (batch_size, class, x,y,z)
        """
        net_output = softmax_helper(net_output)
        with torch.no_grad():
            if len(net_output.shape) != len(gt.shape):
                gt = gt.view((gt.shape[0], 1, *gt.shape[1:]))

            if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
                # if this is the case then gt is probably already a one hot encoding
                y_onehot = gt
            else:
                gt = gt.long()
                y_onehot = torch.zeros(net_output.shape)
                if net_output.device.type == "cuda":
                    y_onehot = y_onehot.cuda(net_output.device.index)
                y_onehot.scatter_(1, gt, 1)
            gt_sdf = compute_sdf(y_onehot.cpu().numpy(), net_output.shape)

        phi = torch.from_numpy(gt_sdf)
        if phi.device != net_output.device:
            phi = phi.to(net_output.device).type(torch.float32)
        # pred = net_output[:, 1:, ...].type(torch.float32)
        # phi = phi[:,1:, ...].type(torch.float32)

        multipled = torch.einsum("bcxyz,bcxyz->bcxyz", net_output[:, 1:, ...], phi[:, 1:, ...])
        bd_loss = multipled.mean()

        return bd_loss
Пример #5
0
    def forward(self, net_output, target, bound):
        """
        net_output: (batch_size, class, x,y,z)
        target: ground truth, shape: (batch_size, 1, x,y,z)
        bound: precomputed distance map, shape (batch_size, class, x,y,z)
        """
        net_output = softmax_helper(net_output)
        # print('net_output shape: ', net_output.shape)
        pc = net_output[:, 1:, ...].type(torch.float32)
        dc = bound[:, 1:, ...].type(torch.float32)

        multipled = torch.einsum("bcxyz,bcxyz->bcxyz", pc, dc)
        bd_loss = multipled.mean()

        return bd_loss
Пример #6
0
    def forward(self, net_output, gt):
        """
        net_output: (batch_size, 2, x,y,z)
        target: ground truth, shape: (batch_size, 1, x,y,z)
        """
        net_output = softmax_helper(net_output)
        # one hot code for gt
        with torch.no_grad():
            if len(net_output.shape) != len(gt.shape):
                gt = gt.view((gt.shape[0], 1, *gt.shape[1:]))

            if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
                # if this is the case then gt is probably already a one hot encoding
                y_onehot = gt
            else:
                gt = gt.long()
                y_onehot = torch.zeros(net_output.shape)
                if net_output.device.type == "cuda":
                    y_onehot = y_onehot.cuda(net_output.device.index)
                y_onehot.scatter_(1, gt, 1)

        gt_temp = gt[:, 0, ...].type(torch.float32)
        with torch.no_grad():
            dist = compute_edts_forPenalizedLoss(
                gt_temp.cpu().numpy() > 0.5) + 1.0
        # print('dist.shape: ', dist.shape)
        dist = torch.from_numpy(dist)

        if dist.device != net_output.device:
            dist = dist.to(net_output.device).type(torch.float32)

        tp = net_output * y_onehot
        tp = torch.sum(tp[:, 1, ...] * dist, (1, 2, 3))

        dc = (2 * tp + self.smooth) / (torch.sum(net_output[:, 1, ...],
                                                 (1, 2, 3)) +
                                       torch.sum(y_onehot[:, 1, ...],
                                                 (1, 2, 3)) + self.smooth)

        dc = dc.mean()

        return -dc
Пример #7
0
    def forward(self, inp, target):
        target = target[:, 0].long()
        res = super(TopKThreshold, self).forward(inp, target)
        with torch.no_grad():
            prob = softmax_helper(inp)
            num_classes = inp.size()[1]
            i0 = 1
            i1 = 2

            while i1 < len(prob.shape): # this is ugly but torch only allows to transpose two axes at once
                prob = prob.transpose(i0, i1)
                i0 += 1
                i1 += 1

            prob = prob.contiguous()
            prob = prob.view(-1, num_classes)
            prob, _ = torch.max(prob,-1)

        # print('res.shape:', res.shape, 'inp.shape:', inp.shape, 'prob.shape:', prob.shape)
        res = res[prob<self.threshold]
        return res.mean()
Пример #8
0
    def forward(self, net_output, gt):
        """
        net_output: (batch_size, c, x,y,z)
        target: ground truth, shape: (batch_size, c, x,y,z)
        """
        net_output = softmax_helper(net_output)
        # one hot code for gt
        with torch.no_grad():
            if len(net_output.shape) != len(gt.shape):
                gt = gt.view((gt.shape[0], 1, *gt.shape[1:]))

            if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
                # if this is the case then gt is probably already a one hot encoding
                y_onehot = gt
            else:
                gt = gt.long()
                y_onehot = torch.zeros(net_output.shape)
                if net_output.device.type == "cuda":
                    y_onehot = y_onehot.cuda(net_output.device.index)
                y_onehot.scatter_(1, gt, 1)
        # print('hd loss.py', net_output.shape, y_onehot.shape)

        with torch.no_grad():
            pc_dist = compute_pred_dtm(net_output.cpu().numpy(), net_output.shape)
            gt_dist = compute_gt_dtm(y_onehot.cpu().numpy(), net_output.shape)
            dist = pc_dist**2 + gt_dist**2 # \alpha=2 in eq(8)
            # print('pc_dist.shape: ', pc_dist.shape, 'gt_dist.shape', gt_dist.shape)
        
        pred_error = (net_output - y_onehot)**2

        dist = torch.from_numpy(dist)
        if dist.device != pred_error.device:
            dist = dist.to(pred_error.device).type(torch.float32)

        multipled = torch.einsum("bcxyz,bcxyz->bcxyz", pred_error[:,1:,...], dist[:,1:,...])
        hd_loss = multipled.mean()

        return hd_loss
Пример #9
0
    def forward(self, net_output, target):
        """
        net_output: (batch_size, 2, x,y,z)
        target: ground truth, shape: (batch_size, 1, x,y,z)
        """
        net_output = softmax_helper(net_output)
        pc = net_output[:, 1, ...].type(torch.float32)
        gt = target[:, 0, ...].type(torch.float32)
        with torch.no_grad():
            pc_dist = compute_edts_forhdloss(pc.cpu().numpy() > 0.5)
            gt_dist = compute_edts_forhdloss(gt.cpu().numpy() > 0.5)
        # print('pc_dist.shape: ', pc_dist.shape)

        pred_error = (gt - pc)**2
        dist = pc_dist**2 + gt_dist**2  # \alpha=2 in eq(8)

        dist = torch.from_numpy(dist)
        if dist.device != pred_error.device:
            dist = dist.to(pred_error.device).type(torch.float32)

        multipled = torch.einsum("bxyz,bxyz->bxyz", pred_error, dist)
        hd_loss = multipled.mean()

        return hd_loss
Пример #10
0
    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        data = to_cuda(data, gpu_id=None)
        target = to_cuda(target, gpu_id=None)

        self.optimizer.zero_grad()

        output = self.network(data)
        del data

        total_loss = None
        for i in range(len(output)):
            # Starting here it gets spicy!
            axes = tuple(range(2, len(output[i].size())))

            # network does not do softmax. We need to do softmax for dice
            output_softmax = softmax_helper(output[i])

            # get the tp, fp and fn terms we need
            tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax,
                                            target[i],
                                            axes,
                                            mask=None)
            # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables
            # do_bg=False in nnUNetTrainer -> [:, 1:]
            nominator = 2 * tp[:, 1:]
            denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:]

            if self.batch_dice:
                # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice
                nominator = awesome_allgather_function.apply(nominator)
                denominator = awesome_allgather_function.apply(denominator)
                nominator = nominator.sum(0)
                denominator = denominator.sum(0)
            else:
                pass

            ce_loss = self.ce_loss(output[i], target[i])

            # we smooth by 1e-5 to penalize false positives if tp is 0
            dice_loss = (-(nominator + 1e-5) / (denominator + 1e-5)).mean()
            if total_loss is None:
                total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss)
            else:
                total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss)

        if run_online_evaluation:
            with torch.no_grad():
                num_classes = output[0].shape[1]
                output_seg = output[0].argmax(1)
                target = target[0][:, 0]
                axes = tuple(range(1, len(target.shape)))
                tp_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                fp_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                fn_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                for c in range(1, num_classes):
                    tp_hard[:, c - 1] = sum_tensor(
                        (output_seg == c).float() * (target == c).float(),
                        axes=axes)
                    fp_hard[:, c - 1] = sum_tensor(
                        (output_seg == c).float() * (target != c).float(),
                        axes=axes)
                    fn_hard[:, c - 1] = sum_tensor(
                        (output_seg != c).float() * (target == c).float(),
                        axes=axes)

                # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target,
                #                                         axes, None)
                # print_if_rank0("before allgather", tp_hard.shape)
                tp_hard = tp_hard.sum(0, keepdim=False)[None]
                fp_hard = fp_hard.sum(0, keepdim=False)[None]
                fn_hard = fn_hard.sum(0, keepdim=False)[None]

                tp_hard = awesome_allgather_function.apply(tp_hard)
                fp_hard = awesome_allgather_function.apply(fp_hard)
                fn_hard = awesome_allgather_function.apply(fn_hard)
                # print_if_rank0("after allgather", tp_hard.shape)

                # print_if_rank0("after sum", tp_hard.shape)

                self.run_online_evaluation(
                    tp_hard.detach().cpu().numpy().sum(0),
                    fp_hard.detach().cpu().numpy().sum(0),
                    fn_hard.detach().cpu().numpy().sum(0))
        del target

        if do_backprop:
            if not self.fp16 or amp is None:
                total_loss.backward()
            else:
                with amp.scale_loss(total_loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            _ = clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()

        return total_loss.detach().cpu().numpy()
Пример #11
0
    def forward(self, x, y=None, return_hard_tp_fp_fn=False):
        res = super(Generic_UNet_DP,
                    self).forward(x)  # regular Generic_UNet forward pass

        if y is None:
            return res
        else:
            # compute ce loss
            if self._deep_supervision and self.do_ds:
                ce_losses = [self.ce_loss(res[0], y[0]).unsqueeze(0)]
                tps = []
                fps = []
                fns = []

                res_softmax = softmax_helper(res[0])
                tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[0])
                tps.append(tp)
                fps.append(fp)
                fns.append(fn)
                for i in range(1, len(y)):
                    ce_losses.append(self.ce_loss(res[i], y[i]).unsqueeze(0))
                    res_softmax = softmax_helper(res[i])
                    tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[i])
                    tps.append(tp)
                    fps.append(fp)
                    fns.append(fn)
                ret = ce_losses, tps, fps, fns
            else:
                ce_loss = self.ce_loss(res, y).unsqueeze(0)

                # tp fp and fn need the output to be softmax
                res_softmax = softmax_helper(res)

                tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y)

                ret = ce_loss, tp, fp, fn

            if return_hard_tp_fp_fn:
                if self._deep_supervision and self.do_ds:
                    output = res[0]
                    target = y[0]
                else:
                    target = y
                    output = res

                with torch.no_grad():
                    num_classes = output.shape[1]
                    output_softmax = softmax_helper(output)
                    output_seg = output_softmax.argmax(1)
                    target = target[:, 0]
                    axes = tuple(range(1, len(target.shape)))
                    tp_hard = torch.zeros(
                        (target.shape[0],
                         num_classes - 1)).to(output_seg.device.index)
                    fp_hard = torch.zeros(
                        (target.shape[0],
                         num_classes - 1)).to(output_seg.device.index)
                    fn_hard = torch.zeros(
                        (target.shape[0],
                         num_classes - 1)).to(output_seg.device.index)
                    for c in range(1, num_classes):
                        tp_hard[:, c - 1] = sum_tensor(
                            (output_seg == c).float() * (target == c).float(),
                            axes=axes)
                        fp_hard[:, c - 1] = sum_tensor(
                            (output_seg == c).float() * (target != c).float(),
                            axes=axes)
                        fn_hard[:, c - 1] = sum_tensor(
                            (output_seg != c).float() * (target == c).float(),
                            axes=axes)

                    tp_hard = tp_hard.sum(0, keepdim=False)[None]
                    fp_hard = fp_hard.sum(0, keepdim=False)[None]
                    fn_hard = fn_hard.sum(0, keepdim=False)[None]

                    ret = *ret, tp_hard, fp_hard, fn_hard
            return ret