Exemplo n.º 1
0
def validate(val_loader, model, logger, epoch, current_step, val_dataset_opt):
    print('---------- validation -------------')
    val_start_time = time.time()
    model.eval()  # Change to eval mode. It is important for BN layers.

    val_results = OrderedDict()
    avg_psnr = 0.0
    idx = 0
    for val_data in val_loader:
        idx += 1
        img_path = val_data['LR_path'][0]
        img_name = os.path.splitext(os.path.basename(img_path))[0]
        img_dir = os.path.join(opt['path']['val_images'], img_name)
        util.mkdir(img_dir)

        model.feed_data(val_data, volatile=True)
        model.val()

        visuals = model.get_current_visuals()

        sr_img = util.tensor2img_np(visuals['SR'])  # uint8
        gt_img = util.tensor2img_np(visuals['HR'])  # uint8
        # # modcrop
        # gt_img = util.modcrop(gt_img, val_dataset_opt['scale'])
        h_min = min(sr_img.shape[0], gt_img.shape[0])
        w_min = min(sr_img.shape[1], gt_img.shape[1])
        sr_img = sr_img[0:h_min, 0:w_min, :]
        gt_img = gt_img[0:h_min, 0:w_min, :]

        crop_size = val_dataset_opt['scale'] + 2
        cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
        cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]

        # Save SR images for reference
        save_img_path = os.path.join(img_dir,
                                     '%s_%s.png' % (img_name, current_step))
        util.save_img_np(sr_img.squeeze(), save_img_path)

        # TODO need to modify
        # metric_mode = val_dataset_opt['metric_mode']
        # if metric_mode == 'y':
        #     cropped_sr_img = util.rgb2ycbcr(cropped_sr_img, only_y=True)
        #     cropped_gt_img = util.rgb2ycbcr(cropped_gt_img, only_y=True)

        avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img)

    avg_psnr = avg_psnr / idx

    val_elapsed = time.time() - val_start_time
    # Save to log
    print_rlt = OrderedDict()
    print_rlt['model'] = opt['model']
    print_rlt['epoch'] = epoch
    print_rlt['iters'] = current_step
    print_rlt['time'] = val_elapsed
    print_rlt['psnr'] = avg_psnr
    logger.print_format_results('val', print_rlt)
    model.train()  # change back to train mode.
    print('-----------------------------------')
Exemplo n.º 2
0
def validate(val_loader, opt, model, current_step, epoch, logger):
    print('---------- validation -------------')
    start_time = time.time()

    avg_psnr = 0.0
    avg_lpips = 0.0
    idx = 0
    for val_data in val_loader:
        idx += 1
        img_name = os.path.splitext(os.path.basename(
            val_data['LR_path'][0]))[0]
        img_dir = os.path.join(opt['path']['val_images'], img_name)
        util.mkdir(img_dir)

        tensor_type = torch.zeros if opt['train']['zero_code'] else torch.randn
        code = model.gen_code(val_data['LR'].shape[0],
                              val_data['LR'].shape[2],
                              val_data['LR'].shape[3],
                              tensor_type=tensor_type)
        model.feed_data(val_data, code=code)
        model.test()

        visuals = model.get_current_visuals()
        sr_img = util.tensor2img(visuals['HR_pred'])  # uint8
        gt_img = util.tensor2img(visuals['HR'])  # uint8

        # Save generated images for reference
        save_img_path = os.path.join(
            img_dir, '{:s}_{:s}_{:d}.png'.format(opt['name'], img_name,
                                                 current_step))
        util.save_img(sr_img, save_img_path)

        # calculate PSNR
        sr_img = sr_img
        gt_img = gt_img
        avg_psnr += util.psnr(sr_img, gt_img)

        avg_lpips += torch.sum(model.get_loss(level=-1))

    if current_step == 0:
        print('Saving the model at the end of iter {:d}.'.format(current_step))
        model.save(current_step)

    avg_psnr = avg_psnr / idx
    avg_lpips = avg_lpips / idx
    time_elapsed = time.time() - start_time
    # Save to log
    print_rlt = OrderedDict()
    print_rlt['model'] = opt['model']
    print_rlt['epoch'] = epoch
    print_rlt['iters'] = current_step
    print_rlt['time'] = time_elapsed
    print_rlt['psnr'] = avg_psnr
    print_rlt['lpips'] = avg_lpips
    logger.print_format_results('val', print_rlt)
    print('-----------------------------------')
Exemplo n.º 3
0
def validate(val_loader, opt, model, current_step, epoch, logger):
    print('---------- validation -------------')
    start_time = time.time()

    avg_psnr = 0.0
    avg_lips = 0.0
    idx = 0
    for val_data in val_loader:
        idx += 1
        img_name = os.path.splitext(os.path.basename(
            val_data['LR_path'][0]))[0]
        img_dir = os.path.join(opt['path']['val_images'], img_name)
        util.mkdir(img_dir)

        code_val_0 = torch.randn(val_data['LR'].shape[0],
                                 int(opt['network_G']['in_code_nc']),
                                 val_data['LR'].shape[2],
                                 val_data['LR'].shape[3])
        model.feed_data(val_data, code=[code_val_0])
        model.test()

        visuals = model.get_current_visuals()
        sr_img = util.tensor2img(visuals['SR'])  # uint8
        gt_img = util.tensor2img(visuals['HR'])  # uint8

        # Save SR images for reference
        arch_name = opt['name'].split("_")[2]
        run_index = opt['name'].split("_")[3]
        save_img_path = os.path.join(
            img_dir, 'HyperRIM_{:s}_{:s}_{:s}_x2_{:d}.png'.format(
                arch_name, run_index, img_name, current_step))
        util.save_img(sr_img, save_img_path)

        # calculate PSNR
        avg_psnr += util.psnr(sr_img, gt_img)

        avg_lips += torch.sum(model.get_loss(level=-1))

    if current_step == 0:
        print('Saving the model at the end of iter {:d}.'.format(current_step))
        model.save(current_step)

    avg_psnr = avg_psnr / idx
    avg_lips = avg_lips / idx
    time_elapsed = time.time() - start_time
    # Save to log
    print_rlt = OrderedDict()
    print_rlt['model'] = opt['model']
    print_rlt['epoch'] = epoch
    print_rlt['iters'] = current_step
    print_rlt['time'] = time_elapsed
    print_rlt['psnr'] = avg_psnr
    print_rlt['lpips'] = avg_lips
    logger.print_format_results('val', print_rlt)
    print('-----------------------------------')
Exemplo n.º 4
0
def get_sim(root_path, ori_path, n_samples):
    img_list = sorted(glob.glob(ori_path))
    total_psnr = 0.
    total_ssim = 0.
    count = 0
    for i, v in enumerate(img_list):
        img_name = v.split("/")[-1].split(".")[0]
        img0_np = load_image(v)
        for j in range(n_samples):
            img1_np = load_image(root_path + img_name + "_" + str(j) + ".png")
            total_psnr += util.psnr(img0_np, img1_np)
            total_ssim += util.ssim(img0_np, img1_np, multichannel=True)
            count += 1

    return total_psnr / count, total_ssim / count
Exemplo n.º 5
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)

    util.mkdir_and_rename(
        opt['path']['experiments_root'])  # rename old experiments if exists
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \
        not key == 'pretrain_model_G' and not key == 'pretrain_model_D'))
    option.save(opt)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # print to file and std_out simultaneously
    sys.stdout = PrintLogger(opt['path']['log'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            print('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epoches = int(math.ceil(total_iters / train_size))
            print('Total epoches needed: {:d} for iters {:,d}'.format(
                total_epoches, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_dataset_opt = dataset_opt
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # Create model
    model = create_model(opt)
    # create logger
    logger = Logger(opt)

    current_step = 0
    start_time = time.time()
    print('---------- Start training -------------')
    for epoch in range(total_epoches):
        for i, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            time_elapsed = time.time() - start_time
            start_time = time.time()

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                for k, v in logs.items():
                    print_rlt[k] = v
                print_rlt['lr'] = model.get_current_learning_rate()
                logger.print_format_results('train', print_rlt)

            # save models
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                print('Saving the model at the end of iter {:d}.'.format(
                    current_step))
                model.save(current_step)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                print('---------- validation -------------')
                start_time = time.time()

                avg_psnr = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    if opt['crop_scale'] is not None:
                        crop_size = opt['crop_scale']
                    else:
                        crop_size = opt['scale']
                    if crop_size <= 0:
                        cropped_sr_img = sr_img.copy()
                        cropped_gt_img = gt_img.copy()
                    else:
                        if len(gt_img.shape) < 3:
                            cropped_sr_img = sr_img[crop_size:-crop_size,
                                                    crop_size:-crop_size]
                            cropped_gt_img = gt_img[crop_size:-crop_size,
                                                    crop_size:-crop_size]
                        else:
                            cropped_sr_img = sr_img[crop_size:-crop_size,
                                                    crop_size:-crop_size, :]
                            cropped_gt_img = gt_img[crop_size:-crop_size,
                                                    crop_size:-crop_size, :]
                    #avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img)
                    cropped_sr_img_y = bgr2ycbcr(cropped_sr_img, only_y=True)
                    cropped_gt_img_y = bgr2ycbcr(cropped_gt_img, only_y=True)
                    avg_psnr += util.psnr(
                        cropped_sr_img_y,
                        cropped_gt_img_y)  ##########only y channel

                avg_psnr = avg_psnr / idx
                time_elapsed = time.time() - start_time
                # Save to log
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                print_rlt['psnr'] = avg_psnr
                logger.print_format_results('val', print_rlt)
                print('-----------------------------------')

            # update learning rate
            model.update_learning_rate()

    print('Saving the final model.')
    model.save('latest')
    print('End of training.')
Exemplo n.º 6
0
        img_name = os.path.splitext(os.path.basename(img_path))[0]

        model.test()  # test
        visuals = model.get_current_visuals(need_HR=need_HR)

        sr_img = util.tensor2img(visuals['SR'])  # uint8

        if need_HR:  # load GT image and calculate psnr
            gt_img = util.tensor2img(visuals['HR'])

            crop_border = test_loader.dataset.opt['scale']
            cropped_sr_img = sr_img[crop_border:-crop_border,
                                    crop_border:-crop_border, :]
            cropped_gt_img = gt_img[crop_border:-crop_border,
                                    crop_border:-crop_border, :]
            psnr = util.psnr(cropped_sr_img, cropped_gt_img)
            ssim = util.ssim(cropped_sr_img, cropped_gt_img, multichannel=True)
            test_results['psnr'].append(psnr)
            test_results['ssim'].append(ssim)
            if gt_img.shape[2] == 3:  # RGB image
                cropped_sr_img_y = rgb2ycbcr(cropped_sr_img, only_y=True)
                cropped_gt_img_y = rgb2ycbcr(cropped_gt_img, only_y=True)
                psnr_y = util.psnr(cropped_sr_img_y, cropped_gt_img_y)
                ssim_y = util.ssim(cropped_sr_img_y,
                                   cropped_gt_img_y,
                                   multichannel=False)
                test_results['psnr_y'].append(psnr_y)
                test_results['ssim_y'].append(ssim_y)
                print('{:20s} - PSNR: {:.4f} dB; SSIM: {:.4f}; PSNR_Y: {:.4f} dB; SSIM_Y: {:.4f}.'\
                    .format(img_name, psnr, ssim, psnr_y, ssim_y))
            else:
Exemplo n.º 7
0
def validate(val_loader, opt, model, current_step, epoch, logger):
    print('---------- validation -------------')
    start_time = time.time()

    avg_psnr = 0.0
    idx = 0
    for val_data in val_loader:
        idx += 1
        img_name = os.path.splitext(os.path.basename(
            val_data['LR_path'][0]))[0]
        img_dir = os.path.join(opt['path']['val_images'], img_name)
        util.mkdir(img_dir)

        if 'zero_code' in opt['train'] and opt['train']['zero_code']:
            code_val_0 = torch.zeros(val_data['LR'].shape[0],
                                     int(opt['network_G']['in_code_nc']),
                                     val_data['LR'].shape[2] * 2,
                                     val_data['LR'].shape[3] * 2)
            code_val_1 = torch.zeros(val_data['LR'].shape[0],
                                     int(opt['network_G']['in_code_nc']),
                                     val_data['LR'].shape[2] * 4,
                                     val_data['LR'].shape[3] * 4)
        elif 'rand_code' in opt['train'] and opt['train']['rand_code']:
            code_val_0 = torch.rand(val_data['LR'].shape[0],
                                    int(opt['network_G']['in_code_nc']),
                                    val_data['LR'].shape[2] * 2,
                                    val_data['LR'].shape[3] * 2)
            code_val_1 = torch.rand(val_data['LR'].shape[0],
                                    int(opt['network_G']['in_code_nc']),
                                    val_data['LR'].shape[2] * 4,
                                    val_data['LR'].shape[3] * 4)
        else:
            code_val_0 = torch.randn(val_data['LR'].shape[0],
                                     int(opt['network_G']['in_code_nc']),
                                     val_data['LR'].shape[2] * 2,
                                     val_data['LR'].shape[3] * 2)
            code_val_1 = torch.randn(val_data['LR'].shape[0],
                                     int(opt['network_G']['in_code_nc']),
                                     val_data['LR'].shape[2] * 4,
                                     val_data['LR'].shape[3] * 4)
        model.feed_data(val_data, code=[code_val_0, code_val_1])
        model.test()

        visuals = model.get_current_visuals()
        sr_img = util.tensor2img(visuals['SR'])  # uint8
        gt_img = util.tensor2img(visuals['HR'])  # uint8

        # Save SR images for reference
        save_img_path = os.path.join(
            img_dir, 'caffe_{:s}_x4_{:d}.png'.format(img_name, current_step))
        util.save_img(sr_img, save_img_path)

        # calculate PSNR
        crop_size = opt['scale']
        cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
        cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
        avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img)

    if current_step == 0:
        print('Saving the model at the end of iter {:d}.'.format(current_step))
        model.save(current_step)

    avg_psnr = avg_psnr / idx
    time_elapsed = time.time() - start_time
    # Save to log
    print_rlt = OrderedDict()
    print_rlt['model'] = opt['model']
    print_rlt['epoch'] = epoch
    print_rlt['iters'] = current_step
    print_rlt['time'] = time_elapsed
    print_rlt['psnr'] = avg_psnr
    logger.print_format_results('val', print_rlt)
    print('-----------------------------------')
Exemplo n.º 8
0
        # For generating multiple samples of the same input image
        for run_index in range(multiple):
            code = model.gen_code(data['LR'].shape[0], data['LR'].shape[2],
                                  data['LR'].shape[3])
            model.feed_data(data, code=code, need_HR=need_HR)
            model.test()

            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            visuals = model.get_current_visuals(need_HR=need_HR)
            sr_img = util.tensor2img(visuals['HR_pred'])  # uint8

            if need_HR:  # load target image and calculate metric scores
                gt_img = util.tensor2img(visuals['HR'])
                psnr = util.psnr(sr_img, gt_img)
                ssim = util.ssim(sr_img, gt_img, multichannel=True)
                lpips = torch.sum(model.get_loss(level=-1))
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)
                test_results['lpips'].append(lpips)
                print('{:20s} - LPIPS: {:.4f}; PSNR: {:.4f} dB; SSIM: {:.4f}.'.
                      format(img_name, lpips, psnr, ssim))
            else:
                print(img_name)

            save_img_path = os.path.join(dataset_dir,
                                         img_name + '_%d.png' % run_index)
            util.save_img(sr_img, save_img_path)

    if need_HR:  # metrics
Exemplo n.º 9
0
def main():
    # Settings
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt) #load settings and initialize settings

    util.mkdir_and_rename(opt['path']['experiments_root'])  # rename old experiments if exists
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \
        not key == 'saved_model'))
    option.save(opt)
    opt = option.dict_to_nonedict(opt)  # Convert to NoneDict, which return None for missing key.

    # Redirect all writes to the "txt" file
    sys.stdout = PrintLogger(opt['path']['log'])

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
            print('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epoches = int(math.ceil(total_iters / train_size))
            print('Total epoches needed: {:d} for iters {:,d}'.format(total_epoches, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [{:s}]: {:d}'.format(dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # Create model
    model = create_model(opt)
    # Create logger
    logger = Logger(opt)

    current_step = 0
    start_time = time.time()
    print('---------- Start training -------------')
    for epoch in range(total_epoches):
        for i, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            time_elapsed = time.time() - start_time
            start_time = time.time()

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                for k, v in logs.items():
                    print_rlt[k] = v
                print_rlt['lr'] = model.get_current_learning_rate()
                logger.print_format_results('train', print_rlt)

            # save models
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                print('Saving the model at the end of iter {:d}.'.format(current_step))
                model.save(current_step)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                print('---------- validation -------------')
                start_time = time.time()

                avg_psnr = 0.0
                avg_ssim =0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(os.path.basename(val_data['GT_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    out_img = util.tensor2img(visuals['Output'])
                    gt_img = util.tensor2img(visuals['ground_truth'])  # uint8

                    # Save output images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(out_img, save_img_path)

                    # calculate PSNR
                    if len(gt_img.shape) == 2:
                        gt_img = np.expand_dims(gt_img, axis=2)
                        out_img = np.expand_dims(out_img, axis=2)
                    crop_border = opt['scale']
                    cropped_out_img = out_img[crop_border:-crop_border, crop_border:-crop_border, :]
                    cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :]
                    if gt_img.shape[2] == 3:  # RGB image
                        cropped_out_img_y = bgr2ycbcr(cropped_out_img, only_y=True)
                        cropped_gt_img_y = bgr2ycbcr(cropped_gt_img, only_y=True)
                        avg_psnr += util.psnr(cropped_out_img_y, cropped_gt_img_y)
                        avg_ssim += util.ssim(cropped_out_img_y, cropped_gt_img_y, multichannel=False)
                    else:
                        avg_psnr += util.psnr(cropped_out_img, cropped_gt_img)
                        avg_ssim += util.ssim(cropped_out_img, cropped_gt_img, multichannel=True)

                avg_psnr = avg_psnr / idx
                avg_ssim = avg_ssim / idx
                time_elapsed = time.time() - start_time
                # Save to log
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                print_rlt['psnr'] = avg_psnr
                print_rlt['ssim'] = avg_ssim
                logger.print_format_results('val', print_rlt)
                print('-----------------------------------')

            # update learning rate
            model.update_learning_rate()

    print('Saving the final model.')
    model.save('latest')
    print('End of training.')
Exemplo n.º 10
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)

    util.mkdir_and_rename(
        opt['path']['experiments_root'])  # rename old experiments if exists
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \
        not key == 'pretrain_model_G' and not key == 'pretrain_model_D'))
    option.save(opt)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # print to file and std_out simultaneously
    sys.stdout = PrintLogger(opt['path']['log'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            print('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epoches = int(math.ceil(total_iters / train_size))
            print('Total epoches needed: {:d} for iters {:,d}'.format(
                total_epoches, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
            batch_size_per_month = dataset_opt['batch_size']
            batch_size_per_day = int(
                opt['datasets']['train']['batch_size_per_day'])
            num_month = int(opt['train']['num_month'])
            num_day = int(opt['train']['num_day'])
            use_dci = false if 'use_dci' not in opt['train'] else opt['train'][
                'use_dci']
        elif phase == 'val':
            val_dataset_opt = dataset_opt
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # Create model
    model = create_model(opt)
    # create logger
    logger = Logger(opt)

    current_step = 0
    start_time = time.time()
    print('---------- Start training -------------')
    for epoch in range(num_month):
        for i, train_data in enumerate(train_loader):
            # get the code
            if use_dci:
                cur_month_code = get_code_for_data(model, train_data, opt)
            else:
                cur_month_code = get_code(model, train_data, opt)
            for j in range(num_day):
                current_step += 1
                if current_step > total_iters:
                    break
                # get the sliced data
                cur_day_batch_start_idx = (
                    j * batch_size_per_day) % batch_size_per_month
                cur_day_batch_end_idx = cur_day_batch_start_idx + batch_size_per_day
                if cur_day_batch_end_idx > batch_size_per_month:
                    cur_day_batch_idx = np.hstack(
                        (np.arange(cur_day_batch_start_idx,
                                   batch_size_per_month),
                         np.arange(cur_day_batch_end_idx -
                                   batch_size_per_month)))
                else:
                    cur_day_batch_idx = slice(cur_day_batch_start_idx,
                                              cur_day_batch_end_idx)

                cur_day_train_data = {
                    'LR': train_data['LR'][cur_day_batch_idx],
                    'HR': train_data['HR'][cur_day_batch_idx]
                }
                code = cur_month_code[cur_day_batch_idx]

                # training
                model.feed_data(cur_day_train_data, code=code)
                model.optimize_parameters(current_step)

                time_elapsed = time.time() - start_time
                start_time = time.time()

                # log
                if current_step % opt['logger']['print_freq'] == 0:
                    logs = model.get_current_log()
                    print_rlt = OrderedDict()
                    print_rlt['model'] = opt['model']
                    print_rlt['epoch'] = epoch
                    print_rlt['iters'] = current_step
                    print_rlt['time'] = time_elapsed
                    for k, v in logs.items():
                        print_rlt[k] = v
                    print_rlt['lr'] = model.get_current_learning_rate()
                    logger.print_format_results('train', print_rlt)

                # save models
                if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                    print('Saving the model at the end of iter {:d}.'.format(
                        current_step))
                    model.save(current_step)

                # validation
                if current_step % opt['train']['val_freq'] == 0:
                    print('---------- validation -------------')
                    start_time = time.time()

                    avg_psnr = 0.0
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(
                            os.path.basename(val_data['LR_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'],
                                               img_name)
                        util.mkdir(img_dir)

                        if 'zero_code' in opt['train'] and opt['train'][
                                'zero_code']:
                            code_val = torch.zeros(
                                val_data['LR'].shape[0],
                                int(opt['network_G']['in_code_nc']),
                                val_data['LR'].shape[2],
                                val_data['LR'].shape[3])
                        elif 'rand_code' in opt['train'] and opt['train'][
                                'rand_code']:
                            code_val = torch.rand(
                                val_data['LR'].shape[0],
                                int(opt['network_G']['in_code_nc']),
                                val_data['LR'].shape[2],
                                val_data['LR'].shape[3])
                        else:
                            code_val = torch.randn(
                                val_data['LR'].shape[0],
                                int(opt['network_G']['in_code_nc']),
                                val_data['LR'].shape[2],
                                val_data['LR'].shape[3])

                        model.feed_data(val_data, code=code_val)
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals['SR'])  # uint8
                        gt_img = util.tensor2img(visuals['HR'])  # uint8

                        # Save SR images for reference
                        run_index = opt['name'].split("_")[2]
                        save_img_path = os.path.join(img_dir, 'srim_{:s}_{:s}_{:d}.png'.format( \
                            run_index, img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        crop_size = opt['scale']
                        cropped_sr_img = sr_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                        cropped_gt_img = gt_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                        avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img)

                    avg_psnr = avg_psnr / idx
                    time_elapsed = time.time() - start_time
                    # Save to log
                    print_rlt = OrderedDict()
                    print_rlt['model'] = opt['model']
                    print_rlt['epoch'] = epoch
                    print_rlt['iters'] = current_step
                    print_rlt['time'] = time_elapsed
                    print_rlt['psnr'] = avg_psnr
                    logger.print_format_results('val', print_rlt)
                    print('-----------------------------------')

                # update learning rate
                model.update_learning_rate()

    print('Saving the final model.')
    model.save('latest')
    print('End of training.')