Ejemplo n.º 1
0
    def compute_loss(self, output, target):
        total_loss = None
        for i in range(len(output)):
            # Starting here it gets spicy!
            axes = tuple(range(2, len(output[i].size())))

            # network does not do softmax. We need to do softmax for dice
            output_softmax = softmax_helper(output[i])

            # get the tp, fp and fn terms we need
            tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None)
            # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables
            # do_bg=False in nnUNetTrainer -> [:, 1:]
            nominator = 2 * tp[:, 1:]
            denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:]

            if self.batch_dice:
                # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice
                nominator = awesome_allgather_function.apply(nominator)
                denominator = awesome_allgather_function.apply(denominator)
                nominator = nominator.sum(0)
                denominator = denominator.sum(0)
            else:
                pass

            ce_loss = self.ce_loss(output[i], target[i][:, 0].long())

            # we smooth by 1e-5 to penalize false positives if tp is 0
            dice_loss = (- (nominator + 1e-5) / (denominator + 1e-5)).mean()
            if total_loss is None:
                total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss)
            else:
                total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss)
        return total_loss
Ejemplo n.º 2
0
    def run_online_evaluation(self, output, target):
        output = output[0]
        target = target[0]
        with torch.no_grad():
            out_sigmoid = torch.sigmoid(output)
            out_sigmoid = (out_sigmoid > 0.5).float()

            if self.threeD:
                axes = (0, 2, 3, 4)
            else:
                axes = (0, 2, 3)

            tp, fp, fn, _ = get_tp_fp_fn_tn(out_sigmoid, target, axes=axes)

            tp_hard = tp.detach().cpu().numpy()
            fp_hard = fp.detach().cpu().numpy()
            fn_hard = fn.detach().cpu().numpy()

            self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
            self.online_eval_tp.append(list(tp_hard))
            self.online_eval_fp.append(list(fp_hard))
            self.online_eval_fn.append(list(fn_hard))
Ejemplo n.º 3
0
    def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
        raise NotImplementedError("this class has not been changed to work with pytorch amp yet!")
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data, gpu_id=None)
            target = to_cuda(target, gpu_id=None)

        self.optimizer.zero_grad()

        output = self.network(data)
        del data

        total_loss = None

        for i in range(len(output)):
            # Starting here it gets spicy!
            axes = tuple(range(2, len(output[i].size())))

            # network does not do softmax. We need to do softmax for dice
            output_softmax = torch.sigmoid(output[i])

            # get the tp, fp and fn terms we need
            tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None)
            # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables
            # do_bg=False in nnUNetTrainer -> [:, 1:]
            nominator = 2 * tp[:, 1:]
            denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:]

            if self.batch_dice:
                # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice
                nominator = awesome_allgather_function.apply(nominator)
                denominator = awesome_allgather_function.apply(denominator)
                nominator = nominator.sum(0)
                denominator = denominator.sum(0)
            else:
                pass

            ce_loss = self.ce_loss(output[i], target[i])

            # we smooth by 1e-5 to penalize false positives if tp is 0
            dice_loss = (- (nominator + 1e-5) / (denominator + 1e-5)).mean()
            if total_loss is None:
                total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss)
            else:
                total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss)

        if run_online_evaluation:
            with torch.no_grad():
                output = output[0]
                target = target[0]
                out_sigmoid = torch.sigmoid(output)
                out_sigmoid = (out_sigmoid > 0.5).float()

                if self.threeD:
                    axes = (2, 3, 4)
                else:
                    axes = (2, 3)

                tp, fp, fn, _ = get_tp_fp_fn_tn(out_sigmoid, target, axes=axes)

                tp_hard = awesome_allgather_function.apply(tp)
                fp_hard = awesome_allgather_function.apply(fp)
                fn_hard = awesome_allgather_function.apply(fn)
                # print_if_rank0("after allgather", tp_hard.shape)

                # print_if_rank0("after sum", tp_hard.shape)

                self.run_online_evaluation(tp_hard.detach().cpu().numpy().sum(0),
                                           fp_hard.detach().cpu().numpy().sum(0),
                                           fn_hard.detach().cpu().numpy().sum(0))
        del target

        if do_backprop:
            if not self.fp16 or amp is None or not torch.cuda.is_available():
                total_loss.backward()
            else:
                with amp.scale_loss(total_loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            _ = clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()

        return total_loss.detach().cpu().numpy()
Ejemplo n.º 4
0
    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        data = to_cuda(data, gpu_id=None)
        target = to_cuda(target, gpu_id=None)

        self.optimizer.zero_grad()

        output = self.network(data)
        del data

        total_loss = None
        for i in range(len(output)):
            # Starting here it gets spicy!
            axes = tuple(range(2, len(output[i].size())))

            # network does not do softmax. We need to do softmax for dice
            output_softmax = softmax_helper(output[i])

            # get the tp, fp and fn terms we need
            tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax,
                                            target[i],
                                            axes,
                                            mask=None)
            # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables
            # do_bg=False in nnUNetTrainer -> [:, 1:]
            nominator = 2 * tp[:, 1:]
            denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:]

            if self.batch_dice:
                # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice
                nominator = awesome_allgather_function.apply(nominator)
                denominator = awesome_allgather_function.apply(denominator)
                nominator = nominator.sum(0)
                denominator = denominator.sum(0)
            else:
                pass

            ce_loss = self.ce_loss(output[i], target[i])

            # we smooth by 1e-5 to penalize false positives if tp is 0
            dice_loss = (-(nominator + 1e-5) / (denominator + 1e-5)).mean()
            if total_loss is None:
                total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss)
            else:
                total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss)

        if run_online_evaluation:
            with torch.no_grad():
                num_classes = output[0].shape[1]
                output_seg = output[0].argmax(1)
                target = target[0][:, 0]
                axes = tuple(range(1, len(target.shape)))
                tp_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                fp_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                fn_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                for c in range(1, num_classes):
                    tp_hard[:, c - 1] = sum_tensor(
                        (output_seg == c).float() * (target == c).float(),
                        axes=axes)
                    fp_hard[:, c - 1] = sum_tensor(
                        (output_seg == c).float() * (target != c).float(),
                        axes=axes)
                    fn_hard[:, c - 1] = sum_tensor(
                        (output_seg != c).float() * (target == c).float(),
                        axes=axes)

                # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target,
                #                                         axes, None)
                # print_if_rank0("before allgather", tp_hard.shape)
                tp_hard = tp_hard.sum(0, keepdim=False)[None]
                fp_hard = fp_hard.sum(0, keepdim=False)[None]
                fn_hard = fn_hard.sum(0, keepdim=False)[None]

                tp_hard = awesome_allgather_function.apply(tp_hard)
                fp_hard = awesome_allgather_function.apply(fp_hard)
                fn_hard = awesome_allgather_function.apply(fn_hard)
                # print_if_rank0("after allgather", tp_hard.shape)

                # print_if_rank0("after sum", tp_hard.shape)

                self.run_online_evaluation(
                    tp_hard.detach().cpu().numpy().sum(0),
                    fp_hard.detach().cpu().numpy().sum(0),
                    fn_hard.detach().cpu().numpy().sum(0))
        del target

        if do_backprop:
            if not self.fp16 or amp is None:
                total_loss.backward()
            else:
                with amp.scale_loss(total_loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            _ = clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()

        return total_loss.detach().cpu().numpy()
Ejemplo n.º 5
0
    def forward(self, x, y=None, return_hard_tp_fp_fn=False):
        res = super(Generic_UNet_DP,
                    self).forward(x)  # regular Generic_UNet forward pass

        if y is None:
            return res
        else:
            # compute ce loss
            if self._deep_supervision and self.do_ds:
                ce_losses = [self.ce_loss(res[0], y[0]).unsqueeze(0)]
                tps = []
                fps = []
                fns = []

                res_softmax = softmax_helper(res[0])
                tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[0])
                tps.append(tp)
                fps.append(fp)
                fns.append(fn)
                for i in range(1, len(y)):
                    ce_losses.append(self.ce_loss(res[i], y[i]).unsqueeze(0))
                    res_softmax = softmax_helper(res[i])
                    tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[i])
                    tps.append(tp)
                    fps.append(fp)
                    fns.append(fn)
                ret = ce_losses, tps, fps, fns
            else:
                ce_loss = self.ce_loss(res, y).unsqueeze(0)

                # tp fp and fn need the output to be softmax
                res_softmax = softmax_helper(res)

                tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y)

                ret = ce_loss, tp, fp, fn

            if return_hard_tp_fp_fn:
                if self._deep_supervision and self.do_ds:
                    output = res[0]
                    target = y[0]
                else:
                    target = y
                    output = res

                with torch.no_grad():
                    num_classes = output.shape[1]
                    output_softmax = softmax_helper(output)
                    output_seg = output_softmax.argmax(1)
                    target = target[:, 0]
                    axes = tuple(range(1, len(target.shape)))
                    tp_hard = torch.zeros(
                        (target.shape[0],
                         num_classes - 1)).to(output_seg.device.index)
                    fp_hard = torch.zeros(
                        (target.shape[0],
                         num_classes - 1)).to(output_seg.device.index)
                    fn_hard = torch.zeros(
                        (target.shape[0],
                         num_classes - 1)).to(output_seg.device.index)
                    for c in range(1, num_classes):
                        tp_hard[:, c - 1] = sum_tensor(
                            (output_seg == c).float() * (target == c).float(),
                            axes=axes)
                        fp_hard[:, c - 1] = sum_tensor(
                            (output_seg == c).float() * (target != c).float(),
                            axes=axes)
                        fn_hard[:, c - 1] = sum_tensor(
                            (output_seg != c).float() * (target == c).float(),
                            axes=axes)

                    tp_hard = tp_hard.sum(0, keepdim=False)[None]
                    fp_hard = fp_hard.sum(0, keepdim=False)[None]
                    fn_hard = fn_hard.sum(0, keepdim=False)[None]

                    ret = *ret, tp_hard, fp_hard, fn_hard
            return ret