예제 #1
0
def test(net, epoch_state):
    clean_seq = load_video(opt.dataset_dir)
    psnr_list = []
    ssim_list = []
    tpsnr_list = []
    tssim_list = []
    avg_time = 0
    print('testing........')
    clean_seq = Variable(clean_seq).cuda()
    gt = clean_seq[:, :,
                   int(opt.input_frame / 2 -
                       opt.output_frame / 2):int(opt.input_frame / 2 +
                                                 opt.output_frame / 2), :, :]
    noise = torch.empty_like(clean_seq).normal_(mean=0, std=opt.val_noiseL).to(
        torch.device('cuda'))
    noisy_clean_seq = clean_seq + noise
    start_time = time.time()
    denoise_seq = net(noisy_clean_seq)
    denoise_time = time.time() - start_time
    N, C, L, H, W = np.shape(gt)
    denoise_seq = denoise_seq.view(N, C, L, H, W)
    diff_denoise_seq, diff_gt = video_diff(denoise_seq, gt, 7)
    psnr_list.append(psnr(denoise_seq.detach(), gt.detach()))
    ssim_list.append(cal_ssim(denoise_seq, gt))
    tpsnr_list.append(psnr(diff_denoise_seq.detach(), diff_gt.detach()))
    tssim_list.append(cal_ssim(diff_denoise_seq, diff_gt))
    strred_score = cal_strred(denoise_seq, gt)
    print(strred_score)
    avg_time = avg_time + denoise_time
    print(
        'valid PSNR---%f, SSIM---%4f, TPSNR---%4f, TSSIM---%4f, average time ----%f'
        % (float(np.array(psnr_list).mean()), float(
            np.array(ssim_list).mean()), float(np.array(tpsnr_list).mean()),
           float(np.array(tssim_list).mean()), avg_time))
    f = open("test_delta.txt", "a+")
    f.write(
        "Epoch: %d, PSNR: %.3f, SSIM---%4f, TPSNR---%4f, TSSIM---%4f, run time: %.5f---"
        %
        (epoch_state, float(np.array(psnr_list).mean()),
         float(np.array(ssim_list).mean()), float(np.array(tpsnr_list).mean()),
         float(np.array(tssim_list).mean()), avg_time) + "\n")
    f.close()
    mkdir_if_not_exist(opt.save_dir)
    save_out = denoise_seq.permute(0, 2, 3, 4, 1)
    save_out = save_out.cpu().numpy()
    for frame_idx in range(7):
        output_img = save_out[0, frame_idx, ...]
        output_img[output_img > 1.0] = 1.0
        output_img[output_img < 0] = 0
        plt.imsave(os.path.join(opt.save_dir, str_format % frame_idx),
                   output_img)
예제 #2
0
def test(net,epoch_state):
    valid_set = ValidSetLoader(opt.dataset_dir, patch_size=opt.patch_size, input_frame=opt.input_frame)
    valid_loader = DataLoader(dataset=valid_set, num_workers=opt.threads, batch_size=1, shuffle=False)
    with open(opt.dataset_dir+'/sep_testlist.txt', 'r') as f:
        train_list = f.read().splitlines()
    psnr_list = []
    ssim_list= []
    tpsnr_list =[]
    tssim_list =[]
    sttred_score_list = []
    avg_time=0
    print('testing........')
    for idx_iter, clean_seq in enumerate(valid_loader):
        #print(np.shape(clean_seq))
        print("processing............",idx_iter)
        clean_seq = Variable(clean_seq).cuda()
        gt=clean_seq[:,:,int(opt.input_frame/2-opt.output_frame/2):int(opt.input_frame/2+opt.output_frame/2),:,:]
        noise = torch.empty_like(clean_seq).normal_(mean=0, std=opt.val_noiseL).to(torch.device('cuda'))
        noisy_clean_seq = clean_seq + noise
        start_time=time.time()
        denoise_seq = net(noisy_clean_seq)
        denoise_time=time.time()-start_time
        
        N,C,L,H,W=np.shape(gt)
        denoise_seq = denoise_seq.view(N,C,L,H,W)
        diff_denoise_seq,diff_gt=video_diff(denoise_seq,gt,7)
        psnr_list.append(psnr(denoise_seq.detach(), gt.detach()))
        ssim_list.append(cal_ssim(denoise_seq,gt))
        tpsnr_list.append(psnr(diff_denoise_seq.detach(), diff_gt.detach()))
        tssim_list.append(cal_ssim(diff_denoise_seq,diff_gt))
        sttred_score_list.append(cal_strred(denoise_seq,gt))
        avg_time=avg_time+denoise_time
        print(train_list[idx_iter])
        mkdir_if_not_exist(opt.save_dir+train_list[idx_iter])
        save_out=denoise_seq.permute(0,2,3,4,1)
        save_out=save_out.cpu().numpy()
        for frame_idx in range(7):
            output_img=save_out[0,frame_idx,...]
            output_img[output_img>1.0]=1.0
            output_img[output_img<0]=0
            plt.imsave(os.path.join(opt.save_dir+train_list[idx_iter], str_format % frame_idx),output_img)
    print('valid PSNR---%f, SSIM---%4f, TPSNR---%4f, TSSIM---%4f,STRRED_SCORE---%4f,  average time ----%f' % (float(np.array(psnr_list).mean()),float(np.array(ssim_list).mean()),float(np.array(tpsnr_list).mean()),float(np.array(tssim_list).mean()),float(np.array(sttred_score_list).mean()),avg_time/idx_iter))
    f=open("test_delta.txt","a+")
    f.write("Epoch: %d, PSNR: %.3f, SSIM---%4f, TPSNR---%4f, TSSIM---%4f, STRRED_SCORE---%4f, run time: %.5f---"%(epoch_state, float(np.array(psnr_list).mean()),float(np.array(ssim_list).mean()),float(np.array(tpsnr_list).mean()),float(np.array(tssim_list).mean()),float(np.array(sttred_score_list).mean()),avg_time/idx_iter)+"\n")
    f.close()
예제 #3
0
def train(train_loader, scale_factor, epoch_num):

    net = Net(scale_factor).cuda()

    epoch_state = 0
    loss_list = []
    psnr_list = []
    loss_epoch = []
    psnr_epoch = []

    if opt.resume:
        ckpt = torch.load(opt.resume)
        net.load_state_dict(ckpt['state_dict'])
        epoch_state = ckpt['epoch']
        loss_list = ckpt['loss']
        psnr_list = ckpt['psnr']

    optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr)
    criterion_MSE = torch.nn.MSELoss().cuda()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=opt.step,
                                                gamma=opt.gamma)
    for idx_epoch in range(epoch_state, epoch_num):
        for idx_iter, (LR, HR) in enumerate(train_loader):
            LR, HR = Variable(LR).cuda(), Variable(HR).cuda()
            SR = net(LR)

            loss = criterion_MSE(SR, HR[:, :, 3, :, :])
            loss_epoch.append(loss.detach().cpu())
            psnr_epoch.append(psnr(SR, HR[:, :, 3, :, :]))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        scheduler.step()
        if idx_epoch % 1 == 0:
            loss_list.append(float(np.array(loss_epoch).mean()))
            psnr_list.append(float(np.array(psnr_epoch).mean()))
            print(time.ctime()[4:-5] +
                  ' Epoch---%d, loss_epoch---%f, PSNR---%f' %
                  (idx_epoch + 1, float(np.array(loss_epoch).mean()),
                   float(np.array(psnr_epoch).mean())))
            save_checkpoint(
                {
                    'epoch': idx_epoch + 1,
                    'state_dict': net.state_dict(),
                    'loss': loss_list,
                    'psnr': psnr_list,
                },
                save_path=opt.save,
                filename='model' + str(scale_factor) + '_epoch' +
                str(idx_epoch + 1) + '.pth.tar')
            loss_epoch = []
            psnr_epoch = []
            valid(net)
예제 #4
0
def valid(net):
    valid_set = ValidSetLoader(opt.train_dataset_dir,
                               scale_factor=opt.scale_factor,
                               inType=opt.inType)
    valid_loader = DataLoader(dataset=valid_set,
                              num_workers=opt.threads,
                              batch_size=8,
                              shuffle=True)
    psnr_list = []
    for idx_iter, (LR, HR) in enumerate(valid_loader):
        LR, HR = Variable(LR).cuda(), Variable(HR).cuda()
        SR = net(LR)
        psnr_list.append(psnr(SR.detach(), HR[:, :, 3, :, :].detach()))
    print('valid PSNR---%f' % (float(np.array(psnr_list).mean())))