コード例 #1
0
 def __init__(self,
              color=True,
              burst_length=8,
              blind_est=False,
              kernel_size=[5],
              sep_conv=False,
              channel_att=False,
              spatial_att=False,
              upMode='bilinear',
              core_bias=False):
     super(Att_Weight_KPN_noise, self).__init__()
     self.Att_Weight_KPN = Att_Weight_KPN(color=color,
                                          burst_length=burst_length,
                                          blind_est=blind_est,
                                          kernel_size=kernel_size,
                                          sep_conv=sep_conv,
                                          channel_att=channel_att,
                                          spatial_att=spatial_att,
                                          upMode=upMode,
                                          core_bias=core_bias)
     self.noise_estimate = NoiseEstimate(color=color)
コード例 #2
0
 def __init__(self,
              color=True,
              burst_length=8,
              blind_est=False,
              kernel_size=[5],
              sep_conv=False,
              channel_att=False,
              spatial_att=False,
              upMode='bilinear',
              core_bias=False,
              in_channel=3):
     super(Att_Weight_KPN_DGF, self).__init__()
     self.Att_Weight_KPN = Att_Weight_KPN(color=color,
                                          burst_length=burst_length,
                                          blind_est=blind_est,
                                          kernel_size=kernel_size,
                                          sep_conv=sep_conv,
                                          channel_att=channel_att,
                                          spatial_att=spatial_att,
                                          upMode=upMode,
                                          core_bias=core_bias,
                                          in_channel=in_channel)
     self.gf = ConvGuidedFilter2(radius=1)
コード例 #3
0
def test_multi(args):
    color = True
    burst_length = 8
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.model_type == "attKPN":
        model = Att_KPN(color=color,
                        burst_length=burst_length,
                        blind_est=True,
                        kernel_size=[5],
                        sep_conv=False,
                        channel_att=True,
                        spatial_att=True,
                        upMode="bilinear",
                        core_bias=False)
    elif args.model_type == "attWKPN":
        model = Att_Weight_KPN(color=color,
                               burst_length=burst_length,
                               blind_est=True,
                               kernel_size=[5],
                               sep_conv=False,
                               channel_att=True,
                               spatial_att=True,
                               upMode="bilinear",
                               core_bias=False)
    elif args.model_type == "KPN":
        model = KPN(color=color,
                    burst_length=burst_length,
                    blind_est=True,
                    kernel_size=[5],
                    sep_conv=False,
                    channel_att=False,
                    spatial_att=False,
                    upMode="bilinear",
                    core_bias=False)
    else:
        print(" Model type not valid")
        return
    # model2 = KPN(
    #     color=color,
    #     burst_length=burst_length,
    #     blind_est=True,
    #     kernel_size=[5],
    #     sep_conv=False,
    #     channel_att=False,
    #     spatial_att=False,
    #     upMode="bilinear",
    #     core_bias=False
    # )
    checkpoint_dir = "checkpoints/" + args.checkpoint
    if not os.path.exists(checkpoint_dir) or len(
            os.listdir(checkpoint_dir)) == 0:
        print('There is no any checkpoint file in path:{}'.format(
            checkpoint_dir))
    # load trained model
    ckpt = load_checkpoint(checkpoint_dir,
                           cuda=device == 'cuda',
                           best_or_latest=args.load_type)
    state_dict = ckpt['state_dict']
    # if not args.cuda:
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    # else:
    #     model.load_state_dict(ckpt['state_dict'])

    #############################################
    # checkpoint_dir = "checkpoints/" + "kpn"
    # if not os.path.exists(checkpoint_dir) or len(os.listdir(checkpoint_dir)) == 0:
    #     print('There is no any checkpoint file in path:{}'.format(checkpoint_dir))
    # # load trained model
    # ckpt = load_checkpoint(checkpoint_dir,cuda=device=='cuda')
    # state_dict = ckpt['state_dict']
    # new_state_dict = OrderedDict()
    # if not args.cuda:
    #     for k, v in state_dict.items():
    #         name = k[7:]  # remove `module.`
    #         new_state_dict[name] = v
    # # model.load_state_dict(ckpt['state_dict'])
    # model2.load_state_dict(new_state_dict)
    ###########################################
    print('The model has been loaded from epoch {}, n_iter {}.'.format(
        ckpt['epoch'], ckpt['global_iter']))
    # switch the eval mode
    model.to(device)
    model.eval()
    # model2.eval()
    # model= save_dict['state_dict']
    trans = transforms.ToPILImage()
    torch.manual_seed(0)
    noisy_path = sorted(glob.glob(args.noise_dir + "/*.png"))
    clean_path = [i.replace("noisy", "clean") for i in noisy_path]
    for i in range(len(noisy_path)):
        image_noise = load_data(noisy_path[i], burst_length)
        begin = time.time()
        image_noise_batch = image_noise.to(device)
        # print(image_noise.size())
        # print(image_noise_batch.size())
        burst_noise = image_noise_batch.to(device)
        if color:
            b, N, c, h, w = burst_noise.size()
            feedData = burst_noise.view(b, -1, h, w)
        else:
            feedData = burst_noise
        # print(feedData.size())
        pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...])
        del pred_i
        # pred_i2, pred2 = model2(feedData, burst_noise[:, 0:burst_length, ...])
        # print("Time : ", time.time()-begin)
        pred = pred.detach().cpu()
        gt = transforms.ToTensor()(Image.open(clean_path[i]).convert('RGB'))
        # print(pred_i.size())
        # print(pred.size())
        # print(gt.size())
        gt = gt.unsqueeze(0)
        _, _, h_hr, w_hr = gt.size()
        _, _, h_lr, w_lr = pred.size()
        gt_down = F.interpolate(gt, (h_lr, w_lr),
                                mode='bilinear',
                                align_corners=True)
        pred_up = F.interpolate(pred, (h_hr, w_hr),
                                mode='bilinear',
                                align_corners=True)
        # print("After interpolate")
        # print(pred_up.size())
        # print(gt_down.size())
        psnr_t_up = calculate_psnr(pred_up, gt)
        ssim_t_up = calculate_ssim(pred_up, gt)
        psnr_t_down = calculate_psnr(pred, gt_down)
        ssim_t_down = calculate_ssim(pred, gt_down)
        print(i, "   UP   :  PSNR : ", str(psnr_t_up), " :  SSIM : ",
              str(ssim_t_up), " : DOWN   :  PSNR : ", str(psnr_t_down),
              " :  SSIM : ", str(ssim_t_down))

        if args.save_img != '':
            if not os.path.exists(args.save_img):
                os.makedirs(args.save_img)
            plt.figure(figsize=(15, 15))
            plt.imshow(np.array(trans(pred_up[0])))
            plt.title("denoise KPN split " + args.model_type, fontsize=25)
            image_name = noisy_path[i].split("/")[-1].split(".")[0]
            plt.axis("off")
            plt.suptitle(image_name + "   UP   :  PSNR : " + str(psnr_t_up) +
                         " :  SSIM : " + str(ssim_t_up),
                         fontsize=25)
            plt.savefig(os.path.join(
                args.save_img, image_name + "_" + args.checkpoint + '.png'),
                        pad_inches=0)

        # print(np.array(trans(mf8[0])))
        """
コード例 #4
0
ファイル: test.py プロジェクト: pminhtam/KPN_attention
def eval(args):
    color = args.color
    print('Eval Process......')
    burst_length = 8
    # print(args.checkpoint)
    checkpoint_dir = "checkpoints/" + args.checkpoint
    if not os.path.exists(checkpoint_dir) or len(
            os.listdir(checkpoint_dir)) == 0:
        print('There is no any checkpoint file in path:{}'.format(
            checkpoint_dir))
    # the path for saving eval images
    eval_dir = "eval_img"
    if not os.path.exists(eval_dir):
        os.mkdir(eval_dir)

    # dataset and dataloader
    data_set = MultiLoader(noise_dir=args.noise_dir,
                           gt_dir=args.gt_dir,
                           image_size=args.image_size)
    data_loader = DataLoader(data_set,
                             batch_size=1,
                             shuffle=False,
                             num_workers=args.num_workers)

    # model here
    if args.model_type == "attKPN":
        model = Att_KPN(color=color,
                        burst_length=burst_length,
                        blind_est=True,
                        kernel_size=[5],
                        sep_conv=False,
                        channel_att=False,
                        spatial_att=False,
                        upMode="bilinear",
                        core_bias=False)
    elif args.model_type == "attWKPN":
        model = Att_Weight_KPN(color=color,
                               burst_length=burst_length,
                               blind_est=True,
                               kernel_size=[5],
                               sep_conv=False,
                               channel_att=False,
                               spatial_att=False,
                               upMode="bilinear",
                               core_bias=False)
    elif args.model_type == "KPN":
        model = KPN(color=color,
                    burst_length=burst_length,
                    blind_est=True,
                    kernel_size=[5],
                    sep_conv=False,
                    channel_att=False,
                    spatial_att=False,
                    upMode="bilinear",
                    core_bias=False)
    else:
        print(" Model type not valid")
        return
    if args.cuda:
        model = model.cuda()

    if args.mGPU:
        model = nn.DataParallel(model)
    # load trained model
    ckpt = load_checkpoint(checkpoint_dir, cuda=args.cuda)

    state_dict = ckpt['state_dict']
    if not args.mGPU:
        new_state_dict = OrderedDict()
        if not args.cuda:
            for k, v in state_dict.items():
                name = k[7:]  # remove `module.`
                new_state_dict[name] = v
        model.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(ckpt['state_dict'])
    print('The model has been loaded from epoch {}, n_iter {}.'.format(
        ckpt['epoch'], ckpt['global_iter']))
    # torch.save(model.state_dict(), "model_state.pth")
    # exit(0)
    # switch the eval mode
    model.eval()

    # data_loader = iter(data_loader)
    trans = transforms.ToPILImage()

    with torch.no_grad():
        psnr = 0.0
        ssim = 0.0
        torch.manual_seed(0)
        for i, (burst_noise, gt) in enumerate(data_loader):
            if i < 100:
                # data = next(data_loader)
                if args.cuda:
                    burst_noise = burst_noise.cuda()
                    gt = gt.cuda()
                if color:
                    b, N, c, h, w = burst_noise.size()
                    feedData = burst_noise.view(b, -1, h, w)
                else:
                    feedData = burst_noise
                pred_i, pred = model(feedData, burst_noise[:, 0:burst_length,
                                                           ...])

                if not color:
                    psnr_t = calculate_psnr(pred.unsqueeze(1), gt.unsqueeze(1))
                    ssim_t = calculate_ssim(pred.unsqueeze(1), gt.unsqueeze(1))
                    psnr_noisy = calculate_psnr(
                        burst_noise[:, 0, ...].unsqueeze(1), gt.unsqueeze(1))
                else:
                    psnr_t = calculate_psnr(pred, gt)
                    ssim_t = calculate_ssim(pred, gt)
                    psnr_noisy = calculate_psnr(burst_noise[:, 0, ...], gt)

                psnr += psnr_t
                ssim += ssim_t

                pred = torch.clamp(pred, 0.0, 1.0)

                if args.cuda:
                    pred = pred.cpu()
                    gt = gt.cpu()
                    burst_noise = burst_noise.cpu()
                if args.save_img:
                    trans(burst_noise[0, 0, ...].squeeze()).save(os.path.join(
                        eval_dir,
                        '{}_noisy_{:.2f}dB.png'.format(i, psnr_noisy)),
                                                                 quality=100)
                    trans(pred.squeeze()).save(os.path.join(
                        eval_dir, '{}_pred_{:.2f}dB.png'.format(i, psnr_t)),
                                               quality=100)
                    trans(gt.squeeze()).save(os.path.join(
                        eval_dir, '{}_gt.png'.format(i)),
                                             quality=100)

                print('{}-th image is OK, with PSNR: {:.2f} , SSIM: {:.4f}'.
                      format(i, psnr_t, ssim_t))
            else:
                break
コード例 #5
0
ファイル: test_custom.py プロジェクト: pminhtam/KPN_attention
def test_multi(dir, image_size, args):
    num_workers = 1
    batch_size = 1
    color = True
    burst_length = 8
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.model_type == "attKPN":
        model = Att_KPN(color=color,
                        burst_length=burst_length,
                        blind_est=True,
                        kernel_size=[5],
                        sep_conv=False,
                        channel_att=False,
                        spatial_att=False,
                        upMode="bilinear",
                        core_bias=False)
    elif args.model_type == "attWKPN":
        model = Att_Weight_KPN(color=color,
                               burst_length=burst_length,
                               blind_est=True,
                               kernel_size=[5],
                               sep_conv=False,
                               channel_att=False,
                               spatial_att=False,
                               upMode="bilinear",
                               core_bias=False)
    elif args.model_type == "KPN":
        model = KPN(color=color,
                    burst_length=burst_length,
                    blind_est=True,
                    kernel_size=[5],
                    sep_conv=False,
                    channel_att=False,
                    spatial_att=False,
                    upMode="bilinear",
                    core_bias=False)
    else:
        print(" Model type not valid")
        return
    model2 = KPN(color=color,
                 burst_length=burst_length,
                 blind_est=True,
                 kernel_size=[5],
                 sep_conv=False,
                 channel_att=False,
                 spatial_att=False,
                 upMode="bilinear",
                 core_bias=False)
    checkpoint_dir = "checkpoints/" + args.checkpoint
    if not os.path.exists(checkpoint_dir) or len(
            os.listdir(checkpoint_dir)) == 0:
        print('There is no any checkpoint file in path:{}'.format(
            checkpoint_dir))
    # load trained model
    ckpt = load_checkpoint(checkpoint_dir, cuda=device == 'cuda')
    state_dict = ckpt['state_dict']
    new_state_dict = OrderedDict()
    # if not args.mGPU:
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    # model.load_state_dict(ckpt['state_dict'])
    model.load_state_dict(new_state_dict)

    checkpoint_dir = "checkpoints/" + "kpn"
    if not os.path.exists(checkpoint_dir) or len(
            os.listdir(checkpoint_dir)) == 0:
        print('There is no any checkpoint file in path:{}'.format(
            checkpoint_dir))
    # load trained model
    ckpt = load_checkpoint(checkpoint_dir,
                           cuda=device == 'cuda',
                           best_or_latest=args.load_type)
    state_dict = ckpt['state_dict']
    if not args.cuda:
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v
        model2.load_state_dict(new_state_dict)
    else:
        model.load_state_dict(ckpt['state_dict'])
    print('The model has been loaded from epoch {}, n_iter {}.'.format(
        ckpt['epoch'], ckpt['global_iter']))
    # switch the eval mode
    model.to(device)
    model2.to(device)
    model.eval()
    model2.eval()
    # model= save_dict['state_dict']
    trans = transforms.ToPILImage()
    torch.manual_seed(0)
    for i in range(10):
        image_noise = load_data(dir, image_size, burst_length)
        begin = time.time()
        image_noise_batch = image_noise.to(device)
        print(image_noise_batch.size())
        burst_size = image_noise_batch.size()[1]
        burst_noise = image_noise_batch.to(device)
        if color:
            b, N, c, h, w = burst_noise.size()
            feedData = burst_noise.view(b, -1, h, w)
        else:
            feedData = burst_noise
        # print(feedData.size())
        pred_i, pred = model(feedData, burst_noise[:, 0:burst_length, ...])
        pred_i2, pred2 = model2(feedData, burst_noise[:, 0:burst_length, ...])
        pred = pred.detach().cpu()
        pred2 = pred2.detach().cpu()
        print("Time : ", time.time() - begin)
        print(pred_i.size())
        print(pred.size())
        if args.save_img != '':
            # print(np.array(trans(mf8[0])))
            plt.figure(figsize=(10, 3))
            plt.subplot(1, 3, 1)
            plt.imshow(np.array(trans(pred[0])))
            plt.title("denoise attKPN")
            plt.subplot(1, 3, 2)
            plt.imshow(np.array(trans(pred2[0])))
            plt.title("denoise KPN")
            # plt.show()
            plt.subplot(1, 3, 3)
            plt.imshow(np.array(trans(image_noise[0][0])))
            plt.title("noise ")
            image_name = str(i)
            plt.savefig(os.path.join(
                args.save_img, image_name + "_" + args.checkpoint + '.png'),
                        pad_inches=0)
コード例 #6
0
ファイル: train_lr.py プロジェクト: pminhtam/KPN_attention
def eval(args):
    color = True
    burst_length = args.burst_length
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.model_type == "attKPN":
        model = Att_KPN(color=color,
                        burst_length=burst_length,
                        blind_est=True,
                        kernel_size=[5],
                        sep_conv=False,
                        channel_att=True,
                        spatial_att=True,
                        upMode="bilinear",
                        core_bias=False)
    elif args.model_type == "attKPN_Wave":
        model = Att_KPN_Wavelet(color=color,
                                burst_length=1,
                                blind_est=True,
                                kernel_size=[5],
                                sep_conv=False,
                                channel_att=True,
                                spatial_att=True,
                                upMode="bilinear",
                                core_bias=False)
    elif args.model_type == "attWKPN":
        model = Att_Weight_KPN(color=color,
                               burst_length=1,
                               blind_est=True,
                               kernel_size=[5],
                               sep_conv=False,
                               channel_att=True,
                               spatial_att=True,
                               upMode="bilinear",
                               core_bias=False)
    elif args.model_type == "KPN":
        model = KPN(color=color,
                    burst_length=1,
                    blind_est=True,
                    kernel_size=[5],
                    sep_conv=False,
                    channel_att=False,
                    spatial_att=False,
                    upMode="bilinear",
                    core_bias=False)
    else:
        print(" Model type not valid")
        return
    checkpoint_dir = "checkpoints/" + args.checkpoint
    if not os.path.exists(checkpoint_dir) or len(
            os.listdir(checkpoint_dir)) == 0:
        print('There is no any checkpoint file in path:{}'.format(
            checkpoint_dir))
    # load trained model
    ckpt = load_checkpoint(checkpoint_dir,
                           cuda=device == 'cuda',
                           best_or_latest=args.load_type)
    state_dict = ckpt['state_dict']

    # if not args.cuda:
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    # else:
    #     model.load_state_dict(ckpt['state_dict'])

    model.to(device)
    print('The model has been loaded from epoch {}, n_iter {}.'.format(
        ckpt['epoch'], ckpt['global_iter']))
    # switch the eval mode
    model.eval()
    # model= save_dict['state_dict']
    trans = transforms.ToPILImage()
    torch.manual_seed(0)
    all_noisy_imgs = scipy.io.loadmat(
        args.noise_dir)['siddplus_valid_noisy_srgb']
    all_clean_imgs = scipy.io.loadmat(args.gt_dir)['siddplus_valid_gt_srgb']
    i_imgs, _, _, _ = all_noisy_imgs.shape
    psnrs = []
    ssims = []
    for i_img in range(i_imgs):
        image_noise = transforms.ToTensor()(Image.fromarray(
            all_noisy_imgs[i_img]))
        image_noise_lr, image_noise_hr = load_data(image_noise, burst_length)
        burst_noise = image_noise_lr[:, 0:1, :, :, :].to(device)
        if color:
            b, N, c, h, w = burst_noise.size()
            feedData = burst_noise.view(b, -1, h, w)
        else:
            feedData = burst_noise
        # print(feedData.size())
        _, pred = model(feedData, burst_noise)
        pred = pred.detach().cpu()
        # print("Time : ", time.time()-begin)
        gt = transforms.ToTensor()(Image.fromarray(all_clean_imgs[i_img]))
        image_gt_lr, image_gt_hr = load_data(gt, burst_length)
        gt = image_gt_lr[:, 0, :, :, :].to(device)
        # print(pred_i.size())
        # print(pred[0].size())
        psnr_t = calculate_psnr(pred, gt)
        ssim_t = calculate_ssim(pred, gt)
        psnrs.append(psnr_t)
        ssims.append(ssim_t)
        print(i_img, "   UP   :  PSNR : ", str(psnr_t), " :  SSIM : ",
              str(ssim_t))
        if args.save_img != '':
            if not os.path.exists(args.save_img):
                os.makedirs(args.save_img)
            plt.figure(figsize=(15, 15))
            plt.imshow(np.array(trans(pred[0])))
            plt.title("denoise KPN DGF " + args.model_type, fontsize=25)
            image_name = str(i_img)
            plt.axis("off")
            plt.suptitle(image_name + "   UP   :  PSNR : " + str(psnr_t) +
                         " :  SSIM : " + str(ssim_t),
                         fontsize=25)
            plt.savefig(os.path.join(
                args.save_img, image_name + "_" + args.checkpoint + '.png'),
                        pad_inches=0)
        """
        if args.save_img:
            # print(np.array(trans(mf8[0])))
            plt.figure(figsize=(30, 9))
            plt.subplot(1,3,1)
            plt.imshow(np.array(trans(pred[0])))
            plt.title("denoise DGF "+args.model_type, fontsize=26)
            plt.subplot(1,3,2)
            plt.imshow(np.array(trans(gt[0])))
            plt.title("gt ", fontsize=26)
            plt.subplot(1,3,3)
            plt.imshow(np.array(trans(image_noise_hr[0])))
            plt.title("noise ", fontsize=26)
            plt.axis("off")
            plt.suptitle(str(i)+"   UP   :  PSNR : "+ str(psnr_t)+" :  SSIM : "+ str(ssim_t), fontsize=26)
            plt.savefig("checkpoints/22_DGF_" + args.checkpoint+str(i)+'.png',pad_inches=0)
        """
    print("   AVG   :  PSNR : " + str(np.mean(psnrs)) + " :  SSIM : " +
          str(np.mean(ssims)))
コード例 #7
0
ファイル: train_lr.py プロジェクト: pminhtam/KPN_attention
def train(num_workers, cuda, restart_train, mGPU):
    # torch.set_num_threads(num_threads)

    color = True
    batch_size = args.batch_size
    lr = 2e-4
    lr_decay = 0.89125093813
    n_epoch = args.epoch
    # num_workers = 8
    save_freq = args.save_every
    loss_freq = args.loss_every
    lr_step_size = 100
    burst_length = args.burst_length
    # checkpoint path
    checkpoint_dir = "checkpoints/" + args.checkpoint
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    # logs path
    logs_dir = "checkpoints/logs/" + args.checkpoint
    if not os.path.exists(logs_dir):
        os.makedirs(logs_dir)
    shutil.rmtree(logs_dir)
    log_writer = SummaryWriter(logs_dir)

    # dataset and dataloader
    if args.data_type == 'real':
        data_set = SingleLoader_DGF(noise_dir=args.noise_dir,
                                    gt_dir=args.gt_dir,
                                    image_size=args.image_size,
                                    burst_length=burst_length)
    elif args.data_type == "synth":
        data_set = SingleLoader_DGF_synth(gt_dir=args.gt_dir,
                                          image_size=args.image_size,
                                          burst_length=burst_length)
    else:
        print("Wrong type data")
        return
    data_loader = DataLoader(data_set,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=num_workers)
    # model here
    if args.model_type == "attKPN":
        model = Att_KPN(color=color,
                        burst_length=burst_length,
                        blind_est=True,
                        kernel_size=[5],
                        sep_conv=False,
                        channel_att=True,
                        spatial_att=True,
                        upMode="bilinear",
                        core_bias=False)
    elif args.model_type == "attKPN_Wave":
        model = Att_KPN_Wavelet(color=color,
                                burst_length=1,
                                blind_est=True,
                                kernel_size=[5],
                                sep_conv=False,
                                channel_att=True,
                                spatial_att=True,
                                upMode="bilinear",
                                core_bias=False)
    elif args.model_type == "attWKPN":
        model = Att_Weight_KPN(color=color,
                               burst_length=1,
                               blind_est=True,
                               kernel_size=[5],
                               sep_conv=False,
                               channel_att=True,
                               spatial_att=True,
                               upMode="bilinear",
                               core_bias=False)
    elif args.model_type == "KPN":
        model = KPN(color=color,
                    burst_length=1,
                    blind_est=True,
                    kernel_size=[5],
                    sep_conv=False,
                    channel_att=False,
                    spatial_att=False,
                    upMode="bilinear",
                    core_bias=False)
    else:
        print(" Model type not valid")
        return
    if cuda:
        model = model.cuda()

    if mGPU:
        model = nn.DataParallel(model)
    model.train()

    # loss function here
    # loss_func = LossFunc(
    #     coeff_basic=1.0,
    #     coeff_anneal=1.0,
    #     gradient_L1=True,
    #     alpha=0.9998,
    #     beta=100.0
    # )
    loss_func = LossBasic()
    loss_func_i = LossAnneal_i()
    if args.wavelet_loss:
        print("Use wavelet loss")
        loss_func2 = WaveletLoss()
    # Optimizer here
    optimizer = optim.Adam(model.parameters(), lr=lr)

    optimizer.zero_grad()

    # learning rate scheduler here
    scheduler = lr_scheduler.StepLR(optimizer,
                                    step_size=lr_step_size,
                                    gamma=lr_decay)

    average_loss = MovingAverage(save_freq)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if not restart_train:
        try:
            checkpoint = load_checkpoint(checkpoint_dir,
                                         cuda=device == 'cuda',
                                         best_or_latest=args.load_type)
            start_epoch = checkpoint['epoch']
            global_step = checkpoint['global_iter']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['lr_scheduler'])
            print('=> loaded checkpoint (epoch {}, global_step {})'.format(
                start_epoch, global_step))
        except:
            start_epoch = 0
            global_step = 0
            best_loss = np.inf
            print('=> no checkpoint file to be loaded.')
    else:
        start_epoch = 0
        global_step = 0
        best_loss = np.inf
        if os.path.exists(checkpoint_dir):
            pass
            # files = os.listdir(checkpoint_dir)
            # for f in files:
            #     os.remove(os.path.join(checkpoint_dir, f))
        else:
            os.mkdir(checkpoint_dir)
        print('=> training')

    for epoch in range(start_epoch, n_epoch):
        epoch_start_time = time.time()
        # decay the learning rate

        # print('='*20, 'lr={}'.format([param['lr'] for param in optimizer.param_groups]), '='*20)
        t1 = time.time()
        for step, (image_noise_hr, image_noise_lr, image_gt_hr,
                   image_gt_lr) in enumerate(data_loader):
            # print(burst_noise.size())
            # print(gt.size())
            if cuda:
                burst_noise = image_noise_lr[:, 0:1, :, :, :].cuda()
                # gt = image_gt_hr.cuda()
                gt = image_gt_lr[:, 0, :, :, :].cuda()
                # image_noise_hr = image_noise_hr.cuda()
            else:
                burst_noise = image_noise_lr[:, 0:1, :, :, :]
                gt = image_gt_lr[:, 0, :, :, :]
            if color:
                b, N, c, h, w = burst_noise.size()
                # print(image_noise_lr.size())
                feedData = burst_noise.view(b, -1, h, w)
            else:
                feedData = image_noise_lr
            # print('white_level', white_level, white_level.size())
            # print("feedData   : ",feedData.size())
            # print("burst_noise   : ",burst_noise.size())
            #
            pred_i, pred = model(feedData, burst_noise)
            #
            # loss_basic, loss_anneal = loss_func(pred_i, pred, gt, global_step)
            # print(pred.size())
            # print(gt.size())
            loss_basic = loss_func(pred, gt)
            # loss_i =loss_func_i(global_step, pred_i, image_gt_lr)
            loss = loss_basic
            if args.wavelet_loss:
                loss_wave = loss_func2(pred, gt)
                # print(loss_wave)
                loss = loss_basic + loss_wave
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # update the average loss
            average_loss.update(loss)
            # global_step

            if not color:
                pred = pred.unsqueeze(1)
                gt = gt.unsqueeze(1)
            if global_step % loss_freq == 0:
                # calculate PSNR
                # print("burst_noise  : ",burst_noise.size())
                # print("gt   :  ",gt.size())
                # print("feedData   : ", feedData.size())
                psnr = calculate_psnr(pred, gt)
                ssim = calculate_ssim(pred, gt)

                # add scalars to tensorboardX
                log_writer.add_scalar('loss_basic', loss_basic, global_step)
                # log_writer.add_scalar('loss_anneal', loss_anneal, global_step)
                log_writer.add_scalar('loss_total', loss, global_step)
                log_writer.add_scalar('psnr', psnr, global_step)
                log_writer.add_scalar('ssim', ssim, global_step)

                # print
                print(
                    '{:-4d}\t| epoch {:2d}\t| step {:4d}\t| loss_basic: {:.4f}\t|'
                    ' loss: {:.4f}\t| PSNR: {:.2f}dB\t| SSIM: {:.4f}\t| time:{:.2f} seconds.'
                    .format(global_step, epoch, step, loss_basic, loss, psnr,
                            ssim,
                            time.time() - t1))
                t1 = time.time()

            if global_step % save_freq == 0:
                if average_loss.get_value() < best_loss:
                    is_best = True
                    best_loss = average_loss.get_value()
                else:
                    is_best = False

                save_dict = {
                    'epoch': epoch,
                    'global_iter': global_step,
                    'state_dict': model.state_dict(),
                    'best_loss': best_loss,
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': scheduler.state_dict()
                }
                save_checkpoint(save_dict,
                                is_best,
                                checkpoint_dir,
                                global_step,
                                max_keep=10)
                print(
                    'Save   : {:-4d}\t| epoch {:2d}\t| step {:4d}\t| loss_basic: {:.4f}\t|'
                    ' loss: {:.4f}'.format(global_step, epoch, step,
                                           loss_basic, loss))
            global_step += 1
        print('Epoch {} is finished, time elapsed {:.2f} seconds.'.format(
            epoch,
            time.time() - epoch_start_time))
        lr_cur = [param['lr'] for param in optimizer.param_groups]
        if lr_cur[0] > 5e-6:
            scheduler.step()
        else:
            for param in optimizer.param_groups:
                param['lr'] = 5e-6