class Solver(object): def __init__(self, config, train_loader, valid_loader, test_loader): # Data loader self.train_loader = train_loader self.valid_loader = valid_loader self.test_loader = test_loader # Models self.unet = None self.optimizer = None self.img_ch = config.img_ch self.output_ch = config.output_ch #self.criterion = torch.nn.BCELoss() self.criterion = nn.CrossEntropyLoss() self.augmentation_prob = config.augmentation_prob # Hyper-parameters self.lr = config.lr self.beta1 = config.beta1 self.beta2 = config.beta2 # Training settings self.num_epochs = config.num_epochs self.num_epochs_decay = config.num_epochs_decay self.batch_size = config.batch_size # Step size self.log_step = config.log_step self.val_step = config.val_step # Path self.model_path = config.model_path self.result_path = config.result_path self.mode = config.mode self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model_type = config.model_type self.t = config.t self.build_model() def build_model(self): """Build generator and discriminator.""" if self.model_type =='U_Net': self.unet = U_Net(img_ch=self.img_ch,output_ch=self.output_ch) elif self.model_type =='R2U_Net': self.unet = R2U_Net(img_ch=self.img_ch,output_ch=self.output_ch,t=self.t) elif self.model_type =='AttU_Net': self.unet = AttU_Net(img_ch=self.img_ch,output_ch=self.output_ch) elif self.model_type == 'R2AttU_Net': self.unet = R2AttU_Net(img_ch=self.img_ch,output_ch=self.output_ch,t=self.t) self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr, [self.beta1, self.beta2]) self.unet.to(self.device) # self.print_network(self.unet, self.model_type) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def to_data(self, x): """Convert variable to tensor.""" if torch.cuda.is_available(): x = x.cpu() return x.data def update_lr(self, g_lr, d_lr): for param_group in self.optimizer.param_groups: param_group['lr'] = lr def reset_grad(self): """Zero the gradient buffers.""" self.unet.zero_grad() def compute_accuracy(self,SR,GT): SR_flat = SR.view(-1) GT_flat = GT.view(-1) acc = GT_flat.data.cpu()==(SR_flat.data.cpu()>0.5) def tensor2img(self,x): img = (x[:,0,:,:]>x[:,1,:,:]).float() img = img*255 return img def train(self): """Train encoder, generator and discriminator.""" #====================================== Training ===========================================# #===========================================================================================# unet_path = os.path.join(self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' %(self.model_type,self.num_epochs,self.lr,self.num_epochs_decay,self.augmentation_prob)) # U-Net Train if os.path.isfile(unet_path): # Load the pretrained Encoder self.unet.load_state_dict(torch.load(unet_path)) print('%s is Successfully Loaded from %s'%(self.model_type,unet_path)) else: # Train for Encoder lr = self.lr best_unet_score = 0. for epoch in range(self.num_epochs): self.unet.train(True) epoch_loss = 0 acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT) in enumerate(self.train_loader): # GT : Ground Truth images = images.to(self.device) GT = GT.to(self.device) # SR : Segmentation Result SR = self.unet(images) #SR_probs = F.sigmoid(SR) #print(SR_probs.size()); exit(1) #SR_flat = SR_probs.view(SR_probs.size(0),-1) #GT_flat = GT.view(GT.size(0),-1) loss = self.criterion(SR, GT) epoch_loss += loss.item() # Backprop + optimize self.reset_grad() loss.backward() self.optimizer.step() acc += get_accuracy(SR,GT) SE += get_sensitivity(SR,GT) SP += get_specificity(SR,GT) PC += get_precision(SR,GT) F1 += get_F1(SR,GT) JS += get_JS(SR,GT) DC += get_DC(SR,GT) length += images.size(0) acc = acc/length SE = SE/length SP = SP/length PC = PC/length F1 = F1/length JS = JS/length DC = DC/length # Print the log info print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % ( epoch+1, self.num_epochs, \ epoch_loss,\ acc,SE,SP,PC,F1,JS,DC)) # Decay learning rate if (epoch+1) > (self.num_epochs - self.num_epochs_decay): lr -= (self.lr / float(self.num_epochs_decay)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr print ('Decay learning rate to lr: {}.'.format(lr)) #===================================== Validation ====================================# self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length=0 for i, (images, GT) in enumerate(self.valid_loader): images = images.to(self.device) GT = GT.to(self.device) #SR = F.sigmoid(self.unet(images)) SR = self.unet(images) acc += get_accuracy(SR,GT) SE += get_sensitivity(SR,GT) SP += get_specificity(SR,GT) PC += get_precision(SR,GT) F1 += get_F1(SR,GT) JS += get_JS(SR,GT) DC += get_DC(SR,GT) length += images.size(0) acc = acc/length SE = SE/length SP = SP/length PC = PC/length F1 = F1/length JS = JS/length DC = DC/length unet_score = JS + DC print('[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f'%(acc,SE,SP,PC,F1,JS,DC)) ''' torchvision.utils.save_image(images.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_image.png'%(self.model_type,epoch+1))) torchvision.utils.save_image(SR.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_SR.png'%(self.model_type,epoch+1))) torchvision.utils.save_image(GT.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_GT.png'%(self.model_type,epoch+1))) ''' # Save Best U-Net model if unet_score > best_unet_score: best_unet_score = unet_score best_epoch = epoch best_unet = self.unet.state_dict() print('Best %s model score : %.4f'%(self.model_type,best_unet_score)) torch.save(best_unet,unet_path) #===================================== Test ====================================# del self.unet del best_unet self.build_model() self.unet.load_state_dict(torch.load(unet_path)) self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length=0 for i, (images, GT) in enumerate(self.valid_loader): images = images.to(self.device) GT = GT.to(self.device) #SR = F.sigmoid(self.unet(images)) SR = self.unet(images) acc += get_accuracy(SR,GT) SE += get_sensitivity(SR,GT) SP += get_specificity(SR,GT) PC += get_precision(SR,GT) F1 += get_F1(SR,GT) JS += get_JS(SR,GT) DC += get_DC(SR,GT) length += images.size(0) acc = acc/length SE = SE/length SP = SP/length PC = PC/length F1 = F1/length JS = JS/length DC = DC/length unet_score = JS + DC f = open(os.path.join(self.result_path,'result.csv'), 'a', encoding='utf-8', newline='') wr = csv.writer(f) wr.writerow([self.model_type,acc,SE,SP,PC,F1,JS,DC,self.lr,best_epoch,self.num_epochs,self.num_epochs_decay,self.augmentation_prob]) f.close()
# - For 2 classes, use n_classes=1 # - For N > 2 classes, use n_classes=N #net = UNet(n_channels=3, n_classes=1, bilinear=True) #net = R2U_Net(n_channels=1, n_classes=1, bilinear=True) net = AttU_Net() ''' logging.info(f'Network:\n' f'\t{net.n_channels} input channels\n' f'\t{net.n_classes} output channels (classes)\n' f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling') ''' if args.load: net.load_state_dict(torch.load(args.load, map_location=device)) logging.info(f'Model loaded from {args.load}') net.to(device=device) # faster convolutions, but more memory # cudnn.benchmark = True try: train_net(net=net, epochs=args.epochs, batch_size=args.batchsize, lr=args.lr, device=device, img_scale=args.scale, val_percent=args.val / 100) except KeyboardInterrupt: torch.save(net.state_dict(), 'INTERRUPTED.pth') logging.info('Saved interrupt') try:
class Solver(object): def __init__(self, config, train_loader, valid_loader, test_loader): # Data loader self.train_loader = train_loader self.valid_loader = valid_loader self.test_loader = test_loader # Models self.unet = None self.optimizer = None self.img_ch = config.img_ch self.output_ch = config.output_ch self.criterion = dice_loss() self.augmentation_prob = config.augmentation_prob # Hyper-parameters self.lr = config.lr self.beta1 = config.beta1 self.beta2 = config.beta2 # Training settings self.num_epochs = config.num_epochs self.num_epochs_decay = config.num_epochs_decay self.batch_size = config.batch_size # Step size self.log_step = config.log_step self.val_step = config.val_step # Path self.model_path = config.model_path self.result_path = config.result_path self.mode = config.mode self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.model_type = config.model_type self.t = config.t self.build_model() def build_model(self): """Build generator and discriminator.""" if self.model_type == 'U_Net': self.unet = U_Net(img_ch=3, output_ch=2) elif self.model_type == 'R2U_Net': self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t) elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=3, output_ch=1) elif self.model_type == 'R2AttU_Net': self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t) self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr, [self.beta1, self.beta2]) self.unet.to(self.device) # self.print_network(self.unet, self.model_type) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def to_data(self, x): """Convert variable to tensor.""" if torch.cuda.is_available(): x = x.cpu() return x.data def update_lr(self, g_lr, d_lr): for param_group in self.optimizer.param_groups: param_group['lr'] = lr def reset_grad(self): """Zero the gradient buffers.""" self.unet.zero_grad() def compute_accuracy(self, SR, GT): SR_flat = SR.view(-1) GT_flat = GT.view(-1) acc = GT_flat.data.cpu() == (SR_flat.data.cpu() > 0.5) def tensor2img(self, x): img = (x[:, 0, :, :] > x[:, 1, :, :]).float() img = img * 255 return img def train(self): """Train encoder, generator and discriminator.""" #====================================== Training ===========================================# #===========================================================================================# unet_path = os.path.join( self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' % (self.model_type, self.num_epochs, self.lr, self.num_epochs_decay, self.augmentation_prob)) # U-Net Train if os.path.isfile(unet_path): # Load the pretrained Encoder self.unet.load_state_dict(torch.load(unet_path)) print('%s is Successfully Loaded from %s' % (self.model_type, unet_path)) else: # Train for Encoder lr = self.lr best_unet_score = 0. for epoch in range(self.num_epochs): self.unet.train(True) epoch_loss = 0 acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT) in enumerate(self.train_loader): # GT : Ground Truth if i == 1: gt = GT.numpy() # print (gt.max()) images = images.to(self.device) GT = GT.to(self.device).squeeze(1) # SR : Segmentation Result SR = self.unet(images) SR_probs = F.softmax(SR, dim=1) # SR_probs = F.softmax(SR) SR_flat = SR_probs GT_flat = GT #print(SR_flat.requires_grad,GT_flat.requires_grad) loss = self.criterion(SR_flat, GT_flat.long()) #print(loss) epoch_loss += loss.item() # Backprop + optimize self.reset_grad() loss.backward() self.optimizer.step() acc += get_accuracy(SR_probs[:, 1:2, :, :], GT) SE += get_sensitivity(SR_probs[:, 1:2, :, :], GT) SP += get_specificity(SR_probs[:, 1:2, :, :], GT) PC += get_precision(SR_probs[:, 1:2, :, :], GT) F1 += get_F1(SR_probs[:, 1:2, :, :], GT) JS += get_JS(SR_probs[:, 1:2, :, :], GT) DC += get_DC(SR_probs[:, 1:2, :, :], GT) length += images.size(0) if i % 50 == 0 or i == len(self.train_loader) - 1: vis_dir = './train_viz/' probs = SR_probs y = GT images = images.permute(0, 2, 3, 1) os.system('rm -rf %s' % (vis_dir)) os.system('mkdir %s' % (vis_dir)) for j in range(0, images.size()[0]): img = images[ j, :, :, :].data.cpu().numpy().squeeze() img -= np.min(img) img /= np.max(img) / 255. img = img[:, :, ::-1] img = img.astype(np.uint8) #img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_GRAY2BGR) viz = viz_img(img, y[j, :, :], probs[j, :, :, :]) path = vis_dir + '%d.jpg' % (j) cv2.imwrite(path, viz) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length epoch_loss = epoch_loss / length # Print the log info print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % ( epoch+1, self.num_epochs, \ epoch_loss,\ acc,SE,SP,PC,F1,JS,DC)) # Decay learning rate if (epoch + 1) > (self.num_epochs - self.num_epochs_decay): lr -= (self.lr / float(self.num_epochs_decay)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr print('Decay learning rate to lr: {}.'.format(lr)) #===================================== Validation ====================================# self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT) in enumerate(self.valid_loader): images = images.to(self.device) GT = GT.to(self.device).squeeze(1) SR = (self.unet(images)) SR_probs = F.softmax(SR, dim=1) acc += get_accuracy(SR_probs[:, 1:2, :, :], GT) SE += get_sensitivity(SR_probs[:, 1:2, :, :], GT) SP += get_specificity(SR_probs[:, 1:2, :, :], GT) PC += get_precision(SR_probs[:, 1:2, :, :], GT) F1 += get_F1(SR_probs[:, 1:2, :, :], GT) JS += get_JS(SR_probs[:, 1:2, :, :], GT) DC += get_DC(SR_probs[:, 1:2, :, :], GT) length += images.size(0) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length unet_score = JS + DC print( '[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (acc, SE, SP, PC, F1, JS, DC)) ''' torchvision.utils.save_image(images.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_image.png'%(self.model_type,epoch+1))) torchvision.utils.save_image(SR.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_SR.png'%(self.model_type,epoch+1))) torchvision.utils.save_image(GT.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_GT.png'%(self.model_type,epoch+1))) ''' #Save Best U-Net model if unet_score > best_unet_score: best_unet_score = unet_score best_epoch = epoch best_unet = self.unet.state_dict() print('Best %s model score : %.4f' % (self.model_type, best_unet_score)) torch.save(best_unet, unet_path) #===================================== Test ====================================# del self.unet # del best_unet self.build_model() self.unet.load_state_dict(torch.load(unet_path)) self.unet.train(False) self.unet.eval()
class Solver(object): def __init__(self, config, train_loader, valid_loader, test_loader): # Data loader self.train_loader = train_loader self.valid_loader = valid_loader self.test_loader = test_loader # Models self.unet = None self.optimizer = None self.img_ch = config.img_ch self.output_ch = config.output_ch self.bce_loss = torch.nn.BCELoss() self.augmentation_prob = config.augmentation_prob # Hyper-parameters self.lr = config.lr self.beta1 = config.beta1 self.beta2 = config.beta2 self.lamda = config.lamda # Training settings self.num_epochs = config.num_epochs self.num_epochs_decay = config.num_epochs_decay self.batch_size = config.batch_size self.save_model = config.save_model # Plots self.loss_history = hl.History() self.acc_history = hl.History() self.dc_history = hl.History() self.canvas = hl.Canvas() # Step size for plotting self.log_step = config.log_step self.val_step = config.val_step # Paths self.model_path = config.model_path self.result_path = config.result_path self.mode = config.mode # Model training properties self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.model_type = config.model_type self.t = config.t self.build_model() def build_model(self): # Load required model if self.model_type == 'U_Net': self.unet = U_Net(img_ch=3, output_ch=1) elif self.model_type == 'R2U_Net': self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t) elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=3, output_ch=1) elif self.model_type == 'R2AttU_Net': self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t) # Load optimizer self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr, [self.beta1, self.beta2]) # Move model to device self.unet.to(self.device) def print_network(self, model, name): # Print out the network information num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def dice_loss(self, pred, target): pred = pred.view(32, -1) target = target.view(32, -1) numerator = 2 * torch.sum(pred * target) denominator = torch.sum(pred + target) return 1 - (numerator + 1) / (denominator + 1) def train(self): # Debugging (Uncomment following lines) # a = torch.zeros((4, 3, 224, 224)) # self.unet(a.to(self.device)) unet_path = os.path.join(self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' %(self.model_type,self.num_epochs,\ self.lr,self.num_epochs_decay,\ self.augmentation_prob)) # U-Net Train if os.path.isfile(unet_path): # Load the pretrained Encoder self.unet.load_state_dict(torch.load(unet_path)) print('%s is Successfully Loaded from %s' % (self.model_type, unet_path)) else: # Train for Encoder lr = self.lr best_unet_score = 0. for epoch in range(self.num_epochs): self.unet.train(True) epoch_loss = 0 acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT) in enumerate(self.train_loader): # GT : Ground Truth images = images.to(self.device) GT = GT.to(self.device) # Zero grad self.optimizer.zero_grad() # SR : Segmentation Result SR = self.unet(images) SR_probs = torch.sigmoid(SR) # Convert to 1D tensor for loss calculation SR_flat = SR_probs.view(SR_probs.size(0), -1) GT_flat = GT.view(GT.size(0), -1) # Compute loss loss = self.bce_loss( SR_flat, GT_flat) + self.lamda * self.dice_loss( SR_flat, GT_flat) epoch_loss += loss.item() # Backprop loss.backward() self.optimizer.step() # Get metrics acc += get_accuracy(SR_probs, GT) SE += get_sensitivity(SR_probs, GT) SP += get_specificity(SR_probs, GT) PC += get_precision(SR_probs, GT) F1 += get_F1(SR_probs, GT) JS += get_JS(SR_probs, GT) DC += get_DC(SR_probs, GT) length = i length = (i + 1) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length train_dc = DC train_acc = acc train_loss = epoch_loss / length # # Decay learning rate # if (epoch+1) > (self.num_epochs - self.num_epochs_decay): # lr -= (self.lr / float(self.num_epochs_decay)) # for param_group in self.optimizer.param_groups: # param_group['lr'] = lr # print ('Decay learning rate to lr: {}.'.format(lr)) # VALIDATION with torch.no_grad(): epoch_loss = 0 self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT) in enumerate(self.valid_loader): images = images.to(self.device) GT = GT.to(self.device) SR = torch.sigmoid(self.unet(images)) # Convert to 1D tensor for loss calculation SR_flat = SR.view(SR.size(0), -1) GT_flat = GT.view(GT.size(0), -1) # Compute loss loss = self.bce_loss( SR_flat, GT_flat) + self.lamda * self.dice_loss( SR_flat, GT_flat) epoch_loss += loss.item() acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length = i length = (i + 1) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length unet_score = JS + DC valid_dc = DC valid_acc = acc valid_loss = epoch_loss / length self.loss_history.log(epoch + 1, train_loss=train_loss, valid_loss=valid_loss) self.acc_history.log(epoch + 1, train_acc=train_acc, valid_acc=valid_acc) self.dc_history.log(epoch + 1, train_dc=train_dc, valid_dc=valid_dc) with self.canvas: self.canvas.draw_plot( [ self.loss_history['train_loss'], self.loss_history['valid_loss'] ], labels=['Train Loss', 'Valid loss']) self.canvas.draw_plot( [ self.acc_history['train_acc'], self.acc_history['valid_acc'] ], labels=['Train Acc', 'Valid Acc']) self.canvas.draw_plot( [ self.dc_history['train_dc'], self.dc_history['valid_dc'] ], labels=['Train Dice Coeff', 'Valid Dice Coeff']) grid_images = torch.cat([(images + 1) / 2, torch.cat([SR, SR, SR], dim=1), torch.cat([GT, GT, GT], dim=1)]) grid = torchvision.utils.make_grid(grid_images, nrow=4) torchvision.utils.save_image(grid, \ os.path.join(self.result_path,'%s_valid_%d_image.png'%\ (self.model_type,epoch+1))) # Save Best U-Net model if self.save_model: if unet_score > best_unet_score: best_unet_score = unet_score best_epoch = epoch best_unet = self.unet.state_dict() print('Best %s model score : %.4f' % (self.model_type, best_unet_score)) torch.save(best_unet, unet_path) def test(self): del self.unet del best_unet self.build_model() self.unet.load_state_dict(torch.load(unet_path)) self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT) in enumerate(self.valid_loader): images = images.to(self.device) GT = GT.to(self.device) SR = torch.sigmoid(self.unet(images)) acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length += images.size(0) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length unet_score = JS + DC f = open(os.path.join(self.result_path, 'result.csv'), 'a', encoding='utf-8', newline='') wr = csv.writer(f) wr.writerow([ self.model_type, acc, SE, SP, PC, F1, JS, DC, self.lr, best_epoch, self.num_epochs, self.num_epochs_decay, self.augmentation_prob ]) f.close()
class Solver(object): def __init__(self, config, train_loader, valid_loader, test_loader): # Data loader self.train_loader = train_loader self.valid_loader = valid_loader self.test_loader = test_loader # Models self.unet = None self.optimizer = None self.img_ch = config.img_ch self.output_ch = config.output_ch self.criterion = torch.nn.BCELoss() self.augmentation_prob = config.augmentation_prob # Hyper-parameters self.lr = config.lr self.beta1 = config.beta1 self.beta2 = config.beta2 # Training settings self.num_epochs = config.num_epochs self.num_epochs_decay = config.num_epochs_decay self.batch_size = config.batch_size # Step size self.log_step = config.log_step self.val_step = config.val_step # Path self.model_path = config.model_path self.result_path = config.result_path self.mode = config.mode self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.model_type = config.model_type self.t = config.t self.build_model() def build_model(self): """Build generator and discriminator.""" if self.model_type == 'U_Net': self.unet = U_Net(img_ch=3, output_ch=1) elif self.model_type == 'R2U_Net': self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t) elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=3, output_ch=1) elif self.model_type == 'R2AttU_Net': self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t) self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr, [self.beta1, self.beta2]) self.unet.to(self.device) def reset_grad(self): """Zero the gradient buffers.""" self.unet.zero_grad() def train(self): """Train encoder, generator and discriminator.""" #====================================== Training ===========================================# #===========================================================================================# unet_path = os.path.join( self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' % (self.model_type, self.num_epochs, self.lr, self.num_epochs_decay, self.augmentation_prob)) print(unet_path) # U-Net Train if os.path.isfile(unet_path): # Load the pretrained Encoder self.unet.load_state_dict(torch.load(unet_path)) print('%s is Successfully Loaded from %s' % (self.model_type, unet_path)) else: # Train for Encoder lr = self.lr best_unet_score = 0. for epoch in range(self.num_epochs): self.unet.train(True) epoch_loss = 0 acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT) in enumerate(self.train_loader): # GT : Ground Truth images = images.to(self.device) GT = GT.to(self.device) # SR : Segmentation Result SR = self.unet(images) SR_probs = torch.sigmoid(SR) SR_flat = SR_probs.view(SR_probs.size(0), -1) GT_flat = GT.view(GT.size(0), -1) loss = self.criterion(SR_flat, GT_flat) epoch_loss += loss.item() # Backprop + optimize self.reset_grad() loss.backward() self.optimizer.step() acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length += images.size(0) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length # Print the log info print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % ( epoch+1, self.num_epochs, \ epoch_loss,\ acc,SE,SP,PC,F1,JS,DC)) # Decay learning rate if (epoch + 1) > (self.num_epochs - self.num_epochs_decay): lr -= (self.lr / float(self.num_epochs_decay)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr print('Decay learning rate to lr: {}.'.format(lr)) #===================================== Validation ====================================# self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT) in enumerate(self.valid_loader): images = images.to(self.device) GT = GT.to(self.device) SR = torch.sigmoid(self.unet(images)) acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length += images.size(0) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length unet_score = JS + DC print( '[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (acc, SE, SP, PC, F1, JS, DC)) torchvision.utils.save_image( images.data.cpu(), os.path.join( self.result_path, '%s_valid_%d_image.png' % (self.model_type, epoch + 1))) torchvision.utils.save_image( SR.data.cpu(), os.path.join( self.result_path, '%s_valid_%d_SR.png' % (self.model_type, epoch + 1))) torchvision.utils.save_image( GT.data.cpu(), os.path.join( self.result_path, '%s_valid_%d_GT.png' % (self.model_type, epoch + 1))) #Para guardar modelo y pesos ---- ACTUAL epoca_actual = epoch model_actual = self.unet.state_dict() print('Actual %s model score : %.4f' % (self.model_type, best_unet_score)) torch.save( { 'epoch': epoca_actual, 'model_state_dict': model_actual, ###'optimizer_state_dict': optimizer.state_dict(), ## no reconoce 'loss': loss }, unet_path) # Save Best U-Net model ---- solo si es mejor que el modelo anterior if unet_score > best_unet_score: best_unet_score = unet_score best_epoch = epoch best_unet = self.unet.state_dict() print('Best %s model score : %.4f' % (self.model_type, best_unet_score)) torch.save(best_unet, unet_path) #===================================== Test ====================================# del self.unet del best_unet self.build_model() self.unet.load_state_dict(torch.load(unet_path)) self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT) in enumerate(self.valid_loader): images = images.to(self.device) GT = GT.to(self.device) SR = torch.sigmoid(self.unet(images)) acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length += images.size(0) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length unet_score = JS + DC f = open(os.path.join(self.result_path, 'result.csv'), 'a', encoding='utf-8', newline='') wr = csv.writer(f) wr.writerow([ self.model_type, acc, SE, SP, PC, F1, JS, DC, self.lr, best_epoch, self.num_epochs, self.num_epochs_decay, self.augmentation_prob ]) f.close()
class Solver(object): def __init__(self, config, train_loader, valid_loader, test_loader): # Data loader self.train_loader = train_loader self.valid_loader = valid_loader self.test_loader = test_loader # Models self.unet = None self.optimizer = None self.img_ch = config.img_ch self.output_ch = config.output_ch self.criterion = torch.nn.BCELoss() self.augmentation_prob = config.augmentation_prob # Hyper-parameters self.lr = config.lr self.beta1 = config.beta1 self.beta2 = config.beta2 # Training settings self.num_epochs = config.num_epochs self.num_epochs_decay = config.num_epochs_decay self.batch_size = config.batch_size # Step size self.log_step = config.log_step self.val_step = config.val_step # Path self.model_path = config.model_path self.result_path = config.result_path self.mode = config.mode self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.model_type = config.model_type self.t = config.t self.build_model() def build_model(self): """Build generator and discriminator.""" if self.model_type == 'U_Net': self.unet = U_Net(img_ch=3, output_ch=1) elif self.model_type == 'R2U_Net': self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t) elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=3, output_ch=1) elif self.model_type == 'R2AttU_Net': self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t) self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr, [self.beta1, self.beta2]) self.unet.to(self.device) # self.print_network(self.unet, self.model_type) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def to_data(self, x): """Convert variable to tensor.""" if torch.cuda.is_available(): x = x.cpu() return x.data def update_lr(self, g_lr, d_lr): for param_group in self.optimizer.param_groups: param_group['lr'] = lr def reset_grad(self): """Zero the gradient buffers.""" self.unet.zero_grad() def compute_accuracy(self, SR, GT): SR_flat = SR.view(-1) GT_flat = GT.view(-1) acc = GT_flat.data.cpu() == (SR_flat.data.cpu() > 0.5) def tensor2img(self, x): img = (x[:, 0, :, :] > x[:, 1, :, :]).float() img = img * 255 return img def train(self, pretrain, pre_bestscore): """Train encoder, generator and discriminator.""" #====================================== Training ===========================================# #===========================================================================================# if pretrain == 0: unet_path = os.path.join( self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' % (self.model_type, self.num_epochs, self.lr, self.num_epochs_decay, self.augmentation_prob)) else: unet_path = self.model_path #print(unet_path) # U-Net Train if os.path.isfile(unet_path): # Load the pretrained Encoder self.unet.load_state_dict(torch.load(unet_path)) print('%s is Successfully Loaded from %s' % (self.model_type, unet_path)) # Train for Encoder lr = self.lr best_unet_score = pre_bestscore for epoch in range(self.num_epochs): self.unet.train(True) epoch_loss = 0 acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 #print(self.train_loader) for i, (images, GT, _, _) in enumerate(self.train_loader): # GT : Ground Truth #print(i, (images, GT)) images = images.to(self.device) GT = GT.to(self.device) #print(images.shape, GT.shape) # SR : Segmentation Result SR = self.unet(images) #print(SR.shape) SR_probs = F.sigmoid(SR) SR_flat = SR_probs.view(SR_probs.size(0), -1) GT_flat = GT.view(GT.size(0), -1) loss = self.criterion(SR_flat, GT_flat) epoch_loss += loss.item() # Backprop + optimize self.reset_grad() loss.backward() self.optimizer.step() acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length += images.size(0) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length # Print the log info print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % ( epoch+1, self.num_epochs, \ epoch_loss,\ acc,SE,SP,PC,F1,JS,DC)) # Decay learning rate if (epoch + 1) > (self.num_epochs - self.num_epochs_decay): lr -= (self.lr / float(self.num_epochs_decay)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr print('Decay learning rate to lr: {}.'.format(lr)) #===================================== Validation ====================================# self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT, _, _) in enumerate(self.valid_loader): images = images.to(self.device) GT = GT.to(self.device) SR = self.unet(images) acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length += images.size(0) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length unet_score = JS + DC print( '[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (acc, SE, SP, PC, F1, JS, DC)) ''' torchvision.utils.save_image(images.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_image.png'%(self.model_type,epoch+1))) torchvision.utils.save_image(SR.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_SR.png'%(self.model_type,epoch+1))) torchvision.utils.save_image(GT.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_GT.png'%(self.model_type,epoch+1))) ''' # Save Best U-Net model if unet_score > best_unet_score: best_unet_score = unet_score best_epoch = epoch best_unet = self.unet.state_dict() premodel_unet_path = unet_path[:-4] + '_pretrained' + '.pkl' print( 'Best %s model score : %.4f unet_path is ' % (self.model_type, best_unet_score), premodel_unet_path) torch.save(best_unet, premodel_unet_path) else: # Train for Encoder lr = self.lr best_unet_score = 0. for epoch in range(self.num_epochs): self.unet.train(True) epoch_loss = 0 acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 #print(self.train_loader) for i, (images, GT, _, _) in enumerate(self.train_loader): # GT : Ground Truth #print(i, (images, GT)) images = images.to(self.device) GT = GT.to(self.device) # SR : Segmentation Result SR = self.unet(images) SR_probs = F.sigmoid(SR) SR_flat = SR_probs.view(SR_probs.size(0), -1) GT_flat = GT.view(GT.size(0), -1) loss = self.criterion(SR_flat, GT_flat) epoch_loss += loss.item() # Backprop + optimize self.reset_grad() loss.backward() self.optimizer.step() acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length += images.size(0) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length # Print the log info print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % ( epoch+1, self.num_epochs, \ epoch_loss,\ acc,SE,SP,PC,F1,JS,DC)) # Decay learning rate if (epoch + 1) > (self.num_epochs - self.num_epochs_decay): lr -= (self.lr / float(self.num_epochs_decay)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr print('Decay learning rate to lr: {}.'.format(lr)) #===================================== Validation ====================================# self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT, _, _) in enumerate(self.valid_loader): images = images.to(self.device) GT = GT.to(self.device) SR = self.unet(images) acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length += images.size(0) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length unet_score = JS + DC print( '[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (acc, SE, SP, PC, F1, JS, DC)) ''' torchvision.utils.save_image(images.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_image.png'%(self.model_type,epoch+1))) torchvision.utils.save_image(SR.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_SR.png'%(self.model_type,epoch+1))) torchvision.utils.save_image(GT.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_GT.png'%(self.model_type,epoch+1))) ''' # Save Best U-Net model if unet_score > best_unet_score: best_unet_score = unet_score best_epoch = epoch best_unet = self.unet.state_dict() print('Best %s model score : %.4f' % (self.model_type, best_unet_score)) torch.save(best_unet, unet_path) #===================================== Test ====================================# del self.unet del best_unet self.build_model() self.unet.load_state_dict(torch.load(unet_path)) self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT, _, _) in enumerate(self.valid_loader): images = images.to(self.device) GT = GT.to(self.device) SR = self.unet(images) acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length += images.size(0) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length unet_score = JS + DC f = open(os.path.join(self.result_path, 'result.csv'), 'a', encoding='utf-8', newline='') wr = csv.writer(f) wr.writerow([ self.model_type, acc, SE, SP, PC, F1, JS, DC, self.lr, best_epoch, self.num_epochs, self.num_epochs_decay, self.augmentation_prob ]) f.close() def test(self, unet_path, result_savepath, mask_savepath, pre_savepath, threshold=0.5): self.build_model() self.unet.load_state_dict(torch.load(unet_path)) self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 num_recall = 0 rm_mkdir(result_savepath) rm_mkdir(mask_savepath) rm_mkdir(pre_savepath) for i, (images, GT, HW, filename) in enumerate(self.test_loader): images = images.to(self.device) GT = GT.to(self.device) SR = self.unet(images) acc += get_accuracy(SR, GT, threshold) SE += get_sensitivity(SR, GT, threshold) SP += get_specificity(SR, GT, threshold) PC += get_precision(SR, GT, threshold) F1 += get_F1(SR, GT, threshold) JS += get_JS(SR, GT, threshold) DC += get_DC(SR, GT, threshold) GT_class = torch.max(GT).int() SR_class = torch.max(SR > threshold) GT_class = GT_class.type_as(SR_class) # recall positive if SR_class > 0: num_recall += 1 #print(GT_class, SR_class) SR_PIL_img = saveImg(SR, HW) GT_PIL_img = saveImg_GT(GT, HW) images_PIL_img = saveImg_contour(images, HW) #filename = self.test_loader.dataset.image_paths[i].split('/')[-1][:-4] #print(filename[0]) SR_PIL_img.save(pre_savepath + filename[0] + ".png") GT_PIL_img.save(mask_savepath + filename[0] + "_mask.png") images_PIL_img.save(result_savepath + filename[0] + ".png") length += images.size(0) # images = images.cpu().numpy() #new_img_PIL = torchvision.transforms.ToPILImage()(images[0,:,:,:]).convert('RGB') # scipy.misc.imsave('outfile.jpg', images) #SR_PIL_img = saveImg(SR, HW) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length unet_score = JS + DC recall = num_recall / length print('acc:{} DC:{} F1:{}'.format(acc, DC, F1)) print('length:{} num_recall:{} recall:{}'.format( length, num_recall, recall)) def predict(self, unet_path, raw_savepath, pre_savepath, zeros_savepath, threshold=0.5): self.build_model() self.unet.load_state_dict(torch.load(unet_path)) self.unet.train(False) self.unet.eval() num_recall = 0 length = 0 rm_mkdir(raw_savepath) rm_mkdir(pre_savepath) rm_mkdir(zeros_savepath) zeors_image = np.zeros(shape=(512, 512, 3)) for i, (images, GT, HW, filename) in enumerate(self.test_loader): images = images.to(self.device) SR = self.unet(images) SR_class = torch.max(SR > threshold) if SR_class > 0: num_recall += 1 images_PIL_img = saveImg_contour(images, HW) SR_PIL_img = saveImg(SR, HW) #filename = self.test_loader.dataset.image_paths[i].split('/')[-1][:-4] images_PIL_img.save(raw_savepath + filename[0] + ".png") SR_PIL_img.save(pre_savepath + filename[0] + ".png") else: scipy.misc.imsave(zeros_savepath + filename[0] + ".png", zeors_image) length += images.size(0) #print(HW) recall = num_recall / length print('length:{} num_recall:{} recall:{}'.format( length, num_recall, recall))
class Solver(object): def __init__(self, config, train_valid_loader): # Data loader self.train_valid_loader = train_valid_loader # Models self.unet = None self.optimizer = None self.img_ch = config.img_ch self.output_ch = config.output_ch self.criterion = torch.nn.BCELoss() self.augmentation_prob = config.augmentation_prob # Hyper-parameters self.lr = config.lr self.beta1 = config.beta1 self.beta2 = config.beta2 # Training settings self.n_splits = config.n_splits self.num_epochs = config.num_epochs self.num_epochs_decay = config.num_epochs_decay self.batch_size = config.batch_size # Step size self.log_step = config.log_step self.val_step = config.val_step # Path self.model_path = config.model_path self.result_path = config.result_path self.mode = config.mode self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.model_type = config.model_type self.t = config.t self.build_model() def build_model(self): """Build generator and discriminator.""" if self.model_type == 'U_Net': self.unet = U_Net(img_ch=3, output_ch=self.output_ch) elif self.model_type == 'R2U_Net': self.unet = R2U_Net(img_ch=3, output_ch=self.output_ch, t=self.t) elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=3, output_ch=self.output_ch) elif self.model_type == 'R2AttU_Net': self.unet = R2AttU_Net(img_ch=3, output_ch=self.output_ch, t=self.t) self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr, [self.beta1, self.beta2]) self.unet.to(self.device) # self.print_network(self.unet, self.model_type) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def to_data(self, x): """Convert variable to tensor.""" if torch.cuda.is_available(): x = x.cpu() return x.data def update_lr(self, g_lr, d_lr): for param_group in self.optimizer.param_groups: param_group['lr'] = lr def reset_grad(self): """Zero the gradient buffers.""" self.unet.zero_grad() def reset_model(self): for layer in self.unet.children(): if hasattr(layer, 'reset_parameters'): layer.reset_parameters() self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr, [self.beta1, self.beta2]) def compute_accuracy(self, SR, GT): SR_flat = SR.view(-1) GT_flat = GT.view(-1) acc = GT_flat.data.cpu() == (SR_flat.data.cpu() > 0.5) def tensor2img(self, x): img = (x[:, 0, :, :] > x[:, 1, :, :]).float() img = img * 255 return img def train(self): """Train encoder, generator and discriminator.""" #====================================== Training ===========================================# #===========================================================================================# # U-Net Train if False: pass #if os.path.isfile(unet_path): # # Load the pretrained Encoder # self.unet.load_state_dict(torch.load(unet_path)) # print('%s is Successfully Loaded from %s'%(self.model_type,unet_path)) else: # Train for Encoder kfold = KFold(n_splits=self.n_splits, shuffle=True) for fold, (train_index, valid_index) in enumerate( kfold.split(self.train_valid_loader.dataset)): print(f"Fold{fold} start") logging.info(f"Fold{fold} start") logging.info(f"train: {train_index} valid: {valid_index}") self.reset_model() lr = self.lr best_unet_score = 0. self.unet.train(True) epoch_loss = 0 for epoch in range(int(self.num_epochs / self.n_splits)): acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 # GT : Ground Truth train_sampler = SubsetRandomSampler(train_index) valid_sampler = SubsetRandomSampler(valid_index) train_loader = torch.utils.data.DataLoader( self.train_valid_loader.dataset, batch_size=self.batch_size, sampler=train_sampler) valid_loader = torch.utils.data.DataLoader( self.train_valid_loader.dataset, batch_size=self.batch_size, sampler=valid_sampler) for i, (images, GT) in enumerate(train_loader): images = images.to(self.device) GT = GT.to(self.device) # SR : Segmentation Result SR = self.unet(images) SR_probs = torch.sigmoid(SR) SR_flat = SR_probs.view(SR_probs.size(0), -1) GT_flat = GT.view(GT.size(0), -1) loss = self.criterion(SR_flat, GT_flat) epoch_loss += loss.item() # Backprop + optimize self.reset_grad() loss.backward() self.optimizer.step() acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length += images.size(0) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length print("training ", DC, length) # Print the log info print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % ( fold*self.num_epochs/self.n_splits+epoch+1, self.num_epochs, \ epoch_loss,\ acc,SE,SP,PC,F1,JS,DC)) logging.info('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % ( fold*self.num_epochs/self.n_splits+epoch+1, self.num_epochs, \ epoch_loss,\ acc,SE,SP,PC,F1,JS,DC)) # Decay learning rate if (epoch * self.n_splits + fold + 1) > (self.num_epochs - self.num_epochs_decay): lr -= (self.lr / float(self.num_epochs_decay)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr print('Decay learning rate to lr: {}.'.format(lr)) #===================================== Validation ====================================# self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT) in enumerate(valid_loader): images = images.to(self.device) GT = GT.to(self.device) SR = torch.sigmoid(self.unet(images)) acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length += images.size(0) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length unet_score = JS + DC print("valid ", DC, length) print( '[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (acc, SE, SP, PC, F1, JS, DC)) logging.info( '[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (acc, SE, SP, PC, F1, JS, DC)) ''' torchvision.utils.save_image(images.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_image.png'%(self.model_type,epoch+1))) torchvision.utils.save_image(SR.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_SR.png'%(self.model_type,epoch+1))) torchvision.utils.save_image(GT.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_GT.png'%(self.model_type,epoch+1))) ''' # Save Best U-Net model print(f'model score: {unet_score} ({best_unet_score})') logging.info( (f'model score: {unet_score} ({best_unet_score})')) if unet_score > best_unet_score: best_unet_score = unet_score best_epoch = fold * self.num_epochs / self.n_splits + epoch + 1 best_unet = self.unet.state_dict() print('Best %s model score : %.4f' % (self.model_type, best_unet_score)) logging.info('Best %s model score : %.4f' % (self.model_type, best_unet_score)) unet_path = os.path.join( self.model_path, '%s-f%d-%d-%.3f.pkl' % (self.model_type, fold, best_epoch, DC)) torch.save(best_unet, unet_path) #===================================== Test ====================================# """
class Solver(object): def __init__(self, config, train_loader, valid_loader, test_loader, whole_slice_prediction_loader): # Data loader self.train_loader = train_loader self.valid_loader = valid_loader self.test_loader = test_loader self.whole_slice_prediction_loader = whole_slice_prediction_loader # Models self.unet = None self.optimizer = None self.img_ch = config.img_ch self.output_ch = config.output_ch #self.criterion = torch.nn.BCELoss() self.augmentation_prob = config.augmentation_prob self.inverse_ratio = config.inverse_ratio # Hyper-parameters self.initial_lr = config.lr self.current_lr = config.lr self.optimizer_choice = config.optimizer_choice if config.optimizer_choice == 'Adam': self.beta1 = config.beta1 self.beta2 = config.beta2 elif config.optimizer_choice == 'SGD': self.momentum = config.momentum else: print('No such optimizer available') # Training settings self.num_epochs = config.num_epochs #self.num_epochs_decay = config.num_epochs_decay self.batch_size = config.batch_size self.PPorLS = config.PPorLS # Step size self.log_step = config.log_step self.val_step = config.val_step self.batch_val_num = config.val_freq_batch # Path self.model_path = config.model_path self.result_path = config.result_path self.result_img_path = config.result_img_path self.mode = config.mode self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.model_type = config.model_type self.t = config.t self.build_model() def build_model(self): """Build generator and discriminator.""" if self.model_type == 'U_Net': self.unet = U_Net(img_ch=1, output_ch=1) elif self.model_type == 'R2U_Net': self.unet = R2U_Net(img_ch=1, output_ch=1, t=self.t) elif self.model_type == 'AttU_Net': self.unet = AttU_Net(img_ch=1, output_ch=1) elif self.model_type == 'R2AttU_Net': self.unet = R2AttU_Net(img_ch=1, output_ch=1, t=self.t) if self.optimizer_choice == 'Adam': self.optimizer = optim.Adam(list(self.unet.parameters()), self.initial_lr, [self.beta1, self.beta2]) elif self.optimizer_choice == 'SGD': self.optimizer = optim.SGD(list(self.unet.parameters()), self.initial_lr, self.momentum) else: pass self.unet.to(self.device) #self.print_network(self.unet, self.model_type) def print_network(self, model, name): """Print out the network information.""" num_params = 0 for p in model.parameters(): num_params += p.numel() print(model) print(name) print("The number of parameters: {}".format(num_params)) def dice_coeff_loss(self, y_pred, y_true): smooth = 1 y_true_flat = y_true.view(y_true.size(0), -1) y_pred_flat = y_pred.view(y_pred.size(0), -1) intersection = (y_true_flat * y_pred_flat).sum() return -(2. * intersection + smooth) / ((y_true_flat).sum() + (y_pred_flat).sum() + smooth) def RR_dice_coeff_loss(self, y_pred, y_true): smooth = 1e-6 y_true_flat = y_true.view(y_true.size(0), -1) y_pred_flat = y_pred.view(y_pred.size(0), -1) intersection = (y_true_flat * y_pred_flat).sum() inverse_y_true_flat = 1 - y_true_flat inverse_y_pred_flat = 1 - y_pred_flat inverse_intersection = (inverse_y_true_flat * inverse_y_pred_flat).sum() return -(2. * intersection + smooth) / ( (y_true_flat).sum() + (y_pred_flat).sum() + smooth) - (2. * inverse_intersection + smooth) / ( (inverse_y_true_flat).sum() + (inverse_y_pred_flat).sum() + smooth) def to_data(self, x): """Convert variable to tensor.""" if torch.cuda.is_available(): x = x.cpu() return x.data # Redefine the 'update_lr' function (R&R) def update_lr(self, new_lr): for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr def run_batch_validation(self, epoch, batch_train): self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient DC_RR = 0 length = 0 validation_batch_loss = 0 for batch, (images, GT) in enumerate(self.valid_loader): images = images.to(self.device) GT = GT.to(self.device) # Reshape the images and GT to 4-dimensional so that they can get fed to the conv2d layer. images = images.reshape(self.batch_size, self.img_ch, np.shape(images)[1], np.shape(images)[2]) GT = GT.reshape(self.batch_size, self.img_ch, np.shape(GT)[1], np.shape(GT)[2]) #SR = F.sigmoid(self.unet(images)) SR = torch.sigmoid(self.unet(images)) acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC_RR += get_DC_RR(SR, GT, inverse_ratio=self.inverse_ratio) DC += get_DC(SR, GT) length += images.size(0) # Compute the validation loss. SR = self.unet(images) SR_probs = torch.sigmoid(SR) SR_flat = SR_probs.view(SR_probs.size(0), -1) GT_flat = GT.view(GT.size(0), -1) # use the dice coefficient loss instead of the BCE loss. (R&R) validation_loss = self.dice_coeff_loss(SR_flat, GT_flat) validation_batch_loss += validation_loss.item() acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length DC_RR = DC_RR / length unet_score = DC_RR print('current batch: {}'.format(batch_train)) print('Current learning rate: {}'.format(self.current_lr)) print( 'Current Batch [%d] \n[Validation] Validation Loss: %.4f, Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f, DC_RR: %.4f' % (batch_train + 1, validation_batch_loss, acc, SE, SP, PC, F1, JS, DC, DC_RR)) # Append validation loss to train loss history (R&R) f = open(os.path.join(self.result_path, 'model_validation_batch_history.csv'), 'a', encoding='utf-8', newline='') wr = csv.writer(f) wr.writerow([ 'Validation', 'Epoch [%d/%d]' % (epoch + 1, self.num_epochs), 'Batch [%d]' % (batch_train + 1), 'Validation loss: %.4f' % validation_batch_loss, 'Accuracy: %.4f' % acc, 'Sensitivity: %.4f' % SE, 'Specificity: %.4f' % SP, 'Precision: %.4f' % PC, 'F1 Score: %.4f' % F1, 'Jaccard Similarity: %.4f' % JS, 'Dice Coefficient: %.4f' % DC, 'RR_DC: %.4f' % DC_RR ]) self.unet.train(True) return (validation_batch_loss, unet_score) # Define adaptive learning rate handler (R&R) def adaptive_lr_handler(self, cooldown, min_lr, current_epoch, previous_update_epoch, plateau_ratio, adjustment_ratio, loss_history): if current_epoch > 1: if current_epoch - previous_update_epoch > cooldown: if (loss_history[-1] > loss_history[-2]) or (abs( (loss_history[-2] - loss_history[-1]) / loss_history[-2]) < plateau_ratio): if self.current_lr > min_lr: self.current_lr = adjustment_ratio * self.current_lr self.update_lr(self.current_lr) print( 'Validation loss stop decreasing. Adjust the learning rate to {}.' .format(self.current_lr)) return current_epoch def reset_grad(self): """Zero the gradient buffers.""" self.unet.zero_grad() def tensor2img(self, x): img = (x[:, 0, :, :] > x[:, 1, :, :]).float() img = img * 255 return img def train(self): """Train encoder, generator and discriminator.""" #====================================== Training ===========================================# #===========================================================================================# unet_path = os.path.join( self.model_path, '%s-%d-%.4f-%.4f-%s-%s-%.4f.pkl' % (self.model_type, self.num_epochs, self.initial_lr, self.augmentation_prob, self.PPorLS, self.optimizer_choice, self.inverse_ratio)) print('The U-Net path is {}'.format(unet_path)) # U-Net Train # Train loss history (R&R) train_loss_history = [] train_batch_loss_history = [] # Validation loss history (R&R) validation_loss_history = [] val_batch_loss_history = [] stop_training = False if os.path.isfile(unet_path): # Load the pretrained Encoder self.unet.load_state_dict(torch.load(unet_path)) print('%s is Successfully Loaded from %s' % (self.model_type, unet_path)) else: # Train for Encoder best_unet_score = 0. print('Start training. The initial learning rate is: {}'.format( self.initial_lr)) for epoch in range(self.num_epochs): self.unet.train(True) train_epoch_loss = 0 validation_epoch_loss = 0 if stop_training == True: break acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient DC_RR = 0 length = 0 for batch, (images, GT) in enumerate(self.train_loader): # GT : Ground Truth images = images.to(self.device) GT = GT.to(self.device) # Reshape the images and GT to 4-dimensional so that they can get fed to the conv2d layer. (R&R) images = images.reshape(self.batch_size, self.img_ch, np.shape(images)[1], np.shape(images)[2]) GT = GT.reshape(self.batch_size, self.img_ch, np.shape(GT)[1], np.shape(GT)[2]) # SR : Segmentation Result SR = self.unet(images) SR_probs = torch.sigmoid(SR) SR_flat = SR_probs.view(SR_probs.size(0), -1) GT_flat = GT.view(GT.size(0), -1) # Use dice coefficient loss instead of the BCE loss. (R&R) train_loss = self.dice_coeff_loss(SR_flat, GT_flat) train_epoch_loss += train_loss.item() # Backprop + optimize self.reset_grad() train_loss.backward() self.optimizer.step() acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC_RR += get_DC_RR(SR, GT, inverse_ratio=self.inverse_ratio) DC += get_DC(SR, GT) length += images.size(0) if epoch == 0: val_frequency = self.batch_val_num[0] else: val_frequency = self.batch_val_num[1] if batch % val_frequency == 0: # update learning rate and record the validation loss history validation_batch_loss, unet_score = self.run_batch_validation( epoch, batch) val_batch_loss_history.append(validation_batch_loss) train_batch_loss_history.append(train_epoch_loss) if unet_score > best_unet_score: best_unet_score = unet_score best_epoch = epoch best_unet = self.unet.state_dict() print('Best %s model score : %.4f' % (self.model_type, best_unet_score)) torch.save(best_unet, unet_path) # update learning rate batch_id = len(val_batch_loss_history) try: previous_batch_id = self.adaptive_lr_handler( 3, 0.01 * self.initial_lr, batch_id, previous_batch_id, 0.001, 0.5, val_batch_loss_history) except: previous_batch_id = self.adaptive_lr_handler( 3, 0.01 * self.initial_lr, batch_id, 0, 0.001, 0.5, val_batch_loss_history) if ((batch_id - 4) % 10 == 0) and ( batch_id > 8) or unet_score < 0.2 * best_unet_score: if (np.median(val_batch_loss_history[-10:-5]) >= np.median(val_batch_loss_history[-5:])): print( 'Validation loss stop decreasing. Stop training.' ) stop_training = True break if stop_training == True: break acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length DC_RR = DC_RR / length # Print the log info print('Epoch [%d/%d] \n[Training] Train Loss: %.4f, Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f, DC_RR: %.4f' % (\ epoch + 1, self.num_epochs, train_epoch_loss,\ acc, SE, SP, PC, F1, JS, DC, DC_RR)) # Append train loss to train loss history (R&R) train_loss_history.append(train_epoch_loss) f = open(os.path.join(self.result_path, 'train_and_validation_history.csv'), 'a', \ encoding = 'utf-8', newline= '') wr = csv.writer(f) wr.writerow(['Training', 'Epoch [%d/%d]' % (epoch + 1, self.num_epochs), \ 'Train loss: %.4f' % train_epoch_loss,\ 'Accuracy: %.4f' % acc, 'Sensitivity: %.4f' % SE, 'Specificity: %.4f' % SP, 'Precision: %.4f'% PC, \ 'F1 Score: %.4f' % F1, 'Jaccard Similarity: %.4f' % JS, 'Dice Coefficient: %.4f' % DC, 'RR_DC: %.4f' % DC_RR]) f.close() #===================================== Validation ====================================# self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient DC_RR = 0 length = 0 for batch, (images, GT) in enumerate(self.valid_loader): images = images.to(self.device) GT = GT.to(self.device) # Reshape the images and GT to 4-dimensional so that they can get fed to the conv2d layer. images = images.reshape(self.batch_size, self.img_ch, np.shape(images)[1], np.shape(images)[2]) GT = GT.reshape(self.batch_size, self.img_ch, np.shape(GT)[1], np.shape(GT)[2]) #SR = F.sigmoid(self.unet(images)) SR = torch.sigmoid(self.unet(images)) acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC_RR += get_DC_RR(SR, GT, inverse_ratio=self.inverse_ratio) DC += get_DC(SR, GT) length += images.size(0) # Compute the validation loss. SR = self.unet(images) SR_probs = torch.sigmoid(SR) SR_flat = SR_probs.view(SR_probs.size(0), -1) GT_flat = GT.view(GT.size(0), -1) # use the dice coefficient loss instead of the BCE loss. (R&R) validation_loss = self.dice_coeff_loss(SR_flat, GT_flat) validation_epoch_loss += validation_loss.item() acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length DC_RR = DC_RR / length unet_score = DC_RR print('Current learning rate: {}'.format(self.current_lr)) print( 'Epoch [%d/%d] \n[Validation] Validation Loss: %.4f, Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f, DC_RR: %.4f' % (epoch + 1, self.num_epochs, validation_epoch_loss, acc, SE, SP, PC, F1, JS, DC, DC_RR)) # Append validation loss to train loss history (R&R) validation_loss_history.append(validation_epoch_loss) ''' torchvision.utils.save_image(images.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_image.png'%(self.model_type,epoch+1))) torchvision.utils.save_image(SR.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_SR.png'%(self.model_type,epoch+1))) torchvision.utils.save_image(GT.data.cpu(), os.path.join(self.result_path, '%s_valid_%d_GT.png'%(self.model_type,epoch+1))) ''' f = open(os.path.join(self.result_path, 'train_and_validation_history.csv'), 'a', \ encoding = 'utf-8', newline= '') wr = csv.writer(f) wr.writerow(['Validation', 'Epoch [%d/%d]' % (epoch + 1, self.num_epochs), \ 'Validation loss: %.4f' % validation_epoch_loss,\ 'Accuracy: %.4f' % acc, 'Sensitivity: %.4f' % SE, 'Specificity: %.4f' % SP, 'Precision: %.4f'% PC, \ 'F1 Score: %.4f' % F1, 'Jaccard Similarity: %.4f' % JS, 'Dice Coefficient: %.4f' % DC, 'RR_DC: %.4f' % DC_RR]) f.close() # Save Best U-Net model if unet_score > best_unet_score: best_unet_score = unet_score best_epoch = epoch best_unet = self.unet.state_dict() print('Best %s model score : %.4f' % (self.model_type, best_unet_score)) torch.save(best_unet, unet_path) # Early stop (R&R) #if (epoch > 8) and ((epoch - 4) % 5 == 0): # if (np.median(validation_loss_history[-10:-5]) >= np.median(validation_loss_history[-5:])): # print('Validation loss stop decreasing. Stop training.') # break if (len(validation_loss_history) > 1): if (validation_loss_history[-2] >= validation_loss_history[-1]): print( 'Validation loss stop decreasing. Stop training.') break del self.unet try: del best_unet except: print( 'Cannot delete the variable "best_unet": variable does not exist.' ) return train_loss_history, validation_loss_history, val_batch_loss_history, train_batch_loss_history def test(self): """Test encoder, generator and discriminator.""" #======================================= Test ====================================# #=================================================================================# unet_path = os.path.join( self.model_path, '%s-%d-%.4f-%.4f-%s-%s-%.4f.pkl' % (self.model_type, self.num_epochs, self.initial_lr, self.augmentation_prob, self.PPorLS, self.optimizer_choice, self.inverse_ratio)) self.build_model() self.unet.load_state_dict(torch.load(unet_path)) self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient DC_RR = 0 length = 0 for i, (images, GT) in enumerate(self.test_loader): images = images.to(self.device) GT = GT.to(self.device) # Reshape the images and GT to 4-dimensional so that they can get fed to the conv2d layer. images = images.reshape(self.batch_size, self.img_ch, np.shape(images)[1], np.shape(images)[2]) GT = GT.reshape(self.batch_size, self.img_ch, np.shape(GT)[1], np.shape(GT)[2]) #SR = F.sigmoid(self.unet(images)) SR = torch.sigmoid(self.unet(images)) acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC_RR += get_DC_RR(SR, GT, inverse_ratio=self.inverse_ratio) DC += get_DC(SR, GT) length += images.size(0) np_img = np.squeeze(SR.cpu().detach().numpy()) np.save(self.result_img_path + str(i) + '.npy', np_img) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length DC_RR = DC_RR / length print('model type: ', self.model_type, 'accuracy: ', acc, 'sensitivity: ', SE, 'specificity: ', SP, 'precision: ', PC, 'F1 score: ', F1, 'Jaccard similarity: ', JS, 'Dice Coefficient: ', DC, 'DC_RR: ', DC_RR) result_csv_path = '/home/raphael/Projects/DL-Lung_Nodule_LUNA16/Solutions/RaphaelRosalie-solution/patch-based_U-net/results/' f = open(os.path.join(result_csv_path, 'result_compare.csv'), 'a', encoding='utf-8', newline='') wr = csv.writer(f) wr.writerow([self.model_type, self.PPorLS, 'Accuracy: %.4f' % acc, 'Sensitivity: %.4f' % SE, 'Specificity: %.4f' % SP, 'Precision: %.4f'% PC, \ 'F1 Score: %.4f' % F1, 'Jaccard Similarity: %.4f' % JS, 'Dice Coefficient: %.4f' % DC, 'RR_DC: %.4f' % DC_RR, 'inverse_ratio: %.3f' % self.inverse_ratio]) f.close() def whole_slice_prediction(self): """Inference mode. Return whole slice prediction as a binary nodule mask.""" unet_path = os.path.join( self.model_path, '%s-%d-%.4f-%.4f-%s-%s.pkl' % (self.model_type, self.num_epochs, self.initial_lr, self.augmentation_prob, self.PPorLS, self.optimizer_choice)) self.build_model() self.unet.load_state_dict(torch.load(unet_path)) self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient DC_RR = 0 length = 0 for batch, (images, GT) in enumerate(self.whole_slice_prediction_loader): images = images.to(self.device) GT = GT.to(self.device) # Reshape the images and GT to 4-dimensional so that they can get fed to the conv2d layer. images = images.reshape(self.batch_size, self.img_ch, np.shape(images)[1], np.shape(images)[2]) GT = GT.reshape(self.batch_size, self.img_ch, np.shape(GT)[1], np.shape(GT)[2]) #SR = F.sigmoid(self.unet(images)) SR = torch.sigmoid(self.unet(images)) acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC_RR += get_DC_RR(SR, GT) DC += get_DC(SR, GT) length += images.size(0) np_img = np.squeeze(SR.cpu().detach().numpy()) np.save(self.result_img_path + str(i) + '.npy', np_img) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length DC_RR = DC_RR / length unet_score = DC