예제 #1
0
def main():
    ## options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    args = parser.parse_args()
    if args.opt is not None:
        opt = parse(args.opt)
    else:
        opt = parse('./options/train_ESRGAN.yml')

    # Instantiate the generator
    model = make_generator(opt, print_summary=True)
    
    # Load the weights from a .h5 file
    model.load_weights('./experiments/backup/full/RRDB_GAN.h5')
    
    # Add a dummy Input layer that allows for inputs of arbitrary size (needed for SavedModel format)
    input = tf.keras.layers.Input(shape=(None,None,1))
    
    # Create a new model with the input layer
    out = model(input)
    newModel = tf.keras.models.Model(input,out)
    
    # Save the model with the SavedModel format
    newModel.save('./experiments/pretrained_models/RRDB_GAN', save_format='tf')
예제 #2
0
파일: train.py 프로젝트: sunchang2017/IKC
def main():
    #### setup options of three networks
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt_P', type=str, help='Path to option YMAL file of Predictor.')
    parser.add_argument('-opt_C', type=str, help='Path to option YMAL file of Corrector.')
    parser.add_argument('-opt_F', type=str, help='Path to option YMAL file of SFTMD_Net.')
    parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt_P = option.parse(args.opt_P, is_train=True)
    opt_C = option.parse(args.opt_C, is_train=True)
    opt_F = option.parse(args.opt_F, is_train=True)

    # convert to NoneDict, which returns None for missing keys
    opt_P = option.dict_to_nonedict(opt_P)
    opt_C = option.dict_to_nonedict(opt_C)
    opt_F = option.dict_to_nonedict(opt_F)

    # choose small opt for SFTMD test
    opt_F = opt_F['sftmd']

    #### random seed
    seed = opt_P['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    util.set_random_seed(seed)

    # create PCA matrix of enough kernel
    batch_ker = util.random_batch_kernel(batch=30000, l=opt_P['kernel_size'], sig_min=0.2, sig_max=4.0, rate_iso=1.0, scaling=3, tensor=False)
    print('batch kernel shape: {}'.format(batch_ker.shape))
    b = np.size(batch_ker, 0)
    batch_ker = batch_ker.reshape((b, -1))
    pca_matrix = util.PCA(batch_ker, k=opt_P['code_length']).float()
    print('PCA matrix shape: {}'.format(pca_matrix.shape))

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt_P['dist'] = False
        opt_F['dist'] = False
        opt_C['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt_P['dist'] = True
        opt_F['dist'] = True
        opt_C['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size() #Returns the number of processes in the current process group
        rank = torch.distributed.get_rank() #Returns the rank of current process group

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    ###### SFTMD train ######
    SFTMD_train(opt_F, rank, world_size, pca_matrix)

    ###### Predictor&Corrector train ######
    IKC_train(opt_P, opt_C, opt_F, rank, world_size, pca_matrix)
예제 #3
0
def main():
    ## options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    args = parser.parse_args()
    if args.opt is not None:
        opt = parse(args.opt)
    else:
        opt = parse('./options/train_ESRGAN.yml')

    train(opt)
예제 #4
0
파일: op_counter.py 프로젝트: gcrth/sr
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('--opt', type=str, help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    print('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = False
    # torch.backends.cudnn.deterministic = True

    #### create model
    model = create_model(opt)

    #### op counting
    print('Start counting')

    var_L = torch.zeros(1, 3, 320, 180).cuda()
    # var_ref=torch.zeros(1280,720).cuda()
    # var_H=torch.zeros(1280,720).cuda()
    print('netG')
    macs, params = profile(model.netG, inputs=(var_L, ))
    macs, params = clever_format([macs, params], "%.5f")
    print('macs:{},params:{}'.format(macs, params))
예제 #5
0
def main():
    ## options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    args = parser.parse_args()
    if args.opt is not None:
        opt = parse(args.opt)
    else:
        opt = parse('./options/train_ESRGAN.yml')

    if not os.path.exists(opt['path']['pretrained_model_G']):
        print("Pretrain checkpoint not found, starting pretraining...")
        pretrain(opt)
    else:
        print("Pretrained model already found: " +
              opt['path']['pretrained_model_G'] + ", aborting")
예제 #6
0
def downloadModel():
    # json parse
    parser = argparse.ArgumentParser(
        description='Test Super Resolution Models')
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt)
    opt = option.dict_to_nonedict(opt)

    # json parse된것 초기화
    scale = opt['scale']
    degrad = opt['degradation']
    network_opt = opt['networks']
    model_name = network_opt['which_model'].upper()
    if opt['self_ensemble']: model_name += 'plus'

    #json파일로 model로드
    solver = create_solver(opt)

    #testset SR한번 후 본격 SR 진행
    shutil.copy('./results/LR/Test/!.png', './results/LR/MyImage/!.png')
    shutil.copy('./results/LR/Test/!.png', './results/LR/MyImage/!!.png')
    SR(solver, opt, model_name)
    os.remove('./results/LR/MyImage/!!.png')
    return solver, opt, model_name
예제 #7
0
class SISR():
    def __init__(): return
        parser = argparse.ArgumentParser(description='Test Super Resolution Models')
        parser.add_argument('-opt', type=str, required=True, help='Path to options JSON file.')
        opt = option.parse(parser.parse_args().opt)
        opt = option.dict_to_nonedict(opt)

        #initial configure
        scale = opt['scale']
        degrad = opt['degradation']
        network_opt = opt['networks']
        model_name = network_opt['which_model'].upper()
        if opt['self_ensemble']: model_name += 'plus'

        for _, dataset_opt in sorted(opt['datasets'].items()):
            test_set = create_dataset(dataset_opt)
            test_loader = create_dataloader(test_set, dataset_opt)
            test_loaders.append(test_loader)
            print('===> Test Dataset: [%s]   Number of images: [%d]' % (test_set.name(), len(test_set)))
            bm_names.append(test_set.name())

        # create solver (and load model)
        solver = create_solver(opt)

        # create test dataloader
        bm_names =[]
        test_loaders = []

        print("==================================================")
        print("Method: %s || Scale: %d || Degradation: %s"%(model_name, scale, degrad))
예제 #8
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, required=True, help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    #util.mkdirs((path for key, path in opt['path'].items() if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    #util.setup_logger(None, opt['path']['log'], 'test.log', level=logging.INFO, screen=True)
    #logger = logging.getLogger('base')
    #logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        #logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)

    # Create model
    model = create_model(opt)

    modelKey = 'SR'
    if opt['model'] == 'ppon':
        modelKey = 'img_p'
        print('Model is recognized as PPON model')

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        print('\nTesting [{:s}]...'.format(test_set_name))
        #logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        #dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        dataset_dir = test_loader.dataset.opt['dataroot_HR']
        #util.mkdir(dataset_dir)

        idx = 0
        for data in test_loader:
            idx += 1
            need_HR = False #if test_loader.dataset.opt['dataroot_HR'] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)

			if opt['model'] == 'ppon':
                sr_img_c = util.tensor2img(visuals['img_c'])
                sr_img_s = util.tensor2img(visuals['img_s']) 
            sr_img = util.tensor2img(visuals[modelKey])

            # save images
            baseinput = os.path.splitext(os.path.basename(img_path))[0][:-8]
            model_path = opt['path']['pretrain_model_G']
            modelname = os.path.splitext(os.path.basename(model_path))[0]
            save_img_path = os.path.join(dataset_dir, img_name + '.png')
			if opt['model'] == 'ppon':
예제 #9
0
def load_model(conf_path):
    opt = option.parse(conf_path, is_train=False)
    opt['gpu_ids'] = None
    opt = option.dict_to_nonedict(opt)
    model = create_model(opt)

    model_path = opt_get(opt, ['model_path'], None)
    model.load_network(load_path=model_path, network=model.netG)
    return model, opt
예제 #10
0
 def load_model(self):
     opt_path = './options/test/test_EDVR_M_AI4KHDR.yml'
     opt = option.parse(opt_path, is_train=False)
     opt['dist'] = False
     rank = -1
     opt = option.dict_to_nonedict(opt)
     torch.backends.cudnn.benchmark = True
     #### create model
     model = create_model(opt)
     return model
예제 #11
0
def main():
    assert torch.cuda.is_available()

    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.
def setup(opts):
    global opt
    model_scale = opts["scale"]
    model = model_scale + "/" + "model.pth"
    config = model_scale + "/" + "config.json"

    opt = option.parse(config)
    opt = option.dict_to_nonedict(opt)

    solver = create_solver(opt, model)

    return solver
예제 #13
0
def get_options(json_path):
    """options"""
    #  parser = argparse.ArgumentParser()
    #  parser.add_argument(
    #  '-opt', type=str, required=True, help='Path to options JSON file.')
    #  opt = option.parse(parser.parse_args().opt, is_train=False)
    is_train = False
    opt = option.parse(json_path, is_train)
    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None,
                      opt['path']['log'],
                      'test',
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    return opt, logger
예제 #14
0
def run(pretrained_path,
        output_path,
        model_name='SRFBN',
        scale=4,
        degrad='BI',
        opt='options/test/test_SRFBN_example.json'):
    opt = option.parse(opt)
    opt = option.dict_to_nonedict(opt)
    # model = create_model(opt)
    model = define_net({
        "scale": scale,
        "which_model": "SRFBN",
        "num_features": 64,
        "in_channels": 3,
        "out_channels": 3,
        "num_steps": 4,
        "num_groups": 6
    })

    img = common.read_img('./results/LR/MyImage/chip.png', 'img')

    np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1)))
    tensor = torch.from_numpy(np_transpose).float()
    lr_tensor = torch.unsqueeze(tensor, 0)

    checkpoint = torch.load(pretrained_path)
    if 'state_dict' in checkpoint.keys():
        checkpoint = checkpoint['state_dict']
    load_func = model.load_state_dict
    load_func(checkpoint)
    torch.save(model, './model.pt')

    with torch.no_grad():
        SR = model(lr_tensor)[0]
    # visuals = np.transpose(SR.data[0].float().cpu().numpy(), (1, 2, 0)).astype(np.uint8)
    visuals = np.transpose(SR.data[0].float().cpu().numpy(),
                           (1, 2, 0)).astype(np.uint8)
    imageio.imwrite(output_path, visuals)
예제 #15
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)

    util.mkdir_and_rename(
        opt['path']['experiments_root'])  # rename old experiments if exists
    util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' and \
        not key == 'pretrain_model_G' and not key == 'pretrain_model_D'))
    option.save(opt)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # print to file and std_out simultaneously
    sys.stdout = PrintLogger(opt['path']['log'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    print("Random Seed: ", seed)
    random.seed(seed)
    torch.manual_seed(seed)

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            print('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epoches = int(math.ceil(total_iters / train_size))
            print('Total epoches needed: {:d} for iters {:,d}'.format(
                total_epoches, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_dataset_opt = dataset_opt
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            print('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # Create model
    model = create_model(opt)
    # create logger
    logger = Logger(opt)

    current_step = 0
    start_time = time.time()
    print('---------- Start training -------------')
    for epoch in range(total_epoches):
        for i, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            time_elapsed = time.time() - start_time
            start_time = time.time()

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                for k, v in logs.items():
                    print_rlt[k] = v
                print_rlt['lr'] = model.get_current_learning_rate()
                logger.print_format_results('train', print_rlt)

            # save models
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                print('Saving the model at the end of iter {:d}.'.format(
                    current_step))
                model.save(current_step)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                print('---------- validation -------------')
                start_time = time.time()

                avg_psnr = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    if opt['crop_scale'] is not None:
                        crop_size = opt['crop_scale']
                    else:
                        crop_size = opt['scale']
                    if crop_size <= 0:
                        cropped_sr_img = sr_img.copy()
                        cropped_gt_img = gt_img.copy()
                    else:
                        if len(gt_img.shape) < 3:
                            cropped_sr_img = sr_img[crop_size:-crop_size,
                                                    crop_size:-crop_size]
                            cropped_gt_img = gt_img[crop_size:-crop_size,
                                                    crop_size:-crop_size]
                        else:
                            cropped_sr_img = sr_img[crop_size:-crop_size,
                                                    crop_size:-crop_size, :]
                            cropped_gt_img = gt_img[crop_size:-crop_size,
                                                    crop_size:-crop_size, :]
                    #avg_psnr += util.psnr(cropped_sr_img, cropped_gt_img)
                    cropped_sr_img_y = bgr2ycbcr(cropped_sr_img, only_y=True)
                    cropped_gt_img_y = bgr2ycbcr(cropped_gt_img, only_y=True)
                    avg_psnr += util.psnr(
                        cropped_sr_img_y,
                        cropped_gt_img_y)  ##########only y channel

                avg_psnr = avg_psnr / idx
                time_elapsed = time.time() - start_time
                # Save to log
                print_rlt = OrderedDict()
                print_rlt['model'] = opt['model']
                print_rlt['epoch'] = epoch
                print_rlt['iters'] = current_step
                print_rlt['time'] = time_elapsed
                print_rlt['psnr'] = avg_psnr
                logger.print_format_results('val', print_rlt)
                print('-----------------------------------')

            # update learning rate
            model.update_learning_rate()

    print('Saving the final model.')
    model.save('latest')
    print('End of training.')
예제 #16
0
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument("-opt", type=str, help="Path to option YAML file.")
    parser.add_argument("--launcher",
                        choices=["none", "pytorch"],
                        default="none",
                        help="job launcher")
    parser.add_argument("--local_rank", type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    # distributed training settings
    if args.launcher == "none":  # disabled distributed training
        opt["dist"] = False
        rank = -1
        print("Disabled distributed training.")
    else:
        opt["dist"] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    # loading resume state if exists
    if opt["path"].get("resume_state", None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt["path"]["resume_state"],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state["iter"])  # check resume options
    else:
        resume_state = None

    # mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt["path"]
                ["experiments_root"])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt["path"].items()
                 if not key == "experiments_root"
                 and "pretrain_model" not in key and "resume" not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger("base",
                          opt["path"]["log"],
                          "train_" + opt["name"],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger("base")
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt["use_tb_logger"] and "debug" not in opt["name"]:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info("You are using PyTorch {}. \
                            Tensorboard will use [tensorboardX]".format(
                    version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir="../tb_logger/" + opt["name"])
    else:
        util.setup_logger("base",
                          opt["path"]["log"],
                          "train",
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger("base")

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    # random seed
    seed = opt["train"]["manual_seed"]
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info("Random seed: {}".format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt["datasets"].items():
        if phase == "train":
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt["batch_size"]))
            total_iters = int(opt["train"]["niter"])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt["dist"]:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    "Number of train images: {:,d}, iters: {:,d}".format(
                        len(train_set), train_size))
                logger.info("Total epochs needed: {:d} for iters {:,d}".format(
                    total_epochs, total_iters))
        elif phase == "val":
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info("Number of val images in [{:s}]: {:d}".format(
                    dataset_opt["name"], len(val_set)))
        else:
            raise NotImplementedError(
                "Phase [{:s}] is not recognized.".format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)
    print("Model created!")

    # resume training
    if resume_state:
        logger.info("Resuming training from epoch: {}, iter: {}.".format(
            resume_state["epoch"], resume_state["iter"]))

        start_epoch = resume_state["epoch"]
        current_step = resume_state["iter"]
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info("Start training from epoch: {:d}, iter: {:d}".format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt["dist"]:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            # update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt["train"]["warmup_iter"])

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # log
            if current_step % opt["logger"]["print_freq"] == 0:
                logs = model.get_current_log()
                message = "[epoch:{:3d}, iter:{:8,d}, lr:(".format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += "{:.3e},".format(v)
                message += ")] "
                for k, v in logs.items():
                    message += "{:s}: {:.4e} ".format(k, v)
                    # tensorboard logger
                    if opt["use_tb_logger"] and "debug" not in opt["name"]:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)
            # validation
            if opt["datasets"].get(
                    "val",
                    None) and current_step % opt["train"]["val_freq"] == 0:
                # image restoration validation
                if opt["model"] in ["sr", "srgan"] and rank <= 0:
                    # does not support multi-GPU validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.0
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(
                            os.path.basename(val_data["LQ_path"][0]))[0]
                        img_dir = os.path.join(opt["path"]["val_images"],
                                               img_name)
                        util.mkdir(img_dir)

                        model.feed_data(val_data)
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals["rlt"])  # uint8
                        gt_img = util.tensor2img(visuals["GT"])  # uint8

                        # Save SR images for reference
                        save_img_path = os.path.join(
                            img_dir,
                            "{:s}_{:d}.png".format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                          opt["scale"])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        pbar.update("Test {}".format(img_name))

                    avg_psnr = avg_psnr / idx

                    # log
                    logger.info("# Validation # PSNR: {:.4e}".format(avg_psnr))
                    # tensorboard logger
                    if opt["use_tb_logger"] and "debug" not in opt["name"]:
                        tb_logger.add_scalar("psnr", avg_psnr, current_step)
                else:  # video restoration validation
                    if opt["dist"]:
                        # multi-GPU testing
                        psnr_rlt = {}  # with border and center frames
                        if rank == 0:
                            pbar = util.ProgressBar(len(val_set))
                        for idx in range(rank, len(val_set), world_size):
                            val_data = val_set[idx]
                            val_data["LQs"].unsqueeze_(0)
                            val_data["GT"].unsqueeze_(0)
                            folder = val_data["folder"]
                            idx_d, max_idx = val_data["idx"].split("/")
                            idx_d, max_idx = int(idx_d), int(max_idx)
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = torch.zeros(
                                    max_idx,
                                    dtype=torch.float32,
                                    device="cuda")
                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals["rlt"])  # uint8
                            gt_img = util.tensor2img(visuals["GT"])  # uint8
                            # calculate PSNR
                            psnr_rlt[folder][idx_d] = util.calculate_psnr(
                                rlt_img, gt_img)

                            if rank == 0:
                                for _ in range(world_size):
                                    pbar.update("Test {} - {}/{}".format(
                                        folder, idx_d, max_idx))
                        # collect data
                        for _, v in psnr_rlt.items():
                            dist.reduce(v, 0)
                        dist.barrier()

                        if rank == 0:
                            psnr_rlt_avg = {}
                            psnr_total_avg = 0.0
                            for k, v in psnr_rlt.items():
                                psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                                psnr_total_avg += psnr_rlt_avg[k]
                            psnr_total_avg /= len(psnr_rlt)
                            log_s = "# Validation # PSNR: {:.4e}:".format(
                                psnr_total_avg)
                            for k, v in psnr_rlt_avg.items():
                                log_s += " {}: {:.4e}".format(k, v)
                            logger.info(log_s)
                            if opt["use_tb_logger"] and "debug" not in opt[
                                    "name"]:
                                tb_logger.add_scalar("psnr_avg",
                                                     psnr_total_avg,
                                                     current_step)
                                for k, v in psnr_rlt_avg.items():
                                    tb_logger.add_scalar(k, v, current_step)
                    else:
                        pbar = util.ProgressBar(len(val_loader))
                        psnr_rlt = {}  # with border and center frames
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.0
                        for val_data in val_loader:
                            folder = val_data["folder"][0]
                            idx_d, max_id = val_data["idx"][0].split("/")
                            # border = val_data['border'].item()
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = []

                            model.feed_data(val_data)
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals["rlt"])  # uint8
                            gt_img = util.tensor2img(visuals["GT"])  # uint8
                            lq_img = util.tensor2img(visuals["LQ"][2])  # uint8

                            img_dir = opt["path"]["val_images"]
                            util.mkdir(img_dir)
                            save_img_path = os.path.join(
                                img_dir, "{}.png".format(idx_d))
                            util.save_img(np.hstack((lq_img, rlt_img, gt_img)),
                                          save_img_path)

                            # calculate PSNR
                            psnr = util.calculate_psnr(rlt_img, gt_img)
                            psnr_rlt[folder].append(psnr)
                            pbar.update("Test {} - {}".format(folder, idx_d))
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = sum(v) / len(v)
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = "# Validation # PSNR: {:.4e}:".format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += " {}: {:.4e}".format(k, v)
                        logger.info(log_s)
                        if opt["use_tb_logger"] and "debug" not in opt["name"]:
                            tb_logger.add_scalar("psnr_avg", psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)

            # save models and training states
            if current_step % opt["logger"]["save_checkpoint_freq"] == 0:
                if rank <= 0:
                    logger.info("Saving models and training states.")
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info("Saving the final model.")
        model.save("latest")
        logger.info("End of training.")
        tb_logger.close()
예제 #17
0
파일: test.py 프로젝트: BlueAmulet/BasicSR
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None,
                      opt['path']['log'],
                      'test.log',
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    znorm = False
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        logger.info('Number of test images in [{:s}]: {:d}'.format(
            dataset_opt['name'], len(test_set)))
        test_loaders.append(test_loader)
        # Temporary, will turn znorm on for all the datasets. Will need to introduce a variable for each dataset and differentiate each one later in the loop.
        if dataset_opt['znorm'] and znorm == False:
            znorm = True

    # Create model
    model = create_model(opt)

    for test_loader in test_loaders:
        test_set_name = test_loader.dataset.opt['name']
        logger.info('\nTesting [{:s}]...'.format(test_set_name))
        test_start_time = time.time()
        dataset_dir = os.path.join(opt['path']['results_root'], test_set_name)
        util.mkdir(dataset_dir)

        test_results = OrderedDict()
        test_results['psnr'] = []
        test_results['ssim'] = []
        test_results['psnr_y'] = []
        test_results['ssim_y'] = []

        for data in test_loader:
            need_HR = False if test_loader.dataset.opt[
                'dataroot_HR'] is None else True

            model.feed_data(data, need_HR=need_HR)
            img_path = data['LR_path'][0]
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            model.test()  # test
            visuals = model.get_current_visuals(need_HR=need_HR)

            if znorm:  #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                sr_img = util.tensor2img(visuals['SR'],
                                         min_max=(-1, 1))  # uint8
            else:  # Default: Image range is [0,1]
                sr_img = util.tensor2img(visuals['SR'])  # uint8

            # save images
            suffix = opt['suffix']
            if suffix:
                save_img_path = os.path.join(dataset_dir,
                                             img_name + suffix + '.png')
            else:
                save_img_path = os.path.join(dataset_dir, img_name + '.png')
            util.save_img(sr_img, save_img_path)

            # calculate PSNR and SSIM
            if need_HR:
                if znorm:  #opt['datasets']['train']['znorm']: # If the image range is [-1,1] # In testing, each "dataset" can have a different name (not train, val or other)
                    gt_img = util.tensor2img(visuals['HR'],
                                             min_max=(-1, 1))  # uint8
                else:  # Default: Image range is [0,1]
                    gt_img = util.tensor2img(visuals['HR'])  # uint8
                gt_img = gt_img / 255.
                sr_img = sr_img / 255.

                crop_border = test_loader.dataset.opt['scale']
                cropped_sr_img = sr_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]
                cropped_gt_img = gt_img[crop_border:-crop_border,
                                        crop_border:-crop_border, :]

                psnr = util.calculate_psnr(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                ssim = util.calculate_ssim(cropped_sr_img * 255,
                                           cropped_gt_img * 255)
                test_results['psnr'].append(psnr)
                test_results['ssim'].append(ssim)

                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    cropped_sr_img_y = sr_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    cropped_gt_img_y = gt_img_y[crop_border:-crop_border,
                                                crop_border:-crop_border]
                    psnr_y = util.calculate_psnr(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    ssim_y = util.calculate_ssim(cropped_sr_img_y * 255,
                                                 cropped_gt_img_y * 255)
                    test_results['psnr_y'].append(psnr_y)
                    test_results['ssim_y'].append(ssim_y)
                    logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'\
                        .format(img_name, psnr, ssim, psnr_y, ssim_y))
                else:
                    logger.info(
                        '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(
                            img_name, psnr, ssim))
            else:
                logger.info(img_name)

        if need_HR:  # metrics
            # Average PSNR/SSIM results
            ave_psnr = sum(test_results['psnr']) / len(test_results['psnr'])
            ave_ssim = sum(test_results['ssim']) / len(test_results['ssim'])
            logger.info('----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'\
                    .format(test_set_name, ave_psnr, ave_ssim))
            if test_results['psnr_y'] and test_results['ssim_y']:
                ave_psnr_y = sum(test_results['psnr_y']) / len(
                    test_results['psnr_y'])
                ave_ssim_y = sum(test_results['ssim_y']) / len(
                    test_results['ssim_y'])
                logger.info('----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'\
                    .format(ave_psnr_y, ave_ssim_y))
예제 #18
0
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

import options.options as option
from torch import utils as vutils
from data import create_dataset, create_dataloader
from solvers import create_solver
from utils import util

torch.backends.cudnn.determinstic = True

parser = argparse.ArgumentParser(description='Train Super Resolution Models in FedProx')
parser.add_argument('-opt', type=str, required=True, help='Path to options JSON file.')
opt = option.parse(parser.parse_args().opt)

##### Random seed #####
seed = opt['solver']['manual_seed']
if seed is None: seed = random.randint(1, 10000)
print("=====> Random Seed: %d" %seed)
torch.manual_seed(seed)


##### Hyperparameters for federated learning #####
num_clients = opt['fed']['num_clients']
num_selected = int(num_clients * opt['fed']['sample_fraction'])
num_rounds = opt['fed']['num_rounds']
client_epochs = opt['fed']['epochs']

예제 #19
0
def main():
    parser = argparse.ArgumentParser(
        description='Test Super Resolution Models')
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt)
    opt = option.dict_to_nonedict(opt)

    # initial configure
    scale = opt['scale']
    degrad = opt['degradation']
    network_opt = opt['networks']
    model_name = network_opt['which_model'].upper()
    if opt['self_ensemble']: model_name += 'plus'

    # create test dataloader
    bm_names = []
    test_loaders = []
    for _, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
        test_loader = create_dataloader(test_set, dataset_opt)
        test_loaders.append(test_loader)
        print('===> Test Dataset: [%s]   Number of images: [%d]' %
              (test_set.name(), len(test_set)))
        bm_names.append(test_set.name())

    # create solver (and load model)
    solver = create_solver(opt)
    # Test phase
    print('===> Start Test')
    print("==================================================")
    print("Method: %s || Scale: %d || Degradation: %s" %
          (model_name, scale, degrad))

    for bm, test_loader in zip(bm_names, test_loaders):
        print("Test set : [%s]" % bm)

        sr_list = []
        path_list = []

        total_psnr = []
        total_ssim = []
        total_time = []

        need_HR = False if test_loader.dataset.__class__.__name__.find(
            'LRHR') < 0 else True

        for iter, batch in enumerate(test_loader):

            solver.feed_data(batch, need_HR=need_HR)

            # calculate forward time
            t0 = time.time()
            solver.test()
            t1 = time.time()
            total_time.append((t1 - t0))

            visuals = solver.get_current_visual(need_HR=need_HR)
            sr_list.append(visuals['SR'])

            # calculate PSNR/SSIM metrics on Python
            if need_HR:
                psnr, ssim = util.calc_metrics(visuals['SR'],
                                               visuals['HR'],
                                               crop_border=scale)
                total_psnr.append(psnr)
                total_ssim.append(ssim)
                path_list.append(
                    os.path.basename(batch['HR_path'][0]).replace(
                        'HR', model_name))
                print(
                    "[%d/%d] %s || PSNR(dB)/SSIM: %.2f/%.4f || Timer: %.4f sec ."
                    % (iter + 1, len(test_loader),
                       os.path.basename(batch['LR_path'][0]), psnr, ssim,
                       (t1 - t0)))
            else:
                path_list.append(os.path.basename(batch['LR_path'][0]))
                print("[%d/%d] %s || Timer: %.4f sec ." %
                      (iter + 1, len(test_loader),
                       os.path.basename(batch['LR_path'][0]), (t1 - t0)))

        if need_HR:
            print("---- Average PSNR(dB) /SSIM /Speed(s) for [%s] ----" % bm)
            print("PSNR: %.2f      SSIM: %.4f      Speed: %.4f" %
                  (sum(total_psnr) / len(total_psnr), sum(total_ssim) /
                   len(total_ssim), sum(total_time) / len(total_time)))
        else:
            print("---- Average Speed(s) for [%s] is %.4f sec ----" %
                  (bm, sum(total_time) / len(total_time)))

        # save SR results for further evaluation on MATLAB
        if need_HR:
            save_img_path = os.path.join('./results/SR/' + degrad, model_name,
                                         bm, "x%d" % scale)
        else:
            save_img_path = os.path.join('./results/SR/' + bm, model_name,
                                         "x%d" % scale)

        print("===> Saving SR images of [%s]... Save Path: [%s]\n" %
              (bm, save_img_path))

        if not os.path.exists(save_img_path): os.makedirs(save_img_path)
        for img, name in zip(sr_list, path_list):
            imageio.imwrite(os.path.join(save_img_path, name), img)

    print("==================================================")
    print("===> Finished !")
예제 #20
0
파일: train.py 프로젝트: W-yk/SR
def main():
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt',
                        type=str,
                        required=True,
                        help='Path to option JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=True)
    opt = option.dict_to_nonedict(
        opt)  # Convert to NoneDict, which return None for missing key.

    # train from scratch OR resume training
    if opt['path']['resume_state']:  # resuming training
        resume_state = torch.load(opt['path']['resume_state'])
    else:  # training from scratch
        resume_state = None
        util.mkdir_and_rename(
            opt['path']['experiments_root'])  # rename old folder if exists
        util.mkdirs((path for key, path in opt['path'].items()
                     if not key == 'experiments_root'
                     and 'pretrain_model' not in key and 'resume' not in key))

    # config loggers. Before it, the log will not work
    util.setup_logger(None,
                      opt['path']['log'],
                      'train',
                      level=logging.INFO,
                      screen=True)
    util.setup_logger('val', opt['path']['log'], 'val', level=logging.INFO)
    logger = logging.getLogger('base')

    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))
        option.check_resume(opt)  # check resume options

    logger.info(option.dict2str(opt))
    # tensorboard logger
    if opt['use_tb_logger'] and 'debug' not in opt['name']:
        from tensorboardX import SummaryWriter
        tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])

    # random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benckmark = True
    # torch.backends.cudnn.deterministic = True

    # create train and val dataloader
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
                len(train_set), train_size))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                total_epochs, total_iters))
            train_loader = create_dataloader(train_set, dataset_opt)
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt)
            logger.info('Number of val images in [{:s}]: {:d}'.format(
                dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    # create model
    model = create_model(opt)

    # resume training
    if resume_state:
        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    # training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs):
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            # update learning rate
            model.update_learning_rate()

            # training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            # log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar(k, v, current_step)
                logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0:
                avg_psnr = 0.0
                avg_IS = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LR_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['HR'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(\
                        img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    #calculate IS
                    IS = model.get_IS()
                    avg_IS += IS
                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx
                avg_IS = avg_IS / idx
                # log
                logger.info('# Validation # PSNR: {:.4e} IS : {:.4e}'.format(
                    avg_psnr, avg_IS))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e} is: {:.4e} '.
                    format(epoch, current_step, avg_psnr, avg_IS))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    tb_logger.add_scalar('IS', avg_IS, current_step)

            # save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                logger.info('Saving models and training states.')
                model.save(current_step)
                model.save_training_state(epoch, current_step)

    logger.info('Saving the final model.')
    model.save('latest')
    logger.info('End of training.')
예제 #21
0
import utils.util as util
from options import options

import os

from CIFARTrainer import Trainer

if __name__ == '__main__':

    opt = options.parse()
    util.mkdirs(os.path.join(opt.checkpoints_dir, opt.name))
    logger = util.get_logger(
        os.path.join(opt.checkpoints_dir, opt.name, 'logger.log'))

    Trainer(opt, logger).train()
예제 #22
0
def main():

    global opt, model_G, model_D, netContent, writer, STEPS

    opt = parser.parse_args()
    options = option.parse(opt.options)
    print(opt)

    out_folder = "steps({})_lrIN({})_lrOUT({})_lambda(mseIN={},mseOUT={},vgg={},adv={},preserve={})".format(
        opt.inner_loop_steps, opt.lr_inner, opt.lr_outer,
        opt.mse_loss_coefficient_inner, opt.mse_loss_coefficient_outer,
        opt.vgg_loss_coefficient, opt.adversarial_loss_coefficient,
        opt.preservation_loss_coefficient)

    writer = SummaryWriter(logdir=os.path.join(opt.logs_dir, out_folder),
                           comment="-srgan-")

    opt.sample_dir = os.path.join(opt.sample_dir, out_folder)
    opt.fine_sample_dir = os.path.join(opt.fine_sample_dir, out_folder)

    opt.checkpoint_file_init = os.path.join(opt.checkpoint_dir,
                                            "init/" + out_folder)
    opt.checkpoint_file_final = os.path.join(opt.checkpoint_dir,
                                             "final/" + out_folder)
    opt.checkpoint_file_fine = os.path.join(opt.fine_checkpoint_dir,
                                            out_folder)

    os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpus
    if not torch.cuda.is_available():
        raise Exception(
            "No GPU found or Wrong gpu id, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True

    print("===> Loading datasets")
    dataset_opt = options['datasets']['train']
    dataset_opt['batch_size'] = opt.batchSize
    print(dataset_opt)
    train_set = create_dataset(dataset_opt)
    training_data_loader = create_dataloader(train_set, dataset_opt)
    print('===> Train Dataset: %s   Number of images: [%d]' %
          (train_set.name(), len(train_set)))
    if training_data_loader is None:
        raise ValueError("[Error] The training data does not exist")

    print('===> Loading VGG model')
    netVGG = models.vgg19()
    if os.path.isfile('data/vgg19-dcbb9e9d.pth'):
        netVGG.load_state_dict(torch.load('data/vgg19-dcbb9e9d.pth'))
    else:
        netVGG.load_state_dict(
            model_zoo.load_url(
                'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'))

    class _content_model(nn.Module):
        def __init__(self):
            super(_content_model, self).__init__()
            self.feature = nn.Sequential(
                *list(netVGG.features.children())[:-1])

        def forward(self, x):
            out = self.feature(x)
            return out

    G_init = _NetG(opt).cuda()
    model_D = _NetD().cuda()
    netContent = _content_model().cuda()
    criterion_G = GeneratorLoss(netContent, writer, STEPS).cuda()
    criterion_D = nn.BCELoss().cuda()

    # optionally copy weights from a checkpoint
    if opt.pretrained:
        assert os.path.isfile(opt.pretrained)
        print("=> loading model '{}'".format(opt.pretrained))
        weights = torch.load(opt.pretrained)
        # changed
        G_init.load_state_dict(weights['model'].state_dict())

    print("===> Setting Optimizer")
    # changed
    optimizer_G_outer = optim.Adam(G_init.parameters(), lr=opt.lr_outer)
    optimizer_D = optim.Adam(model_D.parameters(), lr=opt.lr_disc)

    print("===> Pre-fetching validation data for monitoring training")
    test_dump_file = 'data/dump/Test5.pickle'

    if os.path.isfile(test_dump_file):
        with open(test_dump_file, 'rb') as p:
            images_test = pickle.load(p)
        images_hr = images_test['images_hr']
        images_lr = images_test['images_lr']
        print("===>Loading Checkpoint Test images")
    else:
        images_hr, images_lr = create_val_ims()
        print("===>Creating Checkpoint Test images")

    print("===> Training")
    epoch = opt.start_epoch
    try:
        while STEPS < (opt.inner_loop_steps + 1) * opt.max_updates:
            # changed
            last_model_G = train(training_data_loader, optimizer_G_outer,
                                 optimizer_D, G_init, model_D, criterion_G,
                                 criterion_D, epoch, STEPS, writer)
            assert last_model_G is not None
            save_checkpoint(images_hr, images_lr, G_init, last_model_G, epoch)
            epoch += 1
    except KeyboardInterrupt:
        print("KeyboardInterrupt HANDLED! Running the final epoch on G_init")
    epoch += 1
    if STEPS < 5e4:
        lr_finetune = opt.lr_inner
    elif STEPS < 1e5:
        lr_finetune = opt.lr_inner / 2
    elif STEPS < 2e5:
        lr_finetune = opt.lr_inner / 4
    elif STEPS < 4e5:
        lr_finetune = opt.lr_inner / 8
    elif STEPS < 8e5:
        lr_finetune = opt.lr_inner / 16
    else:
        lr_finetune = opt.lr_inner / 32

    model_G = deepcopy(G_init)
    optimizer_G_inner = optim.Adam(model_G.parameters(), lr=lr_finetune)
    model_G.train()
    optimizer_G_inner.zero_grad()
    init_parameters = torch.cat(
        [p.view(-1) for k, p in G_init.named_parameters() if p.requires_grad])

    opt.adversarial_loss = False
    opt.vgg_loss = True
    opt.mse_loss_coefficient = opt.mse_loss_coefficient_inner

    start_time = dt.datetime.now()
    total_num_examples = len(training_data_loader)
    for iteration, batch in enumerate(training_data_loader, 1):
        input, target = Variable(batch['LR']), Variable(batch['HR'],
                                                        requires_grad=False)
        input = input.cuda() / 255
        target = target.cuda() / 255
        STEPS += 1
        output = model_G(input)
        fake_out = None
        optimizer_G_inner.zero_grad()
        loss_g_inner = criterion_G(fake_out, output, target, opt)
        curr_parameters = torch.cat([
            p.view(-1) for k, p in model_G.named_parameters()
            if p.requires_grad
        ])
        preservation_loss = ((Variable(init_parameters).detach() -
                              curr_parameters)**2).sum()
        loss_g_inner += preservation_loss
        loss_g_inner.backward()
        optimizer_G_inner.step()
        writer.add_scalar("Loss_G_finetune", loss_g_inner.item(), STEPS)
        if iteration % 5 == 0:
            fine_sample_img = torch_utils.make_grid(torch.cat(
                [output.detach().clone(), target], dim=0),
                                                    padding=2,
                                                    normalize=False)
            if not os.path.exists(opt.fine_sample_dir):
                os.makedirs(opt.fine_sample_dir)
            torch_utils.save_image(fine_sample_img,
                                   os.path.join(
                                       opt.fine_sample_dir,
                                       "Epoch-{}--Iteration-{}.png".format(
                                           epoch, iteration)),
                                   padding=5)

            print("===> Finetuning Epoch[{}]({}/{}): G_Loss(finetune): {:.3}".
                  format(epoch, iteration, total_num_examples,
                         loss_g_inner.item(),
                         (dt.datetime.now() - start_time).seconds))
            start_time = dt.datetime.now()
            save_checkpoint(images_hr,
                            images_lr,
                            None,
                            model_G,
                            iteration,
                            finetune=True)
    save_checkpoint(images_hr,
                    images_lr,
                    None,
                    model_G,
                    total_num_examples,
                    finetune=True)
예제 #23
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        util.setup_logger('val',
                          opt['path']['log'],
                          'val_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data)
            model.optimize_parameters(current_step)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '<epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)

            # validation
            if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
                avg_psnr = 0.0
                idx = 0
                for val_data in val_loader:
                    idx += 1
                    img_name = os.path.splitext(
                        os.path.basename(val_data['LQ_path'][0]))[0]
                    img_dir = os.path.join(opt['path']['val_images'], img_name)
                    util.mkdir(img_dir)

                    model.feed_data(val_data)
                    model.test()

                    visuals = model.get_current_visuals()
                    sr_img = util.tensor2img(visuals['SR'])  # uint8
                    gt_img = util.tensor2img(visuals['GT'])  # uint8

                    # Save SR images for reference
                    save_img_path = os.path.join(
                        img_dir,
                        '{:s}_{:d}.png'.format(img_name, current_step))
                    util.save_img(sr_img, save_img_path)

                    # calculate PSNR
                    crop_size = opt['scale']
                    gt_img = gt_img / 255.
                    sr_img = sr_img / 255.
                    cropped_sr_img = sr_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    cropped_gt_img = gt_img[crop_size:-crop_size,
                                            crop_size:-crop_size, :]
                    avg_psnr += util.calculate_psnr(cropped_sr_img * 255,
                                                    cropped_gt_img * 255)

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}> psnr: {:.4e}'.format(
                        epoch, current_step, avg_psnr))
                # tensorboard logger
                if opt['use_tb_logger'] and 'debug' not in opt['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
예제 #24
0
from data.LQGT_dataset import LQGTDataset
from utils import crash_on_ipy

import options.options as option

opt_file = './options/train/train_ESRGAN_M.yml'
opt = option.parse(opt_file, is_train=True)
test_set = LQGTDataset(opt['datasets']['train'])

for xx in test_set:
    print(type(xx), len(xx), xx['LQ'].shape, xx['GT'].shape)
예제 #25
0
from collections import OrderedDict

import options.options as option
import utils.util as util
from data.util import bgr2ycbcr
from data import create_dataset, create_dataloader
from models import create_model

if __name__ == '__main__':
    # options
    parser = argparse.ArgumentParser()
    parser.add_argument('--opt',
                        type=str,
                        default="options/test/test_ESRGAN.json",
                        help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt, is_train=False)
    util.mkdirs((path for key, path in opt['path'].items()
                 if not key == 'pretrain_model_G'))
    opt = option.dict_to_nonedict(opt)

    util.setup_logger(None,
                      opt['path']['log'],
                      'test.log',
                      level=logging.INFO,
                      screen=True)
    logger = logging.getLogger('base')
    logger.info(option.dict2str(opt))
    # Create test dataset and dataloader
    test_loaders = []
    for phase, dataset_opt in sorted(opt['datasets'].items()):
        test_set = create_dataset(dataset_opt)
예제 #26
0
def main():
    parser = argparse.ArgumentParser(description='Test RCGAN model')
    parser.add_argument('-opt', type=str, required=True, help='Path to options JSON file.')
    opt = option.parse(parser.parse_args().opt)
    opt = option.dict_to_nonedict(opt)

    # create test dataloader
    dataset_opt = opt['datasets']['test']
    if dataset_opt is None:
        raise ValueError("test dataset_opt is None!")
    test_set = create_dataset(dataset_opt)
    test_loader = create_dataloader(test_set, dataset_opt)

    if test_loader is None:
        raise ValueError("The test data does not exist")

    solver = RCGANModel(opt)
    solver.model_pth = opt['model_path']
    solver.results_dir = os.path.join(opt['model_path'], 'results')
    solver.cmp_dir = os.path.join(opt['model_path'], 'cmp')

    # load model
    model_pth = os.path.join(solver.model_pth, 'RCGAN_model.pth')
    if model_pth is None:
        raise ValueError("model_pth' is required.")
    print('[Loading model from %s...]' % model_pth)
    model_dict = torch.load(model_pth)
    solver.model['netG'].load_state_dict(model_dict['state_dict_G'])

    print('=> Done.')
    print('[Start Testing]')

    test_bar = tqdm(test_loader)
    fused_list = []
    path_list = []

    if not os.path.exists(solver.cmp_dir):
        os.makedirs(solver.cmp_dir)

    for iter, batch in enumerate(test_bar):
        solver.feed_data(batch)
        solver.test()
        visuals_list = solver.get_current_visual_list()  # fetch current iteration results as cpu tensor
        visuals = solver.get_current_visual()  # fetch current iteration results as cpu tensor
        images = torch.stack(visuals_list)
        saveimg = thutil.make_grid(images, nrow=3, padding=5)
        saveimg_nd = saveimg.byte().permute(1, 2, 0).numpy()
        img_name = os.path.splitext(os.path.basename(batch['VIS_path'][0]))[0]
        imageio.imwrite(os.path.join(solver.cmp_dir, 'comp_%s.bmp' % (img_name)), saveimg_nd)
        fused_img = visuals['img_fuse']
        fused_img = np.transpose(util.quantize(fused_img).numpy(), (1, 2, 0)).astype(np.uint8).squeeze()
        fused_list.append(fused_img)
        path_list.append(img_name)

    save_img_path = solver.results_dir
    if not os.path.exists(save_img_path):
        os.makedirs(save_img_path)

    for img, img_name in zip(fused_list, path_list):
        imageio.imwrite(os.path.join(solver.results_dir, img_name + '.bmp'), img)

    test_bar.close()
예제 #27
0
import gym
import torch
from util import NormalizedActions
from collections import deque
from agent import Agent
import numpy as np

from options import options


options = options()

opts = options.parse()
batch = opts.batch

env = NormalizedActions(gym.make('BipedalWalker-v2'))

from IPython.display import clear_output
import matplotlib.pyplot as plt

policy = Agent(env)
def plot(frame_idx, rewards):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title('Episode %s. reward: %s' % (frame_idx, rewards[-1]))
    plt.plot(rewards)
    plt.show()

rewards = []
예제 #28
0
def main():
    #### options
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, help='Path to option YAML file.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=True)

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()

    #### loading resume state if exists
    if opt['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt, resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
        if resume_state is None:
            util.mkdir_and_rename(
                opt['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt['path'].items() if
                 not key == 'experiments_root' and 'pretrain_model' not in key
                 and 'resume' not in key and 'wandb_load_run_path' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt['path']['log'],
                          'train_' + opt['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt))
        # tensorboard logger
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name'])
        if opt['use_wandb_logger'] and 'debug' not in opt['name']:
            json_path = os.path.join(os.path.expanduser('~'),
                                     '.wandb_api_keys.json')
            if os.path.exists(json_path):
                with open(json_path, 'r') as j:
                    json_file = json.loads(j.read())
                    os.environ['WANDB_API_KEY'] = json_file['ryul99']
            wandb.init(project="mmsr", config=opt, sync_tensorboard=True)
    else:
        util.setup_logger('base',
                          opt['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    # convert to NoneDict, which returns None for missing keys
    opt = option.dict_to_nonedict(opt)

    #### random seed
    seed = opt['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
        if opt['use_wandb_logger'] and 'debug' not in opt['name']:
            wandb.config.update({'random_seed': seed})
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None

    #### create model
    model = create_model(opt)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model.resume_training(resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate
            model.update_learning_rate(current_step,
                                       warmup_iter=opt['train']['warmup_iter'])

            #### training
            model.feed_data(train_data,
                            noise_mode=opt['datasets']['train']['noise_mode'],
                            noise_rate=opt['datasets']['train']['noise_rate'])
            model.optimize_parameters(current_step)

            #### log
            if current_step % opt['logger']['print_freq'] == 0:
                logs = model.get_current_log()
                message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(
                    epoch, current_step)
                for v in model.get_current_learning_rate():
                    message += '{:.3e},'.format(v)
                message += ')] '
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                    if opt['use_wandb_logger'] and 'debug' not in opt['name']:
                        if rank <= 0:
                            wandb.log({k: v}, step=current_step)
                if rank <= 0:
                    logger.info(message)
            #### validation
            if opt['datasets'].get(
                    'val',
                    None) and current_step % opt['train']['val_freq'] == 0:
                if opt['model'] in [
                        'sr', 'srgan'
                ] and rank <= 0:  # image restoration validation
                    # does not support multi-GPU validation
                    pbar = util.ProgressBar(len(val_loader))
                    avg_psnr = 0.
                    idx = 0
                    for val_data in val_loader:
                        idx += 1
                        img_name = os.path.splitext(
                            os.path.basename(val_data['LQ_path'][0]))[0]
                        img_dir = os.path.join(opt['path']['val_images'],
                                               img_name)
                        util.mkdir(img_dir)

                        model.feed_data(
                            val_data,
                            noise_mode=opt['datasets']['val']['noise_mode'],
                            noise_rate=opt['datasets']['val']['noise_rate'])
                        model.test()

                        visuals = model.get_current_visuals()
                        sr_img = util.tensor2img(visuals['rlt'])  # uint8
                        gt_img = util.tensor2img(visuals['GT'])  # uint8

                        # Save SR images for reference
                        save_img_path = os.path.join(
                            img_dir,
                            '{:s}_{:d}.png'.format(img_name, current_step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        sr_img, gt_img = util.crop_border([sr_img, gt_img],
                                                          opt['scale'])
                        avg_psnr += util.calculate_psnr(sr_img, gt_img)
                        pbar.update('Test {}'.format(img_name))

                    avg_psnr = avg_psnr / idx

                    # log
                    logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
                    # tensorboard logger
                    if opt['use_tb_logger'] and 'debug' not in opt['name']:
                        tb_logger.add_scalar('psnr', avg_psnr, current_step)
                    if opt['use_wandb_logger'] and 'debug' not in opt['name']:
                        wandb.log({'psnr': avg_psnr}, step=current_step)
                else:  # video restoration validation
                    if opt['dist']:
                        # multi-GPU testing
                        psnr_rlt = {}  # with border and center frames
                        if rank == 0:
                            pbar = util.ProgressBar(len(val_set))
                        for idx in range(rank, len(val_set), world_size):
                            val_data = val_set[idx]
                            val_data['LQs'].unsqueeze_(0)
                            val_data['GT'].unsqueeze_(0)
                            folder = val_data['folder']
                            idx_d, max_idx = val_data['idx'].split('/')
                            idx_d, max_idx = int(idx_d), int(max_idx)
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = torch.zeros(
                                    max_idx,
                                    dtype=torch.float32,
                                    device='cuda')
                            # tmp = torch.zeros(max_idx, dtype=torch.float32, device='cuda')
                            model.feed_data(val_data,
                                            noise_mode=opt['datasets']['val']
                                            ['noise_mode'],
                                            noise_rate=opt['datasets']['val']
                                            ['noise_rate'])
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8
                            # calculate PSNR
                            psnr_rlt[folder][idx_d] = util.calculate_psnr(
                                rlt_img, gt_img)

                            if rank == 0:
                                for _ in range(world_size):
                                    pbar.update('Test {} - {}/{}'.format(
                                        folder, idx_d, max_idx))
                        # # collect data
                        for _, v in psnr_rlt.items():
                            dist.reduce(v, 0)
                        dist.barrier()

                        if rank == 0:
                            psnr_rlt_avg = {}
                            psnr_total_avg = 0.
                            for k, v in psnr_rlt.items():
                                psnr_rlt_avg[k] = torch.mean(v).cpu().item()
                                psnr_total_avg += psnr_rlt_avg[k]
                            psnr_total_avg /= len(psnr_rlt)
                            log_s = '# Validation # PSNR: {:.4e}:'.format(
                                psnr_total_avg)
                            for k, v in psnr_rlt_avg.items():
                                log_s += ' {}: {:.4e}'.format(k, v)
                            logger.info(log_s)
                            if opt['use_tb_logger'] and 'debug' not in opt[
                                    'name']:
                                tb_logger.add_scalar('psnr_avg',
                                                     psnr_total_avg,
                                                     current_step)
                                for k, v in psnr_rlt_avg.items():
                                    tb_logger.add_scalar(k, v, current_step)
                            if opt['use_wandb_logger'] and 'debug' not in opt[
                                    'name']:
                                lq_img, rlt_img, gt_img = map(
                                    util.tensor2img, [
                                        visuals['LQ'], visuals['rlt'],
                                        visuals['GT']
                                    ])
                                wandb.log({'psnr_avg': psnr_total_avg},
                                          step=current_step)
                                wandb.log(psnr_rlt_avg, step=current_step)
                                wandb.log(
                                    {
                                        'Validation Image': [
                                            wandb.Image(lq_img[:, :,
                                                               [2, 1, 0]],
                                                        caption='LQ'),
                                            wandb.Image(rlt_img[:, :,
                                                                [2, 1, 0]],
                                                        caption='output'),
                                            wandb.Image(gt_img[:, :,
                                                               [2, 1, 0]],
                                                        caption='GT'),
                                        ]
                                    },
                                    step=current_step)
                    else:
                        pbar = util.ProgressBar(len(val_loader))
                        psnr_rlt = {}  # with border and center frames
                        psnr_rlt_avg = {}
                        psnr_total_avg = 0.
                        for val_data in val_loader:
                            folder = val_data['folder'][0]
                            idx_d = val_data['idx'].item()
                            # border = val_data['border'].item()
                            if psnr_rlt.get(folder, None) is None:
                                psnr_rlt[folder] = []

                            model.feed_data(val_data,
                                            noise_mode=opt['datasets']['val']
                                            ['noise_mode'],
                                            noise_rate=opt['datasets']['val']
                                            ['noise_rate'])
                            model.test()
                            visuals = model.get_current_visuals()
                            rlt_img = util.tensor2img(visuals['rlt'])  # uint8
                            gt_img = util.tensor2img(visuals['GT'])  # uint8

                            # calculate PSNR
                            psnr = util.calculate_psnr(rlt_img, gt_img)
                            psnr_rlt[folder].append(psnr)
                            pbar.update('Test {} - {}'.format(folder, idx_d))
                        for k, v in psnr_rlt.items():
                            psnr_rlt_avg[k] = sum(v) / len(v)
                            psnr_total_avg += psnr_rlt_avg[k]
                        psnr_total_avg /= len(psnr_rlt)
                        log_s = '# Validation # PSNR: {:.4e}:'.format(
                            psnr_total_avg)
                        for k, v in psnr_rlt_avg.items():
                            log_s += ' {}: {:.4e}'.format(k, v)
                        logger.info(log_s)
                        if opt['use_tb_logger'] and 'debug' not in opt['name']:
                            tb_logger.add_scalar('psnr_avg', psnr_total_avg,
                                                 current_step)
                            for k, v in psnr_rlt_avg.items():
                                tb_logger.add_scalar(k, v, current_step)
                        if opt['use_wandb_logger'] and 'debug' not in opt[
                                'name']:
                            lq_img, rlt_img, gt_img = map(
                                util.tensor2img,
                                [visuals['LQ'], visuals['rlt'], visuals['GT']])
                            wandb.log({'psnr_avg': psnr_total_avg},
                                      step=current_step)
                            wandb.log(psnr_rlt_avg, step=current_step)
                            wandb.log(
                                {
                                    'Validation Image': [
                                        wandb.Image(lq_img[:, :, [2, 1, 0]],
                                                    caption='LQ'),
                                        wandb.Image(rlt_img[:, :, [2, 1, 0]],
                                                    caption='output'),
                                        wandb.Image(gt_img[:, :, [2, 1, 0]],
                                                    caption='GT'),
                                    ]
                                },
                                step=current_step)

            #### save models and training states
            if current_step % opt['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model.save(current_step)
                    model.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model.save('latest')
        logger.info('End of training.')
        if opt['use_tb_logger'] and 'debug' not in opt['name']:
            tb_logger.close()
def main():
    #################
    # configurations
    #################
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    # parser.add_argument("--gt_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--gpu_id", type=str, required=True)
    parser.add_argument("--gpu_number", type=str, required=True)
    parser.add_argument("--gpu_index", type=str, required=True)
    parser.add_argument("--screen_notation", type=str, required=True)
    parser.add_argument('--opt',
                        type=str,
                        required=True,
                        help='Path to option YAML file.')
    args = parser.parse_args()
    opt = option.parse(args.opt, is_train=False)

    gpu_number = int(args.gpu_number)
    gpu_index = int(args.gpu_index)

    PAD = 32

    total_run_time = AverageMeter()
    # print("GPU ", torch.cuda.device_count())

    device = torch.device('cuda')
    # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    print('export CUDA_VISIBLE_DEVICES=' + str(args.gpu_id))

    data_mode = 'sharp_bicubic'  # Vid4 | sharp_bicubic | blur_bicubic | blur | blur_comp
    # Vid4: SR
    # REDS4: sharp_bicubic (SR-clean), blur_bicubic (SR-blur);
    #        blur (deblur-clean), blur_comp (deblur-compression).
    stage = 1  # 1 or 2, use two stage strategy for REDS dataset.
    flip_test = False

    # Input_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_sharp_bicubic"
    # GT_folder = "/DATA7_DB7/data/4khdr/data/Dataset/train_4k"
    # Result_folder = "/DATA7_DB7/data/4khdr/data/Results"

    Input_folder = args.input_path
    # GT_folder = args.gt_path
    Result_folder = args.output_path
    Model_path = args.model_path

    # create results folder
    if not os.path.exists(Result_folder):
        os.makedirs(Result_folder, exist_ok=True)

    ############################################################################
    #### model
    # if data_mode == 'Vid4':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_Vimeo90K_SR_L.pth'
    #     else:
    #         raise ValueError('Vid4 does not support stage 2.')
    # elif data_mode == 'sharp_bicubic':
    #     if stage == 1:
    #         # model_path = '../experiments/pretrained_models/EDVR_REDS_SR_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SR_Stage2.pth'
    # elif data_mode == 'blur_bicubic':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_SRblur_Stage2.pth'
    # elif data_mode == 'blur':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblur_Stage2.pth'
    # elif data_mode == 'blur_comp':
    #     if stage == 1:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_L.pth'
    #     else:
    #         model_path = '../experiments/pretrained_models/EDVR_REDS_deblurcomp_Stage2.pth'
    # else:
    #     raise NotImplementedError

    model_path = Model_path

    if data_mode == 'Vid4':
        N_in = 7  # use N_in images to restore one HR image
    else:
        N_in = 5

    predeblur, HR_in = False, False
    back_RBs = 40
    if data_mode == 'blur_bicubic':
        predeblur = True
    if data_mode == 'blur' or data_mode == 'blur_comp':
        predeblur, HR_in = True, True
    if stage == 2:
        HR_in = True
        back_RBs = 20

    model = EDVR_arch.EDVR(nf=opt['network_G']['nf'],
                           nframes=opt['network_G']['nframes'],
                           groups=opt['network_G']['groups'],
                           front_RBs=opt['network_G']['front_RBs'],
                           back_RBs=opt['network_G']['back_RBs'],
                           predeblur=opt['network_G']['predeblur'],
                           HR_in=opt['network_G']['HR_in'],
                           w_TSA=opt['network_G']['w_TSA'])

    # model = EDVR_arch.EDVR(128, N_in, 8, 5, back_RBs, predeblur=predeblur, HR_in=HR_in)

    #### dataset
    if data_mode == 'Vid4':
        test_dataset_folder = '../datasets/Vid4/BIx4'
        GT_dataset_folder = '../datasets/Vid4/GT'
    else:
        if stage == 1:
            # test_dataset_folder = '../datasets/REDS4/{}'.format(data_mode)
            # test_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp_bicubic/X4'
            test_dataset_folder = Input_folder
        else:
            test_dataset_folder = '../results/REDS-EDVR_REDS_SR_L_flipx4'
            print('You should modify the test_dataset_folder path for stage 2')
        # GT_dataset_folder = '../datasets/REDS4/GT'
        # GT_dataset_folder = '/DATA/wangshen_data/REDS/val_sharp'
        # GT_dataset_folder = GT_folder

    #### evaluation
    crop_border = 0
    border_frame = N_in // 2  # border frames when evaluate
    # temporal padding mode
    if data_mode == 'Vid4' or data_mode == 'sharp_bicubic':
        padding = 'new_info'
    else:
        padding = 'replicate'
    save_imgs = True

    # save_folder = '../results/{}'.format(data_mode)
    # save_folder = '/DATA/wangshen_data/REDS/results/{}'.format(data_mode)
    save_folder = os.path.join(Result_folder, data_mode)
    util.mkdirs(save_folder)
    util.setup_logger('base',
                      save_folder,
                      'test',
                      level=logging.INFO,
                      screen=True,
                      tofile=True)
    logger = logging.getLogger('base')

    #### log info
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))

    #### set up the models
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    model = model.to(device)

    avg_psnr_l, avg_psnr_center_l, avg_psnr_border_l = [], [], []
    subfolder_name_l = []

    subfolder_l = sorted(glob.glob(osp.join(test_dataset_folder, '*')))
    # subfolder_GT_l = sorted(glob.glob(osp.join(GT_dataset_folder, '*')))
    # for each subfolder
    # for subfolder, subfolder_GT in zip(subfolder_l, subfolder_GT_l):

    end = time.time()

    # load screen change notation
    import json
    with open(args.screen_notation) as f:
        frame_notation = json.load(f)

    subfolder_n = len(subfolder_l)
    subfolder_l = subfolder_l[int(subfolder_n * gpu_index /
                                  gpu_number):int(subfolder_n *
                                                  (gpu_index + 1) /
                                                  gpu_number)]

    for subfolder in subfolder_l:

        input_subfolder = os.path.split(subfolder)[1]

        # subfolder_GT = os.path.join(GT_dataset_folder,input_subfolder)

        #if not os.path.exists(subfolder_GT):
        #    continue

        print("Evaluate Folders: ", input_subfolder)

        subfolder_name = osp.basename(subfolder)
        subfolder_name_l.append(subfolder_name)
        save_subfolder = osp.join(save_folder, subfolder_name)

        img_path_l = sorted(glob.glob(osp.join(subfolder, '*')))
        max_idx = len(img_path_l)
        if save_imgs:
            util.mkdirs(save_subfolder)

        #### read LQ and GT images
        imgs_LQ = data_util.read_img_seq(subfolder)  # Num x 3 x H x W
        #img_GT_l = []
        #for img_GT_path in sorted(glob.glob(osp.join(subfolder_GT, '*'))):
        #    img_GT_l.append(data_util.read_img(None, img_GT_path))

        #avg_psnr, avg_psnr_border, avg_psnr_center, N_border, N_center = 0, 0, 0, 0, 0

        # process each image
        for img_idx, img_path in enumerate(img_path_l):

            img_name = osp.splitext(osp.basename(img_path))[0]

            # todo here handle screen change
            select_idx, log1, log2, nota = data_util.index_generation_process_screen_change_withlog_fixbug(
                input_subfolder,
                frame_notation,
                img_idx,
                max_idx,
                N_in,
                padding=padding)

            if not log1 == None:
                logger.info('screen change')
                logger.info(nota)
                logger.info(log1)
                logger.info(log2)

            imgs_in = imgs_LQ.index_select(
                0, torch.LongTensor(select_idx)).unsqueeze(0).to(
                    device)  # 960 x 540

            # here we split the input images 960x540 into 9 320x180 patch
            gtWidth = 3840
            gtHeight = 2160
            intWidth_ori = imgs_in.shape[4]  # 960
            intHeight_ori = imgs_in.shape[3]  # 540
            split_lengthY = 180
            split_lengthX = 320
            scale = 4

            intPaddingRight_ = int(float(intWidth_ori) / split_lengthX +
                                   1) * split_lengthX - intWidth_ori
            intPaddingBottom_ = int(float(intHeight_ori) / split_lengthY +
                                    1) * split_lengthY - intHeight_ori

            intPaddingRight_ = 0 if intPaddingRight_ == split_lengthX else intPaddingRight_
            intPaddingBottom_ = 0 if intPaddingBottom_ == split_lengthY else intPaddingBottom_

            pader0 = torch.nn.ReplicationPad2d(
                [0, intPaddingRight_, 0, intPaddingBottom_])
            print("Init pad right/bottom " + str(intPaddingRight_) + " / " +
                  str(intPaddingBottom_))

            intPaddingRight = PAD  # 32# 64# 128# 256
            intPaddingLeft = PAD  # 32#64 #128# 256
            intPaddingTop = PAD  # 32#64 #128#256
            intPaddingBottom = PAD  # 32#64 # 128# 256

            pader = torch.nn.ReplicationPad2d([
                intPaddingLeft, intPaddingRight, intPaddingTop,
                intPaddingBottom
            ])

            imgs_in = torch.squeeze(imgs_in, 0)  # N C H W

            imgs_in = pader0(imgs_in)  # N C 540 960

            imgs_in = pader(imgs_in)  # N C 604 1024

            assert (split_lengthY == int(split_lengthY)
                    and split_lengthX == int(split_lengthX))
            split_lengthY = int(split_lengthY)
            split_lengthX = int(split_lengthX)
            split_numY = int(float(intHeight_ori) / split_lengthY)
            split_numX = int(float(intWidth_ori) / split_lengthX)
            splitsY = range(0, split_numY)
            splitsX = range(0, split_numX)

            intWidth = split_lengthX
            intWidth_pad = intWidth + intPaddingLeft + intPaddingRight
            intHeight = split_lengthY
            intHeight_pad = intHeight + intPaddingTop + intPaddingBottom

            # print("split " + str(split_numY) + ' , ' + str(split_numX))
            y_all = np.zeros((gtHeight, gtWidth, 3), dtype="float32")  # HWC
            for split_j, split_i in itertools.product(splitsY, splitsX):
                # print(str(split_j) + ", \t " + str(split_i))
                X0 = imgs_in[:, :, split_j *
                             split_lengthY:(split_j + 1) * split_lengthY +
                             intPaddingBottom + intPaddingTop, split_i *
                             split_lengthX:(split_i + 1) * split_lengthX +
                             intPaddingRight + intPaddingLeft]

                # y_ = torch.FloatTensor()

                X0 = torch.unsqueeze(X0, 0)  # N C H W -> 1 N C H W

                if flip_test:
                    output = util.flipx4_forward(model, X0)
                else:
                    output = util.single_forward(model, X0)

                output_depadded = output[0, :, intPaddingTop *
                                         scale:(intPaddingTop + intHeight) *
                                         scale, intPaddingLeft *
                                         scale:(intPaddingLeft + intWidth) *
                                         scale]
                output_depadded = output_depadded.squeeze(0)
                output = util.tensor2img(output_depadded)


                y_all[split_j * split_lengthY * scale :(split_j + 1) * split_lengthY * scale,
                      split_i * split_lengthX * scale :(split_i + 1) * split_lengthX * scale, :] = \
                        np.round(output).astype(np.uint8)

                # plt.figure(0)
                # plt.title("pic")
                # plt.imshow(y_all)

            if save_imgs:
                cv2.imwrite(
                    osp.join(save_subfolder, '{}.png'.format(img_name)), y_all)

            print("*****************current image process time \t " +
                  str(time.time() - end) + "s ******************")
            total_run_time.update(time.time() - end, 1)

            # calculate PSNR
            #y_all = y_all / 255.
            #GT = np.copy(img_GT_l[img_idx])
            # For REDS, evaluate on RGB channels; for Vid4, evaluate on the Y channel
            #if data_mode == 'Vid4':  # bgr2y, [0, 1]
            #    GT = data_util.bgr2ycbcr(GT, only_y=True)
            #    y_all = data_util.bgr2ycbcr(y_all, only_y=True)

            #y_all, GT = util.crop_border([y_all, GT], crop_border)
            #crt_psnr = util.calculate_psnr(y_all * 255, GT * 255)
            #logger.info('{:3d} - {:25} \tPSNR: {:.6f} dB'.format(img_idx + 1, img_name, crt_psnr))

            logger.info('{} : {:3d} - {:25} \t'.format(input_subfolder,
                                                       img_idx + 1, img_name))

            #if img_idx >= border_frame and img_idx < max_idx - border_frame:  # center frames
            #    avg_psnr_center += crt_psnr
            #    N_center += 1
            #else:  # border frames
            #    avg_psnr_border += crt_psnr
            #    N_border += 1

        #avg_psnr = (avg_psnr_center + avg_psnr_border) / (N_center + N_border)
        #avg_psnr_center = avg_psnr_center / N_center
        #avg_psnr_border = 0 if N_border == 0 else avg_psnr_border / N_border
        #avg_psnr_l.append(avg_psnr)
        #avg_psnr_center_l.append(avg_psnr_center)
        #avg_psnr_border_l.append(avg_psnr_border)

        #logger.info('Folder {} - Average PSNR: {:.6f} dB for {} frames; '
        #            'Center PSNR: {:.6f} dB for {} frames; '
        #            'Border PSNR: {:.6f} dB for {} frames.'.format(subfolder_name, avg_psnr,
        #                                                           (N_center + N_border),
        #                                                           avg_psnr_center, N_center,
        #                                                           avg_psnr_border, N_border))

    #logger.info('################ Tidy Outputs ################')
    #for subfolder_name, psnr, psnr_center, psnr_border in zip(subfolder_name_l, avg_psnr_l,
    #                                                          avg_psnr_center_l, avg_psnr_border_l):
    #    logger.info('Folder {} - Average PSNR: {:.6f} dB. '
    #                'Center PSNR: {:.6f} dB. '
    #                'Border PSNR: {:.6f} dB.'.format(subfolder_name, psnr, psnr_center,
    #                                                 psnr_border))
    #logger.info('################ Final Results ################')
    logger.info('Data: {} - {}'.format(data_mode, test_dataset_folder))
    logger.info('Padding mode: {}'.format(padding))
    logger.info('Model path: {}'.format(model_path))
    logger.info('Save images: {}'.format(save_imgs))
    logger.info('Flip test: {}'.format(flip_test))
예제 #30
0
def main():
    #### setup options of three networks
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt_P',
                        type=str,
                        help='Path to option YMAL file of Predictor.')
    parser.add_argument('-opt_C',
                        type=str,
                        help='Path to option YMAL file of Corrector.')
    parser.add_argument('-opt_F',
                        type=str,
                        help='Path to option YMAL file of SFTMD_Net.')
    parser.add_argument('--launcher',
                        choices=['none', 'pytorch'],
                        default='none',
                        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    opt_P = option.parse(args.opt_P, is_train=True)
    opt_C = option.parse(args.opt_C, is_train=True)
    opt_F = option.parse(args.opt_F, is_train=True)

    # convert to NoneDict, which returns None for missing keys
    opt_P = option.dict_to_nonedict(opt_P)
    opt_C = option.dict_to_nonedict(opt_C)
    opt_F = option.dict_to_nonedict(opt_F)

    # choose small opt for SFTMD test, fill path of pre-trained model_F
    opt_F = opt_F['sftmd']

    # create PCA matrix of enough kernel
    batch_ker = util.random_batch_kernel(batch=30000,
                                         l=opt_P['kernel_size'],
                                         sig_min=0.2,
                                         sig_max=4.0,
                                         rate_iso=1.0,
                                         scaling=3,
                                         tensor=False)
    print('batch kernel shape: {}'.format(batch_ker.shape))
    b = np.size(batch_ker, 0)
    batch_ker = batch_ker.reshape((b, -1))
    pca_matrix = util.PCA(batch_ker, k=opt_P['code_length']).float()
    print('PCA matrix shape: {}'.format(pca_matrix.shape))

    #### distributed training settings
    if args.launcher == 'none':  # disabled distributed training
        opt_P['dist'] = False
        opt_F['dist'] = False
        opt_C['dist'] = False
        rank = -1
        print('Disabled distributed training.')
    else:
        opt_P['dist'] = True
        opt_F['dist'] = True
        opt_C['dist'] = True
        init_dist()
        world_size = torch.distributed.get_world_size(
        )  #Returns the number of processes in the current process group
        rank = torch.distributed.get_rank(
        )  #Returns the rank of current process group

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    ###### Predictor&Corrector train ######

    #### loading resume state if exists
    if opt_P['path'].get('resume_state', None):
        # distributed resuming: all load into default GPU
        device_id = torch.cuda.current_device()
        resume_state = torch.load(
            opt_P['path']['resume_state'],
            map_location=lambda storage, loc: storage.cuda(device_id))
        option.check_resume(opt_P,
                            resume_state['iter'])  # check resume options
    else:
        resume_state = None

    #### mkdir and loggers
    if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0-7)
        if resume_state is None:
            # Predictor path
            util.mkdir_and_rename(
                opt_P['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt_P['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))
            # Corrector path
            util.mkdir_and_rename(
                opt_C['path']
                ['experiments_root'])  # rename experiment folder if exists
            util.mkdirs(
                (path for key, path in opt_C['path'].items()
                 if not key == 'experiments_root'
                 and 'pretrain_model' not in key and 'resume' not in key))

        # config loggers. Before it, the log will not work
        util.setup_logger('base',
                          opt_P['path']['log'],
                          'train_' + opt_P['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        util.setup_logger('val',
                          opt_P['path']['log'],
                          'val_' + opt_P['name'],
                          level=logging.INFO,
                          screen=True,
                          tofile=True)
        logger = logging.getLogger('base')
        logger.info(option.dict2str(opt_P))
        logger.info(option.dict2str(opt_C))
        # tensorboard logger
        if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']:
            version = float(torch.__version__[0:3])
            if version >= 1.1:  # PyTorch 1.1
                from torch.utils.tensorboard import SummaryWriter
            else:
                logger.info(
                    'You are using PyTorch {}. Tensorboard will use [tensorboardX]'
                    .format(version))
                from tensorboardX import SummaryWriter
            tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt_P['name'])
    else:
        util.setup_logger('base',
                          opt_P['path']['log'],
                          'train',
                          level=logging.INFO,
                          screen=True)
        logger = logging.getLogger('base')

    #### random seed
    seed = opt_P['train']['manual_seed']
    if seed is None:
        seed = random.randint(1, 10000)
    if rank <= 0:
        logger.info('Random seed: {}'.format(seed))
    util.set_random_seed(seed)

    torch.backends.cudnn.benchmark = True
    # torch.backends.cudnn.deterministic = True

    #### create train and val dataloader
    dataset_ratio = 200  # enlarge the size of each epoch
    for phase, dataset_opt in opt_P['datasets'].items():
        if phase == 'train':
            train_set = create_dataset(dataset_opt)
            train_size = int(
                math.ceil(len(train_set) / dataset_opt['batch_size']))
            total_iters = int(opt_P['train']['niter'])
            total_epochs = int(math.ceil(total_iters / train_size))
            if opt_P['dist']:
                train_sampler = DistIterSampler(train_set, world_size, rank,
                                                dataset_ratio)
                total_epochs = int(
                    math.ceil(total_iters / (train_size * dataset_ratio)))
            else:
                train_sampler = None
            train_loader = create_dataloader(train_set, dataset_opt, opt_P,
                                             train_sampler)
            if rank <= 0:
                logger.info(
                    'Number of train images: {:,d}, iters: {:,d}'.format(
                        len(train_set), train_size))
                logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
                    total_epochs, total_iters))
        elif phase == 'val':
            val_set = create_dataset(dataset_opt)
            val_loader = create_dataloader(val_set, dataset_opt, opt_P, None)
            if rank <= 0:
                logger.info('Number of val images in [{:s}]: {:d}'.format(
                    dataset_opt['name'], len(val_set)))
        else:
            raise NotImplementedError(
                'Phase [{:s}] is not recognized.'.format(phase))
    assert train_loader is not None
    assert val_loader is not None

    #### create model
    model_F = create_model(opt_F)  #load pretrained model of SFTMD
    model_P = create_model(opt_P)
    model_C = create_model(opt_C)

    #### resume training
    if resume_state:
        logger.info('Resuming training from epoch: {}, iter: {}.'.format(
            resume_state['epoch'], resume_state['iter']))

        start_epoch = resume_state['epoch']
        current_step = resume_state['iter']
        model_P.resume_training(
            resume_state)  # handle optimizers and schedulers
    else:
        current_step = 0
        start_epoch = 0

    #### training
    logger.info('Start training from epoch: {:d}, iter: {:d}'.format(
        start_epoch, current_step))
    for epoch in range(start_epoch, total_epochs + 1):
        if opt_P['dist']:
            train_sampler.set_epoch(epoch)
        for _, train_data in enumerate(train_loader):
            current_step += 1
            if current_step > total_iters:
                break
            #### update learning rate, schedulers
            # model.update_learning_rate(current_step, warmup_iter=opt_P['train']['warmup_iter'])

            #### preprocessing for LR_img and kernel map
            prepro = util.SRMDPreprocessing(opt_P['scale'],
                                            pca_matrix,
                                            para_input=opt_P['code_length'],
                                            kernel=opt_P['kernel_size'],
                                            noise=False,
                                            cuda=True,
                                            sig_min=0.2,
                                            sig_max=4.0,
                                            rate_iso=1.0,
                                            scaling=3,
                                            rate_cln=0.2,
                                            noise_high=0.0)
            LR_img, ker_map = prepro(train_data['GT'])

            #### training Predictor
            model_P.feed_data(LR_img, ker_map)
            model_P.optimize_parameters(current_step)
            P_visuals = model_P.get_current_visuals()
            est_ker_map = P_visuals['Batch_est_ker_map']

            #### log of model_P
            if current_step % opt_P['logger']['print_freq'] == 0:
                logs = model_P.get_current_log()
                message = 'Predictor <epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                    epoch, current_step, model_P.get_current_learning_rate())
                for k, v in logs.items():
                    message += '{:s}: {:.4e} '.format(k, v)
                    # tensorboard logger
                    if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']:
                        if rank <= 0:
                            tb_logger.add_scalar(k, v, current_step)
                if rank <= 0:
                    logger.info(message)

            #### training Corrector
            for step in range(opt_C['step']):
                # test SFTMD for corresponding SR image
                model_F.feed_data(train_data, LR_img, est_ker_map)
                model_F.test()
                F_visuals = model_F.get_current_visuals()
                SR_img = F_visuals['Batch_SR']
                # Test SFTMD to produce SR images

                # train corrector given SR image and estimated kernel map
                model_C.feed_data(SR_img, est_ker_map, ker_map)
                model_C.optimize_parameters(current_step)
                C_visuals = model_C.get_current_visuals()
                est_ker_map = C_visuals['Batch_est_ker_map']

                #### log of model_C
                if current_step % opt_C['logger']['print_freq'] == 0:
                    logs = model_C.get_current_log()
                    message = 'Corrector <epoch:{:3d}, iter:{:8,d}, lr:{:.3e}> '.format(
                        epoch, current_step,
                        model_C.get_current_learning_rate())
                    for k, v in logs.items():
                        message += '{:s}: {:.4e} '.format(k, v)
                        # tensorboard logger
                        if opt_C['use_tb_logger'] and 'debug' not in opt_C[
                                'name']:
                            if rank <= 0:
                                tb_logger.add_scalar(k, v, current_step)
                    if rank <= 0:
                        logger.info(message)

            # validation, to produce ker_map_list(fake)
            if current_step % opt_P['train']['val_freq'] == 0 and rank <= 0:
                avg_psnr = 0.0
                idx = 0
                for _, val_data in enumerate(val_loader):
                    prepro = util.SRMDPreprocessing(
                        opt_P['scale'],
                        pca_matrix,
                        para_input=opt_P['code_length'],
                        kernel=opt_P['kernel_size'],
                        noise=False,
                        cuda=True,
                        sig_min=0.2,
                        sig_max=4.0,
                        rate_iso=1.0,
                        scaling=3,
                        rate_cln=0.2,
                        noise_high=0.0)
                    LR_img, ker_map = prepro(val_data['GT'])
                    single_img_psnr = 0.0

                    # valid Predictor
                    model_P.feed_data(LR_img, ker_map)
                    model_P.test()
                    P_visuals = model_P.get_current_visuals()
                    est_ker_map = P_visuals['Batch_est_ker_map']

                    for step in range(opt_C['step']):
                        step += 1
                        idx += 1
                        model_F.feed_data(val_data, LR_img, est_ker_map)
                        model_F.test()
                        F_visuals = model_F.get_current_visuals()
                        SR_img = F_visuals['Batch_SR']
                        # Test SFTMD to produce SR images

                        model_C.feed_data(SR_img, est_ker_map, ker_map)
                        model_C.test()
                        C_visuals = model_C.get_current_visuals()
                        est_ker_map = C_visuals['Batch_est_ker_map']

                        sr_img = util.tensor2img(F_visuals['SR'])  # uint8
                        gt_img = util.tensor2img(F_visuals['GT'])  # uint8

                        # Save SR images for reference
                        img_name = os.path.splitext(
                            os.path.basename(val_data['LQ_path'][0]))[0]
                        img_dir = os.path.join(opt_P['path']['val_images'],
                                               img_name)
                        # img_dir = os.path.join(opt_F['path']['val_images'], str(current_step), '_', str(step))
                        util.mkdir(img_dir)

                        save_img_path = os.path.join(
                            img_dir, '{:s}_{:d}_{:d}.png'.format(
                                img_name, current_step, step))
                        util.save_img(sr_img, save_img_path)

                        # calculate PSNR
                        crop_size = opt_P['scale']
                        gt_img = gt_img / 255.
                        sr_img = sr_img / 255.
                        cropped_sr_img = sr_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                        cropped_gt_img = gt_img[crop_size:-crop_size,
                                                crop_size:-crop_size, :]
                        step_psnr = util.calculate_psnr(
                            cropped_sr_img * 255, cropped_gt_img * 255)
                        logger.info(
                            '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> img:{:s}, psnr: {:.4f}'
                            .format(epoch, current_step, step, img_name,
                                    step_psnr))
                        single_img_psnr += step_psnr
                        avg_psnr += util.calculate_psnr(
                            cropped_sr_img * 255, cropped_gt_img * 255)

                    avg_signle_img_psnr = single_img_psnr / step
                    logger.info(
                        '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> img:{:s}, average psnr: {:.4f}'
                        .format(epoch, current_step, step, img_name,
                                avg_signle_img_psnr))

                avg_psnr = avg_psnr / idx

                # log
                logger.info('# Validation # PSNR: {:.4f}'.format(avg_psnr))
                logger_val = logging.getLogger('val')  # validation logger
                logger_val.info(
                    '<epoch:{:3d}, iter:{:8,d}, step:{:3d}> psnr: {:.4f}'.
                    format(epoch, current_step, step, avg_psnr))
                # tensorboard logger
                if opt_P['use_tb_logger'] and 'debug' not in opt_P['name']:
                    tb_logger.add_scalar('psnr', avg_psnr, current_step)

            #### save models and training states
            if current_step % opt_P['logger']['save_checkpoint_freq'] == 0:
                if rank <= 0:
                    logger.info('Saving models and training states.')
                    model_P.save(current_step)
                    model_P.save_training_state(epoch, current_step)
                    model_C.save(current_step)
                    model_C.save_training_state(epoch, current_step)

    if rank <= 0:
        logger.info('Saving the final model.')
        model_P.save('latest')
        model_C.save('latest')
        logger.info('End of Predictor and Corrector training.')
    tb_logger.close()