コード例 #1
0
    def training(self):
        class_num   = self.config['network']['class_num']
        iter_valid  = self.config['training']['iter_valid']
        train_loss = 0
        train_dice_list = []
        self.net.train()
        for it in range(iter_valid):
            try:
                data = next(self.trainIter)
            except StopIteration:
                self.trainIter = iter(self.train_loader)
                data = next(self.trainIter)
            # get the inputs
            inputs      = self.convert_tensor_type(data['image'])
            labels_prob = self.convert_tensor_type(data['label_prob'])                 
            
            # # for debug
            # for i in range(inputs.shape[0]):
            #     image_i = inputs[i][0]
            #     label_i = labels_prob[i][1]
            #     pixw_i  = pix_w[i][0]
            #     print(image_i.shape, label_i.shape, pixw_i.shape)
            #     image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i)
            #     label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i)
            #     weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i)
            #     save_nd_array_as_image(image_i, image_name, reference_name = None)
            #     save_nd_array_as_image(label_i, label_name, reference_name = None)
            #     save_nd_array_as_image(pixw_i, weight_name, reference_name = None)
            # continue

            inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device)
            
            # zero the parameter gradients
            self.optimizer.zero_grad()
                
            # forward + backward + optimize
            outputs = self.net(inputs)
            loss = self.get_loss_value(data, inputs, outputs, labels_prob)
            # if (self.config['training']['use'])
            loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            train_loss = train_loss + loss.item()
            # get dice evaluation for each class
            if(isinstance(outputs, tuple) or isinstance(outputs, list)):
                outputs = outputs[0] 
            outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True)
            soft_out       = get_soft_label(outputs_argmax, class_num, self.tensor_type)
            soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) 
            dice_list = get_classwise_dice(soft_out, labels_prob)
            train_dice_list.append(dice_list.cpu().numpy())
        train_avg_loss = train_loss / iter_valid
        train_cls_dice = np.asarray(train_dice_list).mean(axis = 0)
        train_avg_dice = train_cls_dice.mean()

        train_scalers = {'loss': train_avg_loss, 'avg_dice':train_avg_dice,\
            'class_dice': train_cls_dice}
        return train_scalers
コード例 #2
0
    def validation(self):
        class_num = self.config['network']['class_num']
        mini_batch_size = self.config['testing']['mini_batch_size']
        mini_patch_inshape = self.config['testing']['mini_patch_input_shape']
        mini_patch_outshape = self.config['testing']['mini_patch_output_shape']
        mini_patch_stride = self.config['testing']['mini_patch_stride']
        output_num = self.config['testing'].get('output_num', 1)
        valid_loss = 0.0
        valid_dice_list = []
        validIter = iter(self.valid_loader)
        with torch.no_grad():
            self.net.eval()
            for data in validIter:
                inputs = self.convert_tensor_type(data['image'])
                labels_prob = self.convert_tensor_type(data['label_prob'])

                outputs = volume_infer(inputs, self.net, self.device,
                                       class_num, mini_batch_size,
                                       mini_patch_inshape, mini_patch_outshape,
                                       mini_patch_stride, output_num)
                outputs = self.convert_tensor_type(torch.from_numpy(outputs))
                # The tensors are on CPU when calculating loss for validation data
                loss = self.get_loss_value(data, inputs, outputs, labels_prob)
                valid_loss = valid_loss + loss.item()

                if (isinstance(outputs, tuple) or isinstance(outputs, list)):
                    outputs = outputs[0]
                outputs_argmax = torch.argmax(outputs, dim=1, keepdim=True)
                soft_out = get_soft_label(outputs_argmax, class_num,
                                          self.tensor_type)
                soft_out, labels_prob = reshape_prediction_and_ground_truth(
                    soft_out, labels_prob)
                dice_list = get_classwise_dice(soft_out, labels_prob)
                valid_dice_list.append(dice_list.cpu().numpy())

        valid_avg_loss = valid_loss / len(validIter)
        valid_cls_dice = np.asarray(valid_dice_list).mean(axis=0)
        valid_avg_dice = valid_cls_dice.mean()

        valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\
            'class_dice': valid_cls_dice}
        return valid_scalers
コード例 #3
0
    def validation(self):
        class_num = self.config['network']['class_num']
        infer_cfg = self.config['testing']
        infer_cfg['class_num'] = class_num

        valid_loss_list = []
        valid_dice_list = []
        validIter = iter(self.valid_loader)
        with torch.no_grad():
            self.net.eval()
            infer_obj = Inferer(self.net, infer_cfg)
            for data in validIter:
                inputs = self.convert_tensor_type(data['image'])
                labels_prob = self.convert_tensor_type(data['label_prob'])
                inputs, labels_prob = inputs.to(self.device), labels_prob.to(
                    self.device)
                batch_n = inputs.shape[0]
                outputs = infer_obj.run(inputs)

                # The tensors are on CPU when calculating loss for validation data
                loss = self.get_loss_value(data, inputs, outputs, labels_prob)
                valid_loss_list.append(loss.item())

                if (isinstance(outputs, tuple) or isinstance(outputs, list)):
                    outputs = outputs[0]
                outputs_argmax = torch.argmax(outputs, dim=1, keepdim=True)
                soft_out = get_soft_label(outputs_argmax, class_num,
                                          self.tensor_type)
                for i in range(batch_n):
                    soft_out_i, labels_prob_i = reshape_prediction_and_ground_truth(\
                        soft_out[i:i+1], labels_prob[i:i+1])
                    temp_dice = get_classwise_dice(soft_out_i, labels_prob_i)
                    valid_dice_list.append(temp_dice.cpu().numpy())

        valid_avg_loss = np.asarray(valid_loss_list).mean()
        valid_cls_dice = np.asarray(valid_dice_list).mean(axis=0)
        valid_avg_dice = valid_cls_dice.mean()

        valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\
            'class_dice': valid_cls_dice}
        return valid_scalers