Exemple #1
0
    def predict2d(self,test_dataset,output='pic'):
        '''take 2d image and perform super resolution'''
        # networks
        self.G = Generator(num_channels=self.num_channels, base_filter=self.filter, num_residuals=self.num_residuals,scale_factor=self.scale_factor,kernel=3)

        if self.gpu_mode:
            print('gpu mode')
            self.G.cuda()

        # load model
        self.load_model()
            
        image_dir=join(self.data_dir, test_dataset)
                         
        image_filenames=[]  
        image_filenames.extend(join(image_dir, x) for x in sorted(listdir(image_dir)) if utils.is_raw_file(x))
                
            
        img_num=0
        for img_fn in image_filenames:
            print(img_fn)
          
            img = utils.read_and_reshape(img_fn).astype(float)
            minvalue=img.min()
            maxvalue=img.max()
            img = transforms_3d.rescale(img,original_scale=(minvalue,maxvalue),new_scale=(0,1))        
            img = Image.fromarray(img)
            lr_transform = Compose([ToTensor()])
            lr_img = lr_transform(img)   

            if self.num_channels == 1:
                y_ = lr_img.unsqueeze(0)
            else:
                raise Exception("only accept 2d raw image file " )

            if self.gpu_mode:
                y_ = y_.cuda()
            
            # prediction
            self.G.eval()
            recon_img = self.G(y_)
            

            recon_img = recon_img.cpu()[0].clamp(0, 1).detach().numpy()
            if output=='raw':
                recon_img = transforms_3d.rescale(recon_img,original_scale=(0,1),new_scale=(minvalue,maxvalue)).astype(int)  
                img_filename=img_fn.split('\\')[-1]            
                save_dir = os.path.join(self.save_dir, 'SR-2D-raw',test_dataset)
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                utils.save_as_raw(recon_img,os.path.join(save_dir,img_filename),dtype='uint16',prefix='SR')
            else:
                recon_img=torch.from_numpy(recon_img)    
                save_dir = os.path.join(self.save_dir, 'SR-2D-png', test_dataset)

                network_utils.save_img(recon_img, img_num, save_dir=save_dir)
            img_num+=1
            torch.cuda.empty_cache() 
        print('Single test result image is saved.')
Exemple #2
0
    def train(self):
        #defining weight factor for GAN loss, MSE loss and VGG loss for the loss function and label smoothing factor for discriminator
        gan_factor=0.1
        mse_factor=1
        vgg_factor=self.vgg_factor
        smooth_factor=0.1
        
        train_dataset=[]
        

        # load dataset

        train_data_loader = self.load_ct_dataset(dataset=self.train_dataset, is_train=True,
                                                     is_registered=self.registered,
                                                     grayscale_corrected=self.grayscale_corrected)
        test_data_loader = self.load_ct_dataset(dataset=self.test_dataset, is_train=False,
                                                    is_registered=self.registered,
                                                    grayscale_corrected=self.grayscale_corrected)            


        # networks
        self.G = Generator(num_channels=self.num_channels, base_filter=self.filter, num_residuals=self.num_residuals,scale_factor=self.scale_factor,kernel=self.kernel)
        self.D = Discriminator(num_channels=self.num_channels, base_filter=self.filter, image_size=self.crop_size)

        # weigh initialization
        self.G.weight_init()
        self.D.weight_init()

        # For the content loss
        self.feature_extractor = FeatureExtractor(models.vgg19(pretrained=True),feature_layer=self.vgg_layer)

        # optimizer
        self.G_optimizer = optim.Adam(self.G.parameters(), lr=self.lr, betas=(0.9, 0.999))
        self.D_optimizer = optim.Adam(self.D.parameters(), lr=self.lr*self.lr_d, betas=(0.9, 0.999))

        # loss function
        if self.gpu_mode:
            self.G.cuda()
            self.D.cuda()
            self.feature_extractor.cuda()
            self.L1_loss = nn.L1Loss().cuda()
            self.MSE_loss = nn.MSELoss().cuda()
            self.BCE_loss = nn.BCELoss().cuda()
        else:
            self.MSE_loss = nn.MSELoss()
            self.BCE_loss = nn.BCELoss()
            self.L1_loss = nn.L1Loss()

        print('---------- Networks architecture -------------')
        network_utils.print_network(self.G)
        network_utils.print_network(self.D)
        print('----------------------------------------------')


        ################# Pre-train generator #################

        # Load pre-trained parameters of generator
        if not self.load_model(is_pretrain=True):
            # Pre-training generator for 50 epochs
            print('Pre-training is started.')
            self.G.train()
            for epoch in range(self.epoch_pretrain):
                for iter, (lr, hr, _) in enumerate(train_data_loader):
                    # input data (low resolution image)
                    if self.num_channels == 1:
                        x_ = hr
                        y_ = lr
                        #x_ = network_utils.norm(hr.repeat(1,3,1,1), vgg=True)
                        #x_ = torch.mean(x_,1,True)
                        #y_ = network_utils.norm(lr.repeat(1,3,1,1), vgg=True)
                        #y_ = torch.mean(y_,1, True)
                    else:
                        x_ = network_utils.norm(hr, vgg=True)
                        y_ = network_utils.norm(lr, vgg=True)

                    if self.gpu_mode:
                        x_ = x_.cuda()
                        y_ = y_.cuda()

                    # Train generator
                    self.G_optimizer.zero_grad()
                    recon_image = self.G(y_)

                    # Content losses
                    content_loss = self.L1_loss(recon_image, x_)

                    # Back propagation
                    G_loss_pretrain = content_loss
                    G_loss_pretrain.backward()
                    self.G_optimizer.step()

                    # log
                    print("Epoch: [%2d] [%4d/%4d] G_loss_pretrain: %.8f"
                          % ((epoch + 1), (iter + 1), len(train_data_loader), G_loss_pretrain.item()))

            print('Pre-training is finished.')

            # Save pre-trained parameters of generator
            self.save_model(is_pretrain=True)

        ################# Adversarial train #################
        print('Training is started.')
        # Avg. losses
        G_avg_loss = []
        D_avg_loss = []
        step = 0

        # test image
        test_lr, test_hr, test_bc = test_data_loader.dataset.__getitem__(20)
        test_lr = test_lr.unsqueeze(0)
        test_hr = test_hr.unsqueeze(0)
        test_bc = test_bc.unsqueeze(0)

        self.G.train()
        self.D.train()
        for epoch in range(self.num_epochs):
            self.G.train()
            self.D.train()
            if epoch==0:
                start_time=time.time()
            # learning rate is decayed by a factor of 2 every 40 epoch
            if (epoch + 1) % 40 == 0:
                for param_group in self.G_optimizer.param_groups:
                    param_group["lr"] /= 2.0
                print("Learning rate decay for G: lr={}".format(self.G_optimizer.param_groups[0]["lr"]))
                for param_group in self.D_optimizer.param_groups:
                    param_group["lr"] /= 2.0
                print("Learning rate decay for D: lr={}".format(self.D_optimizer.param_groups[0]["lr"]))

            G_epoch_loss = 0
            D_epoch_loss = 0
            for iter, (lr, hr, _) in enumerate(train_data_loader):
                # input data (low resolution image)
                mini_batch = lr.size()[0]

                if self.num_channels == 1:
                    x_ = hr
                    y_ = lr
                    
                else:
                    x_ = network_utils.norm(hr, vgg=True)
                    y_ = network_utils.norm(lr, vgg=True)

                if self.gpu_mode:
                    x_ = x_.cuda()
                    y_ = y_.cuda()
                    # labels
                    real_label = torch.ones(mini_batch).cuda()
                    fake_label = torch.zeros(mini_batch).cuda()
                else:
                    # labels
                    real_label = torch.ones(mini_batch)
                    fake_label = torch.zeros(mini_batch)

                # Reset gradient
                self.D_optimizer.zero_grad()

                # Train discriminator with real data
                D_real_decision = self.D(x_)
                D_real_loss = self.BCE_loss(D_real_decision.squeeze(),real_label*(1.0-smooth_factor))

                # Train discriminator with fake data
                recon_image = self.G(y_)
                D_fake_decision = self.D(recon_image)
                D_fake_loss = self.BCE_loss(D_fake_decision.squeeze(), fake_label)
                
                D_loss = (D_real_loss + D_fake_loss)*gan_factor

                # Back propagation
                D_loss.backward()
                self.D_optimizer.step()

                # Reset gradient
                self.G_optimizer.zero_grad()

                # Train generator
                recon_image = self.G(y_)
                D_fake_decision = self.D(recon_image)

                # Adversarial loss
                GAN_loss = self.BCE_loss(D_fake_decision.squeeze(), real_label)

                # Content losses
                mse_loss = self.L1_loss(recon_image, x_)
                
                if self.num_channels == 1:
                    x_VGG=hr.repeat(1,3,1,1).cpu()
                    x_VGG = network_utils.norm(x_VGG, vgg=True)
                    recon_VGG=recon_image.repeat(1,3,1,1).cpu()
                    recon_VGG = network_utils.norm(recon_VGG, vgg=True)
                else:
                    x_VGG = network_utils.norm(hr.cpu(), vgg=True)
                    recon_VGG = network_utils.norm(recon_image.cpu(), vgg=True)
                if self.gpu_mode:
                    x_VGG=x_VGG.cuda()
                    recon_VGG=recon_VGG.cuda()
                real_feature = self.feature_extractor(x_VGG)
                fake_feature = self.feature_extractor(recon_VGG)
                vgg_loss = self.L1_loss(fake_feature, real_feature.detach())

                # Back propagation
                mse_loss=mse_factor*mse_loss
                vgg_loss=vgg_factor*vgg_loss
                GAN_loss=gan_factor*GAN_loss
                G_loss = mse_loss +  vgg_loss + GAN_loss
                G_loss.backward()
                self.G_optimizer.step()

                # log
                G_epoch_loss += G_loss.item()
                D_epoch_loss += D_loss.item()
                #print("Epoch: [%2d] [%4d/%4d] G_loss: %.8f, D_loss: %.8f"
                #      % ((epoch + 1), (iter + 1), len(train_data_loader), G_loss.item(), D_loss.item()))

                print("Epoch: [%2d] [%4d/%4d] G_loss: %.8f, mse: %.4f,vgg: %.4f, gan: %.4f,D_loss: %.8f"
                      % ((epoch + 1), (iter + 1), len(train_data_loader), G_loss.item(), mse_loss.item(),vgg_loss.item(),GAN_loss.item(),D_loss.item()))
                

                step += 1

            # avg. loss per epoch
            G_avg_loss.append(G_epoch_loss / len(train_data_loader))
            D_avg_loss.append(D_epoch_loss / len(train_data_loader))

            # prediction
            if self.num_channels == 1:
                y_ = test_lr
                #y_ = network_utils.norm(test_lr.repeat(1,3,1,1), vgg=True)
                #y_ = torch.mean(y_,1,True)
            else:
                y_ = network_utils.norm(test_lr, vgg=True)

            if self.gpu_mode:
                y_ = y_.cuda()

            recon_img = self.G(y_)
            if self.num_channels == 1:
                sr_img=recon_img.cpu()
                #sr_img=network_utils.denorm(recon_img.repeat(1,3,1,1).cpu(),vgg=True)
                #sr_img=torch.mean(sr_img,1,True)
            else:
                sr_img = network_utils.denorm(recon_img.cpu(), vgg=True)

            sr_img=sr_img[0]
            # save result image
            save_dir = os.path.join(self.save_dir, 'train_result')
            network_utils.save_img(sr_img, epoch + 1, save_dir=save_dir, is_training=True)
            if epoch==0:
                print('time for 1 epoch is :%.2f'%(time.time()-start_time))
            print('Result image at epoch %d is saved.' % (epoch + 1))

            # Save trained parameters of model
            if (epoch + 1) % self.save_epochs == 0:
                self.save_model(epoch + 1)



        # calculate psnrs
        if self.num_channels == 1:
            gt_img = test_hr[0][0].unsqueeze(0)
            lr_img = test_lr[0][0].unsqueeze(0)
            bc_img = test_bc[0][0].unsqueeze(0)
        else:
            gt_img = test_hr[0]
            lr_img = test_lr[0]
            bc_img = test_bc[0]
        if self.metric=='sc':
            bc_metric = network_utils.SC(bc_img, gt_img)
            recon_metric = network_utils.SC(sr_img, gt_img)
        elif self.metric=='ssim':
            bc_metric = network_utils.SSIM(bc_img, gt_img)
            recon_metric = network_utils.SSIM(sr_img, gt_img)
        else:
            bc_metric = network_utils.PSNR(bc_img, gt_img)
            recon_metric = network_utils.PSNR(sr_img, gt_img)

        # plot result images
        result_imgs = [gt_img, lr_img, bc_img, sr_img]
        metrics = [None, None, bc_metric, recon_metric]
        network_utils.plot_test_result(result_imgs, metrics, self.num_epochs, save_dir=save_dir, is_training=True, index=self.metric)
        print('Training result image is saved.')


        # Plot avg. loss
        network_utils.plot_loss([G_avg_loss, D_avg_loss], self.num_epochs, save_dir=self.save_dir)
        print("Training is finished.")

        # Save final trained parameters of model
        self.save_model(epoch=None)
Exemple #3
0
    def test(self,test_dataset):
        # networks
        self.G = Generator(num_channels=self.num_channels, base_filter=self.filter, num_residuals=self.num_residuals,scale_factor=self.scale_factor,kernel=3)

        if self.gpu_mode:
            self.G.cuda()

        # load model
        self.load_model()

        
        # load dataset
        test_data_loader = self.load_ct_dataset(dataset=[test_dataset], 
                                                    is_train=False,
                                                    is_registered=self.registered,
                                                    grayscale_corrected=self.grayscale_corrected)      


        # Test
        print('Test is started.')
        img_num = 0
        total_img_num = len(test_data_loader)
        self.G.eval()
        metric=[]
        for lr, hr, bc in test_data_loader:

            # input data (low resolution image)
            if self.num_channels == 1:
                y_ = lr[:, 0].unsqueeze(1)
            else:
                y_ = network_utils.norm(lr, vgg=True)

            if self.gpu_mode:
                y_ = y_.cuda()

            # prediction
            recon_imgs = self.G(y_)
            
            if self.num_channels == 1:
                recon_imgs=recon_imgs.cpu()
            else:
                recon_imgs = network_utils.denorm(recon_imgs.cpu(), vgg=True)

            for i, recon_img in enumerate(recon_imgs):
                img_num += 1
                sr_img = recon_img

                # save result image
                save_dir = os.path.join(self.save_dir, test_dataset)
                network_utils.save_img(sr_img, img_num, save_dir=save_dir)

                # calculate psnrs
                if self.num_channels == 1:
                    gt_img = hr[i][0].unsqueeze(0)
                    lr_img = lr[i][0].unsqueeze(0)
                    bc_img = bc[i][0].unsqueeze(0)
                else:
                    gt_img = hr[i]
                    lr_img = lr[i]
                    bc_img = bc[i]

                if self.metric=='sc':
                    bc_metric = network_utils.SC(bc_img, gt_img)
                    recon_metric = network_utils.SC(sr_img, gt_img)
                elif self.metric=='ssim':
                    bc_metric = network_utils.SSIM(bc_img, gt_img)
                    recon_metric = network_utils.SSIM(sr_img, gt_img)
                else:
                    bc_metric = network_utils.PSNR(bc_img, gt_img)
                    recon_metric = network_utils.PSNR(sr_img, gt_img)
                    
                metric.append(recon_metric)
                # plot result images
                result_imgs = [gt_img, lr_img, bc_img, sr_img]
                metrics = [None, None, bc_metric, recon_metric]
                network_utils.plot_test_result(result_imgs, metrics, img_num, save_dir=save_dir, index=self.metric)


                print('Test DB: %s, Saving result images...[%d/%d]' % (test_dataset, img_num, total_img_num))

        print('Test is finishied.')
        mean_metric=np.mean(metric)
        std_metric=np.std(metric)
        save_fn = save_dir + '\\results.txt'
        with open(save_fn,'w+') as file:
            file.write('average metric value is: %.3f\n' %mean_metric)
            file.write('std of metric value is: %.3f' %std_metric)
Exemple #4
0
    def train(self):
        vgg_factor = self.vgg_factor
        train_dataset = []
        # networks, number of filters and resiudal blocks
        self.model = Net(num_channels=self.num_channels,
                         base_filter=self.filter,
                         num_residuals=self.num_residuals,
                         scale_factor=self.scale_factor,
                         kernel=self.kernel)

        # weigh initialization
        self.model.weight_init()
        # For the content loss
        self.feature_extractor = FeatureExtractor(
            models.vgg19(pretrained=True), feature_layer=self.vgg_layer)
        # optimizer
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.lr,
                                    betas=(0.9, 0.999),
                                    eps=1e-8)

        # loss function
        if self.gpu_mode:
            print('in gpu mode')
            self.model.cuda()
            self.feature_extractor.cuda()
            self.L1_loss = nn.L1Loss().cuda()
        else:
            print('in cpu mode')
            self.L1_loss = nn.L1Loss()

        print('---------- Networks architecture -------------')
        network_utils.print_network(self.model)
        print('----------------------------------------------')

        # load dataset

        train_data_loader = self.load_ct_dataset(
            dataset=self.train_dataset,
            is_train=True,
            is_registered=self.registered,
            grayscale_corrected=self.grayscale_corrected)
        test_data_loader = self.load_ct_dataset(
            dataset=self.test_dataset,
            is_train=False,
            is_registered=self.registered,
            grayscale_corrected=self.grayscale_corrected)

        # set the logger
        #log_dir = os.path.join(self.save_dir, 'logs')
        #if not os.path.exists(log_dir):
        #    os.makedirs(log_dir)
        #logger = Logger(log_dir)

        ################# Train #################
        print('Training is started.')
        avg_loss = []
        step = 0

        # test image
        test_lr, test_hr, test_bc = test_data_loader.dataset.__getitem__(2)
        test_lr = test_lr.unsqueeze(0)
        test_hr = test_hr.unsqueeze(0)
        test_bc = test_bc.unsqueeze(0)

        self.model.train()
        for epoch in range(self.num_epochs):
            if epoch == 0:
                start_time = time.time()
            # learning rate is decayed by a factor of 2 every 40 epochs
            if (epoch + 1) % 40 == 0:
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] /= 2.0
                print('Learning rate decay: lr={}'.format(
                    self.optimizer.param_groups[0]['lr']))

            epoch_loss = 0
            for iter, (lr, hr, _) in enumerate(train_data_loader):
                # input data (low resolution image)
                if self.num_channels == 1:
                    x_ = hr[:, 0].unsqueeze(1)
                    y_ = lr[:, 0].unsqueeze(1)
                else:
                    x_ = hr
                    y_ = lr

                if self.gpu_mode:
                    x_ = x_.cuda()
                    y_ = y_.cuda()

                # update network
                self.optimizer.zero_grad()
                recon_image = self.model(y_)

                if self.num_channels == 1:
                    x_VGG = hr.repeat(1, 3, 1, 1).cpu()
                    x_VGG = network_utils.norm(x_VGG, vgg=True)
                    recon_VGG = recon_image.repeat(1, 3, 1, 1).cpu()
                    recon_VGG = network_utils.norm(recon_VGG, vgg=True)
                else:
                    x_VGG = network_utils.norm(hr.cpu(), vgg=True)
                    recon_VGG = network_utils.norm(recon_image.cpu(), vgg=True)
                if self.gpu_mode:
                    x_VGG = x_VGG.cuda()
                    recon_VGG = recon_VGG.cuda()

                real_feature = self.feature_extractor(x_VGG)
                fake_feature = self.feature_extractor(recon_VGG)
                vgg_loss = self.L1_loss(fake_feature, real_feature.detach())
                vgg_loss = vgg_loss * vgg_factor

                loss = self.L1_loss(recon_image, x_) + vgg_loss
                loss.backward()
                self.optimizer.step()

                # log
                epoch_loss += loss.item()
                #print('Epoch: [%2d] [%4d/%4d] loss: %.8f' % ((epoch + 1), (iter + 1), len(train_data_loader), loss.item()))
                print('Epoch: [%2d] [%4d/%4d] loss: %.8f vggloss: %.8f' %
                      ((epoch + 1), (iter + 1), len(train_data_loader),
                       loss.item(), vgg_loss.item()))
                # tensorboard logging
                #logger.scalar_summary('loss', loss.data[0], step + 1)
                #step += 1

            # avg. loss per epoch
            avg_loss.append(epoch_loss / len(train_data_loader))

            # prediction
            if self.num_channels == 1:
                y_ = test_lr[:, 0].unsqueeze(1)
            else:
                y_ = test_lr

            if self.gpu_mode:
                y_ = y_.cuda()

            recon_img = self.model(y_)
            sr_img = recon_img[0].cpu()

            # save result image
            save_dir = os.path.join(self.save_dir, 'train_result')
            network_utils.save_img(sr_img,
                                   epoch + 1,
                                   save_dir=save_dir,
                                   is_training=True)
            if epoch == 0:
                print('time for 1 epoch is :%.2f' % (time.time() - start_time))
            print('Result image at epoch %d is saved.' % (epoch + 1))

            # Save trained parameters of model
            if (epoch + 1) % self.save_epochs == 0:
                self.save_model(epoch + 1)

        # calculate psnrs
        if self.num_channels == 1:
            gt_img = test_hr[0][0].unsqueeze(0)
            lr_img = test_lr[0][0].unsqueeze(0)
            bc_img = test_bc[0][0].unsqueeze(0)
        else:
            gt_img = test_hr[0]
            lr_img = test_lr[0]
            bc_img = test_bc[0]
        if self.metric == 'sc':
            bc_metric = network_utils.SC(bc_img, gt_img)
            recon_metric = network_utils.SC(sr_img, gt_img)
        elif self.metric == 'ssim':
            bc_metric = network_utils.SSIM(bc_img, gt_img)
            recon_metric = network_utils.SSIM(sr_img, gt_img)
        else:
            bc_metric = network_utils.PSNR(bc_img, gt_img)
            recon_metric = network_utils.PSNR(sr_img, gt_img)

        # plot result images
        result_imgs = [gt_img, lr_img, bc_img, sr_img]
        metrics = [None, None, bc_metric, recon_metric]
        network_utils.plot_test_result(result_imgs,
                                       metrics,
                                       self.num_epochs,
                                       save_dir=save_dir,
                                       is_training=True,
                                       index=self.metric)
        print('Training result image is saved.')

        # Plot avg. loss
        network_utils.plot_loss([avg_loss], self.num_epochs, save_dir=save_dir)
        print('Training is finished.')

        # Save final trained parameters of model
        self.save_model(epoch=None)