Beispiel #1
0
    def log_current_visual(self, tb_logger, img_name, current_step):
        """
        log visual results to tensorboard for comparison
        """
        visuals = self.get_current_visual(need_np=False)
        visuals_list = [
            util.quantize(visuals['HR'].squeeze(0), self.opt['rgb_range'])
        ]
        visuals_list.extend([
            util.quantize(s.squeeze(0), self.opt['rgb_range'])
            for s in visuals['SR']
        ])

        visual_images = torch.stack(visuals_list)
        visual_images = thutil.make_grid(visual_images,
                                         nrow=len(visuals_list),
                                         padding=5,
                                         normalize=True,
                                         scale_each=True)
        tb_logger.add_image(img_name + '_SR',
                            visual_images,
                            global_step=current_step)

        #fig = self.get_current_heatmap_pair()
        #tb_logger.add_figure(img_name + '_Heatmap', fig, global_step=current_step)
        fig = self.get_current_landmark_pair()
        tb_logger.add_figure(img_name + '_Landmark',
                             fig,
                             global_step=current_step)
        plt.close(fig)
Beispiel #2
0
    def save_current_visual(self, img_name):
        """
        save visual results for comparison
        """
        visuals = self.get_current_visual(need_np=False)
        visuals_list = [
            util.quantize(visuals['HR'].squeeze(0), self.opt['rgb_range'])
        ]
        visuals_list.extend([
            util.quantize(s.squeeze(0), self.opt['rgb_range'])
            for s in visuals['SR']
        ])

        visual_images = torch.stack(visuals_list)
        visual_images = thutil.make_grid(visual_images,
                                         nrow=len(visuals_list),
                                         padding=5)
        visual_images = visual_images.byte().permute(1, 2, 0).numpy()
        save_dir = os.path.join(self.visual_dir, img_name)
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        cv2.imwrite(os.path.join(save_dir, 'SR_step_%d.png' % (self.step)),
                    visual_images[:, :, ::-1])  # rgb2bgr

        #fig = self.get_current_heatmap_pair()
        #fig.savefig(
        #    os.path.join(save_dir, 'Heatmap_step_%d.png' % (self.step)))
        fig = self.get_current_landmark_pair()
        fig.savefig(
            os.path.join(save_dir, 'Landmark_step_%d.png' % (self.step)))
        plt.close(fig)
Beispiel #3
0
 def save_current_visual(self, epoch, iter):
     """
     save visual results for comparison
     """
     if epoch % self.save_vis_step == 0:
         visuals_list = []
         visuals = self.get_current_visual(need_np=False)
         visuals_list.extend([util.quantize(visuals['HR'].squeeze(0), self.opt['rgb_range']),
                              util.quantize(visuals['SR'].squeeze(0), self.opt['rgb_range'])])
         visual_images = torch.stack(visuals_list)
         visual_images = thutil.make_grid(visual_images, nrow=2, padding=5)
         visual_images = visual_images.byte().permute(1, 2, 0).numpy()
         misc.imsave(os.path.join(self.visual_dir, 'epoch_%d_img_%d.png' % (epoch, iter + 1)),
                     visual_images)
 def save_current_visual(self, epoch, iter):
     """
     save visual results for comparison
     """
     if epoch % self.save_vis_step == 0:
         visuals_list = []
         visuals = self.get_current_visual(need_np=False)
         
         visuals_list.extend([util.quantize(visuals['SR'], self.opt['rgb_range'])])
         visual_images = torch.stack(visuals_list)
         visual_images = visual_images.byte().numpy()[0]
         visual_images = np.transpose(visual_images, (1 ,2, 0))
         print('epoch_%d_img_%d.png' % (epoch, iter + 1))
         imageio.imwrite(os.path.join(self.visual_dir, 'epoch_%d_img_%d.png' % (epoch, iter + 1)),visual_images)
Beispiel #5
0
 def get_current_visual_list(self):
     vis_list = []
     vis_list.append(util.quantize(self.img_vis.data[0].float().cpu()))
     vis_list.append(util.quantize(self.img_ir.data[0].float().cpu()))
     vis_list.append(util.quantize(self.img_fuse1.data[0].float().cpu()))
     vis_list.append(util.quantize(self.img_fuse2.data[0].float().cpu()))
     vis_list.append(
         util.quantize(0.5 * (self.img_fuse1.data[0].float().cpu() +
                              self.img_fuse2.data[0].float().cpu())))
     if self.opt['is_train']:
         vis_list.append(util.quantize(self.img_pf.data[0].float().cpu()))
     return vis_list
Beispiel #6
0
def main():
    parser = argparse.ArgumentParser(description='Test RCGAN model')
    parser.add_argument('-opt', type=str, required=True, help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt)
    opt = option.dict_to_nonedict(opt)

    # create test dataloader
    dataset_opt = opt['datasets']['test']
    if dataset_opt is None:
        raise ValueError("test dataset_opt is None!")
    test_set = create_dataset(dataset_opt)
    test_loader = create_dataloader(test_set, dataset_opt)

    if test_loader is None:
        raise ValueError("The test data does not exist")

    solver = RCGANModel(opt)
    solver.model_pth = opt['model_path']
    solver.results_dir = os.path.join(opt['model_path'], 'results')
    solver.cmp_dir = os.path.join(opt['model_path'], 'cmp')

    # load model
    model_pth = os.path.join(solver.model_pth, 'RCGAN_model.pth')
    if model_pth is None:
        raise ValueError("model_pth' is required.")
    print('[Loading model from %s...]' % model_pth)
    model_dict = torch.load(model_pth)
    solver.model['netG'].load_state_dict(model_dict['state_dict_G'])

    print('=> Done.')
    print('[Start Testing]')

    test_bar = tqdm(test_loader)
    fused_list = []
    path_list = []

    if not os.path.exists(solver.cmp_dir):
        os.makedirs(solver.cmp_dir)

    for iter, batch in enumerate(test_bar):
        solver.feed_data(batch)
        solver.test()
        visuals_list = solver.get_current_visual_list()  # fetch current iteration results as cpu tensor
        visuals = solver.get_current_visual()  # fetch current iteration results as cpu tensor
        images = torch.stack(visuals_list)
        saveimg = thutil.make_grid(images, nrow=3, padding=5)
        saveimg_nd = saveimg.byte().permute(1, 2, 0).numpy()
        img_name = os.path.splitext(os.path.basename(batch['VIS_path'][0]))[0]
        imageio.imwrite(os.path.join(solver.cmp_dir, 'comp_%s.bmp' % (img_name)), saveimg_nd)
        fused_img = visuals['img_fuse']
        fused_img = np.transpose(util.quantize(fused_img).numpy(), (1, 2, 0)).astype(np.uint8).squeeze()
        fused_list.append(fused_img)
        path_list.append(img_name)

    save_img_path = solver.results_dir
    if not os.path.exists(save_img_path):
        os.makedirs(save_img_path)

    for img, img_name in zip(fused_list, path_list):
        imageio.imwrite(os.path.join(solver.results_dir, img_name + '.bmp'), img)

    test_bar.close()
def main():
    # os.environ['CUDA_VISIBLE_DEVICES']='1' # You can specify your GPU device here. 
    parser = argparse.ArgumentParser(description='Train Super Resolution Models')
    parser.add_argument('-opt', type=str, required=True, help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt)

    if opt['train']['resume'] is False:
        util.mkdir_and_rename(opt['path']['exp_root'])  # rename old experiments if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'exp_root' and \
                     not key == 'pretrain_G' and not key == 'pretrain_D'))
        option.save(opt)
        opt = option.dict_to_nonedict(opt)  # Convert to NoneDict, which return None for missing key.
    else:
        opt = option.dict_to_nonedict(opt)
        if opt['train']['resume_path'] is None:
            raise ValueError("The 'resume_path' does not declarate")

    if opt['exec_debug']:
        NUM_EPOCH = 100
        opt['datasets']['train']['dataroot_HR'] = opt['datasets']['train']['dataroot_HR_debug']
        opt['datasets']['train']['dataroot_LR'] = opt['datasets']['train']['dataroot_LR_debug']

    else:
        NUM_EPOCH = int(opt['train']['num_epochs'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_loader = create_dataloader(train_set, dataset_opt)
            print('Number of train images in [%s]: %d' % (dataset_opt['name'], len(train_set)))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [%s]: %d' % (dataset_opt['name'], len(val_set)))
        elif phase == 'test':
            pass
        else:
            raise NotImplementedError("Phase [%s] is not recognized." % phase)

    if train_loader is None:
        raise ValueError("The training data does not exist")

    if opt['mode'] == 'sr':
        solver = SRModel(opt)
    else:
        assert 'Invalid opt.mode [%s] for SRModel class!'

    solver.summary(train_set[0]['LR'].size())
    solver.net_init()
    print('[Start Training]')

    start_time = time.time()

    start_epoch = 1
    if opt['train']['resume']:
        start_epoch = solver.load()

    for epoch in range(start_epoch, NUM_EPOCH + 1):
        # Initialization
        solver.training_loss = 0.0
        epoch_loss_log = 0.0

        if opt['mode'] == 'sr' :
            training_results = {'batch_size': 0, 'training_loss': 0.0}
        else:
            pass    # TODO
        train_bar = tqdm(train_loader)

        # Train model
        for iter, batch in enumerate(train_bar):
            solver.feed_data(batch)
            iter_loss = solver.train_step()
            epoch_loss_log += iter_loss.item()
            batch_size = batch['LR'].size(0)
            training_results['batch_size'] += batch_size

            if opt['mode'] == 'sr':
                training_results['training_loss'] += iter_loss * batch_size
                train_bar.set_description(desc='[%d/%d] Loss: %.4f ' % (
                    epoch, NUM_EPOCH, iter_loss))
            else:
                pass    # TODO

        solver.last_epoch_loss = epoch_loss_log / (len(train_bar))

        train_bar.close()
        time_elapse = time.time() - start_time
        start_time = time.time()
        print('Train Loss: %.4f' % (training_results['training_loss'] / training_results['batch_size']))

        # validate
        val_results = {'batch_size': 0, 'val_loss': 0.0, 'psnr': 0.0, 'ssim': 0.0}

        if epoch % solver.val_step == 0 and epoch != 0:
            print('[Validating...]')
            start_time = time.time()
            solver.val_loss = 0.0

            vis_index = 1

            for iter, batch in enumerate(val_loader):
                visuals_list = []

                solver.feed_data(batch)
                iter_loss = solver.test(opt['chop'])
                batch_size = batch['LR'].size(0)
                val_results['batch_size'] += batch_size

                visuals = solver.get_current_visual()   # float cpu tensor

                sr_img = np.transpose(util.quantize(visuals['SR'], opt['rgb_range']).numpy(), (1,2,0)).astype(np.uint8)
                gt_img = np.transpose(util.quantize(visuals['HR'], opt['rgb_range']).numpy(), (1,2,0)).astype(np.uint8)

                # calculate PSNR
                crop_size = opt['scale']
                cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
                cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]

                val_results['val_loss'] += iter_loss * batch_size

                val_results['psnr'] += util.calc_psnr(cropped_sr_img, cropped_gt_img)
                val_results['ssim'] += util.calc_ssim(cropped_sr_img, cropped_gt_img)

                if opt['mode'] == 'srgan':
                    pass    # TODO

                visuals_list.extend([util.quantize(visuals['HR'].squeeze(0), opt['rgb_range']),
                                     util.quantize(visuals['SR'].squeeze(0), opt['rgb_range'])])

                images = torch.stack(visuals_list)
                img = thutil.make_grid(images, nrow=2, padding=5)
                ndarr = img.byte().permute(1, 2, 0).numpy()
                misc.imsave(os.path.join(solver.vis_dir, 'epoch_%d_%d.png' % (epoch, vis_index)), ndarr)
                vis_index += 1

            avg_psnr = val_results['psnr']/val_results['batch_size']
            avg_ssim = val_results['ssim']/val_results['batch_size']
            print('Valid Loss: %.4f | Avg. PSNR: %.4f | Avg. SSIM: %.4f | Learning Rate: %f'%(val_results['val_loss']/val_results['batch_size'], avg_psnr, avg_ssim, solver.current_learning_rate()))

            time_elapse = start_time - time.time()

            #if epoch%solver.log_step == 0 and epoch != 0:
            # tensorboard visualization
            solver.training_loss = training_results['training_loss'] / training_results['batch_size']
            solver.val_loss = val_results['val_loss'] / val_results['batch_size']

            solver.tf_log(epoch)

            # statistics
            if opt['mode'] == 'sr' :
                solver.results['training_loss'].append(solver.training_loss.cpu().data.item())
                solver.results['val_loss'].append(solver.val_loss.cpu().data.item())
                solver.results['psnr'].append(avg_psnr)
                solver.results['ssim'].append(avg_ssim)
            else:
                pass    # TODO

            is_best = False
            if solver.best_prec < solver.results['psnr'][-1]:
                solver.best_prec = solver.results['psnr'][-1]
                is_best = True

            solver.save(epoch, is_best)

        # update lr
        solver.update_learning_rate(epoch)

    data_frame = pd.DataFrame(
        data={'training_loss': solver.results['training_loss']
            , 'val_loss': solver.results['val_loss']
            , 'psnr': solver.results['psnr']
            , 'ssim': solver.results['ssim']
              },
        index=range(1, NUM_EPOCH+1)
    )
    data_frame.to_csv(os.path.join(solver.results_dir, 'train_results.csv'),
                      index_label='Epoch')
Beispiel #8
0
def main():
    # os.environ['CUDA_VISIBLE_DEVICES']="1" # You can specify your GPU device here. I failed to perform it by `torch.cuda.set_device()`.
    parser = argparse.ArgumentParser(
        description='Train Super Resolution Models')
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt)

    if opt['train']['resume'] is False:
        util.mkdir_and_rename(
            opt['path']['exp_root'])  # rename old experiments if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'exp_root' and \
                     not key == 'pretrain_G' and not key == 'pretrain_D'))
        option.save(opt)
        opt = option.dict_to_nonedict(
            opt)  # Convert to NoneDict, which return None for missing key.
    else:
        opt = option.dict_to_nonedict(opt)
        if opt['train']['resume_path'] is None:
            raise ValueError("The 'resume_path' does not declarate")

    if opt['exec_debug']:
        NUM_EPOCH = 100
        opt['datasets']['train']['dataroot_HR'] = opt['datasets']['train'][
            'dataroot_HR_debug']  #"./dataset/TrainData/DIV2K_train_HR_sub",
        opt['datasets']['train']['dataroot_LR'] = opt['datasets']['train'][
            'dataroot_LR_debug']  #./dataset/TrainData/DIV2K_train_HR_sub_LRx3"

    else:
        NUM_EPOCH = int(opt['train']['num_epochs'])

    # random seed
    seed = opt['train']['manual_seed']  #0
    if seed is None:
        seed = random.randint(1, 10000)
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_loader = create_dataloader(train_set, dataset_opt)
            print('Number of train images in [%s]: %d' %
                  (dataset_opt['name'], len(train_set)))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [%s]: %d' %
                  (dataset_opt['name'], len(val_set)))
        elif phase == 'test':
            pass
        else:
            raise NotImplementedError("Phase [%s] is not recognized." % phase)

    if train_loader is None:
        raise ValueError("The training data does not exist")

    # TODO: design an exp that can obtain the location of the biggest error
    if opt['mode'] == 'sr':
        solver = SRModel1(opt)
    elif opt['mode'] == 'fi':
        solver = SRModel1(opt)
    elif opt['mode'] == 'srgan':
        solver = SRModelGAN(opt)
    elif opt['mode'] == 'msan':
        solver = SRModel1(opt)
    elif opt['mode'] == 'sr_curriculum':
        solver = SRModelCurriculum(opt)

    solver.summary(train_set[0]['LR'].size())
    solver.net_init()
    print('[Start Training]')

    start_time = time.time()

    start_epoch = 1
    if opt['train']['resume']:
        start_epoch = solver.load()

    for epoch in range(start_epoch, NUM_EPOCH + 1):
        # Initialization
        solver.training_loss = 0.0
        epoch_loss_log = 0.0

        if opt['mode'] == 'sr' or opt['mode'] == 'srgan' or opt[
                'mode'] == 'sr_curriculum' or opt['mode'] == 'fi' or opt[
                    'mode'] == 'msan':
            training_results = {'batch_size': 0, 'training_loss': 0.0}
        else:
            pass  # TODO
        train_bar = tqdm(train_loader)

        # Train model
        for iter, batch in enumerate(train_bar):
            solver.feed_data(batch)
            iter_loss = solver.train_step()
            epoch_loss_log += iter_loss.item()
            batch_size = batch['LR'].size(0)
            training_results['batch_size'] += batch_size

            if opt['mode'] == 'sr':
                training_results['training_loss'] += iter_loss * batch_size
                train_bar.set_description(desc='[%d/%d] Loss: %.4f ' %
                                          (epoch, NUM_EPOCH, iter_loss))
            elif opt['mode'] == 'srgan':
                training_results['training_loss'] += iter_loss * batch_size
                train_bar.set_description(desc='[%d/%d] Loss: %.4f ' %
                                          (epoch, NUM_EPOCH, iter_loss))
            elif opt['mode'] == 'fi':
                training_results['training_loss'] += iter_loss * batch_size
                train_bar.set_description(desc='[%d/%d] Loss: %.4f ' %
                                          (epoch, NUM_EPOCH, iter_loss))
            elif opt['mode'] == 'msan':
                training_results['training_loss'] += iter_loss * batch_size
                train_bar.set_description(desc='[%d/%d] Loss: %.4f ' %
                                          (epoch, NUM_EPOCH, iter_loss))
            elif opt['mode'] == 'sr_curriculum':
                training_results[
                    'training_loss'] += iter_loss.data * batch_size
                train_bar.set_description(desc='[%d/%d] Loss: %.4f ' %
                                          (epoch, NUM_EPOCH, iter_loss))
            else:
                pass  # TODO

        solver.last_epoch_loss = epoch_loss_log / (len(train_bar))

        train_bar.close()
        time_elapse = time.time() - start_time
        start_time = time.time()
        print('Train Loss: %.4f' % (training_results['training_loss'] /
                                    training_results['batch_size']))

        # validate
        val_results = {
            'batch_size': 0,
            'val_loss': 0.0,
            'psnr': 0.0,
            'ssim': 0.0
        }

        if epoch % solver.val_step == 0 and epoch != 0:
            print('[Validating...]')
            start_time = time.time()
            solver.val_loss = 0.0

            vis_index = 1

            for iter, batch in enumerate(val_loader):
                visuals_list = []

                solver.feed_data(batch)
                iter_loss = solver.test(opt['chop'])
                batch_size = batch['LR'].size(0)
                val_results['batch_size'] += batch_size

                visuals = solver.get_current_visual()  # float cpu tensor

                sr_img = np.transpose(
                    util.quantize(visuals['SR'], opt['rgb_range']).numpy(),
                    (1, 2, 0)).astype(np.uint8)
                gt_img = np.transpose(
                    util.quantize(visuals['HR'], opt['rgb_range']).numpy(),
                    (1, 2, 0)).astype(np.uint8)

                # calculate PSNR
                crop_size = opt['scale']
                cropped_sr_img = sr_img[crop_size:-crop_size,
                                        crop_size:-crop_size, :]
                cropped_gt_img = gt_img[crop_size:-crop_size,
                                        crop_size:-crop_size, :]

                cropped_sr_img = cropped_sr_img / 255.
                cropped_gt_img = cropped_gt_img / 255.
                cropped_sr_img = rgb2ycbcr(cropped_sr_img).astype(np.float32)
                cropped_gt_img = rgb2ycbcr(cropped_gt_img).astype(np.float32)

                ##################################################################################
                # b, r, g = cv2.split(cropped_sr_img)
                #
                # RG = r - g
                # YB = (r + g) / 2 - b
                # m, n, o = np.shape(cropped_sr_img)  # img为三维 rbg为二维 o并未用到
                # K = m * n
                # alpha_L = 0.1
                # alpha_R = 0.1  # 参数α 可调
                # T_alpha_L = math.ceil(alpha_L * K)  # 向上取整 #表示去除区间
                # T_alpha_R = math.floor(alpha_R * K)  # 向下取整
                #
                # RG_list = RG.flatten()  # 二维数组转一维(方便计算)
                # RG_list = sorted(RG_list)  # 排序
                # sum_RG = 0  # 计算平均值
                # for i in range(T_alpha_L + 1, K - T_alpha_R):
                #     sum_RG = sum_RG + RG_list[i]
                # U_RG = sum_RG / (K - T_alpha_R - T_alpha_L)
                # squ_RG = 0  # 计算方差
                # for i in range(K):
                #     squ_RG = squ_RG + np.square(RG_list[i] - U_RG)
                # sigma2_RG = squ_RG / K
                #
                # # YB和RG计算一样
                # YB_list = YB.flatten()
                # YB_list = sorted(YB_list)
                # sum_YB = 0
                # for i in range(T_alpha_L + 1, K - T_alpha_R):
                #     sum_YB = sum_YB + YB_list[i]
                # U_YB = sum_YB / (K - T_alpha_R - T_alpha_L)
                # squ_YB = 0
                # for i in range(K):
                #     squ_YB = squ_YB + np.square(YB_list[i] - U_YB)
                # sigma2_YB = squ_YB / K
                #
                # uicm = -0.0268 * np.sqrt(np.square(U_RG) + np.square(U_YB)) + 0.1586 * np.sqrt(sigma2_RG + sigma2_RG)
                ##################################################################################

                val_results['val_loss'] += iter_loss * batch_size

                val_results['psnr'] += util.calc_psnr(cropped_sr_img * 255,
                                                      cropped_gt_img * 255)
                val_results['ssim'] += util.compute_ssim1(
                    cropped_sr_img * 255, cropped_gt_img * 255)

                if opt['mode'] == 'srgan':
                    pass  # TODO

                # if opt['save_image']:
                #     visuals_list.extend([util.quantize(visuals['HR'].squeeze(0), opt['rgb_range']),
                #                          util.quantize(visuals['SR'].squeeze(0), opt['rgb_range'])])
                #
                #     images = torch.stack(visuals_list)
                #     img = thutil.make_grid(images, nrow=2, padding=5)
                #     ndarr = img.byte().permute(1, 2, 0).numpy()
                #     misc.imsave(os.path.join(solver.vis_dir, 'epoch_%d_%d.png' % (epoch, vis_index)), ndarr)
                #     vis_index += 1

            avg_psnr = val_results['psnr'] / val_results['batch_size']
            avg_ssim = val_results['ssim'] / val_results['batch_size']
            print(
                'Valid Loss: %.4f | Avg. PSNR: %.4f | Avg. SSIM: %.4f | Learning Rate: %f'
                % (val_results['val_loss'] / val_results['batch_size'],
                   avg_psnr, avg_ssim, solver.current_learning_rate()))

            time_elapse = start_time - time.time()

            #if epoch%solver.log_step == 0 and epoch != 0:
            # tensorboard visualization
            solver.training_loss = training_results[
                'training_loss'] / training_results['batch_size']
            solver.val_loss = val_results['val_loss'] / val_results[
                'batch_size']

            solver.tf_log(epoch)

            # statistics
            if opt['mode'] == 'sr' or opt['mode'] == 'srgan' or opt[
                    'mode'] == 'sr_curriculum' or opt['mode'] == 'fi' or opt[
                        'mode'] == 'msan':
                solver.results['training_loss'].append(
                    solver.training_loss.cpu().data.item())
                solver.results['val_loss'].append(
                    solver.val_loss.cpu().data.item())
                solver.results['psnr'].append(avg_psnr)
                solver.results['ssim'].append(avg_ssim)
            else:
                pass  # TODO

            is_best = False
            if solver.best_prec < solver.results['psnr'][-1]:
                solver.best_prec = solver.results['psnr'][-1]
                is_best = True

            print(
                '#############################################################'
            )
            print(solver.best_prec)
            print(solver.results['psnr'][-1])
            print(
                '***************************************************************'
            )
            # print(is_best)
            # print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
            # print('+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
            solver.save(epoch, is_best)

        # update lr
        solver.update_learning_rate(epoch)

    data_frame = pd.DataFrame(data={
        'training_loss': solver.results['training_loss'],
        'val_loss': solver.results['val_loss'],
        'psnr': solver.results['psnr'],
        'ssim': solver.results['ssim']
    },
                              index=range(1, NUM_EPOCH + 1))
    data_frame.to_csv(os.path.join(solver.results_dir, 'train_results.csv'),
                      index_label='Epoch')
Beispiel #9
0
def main():
    # os.environ['CUDA_VISIBLE_DEVICES']='1' # You can specify your GPU device here. I failed to perform it by `torch.cuda.set_device()`.
    parser = argparse.ArgumentParser(
        description='Test Super Resolution Models')
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt)
    opt = option.dict_to_nonedict(opt)

    # Initialization
    scale = opt['scale']
    dataroot_HR = opt['datasets']['test']['dataroot_HR']
    network_opt = opt['networks']['G']
    if network_opt['which_model'] == "feedback":
        model_name = "%s_f%dt%du%ds%d" % (
            network_opt['which_model'], network_opt['num_features'],
            network_opt['num_steps'], network_opt['num_units'],
            network_opt['num_stages'])
    else:
        model_name = network_opt['which_model']

    bm_list = [dataroot_HR.find(bm) >= 0 for bm in BENCHMARK]
    bm_idx = bm_list.index(True)
    bm_name = BENCHMARK[bm_idx]

    # create test dataloader
    dataset_opt = opt['datasets']['test']
    if dataset_opt is None:
        raise ValueError("test dataset_opt is None!")
    test_set = create_dataset(dataset_opt)
    test_loader = create_dataloader(test_set, dataset_opt)

    if test_loader is None:
        raise ValueError("The training data does not exist")

    if opt['mode'] == 'sr':
        solver = SRModel(opt)
    elif opt['mode'] == 'sr_curriculum':
        solver = SRModelCurriculum(opt)
    else:
        raise NotImplementedError

    # load model
    model_pth = os.path.join(solver.model_pth,
                             'x' + str(opt['scale']) + '.pth')
    # model_pth = os.path.join(solver.model_pth, 'epoch', 'checkpoint.pth')
    # model_pth = solver.model_pth
    if model_pth is None:
        raise ValueError("model_pth' is required.")
    print('[Loading model from %s...]' % model_pth)
    model_dict = torch.load(model_pth, map_location='cpu')
    if 'state_dict' in model_dict.keys():
        solver.model.load_state_dict(model_dict['state_dict'])
    else:
        if model_name == "rcan_ours":
            new_model_dict = OrderedDict()
            for key, value in model_dict.items():
                new_key = 'module.' + key
                new_model_dict[new_key] = value
            model_dict = new_model_dict

        solver.model.load_state_dict(model_dict)
    print('=> Done.')
    print('[Start Testing]')

    start_time = time.time()

    # we only forward one epoch at test stage, so no need to load epoch, best_prec, results from .pth file
    # we only save images and .pth for evaluation. Calculating statistics are handled by matlab.
    # do crop for efficiency
    test_bar = tqdm(test_loader)
    sr_list = []
    path_list = []
    psnr_list = []

    total_psnr = 0.
    for iter, batch in enumerate(test_bar):
        solver.feed_data(batch)
        solver.test(opt['chop'])
        visuals = solver.get_current_visual(
        )  # fetch current iteration results as cpu tensor

        sr_img = np.transpose(
            util.quantize(visuals['SR'], opt['rgb_range']).numpy(),
            (1, 2, 0)).astype(np.uint8)
        gt_img = np.transpose(
            util.quantize(visuals['HR'], opt['rgb_range']).numpy(),
            (1, 2, 0)).astype(np.uint8)

        # calculate PSNR
        crop_size = opt['scale']
        cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
        cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]

        psnr = util.calc_psnr(cropped_sr_img, cropped_gt_img)
        psnr_list.append(psnr)
        total_psnr += psnr

        sr_list.append(sr_img)
        path_list.append(
            os.path.splitext(os.path.basename(batch['HR_path'][0]))[0])

    print("=====================================")
    # save_txt_path = os.path.join(solver.model_pth, '%s_x%d.txt'%(bm_name, scale))
    line_list = []
    line = "Method : %s\nTest set : %s\nScale : %d " % (model_name, bm_name,
                                                        scale)
    line_list.append(line + '\n')
    print(line)
    for value, img_name in zip(psnr_list, path_list):
        line = "Image name : %s PSNR = %.2f " % (img_name, value)
        line_list.append(line + '\n')
        print(line)
    line = "Average PSNR is %.2f" % (total_psnr / len(test_bar))
    line_list.append(line)
    print(line)

    # save results
    # f = open(save_txt_path, 'w')
    # f.writelines(line_list)
    # f.close()

    save_img_path = os.path.join('./eval/SR/BI', model_name, bm_name,
                                 "x%d" % scale)
    if not os.path.exists(save_img_path):
        os.makedirs(save_img_path)

    for img, img_name in zip(sr_list, path_list):
        misc.imsave(
            os.path.join(save_img_path,
                         img_name.replace('HR', model_name) + '.png'), img)

    test_bar.close()
    time_elapse = start_time - time.time()