Ejemplo n.º 1
0
    def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
        """
        gradient clipping improves training stability

        :param data_generator:
        :param do_backprop:
        :param run_online_evaluation:
        :return:
        """
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data)
            target = to_cuda(target)

        self.optimizer.zero_grad()

        if self.fp16:
            with autocast():
                output, conv, conv1x1 = self.network(data)
                del data
                
                # l = self.loss(output, target)
                loss_weights = [0.7, 0.2, 0.1]
                l = loss_weights[0] * self.loss(output, target)
                # print("output loss", l)
                l += loss_weights[1] * self.loss(conv, target)
                # print("conv loss", l)
                l += loss_weights[2] * self.loss(conv1x1, target)
                # print("conv1x1 loss", l)

            if do_backprop:
                self.amp_grad_scaler.scale(l).backward()
                self.amp_grad_scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
                self.amp_grad_scaler.step(self.optimizer)
                self.amp_grad_scaler.update()
        else:
            output = self.network(data)
            del data
            l = self.loss(output, target)

            if do_backprop:
                l.backward()
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
                self.optimizer.step()

        if run_online_evaluation:
            self.run_online_evaluation(output, target)

        del target

        return l.detach().cpu().numpy()
Ejemplo n.º 2
0
    def _internal_maybe_mirror_and_pred_3D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple,
                                           do_mirroring: bool = True,
                                           mult: np.ndarray or torch.tensor = None) -> torch.tensor:
        assert len(x.shape) == 5, 'x must be (b, c, x, y, z)'
        # everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here
        # we now return a cuda tensor! Not numpy array!

        x = to_cuda(maybe_to_torch(x), gpu_id=self.get_device())
        result_torch = torch.zeros([1, self.num_classes] + list(x.shape[2:]),
                                   dtype=torch.float).cuda(self.get_device(), non_blocking=True)

        if mult is not None:
            mult = to_cuda(maybe_to_torch(mult), gpu_id=self.get_device())

        if do_mirroring:
            mirror_idx = 8
            num_results = 2 ** len(mirror_axes)
        else:
            mirror_idx = 1
            num_results = 1

        for m in range(mirror_idx):
            if m == 0:
                pred = self.inference_apply_nonlin(self(x))
                result_torch += 1 / num_results * pred

            if m == 1 and (2 in mirror_axes):
                pred = self.inference_apply_nonlin(self(torch.flip(x, (4, ))))
                result_torch += 1 / num_results * torch.flip(pred, (4,))

            if m == 2 and (1 in mirror_axes):
                pred = self.inference_apply_nonlin(self(torch.flip(x, (3, ))))
                result_torch += 1 / num_results * torch.flip(pred, (3,))

            if m == 3 and (2 in mirror_axes) and (1 in mirror_axes):
                pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3))))
                result_torch += 1 / num_results * torch.flip(pred, (4, 3))

            if m == 4 and (0 in mirror_axes):
                pred = self.inference_apply_nonlin(self(torch.flip(x, (2, ))))
                result_torch += 1 / num_results * torch.flip(pred, (2,))

            if m == 5 and (0 in mirror_axes) and (2 in mirror_axes):
                pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 2))))
                result_torch += 1 / num_results * torch.flip(pred, (4, 2))

            if m == 6 and (0 in mirror_axes) and (1 in mirror_axes):
                pred = self.inference_apply_nonlin(self(torch.flip(x, (3, 2))))
                result_torch += 1 / num_results * torch.flip(pred, (3, 2))

            if m == 7 and (0 in mirror_axes) and (1 in mirror_axes) and (2 in mirror_axes):
                pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3, 2))))
                result_torch += 1 / num_results * torch.flip(pred, (4, 3, 2))

        if mult is not None:
            result_torch[:, :] *= mult

        return result_torch
Ejemplo n.º 3
0
    def run_iteration_nas(self,
                          tr_data_generator,
                          val_data_generator,
                          do_backprop=True,
                          run_online_evaluation=False):

        tr_data_dict = next(tr_data_generator)
        tr_data = tr_data_dict['data']
        tr_target = tr_data_dict['target']

        val_data_dict = next(val_data_generator)
        val_data = val_data_dict['data']
        val_target = val_data_dict['target']

        tr_data = maybe_to_torch(tr_data)
        tr_target = maybe_to_torch(tr_target)
        val_data = maybe_to_torch(val_data)
        val_target = maybe_to_torch(val_target)

        if torch.cuda.is_available():
            tr_data = to_cuda(tr_data)
            tr_target = to_cuda(tr_target)
            val_target = to_cuda(val_target)
            val_data = to_cuda(val_data)
            #val_target = Variable(val_target, requires_grad=False).cuda()

        self.architect.step(tr_data,
                            tr_target,
                            val_data,
                            val_target,
                            self.lr,
                            self.optimizer,
                            unrolled=False)
        del val_data
        del val_target

        self.optimizer.zero_grad()

        output = self.network(tr_data)

        del tr_data
        loss = self.loss(output, tr_target)

        if run_online_evaluation:
            self.run_online_evaluation(output, tr_target)
        del tr_target

        if do_backprop:
            if not self.fp16 or amp is None or not torch.cuda.is_available():
                loss.backward()
            else:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            _ = clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()

        return loss.detach().cpu().numpy()
Ejemplo n.º 4
0
    def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
        """
        gradient clipping improves training stability

        :param data_generator:
        :param do_backprop:
        :param run_online_evaluation:
        :return:
        """
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data)
            target = to_cuda(target)

        self.optimizer.zero_grad()

        if self.fp16:
            with autocast():
                output = self.network(data)
                """print("!!!!!!!!!!!!!!!!!!!!!!!!")
                print("data",data.shape)

                print("output.shape",output.shape)

                print("target.shape", target[0].shape)"""
                del data
                l = self.loss(output, target)

            if do_backprop:
                self.amp_grad_scaler.scale(l).backward()
                self.amp_grad_scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
                self.amp_grad_scaler.step(self.optimizer)
                self.amp_grad_scaler.update()
        else:
            output = self.network(data)
            del data
            l = self.loss(output, target)

            if do_backprop:
                l.backward()
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
                self.optimizer.step()

        if run_online_evaluation:
            self.run_online_evaluation(output, target)

        del target

        return l.detach().cpu().numpy()
Ejemplo n.º 5
0
    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data)
            target = to_cuda(target)

        self.optimizer.zero_grad()

        if self.fp16:
            with autocast():
                ret = self.network(data,
                                   target,
                                   return_hard_tp_fp_fn=run_online_evaluation)
                if run_online_evaluation:
                    ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret
                    self.run_online_evaluation(tp_hard, fp_hard, fn_hard)
                else:
                    ces, tps, fps, fns = ret
                del data, target
                l = self.compute_loss(ces, tps, fps, fns)

            if do_backprop:
                self.amp_grad_scaler.scale(l).backward()
                self.amp_grad_scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
                self.amp_grad_scaler.step(self.optimizer)
                self.amp_grad_scaler.update()
        else:
            ret = self.network(data,
                               target,
                               return_hard_tp_fp_fn=run_online_evaluation)
            if run_online_evaluation:
                ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret
                self.run_online_evaluation(tp_hard, fp_hard, fn_hard)
            else:
                ces, tps, fps, fns = ret
            del data, target
            l = self.compute_loss(ces, tps, fps, fns)

            if do_backprop:
                l.backward()
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
                self.optimizer.step()

        return l.detach().cpu().numpy()
Ejemplo n.º 6
0
    def _internal_maybe_mirror_and_pred_3D(self, x, num_repeats, mirror_axes, do_mirroring=True, mult=None):
        # everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here
        # we now return a cuda tensor! Not numpy array!
        with torch.no_grad():
            x = to_cuda(maybe_to_torch(x), gpu_id=self.get_device())
            result_torch = torch.zeros([1, self.num_classes] + list(x.shape[2:]),
                                       dtype=torch.float).cuda(self.get_device(), non_blocking=True)
            mult = to_cuda(maybe_to_torch(mult), gpu_id=self.get_device())

            num_results = num_repeats
            if do_mirroring:
                mirror_idx = 8
                num_results *= 2 ** len(mirror_axes)
            else:
                mirror_idx = 1

            for i in range(num_repeats):
                for m in range(mirror_idx):
                    if m == 0:
                        pred = self.inference_apply_nonlin(self(x))
                        result_torch += 1 / num_results * pred

                    if m == 1 and (2 in mirror_axes):
                        pred = self.inference_apply_nonlin(self(flip(x, 4)))
                        result_torch += 1 / num_results * flip(pred, 4)

                    if m == 2 and (1 in mirror_axes):
                        pred = self.inference_apply_nonlin(self(flip(x, 3)))
                        result_torch += 1 / num_results * flip(pred, 3)

                    if m == 3 and (2 in mirror_axes) and (1 in mirror_axes):
                        pred = self.inference_apply_nonlin(self(flip(flip(x, 4), 3)))
                        result_torch += 1 / num_results * flip(flip(pred, 4), 3)

                    if m == 4 and (0 in mirror_axes):
                        pred = self.inference_apply_nonlin(self(flip(x, 2)))
                        result_torch += 1 / num_results * flip(pred, 2)

                    if m == 5 and (0 in mirror_axes) and (2 in mirror_axes):
                        pred = self.inference_apply_nonlin(self(flip(flip(x, 4), 2)))
                        result_torch += 1 / num_results * flip(flip(pred, 4), 2)

                    if m == 6 and (0 in mirror_axes) and (1 in mirror_axes):
                        pred = self.inference_apply_nonlin(self(flip(flip(x, 3), 2)))
                        result_torch += 1 / num_results * flip(flip(pred, 3), 2)

                    if m == 7 and (0 in mirror_axes) and (1 in mirror_axes) and (2 in mirror_axes):
                        pred = self.inference_apply_nonlin(self(flip(flip(flip(x, 3), 2), 4)))
                        result_torch += 1 / num_results * flip(flip(flip(pred, 3), 2), 4)

            if mult is not None:
                result_torch[:, :] *= mult

        return result_torch
Ejemplo n.º 7
0
    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        """
        gradient clipping improves training stability

        :param data_generator:
        :param do_backprop:
        :param run_online_evaluation:
        :return:
        """
        data_dict = next(data_generator)
        # length = get_length(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        # self.x_tags = ['liver','spleen','pancreas','rightkidney','leftkidney'] #test-mk
        if self.x_tags is None:
            self.x_tags = [tag.lower() for tag in data_dict['tags']]
        y_tags = [tag.lower() for tag in data_dict['tags']]
        # print("------------------x_tags:",self.x_tags)
        # print("------------------y_tags:",y_tags)
        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data)
            target = to_cuda(target)

        self.optimizer.zero_grad()

        output = self.network(data)

        del data
        # loss = self.loss(output, target,self.x_tags,y_tags,need_updateGT=need_updateGT)
        loss = self.loss(output, target, self.x_tags, y_tags)
        if run_online_evaluation:
            self.run_online_evaluation(output, target)
        del target

        if do_backprop:
            if not self.fp16 or amp is None or not torch.cuda.is_available():
                loss.backward()
            else:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            _ = clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()

        return loss.detach().cpu().numpy()
Ejemplo n.º 8
0
    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data, gpu_id=None)
            target = to_cuda(target, gpu_id=None)

        self.optimizer.zero_grad()

        if self.fp16:
            with autocast():
                output = self.network(data)
                del data
                l = self.compute_loss(output, target)

            if do_backprop:
                self.amp_grad_scaler.scale(l).backward()
                self.amp_grad_scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
                self.amp_grad_scaler.step(self.optimizer)
                self.amp_grad_scaler.update()
        else:
            output = self.network(data)
            del data
            l = self.compute_loss(output, target)

            if do_backprop:
                l.backward()
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
                self.optimizer.step()

        if run_online_evaluation:
            self.run_online_evaluation(output, target)

        del target

        return l.detach().cpu().numpy()
Ejemplo n.º 9
0
    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        """
        gradient clipping improves training stability

        :param data_generator:
        :param do_backprop:
        :param run_online_evaluation:
        :return:
        """
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data)
            target = to_cuda(target)

        self.optimizer.zero_grad()

        output = self.network(data)

        del data
        loss = self.loss(output, target)

        if run_online_evaluation:
            self.run_online_evaluation(output, target)
        del target

        if do_backprop:
            if not self.fp16 or amp is None or not torch.cuda.is_available():
                loss.backward()
            else:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            _ = clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()

        return loss.detach().cpu().numpy()
Ejemplo n.º 10
0
    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data)
            target = to_cuda(target)

        self.optimizer.zero_grad()

        if self.fp16:
            with autocast():
                output = self.network(data)
                del data
                l, lossparts = self.loss(
                    output, target)  # , lossparts added by Camila zxc

            if do_backprop:
                self.amp_grad_scaler.scale(l).backward()
                self.amp_grad_scaler.step(self.optimizer)
                self.amp_grad_scaler.update()
        else:
            output = self.network(data)
            del data
            l, lossparts = self.loss(output,
                                     target)  # , lossparts added by Camila zxc

            if do_backprop:
                l.backward()
                self.optimizer.step()

        if run_online_evaluation:
            self.run_online_evaluation(output, target)

        del target

        return l.detach().cpu().numpy(), lossparts
Ejemplo n.º 11
0
    def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
        raise NotImplementedError("this class has not been changed to work with pytorch amp yet!")
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data, gpu_id=None)
            target = to_cuda(target, gpu_id=None)

        self.optimizer.zero_grad()

        output = self.network(data)
        del data

        total_loss = None

        for i in range(len(output)):
            # Starting here it gets spicy!
            axes = tuple(range(2, len(output[i].size())))

            # network does not do softmax. We need to do softmax for dice
            output_softmax = torch.sigmoid(output[i])

            # get the tp, fp and fn terms we need
            tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None)
            # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables
            # do_bg=False in nnUNetTrainer -> [:, 1:]
            nominator = 2 * tp[:, 1:]
            denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:]

            if self.batch_dice:
                # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice
                nominator = awesome_allgather_function.apply(nominator)
                denominator = awesome_allgather_function.apply(denominator)
                nominator = nominator.sum(0)
                denominator = denominator.sum(0)
            else:
                pass

            ce_loss = self.ce_loss(output[i], target[i])

            # we smooth by 1e-5 to penalize false positives if tp is 0
            dice_loss = (- (nominator + 1e-5) / (denominator + 1e-5)).mean()
            if total_loss is None:
                total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss)
            else:
                total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss)

        if run_online_evaluation:
            with torch.no_grad():
                output = output[0]
                target = target[0]
                out_sigmoid = torch.sigmoid(output)
                out_sigmoid = (out_sigmoid > 0.5).float()

                if self.threeD:
                    axes = (2, 3, 4)
                else:
                    axes = (2, 3)

                tp, fp, fn, _ = get_tp_fp_fn_tn(out_sigmoid, target, axes=axes)

                tp_hard = awesome_allgather_function.apply(tp)
                fp_hard = awesome_allgather_function.apply(fp)
                fn_hard = awesome_allgather_function.apply(fn)
                # print_if_rank0("after allgather", tp_hard.shape)

                # print_if_rank0("after sum", tp_hard.shape)

                self.run_online_evaluation(tp_hard.detach().cpu().numpy().sum(0),
                                           fp_hard.detach().cpu().numpy().sum(0),
                                           fn_hard.detach().cpu().numpy().sum(0))
        del target

        if do_backprop:
            if not self.fp16 or amp is None or not torch.cuda.is_available():
                total_loss.backward()
            else:
                with amp.scale_loss(total_loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            _ = clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()

        return total_loss.detach().cpu().numpy()
Ejemplo n.º 12
0
    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        data = to_cuda(data, gpu_id=None)
        target = to_cuda(target, gpu_id=None)

        self.optimizer.zero_grad()

        output = self.network(data)
        del data

        total_loss = None
        for i in range(len(output)):
            # Starting here it gets spicy!
            axes = tuple(range(2, len(output[i].size())))

            # network does not do softmax. We need to do softmax for dice
            output_softmax = softmax_helper(output[i])

            # get the tp, fp and fn terms we need
            tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax,
                                            target[i],
                                            axes,
                                            mask=None)
            # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables
            # do_bg=False in nnUNetTrainer -> [:, 1:]
            nominator = 2 * tp[:, 1:]
            denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:]

            if self.batch_dice:
                # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice
                nominator = awesome_allgather_function.apply(nominator)
                denominator = awesome_allgather_function.apply(denominator)
                nominator = nominator.sum(0)
                denominator = denominator.sum(0)
            else:
                pass

            ce_loss = self.ce_loss(output[i], target[i])

            # we smooth by 1e-5 to penalize false positives if tp is 0
            dice_loss = (-(nominator + 1e-5) / (denominator + 1e-5)).mean()
            if total_loss is None:
                total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss)
            else:
                total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss)

        if run_online_evaluation:
            with torch.no_grad():
                num_classes = output[0].shape[1]
                output_seg = output[0].argmax(1)
                target = target[0][:, 0]
                axes = tuple(range(1, len(target.shape)))
                tp_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                fp_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                fn_hard = torch.zeros(
                    (target.shape[0],
                     num_classes - 1)).to(output_seg.device.index)
                for c in range(1, num_classes):
                    tp_hard[:, c - 1] = sum_tensor(
                        (output_seg == c).float() * (target == c).float(),
                        axes=axes)
                    fp_hard[:, c - 1] = sum_tensor(
                        (output_seg == c).float() * (target != c).float(),
                        axes=axes)
                    fn_hard[:, c - 1] = sum_tensor(
                        (output_seg != c).float() * (target == c).float(),
                        axes=axes)

                # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target,
                #                                         axes, None)
                # print_if_rank0("before allgather", tp_hard.shape)
                tp_hard = tp_hard.sum(0, keepdim=False)[None]
                fp_hard = fp_hard.sum(0, keepdim=False)[None]
                fn_hard = fn_hard.sum(0, keepdim=False)[None]

                tp_hard = awesome_allgather_function.apply(tp_hard)
                fp_hard = awesome_allgather_function.apply(fp_hard)
                fn_hard = awesome_allgather_function.apply(fn_hard)
                # print_if_rank0("after allgather", tp_hard.shape)

                # print_if_rank0("after sum", tp_hard.shape)

                self.run_online_evaluation(
                    tp_hard.detach().cpu().numpy().sum(0),
                    fp_hard.detach().cpu().numpy().sum(0),
                    fn_hard.detach().cpu().numpy().sum(0))
        del target

        if do_backprop:
            if not self.fp16 or amp is None:
                total_loss.backward()
            else:
                with amp.scale_loss(total_loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            _ = clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()

        return total_loss.detach().cpu().numpy()
Ejemplo n.º 13
0
def run_iteration_nas(self,tr_data_generator, val_data_generator,do_backprop=True, run_online_evaluation=False):
        tr_data_dict = next(tr_data_generator)
        tr_data = tr_data_dict['data']
        tr_target = tr_data_dict['target']
        
        val_data_dict = next(val_data_generator)
        val_data = val_data_dict['data']
        val_target = val_data_dict['target']

        tr_data = maybe_to_torch(tr_data)
        tr_target = maybe_to_torch(tr_target)
        val_data = maybe_to_torch(val_data)
        val_target = maybe_to_torch(val_target)
        if torch.cuda.is_available():
            tr_data = to_cuda(tr_data)
            tr_target = to_cuda(tr_target)
            val_target = to_cuda(val_target)
            val_data = to_cuda(val_data)
        self.architect.step(tr_data, tr_target, val_data, val_target, self.lr, self.optimizer,unrolled=False)
        
        del val_data 
        del val_target
        self.optimizer.zero_grad()

        ###############
        ret = self.network(tr_data, tr_target, return_hard_tp_fp_fn=run_online_evaluation)
        if run_online_evaluation:
            ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret
            tp_hard = tp_hard.detach().cpu().numpy().mean(0)
            fp_hard = fp_hard.detach().cpu().numpy().mean(0)
            fn_hard = fn_hard.detach().cpu().numpy().mean(0)
            self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
            self.online_eval_tp.append(list(tp_hard))
            self.online_eval_fp.append(list(fp_hard))
            self.online_eval_fn.append(list(fn_hard))
        else:
            ces, tps, fps, fns = ret

        del tr_data, tr_target

        # we now need to effectively reimplement the loss
        loss = None
        for i in range(len(ces)):
            if not self.dice_do_BG:
                tp = tps[i][:, 1:]
                fp = fps[i][:, 1:]
                fn = fns[i][:, 1:]
            else:
                tp = tps[i]
                fp = fps[i]
                fn = fns[i]

            if self.batch_dice:
                tp = tp.sum(0)
                fp = fp.sum(0)
                fn = fn.sum(0)
            else:
                pass

            nominator = 2 * tp + self.dice_smooth
            denominator = 2 * tp + fp + fn + self.dice_smooth

            dice_loss = (- nominator / denominator).mean()
            if loss is None:
                loss = self.loss_weights[i] * (ces[i].mean() + dice_loss)
            else:
                loss += self.loss_weights[i] * (ces[i].mean() + dice_loss)
        ###########

        if do_backprop:
            if not self.fp16 or amp is None or not torch.cuda.is_available():
                loss.backward()
            else:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            _ = clip_grad_norm_(self.network.parameters(), 12)
            self.optimizer.step()
        return loss.detach().cpu().numpy()
Ejemplo n.º 14
0
    def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
        # with torch.autograd.set_detect_anomaly(True):
        """
        gradient clipping improves training stability

        :param data_generator:
        :param do_backprop:
        :param run_online_evaluation:
        :return:
        """
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        #fig, ax = plt.subplots(2, 4, figsize=(40, 20), tight_layout=True)
        # for i in range(4):
        #    ax[0, i].imshow(target[0][i, 0].detach().cpu().numpy())
        #    ax[1, i].imshow(target[0][i, 1].detach().cpu().numpy())
        #    ax[1, i].set_title(torch.sum(target[0][i, 1]))
        #fig.savefig('/cluster/husvogt/debug_imgs/target.png')

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data)
            target = to_cuda(target)

        if not self.fp16:
            print("supports only mixed precision")
            exit(1)

        for idx, (opt, loss) in enumerate(self.opt_loss):
            opt.zero_grad()

            with autocast():
                output = [self.network(data)]
                if idx == 0:
                    l = loss(output[0][:, :2], target[0][:, :1])
                    self.pickle_losses['l_'].append(l.detach().cpu().numpy())
                elif idx == 1:
                    l = loss(output[0][:, 2:], target[0][:, 1:])
                    self.pickle_losses['l_dm'].append(l.detach().cpu().numpy())

                    for b in range(self.batch_size):
                        sum_p = torch.sum(output[0][b, 2])
                        sum_t = torch.sum(target[0][b, 1])
                        self.pickle_losses['sums_dm'].append((sum_p.detach().cpu().numpy(),
                                                              sum_t.detach().cpu().numpy()))
                elif idx == 2:
                    l = torch.Tensor([0.0]).cuda()
                    for b in range(self.batch_size):
                        sum_p = torch.sum(output[0][b, 2])
                        sum_t = torch.sum(target[0][b, 1])
                        l += torch.square(sum_t - sum_p)
                        self.pickle_losses['sums'].append((sum_p.detach().cpu().numpy(),
                                                           sum_t.detach().cpu().numpy()))
                    self.pickle_losses['l_n'].append(l.detach().cpu().numpy())
                    # if self.epoch < 200:
                    #    continue

            if do_backprop:
                self.amp_grad_scaler.scale(l).backward()
                self.amp_grad_scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
                self.amp_grad_scaler.step(opt)
                self.amp_grad_scaler.update()

        del data

        if run_online_evaluation:
            self.run_online_evaluation(output, target)

        del target

        return l.detach().cpu().numpy()
Ejemplo n.º 15
0
    def run_iteration(self,
                      data_generator,
                      do_backprop=True,
                      run_online_evaluation=False):
        # with torch.autograd.set_detect_anomaly(True):
        """
        gradient clipping improves training stability

        :param data_generator:
        :param do_backprop:
        :param run_online_evaluation:
        :return:
        """
        data_dict = next(data_generator)
        data = data_dict['data']
        target = data_dict['target']

        data = maybe_to_torch(data)
        target = maybe_to_torch(target)

        if torch.cuda.is_available():
            data = to_cuda(data)
            target = to_cuda(target)

        self.optimizer.zero_grad()

        if self.fp16:
            with autocast():
                output = self.network(data)
                del data
                if not self.deep_supervision:
                    output = [output]
                    l = self.loss(output[0], target[0])
                else:
                    l = self.loss(output, target)

            if do_backprop:
                self.amp_grad_scaler.scale(l).backward()
                self.amp_grad_scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
                self.amp_grad_scaler.step(self.optimizer)
                self.amp_grad_scaler.update()
        else:
            output = self.network(data)
            del data
            if not self.deep_supervision:
                output = [output]
                l = self.loss(output[0], target[0])
            else:
                l = self.loss(output, target)

            if do_backprop:
                l.backward()
                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
                self.optimizer.step()

        if run_online_evaluation:
            self.run_online_evaluation(output, target)

        del target
        self.pickle_losses['l_'].append(l.detach().cpu().numpy())

        return l.detach().cpu().numpy()