Esempio n. 1
0
def vimeo(img_root, lmdb_save_path):
    """Create lmdb for the vimeo dataset, each image with a fixed size
    GT: [3, 256, 448],     key: 00001_0001_4
    """
    #### configurations
    BATCH = 50000
    n_thread = 40
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    txt_file = osp.join(img_root, 'sep_trainlist.txt')
    with open(txt_file, 'r') as f:
        lines = f.readlines()
        img_list = [line.strip() for line in lines]

    imgs, keys = [], []
    for item in img_list:
        key_pre = item.replace('/', '_')
        im_dir = osp.join(img_root, 'sequences', item)
        names = sorted(os.listdir(im_dir))
        for name in names:
            imgs.append(osp.join(im_dir, name))
            keys.append(key_pre + '_' + name[2])
    im1 = cv2.imread(imgs[0], cv2.IMREAD_UNCHANGED)
    H, W, C = im1.shape
    print('data size per image is: ', im1.nbytes)
    data_size = im1.nbytes * len(imgs)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    txn = env.begin(write=True)
    for i in range(0, len(imgs), BATCH):
        batch_imgs = imgs[i:i + BATCH]
        batch_keys = keys[i:i + BATCH]
        batch_data = read_imgs_multi_thread(batch_imgs, batch_keys, n_thread)
        pbar = util.ProgressBar(len(batch_imgs))
        for k, v in batch_data.items():
            pbar.update('Write {}'.format(k))
            key_byte = k.encode('ascii')
            txn.put(key_byte, v)
        txn.commit()
        txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    meta_info['name'] = 'vimeo_train'
    meta_info['resolution'] = '{}_{}_{}'.format(C, H, W)
    meta_info['keys'] = keys
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 2
0
def vimeo_test(img_root, lmdb_save_path):
    gt_root = osp.join(img_root, 'target')
    lq_root = osp.join(img_root, 'low_resolution')
    txt_file = osp.join(img_root, 'sep_testlist.txt')
    with open(txt_file, 'r') as f:
        lines = f.readlines()
        img_list = [line.strip() for line in lines]
    imgs, keys = [], []
    for item in img_list:
        gt_key = 'gt_' + item.replace('/', '_') + '_4'
        lq_key_pre = 'lq_' + item.replace('/', '_')
        gt_img = osp.join(gt_root, item, 'im4.png')
        imgs.append(gt_img)
        keys.append(gt_key)
        lq_im_dir = osp.join(lq_root, item)
        lq_names = sorted(os.listdir(lq_im_dir))
        for name in lq_names:
            imgs.append(osp.join(lq_im_dir, name))
            keys.append(lq_key_pre + '_' + name[2])

    im1 = cv2.imread(imgs[0], cv2.IMREAD_UNCHANGED)
    H, W, C = im1.shape
    im2 = cv2.imread(imgs[1], cv2.IMREAD_UNCHANGED)
    lH, lW, lC = im2.shape
    print('data size per image is: ', im1.nbytes)
    data_size = im1.nbytes * len(imgs)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    txn = env.begin(write=True)
    img_data = read_imgs_multi_thread(imgs, keys, 40)
    pbar = util.ProgressBar(len(imgs))
    for k, v in img_data.items():
        pbar.update('Write {}'.format(k))
        key_byte = k.encode('ascii')
        txn.put(key_byte, v)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    meta_info['name'] = 'vimeo_test'
    meta_info['gt_resolution'] = '{}_{}_{}'.format(C, H, W)
    meta_info['lq_resolution'] = '{}_{}_{}'.format(lC, lH, lW)
    meta_info['keys'] = keys
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 3
0
def read_imgs_multi_thread(imgs, keys, n_thread=40):
    #### read all images to memory (multiprocessing)
    dataset = {}  # store all image data. list cannot keep the order, use dict
    print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
    pbar = util.ProgressBar(len(imgs))

    def mycallback(arg):
        '''get the image data and update pbar'''
        key = arg[0]
        dataset[key] = arg[1]
        pbar.update('Reading {}'.format(key))

    pool = Pool(n_thread)
    for path, key in zip(imgs, keys):
        pool.apply_async(read_image_worker,
                         args=(path, key),
                         callback=mycallback)
    pool.close()
    pool.join()
    print('Finish reading {} images.'.format(len(imgs)))
    return dataset
Esempio n. 4
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--exp_name', type=str, default='temp')
    parser.add_argument('--degradation_type', type=str, default=None)
    parser.add_argument('--sigma_x', type=float, default=None)
    parser.add_argument('--sigma_y', type=float, default=None)
    parser.add_argument('--theta', type=float, default=None)
    args = parser.parse_args()
    if args.exp_name == 'temp':
        opt = option.parse(args.opt, is_train=True)
    else:
        opt = option.parse(args.opt, is_train=True, exp_name=args.exp_name)

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)
    inner_loop_name = opt['train']['maml']['optimizer'][0] + str(
        opt['train']['maml']['adapt_iter']) + str(
            math.floor(math.log10(opt['train']['maml']['lr_alpha'])))
    meta_loop_name = opt['train']['optim'][0] + str(
        math.floor(math.log10(opt['train']['lr_G'])))

    if args.degradation_type is not None:
        if args.degradation_type == 'preset':
            opt['datasets']['val']['degradation_mode'] = args.degradation_type
        else:
            opt['datasets']['val']['degradation_type'] = args.degradation_type
    if args.sigma_x is not None:
        opt['datasets']['val']['sigma_x'] = args.sigma_x
    if args.sigma_y is not None:
        opt['datasets']['val']['sigma_y'] = args.sigma_y
    if args.theta is not None:
        opt['datasets']['val']['theta'] = args.theta
    if opt['datasets']['val']['degradation_mode'] == 'set':
        degradation_name = str(opt['datasets']['val']['degradation_type'])\
                  + '_' + str(opt['datasets']['val']['sigma_x']) \
                  + '_' + str(opt['datasets']['val']['sigma_y'])\
                  + '_' + str(opt['datasets']['val']['theta'])
    else:
        degradation_name = opt['datasets']['val']['degradation_mode']
    patch_name = 'p{}x{}'.format(
        opt['train']['maml']['patch_size'], opt['train']['maml']
        ['num_patch']) if opt['train']['maml']['use_patch'] else 'full'
    use_real_flag = '_ideal' if opt['train']['use_real'] else ''
    folder_name = opt[
        'name'] + '_' + degradation_name  # + '_' + inner_loop_name + meta_loop_name + '_' + degradation_name + '_' + patch_name + use_real_flag

    if args.exp_name != 'temp':
        folder_name = args.exp_name

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            #util.mkdir_and_rename(
            #    opt['path']['experiments_root'])  # rename experiment folder if exists
            #util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
            #             and 'pretrain_model' not in key and 'resume' not in key))
            if not os.path.exists(opt['path']['experiments_root']):
                os.mkdir(opt['path']['experiments_root'])
                # raise ValueError('Path does not exists - check path')

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        #logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + folder_name)
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            pass
        elif phase == 'val':
            if '+' in opt['datasets']['val']['name']:
                val_set, val_loader = [], []
                valname_list = opt['datasets']['val']['name'].split('+')
                for i in range(len(valname_list)):
                    val_set.append(
                        create_dataset(
                            dataset_opt,
                            scale=opt['scale'],
                            kernel_size=opt['datasets']['train']
                            ['kernel_size'],
                            model_name=opt['network_E']['which_model_E'],
                            idx=i))
                    val_loader.append(
                        create_dataloader(val_set[-1], dataset_opt, opt, None))
            else:
                val_set = create_dataset(
                    dataset_opt,
                    scale=opt['scale'],
                    kernel_size=opt['datasets']['train']['kernel_size'],
                    model_name=opt['network_E']['which_model_E'])
                # val_set = loader.get_dataset(opt, train=False)
                val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))

    #### create model
    models = create_model(opt)
    assert len(models) == 2
    model, est_model = models[0], models[1]
    modelcp, est_modelcp = create_model(opt)
    _, est_model_fixed = create_model(opt)

    center_idx = (opt['datasets']['val']['N_frames']) // 2
    lr_alpha = opt['train']['maml']['lr_alpha']
    update_step = opt['train']['maml']['adapt_iter']

    pd_log = pd.DataFrame(
        columns=['PSNR_Bicubic', 'PSNR_Ours', 'SSIM_Bicubic', 'SSIM_Ours'])

    def crop(LR_seq, HR, num_patches_for_batch=4, patch_size=44):
        """
        Crop given patches.

        Args:
            LR_seq: (B=1) x T x C x H x W
            HR: (B=1) x C x H x W

            patch_size (int, optional):

        Return:
            B(=batch_size) x T x C x H x W
        """
        # Find the lowest resolution
        cropped_lr = []
        cropped_hr = []
        assert HR.size(0) == 1
        LR_seq_ = LR_seq[0]
        HR_ = HR[0]
        for _ in range(num_patches_for_batch):
            patch_lr, patch_hr = preprocessing.common_crop(
                LR_seq_, HR_, patch_size=patch_size // 2)
            cropped_lr.append(patch_lr)
            cropped_hr.append(patch_hr)

        cropped_lr = torch.stack(cropped_lr, dim=0)
        cropped_hr = torch.stack(cropped_hr, dim=0)

        return cropped_lr, cropped_hr

    # Single GPU
    # PSNR_rlt: psnr_init, psnr_before, psnr_after
    psnr_rlt = [{}, {}]
    # SSIM_rlt: ssim_init, ssim_after
    ssim_rlt = [{}, {}]
    pbar = util.ProgressBar(len(val_set))
    for val_data in val_loader:
        folder = val_data['folder'][0]
        idx_d = int(val_data['idx'][0].split('/')[0])
        if 'name' in val_data.keys():
            name = val_data['name'][0][center_idx][0]
        else:
            #name = '{}/{:08d}'.format(folder, idx_d)
            name = folder

        train_folder = os.path.join('../results_for_paper', folder_name, name)

        hr_train_folder = os.path.join(train_folder, 'hr')
        bic_train_folder = os.path.join(train_folder, 'bic')
        maml_train_folder = os.path.join(train_folder, 'maml')
        #slr_train_folder = os.path.join(train_folder, 'slr')

        # print(train_folder)
        if not os.path.exists(train_folder):
            os.makedirs(train_folder, exist_ok=False)
        if not os.path.exists(hr_train_folder):
            os.mkdir(hr_train_folder)
        if not os.path.exists(bic_train_folder):
            os.mkdir(bic_train_folder)
        if not os.path.exists(maml_train_folder):
            os.mkdir(maml_train_folder)
        #if not os.path.exists(slr_train_folder):
        #    os.mkdir(slr_train_folder)

        for i in range(len(psnr_rlt)):
            if psnr_rlt[i].get(folder, None) is None:
                psnr_rlt[i][folder] = []
        for i in range(len(ssim_rlt)):
            if ssim_rlt[i].get(folder, None) is None:
                ssim_rlt[i][folder] = []

        if idx_d % 10 != 5:
            #continue
            pass

        cropped_meta_train_data = {}
        meta_train_data = {}
        meta_test_data = {}

        # Make SuperLR seq using estimation model
        meta_train_data['GT'] = val_data['LQs'][:, center_idx]
        meta_test_data['LQs'] = val_data['LQs'][0:1]
        meta_test_data['GT'] = val_data['GT'][0:1, center_idx]
        # Check whether the batch size of each validation data is 1
        assert val_data['SuperLQs'].size(0) == 1

        if opt['network_G']['which_model_G'] == 'TOF':
            LQs = meta_test_data['LQs']
            B, T, C, H, W = LQs.shape
            LQs = LQs.reshape(B * T, C, H, W)
            Bic_LQs = F.interpolate(LQs,
                                    scale_factor=opt['scale'],
                                    mode='bicubic',
                                    align_corners=True)
            meta_test_data['LQs'] = Bic_LQs.reshape(B, T, C, H * opt['scale'],
                                                    W * opt['scale'])

        ## Before start training, first save the bicubic, real outputs
        # Bicubic
        modelcp.load_network(opt['path']['bicubic_G'], modelcp.netG)
        modelcp.feed_data(meta_test_data)
        modelcp.test()
        model_start_visuals = modelcp.get_current_visuals(need_GT=True)
        hr_image = util.tensor2img(model_start_visuals['GT'], mode='rgb')
        start_image = util.tensor2img(model_start_visuals['rlt'], mode='rgb')

        #####imageio.imwrite(os.path.join(hr_train_folder, '{:08d}.png'.format(idx_d)), hr_image)
        #####imageio.imwrite(os.path.join(bic_train_folder, '{:08d}.png'.format(idx_d)), start_image)
        psnr_rlt[0][folder].append(util.calculate_psnr(start_image, hr_image))
        ssim_rlt[0][folder].append(util.calculate_ssim(start_image, hr_image))

        modelcp.netG, est_modelcp.netE = deepcopy(model.netG), deepcopy(
            est_model.netE)

        ########## SLR LOSS Preparation ############
        est_model_fixed.load_network(opt['path']['fixed_E'],
                                     est_model_fixed.netE)

        optim_params = []
        for k, v in modelcp.netG.named_parameters():
            if v.requires_grad:
                optim_params.append(v)

        if not opt['train']['use_real']:
            for k, v in est_modelcp.netE.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)

        if opt['train']['maml']['optimizer'] == 'Adam':
            inner_optimizer = torch.optim.Adam(
                optim_params,
                lr=lr_alpha,
                betas=(opt['train']['maml']['beta1'],
                       opt['train']['maml']['beta2']))
        elif opt['train']['maml']['optimizer'] == 'SGD':
            inner_optimizer = torch.optim.SGD(optim_params, lr=lr_alpha)
        else:
            raise NotImplementedError()

        # Inner Loop Update
        st = time.time()
        for i in range(update_step):
            # Make SuperLR seq using UPDATED estimation model
            if not opt['train']['use_real']:
                est_modelcp.feed_data(val_data)
                # est_model.test()
                est_modelcp.forward_without_optim()
                superlr_seq = est_modelcp.fake_L
                meta_train_data['LQs'] = superlr_seq
            else:
                meta_train_data['LQs'] = val_data['SuperLQs']

            if opt['network_G']['which_model_G'] == 'TOF':
                # Bicubic upsample to match the size
                LQs = meta_train_data['LQs']
                B, T, C, H, W = LQs.shape
                LQs = LQs.reshape(B * T, C, H, W)
                Bic_LQs = F.interpolate(LQs,
                                        scale_factor=opt['scale'],
                                        mode='bicubic',
                                        align_corners=True)
                meta_train_data['LQs'] = Bic_LQs.reshape(
                    B, T, C, H * opt['scale'], W * opt['scale'])

            # Update both modelcp + estmodelcp jointly
            inner_optimizer.zero_grad()
            if opt['train']['maml']['use_patch']:
                cropped_meta_train_data['LQs'], cropped_meta_train_data['GT'] = \
                    crop(meta_train_data['LQs'], meta_train_data['GT'],
                         opt['train']['maml']['num_patch'],
                         opt['train']['maml']['patch_size'])
                modelcp.feed_data(cropped_meta_train_data)
            else:
                modelcp.feed_data(meta_train_data)

            loss_train = modelcp.calculate_loss()

            ##################### SLR LOSS ###################
            est_model_fixed.feed_data(val_data)
            est_model_fixed.test()
            slr_initialized = est_model_fixed.fake_L
            slr_initialized = slr_initialized.to('cuda')
            if opt['network_G']['which_model_G'] == 'TOF':
                loss_train += 10 * F.l1_loss(
                    LQs.to('cuda').squeeze(0), slr_initialized)
            else:
                loss_train += 10 * F.l1_loss(meta_train_data['LQs'].to('cuda'),
                                             slr_initialized)

            loss_train.backward()
            inner_optimizer.step()

        et = time.time()
        update_time = et - st

        modelcp.feed_data(meta_test_data)
        modelcp.test()

        model_update_visuals = modelcp.get_current_visuals(need_GT=False)
        update_image = util.tensor2img(model_update_visuals['rlt'], mode='rgb')
        # Save and calculate final image
        imageio.imwrite(
            os.path.join(maml_train_folder, '{:08d}.png'.format(idx_d)),
            update_image)
        psnr_rlt[1][folder].append(util.calculate_psnr(update_image, hr_image))
        ssim_rlt[1][folder].append(util.calculate_ssim(update_image, hr_image))

        name_df = '{}/{:08d}'.format(folder, idx_d)
        if name_df in pd_log.index:
            pd_log.at[name_df, 'PSNR_Bicubic'] = psnr_rlt[0][folder][-1]
            pd_log.at[name_df, 'PSNR_Ours'] = psnr_rlt[1][folder][-1]
            pd_log.at[name_df, 'SSIM_Bicubic'] = ssim_rlt[0][folder][-1]
            pd_log.at[name_df, 'SSIM_Ours'] = ssim_rlt[1][folder][-1]
        else:
            pd_log.loc[name_df] = [
                psnr_rlt[0][folder][-1], psnr_rlt[1][folder][-1],
                ssim_rlt[0][folder][-1], ssim_rlt[1][folder][-1]
            ]

        pd_log.to_csv(
            os.path.join('../results_for_paper', folder_name,
                         'psnr_update.csv'))

        pbar.update(
            'Test {} - {}: I: {:.3f}/{:.4f} \tF+: {:.3f}/{:.4f} \tTime: {:.3f}s'
            .format(folder, idx_d, psnr_rlt[0][folder][-1],
                    ssim_rlt[0][folder][-1], psnr_rlt[1][folder][-1],
                    ssim_rlt[1][folder][-1], update_time))

    psnr_rlt_avg = {}
    psnr_total_avg = 0.
    # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2])
    for k, v in psnr_rlt[0].items():
        psnr_rlt_avg[k] = sum(v) / len(v)
        psnr_total_avg += psnr_rlt_avg[k]
    psnr_total_avg /= len(psnr_rlt[0])
    log_s = '# Validation # Bic PSNR: {:.4e}:'.format(psnr_total_avg)
    for k, v in psnr_rlt_avg.items():
        log_s += ' {}: {:.4e}'.format(k, v)
    logger.info(log_s)

    psnr_rlt_avg = {}
    psnr_total_avg = 0.
    # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2])
    for k, v in psnr_rlt[1].items():
        psnr_rlt_avg[k] = sum(v) / len(v)
        psnr_total_avg += psnr_rlt_avg[k]
    psnr_total_avg /= len(psnr_rlt[1])
    log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
    for k, v in psnr_rlt_avg.items():
        log_s += ' {}: {:.4e}'.format(k, v)
    logger.info(log_s)

    ssim_rlt_avg = {}
    ssim_total_avg = 0.
    # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1])
    for k, v in ssim_rlt[0].items():
        ssim_rlt_avg[k] = sum(v) / len(v)
        ssim_total_avg += ssim_rlt_avg[k]
    ssim_total_avg /= len(ssim_rlt[0])
    log_s = '# Validation # Bicubic SSIM: {:.4e}:'.format(ssim_total_avg)
    for k, v in ssim_rlt_avg.items():
        log_s += ' {}: {:.4e}'.format(k, v)
    logger.info(log_s)

    ssim_rlt_avg = {}
    ssim_total_avg = 0.
    # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1])
    for k, v in ssim_rlt[1].items():
        ssim_rlt_avg[k] = sum(v) / len(v)
        ssim_total_avg += ssim_rlt_avg[k]
    ssim_total_avg /= len(ssim_rlt[1])
    log_s = '# Validation # SSIM: {:.4e}:'.format(ssim_total_avg)
    for k, v in ssim_rlt_avg.items():
        log_s += ' {}: {:.4e}'.format(k, v)
    logger.info(log_s)

    logger.info('End of evaluation.')
Esempio n. 5
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        default='options/train/train_EDVR_woTSA_M.yml',
                        help='Path to option YAML file.')
    parser.add_argument('--set',
                        dest='set_opt',
                        default=None,
                        nargs=argparse.REMAINDER,
                        help='set options')
    args = parser.parse_args()
    opt = option.parse(args.opt, args.set_opt, is_train=True)

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        print('Training from state: {}'.format(opt['path']['resume_state']))
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    elif opt['auto_resume']:
        exp_dir = opt['path']['experiments_root']
        # first time run: create dirs
        if not os.path.exists(exp_dir):
            os.makedirs(exp_dir)
            os.makedirs(opt['path']['models'])
            os.makedirs(opt['path']['training_state'])
            os.makedirs(opt['path']['val_images'])
            os.makedirs(opt['path']['tb_logger'])
            resume_state = None
        else:
            # detect experiment directory and get the latest state
            state_dir = opt['path']['training_state']
            state_files = [
                x for x in os.listdir(state_dir) if x.endswith('state')
            ]
            # no valid state detected
            if len(state_files) < 1:
                print(
                    'No previous training state found, train from start state')
                resume_state = None
            else:
                state_files = sorted(state_files,
                                     key=lambda x: int(x.split('.')[0]))
                latest_state = state_files[-1]
                print('Training from lastest state: {}'.format(latest_state))
                latest_state_file = os.path.join(state_dir, latest_state)
                opt['path']['resume_state'] = latest_state_file
                device_id = torch.cuda.current_device()
                resume_state = torch.load(
                    latest_state_file,
                    map_location=lambda storage, loc: storage.cuda(device_id))
                option.check_resume(opt, resume_state['iter'])
    else:
        resume_state = None

    if resume_state is None and not opt['auto_resume'] and not opt['no_log']:
        util.mkdir_and_rename(
            opt['path']
            ['experiments_root'])  # rename experiment folder if exists
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger('base',
                      opt['path']['log'],
                      'train_' + opt['name'],
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        version = float(torch.__version__[0:3])
        if version >= 1.2:  # PyTorch 1.1
            from torch.utils.tensorboard import SummaryWriter
        else:
            logger.info(
                'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                .format(version))
            from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir=opt['path']['tb_logger'])

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    if opt['datasets']['train']['ratio']:
        dataset_ratio = opt['datasets']['train']['ratio']
    else:
        dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(
                math.ceil(total_iters / (train_size * dataset_ratio)))
            if dataset_opt['mode'] in ['MetaREDS', 'MetaREDSOnline']:
                train_sampler = MetaIterSampler(train_set,
                                                dataset_opt['batch_size'],
                                                len(opt['scale']),
                                                dataset_ratio)
            elif dataset_opt['mode'] in ['REDS', 'MultiREDS']:
                train_sampler = IterSampler(train_set,
                                            dataset_opt['batch_size'],
                                            dataset_ratio)
            else:
                train_sampler = None

            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)
            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)
                print("PROGRESS: {:02d}%".format(
                    int(current_step / total_iters * 100)))
            #### validation
            if opt['datasets'].get(
                    'val',
                    None) and current_step % opt['train']['val_freq'] == 0:
                pbar = util.ProgressBar(len(val_loader))
                psnr_rlt = {}  # with border and center frames
                psnr_rlt_avg = {}
                psnr_total_avg = 0.
                for val_data in val_loader:
                    folder = val_data['folder'][0]
                    idx_d = val_data['idx'].item()
                    # border = val_data['border'].item()
                    if psnr_rlt.get(folder, None) is None:
                        psnr_rlt[folder] = []

                    model.feed_data(val_data)
                    model.test()
                    visuals = model.get_current_visuals()
                    rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    # calculate PSNR
                    psnr = util.calculate_psnr(rlt_img, gt_img)
                    psnr_rlt[folder].append(psnr)
                    pbar.update('Test {} - {}'.format(folder, idx_d))
                for k, v in psnr_rlt.items():
                    psnr_rlt_avg[k] = sum(v) / len(v)
                    psnr_total_avg += psnr_rlt_avg[k]
                psnr_total_avg /= len(psnr_rlt)
                log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
                for k, v in psnr_rlt_avg.items():
                    log_s += ' {}: {:.4e}'.format(k, v)
                logger.info(log_s)
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                         current_step)
                    for k, v in psnr_rlt_avg.items():
                        tb_logger.add_scalar(k, v, current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
    tb_logger.close()
Esempio n. 6
0
def REDS(mode):
    """Create lmdb for the REDS dataset, each image with a fixed size
    GT: [3, 720, 1280], key: 000_00000000
    LR: [3, 180, 320], key: 000_00000000
    key: 000_00000000

    flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2]
        Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4
        Flow map is quantized by mmcv and saved in png format
    """
    #### configurations
    read_all_imgs = False  # whether real all images to memory with multiprocessing
    # Set False for use limited memory
    BATCH = 5000  # After BATCH images, lmdb commits, if read_all_imgs = False
    if mode == 'train_sharp':
        img_folder = '../../datasets/REDS/train_sharp'
        lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb'
        H_dst, W_dst = 720, 1280
    elif mode == 'train_sharp_bicubic':
        img_folder = '../../datasets/REDS/train_sharp_bicubic'
        lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
        H_dst, W_dst = 180, 320
    elif mode == 'train_blur_bicubic':
        img_folder = '../../datasets/REDS/train_blur_bicubic'
        lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb'
        H_dst, W_dst = 180, 320
    elif mode == 'train_blur':
        img_folder = '../../datasets/REDS/train_blur'
        lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb'
        H_dst, W_dst = 720, 1280
    elif mode == 'train_blur_comp':
        img_folder = '../../datasets/REDS/train_blur_comp'
        lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb'
        H_dst, W_dst = 720, 1280
    elif mode == 'train_sharp_flowx4':
        img_folder = '../../datasets/REDS/train_sharp_flowx4'
        lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb'
        H_dst, W_dst = 360, 320
    n_thread = 40
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    all_img_list = data_util._get_paths_from_images(img_folder)
    keys = []
    for img_path in all_img_list:
        split_rlt = img_path.split('/')
        folder = split_rlt[-2]
        img_name = split_rlt[-1].split('.png')[0]
        keys.append(folder + '_' + img_name)

    if read_all_imgs:
        #### read all images to memory (multiprocessing)
        dataset = {
        }  # store all image data. list cannot keep the order, use dict
        print('Read images with multiprocessing, #thread: {} ...'.format(
            n_thread))
        pbar = util.ProgressBar(len(all_img_list))

        def mycallback(arg):
            '''get the image data and update pbar'''
            key = arg[0]
            dataset[key] = arg[1]
            pbar.update('Reading {}'.format(key))

        pool = Pool(n_thread)
        for path, key in zip(all_img_list, keys):
            pool.apply_async(read_image_worker,
                             args=(path, key),
                             callback=mycallback)
        pool.close()
        pool.join()
        print('Finish reading {} images.\nWrite lmdb...'.format(
            len(all_img_list)))

    #### create lmdb environment
    data_size_per_img = cv2.imread(all_img_list[0],
                                   cv2.IMREAD_UNCHANGED).nbytes
    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))
    txn = env.begin(write=True)
    for idx, (path, key) in enumerate(zip(all_img_list, keys)):
        pbar.update('Write {}'.format(key))
        key_byte = key.encode('ascii')
        data = dataset[key] if read_all_imgs else cv2.imread(
            path, cv2.IMREAD_UNCHANGED)
        if 'flow' in mode:
            H, W = data.shape
            assert H == H_dst and W == W_dst, 'different shape.'
        else:
            H, W, C = data.shape
            assert H == H_dst and W == W_dst and C == 3, 'different shape.'
        txn.put(key_byte, data)
        if not read_all_imgs and idx % BATCH == 0:
            txn.commit()
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    meta_info['name'] = 'REDS_{}_wval'.format(mode)
    channel = 1 if 'flow' in mode else 3
    meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
    meta_info['keys'] = keys
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 7
0
def SR4K(mode):
    """Create lmdb for the 4k dataset, each image with a fixed size
    GT: [3, 3840, 2160], key: 000_00000000
    LR: [3, 960, 540], key: 000_00000000
    key: 000_00000000
    """
    #### configurations
    read_all_imgs = False  # whether real all images to memory with multiprocessing
    # Set False for use limited memory
    BATCH = 500  # After BATCH images, lmdb commits, if read_all_imgs = False
    train_txt = '/home/mcc/4khdr/train.txt'
    if mode == 'train_4k':
        img_folder = '/home/mcc/4khdr/image/4k'
        lmdb_save_path = '/home/mcc/4khdr/4k.lmdb'
        H_dst, W_dst = 1080, 1920
        BATCH = 1000
    elif mode == 'train_540p':
        img_folder = '/home/mcc/4khdr/image/540p'
        lmdb_save_path = '/home/mcc/4khdr/540p.lmdb'
        H_dst, W_dst = 270, 480
        BATCH = 5000
    n_thread = 12
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    with open(train_txt, 'r') as f:
        train_list = [x.strip() for x in f.readlines()]
    all_img_list = []
    for dirpath, _, fnames in sorted(os.walk(img_folder)):
        if not osp.basename(dirpath)[:-2] in train_list:
            continue
        for fname in sorted(fnames):
            if fname.endswith('.png'):
                img_path = osp.join(dirpath, fname)
                all_img_list.append(img_path)

    keys = []
    for img_path in all_img_list:
        split_rlt = img_path.split('/')
        folder = split_rlt[-2]
        img_name = split_rlt[-1].split('.png')[0]
        keys.append(folder + '_' + img_name)

    if read_all_imgs:
        #### read all images to memory (multiprocessing)
        dataset = {
        }  # store all image data. list cannot keep the order, use dict
        print('Read images with multiprocessing, #thread: {} ...'.format(
            n_thread))
        pbar = util.ProgressBar(len(all_img_list))

        def mycallback(arg):
            '''get the image data and update pbar'''
            key = arg[0]
            dataset[key] = arg[1]
            pbar.update('Reading {}'.format(key))

        pool = Pool(n_thread)
        for path, key in zip(all_img_list, keys):
            pool.apply_async(read_image_worker,
                             args=(path, key),
                             callback=mycallback)
        pool.close()
        pool.join()
        print('Finish reading {} images.\nWrite lmdb...'.format(
            len(all_img_list)))

    #### create lmdb environment
    data_size_per_img = cv2.imread(all_img_list[0],
                                   cv2.IMREAD_UNCHANGED).nbytes
    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))
    txn = env.begin(write=True)
    for idx, (path, key) in enumerate(zip(all_img_list, keys)):
        pbar.update('Write {}'.format(key))
        key_byte = key.encode('ascii')
        data = dataset[key] if read_all_imgs else cv2.imread(
            path, cv2.IMREAD_UNCHANGED)
        H, W, C = data.shape
        assert H == H_dst and W == W_dst and C == 3, 'different shape.'
        txn.put(key_byte, data)
        if not read_all_imgs and idx % BATCH == 0:
            txn.commit()
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    meta_info['name'] = 'REDS_{}_wval'.format(mode)
    channel = 1 if 'flow' in mode else 3
    meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
    meta_info['keys'] = keys
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 8
0
def youku(mode):
    """Create lmdb for the youku dataset, each image with a fixed size
    GT: [3, 1080, 1920] or [3, 1152, 2048], key: 00000_000000
    LR: [3, 270, 480] or [3, 288, 512], key: 00000_000000
    """
    #### configurations
    read_all_imgs = False  # whether real all images to memory with multiprocessing
    # Set False for use limited memory
    BATCH = 5000  # After BATCH images, lmdb commits, if read_all_imgs = False
    if mode == 'gt':
        train_folder = '/media/tclwh2/public/youku/train/gt'
        val_folder = '/media/tclwh2/public/youku/val/gt'
        lmdb_save_path = '/media/tclwh2/public/youku/youku_train_gt.lmdb'
        H_dst, W_dst = (1080, 1152), (1920, 2048)
    elif mode == 'lq':
        train_folder = '/media/tclwh2/public/youku/train/lq'
        val_folder = '/media/tclwh2/public/youku/val/lq'
        lmdb_save_path = '/media/tclwh2/public/youku/youku_train_lq.lmdb'
        H_dst, W_dst = (270, 288), (480, 512)

    n_thread = 40
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    train_img_list = data_util._get_paths_from_images(train_folder)
    val_img_list = data_util._get_paths_from_images(val_folder)
    all_img_list = sorted(train_img_list + val_img_list)
    keys = []
    for img_path in all_img_list:
        split_rlt = img_path.split('/')
        folder = split_rlt[-2]
        img_name = split_rlt[-1].split('.png')[0]
        keys.append(folder + '_' + img_name)

    if read_all_imgs:
        #### read all images to memory (multiprocessing)
        dataset = {
        }  # store all image data. list cannot keep the order, use dict
        print('Read images with multiprocessing, #thread: {} ...'.format(
            n_thread))
        pbar = util.ProgressBar(len(all_img_list))

        def mycallback(arg):
            '''get the image data and update pbar'''
            key = arg[0]
            dataset[key] = arg[1]
            pbar.update('Reading {}'.format(key))

        pool = Pool(n_thread)
        for path, key in zip(all_img_list, keys):
            pool.apply_async(read_image_worker,
                             args=(path, key),
                             callback=mycallback)
        pool.close()
        pool.join()
        print('Finish reading {} images.\nWrite lmdb...'.format(
            len(all_img_list)))

    #### create lmdb environment
    cnt_1 = 993 * 100
    cnt_2 = 7 * 100
    assert cnt_1 + cnt_2 == len(all_img_list)
    img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED)
    data_size_per_img_1 = cv2.imread(all_img_list[0],
                                     cv2.IMREAD_UNCHANGED).nbytes
    data_size_per_img_2 = cv2.imread(all_img_list[30 * 100],
                                     cv2.IMREAD_UNCHANGED).nbytes
    assert data_size_per_img_1 != data_size_per_img_2
    print('data size per image is: %d and %d' %
          (data_size_per_img_1, data_size_per_img_2))
    data_size = data_size_per_img_1 * cnt_1 + data_size_per_img_2 + cnt_2
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))
    txn = env.begin(write=True)
    for idx, (path, key) in enumerate(zip(all_img_list, keys)):
        pbar.update('Write {}'.format(key))
        key_byte = key.encode('ascii')
        data = dataset[key] if read_all_imgs else cv2.imread(
            path, cv2.IMREAD_UNCHANGED)
        H, W, C = data.shape
        assert H in H_dst and W in W_dst and C == 3, 'different shape.'
        txn.put(key_byte, data)
        if not read_all_imgs and idx % BATCH == 0:
            txn.commit()
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    meta_info['name'] = 'youku_train_{}'.format(mode)
    channel = 3
    meta_info['resolution_2'] = '{}_{}_{}'.format(channel, H_dst[1], W_dst[1])
    meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst[0], W_dst[0])
    meta_info['res_2_list'] = '_'.join(
        ['31', '44', '54', '101', '121', '142', '177'])
    meta_info['keys'] = keys
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 9
0
def REDS():
    '''create lmdb for the REDS dataset, each image with fixed size
    GT: [3, 720, 1280], key: 000_00000000
    LR: [3, 180, 320], key: 000_00000000
    key: 000_00000000
    '''
    #### configurations
    mode = 'train_sharp'
    # train_sharp | train_sharp_bicubic | train_blur_bicubic| train_blur | train_blur_comp
    if mode == 'train_sharp':
        img_folder = '/home/xtwang/datasets/REDS/train_sharp'
        lmdb_save_path = '/home/xtwang/datasets/REDS/train_sharp_wval.lmdb'
        H_dst, W_dst = 720, 1280
    elif mode == 'train_sharp_bicubic':
        img_folder = '/home/xtwang/datasets/REDS/train_sharp_bicubic'
        lmdb_save_path = '/home/xtwang/datasets/REDS/train_sharp_bicubic_wval.lmdb'
        H_dst, W_dst = 180, 320
    elif mode == 'train_blur_bicubic':
        img_folder = '/home/xtwang/datasets/REDS/train_blur_bicubic'
        lmdb_save_path = '/home/xtwang/datasets/REDS/train_blur_bicubic_wval.lmdb'
        H_dst, W_dst = 180, 320
    elif mode == 'train_blur':
        img_folder = '/home/xtwang/datasets/REDS/train_blur'
        lmdb_save_path = '/home/xtwang/datasets/REDS/train_blur_wval.lmdb'
        H_dst, W_dst = 720, 1280
    elif mode == 'train_blur_comp':
        img_folder = '/home/xtwang/datasets/REDS/train_blur_comp'
        lmdb_save_path = '/home/xtwang/datasets/REDS/train_blur_comp_wval.lmdb'
        H_dst, W_dst = 720, 1280
    n_thread = 40
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    #### whether the lmdb file exist
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    all_img_list = data_util._get_paths_from_images(img_folder)
    keys = []
    for img_path in all_img_list:
        split_rlt = img_path.split('/')
        a = split_rlt[-2]
        b = split_rlt[-1].split('.png')[0]
        keys.append(a + '_' + b)

    #### read all images to memory (multiprocessing)
    dataset = {}  # store all image data. list cannot keep the order, use dict
    print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
    pbar = util.ProgressBar(len(all_img_list))

    def mycallback(arg):
        '''get the image data and update pbar'''
        key = arg[0]
        dataset[key] = arg[1]
        pbar.update('Reading {}'.format(key))

    pool = Pool(n_thread)
    for path, key in zip(all_img_list, keys):
        pool.apply_async(reading_image_worker,
                         args=(path, key),
                         callback=mycallback)
    pool.close()
    pool.join()
    print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))

    #### create lmdb environment
    data_size_per_img = dataset['000_00000000'].nbytes
    if 'flow' in mode:
        data_size_per_img = dataset['000_00000002_n1'].nbytes
    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))
    with env.begin(write=True) as txn:
        for key in keys:
            pbar.update('Write {}'.format(key))
            key_byte = key.encode('ascii')
            data = dataset[key]
            if 'flow' in mode:
                H, W = data.shape
                assert H == H_dst and W == W_dst, 'different shape.'
            else:
                H, W, C = data.shape  # fixed shape
                assert H == H_dst and W == W_dst and C == 3, 'different shape.'
            txn.put(key_byte, data)
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    meta_info['name'] = 'REDS_{}_wval'.format(mode)
    if 'flow' in mode:
        meta_info['resolution'] = '{}_{}_{}'.format(1, H_dst, W_dst)
    else:
        meta_info['resolution'] = '{}_{}_{}'.format(3, H_dst, W_dst)
    meta_info['keys'] = keys
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 10
0
File: train.py Progetto: zoq/BIN
def main():

    ############################################
    #
    #           set options
    #
    ############################################

    parser = argparse.ArgumentParser()
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    ############################################
    #
    #           distributed training settings
    #
    ############################################

    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

        print("Rank:", rank)
        print("World Size", world_size)
        print("------------------DIST-------------------------")

    ############################################
    #
    #           loading resume state if exists
    #
    ############################################

    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    ############################################
    #
    #           mkdir and loggers
    #
    ############################################
    if 'debug' in opt['name']:
        debug_mode = True
    else:
        debug_mode = False

    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists

            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        logger = logging.getLogger('base')
        logger_val = logging.getLogger('base_val')

        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_',
                          level=logging.INFO,
                          screen=True)
        print("set train log")
        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_',
                          level=logging.INFO,
                          screen=True)
        print("set val log")
        logger = logging.getLogger('base')
        logger_val = logging.getLogger('base_val')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    ############################################
    #
    #           create train and val dataloader
    #
    ############################################
    ####

    # dataset_ratio = 200  # enlarge the size of each epoch
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            if opt['datasets']['train'].get('split', None):
                train_set, val_set = create_dataset(dataset_opt)
            else:
                train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            # total_iters = int(opt['train']['niter'])
            # total_epochs = int(math.ceil(total_iters / train_size))
            total_iters = train_size
            total_epochs = int(opt['train']['epoch'])
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                # total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
                total_epochs = int(opt['train']['epoch'])
                if opt['train']['enable'] == False:
                    total_epochs = 1
            else:
                # train_sampler = None
                train_sampler = RandomBalancedSampler(train_set, train_size)
            train_loader = create_dataloader(train_set,
                                             dataset_opt,
                                             opt,
                                             train_sampler,
                                             vscode_debug=debug_mode)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            if not opt['datasets']['train'].get('split', None):
                val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set,
                                           dataset_opt,
                                           opt,
                                           None,
                                           vscode_debug=debug_mode)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))

    assert train_loader is not None

    ############################################
    #
    #          create model
    #
    ############################################
    ####

    model = create_model(opt)

    print("Model Created! ")

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0
        print("Not Resume Training")

    ############################################
    #
    #          training
    #
    ############################################

    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    model.train_AverageMeter()
    saved_total_loss = 10e10
    saved_total_PSNR = -1
    saved_total_SSIM = -1

    for epoch in range(start_epoch, total_epochs):

        ############################################
        #
        #          Start a new epoch
        #
        ############################################

        current_step = 0

        if opt['dist']:
            train_sampler.set_epoch(epoch)

        for train_idx, train_data in enumerate(train_loader):

            # print('current_step', current_step)

            if 'debug' in opt['name']:
                img_dir = os.path.join(opt['path']['train_images'])
                util.mkdir(img_dir)

                LQs = train_data['LQs']  # B N C H W

                if not 'sr' in opt['name']:
                    GTenh = train_data['GTenh']
                    GTinp = train_data['GTinp']

                    for imgs, name in zip([LQs, GTenh, GTinp],
                                          ['LQs', 'GTenh', 'GTinp']):
                        num = imgs.size(1)
                        for i in range(num):
                            img = util.tensor2img(imgs[0, i, ...])  # uint8
                            save_img_path = os.path.join(
                                img_dir, '{:4d}_{:s}_{:1d}.png'.format(
                                    train_idx, str(name), i))
                            util.save_img(img, save_img_path)
                else:
                    if 'GT' in train_data:
                        GT_name = 'GT'
                    elif 'GTs' in train_data:
                        GT_name = 'GTs'

                    GT = train_data[GT_name]
                    for imgs, name in zip([LQs, GT], ['LQs', GT_name]):
                        if name == 'GT':
                            num = imgs.size(0)
                            img = util.tensor2img(imgs[0, ...])  # uint8
                            save_img_path = os.path.join(
                                img_dir, '{:4d}_{:s}_{:1d}.png'.format(
                                    train_idx, str(name), 0))
                            util.save_img(img, save_img_path)
                        elif name == 'GTs':
                            num = imgs.size(1)
                            for i in range(num):
                                img = util.tensor2img(imgs[:, i, ...])  # uint8
                                save_img_path = os.path.join(
                                    img_dir, '{:4d}_{:s}_{:1d}.png'.format(
                                        train_idx, str(name), i))
                                util.save_img(img, save_img_path)
                        else:
                            num = imgs.size(1)
                            for i in range(num):
                                img = util.tensor2img(imgs[:, i, ...])  # uint8
                                save_img_path = os.path.join(
                                    img_dir, '{:4d}_{:s}_{:1d}.png'.format(
                                        train_idx, str(name), i))
                                util.save_img(img, save_img_path)

                if (train_idx >= 3):  # set to 0, just do validation
                    break

            # if pre-load weight first do validation and skip the first epoch
            # if opt['path'].get('pretrain_model_G', None) and epoch == 0:
            #     epoch += 1
            #     break

            if opt['train']['enable'] == False:
                message_train_loss = 'None'
                break

            current_step += 1
            if current_step > total_iters:
                print("Total Iteration Reached !")
                break

            #### update learning rate
            if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
                pass
            else:
                model.update_learning_rate(
                    current_step, warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)

            model.optimize_parameters(current_step)

            model.train_AverageMeter_update()

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs_inst, logs_avg = model.get_current_log(
                )  # training loss  mode='train'
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                # if 'debug' in opt['name']:  # debug model print the instant loss
                #     for k, v in logs_inst.items():
                #         message += '{:s}: {:.4e} '.format(k, v)
                #         # tensorboard logger
                #         if opt['use_tb_logger'] and 'debug' not in opt['name']:
                #             if rank <= 0:
                #                 tb_logger.add_scalar(k, v, current_step)
                # for avg loss
                current_iters_epoch = epoch * total_iters + current_step
                for k, v in logs_avg.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_iters_epoch)
                if rank <= 0:
                    logger.info(message)

        # saving models
        if epoch == 1:
            save_filename = '{:04d}_{}.pth'.format(0, 'G')
            save_path = os.path.join(opt['path']['models'], save_filename)
            if os.path.exists(save_path):
                os.remove(save_path)

        save_filename = '{:04d}_{}.pth'.format(epoch - 1, 'G')
        save_path = os.path.join(opt['path']['models'], save_filename)
        if os.path.exists(save_path):
            os.remove(save_path)

        if rank <= 0:
            logger.info('Saving models and training states.')
            save_filename = '{:04d}'.format(epoch)
            model.save(save_filename)

        # ======================================================================= #
        #                  Main validation loop                                   #
        # ======================================================================= #

        if opt['datasets'].get('val', None):
            if opt['dist']:
                # multi-GPU testing
                psnr_rlt = {}  # with border and center frames
                psnr_rlt_avg = {}
                psnr_total_avg = 0.

                ssim_rlt = {}  # with border and center frames
                ssim_rlt_avg = {}
                ssim_total_avg = 0.

                val_loss_rlt = {}  # the averaged loss
                val_loss_rlt_avg = {}
                val_loss_total_avg = 0.

                if rank == 0:
                    pbar = util.ProgressBar(len(val_set))

                for idx in range(
                        rank, len(val_set),
                        world_size):  # distributed parallel validation
                    # print('idx', idx)

                    if 'debug' in opt['name']:
                        if (idx >= 3):
                            break

                    if (idx >= 1000):
                        break
                    val_data = val_set[idx]
                    # use idx method to fetch must extend batch dimension
                    val_data['LQs'].unsqueeze_(0)
                    val_data['GTenh'].unsqueeze_(0)
                    val_data['GTinp'].unsqueeze_(0)

                    key = val_data['key'][0]  # IMG_0034_00809
                    max_idx = len(val_set)
                    val_name = 'val_set'
                    num = model.get_info(
                    )  # each model has different number of loss

                    if psnr_rlt.get(val_name, None) is None:
                        psnr_rlt[val_name] = torch.zeros([num, max_idx],
                                                         dtype=torch.float32,
                                                         device='cuda')

                    if ssim_rlt.get(val_name, None) is None:
                        ssim_rlt[val_name] = torch.zeros([num, max_idx],
                                                         dtype=torch.float32,
                                                         device='cuda')

                    if val_loss_rlt.get(val_name, None) is None:
                        val_loss_rlt[val_name] = torch.zeros(
                            [num, max_idx], dtype=torch.float32, device='cuda')

                    model.feed_data(val_data)

                    model.test()

                    avg_loss, loss_list = model.get_loss(ret=1)

                    save_enable = True
                    if idx >= 100:
                        save_enable = False

                    psnr_list, ssim_list = model.compute_current_psnr_ssim(
                        save=save_enable,
                        name=key,
                        save_path=opt['path']['val_images'])

                    # print('psnr_list',psnr_list)

                    assert len(loss_list) == num
                    assert len(psnr_list) == num

                    for i in range(num):
                        psnr_rlt[val_name][i, idx] = psnr_list[i]
                        ssim_rlt[val_name][i, idx] = ssim_list[i]
                        val_loss_rlt[val_name][i, idx] = loss_list[i]
                        # print('psnr_rlt[val_name][i, idx]',psnr_rlt[val_name][i, idx])
                        # print('ssim_rlt[val_name][i, idx]',ssim_rlt[val_name][i, idx])
                        # print('val_loss_rlt[val_name][i, idx] ',val_loss_rlt[val_name][i, idx] )

                    if rank == 0:
                        for _ in range(world_size):
                            pbar.update('Test {} - {}/{}'.format(
                                key, idx, max_idx))

                # # collect data
                for _, v in psnr_rlt.items():
                    for i in v:
                        dist.reduce(i, 0)

                for _, v in ssim_rlt.items():
                    for i in v:
                        dist.reduce(i, 0)

                for _, v in val_loss_rlt.items():
                    for i in v:
                        dist.reduce(i, 0)

                dist.barrier()

                if rank == 0:
                    psnr_rlt_avg = {}
                    psnr_total_avg = 0.
                    for k, v in psnr_rlt.items():  # key, value
                        # print('k', k, 'v', v, 'v.shape', v.shape)
                        psnr_rlt_avg[k] = []
                        for i in range(num):
                            non_zero_idx = v[i, :].nonzero()
                            # logger.info('non_zero_idx {}'.format(non_zero_idx.shape)) # check
                            matrix = v[i, :][non_zero_idx]
                            # print('matrix', matrix)
                            value = torch.mean(matrix).cpu().item()
                            # print('value', value)
                            psnr_rlt_avg[k].append(value)
                            psnr_total_avg += psnr_rlt_avg[k][i]
                    psnr_total_avg = psnr_total_avg / (len(psnr_rlt) * num)
                    log_p = '# Validation # Avg. PSNR: {:.2f},'.format(
                        psnr_total_avg)
                    for k, v in psnr_rlt_avg.items():
                        for i, it in enumerate(v):
                            log_p += ' {}: {:.2f}'.format(i, it)
                    logger.info(log_p)
                    logger_val.info(log_p)

                    # ssim
                    ssim_rlt_avg = {}
                    ssim_total_avg = 0.
                    for k, v in ssim_rlt.items():
                        ssim_rlt_avg[k] = []
                        for i in range(num):
                            non_zero_idx = v[i, :].nonzero()
                            # print('non_zero_idx', non_zero_idx)
                            matrix = v[i, :][non_zero_idx]
                            # print('matrix', matrix)
                            value = torch.mean(matrix).cpu().item()
                            # print('value', value)
                            ssim_rlt_avg[k].append(
                                torch.mean(matrix).cpu().item())
                            ssim_total_avg += ssim_rlt_avg[k][i]
                    ssim_total_avg /= (len(ssim_rlt) * num)
                    log_s = '# Validation # Avg. SSIM: {:.2f},'.format(
                        ssim_total_avg)
                    for k, v in ssim_rlt_avg.items():
                        for i, it in enumerate(v):
                            log_s += ' {}: {:.2f}'.format(i, it)
                    logger.info(log_s)
                    logger_val.info(log_s)

                    # added
                    val_loss_rlt_avg = {}
                    val_loss_total_avg = 0.
                    for k, v in val_loss_rlt.items():
                        # k, key, the folder name
                        # v, value, the torch matrix
                        val_loss_rlt_avg[k] = []  # loss0 - loss_N
                        for i in range(num):
                            non_zero_idx = v[i, :].nonzero()
                            # print('non_zero_idx', non_zero_idx)
                            matrix = v[i, :][non_zero_idx]
                            # print('matrix', matrix)
                            value = torch.mean(matrix).cpu().item()
                            # print('value', value)
                            val_loss_rlt_avg[k].append(
                                torch.mean(matrix).cpu().item())
                            val_loss_total_avg += val_loss_rlt_avg[k][i]
                    val_loss_total_avg /= (len(val_loss_rlt) * num)
                    log_l = '# Validation # Avg. Loss: {:.4e},'.format(
                        val_loss_total_avg)
                    for k, v in val_loss_rlt_avg.items():
                        for i, it in enumerate(v):
                            log_l += ' {}: {:.4e}'.format(i, it)
                    logger.info(log_l)
                    logger_val.info(log_l)

                    message = ''
                    for v in model.get_current_learning_rate():
                        message += '{:.5e}'.format(v)

                    logger_val.info(
                        'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f}, Val Loss {:.4e}'
                        .format(epoch, message, psnr_total_avg, ssim_total_avg,
                                val_loss_total_avg))

            else:
                pbar = util.ProgressBar(len(val_loader))

                model.val_loss_AverageMeter()
                model.val_AverageMeter_para()

                for val_inx, val_data in enumerate(val_loader):

                    # if 'debug' in opt['name']:
                    #     if (val_inx >= 10):
                    #         break

                    save_enable = True
                    if val_inx >= 100:
                        save_enable = False
                    if val_inx >= 100:
                        break

                    key = val_data['key'][0]

                    folder = key[:-6]
                    model.feed_data(val_data)

                    model.test()

                    avg_loss, loss_list = model.get_loss(ret=1)

                    model.val_loss_AverageMeter_update(loss_list, avg_loss)

                    psnr_list, ssim_list = model.compute_current_psnr_ssim(
                        save=save_enable,
                        name=key,
                        save_path=opt['path']['val_images'])

                    model.val_AverageMeter_para_update(psnr_list, ssim_list)

                    if 'debug' in opt['name']:
                        msg_psnr = ''
                        msg_ssim = ''
                        for i, psnr in enumerate(psnr_list):
                            msg_psnr += '{} :{:.02f} '.format(i, psnr)
                        for i, ssim in enumerate(ssim_list):
                            msg_ssim += '{} :{:.02f} '.format(i, ssim)

                        logger.info('{}_{:02d} {}'.format(
                            key, val_inx, msg_psnr))
                        logger.info('{}_{:02d} {}'.format(
                            key, val_inx, msg_ssim))

                    pbar.update('Test {} - {}'.format(key, val_inx))

                # toal validation log

                lr = ''
                for v in model.get_current_learning_rate():
                    lr += '{:.5e}'.format(v)

                logs_avg, logs_psnr_avg, psnr_total_avg, ssim_total_avg, val_loss_total_avg = model.get_current_log(
                    mode='val')

                msg_logs_avg = ''
                for k, v in logs_avg.items():
                    msg_logs_avg += '{:s}: {:.4e} '.format(k, v)

                logger_val.info('Val-Epoch {:02d}, LR {:s}, {:s}'.format(
                    epoch, lr, msg_logs_avg))
                logger.info('Val-Epoch {:02d}, LR {:s}, {:s}'.format(
                    epoch, lr, msg_logs_avg))

                msg_logs_psnr_avg = ''
                for k, v in logs_psnr_avg.items():
                    msg_logs_psnr_avg += '{:s}: {:.4e} '.format(k, v)

                logger_val.info('Val-Epoch {:02d}, LR {:s}, {:s}'.format(
                    epoch, lr, msg_logs_psnr_avg))
                logger.info('Val-Epoch {:02d}, LR {:s}, {:s}'.format(
                    epoch, lr, msg_logs_psnr_avg))

                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('val_psnr', psnr_total_avg, epoch)
                    tb_logger.add_scalar('val_loss', val_loss_total_avg, epoch)

        ############################################
        #
        #          end of validation, save model
        #
        ############################################
        #
        if rank <= 0:
            logger.info("Finished an epoch, Check and Save the model weights")
            # we check the validation loss instead of training loss. OK~
            if saved_total_loss >= val_loss_total_avg:
                saved_total_loss = val_loss_total_avg
                #torch.save(model.state_dict(), args.save_path + "/best" + ".pth")
                model.save('best')
                logger.info(
                    "Best Weights updated for decreased validation loss")
            else:
                logger.info(
                    "Weights Not updated for undecreased validation loss")
            if saved_total_PSNR <= psnr_total_avg:
                saved_total_PSNR = psnr_total_avg
                model.save('bestPSNR')
                logger.info(
                    "Best Weights updated for increased validation PSNR")

            else:
                logger.info(
                    "Weights Not updated for unincreased validation PSNR")

        ############################################
        #
        #          end of one epoch, schedule LR
        #
        ############################################

        model.train_AverageMeter_reset()

        # add scheduler  todo
        if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
            for scheduler in model.schedulers:
                # scheduler.step(val_loss_total_avg)
                scheduler.step(val_loss_total_avg)
    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('last')
        logger.info('End of training.')
        tb_logger.close()
Esempio n. 11
0
def SDR4k(mode):
    """Create lmdb for the REDS dataset, each image with a fixed size
    10bit: [3, 2160, 3840], key: 00000000_000
    4bit: [3, 2160, 3840], key: 00000000_000
    """
    #### configurations
    read_all_imgs = False  # whether real all images to memory with multiprocessing
    # Set False for use limited memory
    BATCH = 5000 # 5000  # After BATCH images, lmdb commits, if read_all_imgs = False
    if mode == "10bit":
        # img_folder = "..\\..\\datasets\\SDR_10bit"  # for windows
        # lmdb_save_path = "..\\..\\datasets\\SDR_10bit.lmdb"  # for windows
        img_folder = "../../datasets/SDR4k/train/SDR_10BIT_patch"  # for linux
        lmdb_save_path = "../../datasets/SDR4k/train/SDR_10BIT_patch.lmdb"  # for linux
        # H_dst, W_dst = 2160, 3840
        H_dst, W_dst = 480, 480
    elif mode == "4bit":
        # img_folder = "..\\..\\datasets\\SDR_4bit"  # for windows
        # lmdb_save_path = "..\\..\\datasets\\SDR_4bit.lmdb"
        img_folder = "../../datasets/SDR4k/train/SDR_4BIT_patch"  # for linux
        lmdb_save_path = "../../datasets/SDR4k/train/SDR_4BIT_patch.lmdb"  # for linux
        # H_dst, W_dst = 2160, 3840
        H_dst, W_dst = 480, 480
    n_thread = 40
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    all_img_list = data_util._get_paths_from_images(img_folder)
    keys = []
    for img_path in all_img_list:
        split_rlt = img_path.split('/')  # for linux
        # split_rlt = img_path.split("\\")  # for windows
        folder = split_rlt[-2]
        img_name = split_rlt[-1].split('.png')[0]
        keys.append(folder + '_' + img_name)
        # keys: 00000000_000_000

    if read_all_imgs:
        #### read all images to memory (multiprocessing)
        dataset = {}  # store all image data. list cannot keep the order, use dict
        print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
        pbar = util.ProgressBar(len(all_img_list))

        def mycallback(arg):
            '''get the image data and update pbar'''
            key = arg[0]
            dataset[key] = arg[1]
            pbar.update('Reading {}'.format(key))

        pool = Pool(n_thread)
        for path, key in zip(all_img_list, keys):
            pool.apply_async(read_image_worker, args=(path, key), callback=mycallback)
        pool.close()
        pool.join()
        print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))

    #### create lmdb environment
    data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    print("Start writing...")
    pbar = util.ProgressBar(len(all_img_list))
    txn = env.begin(write=True)
    rm_list = []
    for idx, (path, key) in enumerate(zip(all_img_list, keys)):
        pbar.update('Write {}'.format(key))
        key_byte = key.encode('ascii')
        data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
        H, W, C = data.shape
        assert H == H_dst and W == W_dst and C == 3, 'different shape.'
        txn.put(key_byte, data)
        # delete the image
        rm_list.append(path)
        if not read_all_imgs and idx % BATCH == 0:
            txn.commit()
            for img in rm_list:
                print('os.system('rm {}')'.format(img))
            rm_list = []
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    meta_info['name'] = 'SDR4k_{}'.format(mode)
    channel = 3
    meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
    meta_info['keys'] = keys
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 12
0
def OURS(mode="input"):
    '''create lmdb for the REDS dataset, each image with fixed size
    GT: [3, H, W], key: 000000_000000
    LR: [3, H, W], key: 000000_000000
    key: 000000_00000
    ** 记得前面我们的数据结构吗?{子目录名}_{图片名}
    '''
    #### configurations
    mode = mode  # ** 数据模式: input / gt
    read_all_imgs = False  # whether real all images to the memory. Set False with limited memory
    BATCH = 5000  # After BATCH images, lmdb commits, if read_all_imgs = False

    if mode == 'input':
        img_folder = './../../datasets/train/input'  # ** 使用相对路径指向我们的数据集的input
        lmdb_save_path = './../../datasets/train_input_wval.lmdb'  # ** 待会生成的lmdb文件存储的路径
        '''原来使用全局路径,我们使用相对路径'''
        H_dst, W_dst = 480, 640  # 帧的大小:H,W

    elif mode == 'gt':
        img_folder = './../../datasets/train/gt'  # ** 使用相对路径指向我们的数据集的input
        lmdb_save_path = './../../datasets/train_gt_wval.lmdb'  # ** 待会生成的lmdb文件存储的路径
        '''原来使用全局路径,我们使用相对路径'''
        H_dst, W_dst = 480, 640  # 帧的大小:H,W

    n_thread = 2
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError(
            "lmdb_save_path must end with \'lmdb\'.")  # 保存格式必须以“.lmdb”结尾
    #### whether the lmdb file exist
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(
            lmdb_save_path))  # 文件是否已经存在
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    all_img_list = data_util._get_paths_from_images(
        img_folder)  # 获取input/gt下所有帧的完整路径名,作为list
    keys = []
    for img_path in all_img_list:
        split_rlt = img_path.split('/')
        # 取子文件夹名 xxxxxx
        a = split_rlt[-2]
        # 取帧的名字,出去文件后缀 xxxxxx
        b = split_rlt[-1].split('.jpg')[0]  # ** 我们的图像是".jpg"结尾的
        keys.append(a + '_' + b)

    if read_all_imgs:  # read_all_images = False,所以这部分不管
        #### read all images to memory (multiprocessing)
        dataset = {
        }  # store all image data. list cannot keep the order, use dict
        print('Read images with multiprocessing, #thread: {} ...'.format(
            n_thread))
        pbar = util.ProgressBar(len(all_img_list))

        def mycallback(arg):
            '''get the image data and update pbar'''
            key = arg[0]
            dataset[key] = arg[1]
            pbar.update('Reading {}'.format(key))

        pool = Pool(n_thread)
        for path, key in zip(all_img_list, keys):
            pool.apply_async(read_image_worker,
                             args=(path, key),
                             callback=mycallback)
        pool.close()
        pool.join()
        print('Finish reading {} images.\nWrite lmdb...'.format(
            len(all_img_list)))

    #### create lmdb environment
    data_size_per_img = cv2.imread(
        all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes  # 每帧图像大小(byte为单位)
    if 'flow' in mode:
        data_size_per_img = dataset['000_00000002_n1'].nbytes
    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)  # 总的需要多少空间
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)  # 索取这么多的比特数
    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))
    txn = env.begin(write=True)
    idx = 1
    for path, key in zip(all_img_list, keys):

        idx = idx + 1
        pbar.update('Write {}'.format(key))
        key_byte = key.encode('ascii')
        data = dataset[key] if read_all_imgs else cv2.imread(
            path, cv2.IMREAD_UNCHANGED)
        if 'flow' in mode:
            H, W = data.shape
            assert H == H_dst and W == W_dst, 'different shape.'
        else:
            H, W, C = data.shape  # fixed shape
            assert H == H_dst and W == W_dst and C == 3, 'different shape.'
        txn.put(key_byte, data)
        if not read_all_imgs and idx % BATCH == 1:
            txn.commit()
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information      # 存储元数据:名字(str)+分辨率(str)
    meta_info = {}
    meta_info['name'] = 'OURS_{}_wval'.format(mode)  # ** 现在的数据集是OURS了
    if 'flow' in mode:
        meta_info['resolution'] = '{}_{}_{}'.format(1, H_dst, W_dst)
    else:
        meta_info['resolution'] = '{}_{}_{}'.format(3, H_dst, W_dst)
    meta_info['keys'] = keys
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
def vimeo7():
    '''create lmdb for the Vimeo90K-7 frames dataset, each image with fixed size
    GT: [3, 256, 448]
        Only need the 4th frame currently, e.g., 00001_0001_4
    LR: [3, 64, 112]
        With 1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7
    key:
        Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001
    '''
    #### configurations
    mode = 'GT'  # GT | LR
    batch = 3000 # TODO: depending on your mem size
    if mode == 'GT':
        img_folder = '/data/datasets/SR/vimeo_septuplet/sequences/train'
        lmdb_save_path = '/data/datasets/SR/vimeo_septuplet/vimeo7_train_GT.lmdb'
        txt_file = '/data/datasets/SR/vimeo_septuplet/sep_trainlist.txt'
        H_dst, W_dst = 256, 448
    elif mode == 'LR':
        img_folder = '/data/datasets/SR/vimeo_septuplet/sequences_LR/LR/x4/train'
        lmdb_save_path = '/data/datasets/SR/vimeo_septuplet/vimeo7_train_LR7.lmdb'
        txt_file = '/data/datasets/SR/vimeo_septuplet/sep_trainlist.txt'
        H_dst, W_dst = 64, 112
    n_thread = 40
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    #### whether the lmdb file exist
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    with open(txt_file) as f:
        train_l = f.readlines()
        train_l = [v.strip() for v in train_l]
    all_img_list = []
    keys = []
    for line in train_l:
        folder = line.split('/')[0]
        sub_folder = line.split('/')[1]
        file_l = glob.glob(osp.join(img_folder, folder, sub_folder) + '/*')
        all_img_list.extend(file_l)
        for j in range(7):
            keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
    all_img_list = sorted(all_img_list)
    keys = sorted(keys)
    if mode == 'GT': 
        all_img_list = [v for v in all_img_list if v.endswith('.png')]
        keys = [v for v in keys]
    print('Calculating the total size of images...')
    data_size = sum(os.stat(v).st_size for v in all_img_list)

    #### read all images to memory (multiprocessing)
    print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
    
    #### create lmdb environment
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
    txn = env.begin(write=True)  # txn is a Transaction object

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))

    i = 0
    for path, key in zip(all_img_list, keys):
        pbar.update('Write {}'.format(key))
        img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        key_byte = key.encode('ascii')
        H, W, C = img.shape  # fixed shape
        assert H == H_dst and W == W_dst and C == 3, 'different shape.'
        txn.put(key_byte, img)
        i += 1
        if  i % batch == 1:
            txn.commit()
            txn = env.begin(write=True)

    txn.commit()
    env.close()
    print('Finish reading and writing {} images.'.format(len(all_img_list)))
            
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    if mode == 'GT':
        meta_info['name'] = 'Vimeo7_train_GT'
    elif mode == 'LR':
        meta_info['name'] = 'Vimeo7_train_LR7'
    meta_info['resolution'] = '{}_{}_{}'.format(3, H_dst, W_dst)
    key_set = set()
    for key in keys:
        a, b, _ = key.split('_')
        key_set.add('{}_{}'.format(a, b))
    meta_info['keys'] = key_set
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'Vimeo7_train_keys.pkl'), "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 14
0
def VideoSR(mode):
    """Create lmdb for the Video dataset, each image with a fixed size
    LR: [3, 540, 960], key: 000_00000000
    GT: [3, 2160, 3840], key: 000_00000000
    key: 000_00000000

    """
    #### configurations
    read_all_imgs = False  # whether real all images to memory with multiprocessing
    # Set False for use limited memory
    BATCH = 2000  #5000  # After BATCH images, lmdb commits, if read_all_imgs = False
    if mode == 'GT':
        img_folder = '/home/yhliu/AI4K/train1_HR_png/'
        lmdb_save_path = '/home/yhliu/AI4K/train1_HR.lmdb'
        H_dst, W_dst = 2160, 3840
    elif mode == 'LR':
        #img_folder = '/home/yhliu/AI4K/contest2/train2_LR_png/'
        #lmdb_save_path = '/home/yhliu/AI4K/contest2/train2_LR.lmdb'
        img_folder = '/home/yhliu/BasicSR/results/trainLR_35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_re_100k_220000/trainLR_35_ResNet_alpha_beta_decoder_3x3_IN_encoder_8HW_re_100k_220000/'
        lmdb_save_path = '/home/yhliu/AI4K/contest2/train2_LR_35_220000.lmdb'

        H_dst, W_dst = 540, 960

    n_thread = 40  #40
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    all_img_list = data_util._get_paths_from_images(img_folder)
    keys = []
    for img_path in all_img_list:
        split_rlt = img_path.split('/')
        folder = split_rlt[-2]
        img_name = split_rlt[-1].split('.png')[0]
        keys.append(folder + '_' + img_name)

    if read_all_imgs:
        #### read all images to memory (multiprocessing)
        dataset = {
        }  # store all image data. list cannot keep the order, use dict
        print('Read images with multiprocessing, #thread: {} ...'.format(
            n_thread))
        pbar = util.ProgressBar(len(all_img_list))

        def mycallback(arg):
            '''get the image data and update pbar'''
            key = arg[0]
            dataset[key] = arg[1]
            pbar.update('Reading {}'.format(key))

        pool = Pool(n_thread)
        for path, key in zip(all_img_list, keys):
            pool.apply_async(read_image_worker,
                             args=(path, key),
                             callback=mycallback)
        pool.close()
        pool.join()
        print('Finish reading {} images.\nWrite lmdb...'.format(
            len(all_img_list)))

    #### create lmdb environment
    data_size_per_img = cv2.imread(all_img_list[0],
                                   cv2.IMREAD_UNCHANGED).nbytes
    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))
    txn = env.begin(write=True)
    for idx, (path, key) in enumerate(zip(all_img_list, keys)):
        pbar.update('Write {}'.format(key))
        key_byte = key.encode('ascii')
        data = dataset[key] if read_all_imgs else cv2.imread(
            path, cv2.IMREAD_UNCHANGED)
        if 'flow' in mode:
            H, W = data.shape
            assert H == H_dst and W == W_dst, 'different shape.'
        else:
            H, W, C = data.shape
            assert H == H_dst and W == W_dst and C == 3, 'different shape.'
        txn.put(key_byte, data)
        if not read_all_imgs and idx % BATCH == 0:
            txn.commit()
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    meta_info['name'] = 'AI4K_{}_train1'.format(mode)
    channel = 1 if 'flow' in mode else 3
    meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
    meta_info['keys'] = keys
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 15
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    parser.add_argument('--exp_name', type=str, default='temp')
    parser.add_argument('--degradation_type', type=str, default=None)
    parser.add_argument('--sigma_x', type=float, default=None)
    parser.add_argument('--sigma_y', type=float, default=None)
    parser.add_argument('--theta', type=float, default=None)
    args = parser.parse_args()
    if args.exp_name == 'temp':
        opt = option.parse(args.opt, is_train=False)
    else:
        opt = option.parse(args.opt, is_train=False, exp_name=args.exp_name)

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)
    inner_loop_name = opt['train']['maml']['optimizer'][0] + str(opt['train']['maml']['adapt_iter']) + str(math.floor(math.log10(opt['train']['maml']['lr_alpha'])))
    meta_loop_name = opt['train']['optim'][0] + str(math.floor(math.log10(opt['train']['lr_G'])))

    if args.degradation_type is not None:
        if args.degradation_type == 'preset':
            opt['datasets']['val']['degradation_mode'] = args.degradation_type
        else:
            opt['datasets']['val']['degradation_type'] = args.degradation_type
    if args.sigma_x is not None:
        opt['datasets']['val']['sigma_x'] = args.sigma_x
    if args.sigma_y is not None:
        opt['datasets']['val']['sigma_y'] = args.sigma_y
    if args.theta is not None:
        opt['datasets']['val']['theta'] = args.theta
    
    if 'degradation_mode' not in opt['datasets']['val'].keys():
        degradation_name = ''
    elif opt['datasets']['val']['degradation_mode'] == 'set':
        degradation_name = '_' + str(opt['datasets']['val']['degradation_type'])\
                  + '_' + str(opt['datasets']['val']['sigma_x']) \
                  + '_' + str(opt['datasets']['val']['sigma_y'])\
                  + '_' + str(opt['datasets']['val']['theta'])
    else:
        degradation_name = '_' + opt['datasets']['val']['degradation_mode']
    folder_name = opt['name'] + '_' + degradation_name

    if args.exp_name != 'temp':
        folder_name = args.exp_name

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            pass
        elif phase == 'val':
            if '+' in opt['datasets']['val']['name']:
                raise NotImplementedError('Do not use + signs in test mode')
            else:
                val_set = create_dataset(dataset_opt, scale=opt['scale'],
                                         kernel_size=opt['datasets']['train']['kernel_size'],
                                         model_name=opt['network_E']['which_model_E'])
                # val_set = loader.get_dataset(opt, train=False)
                val_loader = create_dataloader(val_set, dataset_opt, opt, None)

            print('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))

    #### create model
    models = create_model(opt)
    assert len(models) == 2
    model, est_model = models[0], models[1]
    modelcp, est_modelcp = create_model(opt)
    _, est_model_fixed = create_model(opt)

    center_idx = (opt['datasets']['val']['N_frames']) // 2
    lr_alpha = opt['train']['maml']['lr_alpha']
    update_step = opt['train']['maml']['adapt_iter']
    with_GT = False if opt['datasets']['val']['mode'] == 'demo' else True

    pd_log = pd.DataFrame(columns=['PSNR_Bicubic', 'PSNR_Ours', 'SSIM_Bicubic', 'SSIM_Ours'])

    def crop(LR_seq, HR, num_patches_for_batch=4, patch_size=44):
        """
        Crop given patches.

        Args:
            LR_seq: (B=1) x T x C x H x W
            HR: (B=1) x C x H x W

            patch_size (int, optional):

        Return:
            B(=batch_size) x T x C x H x W
        """
        # Find the lowest resolution
        cropped_lr = []
        cropped_hr = []
        assert HR.size(0) == 1
        LR_seq_ = LR_seq[0]
        HR_ = HR[0]
        for _ in range(num_patches_for_batch):
            patch_lr, patch_hr = preprocessing.common_crop(LR_seq_, HR_, patch_size=patch_size // 2)
            cropped_lr.append(patch_lr)
            cropped_hr.append(patch_hr)

        cropped_lr = torch.stack(cropped_lr, dim=0)
        cropped_hr = torch.stack(cropped_hr, dim=0)

        return cropped_lr, cropped_hr

    # Single GPU
    # PSNR_rlt: psnr_init, psnr_before, psnr_after
    psnr_rlt = [{}, {}]
    # SSIM_rlt: ssim_init, ssim_after
    ssim_rlt = [{}, {}]

    pbar = util.ProgressBar(len(val_set))
    for val_data in val_loader:
        folder = val_data['folder'][0]
        idx_d = int(val_data['idx'][0].split('/')[0])
        if 'name' in val_data.keys():
            name = val_data['name'][0][center_idx][0]
        else:
            name = folder

        train_folder = os.path.join('../test_results', folder_name, name)
        maml_train_folder = os.path.join(train_folder, 'DynaVSR')

        if not os.path.exists(train_folder):
            os.makedirs(train_folder, exist_ok=False)
        if not os.path.exists(maml_train_folder):
            os.mkdir(maml_train_folder)

        for i in range(len(psnr_rlt)):
            if psnr_rlt[i].get(folder, None) is None:
                psnr_rlt[i][folder] = []
        for i in range(len(ssim_rlt)):
            if ssim_rlt[i].get(folder, None) is None:
                ssim_rlt[i][folder] = []
        
        cropped_meta_train_data = {}
        meta_train_data = {}
        meta_test_data = {}

        # Make SuperLR seq using estimation model
        meta_train_data['GT'] = val_data['LQs'][:, center_idx]
        meta_test_data['LQs'] = val_data['LQs'][0:1]
        meta_test_data['GT'] = val_data['GT'][0:1, center_idx] if with_GT else None
        # Check whether the batch size of each validation data is 1
        assert val_data['LQs'].size(0) == 1

        if opt['network_G']['which_model_G'] == 'TOF':
            LQs = meta_test_data['LQs']
            B, T, C, H, W = LQs.shape
            LQs = LQs.reshape(B*T, C, H, W)
            Bic_LQs = F.interpolate(LQs, scale_factor=opt['scale'], mode='bicubic', align_corners=True)
            meta_test_data['LQs'] = Bic_LQs.reshape(B, T, C, H*opt['scale'], W*opt['scale'])
        
        ## Before start testing
        # Bicubic Model Results
        modelcp.load_network(opt['path']['bicubic_G'], modelcp.netG)
        modelcp.feed_data(meta_test_data, need_GT=with_GT)
        modelcp.test()

        if with_GT:
            model_start_visuals = modelcp.get_current_visuals(need_GT=True)
            hr_image = util.tensor2img(model_start_visuals['GT'], mode='rgb')
            start_image = util.tensor2img(model_start_visuals['rlt'], mode='rgb')
            psnr_rlt[0][folder].append(util.calculate_psnr(start_image, hr_image))
            ssim_rlt[0][folder].append(util.calculate_ssim(start_image, hr_image))

        modelcp.netG, est_modelcp.netE = deepcopy(model.netG), deepcopy(est_model.netE)

        ########## SLR LOSS Preparation ############
        est_model_fixed.load_network(opt['path']['fixed_E'], est_model_fixed.netE)

        optim_params = []
        for k, v in modelcp.netG.named_parameters():
            if v.requires_grad:
                optim_params.append(v)
        
        if not opt['train']['use_real']:
            for k, v in est_modelcp.netE.named_parameters():
                if v.requires_grad:
                    optim_params.append(v)
        
        if opt['train']['maml']['optimizer'] == 'Adam':
            inner_optimizer = torch.optim.Adam(optim_params, lr=lr_alpha,
                                               betas=(
                                                   opt['train']['maml']['beta1'],
                                                   opt['train']['maml']['beta2']))
        elif opt['train']['maml']['optimizer'] == 'SGD':
            inner_optimizer = torch.optim.SGD(optim_params, lr=lr_alpha)
        else:
            raise NotImplementedError()

        # Inner Loop Update
        st = time.time()
        for i in range(update_step):
            # Make SuperLR seq using UPDATED estimation model
            if not opt['train']['use_real']:
                est_modelcp.feed_data(val_data)
                est_modelcp.forward_without_optim()
                superlr_seq = est_modelcp.fake_L
                meta_train_data['LQs'] = superlr_seq
            else:
                meta_train_data['LQs'] = val_data['SuperLQs']

            if opt['network_G']['which_model_G'] == 'TOF':
                # Bicubic upsample to match the size
                LQs = meta_train_data['LQs']
                B, T, C, H, W = LQs.shape
                LQs = LQs.reshape(B*T, C, H, W)
                Bic_LQs = F.interpolate(LQs, scale_factor=opt['scale'], mode='bicubic', align_corners=True)
                meta_train_data['LQs'] = Bic_LQs.reshape(B, T, C, H*opt['scale'], W*opt['scale'])

            # Update both modelcp + estmodelcp jointly
            inner_optimizer.zero_grad()
            if opt['train']['maml']['use_patch']:
                cropped_meta_train_data['LQs'], cropped_meta_train_data['GT'] = \
                    crop(meta_train_data['LQs'], meta_train_data['GT'],
                         opt['train']['maml']['num_patch'],
                         opt['train']['maml']['patch_size'])
                modelcp.feed_data(cropped_meta_train_data)
            else:
                modelcp.feed_data(meta_train_data)

            loss_train = modelcp.calculate_loss()
            
            ##################### SLR LOSS ###################
            est_model_fixed.feed_data(val_data)
            est_model_fixed.test()
            slr_initialized = est_model_fixed.fake_L
            slr_initialized = slr_initialized.to('cuda')
            if opt['network_G']['which_model_G'] == 'TOF':
                loss_train += 10 * F.l1_loss(LQs.to('cuda').squeeze(0), slr_initialized)
            else:
                loss_train += 10 * F.l1_loss(meta_train_data['LQs'].to('cuda'), slr_initialized)
            
            loss_train.backward()
            inner_optimizer.step()

        et = time.time()
        update_time = et - st

        modelcp.feed_data(meta_test_data, need_GT=with_GT)
        modelcp.test()

        model_update_visuals = modelcp.get_current_visuals(need_GT=False)
        update_image = util.tensor2img(model_update_visuals['rlt'], mode='rgb')
        # Save and calculate final image
        imageio.imwrite(os.path.join(maml_train_folder, '{:08d}.png'.format(idx_d)), update_image)

        if with_GT:
            psnr_rlt[1][folder].append(util.calculate_psnr(update_image, hr_image))
            ssim_rlt[1][folder].append(util.calculate_ssim(update_image, hr_image))

            name_df = '{}/{:08d}'.format(folder, idx_d)
            if name_df in pd_log.index:
                pd_log.at[name_df, 'PSNR_Bicubic'] = psnr_rlt[0][folder][-1]
                pd_log.at[name_df, 'PSNR_Ours'] = psnr_rlt[1][folder][-1]
                pd_log.at[name_df, 'SSIM_Bicubic'] = ssim_rlt[0][folder][-1]
                pd_log.at[name_df, 'SSIM_Ours'] = ssim_rlt[1][folder][-1]
            else:
                pd_log.loc[name_df] = [psnr_rlt[0][folder][-1],
                                    psnr_rlt[1][folder][-1],
                                    ssim_rlt[0][folder][-1], ssim_rlt[1][folder][-1]]

            pd_log.to_csv(os.path.join('../test_results', folder_name, 'psnr_update.csv'))

            pbar.update('Test {} - {}: I: {:.3f}/{:.4f} \tF+: {:.3f}/{:.4f} \tTime: {:.3f}s'
                            .format(folder, idx_d,
                                    psnr_rlt[0][folder][-1], ssim_rlt[0][folder][-1],
                                    psnr_rlt[1][folder][-1], ssim_rlt[1][folder][-1],
                                    update_time
                                    ))
        else:
            pbar.update()

    if with_GT:
        psnr_rlt_avg = {}
        psnr_total_avg = 0.
        # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2])
        for k, v in psnr_rlt[0].items():
            psnr_rlt_avg[k] = sum(v) / len(v)
            psnr_total_avg += psnr_rlt_avg[k]
        psnr_total_avg /= len(psnr_rlt[0])
        log_s = '# Validation # Bic PSNR: {:.4e}:'.format(psnr_total_avg)
        for k, v in psnr_rlt_avg.items():
            log_s += ' {}: {:.4e}'.format(k, v)
        print(log_s)

        psnr_rlt_avg = {}
        psnr_total_avg = 0.
        # Just calculate the final value of psnr_rlt(i.e. psnr_rlt[2])
        for k, v in psnr_rlt[1].items():
            psnr_rlt_avg[k] = sum(v) / len(v)
            psnr_total_avg += psnr_rlt_avg[k]
        psnr_total_avg /= len(psnr_rlt[1])
        log_s = '# Validation # PSNR: {:.4e}:'.format(psnr_total_avg)
        for k, v in psnr_rlt_avg.items():
            log_s += ' {}: {:.4e}'.format(k, v)
        print(log_s)

        ssim_rlt_avg = {}
        ssim_total_avg = 0.
        # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1])
        for k, v in ssim_rlt[0].items():
            ssim_rlt_avg[k] = sum(v) / len(v)
            ssim_total_avg += ssim_rlt_avg[k]
        ssim_total_avg /= len(ssim_rlt[0])
        log_s = '# Validation # Bicubic SSIM: {:.4e}:'.format(ssim_total_avg)
        for k, v in ssim_rlt_avg.items():
            log_s += ' {}: {:.4e}'.format(k, v)
        print(log_s)

        ssim_rlt_avg = {}
        ssim_total_avg = 0.
        # Just calculate the final value of ssim_rlt(i.e. ssim_rlt[1])
        for k, v in ssim_rlt[1].items():
            ssim_rlt_avg[k] = sum(v) / len(v)
            ssim_total_avg += ssim_rlt_avg[k]
        ssim_total_avg /= len(ssim_rlt[1])
        log_s = '# Validation # SSIM: {:.4e}:'.format(ssim_total_avg)
        for k, v in ssim_rlt_avg.items():
            log_s += ' {}: {:.4e}'.format(k, v)
        print(log_s)

    print('End of evaluation.')
Esempio n. 16
0
def vimeo90k():
    '''create lmdb for the Vimeo90K dataset, each image with fixed size
    GT: [3, 256, 448]
        Only need the 4th frame currently, e.g., 00001_0001_4
    LR: [3, 64, 112]
        With 1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7
    key:
        Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001
    '''
    #### configurations
    mode = 'GT'  # GT | LR
    if mode == 'GT':
        img_folder = '/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences'
        lmdb_save_path = '/home/xtwang/datasets/vimeo90k/vimeo90k_train_GT.lmdb'
        txt_file = '/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
        H_dst, W_dst = 256, 448
    elif mode == 'LR':
        img_folder = '/home/xtwang/datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences'
        lmdb_save_path = '/home/xtwang/datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
        txt_file = '/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
        H_dst, W_dst = 64, 112
    n_thread = 40
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    #### whether the lmdb file exist
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    with open(txt_file) as f:
        train_l = f.readlines()
        train_l = [v.strip() for v in train_l]
    all_img_list = []
    keys = []
    for line in train_l:
        folder = line.split('/')[0]
        sub_folder = line.split('/')[1]
        file_l = glob.glob(osp.join(img_folder, folder, sub_folder) + '/*')
        all_img_list.extend(file_l)
        for j in range(7):
            keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
    all_img_list = sorted(all_img_list)
    keys = sorted(keys)
    if mode == 'GT':  # read the 4th frame only for GT mode
        print('Only keep the 4th frame.')
        all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
        keys = [v for v in keys if v.endswith('_4')]

    #### read all images to memory (multiprocessing)
    dataset = {}  # store all image data. list cannot keep the order, use dict
    print('Read images with multiprocessing, #thread: {} ...'.format(n_thread))
    pbar = util.ProgressBar(len(all_img_list))

    def mycallback(arg):
        '''get the image data and update pbar'''
        key = arg[0]
        dataset[key] = arg[1]
        pbar.update('Reading {}'.format(key))

    pool = Pool(n_thread)
    for path, key in zip(all_img_list, keys):
        pool.apply_async(reading_image_worker,
                         args=(path, key),
                         callback=mycallback)
    pool.close()
    pool.join()
    print('Finish reading {} images.\nWrite lmdb...'.format(len(all_img_list)))

    #### create lmdb environment
    data_size_per_img = dataset['00001_0001_4'].nbytes
    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))
    with env.begin(write=True) as txn:
        for key in keys:
            pbar.update('Write {}'.format(key))
            key_byte = key.encode('ascii')
            data = dataset[key]
            H, W, C = data.shape  # fixed shape
            assert H == H_dst and W == W_dst and C == 3, 'different shape.'
            txn.put(key_byte, data)
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    if mode == 'GT':
        meta_info['name'] = 'Vimeo90K_train_GT'
    elif mode == 'LR':
        meta_info['name'] = 'Vimeo90K_train_LR'
    meta_info['resolution'] = '{}_{}_{}'.format(3, H_dst, W_dst)
    key_set = set()
    for key in keys:
        a, b, _ = key.split('_')
        key_set.add('{}_{}'.format(a, b))
    meta_info['keys'] = key_set
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 17
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch', 'slurm'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)
    #pdb.set_trace()
    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist(args.launcher)
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    #pdb.set_trace()
    dataset_ratio = 1000  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)
            #### validation
            if opt['datasets'].get(
                    'val',
                    None) and current_step % opt['train']['val_freq'] == 0:
                if rank <= 0:
                    # does not support multi-GPU validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(
                            os.path.basename(val_data['LQ_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'],
                                               img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals['SR'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # Save SR images for reference
                        save_img_path = os.path.join(
                            img_dir,
                            '{:s}_{:d}.png'.format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                          opt['scale'])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        pbar.update('Test {}'.format(img_name))

                    avg_psnr = avg_psnr / idx

                    # log
                    logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr', avg_psnr, current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)
            if current_step % 50000 == 0:
                torch.cuda.empty_cache()

        torch.cuda.empty_cache()

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
Esempio n. 18
0
def AI4K(model='gt'):
    model = model
    read_all_imgs = False
    BATCH = 700
    if model == 'gt':
        img_folder = 'dataset/gt'
        lmdb_save_path = 'dataset/train_gt_wval.lmdb'
        H_dst, W_dst = 2160, 3840
    if model == 'X4':
        img_folder = 'dataset/X4'
        lmdb_save_path = 'dataset/train_x4_wval.lmdb'
        H_dst, W_dst = 540, 960
    n_thread = 40
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    #### whether the lmdb file exist
    # if osp.exists(lmdb_save_path):
    #     print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
    #     sys.exit(1)
    print('Reading image path list ...')
    all_clips_list = sorted(os.listdir(img_folder))
    all_clips_list_path = []
    for x in all_clips_list:
        all_clips_list_path.append(os.path.join(img_folder, x))

    keys = []
    all_imgs_path = []
    index_clip = 0
    for clips_path in all_clips_list_path:
        index_clip += 1
        if model == 'X4':
            for imgs_x4_path in data_util._get_paths_from_images(clips_path):
                all_imgs_path.append(imgs_x4_path)
            for index_imgs_x4 in range(100):
                a = (index_imgs_x4 + 1) // 7 + 1
                b = (index_imgs_x4 + 1) % 7
                if b == 0:
                    b = 7
                c = '%.5d' % (index_clip) + '_' + '%.4d' % (a) + '_' + '%d' % (
                    b)
                keys.append(c)
        else:
            for index, imgs_path in enumerate(
                    data_util._get_paths_from_images(clips_path)):
                if index % 7 == 3:
                    all_imgs_path.append(imgs_path)
            for index_imgs_gt in range(100):
                if index_imgs_gt % 7 == 3:
                    a = (index_imgs_gt + 1) // 7 + 1
                    c = '%.5d' % (index_clip) + '_' + '%.4d' % (a) + '_4'
                    keys.append(c)

    data_size_per_img = cv2.imread(all_imgs_path[0],
                                   cv2.IMREAD_UNCHANGED).nbytes
    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_imgs_path)

    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
    pbar = util.ProgressBar(len(all_imgs_path))
    txn = env.begin(write=True)
    idx = 1
    for path, key in zip(all_imgs_path, keys):
        idx = idx + 1
        pbar.update('Write {}'.format(key))
        key_byte = key.encode('ascii')
        data = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        H, W, C = data.shape  # fixed shape
        assert H == H_dst and W == W_dst and C == 3, 'different shape.'
        txn.put(key_byte, data)
        if not read_all_imgs and idx % BATCH == 1:
            txn.commit()
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')
    #### create meta information
    meta_info = {}
    if model == 'gt':
        meta_info['name'] = 'AI4K_train_GT'
    elif model == 'X4':
        meta_info['name'] = 'AI4K_train_X4'
    meta_info['resolution'] = '{}_{}_{}'.format(3, H_dst, W_dst)
    key_set = set()
    for key in keys:
        a, b, _ = key.split('_')
        key_set.add('{}_{}'.format(a, b))
    meta_info['keys'] = key_set
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 19
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt['path'].items() if
                 not key == 'experiments_root' and 'pretrain_model' not in key
                 and 'resume' not in key and 'wandb_load_run_path' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
        if opt['use_wandb_logger'] and 'debug' not in opt['name']:
            json_path = os.path.join(os.path.expanduser('~'),
                                     '.wandb_api_keys.json')
            if os.path.exists(json_path):
                with open(json_path, 'r') as j:
                    json_file = json.loads(j.read())
                    os.environ['WANDB_API_KEY'] = json_file['ryul99']
            wandb.init(project="mmsr", config=opt, sync_tensorboard=True)
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
        if opt['use_wandb_logger'] and 'debug' not in opt['name']:
            wandb.config.update({'random_seed': seed})
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data,
                            noise_mode=opt['datasets']['train']['noise_mode'],
                            noise_rate=opt['datasets']['train']['noise_rate'])
            model.optimize_parameters(current_step)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                    if opt['use_wandb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            wandb.log({k: v}, step=current_step)
                if rank <= 0:
                    logger.info(message)
            #### validation
            if opt['datasets'].get(
                    'val',
                    None) and current_step % opt['train']['val_freq'] == 0:
                if opt['model'] in [
                        'sr', 'srgan'
                ] and rank <= 0:  # image restoration validation
                    # does not support multi-GPU validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(
                            os.path.basename(val_data['LQ_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'],
                                               img_name)
                        util.mkdir(img_dir)

                        model.feed_data(
                            val_data,
                            noise_mode=opt['datasets']['val']['noise_mode'],
                            noise_rate=opt['datasets']['val']['noise_rate'])
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # Save SR images for reference
                        save_img_path = os.path.join(
                            img_dir,
                            '{:s}_{:d}.png'.format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                          opt['scale'])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        pbar.update('Test {}'.format(img_name))

                    avg_psnr = avg_psnr / idx

                    # log
                    logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    if opt['use_wandb_logger'] and 'debug' not in opt['name']:
                        wandb.log({'psnr': avg_psnr}, step=current_step)
                else:  # video restoration validation
                    if opt['dist']:
                        # multi-GPU testing
                        psnr_rlt = {}  # with border and center frames
                        if rank == 0:
                            pbar = util.ProgressBar(len(val_set))
                        for idx in range(rank, len(val_set), world_size):
                            val_data = val_set[idx]
                            val_data['LQs'].unsqueeze_(0)
                            val_data['GT'].unsqueeze_(0)
                            folder = val_data['folder']
                            idx_d, max_idx = val_data['idx'].split('/')
                            idx_d, max_idx = int(idx_d), int(max_idx)
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = torch.zeros(
                                    max_idx,
                                    dtype=torch.float32,
                                    device='cuda')
                            # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                            model.feed_data(val_data,
                                            noise_mode=opt['datasets']['val']
                                            ['noise_mode'],
                                            noise_rate=opt['datasets']['val']
                                            ['noise_rate'])
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8
                            # calculate PSNR
                            psnr_rlt[folder][idx_d] = util.calculate_psnr(
                                rlt_img, gt_img)

                            if rank == 0:
                                for _ in range(world_size):
                                    pbar.update('Test {} - {}/{}'.format(
                                        folder, idx_d, max_idx))
                        # # collect data
                        for _, v in psnr_rlt.items():
                            dist.reduce(v, 0)
                        dist.barrier()

                        if rank == 0:
                            psnr_rlt_avg = {}
                            psnr_total_avg = 0.
                            for k, v in psnr_rlt.items():
                                psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                                psnr_total_avg += psnr_rlt_avg[k]
                            psnr_total_avg /= len(psnr_rlt)
                            log_s = '# Validation # PSNR: {:.4e}:'.format(
                                psnr_total_avg)
                            for k, v in psnr_rlt_avg.items():
                                log_s += ' {}: {:.4e}'.format(k, v)
                            logger.info(log_s)
                            if opt['use_tb_logger'] and 'debug' not in opt[
                                    'name']:
                                tb_logger.add_scalar('psnr_avg',
                                                     psnr_total_avg,
                                                     current_step)
                                for k, v in psnr_rlt_avg.items():
                                    tb_logger.add_scalar(k, v, current_step)
                            if opt['use_wandb_logger'] and 'debug' not in opt[
                                    'name']:
                                lq_img, rlt_img, gt_img = map(
                                    util.tensor2img, [
                                        visuals['LQ'], visuals['rlt'],
                                        visuals['GT']
                                    ])
                                wandb.log({'psnr_avg': psnr_total_avg},
                                          step=current_step)
                                wandb.log(psnr_rlt_avg, step=current_step)
                                wandb.log(
                                    {
                                        'Validation Image': [
                                            wandb.Image(lq_img[:, :,
                                                               [2, 1, 0]],
                                                        caption='LQ'),
                                            wandb.Image(rlt_img[:, :,
                                                                [2, 1, 0]],
                                                        caption='output'),
                                            wandb.Image(gt_img[:, :,
                                                               [2, 1, 0]],
                                                        caption='GT'),
                                        ]
                                    },
                                    step=current_step)
                    else:
                        pbar = util.ProgressBar(len(val_loader))
                        psnr_rlt = {}  # with border and center frames
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for val_data in val_loader:
                            folder = val_data['folder'][0]
                            idx_d = val_data['idx'].item()
                            # border = val_data['border'].item()
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = []

                            model.feed_data(val_data,
                                            noise_mode=opt['datasets']['val']
                                            ['noise_mode'],
                                            noise_rate=opt['datasets']['val']
                                            ['noise_rate'])
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8

                            # calculate PSNR
                            psnr = util.calculate_psnr(rlt_img, gt_img)
                            psnr_rlt[folder].append(psnr)
                            pbar.update('Test {} - {}'.format(folder, idx_d))
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = sum(v) / len(v)
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)
                        if opt['use_wandb_logger'] and 'debug' not in opt[
                                'name']:
                            lq_img, rlt_img, gt_img = map(
                                util.tensor2img,
                                [visuals['LQ'], visuals['rlt'], visuals['GT']])
                            wandb.log({'psnr_avg': psnr_total_avg},
                                      step=current_step)
                            wandb.log(psnr_rlt_avg, step=current_step)
                            wandb.log(
                                {
                                    'Validation Image': [
                                        wandb.Image(lq_img[:, :, [2, 1, 0]],
                                                    caption='LQ'),
                                        wandb.Image(rlt_img[:, :, [2, 1, 0]],
                                                    caption='output'),
                                        wandb.Image(gt_img[:, :, [2, 1, 0]],
                                                    caption='GT'),
                                    ]
                                },
                                step=current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            tb_logger.close()
Esempio n. 20
0
def REDS():
    """create lmdb for the REDS dataset, each image with fixed size
    GT: [3, 720, 1280], key: 000_00000000
    LR: [3, 180, 320], key: 000_00000000
    key: 000_00000000
    """
    #### configurations
    mode = "train_sharp"
    read_all_imgs = (
        False
    )  # whether real all images to the memory. Set False with limited memory
    BATCH = 5000  # After BATCH images, lmdb commits, if read_all_imgs = False
    # train_sharp | train_sharp_bicubic | train_blur_bicubic| train_blur | train_blur_comp
    if mode == "train_sharp":
        img_folder = "/home/xtwang/datasets/REDS/train_sharp"
        lmdb_save_path = "/home/xtwang/datasets/REDS/train_sharp_wval.lmdb"
        H_dst, W_dst = 720, 1280
    elif mode == "train_sharp_bicubic":
        img_folder = "/home/xtwang/datasets/REDS/train_sharp_bicubic"
        lmdb_save_path = "/home/xtwang/datasets/REDS/train_sharp_bicubic_wval.lmdb"
        H_dst, W_dst = 180, 320
    elif mode == "train_blur_bicubic":
        img_folder = "/home/xtwang/datasets/REDS/train_blur_bicubic"
        lmdb_save_path = "/home/xtwang/datasets/REDS/train_blur_bicubic_wval.lmdb"
        H_dst, W_dst = 180, 320
    elif mode == "train_blur":
        img_folder = "/home/xtwang/datasets/REDS/train_blur"
        lmdb_save_path = "/home/xtwang/datasets/REDS/train_blur_wval.lmdb"
        H_dst, W_dst = 720, 1280
    elif mode == "train_blur_comp":
        img_folder = "/home/xtwang/datasets/REDS/train_blur_comp"
        lmdb_save_path = "/home/xtwang/datasets/REDS/train_blur_comp_wval.lmdb"
        H_dst, W_dst = 720, 1280
    n_thread = 40
    ########################################################
    if not lmdb_save_path.endswith(".lmdb"):
        raise ValueError("lmdb_save_path must end with 'lmdb'.")
    #### whether the lmdb file exist
    if osp.exists(lmdb_save_path):
        print("Folder [{:s}] already exists. Exit...".format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print("Reading image path list ...")
    all_img_list = data_util._get_paths_from_images(img_folder)
    keys = []
    for img_path in all_img_list:
        split_rlt = img_path.split("/")
        a = split_rlt[-2]
        b = split_rlt[-1].split(".png")[0]
        keys.append(a + "_" + b)

    if read_all_imgs:
        #### read all images to memory (multiprocessing)
        dataset = {
        }  # store all image data. list cannot keep the order, use dict
        print("Read images with multiprocessing, #thread: {} ...".format(
            n_thread))
        pbar = util.ProgressBar(len(all_img_list))

        def mycallback(arg):
            """get the image data and update pbar"""
            key = arg[0]
            dataset[key] = arg[1]
            pbar.update("Reading {}".format(key))

        pool = Pool(n_thread)
        for path, key in zip(all_img_list, keys):
            pool.apply_async(reading_image_worker,
                             args=(path, key),
                             callback=mycallback)
        pool.close()
        pool.join()
        print("Finish reading {} images.\nWrite lmdb...".format(
            len(all_img_list)))

    #### create lmdb environment
    data_size_per_img = cv2.imread(all_img_list[0],
                                   cv2.IMREAD_UNCHANGED).nbytes
    if "flow" in mode:
        data_size_per_img = dataset["000_00000002_n1"].nbytes
    print("data size per image is: ", data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))
    txn = env.begin(write=True)
    idx = 1
    for path, key in zip(all_img_list, keys):
        idx = idx + 1
        pbar.update("Write {}".format(key))
        key_byte = key.encode("ascii")
        data = dataset[key] if read_all_imgs else cv2.imread(
            path, cv2.IMREAD_UNCHANGED)
        if "flow" in mode:
            H, W = data.shape
            assert H == H_dst and W == W_dst, "different shape."
        else:
            H, W, C = data.shape  # fixed shape
            assert H == H_dst and W == W_dst and C == 3, "different shape."
        txn.put(key_byte, data)
        if not read_all_imgs and idx % BATCH == 1:
            txn.commit()
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print("Finish writing lmdb.")

    #### create meta information
    meta_info = {}
    meta_info["name"] = "REDS_{}_wval".format(mode)
    if "flow" in mode:
        meta_info["resolution"] = "{}_{}_{}".format(1, H_dst, W_dst)
    else:
        meta_info["resolution"] = "{}_{}_{}".format(3, H_dst, W_dst)
    meta_info["keys"] = keys
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, "meta_info.pkl"),
                                "wb"))
    print("Finish creating lmdb meta info.")
def HDR(mode):
    """Create lmdb for the REDS dataset, each image with a fixed size
    GT: [3, 720, 1280], key: 000_00000000
    LR: [3, 180, 320], key: 000_00000000
    key: 000_00000000

    flow: downsampled flow: [3, 360, 320], keys: 000_00000005_[p2, p1, n1, n2]
        Each flow is calculated with the GT images by PWCNet and then downsampled by 1/4
        Flow map is quantized by mmcv and saved in png format
    """
    #### configurations
    read_all_imgs = False  # whether real all images to memory with multiprocessing
    # Set False for use limited memory
    BATCH = args.batch  # After BATCH images, lmdb commits, if read_all_imgs = False
    if mode == 'train_sharp':
        # img_folder = '../../datasets/REDS/train_sharp'
        # lmdb_save_path = '../../datasets/REDS/train_sharp_wval.lmdb'
        img_folder = '/DATA/wangshen_data/REDS/train_sharp'
        lmdb_save_path = '/DATA/wangshen_data/REDS/train_sharp_wval.lmdb'
        H_dst, W_dst = 720, 1280
    elif mode == 'train_sharp_bicubic':
        # img_folder = '../../datasets/REDS/train_sharp_bicubic'
        # lmdb_save_path = '../../datasets/REDS/train_sharp_bicubic_wval.lmdb'
        img_folder = '/DATA/wangshen_data/REDS/train_sharp_bicubic'
        lmdb_save_path = '/DATA/wangshen_data/REDS/train_sharp_bicubic_wval.lmdb'
        H_dst, W_dst = 180, 320
    elif mode == 'train_blur_bicubic':
        img_folder = '../../datasets/REDS/train_blur_bicubic'
        lmdb_save_path = '../../datasets/REDS/train_blur_bicubic_wval.lmdb'
        H_dst, W_dst = 180, 320
    elif mode == 'train_blur':
        img_folder = '../../datasets/REDS/train_blur'
        lmdb_save_path = '../../datasets/REDS/train_blur_wval.lmdb'
        H_dst, W_dst = 720, 1280
    elif mode == 'train_blur_comp':
        img_folder = '../../datasets/REDS/train_blur_comp'
        lmdb_save_path = '../../datasets/REDS/train_blur_comp_wval.lmdb'
        H_dst, W_dst = 720, 1280
    elif mode == 'train_sharp_flowx4':
        img_folder = '../../datasets/REDS/train_sharp_flowx4'
        lmdb_save_path = '../../datasets/REDS/train_sharp_flowx4.lmdb'
        H_dst, W_dst = 360, 320
    elif mode == 'train_540p':
        img_folder = "/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/train_540p"
        lmdb_save_path = '/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/{}.lmdb'.format(
            args.name)
        H_dst, W_dst = 540, 960
    elif mode == 'train_4k':
        img_folder = "/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/train_4k"
        lmdb_save_path = '/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/{}.lmdb'.format(
            args.name)
        H_dst, W_dst = 540, 960
        H_dst, W_dst = 2160, 3840
    elif mode == 'both':
        img_folder_S = "/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/train_540p"
        lmdb_save_path_S = '/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/{}_540p.lmdb'.format(
            args.name)
        H_dst_S, W_dst_S = 256, 256
        img_folder_L = "/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/train_4k"
        lmdb_save_path_L = '/mnt/lustre/shanghai/cmic/home/xyz18/Dataset/{}_4k.lmdb'.format(
            args.name)
        H_dst_L, W_dst_L = 1024, 1024

    assert mode == 'both'

    N = 8  # divide one 4k into 8 parts

    n_thread = 40

    ########################################################

    import os
    import shutil
    if os.path.exists(lmdb_save_path_S):
        shutil.rmtree(lmdb_save_path_S)
    if os.path.exists(lmdb_save_path_L):
        shutil.rmtree(lmdb_save_path_L)

    if not lmdb_save_path_S.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    if osp.exists(lmdb_save_path_S):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path_S))
        sys.exit(1)

    if not lmdb_save_path_L.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    if osp.exists(lmdb_save_path_L):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path_L))
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    all_img_list_S = data_util._get_paths_from_images_suzhou(
        img_folder_S, args.small)
    all_img_list_L = data_util._get_paths_from_images_suzhou(
        img_folder_L, args.small)

    keys_S = []
    keys_L = []
    for img_path in all_img_list_S:
        split_rlt = img_path.split('/')
        folder = split_rlt[-2]
        img_name = split_rlt[-1].split('.png')[0]
        keys_S.append(folder + '_' + img_name)

    for img_path in all_img_list_L:
        split_rlt = img_path.split('/')
        folder = split_rlt[-2]
        img_name = split_rlt[-1].split('.png')[0]
        keys_L.append(folder + '_' + img_name)

    assert keys_S == keys_L

    keys_patch = []
    for img_path in all_img_list_L:
        split_rlt = img_path.split('/')
        folder = split_rlt[-2]
        img_name = split_rlt[-1].split('.png')[0]
        for i in range(N):
            keys_patch.append(folder + '_' + img_name + '_' + str(i))

    keys = keys_S

    if read_all_imgs:  #  todo never use that
        # read all images to memory (multiprocessing)
        dataset = {
        }  # store all image data. list cannot keep the order, use dict
        print('Read images with multiprocessing, #thread: {} ...'.format(
            n_thread))
        pbar = util.ProgressBar(len(all_img_list_S))

        def mycallback(arg):
            '''get the image data and update pbar'''
            key = arg[0]
            dataset[key] = arg[1]
            pbar.update('Reading {}'.format(key))

        pool = Pool(n_thread)
        for path, key in zip(all_img_list_S, keys_S):
            pool.apply_async(read_image_worker,
                             args=(path, key),
                             callback=mycallback)
        pool.close()
        pool.join()
        print('Finish reading {} images.\nWrite lmdb...'.format(
            len(all_img_list_S)))

    #### create lmdb environment

    # for small pic
    data_size_per_img = cv2.imread(all_img_list_S[0],
                                   cv2.IMREAD_UNCHANGED).nbytes
    print('data size per small image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list_S)
    # env_S = lmdb.open(lmdb_save_path_S, map_size=data_size * 10)

    # for large pic
    data_size_per_img = cv2.imread(all_img_list_L[0],
                                   cv2.IMREAD_UNCHANGED).nbytes
    print('data size per large image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list_S)
    # env_L = lmdb.open(lmdb_save_path_L, map_size=data_size * 10)

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list_S))

    # txn_S = env_S.begin(write=True)
    # txn_L = env_L.begin(write=True)

    for idx_all, (path_S, path_L,
                  key) in enumerate(zip(all_img_list_S, all_img_list_L, keys)):

        # pbar.update('Write {}'.format(key))

        data_S = dataset[key] if read_all_imgs else cv2.imread(
            path_S, cv2.IMREAD_UNCHANGED)  # shape H W C  ndarray
        data_L = dataset[key] if read_all_imgs else cv2.imread(
            path_L, cv2.IMREAD_UNCHANGED)

        # process the black blank
        H_S = data_S.shape[0]  # 540
        W_S = data_S.shape[1]  # 960
        H_L = data_L.shape[0]
        W_L = data_L.shape[1]

        blank_1_S = 0
        blank_2_S = 0
        for i in range(H_S):
            if not sum(data_S[:, :, 0][i]) == 0:
                blank_1_S = i - 1
                # assert not sum(data_S[:, :, 0][i+1]) == 0
                break

        for i in range(H_S):
            if not sum(data_S[:, :, 0][H_S - i - 1]) == 0:
                blank_2_S = (H_S - 1) - i - 1
                # assert not sum(data_S[:, :, 0][blank_2_S-1]) == 0
                break
        print('LQ :', blank_1_S, blank_2_S)

        if blank_1_S == -1:
            print('LQ has no blank')

        blank_1_L = 0
        blank_2_L = 0
        for i in range(H_L):
            if not sum(data_L[:, :, 0][i]) == 0:
                blank_1_L = i - 1
                # assert not sum(data_L[:, :, 0][i + 1]) == 0
                break

        for i in range(H_L):
            if not sum(data_L[:, :, 0][H_L - i - 1]) == 0:
                blank_2_L = (H_L - 1) - i - 1
                # assert not sum(data_L[:, :, 0][blank_2_L - 1]) == 0
                break
        print('GT :', blank_1_L, blank_2_L)
        if blank_1_L == -1 and blank_2_L == H_L - 2:
            print('No blank', key)
            U_L1 = 56
            D_L2 = 2104
        else:
            U_L1 = ((blank_1_L >> 2) + 1) << 2
            D_L2 = ((blank_2_L >> 2) - 1) << 2

        print('Content:', U_L1, D_L2)

        # crop into eight patches
        U_L1_d = U_L1 + H_dst_L
        D_L2_u = D_L2 - H_dst_L

        assert U_L1_d <= H_L
        assert D_L2_u >= 0

        # for idx in range(N):

        # crop eight part
        H_list = [U_L1, D_L2_u]
        W_list = [0, 1024, 2048, 2816]
        h_list = [U_L1 // 4, D_L2_u // 4]
        w_list = [0, 256, 512, 704]

        for h_idx, _ in enumerate(H_list):
            for w_idx, _ in enumerate(W_list):
                key_idx = key + '_' + str(4 * h_idx + w_idx)
                print(key_idx)
                key_byte = key_idx.encode('ascii')
                data_gt = data_L[H_list[h_idx]:(H_list[h_idx] + 1024),
                                 W_list[w_idx]:(W_list[w_idx] + 1024), :]
                data_lq = data_S[h_list[h_idx]:(h_list[h_idx] + 256),
                                 w_list[w_idx]:(w_list[w_idx] + 256), :]
                # txn_L.put(key_byte, data_gt.copy(order='C'))
                # txn_S.put(key_byte, data_lq.copy(order='C'))

        # if not read_all_imgs and idx_all % BATCH == 0:
        # txn_L.commit()
        # txn_S.commit()
        # txn_L = env_L.begin(write=True)
        # txn_S = env_S.begin(write=True)

    # txn_L.commit()
    # txn_S.commit()
    # env_L.close()
    # env_S.close()
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    meta_info['name'] = 'HDR_{}_wval'.format(mode)
    channel = 1 if 'flow' in mode else 3
    meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst_L, W_dst_L)
    meta_info['keys'] = keys_patch
    pickle.dump(meta_info,
                open(osp.join(lmdb_save_path_L, 'meta_info.pkl'), "wb"))
    print('Finish creating lmdb meta info.')

    # for LQ
    meta_info = {}
    meta_info['name'] = 'HDR_{}_wval'.format(mode)
    channel = 1 if 'flow' in mode else 3
    meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst_S, W_dst_S)
    meta_info['keys'] = keys_patch
    pickle.dump(meta_info,
                open(osp.join(lmdb_save_path_S, 'meta_info.pkl'), "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 22
0
def general_image_folder(opt):
    """Create lmdb for general image folders
    Users should define the keys, such as: '0321_s035' for DIV2K sub-images
    If all the images have the same resolution, it will only store one copy of resolution info.
        Otherwise, it will store every resolution info.
    """
    #### configurations
    read_all_imgs = False  # whether real all images to memory with multiprocessing
    # Set False for use limited memory
    BATCH = 5000  # After BATCH images, lmdb commits, if read_all_imgs = False
    n_thread = 40
    ########################################################
    img_folder = opt['img_folder']
    lmdb_save_path = opt['lmdb_save_path']
    meta_info = {'name': opt['name']}
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    all_img_list = sorted(glob.glob(osp.join(img_folder, '*')))
    keys = []
    for img_path in all_img_list:
        keys.append(osp.splitext(osp.basename(img_path))[0])

    if read_all_imgs:
        #### read all images to memory (multiprocessing)
        dataset = {
        }  # store all image data. list cannot keep the order, use dict
        print('Read images with multiprocessing, #thread: {} ...'.format(
            n_thread))
        pbar = util.ProgressBar(len(all_img_list))

        def mycallback(arg):
            '''get the image data and update pbar'''
            key = arg[0]
            dataset[key] = arg[1]
            pbar.update('Reading {}'.format(key))

        pool = Pool(n_thread)
        for path, key in zip(all_img_list, keys):
            pool.apply_async(read_image_worker,
                             args=(path, key),
                             callback=mycallback)
        pool.close()
        pool.join()
        print('Finish reading {} images.\nWrite lmdb...'.format(
            len(all_img_list)))

    #### create lmdb environment
    data_size_per_img = cv2.imread(all_img_list[0],
                                   cv2.IMREAD_UNCHANGED).nbytes
    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))
    txn = env.begin(write=True)
    resolutions = []
    for idx, (path, key) in enumerate(zip(all_img_list, keys)):
        pbar.update('Write {}'.format(key))
        key_byte = key.encode('ascii')
        data = dataset[key] if read_all_imgs else cv2.imread(
            path, cv2.IMREAD_UNCHANGED)
        if data.ndim == 2:
            H, W = data.shape
            C = 1
        else:
            H, W, C = data.shape
        txn.put(key_byte, data)
        resolutions.append('{:d}_{:d}_{:d}'.format(C, H, W))
        if not read_all_imgs and idx % BATCH == 0:
            txn.commit()
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information
    # check whether all the images are the same size
    assert len(keys) == len(resolutions)
    if len(set(resolutions)) <= 1:
        meta_info['resolution'] = [resolutions[0]]
        meta_info['keys'] = keys
        print('All images have the same resolution. Simplify the meta info.')
    else:
        meta_info['resolution'] = resolutions
        meta_info['keys'] = keys
        print(
            'Not all images have the same resolution. Save meta info for each image.'
        )

    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 23
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument("-opt", type=str, help="Path to option YAML file.")
    parser.add_argument("--launcher",
                        choices=["none", "pytorch"],
                        default="none",
                        help="job launcher")
    parser.add_argument("--local_rank", type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    # distributed training settings
    if args.launcher == "none":  # disabled distributed training
        opt["dist"] = False
        rank = -1
        print("Disabled distributed training.")
    else:
        opt["dist"] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    # loading resume state if exists
    if opt["path"].get("resume_state", None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt["path"]["resume_state"],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state["iter"])  # check resume options
    else:
        resume_state = None

    # mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt["path"]
                ["experiments_root"])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt["path"].items()
                 if not key == "experiments_root"
                 and "pretrain_model" not in key and "resume" not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger("base",
                          opt["path"]["log"],
                          "train_" + opt["name"],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger("base")
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt["use_tb_logger"] and "debug" not in opt["name"]:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info("You are using PyTorch {}. \
                            Tensorboard will use [tensorboardX]".format(
                    version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir="../tb_logger/" + opt["name"])
    else:
        util.setup_logger("base",
                          opt["path"]["log"],
                          "train",
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger("base")

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    # random seed
    seed = opt["train"]["manual_seed"]
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info("Random seed: {}".format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt["datasets"].items():
        if phase == "train":
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt["batch_size"]))
            total_iters = int(opt["train"]["niter"])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt["dist"]:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    "Number of train images: {:,d}, iters: {:,d}".format(
                        len(train_set), train_size))
                logger.info("Total epochs needed: {:d} for iters {:,d}".format(
                    total_epochs, total_iters))
        elif phase == "val":
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info("Number of val images in [{:s}]: {:d}".format(
                    dataset_opt["name"], len(val_set)))
        else:
            raise NotImplementedError(
                "Phase [{:s}] is not recognized.".format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)
    print("Model created!")

    # resume training
    if resume_state:
        logger.info("Resuming training from epoch: {}, iter: {}.".format(
            resume_state["epoch"], resume_state["iter"]))

        start_epoch = resume_state["epoch"]
        current_step = resume_state["iter"]
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info("Start training from epoch: {:d}, iter: {:d}".format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt["dist"]:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            # update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt["train"]["warmup_iter"])

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # log
            if current_step % opt["logger"]["print_freq"] == 0:
                logs = model.get_current_log()
                message = "[epoch:{:3d}, iter:{:8,d}, lr:(".format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += "{:.3e},".format(v)
                message += ")] "
                for k, v in logs.items():
                    message += "{:s}: {:.4e} ".format(k, v)
                    # tensorboard logger
                    if opt["use_tb_logger"] and "debug" not in opt["name"]:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)
            # validation
            if opt["datasets"].get(
                    "val",
                    None) and current_step % opt["train"]["val_freq"] == 0:
                # image restoration validation
                if opt["model"] in ["sr", "srgan"] and rank <= 0:
                    # does not support multi-GPU validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.0
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(
                            os.path.basename(val_data["LQ_path"][0]))[0]
                        img_dir = os.path.join(opt["path"]["val_images"],
                                               img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals["rlt"])  # uint8
                        gt_img = util.tensor2img(visuals["GT"])  # uint8

                        # Save SR images for reference
                        save_img_path = os.path.join(
                            img_dir,
                            "{:s}_{:d}.png".format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                          opt["scale"])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        pbar.update("Test {}".format(img_name))

                    avg_psnr = avg_psnr / idx

                    # log
                    logger.info("# Validation # PSNR: {:.4e}".format(avg_psnr))
                    # tensorboard logger
                    if opt["use_tb_logger"] and "debug" not in opt["name"]:
                        tb_logger.add_scalar("psnr", avg_psnr, current_step)
                else:  # video restoration validation
                    if opt["dist"]:
                        # multi-GPU testing
                        psnr_rlt = {}  # with border and center frames
                        if rank == 0:
                            pbar = util.ProgressBar(len(val_set))
                        for idx in range(rank, len(val_set), world_size):
                            val_data = val_set[idx]
                            val_data["LQs"].unsqueeze_(0)
                            val_data["GT"].unsqueeze_(0)
                            folder = val_data["folder"]
                            idx_d, max_idx = val_data["idx"].split("/")
                            idx_d, max_idx = int(idx_d), int(max_idx)
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = torch.zeros(
                                    max_idx,
                                    dtype=torch.float32,
                                    device="cuda")
                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals["rlt"])  # uint8
                            gt_img = util.tensor2img(visuals["GT"])  # uint8
                            # calculate PSNR
                            psnr_rlt[folder][idx_d] = util.calculate_psnr(
                                rlt_img, gt_img)

                            if rank == 0:
                                for _ in range(world_size):
                                    pbar.update("Test {} - {}/{}".format(
                                        folder, idx_d, max_idx))
                        # collect data
                        for _, v in psnr_rlt.items():
                            dist.reduce(v, 0)
                        dist.barrier()

                        if rank == 0:
                            psnr_rlt_avg = {}
                            psnr_total_avg = 0.0
                            for k, v in psnr_rlt.items():
                                psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                                psnr_total_avg += psnr_rlt_avg[k]
                            psnr_total_avg /= len(psnr_rlt)
                            log_s = "# Validation # PSNR: {:.4e}:".format(
                                psnr_total_avg)
                            for k, v in psnr_rlt_avg.items():
                                log_s += " {}: {:.4e}".format(k, v)
                            logger.info(log_s)
                            if opt["use_tb_logger"] and "debug" not in opt[
                                    "name"]:
                                tb_logger.add_scalar("psnr_avg",
                                                     psnr_total_avg,
                                                     current_step)
                                for k, v in psnr_rlt_avg.items():
                                    tb_logger.add_scalar(k, v, current_step)
                    else:
                        pbar = util.ProgressBar(len(val_loader))
                        psnr_rlt = {}  # with border and center frames
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.0
                        for val_data in val_loader:
                            folder = val_data["folder"][0]
                            idx_d, max_id = val_data["idx"][0].split("/")
                            # border = val_data['border'].item()
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = []

                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals["rlt"])  # uint8
                            gt_img = util.tensor2img(visuals["GT"])  # uint8
                            lq_img = util.tensor2img(visuals["LQ"][2])  # uint8

                            img_dir = opt["path"]["val_images"]
                            util.mkdir(img_dir)
                            save_img_path = os.path.join(
                                img_dir, "{}.png".format(idx_d))
                            util.save_img(np.hstack((lq_img, rlt_img, gt_img)),
                                          save_img_path)

                            # calculate PSNR
                            psnr = util.calculate_psnr(rlt_img, gt_img)
                            psnr_rlt[folder].append(psnr)
                            pbar.update("Test {} - {}".format(folder, idx_d))
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = sum(v) / len(v)
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = "# Validation # PSNR: {:.4e}:".format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += " {}: {:.4e}".format(k, v)
                        logger.info(log_s)
                        if opt["use_tb_logger"] and "debug" not in opt["name"]:
                            tb_logger.add_scalar("psnr_avg", psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

            # save models and training states
            if current_step % opt["logger"]["save_checkpoint_freq"] == 0:
                if rank <= 0:
                    logger.info("Saving models and training states.")
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info("Saving the final model.")
        model.save("latest")
        logger.info("End of training.")
        tb_logger.close()
Esempio n. 24
0
def REDS(mode = 'train_sharp', overwrite = True):
    '''create lmdb for the REDS dataset, each image with fixed size
    GT: [3, 720, 1280], key: 000_00000000
    LR: [3, 180, 320], key: 000_00000000
    key: 000_00000000
    '''
    #### configurations
    read_all_imgs = False  # whether real all images to the memory. Set False with limited memory
    BATCH = 5000  # After BATCH images, lmdb commits, if read_all_imgs = False
    # train_sharp | train_sharp_bicubic | train_blur_bicubic| train_blur | train_blur_comp
    if mode == 'train_sharp':
        img_folder = osp.join(root,'datasets/REDS/train/sharp')
        lmdb_save_path = osp.join(root,'datasets/REDS/train/sharp_wval.lmdb')
        H_dst, W_dst = 720, 1280
    elif mode == 'train_sharp_bicubic':
        img_folder = osp.join(root,'datasets/REDS/train/sharp_bicubic')
        lmdb_save_path = osp.join(root,'datasets/REDS/train/sharp_bicubic_wval.lmdb')
        H_dst, W_dst = 180, 320
    elif mode == 'train_blur_bicubic':
        img_folder = osp.join(root,'datasets/REDS/train/blur_bicubic')
        lmdb_save_path = osp.join(root,'datasets/REDS/train/blur_bicubic_wval.lmdb')
        H_dst, W_dst = 180, 320
    elif mode == 'train_blur':
        img_folder = osp.join(root,'datasets/REDS/train/blur')
        lmdb_save_path = osp.join(root,'datasets/REDS/train/blur_wval.lmdb')
        H_dst, W_dst = 720, 1280
    elif mode == 'train_blur_comp':
        img_folder = osp.join(root,'datasets/REDS/train/blur_comp')
        lmdb_save_path = osp.join(root,'datasets/REDS/train/blur_comp_wval.lmdb')
        H_dst, W_dst = 720, 1280
    n_thread = 40
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    #### whether the lmdb file exist
    if not overwrite and osp.exists(lmdb_save_path):
        print(f'Folder [{lmdb_save_path}] already exists. Exit...')
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    all_img_list = data_util._get_paths_from_images(img_folder)
    keys = []
    for img_path in all_img_list:
        split_rlt = img_path.split('/')
        a = split_rlt[-2]
        b = split_rlt[-1].split('.png')[0]
        keys.append(a + '_' + b)

    if read_all_imgs:
        #### read all images to memory (multiprocessing)
        dataset = {}  # store all image data. list cannot keep the order, use dict
        print(f'Read images with multiprocessing, #thread: {n_thread} ...')
        pbar = util.ProgressBar(len(all_img_list))

        def mycallback(arg):
            '''get the image data and update pbar'''
            key = arg[0]
            dataset[key] = arg[1]
            pbar.update(f'Reading {key}')

        pool = Pool(n_thread)
        for path, key in zip(all_img_list, keys):
            pool.apply_async(reading_image_worker, args=(path, key), callback=mycallback)
        pool.close()
        pool.join()
        print(f'Finish reading {len(all_img_list)} images.\nWrite lmdb...')

    #### create lmdb environment
    data_size_per_img = cv2.imread(all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes
    if 'flow' in mode:
        data_size_per_img = dataset['000_00000002_n1'].nbytes
    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))
    txn = env.begin(write=True)
    idx = 1
    for path, key in zip(all_img_list, keys):
        idx = idx + 1
        pbar.update(f'Write {key}')
        key_byte = key.encode('ascii')
        data = dataset[key] if read_all_imgs else cv2.imread(path, cv2.IMREAD_UNCHANGED)
        if 'flow' in mode:
            H, W = data.shape
            assert H == H_dst and W == W_dst, 'different shape.'
        else:
            H, W, C = data.shape  # fixed shape
            assert H == H_dst and W == W_dst and C == 3, 'different shape.'
        txn.put(key_byte, data)
        if not read_all_imgs and idx % BATCH == 1:
            txn.commit()
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    meta_info['name'] = 'REDS_{}_wval'.format(mode)
    if 'flow' in mode:
        meta_info['resolution'] = f'1_{H_dst}_{W_dst}')
    else:
        meta_info['resolution'] = f'{3}_{H_dst}_{W_dst}')
    meta_info['keys'] = keys
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 25
0
def vimeo90k():
    """create lmdb for the Vimeo90K dataset, each image with fixed size
    GT: [3, 256, 448]
        Only need the 4th frame currently, e.g., 00001_0001_4
    LR: [3, 64, 112]
        With 1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7
    key:
        Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001
    """
    #### configurations
    mode = "GT"  # GT | LR
    read_all_imgs = (
        False
    )  # whether real all images to the memory. Set False with limited memory
    BATCH = 5000  # After BATCH images, lmdb commits, if read_all_imgs = False
    if mode == "GT":
        img_folder = "/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences"
        lmdb_save_path = "/home/xtwang/datasets/vimeo90k/vimeo90k_train_GT.lmdb"
        txt_file = "/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt"
        H_dst, W_dst = 256, 448
    elif mode == "LR":
        img_folder = (
            "/home/xtwang/datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences"
        )
        lmdb_save_path = "/home/xtwang/datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb"
        txt_file = "/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt"
        H_dst, W_dst = 64, 112
    n_thread = 40
    ########################################################
    if not lmdb_save_path.endswith(".lmdb"):
        raise ValueError("lmdb_save_path must end with 'lmdb'.")
    #### whether the lmdb file exist
    if osp.exists(lmdb_save_path):
        print("Folder [{:s}] already exists. Exit...".format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print("Reading image path list ...")
    with open(txt_file) as f:
        train_l = f.readlines()
        train_l = [v.strip() for v in train_l]
    all_img_list = []
    keys = []
    for line in train_l:
        folder = line.split("/")[0]
        sub_folder = line.split("/")[1]
        file_l = glob.glob(osp.join(img_folder, folder, sub_folder) + "/*")
        all_img_list.extend(file_l)
        for j in range(7):
            keys.append("{}_{}_{}".format(folder, sub_folder, j + 1))
    all_img_list = sorted(all_img_list)
    keys = sorted(keys)
    if mode == "GT":  # read the 4th frame only for GT mode
        print("Only keep the 4th frame.")
        all_img_list = [v for v in all_img_list if v.endswith("im4.png")]
        keys = [v for v in keys if v.endswith("_4")]

    if read_all_imgs:
        #### read all images to memory (multiprocessing)
        dataset = {
        }  # store all image data. list cannot keep the order, use dict
        print("Read images with multiprocessing, #thread: {} ...".format(
            n_thread))
        pbar = util.ProgressBar(len(all_img_list))

        def mycallback(arg):
            """get the image data and update pbar"""
            key = arg[0]
            dataset[key] = arg[1]
            pbar.update("Reading {}".format(key))

        pool = Pool(n_thread)
        for path, key in zip(all_img_list, keys):
            pool.apply_async(reading_image_worker,
                             args=(path, key),
                             callback=mycallback)
        pool.close()
        pool.join()
        print("Finish reading {} images.\nWrite lmdb...".format(
            len(all_img_list)))

    #### create lmdb environment
    data_size_per_img = cv2.imread(all_img_list[0],
                                   cv2.IMREAD_UNCHANGED).nbytes
    print("data size per image is: ", data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))
    txn = env.begin(write=True)
    idx = 1
    for path, key in zip(all_img_list, keys):
        idx = idx + 1
        pbar.update("Write {}".format(key))
        key_byte = key.encode("ascii")
        data = dataset[key] if read_all_imgs else cv2.imread(
            path, cv2.IMREAD_UNCHANGED)
        H, W, C = data.shape  # fixed shape
        assert H == H_dst and W == W_dst and C == 3, "different shape."
        txn.put(key_byte, data)
        if not read_all_imgs and idx % BATCH == 1:
            txn.commit()
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print("Finish writing lmdb.")

    #### create meta information
    meta_info = {}
    if mode == "GT":
        meta_info["name"] = "Vimeo90K_train_GT"
    elif mode == "LR":
        meta_info["name"] = "Vimeo90K_train_LR"
    meta_info["resolution"] = "{}_{}_{}".format(3, H_dst, W_dst)
    key_set = set()
    for key in keys:
        a, b, _ = key.split("_")
        key_set.add("{}_{}".format(a, b))
    meta_info["keys"] = key_set
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, "meta_info.pkl"),
                                "wb"))
    print("Finish creating lmdb meta info.")
Esempio n. 26
0
def MultiScaleREDS(img_root, lmdb_save_path, scales):
    """Create lmdb for the REDS dataset with multiple scales
    """
    #### configurations
    BATCH = 5000  # After BATCH images, lmdb commits, if read_all_imgs = False
    n_thread = 40
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)

    print('Reading image path list ...')
    # all_img_list = get_paths_from_images(img_folder)
    scale_folders = sorted(os.listdir(img_root))
    all_imgs, all_keys = [], []
    resolution = {}
    for i, folder in enumerate(scale_folders):
        print('[{:02d}/{:02d}] Reading scale-folder: {:s} ...'.format(
            i, len(scale_folders), folder))
        folder_dir = osp.join(img_root, folder)
        sub_folders = sorted(os.listdir(folder_dir))
        for sub in sub_folders:
            sub_dir = osp.join(folder_dir, sub)
            img_names = sorted(os.listdir(sub_dir))
            imgs = [osp.join(sub_dir, name) for name in img_names]
            keys = [folder + '_' + sub + '_' + name[:-4] for name in img_names]
            all_imgs.extend(imgs)
            all_keys.extend(keys)
        resolution[folder] = cv2.imread(imgs[-1]).shape

    #### create lmdb environment
    data_size_per_img = cv2.imread(all_imgs[0], cv2.IMREAD_UNCHANGED).nbytes
    print('max data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_imgs)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)

    #### write data to lmdb
    txn = env.begin(write=True)
    for i in range(0, len(all_imgs), BATCH):
        imgs = all_imgs[i:i + BATCH]
        keys = all_keys[i:i + BATCH]
        batch_data = read_imgs_multi_thread(imgs, keys, n_thread)
        pbar = util.ProgressBar(len(imgs))
        for k, v in batch_data.items():
            pbar.update('Write {}'.format(k))
            key_byte = k.encode('ascii')
            txn.put(key_byte, v)
        txn.commit()
        txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    meta_info['name'] = 'REDS_X1_X6_wval'
    channel = 3
    meta_info['resolution'] = resolution
    meta_info['keys'] = all_keys
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 27
0
def vimeo90k(mode):
    """Create lmdb for the Vimeo90K dataset, each image with a fixed size
    GT: [3, 256, 448]
        Now only need the 4th frame, e.g., 00001_0001_4
    LR: [3, 64, 112]
        1st - 7th frames, e.g., 00001_0001_1, ..., 00001_0001_7
    key:
        Use the folder and subfolder names, w/o the frame index, e.g., 00001_0001

    flow: downsampled flow: [3, 360, 320], keys: 00001_0001_4_[p3, p2, p1, n1, n2, n3]
        Each flow is calculated with GT images by PWCNet and then downsampled by 1/4
        Flow map is quantized by mmcv and saved in png format
    """
    #### configurations
    read_all_imgs = False  # whether real all images to memory with multiprocessing
    # Set False for use limited memory
    BATCH = 5000  # After BATCH images, lmdb commits, if read_all_imgs = False
    if mode == 'GT':
        img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences'
        lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_GT.lmdb'
        txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
        H_dst, W_dst = 256, 448
    elif mode == 'LR':
        img_folder = '../../datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences'
        lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
        txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
        H_dst, W_dst = 64, 112
    elif mode == 'flow':
        img_folder = '../../datasets/vimeo90k/vimeo_septuplet/sequences_flowx4'
        lmdb_save_path = '../../datasets/vimeo90k/vimeo90k_train_flowx4.lmdb'
        txt_file = '../../datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
        H_dst, W_dst = 128, 112
    else:
        raise ValueError('Wrong dataset mode: {}'.format(mode))
    n_thread = 40
    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError("lmdb_save_path must end with \'lmdb\'.")
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path))
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    with open(txt_file) as f:
        train_l = f.readlines()
        train_l = [v.strip() for v in train_l]
    all_img_list = []
    keys = []
    for line in train_l:
        folder = line.split('/')[0]
        sub_folder = line.split('/')[1]
        all_img_list.extend(
            glob.glob(osp.join(img_folder, folder, sub_folder, '*')))
        if mode == 'flow':
            for j in range(1, 4):
                keys.append('{}_{}_4_n{}'.format(folder, sub_folder, j))
                keys.append('{}_{}_4_p{}'.format(folder, sub_folder, j))
        else:
            for j in range(7):
                keys.append('{}_{}_{}'.format(folder, sub_folder, j + 1))
    all_img_list = sorted(all_img_list)
    keys = sorted(keys)
    if mode == 'GT':  # only read the 4th frame for the GT mode
        print('Only keep the 4th frame.')
        all_img_list = [v for v in all_img_list if v.endswith('im4.png')]
        keys = [v for v in keys if v.endswith('_4')]

    if read_all_imgs:
        #### read all images to memory (multiprocessing)
        dataset = {
        }  # store all image data. list cannot keep the order, use dict
        print('Read images with multiprocessing, #thread: {} ...'.format(
            n_thread))
        pbar = util.ProgressBar(len(all_img_list))

        def mycallback(arg):
            """get the image data and update pbar"""
            key = arg[0]
            dataset[key] = arg[1]
            pbar.update('Reading {}'.format(key))

        pool = Pool(n_thread)
        for path, key in zip(all_img_list, keys):
            pool.apply_async(read_image_worker,
                             args=(path, key),
                             callback=mycallback)
        pool.close()
        pool.join()
        print('Finish reading {} images.\nWrite lmdb...'.format(
            len(all_img_list)))

    #### write data to lmdb
    data_size_per_img = cv2.imread(all_img_list[0],
                                   cv2.IMREAD_UNCHANGED).nbytes
    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)
    txn = env.begin(write=True)
    pbar = util.ProgressBar(len(all_img_list))
    for idx, (path, key) in enumerate(zip(all_img_list, keys)):
        pbar.update('Write {}'.format(key))
        key_byte = key.encode('ascii')
        data = dataset[key] if read_all_imgs else cv2.imread(
            path, cv2.IMREAD_UNCHANGED)
        if 'flow' in mode:
            H, W = data.shape
            assert H == H_dst and W == W_dst, 'different shape.'
        else:
            H, W, C = data.shape
            assert H == H_dst and W == W_dst and C == 3, 'different shape.'
        txn.put(key_byte, data)
        if not read_all_imgs and idx % BATCH == 0:
            txn.commit()
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information
    meta_info = {}
    if mode == 'GT':
        meta_info['name'] = 'Vimeo90K_train_GT'
    elif mode == 'LR':
        meta_info['name'] = 'Vimeo90K_train_LR'
    elif mode == 'flow':
        meta_info['name'] = 'Vimeo90K_train_flowx4'
    channel = 1 if 'flow' in mode else 3
    meta_info['resolution'] = '{}_{}_{}'.format(channel, H_dst, W_dst)
    key_set = set()
    for key in keys:
        if mode == 'flow':
            a, b, _, _ = key.split('_')
        else:
            a, b, _ = key.split('_')
        key_set.add('{}_{}'.format(a, b))
    meta_info['keys'] = list(key_set)
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 28
0
def main():

    ############################################
    #
    #           set options
    #
    ############################################

    parser = argparse.ArgumentParser()
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    ############################################
    #
    #           distributed training settings
    #
    ############################################

    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

        print("Rank:", rank)
        print("------------------DIST-------------------------")

    ############################################
    #
    #           loading resume state if exists
    #
    ############################################

    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    ############################################
    #
    #           mkdir and loggers
    #
    ############################################

    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists

            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)

        logger = logging.getLogger('base')
        logger_val = logging.getLogger('base_val')

        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_',
                          level=logging.INFO,
                          screen=True)

        print("set train log")

        util.setup_logger('base_val',
                          opt['path']['log'],
                          'val_',
                          level=logging.INFO,
                          screen=True)

        print("set val log")

        logger = logging.getLogger('base')

        logger_val = logging.getLogger('base_val')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    ############################################
    #
    #           create train and val dataloader
    #
    ############################################
    ####

    # dataset_ratio = 200  # enlarge the size of each epoch, todo: what it is
    dataset_ratio = 1  # enlarge the size of each epoch, todo: what it is
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            # total_iters = int(opt['train']['niter'])
            # total_epochs = int(math.ceil(total_iters / train_size))

            total_iters = train_size
            total_epochs = int(opt['train']['epoch'])

            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                # total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio)))
                total_epochs = int(opt['train']['epoch'])
                if opt['train']['enable'] == False:
                    total_epochs = 1
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))

    assert train_loader is not None

    ############################################
    #
    #          create model
    #
    ############################################
    ####

    model = create_model(opt)

    print("Model Created! ")

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0
        print("Not Resume Training")

    ############################################
    #
    #          training
    #
    ############################################
    ####

    ####
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    Avg_train_loss = AverageMeter()  # total
    if (opt['train']['pixel_criterion'] == 'cb+ssim'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_ssim = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_ssim = AverageMeter()
        Avg_train_loss_vmaf = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'ssim'):
        Avg_train_loss_ssim = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'msssim'):
        Avg_train_loss_msssim = AverageMeter()
    elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
        Avg_train_loss_pix = AverageMeter()
        Avg_train_loss_msssim = AverageMeter()

    saved_total_loss = 10e10
    saved_total_PSNR = -1

    for epoch in range(start_epoch, total_epochs):

        ############################################
        #
        #          Start a new epoch
        #
        ############################################

        # Turn into training mode
        #model = model.train()

        # reset total loss
        Avg_train_loss.reset()
        current_step = 0

        if (opt['train']['pixel_criterion'] == 'cb+ssim'):
            Avg_train_loss_pix.reset()
            Avg_train_loss_ssim.reset()
        elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
            Avg_train_loss_pix.reset()
            Avg_train_loss_ssim.reset()
            Avg_train_loss_vmaf.reset()
        elif (opt['train']['pixel_criterion'] == 'ssim'):
            Avg_train_loss_ssim = AverageMeter()
        elif (opt['train']['pixel_criterion'] == 'msssim'):
            Avg_train_loss_msssim = AverageMeter()
        elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
            Avg_train_loss_pix = AverageMeter()
            Avg_train_loss_msssim = AverageMeter()

        if opt['dist']:
            train_sampler.set_epoch(epoch)

        for train_idx, train_data in enumerate(train_loader):

            if 'debug' in opt['name']:

                img_dir = os.path.join(opt['path']['train_images'])
                util.mkdir(img_dir)

                LQ = train_data['LQs']
                GT = train_data['GT']

                GT_img = util.tensor2img(GT)  # uint8

                save_img_path = os.path.join(
                    img_dir, '{:4d}_{:s}.png'.format(train_idx, 'debug_GT'))
                util.save_img(GT_img, save_img_path)

                for i in range(5):
                    LQ_img = util.tensor2img(LQ[0, i, ...])  # uint8
                    save_img_path = os.path.join(
                        img_dir,
                        '{:4d}_{:s}_{:1d}.png'.format(train_idx, 'debug_LQ',
                                                      i))
                    util.save_img(LQ_img, save_img_path)

                if (train_idx >= 3):
                    break

            if opt['train']['enable'] == False:
                message_train_loss = 'None'
                break

            current_step += 1
            if current_step > total_iters:
                print("Total Iteration Reached !")
                break
            #### update learning rate
            if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
                pass
            else:
                model.update_learning_rate(
                    current_step, warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)

            # if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
            #    model.optimize_parameters_without_schudlue(current_step)
            # else:
            model.optimize_parameters(current_step)

            if (opt['train']['pixel_criterion'] == 'cb+ssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
                Avg_train_loss_vmaf.update(model.log_dict['vmaf_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'ssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_ssim.update(model.log_dict['ssim_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'msssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1)
            elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
                Avg_train_loss.update(model.log_dict['total_loss'], 1)
                Avg_train_loss_pix.update(model.log_dict['l_pix'], 1)
                Avg_train_loss_msssim.update(model.log_dict['msssim_loss'], 1)
            else:
                Avg_train_loss.update(model.log_dict['l_pix'], 1)

            # add total train loss
            if (opt['train']['pixel_criterion'] == 'cb+ssim'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'cb+ssim+vmaf'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' vmaf_avg_loss: {:.4e}'.format(
                    Avg_train_loss_vmaf.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'ssim'):
                message_train_loss = ' ssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_ssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'msssim'):
                message_train_loss = ' msssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_msssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            elif (opt['train']['pixel_criterion'] == 'cb+msssim'):
                message_train_loss = ' pix_avg_loss: {:.4e}'.format(
                    Avg_train_loss_pix.avg)
                message_train_loss += ' msssim_avg_loss: {:.4e}'.format(
                    Avg_train_loss_msssim.avg)
                message_train_loss += ' total_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)
            else:
                message_train_loss = ' train_avg_loss: {:.4e}'.format(
                    Avg_train_loss.avg)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)

                message += message_train_loss

                if rank <= 0:
                    logger.info(message)

        ############################################
        #
        #        end of one epoch, save epoch model
        #
        ############################################

        #### save models and training states
        # if current_step % opt['logger']['save_checkpoint_freq'] == 0:
        #     if rank <= 0:
        #         logger.info('Saving models and training states.')
        #         model.save(current_step)
        #         model.save('latest')
        #         # model.save_training_state(epoch, current_step)
        #         # todo delete previous weights
        #         previous_step = current_step - opt['logger']['save_checkpoint_freq']
        #         save_filename = '{}_{}.pth'.format(previous_step, 'G')
        #         save_path = os.path.join(opt['path']['models'], save_filename)
        #         if os.path.exists(save_path):
        #             os.remove(save_path)

        if epoch == 1:
            save_filename = '{:04d}_{}.pth'.format(0, 'G')
            save_path = os.path.join(opt['path']['models'], save_filename)
            if os.path.exists(save_path):
                os.remove(save_path)

        save_filename = '{:04d}_{}.pth'.format(epoch - 1, 'G')
        save_path = os.path.join(opt['path']['models'], save_filename)
        if os.path.exists(save_path):
            os.remove(save_path)

        if rank <= 0:
            logger.info('Saving models and training states.')
            save_filename = '{:04d}'.format(epoch)
            model.save(save_filename)
            # model.save('latest')
            # model.save_training_state(epoch, current_step)

        ############################################
        #
        #          end of one epoch, do validation
        #
        ############################################

        #### validation
        #if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
        if opt['datasets'].get('val', None):
            if opt['model'] in [
                    'sr', 'srgan'
            ] and rank <= 0:  # image restoration validation
                # does not support multi-GPU validation
                pbar = util.ProgressBar(len(val_loader))
                avg_psnr = 0.
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()
                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['rlt'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}.png'.format(img_name, current_step))
                    #util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                      opt['scale'])
                    avg_psnr += util.calculate_psnr(sr_img, gt_img)
                    pbar.update('Test {}'.format(img_name))

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
            else:  # video restoration validation
                if opt['dist']:
                    # todo : multi-GPU testing
                    psnr_rlt = {}  # with border and center frames
                    psnr_rlt_avg = {}
                    psnr_total_avg = 0.

                    ssim_rlt = {}  # with border and center frames
                    ssim_rlt_avg = {}
                    ssim_total_avg = 0.

                    val_loss_rlt = {}
                    val_loss_rlt_avg = {}
                    val_loss_total_avg = 0.

                    if rank == 0:
                        pbar = util.ProgressBar(len(val_set))

                    for idx in range(rank, len(val_set), world_size):

                        print('idx', idx)

                        if 'debug' in opt['name']:
                            if (idx >= 3):
                                break

                        val_data = val_set[idx]
                        val_data['LQs'].unsqueeze_(0)
                        val_data['GT'].unsqueeze_(0)
                        folder = val_data['folder']
                        idx_d, max_idx = val_data['idx'].split('/')
                        idx_d, max_idx = int(idx_d), int(max_idx)

                        if psnr_rlt.get(folder, None) is None:
                            psnr_rlt[folder] = torch.zeros(max_idx,
                                                           dtype=torch.float32,
                                                           device='cuda')

                        if ssim_rlt.get(folder, None) is None:
                            ssim_rlt[folder] = torch.zeros(max_idx,
                                                           dtype=torch.float32,
                                                           device='cuda')

                        if val_loss_rlt.get(folder, None) is None:
                            val_loss_rlt[folder] = torch.zeros(
                                max_idx, dtype=torch.float32, device='cuda')

                        # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                        model.feed_data(val_data)
                        # model.test()
                        # model.test_stitch()

                        if opt['stitch'] == True:
                            model.test_stitch()
                        else:
                            model.test()  # large GPU memory

                        # visuals = model.get_current_visuals()
                        visuals = model.get_current_visuals(
                            save=True,
                            name='{}_{}'.format(folder, idx),
                            save_path=opt['path']['val_images'])

                        rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # calculate PSNR
                        psnr = util.calculate_psnr(rlt_img, gt_img)
                        psnr_rlt[folder][idx_d] = psnr

                        # calculate SSIM
                        ssim = util.calculate_ssim(rlt_img, gt_img)
                        ssim_rlt[folder][idx_d] = ssim

                        # calculate Val loss
                        val_loss = model.get_loss()
                        val_loss_rlt[folder][idx_d] = val_loss

                        logger.info(
                            '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format(
                                folder, idx, psnr, ssim))

                        if rank == 0:
                            for _ in range(world_size):
                                pbar.update('Test {} - {}/{}'.format(
                                    folder, idx_d, max_idx))

                    # # collect data
                    for _, v in psnr_rlt.items():
                        dist.reduce(v, 0)

                    for _, v in ssim_rlt.items():
                        dist.reduce(v, 0)

                    for _, v in val_loss_rlt.items():
                        dist.reduce(v, 0)

                    dist.barrier()

                    if rank == 0:
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)

                        # ssim
                        ssim_rlt_avg = {}
                        ssim_total_avg = 0.
                        for k, v in ssim_rlt.items():
                            ssim_rlt_avg[k] = torch.mean(v).cpu().item()
                            ssim_total_avg += ssim_rlt_avg[k]
                        ssim_total_avg /= len(ssim_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            ssim_total_avg)
                        for k, v in ssim_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)

                        # added
                        val_loss_rlt_avg = {}
                        val_loss_total_avg = 0.
                        for k, v in val_loss_rlt.items():
                            val_loss_rlt_avg[k] = torch.mean(v).cpu().item()
                            val_loss_total_avg += val_loss_rlt_avg[k]
                        val_loss_total_avg /= len(val_loss_rlt)
                        log_l = '# Validation # Loss: {:.4e}:'.format(
                            val_loss_total_avg)
                        for k, v in val_loss_rlt_avg.items():
                            log_l += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_l)

                        message = ''
                        for v in model.get_current_learning_rate():
                            message += '{:.5e}'.format(v)

                        logger_val.info(
                            'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}'
                            .format(epoch, message, psnr_total_avg,
                                    ssim_total_avg, message_train_loss,
                                    val_loss_total_avg))

                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)
                            # add val loss
                            tb_logger.add_scalar('val_loss_avg',
                                                 val_loss_total_avg,
                                                 current_step)
                            for k, v in val_loss_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

                else:  # Todo: our function One GPU
                    pbar = util.ProgressBar(len(val_loader))
                    psnr_rlt = {}  # with border and center frames
                    psnr_rlt_avg = {}
                    psnr_total_avg = 0.

                    ssim_rlt = {}  # with border and center frames
                    ssim_rlt_avg = {}
                    ssim_total_avg = 0.

                    val_loss_rlt = {}
                    val_loss_rlt_avg = {}
                    val_loss_total_avg = 0.

                    for val_inx, val_data in enumerate(val_loader):

                        if 'debug' in opt['name']:
                            if (val_inx >= 5):
                                break

                        folder = val_data['folder'][0]
                        # idx_d = val_data['idx'].item()
                        idx_d = val_data['idx']
                        # border = val_data['border'].item()
                        if psnr_rlt.get(folder, None) is None:
                            psnr_rlt[folder] = []

                        if ssim_rlt.get(folder, None) is None:
                            ssim_rlt[folder] = []

                        if val_loss_rlt.get(folder, None) is None:
                            val_loss_rlt[folder] = []

                        # process the black blank [B N C H W]

                        print(val_data['LQs'].size())

                        H_S = val_data['LQs'].size(3)  # 540
                        W_S = val_data['LQs'].size(4)  # 960

                        print(H_S)
                        print(W_S)

                        blank_1_S = 0
                        blank_2_S = 0

                        print(val_data['LQs'][0, 2, 0, :, :].size())

                        for i in range(H_S):
                            if not sum(val_data['LQs'][0, 2, 0, i, :]) == 0:
                                blank_1_S = i - 1
                                # assert not sum(data_S[:, :, 0][i+1]) == 0
                                break

                        for i in range(H_S):
                            if not sum(val_data['LQs'][0, 2, 0, :,
                                                       H_S - i - 1]) == 0:
                                blank_2_S = (H_S - 1) - i - 1
                                # assert not sum(data_S[:, :, 0][blank_2_S-1]) == 0
                                break
                        print('LQ :', blank_1_S, blank_2_S)

                        if blank_1_S == -1:
                            print('LQ has no blank')
                            blank_1_S = 0
                            blank_2_S = H_S

                        # val_data['LQs'] = val_data['LQs'][:,:,:,blank_1_S:blank_2_S,:]

                        print("LQ", val_data['LQs'].size())

                        # end of process the black blank

                        model.feed_data(val_data)

                        if opt['stitch'] == True:
                            model.test_stitch()
                        else:
                            model.test()  # large GPU memory

                        # process blank

                        blank_1_L = blank_1_S << 2
                        blank_2_L = blank_2_S << 2
                        print(blank_1_L, blank_2_L)

                        print(model.fake_H.size())

                        if not blank_1_S == 0:
                            # model.fake_H = model.fake_H[:,:,blank_1_L:blank_2_L,:]
                            model.fake_H[:, :, 0:blank_1_L, :] = 0
                            model.fake_H[:, :, blank_2_L:H_S, :] = 0

                        # end of # process blank

                        visuals = model.get_current_visuals(
                            save=True,
                            name='{}_{:02d}'.format(folder, val_inx),
                            save_path=opt['path']['val_images'])

                        rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # calculate PSNR
                        psnr = util.calculate_psnr(rlt_img, gt_img)
                        psnr_rlt[folder].append(psnr)

                        # calculate SSIM
                        ssim = util.calculate_ssim(rlt_img, gt_img)
                        ssim_rlt[folder].append(ssim)

                        # val loss
                        val_loss = model.get_loss()
                        val_loss_rlt[folder].append(val_loss.item())

                        logger.info(
                            '{}_{:02d} PSNR: {:.4f}, SSIM: {:.4f}'.format(
                                folder, val_inx, psnr, ssim))

                        pbar.update('Test {} - {}'.format(folder, idx_d))

                    # average PSNR
                    for k, v in psnr_rlt.items():
                        psnr_rlt_avg[k] = sum(v) / len(v)
                        psnr_total_avg += psnr_rlt_avg[k]
                    psnr_total_avg /= len(psnr_rlt)
                    log_s = '# Validation # PSNR: {:.4e}:'.format(
                        psnr_total_avg)
                    for k, v in psnr_rlt_avg.items():
                        log_s += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_s)

                    # average SSIM
                    for k, v in ssim_rlt.items():
                        ssim_rlt_avg[k] = sum(v) / len(v)
                        ssim_total_avg += ssim_rlt_avg[k]
                    ssim_total_avg /= len(ssim_rlt)
                    log_s = '# Validation # SSIM: {:.4e}:'.format(
                        ssim_total_avg)
                    for k, v in ssim_rlt_avg.items():
                        log_s += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_s)

                    # average VMAF

                    # average Val LOSS
                    for k, v in val_loss_rlt.items():
                        val_loss_rlt_avg[k] = sum(v) / len(v)
                        val_loss_total_avg += val_loss_rlt_avg[k]
                    val_loss_total_avg /= len(val_loss_rlt)
                    log_l = '# Validation # Loss: {:.4e}:'.format(
                        val_loss_total_avg)
                    for k, v in val_loss_rlt_avg.items():
                        log_l += ' {}: {:.4e}'.format(k, v)
                    logger.info(log_l)

                    # toal validation log

                    message = ''
                    for v in model.get_current_learning_rate():
                        message += '{:.5e}'.format(v)

                    logger_val.info(
                        'Epoch {:02d}, LR {:s}, PSNR {:.4f}, SSIM {:.4f} Train {:s}, Val Total Loss {:.4e}'
                        .format(epoch, message, psnr_total_avg, ssim_total_avg,
                                message_train_loss, val_loss_total_avg))

                    # end add

                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                             current_step)
                        for k, v in psnr_rlt_avg.items():
                            tb_logger.add_scalar(k, v, current_step)
                        # tb_logger.add_scalar('ssim_avg', ssim_total_avg, current_step)
                        # for k, v in ssim_rlt_avg.items():
                        #     tb_logger.add_scalar(k, v, current_step)
                        # add val loss
                        tb_logger.add_scalar('val_loss_avg',
                                             val_loss_total_avg, current_step)
                        for k, v in val_loss_rlt_avg.items():
                            tb_logger.add_scalar(k, v, current_step)

            ############################################
            #
            #          end of validation, save model
            #
            ############################################
            #
            logger.info("Finished an epoch, Check and Save the model weights")
            # we check the validation loss instead of training loss. OK~
            if saved_total_loss >= val_loss_total_avg:
                saved_total_loss = val_loss_total_avg
                #torch.save(model.state_dict(), args.save_path + "/best" + ".pth")
                model.save('best')
                logger.info(
                    "Best Weights updated for decreased validation loss")

            else:
                logger.info(
                    "Weights Not updated for undecreased validation loss")

            if saved_total_PSNR <= psnr_total_avg:
                saved_total_PSNR = psnr_total_avg
                model.save('bestPSNR')
                logger.info(
                    "Best Weights updated for increased validation PSNR")

            else:
                logger.info(
                    "Weights Not updated for unincreased validation PSNR")

        ############################################
        #
        #          end of one epoch, schedule LR
        #
        ############################################

        # add scheduler  todo
        if opt['train']['lr_scheme'] == 'ReduceLROnPlateau':
            for scheduler in model.schedulers:
                # scheduler.step(val_loss_total_avg)
                scheduler.step(val_loss_total_avg)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('last')
        logger.info('End of training.')
        tb_logger.close()
Esempio n. 29
0
def AI_4K(mode):
    '''create lmdb for the REDS dataset, each image with fixed size
    GT: [3, 2160, 3840], key: 000000_000000
    LR: [3, 540, 960], key: 000000_000000
    key: 000000_00000
    ** 记得前面我们的数据结构吗?{子目录名}_{图片名}
    '''
    mode = mode  # ** 数据模式: input / gt
    read_all_imgs = False  # whether real all images to the memory. Set False with limited memory
    BATCH = 500  # After BATCH images, lmdb commits, if read_all_imgs = False

    if mode == 'input':
        img_folder = '/mnt/sdb/duan/EDVR/datasets/AI_4K/val/input'  # ** 使用相对路径指向我们的数据集的input
        lmdb_save_path = '/mnt/sdb/duan/EDVR/datasets/AI_4K/train_input_wval.lmdb'  # ** 待会生成的lmdb文件存储的路径
        '''原来使用全局路径,我们使用相对路径'''
        H_dst, W_dst = 540, 960  # 帧的大小:H,W

    elif mode == 'gt':
        img_folder = '/mnt/sdb/duan/EDVR/datasets/AI_4K/train/gt'  # ** 使用相对路径指向我们的数据集的input
        lmdb_save_path = '/mnt/sdb/duan/EDVR/datasets/AI_4K/train_gt_wval.lmdb'  # ** 待会生成的lmdb文件存储的路径
        '''原来使用全局路径,我们使用相对路径'''
        H_dst, W_dst = 2160, 3840  # 帧的大小:H,W
    elif mode == 'test':
        img_folder = '/mnt/sdb/duan/EDVR/datasets/AI_4K/test/gt'  # ** 使用相对路径指向我们的数据集的input
        lmdb_save_path = '/mnt/sdb/duan/EDVR/datasets/AI_4K/test_gt_wval.lmdb'  # ** 待会生成的lmdb文件存储的路径
        '''原来使用全局路径,我们使用相对路径'''
        H_dst, W_dst = 2160, 3840

    ########################################################
    if not lmdb_save_path.endswith('.lmdb'):
        raise ValueError(
            "lmdb_save_path must end with \'lmdb\'.")  # 保存格式必须以“.lmdb”结尾
    #### whether the lmdb file exist
    if osp.exists(lmdb_save_path):
        print('Folder [{:s}] already exists. Exit...'.format(
            lmdb_save_path))  # 文件是否已经存在
        sys.exit(1)

    #### read all the image paths to a list
    print('Reading image path list ...')
    all_img_list = data_util._get_paths_from_images(
        img_folder)  # 获取input/gt下所有帧的完整路径名,作为list
    keys = []
    for img_path in all_img_list:
        split_rlt = img_path.split('/')
        # 取子文件夹名 xxxxxx
        a = split_rlt[-2]
        # 取帧的名字,出去文件后缀 xxxxxx
        b = split_rlt[-1].split('.png')[0]  # ** 我们的图像是".jpg"结尾的
        keys.append(a + '_' + b)

    #### create lmdb environment
    data_size_per_img = cv2.imread(
        all_img_list[0], cv2.IMREAD_UNCHANGED).nbytes  # 每帧图像大小(byte为单位)

    print('data size per image is: ', data_size_per_img)
    data_size = data_size_per_img * len(all_img_list)  # 总的需要多少空间
    env = lmdb.open(lmdb_save_path, map_size=data_size * 10)  # 索取这么多的比特数

    #### write data to lmdb
    pbar = util.ProgressBar(len(all_img_list))
    txn = env.begin(write=True)
    idx = 1
    for path, key in zip(all_img_list, keys):
        idx = idx + 1
        pbar.update('Write {}'.format(key))
        key_byte = key.encode('ascii')

        data = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        H, W, C = data.shape  # fixed shape
        assert H == H_dst and W == W_dst and C == 3, 'different shape.'
        txn.put(key_byte, data)
        if not read_all_imgs and idx % BATCH == 1:
            txn.commit()
            txn = env.begin(write=True)
    txn.commit()
    env.close()
    print('Finish writing lmdb.')

    #### create meta information                                                    # 存储元数据:名字(str)+分辨率(str)
    meta_info = {}
    meta_info['name'] = 'OURS_{}_wval'.format(mode)  # ** 现在的数据集是OURS了

    meta_info['resolution'] = '{}_{}_{}'.format(3, H_dst, W_dst)
    meta_info['keys'] = keys
    pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'),
                                "wb"))
    print('Finish creating lmdb meta info.')
Esempio n. 30
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### loading resume state if exists
    if 'resume_latest' in opt and opt['resume_latest'] == True:
        if os.path.isdir(opt['path']['training_state']):
            name_state_files = os.listdir(opt['path']['training_state'])
            if len(name_state_files) > 0:
                latest_state_num = 0
                for name_state_file in name_state_files:
                    state_num = int(name_state_file.split('.')[0])
                    if state_num > latest_state_num:
                        latest_state_num = state_num
                opt['path']['resume_state'] = os.path.join(
                    opt['path']['training_state'], str(latest_state_num)+'.state')
            else:
                raise ValueError
    if opt['path'].get('resume_state', None):
        device_id = torch.cuda.current_device()
        resume_state = torch.load(opt['path']['resume_state'],
                                  map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if resume_state is None:
        util.mkdir_and_rename(
            opt['path']['experiments_root'])  # rename experiment folder if exists
        util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
                      screen=True, tofile=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        version = float(torch.__version__[0:3])
        if version >= 1.1:  # PyTorch 1.1
            from torch.utils.tensorboard import SummaryWriter
        else:
            logger.info(
                'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
            from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='../tb_logger/' +
                                  opt['name'] + '_{}'.format(util.get_timestamp()))

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            train_sampler = None
            train_loader = create_dataloader(
                train_set, dataset_opt, opt, train_sampler)
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    is_time = False

    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))

    if is_time:
        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')

    for epoch in range(start_epoch, total_epochs + 1):
        if current_step > total_iters:
            break

        if is_time:
            torch.cuda.synchronize()
            end = time.time()
        for _, train_data in enumerate(train_loader):
            if 'adv_train' in opt:
                current_step += opt['adv_train']['m']
            else:
                current_step += 1
            if current_step > total_iters:
                break

            #### training
            model.feed_data(train_data)

            if is_time:
                torch.cuda.synchronize()
                data_time.update(time.time() - end)

            model.optimize_parameters(current_step)

            #### update learning rate
            model.update_learning_rate(
                current_step, warmup_iter=opt['train']['warmup_iter'])

            if is_time:
                torch.cuda.synchronize()
                batch_time.update(time.time() - end)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                # FIXME remove debug
                debug = True
                if debug:
                    torch.cuda.empty_cache()
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)
                if is_time:
                    logger.info(str(data_time))
                    logger.info(str(batch_time))

            #### validation
            if opt['datasets'].get('val', None) and current_step % opt['train']['val_freq'] == 0:
                if opt['model'] in ['sr', 'srgan']:  # image restoration validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(
                            os.path.basename(val_data['LQ_path'][0]))[0]
                        img_dir = os.path.join(
                            opt['path']['val_images'], img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # Save SR images for reference
                        save_img_path = os.path.join(img_dir,
                                                     '{:s}_{:d}.png'.format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border(
                            [sr_img, gt_img], opt['scale'])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        pbar.update('Test {}'.format(img_name))

                    avg_psnr = avg_psnr / idx

                    # log
                    logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr', avg_psnr, current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)
                tb_logger.flush()
            if is_time:
                torch.cuda.synchronize()
                end = time.time()

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
    tb_logger.close()