def training(self): class_num = self.config['network']['class_num'] iter_valid = self.config['training']['iter_valid'] train_loss = 0 train_dice_list = [] self.net.train() for it in range(iter_valid): try: data = next(self.trainIter) except StopIteration: self.trainIter = iter(self.train_loader) data = next(self.trainIter) # get the inputs inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) # # for debug # for i in range(inputs.shape[0]): # image_i = inputs[i][0] # label_i = labels_prob[i][1] # pixw_i = pix_w[i][0] # print(image_i.shape, label_i.shape, pixw_i.shape) # image_name = "temp/image_{0:}_{1:}.nii.gz".format(it, i) # label_name = "temp/label_{0:}_{1:}.nii.gz".format(it, i) # weight_name= "temp/weight_{0:}_{1:}.nii.gz".format(it, i) # save_nd_array_as_image(image_i, image_name, reference_name = None) # save_nd_array_as_image(label_i, label_name, reference_name = None) # save_nd_array_as_image(pixw_i, weight_name, reference_name = None) # continue inputs, labels_prob = inputs.to(self.device), labels_prob.to(self.device) # zero the parameter gradients self.optimizer.zero_grad() # forward + backward + optimize outputs = self.net(inputs) loss = self.get_loss_value(data, inputs, outputs, labels_prob) # if (self.config['training']['use']) loss.backward() self.optimizer.step() self.scheduler.step() train_loss = train_loss + loss.item() # get dice evaluation for each class if(isinstance(outputs, tuple) or isinstance(outputs, list)): outputs = outputs[0] outputs_argmax = torch.argmax(outputs, dim = 1, keepdim = True) soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) soft_out, labels_prob = reshape_prediction_and_ground_truth(soft_out, labels_prob) dice_list = get_classwise_dice(soft_out, labels_prob) train_dice_list.append(dice_list.cpu().numpy()) train_avg_loss = train_loss / iter_valid train_cls_dice = np.asarray(train_dice_list).mean(axis = 0) train_avg_dice = train_cls_dice.mean() train_scalers = {'loss': train_avg_loss, 'avg_dice':train_avg_dice,\ 'class_dice': train_cls_dice} return train_scalers
def validation(self): class_num = self.config['network']['class_num'] mini_batch_size = self.config['testing']['mini_batch_size'] mini_patch_inshape = self.config['testing']['mini_patch_input_shape'] mini_patch_outshape = self.config['testing']['mini_patch_output_shape'] mini_patch_stride = self.config['testing']['mini_patch_stride'] output_num = self.config['testing'].get('output_num', 1) valid_loss = 0.0 valid_dice_list = [] validIter = iter(self.valid_loader) with torch.no_grad(): self.net.eval() for data in validIter: inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) outputs = volume_infer(inputs, self.net, self.device, class_num, mini_batch_size, mini_patch_inshape, mini_patch_outshape, mini_patch_stride, output_num) outputs = self.convert_tensor_type(torch.from_numpy(outputs)) # The tensors are on CPU when calculating loss for validation data loss = self.get_loss_value(data, inputs, outputs, labels_prob) valid_loss = valid_loss + loss.item() if (isinstance(outputs, tuple) or isinstance(outputs, list)): outputs = outputs[0] outputs_argmax = torch.argmax(outputs, dim=1, keepdim=True) soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) soft_out, labels_prob = reshape_prediction_and_ground_truth( soft_out, labels_prob) dice_list = get_classwise_dice(soft_out, labels_prob) valid_dice_list.append(dice_list.cpu().numpy()) valid_avg_loss = valid_loss / len(validIter) valid_cls_dice = np.asarray(valid_dice_list).mean(axis=0) valid_avg_dice = valid_cls_dice.mean() valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\ 'class_dice': valid_cls_dice} return valid_scalers
def validation(self): class_num = self.config['network']['class_num'] infer_cfg = self.config['testing'] infer_cfg['class_num'] = class_num valid_loss_list = [] valid_dice_list = [] validIter = iter(self.valid_loader) with torch.no_grad(): self.net.eval() infer_obj = Inferer(self.net, infer_cfg) for data in validIter: inputs = self.convert_tensor_type(data['image']) labels_prob = self.convert_tensor_type(data['label_prob']) inputs, labels_prob = inputs.to(self.device), labels_prob.to( self.device) batch_n = inputs.shape[0] outputs = infer_obj.run(inputs) # The tensors are on CPU when calculating loss for validation data loss = self.get_loss_value(data, inputs, outputs, labels_prob) valid_loss_list.append(loss.item()) if (isinstance(outputs, tuple) or isinstance(outputs, list)): outputs = outputs[0] outputs_argmax = torch.argmax(outputs, dim=1, keepdim=True) soft_out = get_soft_label(outputs_argmax, class_num, self.tensor_type) for i in range(batch_n): soft_out_i, labels_prob_i = reshape_prediction_and_ground_truth(\ soft_out[i:i+1], labels_prob[i:i+1]) temp_dice = get_classwise_dice(soft_out_i, labels_prob_i) valid_dice_list.append(temp_dice.cpu().numpy()) valid_avg_loss = np.asarray(valid_loss_list).mean() valid_cls_dice = np.asarray(valid_dice_list).mean(axis=0) valid_avg_dice = valid_cls_dice.mean() valid_scalers = {'loss': valid_avg_loss, 'avg_dice': valid_avg_dice,\ 'class_dice': valid_cls_dice} return valid_scalers