Ejemplo n.º 1
0
def main():
    #### setup options of three networks
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt_P',
                        type=str,
                        help='Path to option YMAL file of Predictor.')
    parser.add_argument('-opt_C',
                        type=str,
                        help='Path to option YMAL file of Corrector.')
    parser.add_argument('-opt_F',
                        type=str,
                        help='Path to option YMAL file of SFTMD_Net.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt_P = option.parse(args.opt_P, is_train=True)
    opt_C = option.parse(args.opt_C, is_train=True)
    opt_F = option.parse(args.opt_F, is_train=True)

    # convert to NoneDict, which returns None for missing keys
    opt_P = option.dict_to_nonedict(opt_P)
    opt_C = option.dict_to_nonedict(opt_C)
    opt_F = option.dict_to_nonedict(opt_F)

    # choose small opt for SFTMD test, fill path of pre-trained model_F
    opt_F = opt_F['sftmd']

    # create PCA matrix of enough kernel
    batch_ker = util.random_batch_kernel(batch=30000,
                                         l=opt_P['kernel_size'],
                                         sig_min=0.2,
                                         sig_max=4.0,
                                         rate_iso=1.0,
                                         scaling=3,
                                         tensor=False)
    print('batch kernel shape: {}'.format(batch_ker.shape))
    b = np.size(batch_ker, 0)
    batch_ker = batch_ker.reshape((b, -1))
    pca_matrix = util.PCA(batch_ker, k=opt_P['code_length']).float()
    print('PCA matrix shape: {}'.format(pca_matrix.shape))

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt_P['dist'] = False
        opt_F['dist'] = False
        opt_C['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt_P['dist'] = True
        opt_F['dist'] = True
        opt_C['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size(
        )  #Returns the number of processes in the current process group
        rank = torch.distributed.get_rank(
        )  #Returns the rank of current process group

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    ###### Predictor&Corrector train ######

    #### loading resume state if exists
    if opt_P['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt_P['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt_P,
                            resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0-7)
        if resume_state is None:
            # Predictor path
            util.mkdir_and_rename(
                opt_P['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt_P['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))
            # Corrector path
            util.mkdir_and_rename(
                opt_C['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt_C['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt_P['path']['log'],
                          'train_' + opt_P['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        util.setup_logger('val',
                          opt_P['path']['log'],
                          'val_' + opt_P['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt_P))
        logger.info(option.dict2str(opt_C))
        # tensorboard logger
        if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt_P['name'])
    else:
        util.setup_logger('base',
                          opt_P['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    #### random seed
    seed = opt_P['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt_P['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt_P['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt_P['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt_P,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt_P, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None
    assert val_loader is not None

    #### create model
    model_F = create_model(opt_F)  #load pretrained model of SFTMD
    model_P = create_model(opt_P)
    model_C = create_model(opt_C)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model_P.resume_training(
            resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt_P['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate, schedulers
            # model.update_learning_rate(current_step, warmup_iter=opt_P['train']['warmup_iter'])

            #### preprocessing for LR_img and kernel map
            prepro = util.SRMDPreprocessing(opt_P['scale'],
                                            pca_matrix,
                                            para_input=opt_P['code_length'],
                                            kernel=opt_P['kernel_size'],
                                            noise=False,
                                            cuda=True,
                                            sig_min=0.2,
                                            sig_max=4.0,
                                            rate_iso=1.0,
                                            scaling=3,
                                            rate_cln=0.2,
                                            noise_high=0.0)
            LR_img, ker_map = prepro(train_data['GT'])

            #### training Predictor
            model_P.feed_data(LR_img, ker_map)
            model_P.optimize_parameters(current_step)
            P_visuals = model_P.get_current_visuals()
            est_ker_map = P_visuals['Batch_est_ker_map']

            #### log of model_P
            if current_step % opt_P['logger']['print_freq'] == 0:
                logs = model_P.get_current_log()
                message = 'Predictor <epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model_P.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)

            #### training Corrector
            for step in range(opt_C['step']):
                # test SFTMD for corresponding SR image
                model_F.feed_data(train_data, LR_img, est_ker_map)
                model_F.test()
                F_visuals = model_F.get_current_visuals()
                SR_img = F_visuals['Batch_SR']
                # Test SFTMD to produce SR images

                # train corrector given SR image and estimated kernel map
                model_C.feed_data(SR_img, est_ker_map, ker_map)
                model_C.optimize_parameters(current_step)
                C_visuals = model_C.get_current_visuals()
                est_ker_map = C_visuals['Batch_est_ker_map']

                #### log of model_C
                if current_step % opt_C['logger']['print_freq'] == 0:
                    logs = model_C.get_current_log()
                    message = 'Corrector <epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                        epoch, current_step,
                        model_C.get_current_learning_rate())
                    for k, v in logs.items():
                        message += '{:s}: {:.4e} '.format(k, v)
                        # tensorboard logger
                        if opt_C['use_tb_logger'] and 'debug' not in opt_C[
                                'name']:
                            if rank <= 0:
                                tb_logger.add_scalar(k, v, current_step)
                    if rank <= 0:
                        logger.info(message)

            # validation, to produce ker_map_list(fake)
            if current_step % opt_P['train']['val_freq'] == 0 and rank <= 0:
                avg_psnr = 0.0
                idx = 0
                for _, val_data in enumerate(val_loader):
                    prepro = util.SRMDPreprocessing(
                        opt_P['scale'],
                        pca_matrix,
                        para_input=opt_P['code_length'],
                        kernel=opt_P['kernel_size'],
                        noise=False,
                        cuda=True,
                        sig_min=0.2,
                        sig_max=4.0,
                        rate_iso=1.0,
                        scaling=3,
                        rate_cln=0.2,
                        noise_high=0.0)
                    LR_img, ker_map = prepro(val_data['GT'])
                    single_img_psnr = 0.0

                    # valid Predictor
                    model_P.feed_data(LR_img, ker_map)
                    model_P.test()
                    P_visuals = model_P.get_current_visuals()
                    est_ker_map = P_visuals['Batch_est_ker_map']

                    for step in range(opt_C['step']):
                        step += 1
                        idx += 1
                        model_F.feed_data(val_data, LR_img, est_ker_map)
                        model_F.test()
                        F_visuals = model_F.get_current_visuals()
                        SR_img = F_visuals['Batch_SR']
                        # Test SFTMD to produce SR images

                        model_C.feed_data(SR_img, est_ker_map, ker_map)
                        model_C.test()
                        C_visuals = model_C.get_current_visuals()
                        est_ker_map = C_visuals['Batch_est_ker_map']

                        sr_img = util.tensor2img(F_visuals['SR'])  # uint8
                        gt_img = util.tensor2img(F_visuals['GT'])  # uint8

                        # Save SR images for reference
                        img_name = os.path.splitext(
                            os.path.basename(val_data['LQ_path'][0]))[0]
                        img_dir = os.path.join(opt_P['path']['val_images'],
                                               img_name)
                        # img_dir = os.path.join(opt_F['path']['val_images'], str(current_step), '_', str(step))
                        util.mkdir(img_dir)

                        save_img_path = os.path.join(
                            img_dir, '{:s}_{:d}_{:d}.png'.format(
                                img_name, current_step, step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        crop_size = opt_P['scale']
                        gt_img = gt_img / 255.
                        sr_img = sr_img / 255.
                        cropped_sr_img = sr_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                        cropped_gt_img = gt_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                        step_psnr = util.calculate_psnr(
                            cropped_sr_img * 255, cropped_gt_img * 255)
                        logger.info(
                            '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> img:{:s}, psnr: {:.4f}'
                            .format(epoch, current_step, step, img_name,
                                    step_psnr))
                        single_img_psnr += step_psnr
                        avg_psnr += util.calculate_psnr(
                            cropped_sr_img * 255, cropped_gt_img * 255)

                    avg_signle_img_psnr = single_img_psnr / step
                    logger.info(
                        '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> img:{:s}, average psnr: {:.4f}'
                        .format(epoch, current_step, step, img_name,
                                avg_signle_img_psnr))

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4f}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> psnr: {:.4f}'.
                    format(epoch, current_step, step, avg_psnr))
                # tensorboard logger
                if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)

            #### save models and training states
            if current_step % opt_P['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model_P.save(current_step)
                    model_P.save_training_state(epoch, current_step)
                    model_C.save(current_step)
                    model_C.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model_P.save('latest')
        model_C.save('latest')
        logger.info('End of Predictor and Corrector training.')
    tb_logger.close()
Ejemplo n.º 2
0
def SFTMD_train(opt_F, rank, world_size, pca_matrix):
    #### loading resume state if exists
    if opt_F['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(opt_F['path']['resume_state'],
                                  map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt_F, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:
        if resume_state is None:
            util.mkdir_and_rename(
                opt_F['path']['experiments_root'])  # rename experiment folder if exists
            util.mkdirs((path for key, path in opt_F['path'].items() if not key == 'experiments_root'
                         and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base', opt_F['path']['log'], 'train_' + opt_F['name'], level=logging.INFO,
                          screen=True, tofile=True)
        util.setup_logger('val', opt_F['path']['log'], 'val_' + opt_F['name'], level=logging.INFO,
                          screen=True, tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt_F))
        # tensorboard logger
        if opt_F['use_tb_logger'] and 'debug' not in opt_F['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt_F['name'])
    else:
        util.setup_logger('base', opt_F['path']['log'], 'train', level=logging.INFO, screen=True)
        logger = logging.getLogger('base')

    #### create train and val dataloader
    dataset_ratio = 200   # enlarge the size of each epoch
    for phase, dataset_opt in opt_F['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt_F['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt_F['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio)
                total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt_F, train_sampler)
            if rank <= 0:
                logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                    len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt_F, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None
    assert val_loader is not None

    #### create model
    model_F = create_model(opt_F)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model_F.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt_F['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### preprocessing for LR_img and kernel map
            prepro = util.SRMDPreprocessing(opt_F['scale'], pca_matrix, para_input=10, kernel=21, noise=False, cuda=True,
                                                  sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3,
                                                  rate_cln=0.2, noise_high=0.0)
            LR_img, ker_map = prepro(train_data['GT'])

            #### update learning rate, schedulers
            model_F.update_learning_rate(current_step, warmup_iter=opt_F['train']['warmup_iter'])

            #### training
            model_F.feed_data(train_data, LR_img, ker_map)
            model_F.optimize_parameters(current_step)

            #### log
            if current_step % opt_F['logger']['print_freq'] == 0:
                logs = model_F.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model_F.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt_F['use_tb_logger'] and 'debug' not in opt_F['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)

            # validation
            if current_step % opt_F['train']['val_freq'] == 0 and rank <= 0:
                avg_psnr = 0.0
                idx = 0
                for _, val_data in enumerate(val_loader):
                    idx += 1
                    #### preprocessing for LR_img and kernel map
                    prepro = util.SRMDPreprocessing(opt_F['scale'], pca_matrix, para_input=15, noise=False, cuda=True,
                                                    sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3,
                                                    rate_cln=0.2, noise_high=0.0)
                    LR_img, ker_map = prepro(val_data['GT'])

                    model_F.feed_data(val_data, LR_img, ker_map)
                    model_F.test()

                    visuals = model_F.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    # Save SR images for reference
                    img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
                    #img_dir = os.path.join(opt_F['path']['val_images'], img_name)
                    img_dir = os.path.join(opt_F['path']['val_images'], str(current_step))
                    util.mkdir(img_dir)

                    save_img_path = os.path.join(img_dir,'{:s}_{:d}.png'.format(img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt_F['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info('<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt_F['use_tb_logger'] and 'debug' not in opt_F['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)


            #### save models and training states
            if current_step % opt_F['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model_F.save(current_step)
                    model_F.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model_F.save('latest')
        logger.info('End of SFTMD training.')
Ejemplo n.º 3
0
def generate_mod_LR_bic():
    # set parameters
    up_scale = 4
    mod_scale = 4
    # set data dir
    sourcedir = '/mnt/yjchai/SR_data/BSDS100'  #'/mnt/yjchai/SR_data/DIV2K_test_HR' #'/mnt/yjchai/SR_data/Flickr2K/Flickr2K_HR'
    savedir = '/mnt/yjchai/SR_data/BSDS100_test'  #'/mnt/yjchai/SR_data/DIV2K_test' #'/mnt/yjchai/SR_data/Flickr2K_train'

    # set random seed
    util.set_random_seed(0)

    # load PCA matrix of enough kernel
    print('load PCA matrix')
    pca_matrix = torch.load('/media/sdc/yjchai/IKC/codes/pca_matrix.pth')
    print('PCA matrix shape: {}'.format(pca_matrix.shape))

    saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale))
    saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale))
    saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale))
    saveLRblurpath = os.path.join(savedir, 'LRblur', 'x' + str(up_scale))

    if not os.path.isdir(sourcedir):
        print('Error: No source data found')
        exit(0)
    if not os.path.isdir(savedir):
        os.mkdir(savedir)

    if not os.path.isdir(os.path.join(savedir, 'HR')):
        os.mkdir(os.path.join(savedir, 'HR'))
    if not os.path.isdir(os.path.join(savedir, 'LR')):
        os.mkdir(os.path.join(savedir, 'LR'))
    if not os.path.isdir(os.path.join(savedir, 'Bic')):
        os.mkdir(os.path.join(savedir, 'Bic'))
    if not os.path.isdir(os.path.join(savedir, 'LRblur')):
        os.mkdir(os.path.join(savedir, 'LRblur'))

    if not os.path.isdir(saveHRpath):
        os.mkdir(saveHRpath)
    else:
        print('It will cover ' + str(saveHRpath))

    if not os.path.isdir(saveLRpath):
        os.mkdir(saveLRpath)
    else:
        print('It will cover ' + str(saveLRpath))

    if not os.path.isdir(saveBicpath):
        os.mkdir(saveBicpath)
    else:
        print('It will cover ' + str(saveBicpath))

    if not os.path.isdir(saveLRblurpath):
        os.mkdir(saveLRblurpath)
    else:
        print('It will cover ' + str(saveLRblurpath))

    filepaths = sorted(
        [f for f in os.listdir(sourcedir) if f.endswith('.png')])
    print(filepaths)
    num_files = len(filepaths)

    kernel_map_tensor = torch.zeros(
        (num_files, 1, 10))  # each kernel map: 1*10

    # prepare data with augementation
    for i in range(num_files):
        filename = filepaths[i]
        print('No.{} -- Processing {}'.format(i, filename))
        # read image
        image = cv2.imread(os.path.join(sourcedir, filename))

        width = int(np.floor(image.shape[1] / mod_scale))
        height = int(np.floor(image.shape[0] / mod_scale))
        # modcrop
        if len(image.shape) == 3:
            image_HR = image[0:mod_scale * height, 0:mod_scale * width, :]
        else:
            image_HR = image[0:mod_scale * height, 0:mod_scale * width]
        img_HR = util.img2tensor(image_HR)
        C, H, W = img_HR.size()
        # LR_blur, by random gaussian kernel
        prepro = util.SRMDPreprocessing(up_scale,
                                        pca_matrix,
                                        para_input=10,
                                        kernel=21,
                                        noise=False,
                                        cuda=True,
                                        sig_min=0.2,
                                        sig_max=4.0,
                                        rate_iso=1.0,
                                        scaling=3,
                                        rate_cln=0.2,
                                        noise_high=0.0)
        LR_img, ker_map = prepro(img_HR.view(1, C, H, W))
        image_LR_blur = util.tensor2img(LR_img)
        # LR
        image_LR = imresize_np(image_HR, 1 / up_scale, True)
        # bic
        image_Bic = imresize_np(image_LR, up_scale, True)

        cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
        cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
        cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)
        cv2.imwrite(os.path.join(saveLRblurpath, filename), image_LR_blur)

        kernel_map_tensor[i] = ker_map
    # save dataset corresponding kernel maps
    torch.save(kernel_map_tensor, './BSDS100_kermap.pth')
    print("Image Blurring & Down smaple Done: X" + str(up_scale))
Ejemplo n.º 4
0
def generate_mod_LR_bic():
    # set parameters
    up_scale = 4
    mod_scale = 4
    # set data dir
    sourcedir = "/data/DIV2K_public/gt_k_x4"  #'/mnt/yjchai/SR_data/DIV2K_test_HR' #'/mnt/yjchai/SR_data/Flickr2K/Flickr2K_HR'
    savedir = "/data/DIV2KRK_public/x4HRblur.lmdb"  #'/mnt/yjchai/SR_data/DIV2K_test' #'/mnt/yjchai/SR_data/Flickr2K_train'

    # set random seed
    util.set_random_seed(0)

    # load PCA matrix of enough kernel
    print("load PCA matrix")
    pca_matrix = torch.load(
        "/data/IKC/pca_aniso_matrix.pth", map_location=lambda storage, loc: storage
    )
    print("PCA matrix shape: {}".format(pca_matrix.shape))

    saveHRpath = os.path.join(savedir, "HR", "x" + str(mod_scale))
    saveLRpath = os.path.join(savedir, "LR", "x" + str(up_scale))
    saveBicpath = os.path.join(savedir, "Bic", "x" + str(up_scale))
    saveLRblurpath = os.path.join(savedir, "LRblur", "x" + str(up_scale))

    if not os.path.isdir(sourcedir):
        print("Error: No source data found")
        exit(0)
    if not os.path.isdir(savedir):
        os.mkdir(savedir)

    if not os.path.isdir(os.path.join(savedir, "HR")):
        os.mkdir(os.path.join(savedir, "HR"))
    if not os.path.isdir(os.path.join(savedir, "LR")):
        os.mkdir(os.path.join(savedir, "LR"))
    if not os.path.isdir(os.path.join(savedir, "Bic")):
        os.mkdir(os.path.join(savedir, "Bic"))
    if not os.path.isdir(os.path.join(savedir, "LRblur")):
        os.mkdir(os.path.join(savedir, "LRblur"))

    if not os.path.isdir(saveHRpath):
        os.mkdir(saveHRpath)
    else:
        print("It will cover " + str(saveHRpath))

    if not os.path.isdir(saveLRpath):
        os.mkdir(saveLRpath)
    else:
        print("It will cover " + str(saveLRpath))

    if not os.path.isdir(saveBicpath):
        os.mkdir(saveBicpath)
    else:
        print("It will cover " + str(saveBicpath))

    if not os.path.isdir(saveLRblurpath):
        os.mkdir(saveLRblurpath)
    else:
        print("It will cover " + str(saveLRblurpath))

    filepaths = sorted([f for f in os.listdir(sourcedir) if f.endswith(".png")])
    print(filepaths)
    num_files = len(filepaths)

    # kernel_map_tensor = torch.zeros((num_files, 1, 10)) # each kernel map: 1*10

    # prepare data with augementation
    for i in range(num_files):
        filename = filepaths[i]
        print("No.{} -- Processing {}".format(i, filename))
        # read image
        image = cv2.imread(os.path.join(sourcedir, filename))

        width = int(np.floor(image.shape[1] / mod_scale))
        height = int(np.floor(image.shape[0] / mod_scale))
        # modcrop
        if len(image.shape) == 3:
            image_HR = image[0 : mod_scale * height, 0 : mod_scale * width, :]
        else:
            image_HR = image[0 : mod_scale * height, 0 : mod_scale * width]
        # LR_blur, by random gaussian kernel
        img_HR = util.img2tensor(image_HR)
        C, H, W = img_HR.size()
        # sig_list = [1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0, 3.2]
        # # sig = 2.6
        for sig in np.linspace(1.8, 3.2, 8):
            prepro = util.SRMDPreprocessing(
                up_scale,
                pca_matrix,
                random=True,
                para_input=10,
                kernel=11,
                noise=False,
                cuda=True,
                sig=0,
                sig_min=0.6,
                sig_max=5,
                rate_iso=0,
                scaling=3,
                rate_cln=0.2,
                noise_high=0.0,
            )  # random(sig_min, sig_max) | stable kernel(sig)
            LR_img, ker_map = prepro(img_HR.view(1, C, H, W))
            image_LR_blur = util.tensor2img(LR_img)
            cv2.imwrite(os.path.join(saveLRblurpath, 'sig{}_{}'.format(sig,filename)), image_LR_blur)
            cv2.imwrite(os.path.join(saveHRpath, 'sig{}_{}'.format(sig,filename)), image_HR)
        # LR
        image_LR = imresize_np(image_HR, 1 / up_scale, True)
        # bic
        image_Bic = imresize_np(image_LR, up_scale, True)

        # cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)
        cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
        cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)

        # kernel_map_tensor[i] = ker_map
    # save dataset corresponding kernel maps
    # torch.save(kernel_map_tensor, './Set5_sig2.6_kermap.pth')
    print("Image Blurring & Down smaple Done: X" + str(up_scale))
Ejemplo n.º 5
0
def main():
    #### setup options of three networks
    parser = argparse.ArgumentParser()
    parser.add_argument("-opt",
                        type=str,
                        help="Path to option YMAL file of Predictor.")
    parser.add_argument("--launcher",
                        choices=["none", "pytorch"],
                        default="none",
                        help="job launcher")
    parser.add_argument("--local_rank", type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    # choose small opt for SFTMD test, fill path of pre-trained model_F
    #### set random seed
    seed = opt["train"]["manual_seed"]
    if seed is None:
        seed = random.randint(1, 10000)
    util.set_random_seed(seed)

    # load PCA matrix of enough kernel
    print("load PCA matrix")
    pca_matrix = torch.load(opt["pca_matrix_path"],
                            map_location=lambda storage, loc: storage)
    print("PCA matrix shape: {}".format(pca_matrix.shape))

    #### distributed training settings
    if args.launcher == "none":  # disabled distributed training
        opt["dist"] = False
        opt["dist"] = False
        rank = -1
        print("Disabled distributed training.")
    else:
        opt["dist"] = True
        opt["dist"] = True
        init_dist()
        world_size = (
            torch.distributed.get_world_size()
        )  # Returns the number of processes in the current process group
        rank = torch.distributed.get_rank(
        )  # Returns the rank of current process group

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    ###### Predictor&Corrector train ######

    #### loading resume state if exists
    if opt["path"].get("resume_state", None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt["path"]["resume_state"],
            map_location=lambda storage, loc: storage.cuda(device_id),
        )
        option.check_resume(opt, resume_state["iter"])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0-7)
        if resume_state is None:
            # Predictor path
            util.mkdir_and_rename(
                opt["path"]
                ["experiments_root"])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt["path"].items()
                 if not key == "experiments_root"
                 and "pretrain_model" not in key and "resume" not in key))
            os.system("rm ./log")
            os.symlink(os.path.join(opt["path"]["experiments_root"], ".."),
                       "./log")

        # config loggers. Before it, the log will not work
        util.setup_logger(
            "base",
            opt["path"]["log"],
            "train_" + opt["name"],
            level=logging.INFO,
            screen=True,
            tofile=True,
        )
        util.setup_logger(
            "val",
            opt["path"]["log"],
            "val_" + opt["name"],
            level=logging.INFO,
            screen=True,
            tofile=True,
        )
        logger = logging.getLogger("base")
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt["use_tb_logger"] and "debug" not in opt["name"]:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    "You are using PyTorch {}. Tensorboard will use [tensorboardX]"
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir="log/tb_logger/" + opt["name"])
    else:
        util.setup_logger("base",
                          opt["path"]["log"],
                          "train",
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger("base")

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt["datasets"].items():
        if phase == "train":
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt["batch_size"]))
            total_iters = int(opt["train"]["niter"])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt["dist"]:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    "Number of train images: {:,d}, iters: {:,d}".format(
                        len(train_set), train_size))
                logger.info("Total epochs needed: {:d} for iters {:,d}".format(
                    total_epochs, total_iters))
        elif phase == "val":
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info("Number of val images in [{:s}]: {:d}".format(
                    dataset_opt["name"], len(val_set)))
        else:
            raise NotImplementedError(
                "Phase [{:s}] is not recognized.".format(phase))
    assert train_loader is not None
    assert val_loader is not None

    #### create model
    model = create_model(opt)  # load pretrained model of SFTMD

    #### resume training
    if resume_state:
        logger.info("Resuming training from epoch: {}, iter: {}.".format(
            resume_state["epoch"], resume_state["iter"]))

        start_epoch = resume_state["epoch"]
        current_step = resume_state["iter"]
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    prepro = util.SRMDPreprocessing(
        opt["scale"],
        pca_matrix,
        random=True,
        para_input=opt["code_length"],
        kernel=opt["kernel_size"],
        noise=False,
        cuda=True,
        sig=None,
        sig_min=opt["sig_min"],
        sig_max=opt["sig_max"],
        rate_iso=1.0,
        scaling=3,
        rate_cln=0.2,
        noise_high=0.0,
    )
    #### training
    logger.info("Start training from epoch: {:d}, iter: {:d}".format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt["dist"]:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### preprocessing for LR_img and kernel map
            LR_img, ker_map = prepro(train_data["GT"])
            LR_img = (LR_img * 255).round() / 255
            #### training Predictor
            model.feed_data(LR_img, train_data["GT"], ker_map)
            model.optimize_parameters(current_step)
            model.update_learning_rate(current_step,
                                       warmup_iter=opt["train"]["warmup_iter"])
            visuals = model.get_current_visuals()

            #### log of model_P
            if current_step % opt["logger"]["print_freq"] == 0:
                logs = model.get_current_log()
                message = "Predictor <epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> ".format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += "{:s}: {:.4e} ".format(k, v)
                    # tensorboard logger
                    if opt["use_tb_logger"] and "debug" not in opt["name"]:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)

            # validation, to produce ker_map_list(fake)
            if current_step % opt["train"]["val_freq"] == 0 and rank <= 0:
                avg_psnr = 0.0
                idx = 0
                for _, val_data in enumerate(val_loader):

                    # LR_img, ker_map = prepro(val_data['GT'])
                    LR_img = val_data["LQ"]
                    lr_img = util.tensor2img(
                        LR_img)  # save LR image for reference

                    # valid Predictor
                    model.feed_data(LR_img, val_data["GT"])
                    model.test()
                    visuals = model.get_current_visuals()

                    # Save images for reference
                    img_name = os.path.splitext(
                        os.path.basename(val_data["LQ_path"][0]))[0]
                    img_dir = os.path.join(opt["path"]["val_images"], img_name)
                    # img_dir = os.path.join(opt['path']['val_images'], str(current_step), '_', str(step))
                    util.mkdir(img_dir)
                    save_lr_path = os.path.join(img_dir,
                                                "{:s}_LR.png".format(img_name))
                    util.save_img(lr_img, save_lr_path)

                    sr_img = util.tensor2img(visuals["SR"])  # uint8
                    gt_img = util.tensor2img(visuals["GT"])  # uint8

                    save_img_path = os.path.join(
                        img_dir,
                        "{:s}_{:d}.png".format(img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt["scale"]
                    gt_img = gt_img / 255.0
                    sr_img = sr_img / 255.0
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]

                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)
                    idx += 1

                avg_psnr = avg_psnr / idx

                # log
                logger.info("# Validation # PSNR: {:.6f}".format(avg_psnr))
                logger_val = logging.getLogger("val")  # validation logger
                logger_val.info(
                    "<epoch:{:3d}, iter:{:8,d}, psnr: {:.6f}".format(
                        epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt["use_tb_logger"] and "debug" not in opt["name"]:
                    tb_logger.add_scalar("psnr", avg_psnr, current_step)

            #### save models and training states
            if current_step % opt["logger"]["save_checkpoint_freq"] == 0:
                if rank <= 0:
                    logger.info("Saving models and training states.")
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info("Saving the final model.")
        model.save("latest")
        logger.info("End of Predictor and Corrector training.")
    tb_logger.close()
Ejemplo n.º 6
0
    test_results['psnr_y'] = []
    test_results['ssim_y'] = []

    for test_data in test_loader:
        need_GT = False if test_loader.dataset.opt[
            'dataroot_GT'] is None else True
        img_path = test_data['GT_path'][0] if need_GT else test_data[
            'LQ_path'][0]
        img_name = os.path.splitext(os.path.basename(img_path))[0]
        #### preprocessing for LR_img and kernel map
        prepro = util.SRMDPreprocessing(opt_F['scale'],
                                        pca_matrix,
                                        para_input=15,
                                        noise=False,
                                        cuda=True,
                                        sig_min=0.2,
                                        sig_max=4.0,
                                        rate_iso=1.0,
                                        scaling=3,
                                        rate_cln=0.2,
                                        noise_high=0.0)
        LR_img, ker_map = prepro(test_data['GT'])

        model_F.feed_data(test_data, LR_img, ker_map)
        model_F.test()

        F_visuals = model_F.get_current_visuals()

        sr_img = util.tensor2img(F_visuals['SR'])  # uint8

        # save images
Ejemplo n.º 7
0
    def train(self):
        self.scheduler.step()
        self.loss.step()
        epoch = self.scheduler.last_epoch + 1

        # lr stepwise
        if epoch <= self.args.epochs_encoder:
            lr = self.args.lr_encoder * (self.args.gamma_encoder**
                                         (epoch // self.args.lr_decay_encoder))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
        else:
            lr = self.args.lr_sr * (self.args.gamma_sr**(
                (epoch - self.args.epochs_encoder) // self.args.lr_decay_sr))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr

        self.ckp.write_log('[Epoch {}]\tLearning rate: {:.2e}'.format(
            epoch, Decimal(lr)))
        self.loss.start_log()
        self.model.train()

        degrade = util.SRMDPreprocessing(self.scale[0],
                                         kernel_size=self.args.blur_kernel,
                                         blur_type=self.args.blur_type,
                                         sig_min=self.args.sig_min,
                                         sig_max=self.args.sig_max,
                                         lambda_min=self.args.lambda_min,
                                         lambda_max=self.args.lambda_max,
                                         noise=self.args.noise)

        timer = utility.timer()
        losses_contrast, losses_sr = utility.AverageMeter(
        ), utility.AverageMeter()

        for batch, (hr, _, idx_scale) in enumerate(self.loader_train):
            hr = hr.cuda()  # b, n, c, h, w
            lr, b_kernels = degrade(hr)  # bn, c, h, w

            self.optimizer.zero_grad()

            timer.tic()
            # forward
            ## train degradation encoder
            if epoch <= self.args.epochs_encoder:
                _, output, target = self.model_E(im_q=lr[:, 0, ...],
                                                 im_k=lr[:, 1, ...])
                loss_constrast = self.contrast_loss(output, target)
                loss = loss_constrast

                losses_contrast.update(loss_constrast.item())
            ## train the whole network
            else:
                sr, output, target = self.model(lr)
                loss_SR = self.loss(sr, hr[:, 0, ...])
                loss_constrast = self.contrast_loss(output, target)
                loss = loss_constrast + loss_SR

                losses_sr.update(loss_SR.item())
                losses_contrast.update(loss_constrast.item())

            # backward
            loss.backward()
            self.optimizer.step()
            timer.hold()

            if epoch <= self.args.epochs_encoder:
                if (batch + 1) % self.args.print_every == 0:
                    self.ckp.write_log('Epoch: [{:03d}][{:04d}/{:04d}]\t'
                                       'Loss [contrastive loss: {:.3f}]\t'
                                       'Time [{:.1f}s]'.format(
                                           epoch,
                                           (batch + 1) * self.args.batch_size,
                                           len(self.loader_train.dataset),
                                           losses_contrast.avg,
                                           timer.release()))
            else:
                if (batch + 1) % self.args.print_every == 0:
                    self.ckp.write_log(
                        'Epoch: [{:04d}][{:04d}/{:04d}]\t'
                        'Loss [SR loss:{:.3f} | contrastive loss: {:.3f}]\t'
                        'Time [{:.1f}s]'.format(
                            epoch,
                            (batch + 1) * self.args.batch_size,
                            len(self.loader_train.dataset),
                            losses_sr.avg,
                            losses_contrast.avg,
                            timer.release(),
                        ))

        self.loss.end_log(len(self.loader_train))

        # save model
        target = self.model.get_model()
        model_dict = target.state_dict()
        keys = list(model_dict.keys())
        for key in keys:
            if 'E.encoder_k' in key or 'queue' in key:
                del model_dict[key]
        torch.save(
            model_dict,
            os.path.join(self.ckp.dir, 'model', 'model_{}.pt'.format(epoch)))
Ejemplo n.º 8
0
    def test(self):
        self.ckp.write_log('\nEvaluation:')
        self.ckp.add_log(torch.zeros(1, len(self.scale)))
        self.model.eval()

        timer_test = utility.timer()

        with torch.no_grad():
            for idx_scale, scale in enumerate(self.scale):
                self.loader_test.dataset.set_scale(idx_scale)
                eval_psnr = 0
                eval_ssim = 0

                degrade = util.SRMDPreprocessing(
                    self.scale[0],
                    kernel_size=self.args.blur_kernel,
                    blur_type=self.args.blur_type,
                    sig=self.args.sig,
                    lambda_1=self.args.lambda_1,
                    lambda_2=self.args.lambda_2,
                    theta=self.args.theta,
                    noise=self.args.noise)

                for idx_img, (hr, filename, _) in enumerate(self.loader_test):
                    hr = hr.cuda()  # b, 1, c, h, w
                    hr = self.crop_border(hr, scale)
                    lr, _ = degrade(hr, random=False)  # b, 1, c, h, w
                    hr = hr[:, 0, ...]  # b, c, h, w

                    # inference
                    timer_test.tic()
                    sr = self.model(lr[:, 0, ...])
                    timer_test.hold()

                    sr = utility.quantize(sr, self.args.rgb_range)
                    hr = utility.quantize(hr, self.args.rgb_range)

                    # metrics
                    eval_psnr += utility.calc_psnr(
                        sr,
                        hr,
                        scale,
                        self.args.rgb_range,
                        benchmark=self.loader_test.dataset.benchmark)
                    eval_ssim += utility.calc_ssim(
                        sr,
                        hr,
                        scale,
                        benchmark=self.loader_test.dataset.benchmark)

                    # save results
                    if self.args.save_results:
                        save_list = [sr]
                        filename = filename[0]
                        self.ckp.save_results(filename, save_list, scale)

                self.ckp.log[-1, idx_scale] = eval_psnr / len(self.loader_test)
                self.ckp.write_log(
                    '[Epoch {}---{} x{}]\tPSNR: {:.3f} SSIM: {:.4f}'.format(
                        self.args.resume,
                        self.args.data_test,
                        scale,
                        eval_psnr / len(self.loader_test),
                        eval_ssim / len(self.loader_test),
                    ))
Ejemplo n.º 9
0
def degradation_model():
    # set parameters
    scale = 4
    mod_scale = 4

    # set data dir
    sourcedir = 'D:/NEU/ImageRestoration/div2k/Set5/HR/'
    savedir = 'D:/NEU/ImageRestoration/div2k/Set5/'
    # set random seed和
    util.set_random_seed(0)

    # load PCA matrix of enough kernelh
    # print('load PCA matrix')
    # pca_matrix = torch.load('/media/sdc/yjchai/IKC/codes/pca_matrix.pth', map_location=lambda storage, loc: storage)
    # print('PCA matrix shape: {}'.format(pca_matrix.shape))
    # 初始化kernelmap
    batch_ker = util.random_batch_kernel(batch=1000,
                                         k=15,
                                         sig_min=0.2,
                                         sig_max=4.0,
                                         rate_iso=1.0,
                                         scaling=3,
                                         tensor=False)
    print('batch kernel shape: {}'.format(batch_ker.shape))
    b = np.size(batch_ker, 0)
    batch_ker = batch_ker.reshape((b, -1))

    dim_pca = 15
    # torch.Size([225, 15])
    pca_matrix = util.PCA(batch_ker, dim_pca).float()
    print('PCA matrix shape: {}'.format(pca_matrix.shape))
    # saveHRpath = os.path.join(savedir, 'HR')
    # saveLRpath = os.path.join(savedir, 'LR', 'x' + str(scale))
    # saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(scale))
    saveLRblurpath = os.path.join(savedir, 'LRblur_sig0.6', 'X' + str(scale))

    if not os.path.isdir(sourcedir):
        print('Error: No source data found')
        exit(0)
    if not os.path.isdir(savedir):
        os.mkdir(savedir)

    # if not os.path.isdir(os.path.join(savedir, 'HR')):
    #     os.mkdir(os.path.join(savedir, 'HR'))

    # if not os.path.isdir(os.path.join(savedir, 'LR')):
    #     os.mkdir(os.path.join(savedir, 'LR'))
    # if not os.path.isdir(os.path.join(savedir, 'Bic')):
    #     os.mkdir(os.path.join(savedir, 'Bic'))
    if not os.path.isdir(os.path.join(savedir, 'LRblur_sig0.6')):
        os.mkdir(os.path.join(savedir, 'LRblur_sig0.6'))

    # if not os.path.isdir(saveHRpath):
    #     os.mkdir(saveHRpath)
    # else:
    #     print('It will cover ' + str(saveHRpath))

    # if not os.path.isdir(saveLRpath):
    #     os.mkdir(saveLRpath)
    # else:
    #     print('It will cover ' + str(saveLRpath))

    # if not os.path.isdir(saveBicpath):
    #     os.mkdir(saveBicpath)
    # else:
    #     print('It will cover ' + str(saveBicpath))

    if not os.path.isdir(saveLRblurpath):
        os.mkdir(saveLRblurpath)
    else:
        print('It will cover ' + str(saveLRblurpath))

    # 记得更改图片格式,不然无法生成
    filepaths = sorted(
        [f for f in os.listdir(sourcedir) if f.endswith('.bmp')])
    num_files = len(filepaths)

    # prepare data with augementation
    for i in range(num_files):
        filename = filepaths[i]
        print('No.{} -- Processing {}'.format((i + 1), filename))
        # read image
        image_HR = cv2.imread(os.path.join(sourcedir, filename))

        image = cv2.imread(os.path.join(sourcedir, filename))
        width = int(np.floor(image.shape[1] / mod_scale))
        height = int(np.floor(image.shape[0] / mod_scale))
        # modcrop
        if len(image.shape) == 3:
            image_HR = image[0:mod_scale * height, 0:mod_scale * width, :]
        else:
            image_HR = image[0:mod_scale * height, 0:mod_scale * width]

        # LR_blur, by random gaussian kernel
        img_HR = util.img2tensor(image_HR)
        C, H, W = img_HR.size()
        # sig_list = [1.8, 2.0, 2.2, 2.4, 2.6, 2.8, 3.0, 3.2]
        sig = 0.6

        prepro = util.SRMDPreprocessing(
            scale,
            pca_matrix,
            random=False,
            kernel=15,
            noise=False,
            cuda=True,
            sig=sig,
            sig_min=0.2,
            sig_max=4.0,
            rate_iso=1.0,
            scaling=4,
            rate_cln=0.2,
            noise_high=0.0)  #random(sig_min, sig_max) | stable kernel(sig)

        LR_img, ker_map = prepro(img_HR.view(1, C, H, W))
        image_LR_blur = util.tensor2img(LR_img)
        cv2.imwrite(
            os.path.join(saveLRblurpath, 'sig{}_'.format(str(sig)) + filename),
            image_LR_blur)
        # LR
        # image_LR = imresize_np(image_HR, 1 / scale, True)
        # bic
        # image_Bic = imresize_np(image_LR, scale, True)

        # cv2.imwrite(os.path.join(saveHRpath, filename), image_HR)

        # cv2.imwrite(os.path.join(saveLRpath, filename), image_LR)
        # cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic)

        # kernel_map_tensor[i] = ker_map

    # save dataset corresponding kernel maps
    # 1016windows下路径不可含特殊符号,
    # torch.save(kernel_map_tensor, 'D:/NEU/ImageRestoration/datasets/try/try_sig2.6_kermap.pth')
    print("Image Blurring & Down smaple Done: X" + str(scale))
Ejemplo n.º 10
0
    test_results['ssim_y'] = []

    for test_data in test_loader:
        need_GT = False if test_loader.dataset.opt[
            'dataroot_GT'] is None else True
        img_path = test_data['GT_path'][0] if need_GT else test_data[
            'LQ_path'][0]
        img_name = os.path.splitext(os.path.basename(img_path))[0]
        #### preprocessing for LR_img and kernel map
        prepro = util.SRMDPreprocessing(
            opt_F['scale'],
            pca_matrix,
            random=False,
            para_input=opt_F['code_length'],
            noise=False,
            cuda=True,
            sig=opt_F['sig'],
            sig_min=opt_F['sig_min'],
            sig_max=opt_F['sig_max'],
            rate_iso=1.0,
            scaling=3,
            rate_cln=0.2,
            noise_high=0.0)  # random(sig_min, sig_max) | stable kernel(sig)
        LR_img, ker_map = prepro(test_data['GT'])

        model_F.feed_data(test_data, LR_img, ker_map)
        model_F.test()

        F_visuals = model_F.get_current_visuals()

        sr_img = util.tensor2img(F_visuals['SR'])  # uint8