Пример #1
0
def main():
    # load image 
    ext = ('.png')
    dir_hr = 'D:/xImageDataset/benchmark/Set5/HR'
    #names_hr = sorted(
    #    glob.glob(os.path.join(dir_hr + test_set + '/HR', '*' + ext[0]))
    #)

    files = [f for f in glob.glob(dir_hr + "**/*.png", recursive=True)]

    for f_hr in files:

        # Reading images
        hr =cv2.imread(f_hr)        
        h, w, _ = np.shape(hr)
        cut_size = min(h, w)
        hr = hr[0:cut_size, 0:cut_size, :]

        hr = cv2.resize(hr, (256, 256), interpolation=cv2.INTER_CUBIC)
        cv2.imshow('HR '  + str(np.shape(hr)[0]), hr)
        cv2.waitKey(10) 
        
        # Simulate flatcam measuremetn 
        img = torch.from_numpy((cv2.cvtColor(hr, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)).float().cuda()
        fc_meas = common.flatcamSamp(torch.unsqueeze(img / 255, 0))
        fc_meas_n = common.apply_noise(fc_meas, nSig= 10)       

        # Simulated reconstruction        
        rec_sim = common.flatcamRecSimple(fc_meas_n)
        x_bayer, x_norm, x_, x_nonneg, rec_org = common.flatcamRecOrg(fc_meas_n)        

        print('Simulated measurement')
        
        file_name = os.path.basename(f_hr)
        file_name = file_name[:-4]
        scio.savemat(file_name + '.mat', {'img' : torch.squeeze(img.permute(1, 2, 0)).cpu().numpy(),
                                    'fc_meas' : torch.squeeze(fc_meas).cpu().numpy(),
                                    'fc_meas_n' : torch.squeeze(fc_meas_n).cpu().numpy(),
                                    'x_bayer' : torch.squeeze(x_bayer).cpu().numpy(),
                                    'x_norm' : torch.squeeze(x_norm).cpu().numpy(),
                                    'x_' : torch.squeeze(x_norm).cpu().numpy(),
                                    'x_nonneg' : torch.squeeze(x_nonneg).cpu().numpy(),
                                    'rec_org' : torch.squeeze(rec_org).permute(1, 2, 0).cpu().numpy(),
                                    'rec_sim' : torch.squeeze(rec_sim).permute(1, 2, 0).cpu().numpy()})
Пример #2
0
    def train(self):
        self.loss.step()        
        epoch = self.optimizer.get_last_epoch() + 1
        if self.args.resume > 0:
            epoch = self.args.resume + 1

        lr = self.optimizer.get_lr()

        self.ckp.write_log(
            '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr))
        )
        self.loss.start_log()
        self.model.train()

        timer_data, timer_model = utility.timer(), utility.timer()
        
        for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train):

            #if batch > 10:
            #    continue

            _, hr = self.prepare(lr, hr)

            timer_data.hold()
            timer_model.tic()
            self.optimizer.zero_grad()

            img = utility.quantize(hr , self.args.rgb_range) 
            if self.args.is_fcSim:
                img = common.flatcamSamp(img / self.args.rgb_range)
                img = common.apply_noise(img, self.args.sigma)
                img = common.Raw2Bayer(img)
                img = common.make_separable(img)
                
                #scio.savemat( 'train_test_sig' + str(self.args.sigma) + '_' + str(batch) +'.mat', 
                #                { 'hr' : torch.squeeze(hr).permute(0, 2, 3, 1).detach().cpu().numpy(),
                #                  'hr2' : torch.squeeze(hr2).permute(0, 2, 3, 1).detach().cpu().numpy(),  
                #                  'sim_fc' : torch.squeeze(sim_fc).detach().cpu().numpy(),
                #                  'sim_fc_noise' : torch.squeeze(sim_fc_noise).detach().cpu().numpy(),
                #                  'sim_fc_bayer' : torch.squeeze(sim_fc_bayer).detach().cpu().numpy(),
                #                  'sim_fc_bayerNorm' : torch.squeeze(sim_fc_bayerNorm).detach().cpu().numpy(),})
            
            sr = self.model(img, idx_scale)
            loss = self.loss(sr, hr)

            if self.args.model == 'kcsres_mwcnn2' :
                loss = loss + self.loss(sr_init, hr)


            loss.backward()
            if self.args.gclip > 0:
                utils.clip_grad_value_(
                    self.model.parameters(),
                    self.args.gclip
                )
            self.optimizer.step()

            timer_model.hold()

            if (batch + 1) % self.args.print_every == 0:
                self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
                    (batch + 1) * self.args.batch_size,
                    len(self.loader_train.dataset),
                    self.loss.display_loss(batch),
                    timer_model.release(),
                    timer_data.release()))

            timer_data.tic()

        self.loss.end_log(len(self.loader_train))
        self.error_last = self.loss.log[-1, -1]
        self.optimizer.schedule()
Пример #3
0
    def test(self):
        torch.set_grad_enabled(False)

        epoch = self.optimizer.get_last_epoch()
        self.ckp.write_log('\nEvaluation:')
        self.ckp.add_log(
            torch.zeros(1, len(self.loader_test), len(self.scale))
        )
        self.model.eval()

        save_folder = 'Results_DL/' + self.args.save + '/' + self.args.data_test[0] + '/'
        if not os.path.exists(save_folder):
            os.makedirs(save_folder)

        timer_test = utility.timer()
        # if self.args.save_results: self.ckp.begin_background()
        for idx_data, d in enumerate(self.loader_test):
            for idx_scale, scale in enumerate(self.scale):
                d.dataset.set_scale(idx_scale)
                for lr, hr, filename, _ in tqdm(d, ncols=80):
                    _, hr = self.prepare(lr, hr)    

                    # Prepare data for test_only 
                
                    _, _, h, w = hr.size() 
                    idx = min(h, w)
                    hr = hr[:, :, 0:idx, 0:idx]         # squazsied

                    img = utility.quantize(hr , self.args.rgb_range) 
                    if self.args.is_fcSim:
                        img = common.flatcamSamp(img / self.args.rgb_range)
                        img = common.apply_noise(img, self.args.sigma)
                        img = common.Raw2Bayer(img)
                        img = common.make_separable(img)
                        #img = sim_fc_bayerNorm

                    sr = self.model(img, idx_scale)
                    sr = utility.quantize(sr , self.args.rgb_range)


                    if self.args.test_only:                        
                        plt.imsave(save_folder + filename[0] + '.png',
                        torch.squeeze(sr).permute(1, 2, 0).detach().cpu().numpy()  /self.args.rgb_range ) 
                        plt.imsave(save_folder + '__Org_' + filename[0] + '.png',
                        torch.squeeze(hr).permute(1, 2, 0).detach().cpu().numpy()  /self.args.rgb_range ) 
                        
                    save_list = [sr]
                    #print('\n')
                    #print(hr.size())
                    #print(sr.size())
                    self.ckp.log[-1, idx_data, idx_scale] += utility.calc_psnr(
                        sr, hr, scale, self.args.rgb_range, dataset=d
                    )
                    if self.args.save_gt:
                        save_list.extend([lr, hr])
                        #print(cur_psnr, init_psnr)

                    if self.args.save_results:
                        self.ckp.save_results(d, filename[0], save_list, scale)

                self.ckp.log[-1, idx_data, idx_scale] /= len(d)
                best = self.ckp.log.max(0)
                self.ckp.write_log(
                    '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format(
                        d.dataset.name,
                        scale,
                        self.ckp.log[-1, idx_data, idx_scale],
                        best[0][idx_data, idx_scale],
                        best[1][idx_data, idx_scale] + 1
                    )
                )

        self.ckp.write_log('Forward: {:.2f}s\n'.format(timer_test.toc()))
        self.ckp.write_log('Saving...')

        # if self.args.save_results:             self.ckp.end_background()

        if not self.args.test_only:
            self.ckp.save(self, epoch, is_best=(best[1][0, 0] + 1 == epoch))

        self.ckp.write_log(
            'Total: {:.2f}s\n'.format(timer_test.toc()), refresh=True
        )

        torch.set_grad_enabled(True)
Пример #4
0
    def train(self):
        self.loss.step()
        epoch = self.optimizer.get_last_epoch() + 1
        if self.args.resume > 0:
            epoch = self.args.resume + 1

        lr = self.optimizer.get_lr()

        self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(
            epoch, Decimal(lr)))
        self.loss.start_log()
        self.model.train()

        timer_data, timer_model = utility.timer(), utility.timer()

        for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train):

            #if batch > 10:
            #    continue

            _, hr = self.prepare(lr, hr)
            #hr =  hr/self.args.rgb_range

            timer_data.hold()
            timer_model.tic()

            self.optimizer.zero_grad()

            # Initial Reconstruction
            img = utility.quantize(hr, self.args.rgb_range)
            if self.args.is_fcSim:
                img = common.flatcamSamp(img / self.args.rgb_range)
                img = common.apply_noise(img, self.args.sigma)
                img = common.Raw2Bayer(img)
                img = common.make_separable(img)

            sr0 = self.model_init(img, idx_scale)

            # Enhance reconstruction
            sr = self.model(sr0, idx_scale)

            loss = self.loss(sr, hr)

            if self.args.model == 'kcsres_mwcnn2':
                loss = loss + self.loss(sr_init, hr)

            loss.backward()
            if self.args.gclip > 0:
                utils.clip_grad_value_(self.model.parameters(),
                                       self.args.gclip)
            self.optimizer.step()

            timer_model.hold()

            if (batch + 1) % self.args.print_every == 0:
                self.ckp.write_log('[{}/{}]\t{}\t{:.1f}+{:.1f}s'.format(
                    (batch + 1) * self.args.batch_size,
                    len(self.loader_train.dataset),
                    self.loss.display_loss(batch), timer_model.release(),
                    timer_data.release()))

            timer_data.tic()

        self.loss.end_log(len(self.loader_train))
        self.error_last = self.loss.log[-1, -1]
        self.optimizer.schedule()