def predict2d(self,test_dataset,output='pic'): '''take 2d image and perform super resolution''' # networks self.G = Generator(num_channels=self.num_channels, base_filter=self.filter, num_residuals=self.num_residuals,scale_factor=self.scale_factor,kernel=3) if self.gpu_mode: print('gpu mode') self.G.cuda() # load model self.load_model() image_dir=join(self.data_dir, test_dataset) image_filenames=[] image_filenames.extend(join(image_dir, x) for x in sorted(listdir(image_dir)) if utils.is_raw_file(x)) img_num=0 for img_fn in image_filenames: print(img_fn) img = utils.read_and_reshape(img_fn).astype(float) minvalue=img.min() maxvalue=img.max() img = transforms_3d.rescale(img,original_scale=(minvalue,maxvalue),new_scale=(0,1)) img = Image.fromarray(img) lr_transform = Compose([ToTensor()]) lr_img = lr_transform(img) if self.num_channels == 1: y_ = lr_img.unsqueeze(0) else: raise Exception("only accept 2d raw image file " ) if self.gpu_mode: y_ = y_.cuda() # prediction self.G.eval() recon_img = self.G(y_) recon_img = recon_img.cpu()[0].clamp(0, 1).detach().numpy() if output=='raw': recon_img = transforms_3d.rescale(recon_img,original_scale=(0,1),new_scale=(minvalue,maxvalue)).astype(int) img_filename=img_fn.split('\\')[-1] save_dir = os.path.join(self.save_dir, 'SR-2D-raw',test_dataset) if not os.path.exists(save_dir): os.makedirs(save_dir) utils.save_as_raw(recon_img,os.path.join(save_dir,img_filename),dtype='uint16',prefix='SR') else: recon_img=torch.from_numpy(recon_img) save_dir = os.path.join(self.save_dir, 'SR-2D-png', test_dataset) network_utils.save_img(recon_img, img_num, save_dir=save_dir) img_num+=1 torch.cuda.empty_cache() print('Single test result image is saved.')
def train(self): #defining weight factor for GAN loss, MSE loss and VGG loss for the loss function and label smoothing factor for discriminator gan_factor=0.1 mse_factor=1 vgg_factor=self.vgg_factor smooth_factor=0.1 train_dataset=[] # load dataset train_data_loader = self.load_ct_dataset(dataset=self.train_dataset, is_train=True, is_registered=self.registered, grayscale_corrected=self.grayscale_corrected) test_data_loader = self.load_ct_dataset(dataset=self.test_dataset, is_train=False, is_registered=self.registered, grayscale_corrected=self.grayscale_corrected) # networks self.G = Generator(num_channels=self.num_channels, base_filter=self.filter, num_residuals=self.num_residuals,scale_factor=self.scale_factor,kernel=self.kernel) self.D = Discriminator(num_channels=self.num_channels, base_filter=self.filter, image_size=self.crop_size) # weigh initialization self.G.weight_init() self.D.weight_init() # For the content loss self.feature_extractor = FeatureExtractor(models.vgg19(pretrained=True),feature_layer=self.vgg_layer) # optimizer self.G_optimizer = optim.Adam(self.G.parameters(), lr=self.lr, betas=(0.9, 0.999)) self.D_optimizer = optim.Adam(self.D.parameters(), lr=self.lr*self.lr_d, betas=(0.9, 0.999)) # loss function if self.gpu_mode: self.G.cuda() self.D.cuda() self.feature_extractor.cuda() self.L1_loss = nn.L1Loss().cuda() self.MSE_loss = nn.MSELoss().cuda() self.BCE_loss = nn.BCELoss().cuda() else: self.MSE_loss = nn.MSELoss() self.BCE_loss = nn.BCELoss() self.L1_loss = nn.L1Loss() print('---------- Networks architecture -------------') network_utils.print_network(self.G) network_utils.print_network(self.D) print('----------------------------------------------') ################# Pre-train generator ################# # Load pre-trained parameters of generator if not self.load_model(is_pretrain=True): # Pre-training generator for 50 epochs print('Pre-training is started.') self.G.train() for epoch in range(self.epoch_pretrain): for iter, (lr, hr, _) in enumerate(train_data_loader): # input data (low resolution image) if self.num_channels == 1: x_ = hr y_ = lr #x_ = network_utils.norm(hr.repeat(1,3,1,1), vgg=True) #x_ = torch.mean(x_,1,True) #y_ = network_utils.norm(lr.repeat(1,3,1,1), vgg=True) #y_ = torch.mean(y_,1, True) else: x_ = network_utils.norm(hr, vgg=True) y_ = network_utils.norm(lr, vgg=True) if self.gpu_mode: x_ = x_.cuda() y_ = y_.cuda() # Train generator self.G_optimizer.zero_grad() recon_image = self.G(y_) # Content losses content_loss = self.L1_loss(recon_image, x_) # Back propagation G_loss_pretrain = content_loss G_loss_pretrain.backward() self.G_optimizer.step() # log print("Epoch: [%2d] [%4d/%4d] G_loss_pretrain: %.8f" % ((epoch + 1), (iter + 1), len(train_data_loader), G_loss_pretrain.item())) print('Pre-training is finished.') # Save pre-trained parameters of generator self.save_model(is_pretrain=True) ################# Adversarial train ################# print('Training is started.') # Avg. losses G_avg_loss = [] D_avg_loss = [] step = 0 # test image test_lr, test_hr, test_bc = test_data_loader.dataset.__getitem__(20) test_lr = test_lr.unsqueeze(0) test_hr = test_hr.unsqueeze(0) test_bc = test_bc.unsqueeze(0) self.G.train() self.D.train() for epoch in range(self.num_epochs): self.G.train() self.D.train() if epoch==0: start_time=time.time() # learning rate is decayed by a factor of 2 every 40 epoch if (epoch + 1) % 40 == 0: for param_group in self.G_optimizer.param_groups: param_group["lr"] /= 2.0 print("Learning rate decay for G: lr={}".format(self.G_optimizer.param_groups[0]["lr"])) for param_group in self.D_optimizer.param_groups: param_group["lr"] /= 2.0 print("Learning rate decay for D: lr={}".format(self.D_optimizer.param_groups[0]["lr"])) G_epoch_loss = 0 D_epoch_loss = 0 for iter, (lr, hr, _) in enumerate(train_data_loader): # input data (low resolution image) mini_batch = lr.size()[0] if self.num_channels == 1: x_ = hr y_ = lr else: x_ = network_utils.norm(hr, vgg=True) y_ = network_utils.norm(lr, vgg=True) if self.gpu_mode: x_ = x_.cuda() y_ = y_.cuda() # labels real_label = torch.ones(mini_batch).cuda() fake_label = torch.zeros(mini_batch).cuda() else: # labels real_label = torch.ones(mini_batch) fake_label = torch.zeros(mini_batch) # Reset gradient self.D_optimizer.zero_grad() # Train discriminator with real data D_real_decision = self.D(x_) D_real_loss = self.BCE_loss(D_real_decision.squeeze(),real_label*(1.0-smooth_factor)) # Train discriminator with fake data recon_image = self.G(y_) D_fake_decision = self.D(recon_image) D_fake_loss = self.BCE_loss(D_fake_decision.squeeze(), fake_label) D_loss = (D_real_loss + D_fake_loss)*gan_factor # Back propagation D_loss.backward() self.D_optimizer.step() # Reset gradient self.G_optimizer.zero_grad() # Train generator recon_image = self.G(y_) D_fake_decision = self.D(recon_image) # Adversarial loss GAN_loss = self.BCE_loss(D_fake_decision.squeeze(), real_label) # Content losses mse_loss = self.L1_loss(recon_image, x_) if self.num_channels == 1: x_VGG=hr.repeat(1,3,1,1).cpu() x_VGG = network_utils.norm(x_VGG, vgg=True) recon_VGG=recon_image.repeat(1,3,1,1).cpu() recon_VGG = network_utils.norm(recon_VGG, vgg=True) else: x_VGG = network_utils.norm(hr.cpu(), vgg=True) recon_VGG = network_utils.norm(recon_image.cpu(), vgg=True) if self.gpu_mode: x_VGG=x_VGG.cuda() recon_VGG=recon_VGG.cuda() real_feature = self.feature_extractor(x_VGG) fake_feature = self.feature_extractor(recon_VGG) vgg_loss = self.L1_loss(fake_feature, real_feature.detach()) # Back propagation mse_loss=mse_factor*mse_loss vgg_loss=vgg_factor*vgg_loss GAN_loss=gan_factor*GAN_loss G_loss = mse_loss + vgg_loss + GAN_loss G_loss.backward() self.G_optimizer.step() # log G_epoch_loss += G_loss.item() D_epoch_loss += D_loss.item() #print("Epoch: [%2d] [%4d/%4d] G_loss: %.8f, D_loss: %.8f" # % ((epoch + 1), (iter + 1), len(train_data_loader), G_loss.item(), D_loss.item())) print("Epoch: [%2d] [%4d/%4d] G_loss: %.8f, mse: %.4f,vgg: %.4f, gan: %.4f,D_loss: %.8f" % ((epoch + 1), (iter + 1), len(train_data_loader), G_loss.item(), mse_loss.item(),vgg_loss.item(),GAN_loss.item(),D_loss.item())) step += 1 # avg. loss per epoch G_avg_loss.append(G_epoch_loss / len(train_data_loader)) D_avg_loss.append(D_epoch_loss / len(train_data_loader)) # prediction if self.num_channels == 1: y_ = test_lr #y_ = network_utils.norm(test_lr.repeat(1,3,1,1), vgg=True) #y_ = torch.mean(y_,1,True) else: y_ = network_utils.norm(test_lr, vgg=True) if self.gpu_mode: y_ = y_.cuda() recon_img = self.G(y_) if self.num_channels == 1: sr_img=recon_img.cpu() #sr_img=network_utils.denorm(recon_img.repeat(1,3,1,1).cpu(),vgg=True) #sr_img=torch.mean(sr_img,1,True) else: sr_img = network_utils.denorm(recon_img.cpu(), vgg=True) sr_img=sr_img[0] # save result image save_dir = os.path.join(self.save_dir, 'train_result') network_utils.save_img(sr_img, epoch + 1, save_dir=save_dir, is_training=True) if epoch==0: print('time for 1 epoch is :%.2f'%(time.time()-start_time)) print('Result image at epoch %d is saved.' % (epoch + 1)) # Save trained parameters of model if (epoch + 1) % self.save_epochs == 0: self.save_model(epoch + 1) # calculate psnrs if self.num_channels == 1: gt_img = test_hr[0][0].unsqueeze(0) lr_img = test_lr[0][0].unsqueeze(0) bc_img = test_bc[0][0].unsqueeze(0) else: gt_img = test_hr[0] lr_img = test_lr[0] bc_img = test_bc[0] if self.metric=='sc': bc_metric = network_utils.SC(bc_img, gt_img) recon_metric = network_utils.SC(sr_img, gt_img) elif self.metric=='ssim': bc_metric = network_utils.SSIM(bc_img, gt_img) recon_metric = network_utils.SSIM(sr_img, gt_img) else: bc_metric = network_utils.PSNR(bc_img, gt_img) recon_metric = network_utils.PSNR(sr_img, gt_img) # plot result images result_imgs = [gt_img, lr_img, bc_img, sr_img] metrics = [None, None, bc_metric, recon_metric] network_utils.plot_test_result(result_imgs, metrics, self.num_epochs, save_dir=save_dir, is_training=True, index=self.metric) print('Training result image is saved.') # Plot avg. loss network_utils.plot_loss([G_avg_loss, D_avg_loss], self.num_epochs, save_dir=self.save_dir) print("Training is finished.") # Save final trained parameters of model self.save_model(epoch=None)
def test(self,test_dataset): # networks self.G = Generator(num_channels=self.num_channels, base_filter=self.filter, num_residuals=self.num_residuals,scale_factor=self.scale_factor,kernel=3) if self.gpu_mode: self.G.cuda() # load model self.load_model() # load dataset test_data_loader = self.load_ct_dataset(dataset=[test_dataset], is_train=False, is_registered=self.registered, grayscale_corrected=self.grayscale_corrected) # Test print('Test is started.') img_num = 0 total_img_num = len(test_data_loader) self.G.eval() metric=[] for lr, hr, bc in test_data_loader: # input data (low resolution image) if self.num_channels == 1: y_ = lr[:, 0].unsqueeze(1) else: y_ = network_utils.norm(lr, vgg=True) if self.gpu_mode: y_ = y_.cuda() # prediction recon_imgs = self.G(y_) if self.num_channels == 1: recon_imgs=recon_imgs.cpu() else: recon_imgs = network_utils.denorm(recon_imgs.cpu(), vgg=True) for i, recon_img in enumerate(recon_imgs): img_num += 1 sr_img = recon_img # save result image save_dir = os.path.join(self.save_dir, test_dataset) network_utils.save_img(sr_img, img_num, save_dir=save_dir) # calculate psnrs if self.num_channels == 1: gt_img = hr[i][0].unsqueeze(0) lr_img = lr[i][0].unsqueeze(0) bc_img = bc[i][0].unsqueeze(0) else: gt_img = hr[i] lr_img = lr[i] bc_img = bc[i] if self.metric=='sc': bc_metric = network_utils.SC(bc_img, gt_img) recon_metric = network_utils.SC(sr_img, gt_img) elif self.metric=='ssim': bc_metric = network_utils.SSIM(bc_img, gt_img) recon_metric = network_utils.SSIM(sr_img, gt_img) else: bc_metric = network_utils.PSNR(bc_img, gt_img) recon_metric = network_utils.PSNR(sr_img, gt_img) metric.append(recon_metric) # plot result images result_imgs = [gt_img, lr_img, bc_img, sr_img] metrics = [None, None, bc_metric, recon_metric] network_utils.plot_test_result(result_imgs, metrics, img_num, save_dir=save_dir, index=self.metric) print('Test DB: %s, Saving result images...[%d/%d]' % (test_dataset, img_num, total_img_num)) print('Test is finishied.') mean_metric=np.mean(metric) std_metric=np.std(metric) save_fn = save_dir + '\\results.txt' with open(save_fn,'w+') as file: file.write('average metric value is: %.3f\n' %mean_metric) file.write('std of metric value is: %.3f' %std_metric)
def train(self): vgg_factor = self.vgg_factor train_dataset = [] # networks, number of filters and resiudal blocks self.model = Net(num_channels=self.num_channels, base_filter=self.filter, num_residuals=self.num_residuals, scale_factor=self.scale_factor, kernel=self.kernel) # weigh initialization self.model.weight_init() # For the content loss self.feature_extractor = FeatureExtractor( models.vgg19(pretrained=True), feature_layer=self.vgg_layer) # optimizer self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999), eps=1e-8) # loss function if self.gpu_mode: print('in gpu mode') self.model.cuda() self.feature_extractor.cuda() self.L1_loss = nn.L1Loss().cuda() else: print('in cpu mode') self.L1_loss = nn.L1Loss() print('---------- Networks architecture -------------') network_utils.print_network(self.model) print('----------------------------------------------') # load dataset train_data_loader = self.load_ct_dataset( dataset=self.train_dataset, is_train=True, is_registered=self.registered, grayscale_corrected=self.grayscale_corrected) test_data_loader = self.load_ct_dataset( dataset=self.test_dataset, is_train=False, is_registered=self.registered, grayscale_corrected=self.grayscale_corrected) # set the logger #log_dir = os.path.join(self.save_dir, 'logs') #if not os.path.exists(log_dir): # os.makedirs(log_dir) #logger = Logger(log_dir) ################# Train ################# print('Training is started.') avg_loss = [] step = 0 # test image test_lr, test_hr, test_bc = test_data_loader.dataset.__getitem__(2) test_lr = test_lr.unsqueeze(0) test_hr = test_hr.unsqueeze(0) test_bc = test_bc.unsqueeze(0) self.model.train() for epoch in range(self.num_epochs): if epoch == 0: start_time = time.time() # learning rate is decayed by a factor of 2 every 40 epochs if (epoch + 1) % 40 == 0: for param_group in self.optimizer.param_groups: param_group['lr'] /= 2.0 print('Learning rate decay: lr={}'.format( self.optimizer.param_groups[0]['lr'])) epoch_loss = 0 for iter, (lr, hr, _) in enumerate(train_data_loader): # input data (low resolution image) if self.num_channels == 1: x_ = hr[:, 0].unsqueeze(1) y_ = lr[:, 0].unsqueeze(1) else: x_ = hr y_ = lr if self.gpu_mode: x_ = x_.cuda() y_ = y_.cuda() # update network self.optimizer.zero_grad() recon_image = self.model(y_) if self.num_channels == 1: x_VGG = hr.repeat(1, 3, 1, 1).cpu() x_VGG = network_utils.norm(x_VGG, vgg=True) recon_VGG = recon_image.repeat(1, 3, 1, 1).cpu() recon_VGG = network_utils.norm(recon_VGG, vgg=True) else: x_VGG = network_utils.norm(hr.cpu(), vgg=True) recon_VGG = network_utils.norm(recon_image.cpu(), vgg=True) if self.gpu_mode: x_VGG = x_VGG.cuda() recon_VGG = recon_VGG.cuda() real_feature = self.feature_extractor(x_VGG) fake_feature = self.feature_extractor(recon_VGG) vgg_loss = self.L1_loss(fake_feature, real_feature.detach()) vgg_loss = vgg_loss * vgg_factor loss = self.L1_loss(recon_image, x_) + vgg_loss loss.backward() self.optimizer.step() # log epoch_loss += loss.item() #print('Epoch: [%2d] [%4d/%4d] loss: %.8f' % ((epoch + 1), (iter + 1), len(train_data_loader), loss.item())) print('Epoch: [%2d] [%4d/%4d] loss: %.8f vggloss: %.8f' % ((epoch + 1), (iter + 1), len(train_data_loader), loss.item(), vgg_loss.item())) # tensorboard logging #logger.scalar_summary('loss', loss.data[0], step + 1) #step += 1 # avg. loss per epoch avg_loss.append(epoch_loss / len(train_data_loader)) # prediction if self.num_channels == 1: y_ = test_lr[:, 0].unsqueeze(1) else: y_ = test_lr if self.gpu_mode: y_ = y_.cuda() recon_img = self.model(y_) sr_img = recon_img[0].cpu() # save result image save_dir = os.path.join(self.save_dir, 'train_result') network_utils.save_img(sr_img, epoch + 1, save_dir=save_dir, is_training=True) if epoch == 0: print('time for 1 epoch is :%.2f' % (time.time() - start_time)) print('Result image at epoch %d is saved.' % (epoch + 1)) # Save trained parameters of model if (epoch + 1) % self.save_epochs == 0: self.save_model(epoch + 1) # calculate psnrs if self.num_channels == 1: gt_img = test_hr[0][0].unsqueeze(0) lr_img = test_lr[0][0].unsqueeze(0) bc_img = test_bc[0][0].unsqueeze(0) else: gt_img = test_hr[0] lr_img = test_lr[0] bc_img = test_bc[0] if self.metric == 'sc': bc_metric = network_utils.SC(bc_img, gt_img) recon_metric = network_utils.SC(sr_img, gt_img) elif self.metric == 'ssim': bc_metric = network_utils.SSIM(bc_img, gt_img) recon_metric = network_utils.SSIM(sr_img, gt_img) else: bc_metric = network_utils.PSNR(bc_img, gt_img) recon_metric = network_utils.PSNR(sr_img, gt_img) # plot result images result_imgs = [gt_img, lr_img, bc_img, sr_img] metrics = [None, None, bc_metric, recon_metric] network_utils.plot_test_result(result_imgs, metrics, self.num_epochs, save_dir=save_dir, is_training=True, index=self.metric) print('Training result image is saved.') # Plot avg. loss network_utils.plot_loss([avg_loss], self.num_epochs, save_dir=save_dir) print('Training is finished.') # Save final trained parameters of model self.save_model(epoch=None)