Ejemplo n.º 1
0
    def forward(self, loss_input_dict):
        predict = loss_input_dict['prediction']
        soft_y = loss_input_dict['ground_truth']
        img_w = loss_input_dict['image_weight']
        pix_w = loss_input_dict['pixel_weight']
        cls_w = loss_input_dict['class_weight']
        softmax = loss_input_dict['softmax']

        if (isinstance(predict, (list, tuple))):
            predict = predict[0]
        tensor_dim = len(predict.size())
        if (softmax):
            predict = nn.Softmax(dim=1)(predict)
        predict = reshape_tensor_to_2D(predict)
        soft_y = reshape_tensor_to_2D(soft_y)

        # combien pixel weight and image weight
        if (tensor_dim == 5):
            img_w = img_w[:, None, None, None, None]
        else:
            img_w = img_w[:, None, None, None]
        pix_w = pix_w * img_w
        pix_w = reshape_tensor_to_2D(pix_w)
        dice_score = get_classwise_dice(predict, soft_y, pix_w)

        weighted_dice = dice_score * cls_w
        average_dice = weighted_dice.sum() / cls_w.sum()
        dice_loss = 1.0 - average_dice
        return dice_loss
Ejemplo n.º 2
0
    def forward(self, loss_input_dict):
        predict = loss_input_dict['prediction']
        soft_y = loss_input_dict['ground_truth']
        softmax = loss_input_dict['softmax']

        if (softmax):
            predict = nn.Softmax(dim=1)(predict)
        predict = reshape_tensor_to_2D(predict)
        soft_y = reshape_tensor_to_2D(soft_y)

        dice_score = get_classwise_dice(predict, soft_y)
        dice_score = 0.01 + dice_score * 0.98
        exp_dice = -torch.log(dice_score)
        exp_dice = torch.pow(exp_dice, self.gamma)
        exp_dice = torch.mean(exp_dice)

        predict = 0.01 + predict * 0.98
        wc = torch.mean(soft_y, dim=0)
        wc = 1.0 / (wc + 0.1)
        wc = torch.pow(wc, 0.5)
        ce = -torch.log(predict)
        exp_ce = wc * torch.pow(ce, self.gamma)
        exp_ce = torch.sum(soft_y * exp_ce, dim=1)
        exp_ce = torch.mean(exp_ce)

        loss = exp_dice * self.w_dice + exp_ce * (1.0 - self.w_dice)
        return loss
Ejemplo n.º 3
0
    def training(self):
        class_num   = self.config['network']['class_num']
        iter_valid  = self.config['training']['iter_valid']
        train_loss = 0
        train_dice_list = []
        self.net.train()
        for it in range(iter_valid):
            try:
                data = next(self.trainIter)
            except StopIteration:
                self.trainIter = iter(self.train_loader)
                data = next(self.trainIter)
            # get the inputs
            inputs      = self.convert_tensor_type(data['image'])
            labels_prob = self.convert_tensor_type(data['label_prob'])                 
            
            # # for debug
            # for i in range(inputs.shape[0]):
            #     image_i = inputs[i][0]
            #     label_i = labels_prob[i][1]
            #     pixw_i  = pix_w[i][0]
            #     print(image_i.shape, label_i.shape, pixw_i.shape)
            #     image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i)
            #     label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i)
            #     weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i)
            #     save_nd_array_as_image(image_i, image_name, reference_name = None)
            #     save_nd_array_as_image(label_i, label_name, reference_name = None)
            #     save_nd_array_as_image(pixw_i, weight_name, reference_name = None)
            # continue

            inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device)
            
            # zero the parameter gradients
            self.optimizer.zero_grad()
                
            # forward + backward + optimize
            outputs = self.net(inputs)
            loss = self.get_loss_value(data, inputs, outputs, labels_prob)
            # if (self.config['training']['use'])
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            train_loss = train_loss + loss.item()
            # get dice evaluation for each class
            if(isinstance(outputs, tuple) or isinstance(outputs, list)):
                outputs = outputs[0] 
            outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True)
            soft_out       = get_soft_label(outputs_argmax, class_num, self.tensor_type)
            soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) 
            dice_list = get_classwise_dice(soft_out, labels_prob)
            train_dice_list.append(dice_list.cpu().numpy())
        train_avg_loss = train_loss / iter_valid
        train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
        train_avg_dice = train_cls_dice.mean()

        train_scalers = {'loss': train_avg_loss, 'avg_dice':train_avg_dice,\
            'class_dice': train_cls_dice}
        return train_scalers
Ejemplo n.º 4
0
    def forward(self, loss_input_dict):
        predict = loss_input_dict['prediction']
        soft_y  = loss_input_dict['ground_truth']
        softmax = loss_input_dict['softmax']

        if(softmax):
            predict = nn.Softmax(dim = 1)(predict)
        predict = reshape_tensor_to_2D(predict)
        soft_y  = reshape_tensor_to_2D(soft_y) 

        dice_score = get_classwise_dice(predict, soft_y, None)
        dice_score = 0.01 + dice_score * 0.98
        dice_loss  = 1.0 - torch.pow(dice_score, 1.0 / self.beta)

        avg_loss = torch.mean(dice_loss)   
        return avg_loss
Ejemplo n.º 5
0
    def validation(self):
        class_num = self.config['network']['class_num']
        mini_batch_size = self.config['testing']['mini_batch_size']
        mini_patch_inshape = self.config['testing']['mini_patch_input_shape']
        mini_patch_outshape = self.config['testing']['mini_patch_output_shape']
        mini_patch_stride = self.config['testing']['mini_patch_stride']
        output_num = self.config['testing'].get('output_num', 1)
        valid_loss = 0.0
        valid_dice_list = []
        validIter = iter(self.valid_loader)
        with torch.no_grad():
            self.net.eval()
            for data in validIter:
                inputs = self.convert_tensor_type(data['image'])
                labels_prob = self.convert_tensor_type(data['label_prob'])

                outputs = volume_infer(inputs, self.net, self.device,
                                       class_num, mini_batch_size,
                                       mini_patch_inshape, mini_patch_outshape,
                                       mini_patch_stride, output_num)
                outputs = self.convert_tensor_type(torch.from_numpy(outputs))
                # The tensors are on CPU when calculating loss for validation data
                loss = self.get_loss_value(data, inputs, outputs, labels_prob)
                valid_loss = valid_loss + loss.item()

                if (isinstance(outputs, tuple) or isinstance(outputs, list)):
                    outputs = outputs[0]
                outputs_argmax = torch.argmax(outputs, dim=1, keepdim=True)
                soft_out = get_soft_label(outputs_argmax, class_num,
                                          self.tensor_type)
                soft_out, labels_prob = reshape_prediction_and_ground_truth(
                    soft_out, labels_prob)
                dice_list = get_classwise_dice(soft_out, labels_prob)
                valid_dice_list.append(dice_list.cpu().numpy())

        valid_avg_loss = valid_loss / len(validIter)
        valid_cls_dice = np.asarray(valid_dice_list).mean(axis=0)
        valid_avg_dice = valid_cls_dice.mean()

        valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\
            'class_dice': valid_cls_dice}
        return valid_scalers
Ejemplo n.º 6
0
    def validation(self):
        class_num = self.config['network']['class_num']
        infer_cfg = self.config['testing']
        infer_cfg['class_num'] = class_num

        valid_loss_list = []
        valid_dice_list = []
        validIter = iter(self.valid_loader)
        with torch.no_grad():
            self.net.eval()
            infer_obj = Inferer(self.net, infer_cfg)
            for data in validIter:
                inputs = self.convert_tensor_type(data['image'])
                labels_prob = self.convert_tensor_type(data['label_prob'])
                inputs, labels_prob = inputs.to(self.device), labels_prob.to(
                    self.device)
                batch_n = inputs.shape[0]
                outputs = infer_obj.run(inputs)

                # The tensors are on CPU when calculating loss for validation data
                loss = self.get_loss_value(data, inputs, outputs, labels_prob)
                valid_loss_list.append(loss.item())

                if (isinstance(outputs, tuple) or isinstance(outputs, list)):
                    outputs = outputs[0]
                outputs_argmax = torch.argmax(outputs, dim=1, keepdim=True)
                soft_out = get_soft_label(outputs_argmax, class_num,
                                          self.tensor_type)
                for i in range(batch_n):
                    soft_out_i, labels_prob_i = reshape_prediction_and_ground_truth(\
                        soft_out[i:i+1], labels_prob[i:i+1])
                    temp_dice = get_classwise_dice(soft_out_i, labels_prob_i)
                    valid_dice_list.append(temp_dice.cpu().numpy())

        valid_avg_loss = np.asarray(valid_loss_list).mean()
        valid_cls_dice = np.asarray(valid_dice_list).mean(axis=0)
        valid_avg_dice = valid_cls_dice.mean()

        valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\
            'class_dice': valid_cls_dice}
        return valid_scalers