Exemple #1
0
def main(cfg):
    video_name = cfg.video_name
    upscale_factor = cfg.upscale_factor
    use_gpu = cfg.gpu_mode

    test_set = TestsetLoader('data/'+ video_name, upscale_factor)
    test_loader = DataLoader(test_set, num_workers=1, batch_size=1, shuffle=False)
    net = SOFVSR(upscale_factor=upscale_factor)
    ckpt = torch.load('./log/SOFVSR_x' + str(upscale_factor) + '.pth')
    net.load_state_dict(ckpt)
    if use_gpu:
        net.cuda()


    for idx_iter, (LR_y_cube, SR_cb, SR_cr) in enumerate(test_loader):
        LR_y_cube = Variable(LR_y_cube)
        if use_gpu:
            LR_y_cube = LR_y_cube.cuda()
        SR_y = net(LR_y_cube)

        SR_y = np.array(SR_y.data)
        SR_y = SR_y[np.newaxis, :, :]

        SR_ycbcr = np.concatenate((SR_y, SR_cb, SR_cr), axis=0).transpose(1,2,0)
        SR_rgb = ycbcr2rgb(SR_ycbcr) * 255.0
        SR_rgb = np.clip(SR_rgb, 0, 255)
        SR_rgb = ToPILImage()(SR_rgb.astype(np.uint8))

        if not os.path.exists('results/' + video_name):
            os.mkdir('results/' + video_name)
        SR_rgb.save('results/'+video_name+'/sr_'+ str(idx_iter+2).rjust(2,'0') + '.png')
Exemple #2
0
    def extract_gall_feat(gall_loader):
        print('Extracting Gallery Feature...')
        ptr = 0
        gall_feat = np.zeros((ngall, feature_dim))

        rgbs = np.zeros((ngall, 256, 128, 3))
        learns = np.zeros((ngall, 256, 128, 3))

        for batch_idx, (input, label) in enumerate(gall_loader):
            batch_num = input.size(0)

            model.eval()
            with torch.no_grad():
                input = input.to('cuda')

                #feat = model(input)
                feat, rgb, x = model(input)

                rgbs[ptr:ptr + batch_num, :] = rgb
                learns[ptr:ptr + batch_num, :] = x

                gall_feat[ptr:ptr + batch_num, :] = feat.detach().cpu().numpy()
                ptr = ptr + batch_num

        from torchvision.transforms import ToPILImage
        import matplotlib.pyplot as plt
        img1 = np.mean(rgbs, axis=0)
        img2 = np.mean(learns, axis=0)

        #import pdb;pdb.set_trace()

        img1 = ToPILImage()(img1.astype(np.uint8))
        img2 = ToPILImage()(img2.astype(np.uint8))

        img1.save('regdb_RGB.jpg')
        img2.save('regdb_X.jpg')
        import pdb
        pdb.set_trace()

        return gall_feat
Exemple #3
0
def main(cfg):
    video_name = cfg.video_name
    upscale_factor = cfg.upscale_factor
    use_gpu = cfg.gpu_mode

    test_set = TestsetLoader('data/test/' + video_name, upscale_factor)
    test_loader = DataLoader(test_set,
                             num_workers=1,
                             batch_size=1,
                             shuffle=False)
    net = SOFVSR(upscale_factor=upscale_factor)
    ckpt = torch.load('./log/SOFVSR_x' + str(upscale_factor) + '.pth')
    net.load_state_dict(ckpt)
    if use_gpu:
        net.cuda()

    for idx_iter, (LR_y_cube, SR_cb, SR_cr) in enumerate(test_loader):
        LR_y_cube = Variable(LR_y_cube)
        if use_gpu:
            LR_y_cube = LR_y_cube.cuda()
            if cfg.chop_forward:
                # crop borders to ensure each patch can be divisible by 2
                _, _, h, w = LR_y_cube.size()
                h = int(h // 16) * 16
                w = int(w // 16) * 16
                LR_y_cube = LR_y_cube[:, :, :h, :w]
                SR_cb = SR_cb[:, :h * upscale_factor, :w * upscale_factor]
                SR_cr = SR_cr[:, :h * upscale_factor, :w * upscale_factor]
                SR_y = chop_forward(LR_y_cube, net, cfg.upscale_factor)
            else:
                SR_y = net(LR_y_cube)
            SR_y = SR_y.cpu()
        else:
            SR_y = net(LR_y_cube)

        SR_y = np.array(SR_y.data)
        SR_y = SR_y[np.newaxis, :, :]

        SR_ycbcr = np.concatenate((SR_y, SR_cb, SR_cr),
                                  axis=0).transpose(1, 2, 0)
        SR_rgb = ycbcr2rgb(SR_ycbcr) * 255.0
        SR_rgb = np.clip(SR_rgb, 0, 255)
        SR_rgb = ToPILImage()(SR_rgb.astype(np.uint8))

        if not os.path.exists('results/' + video_name):
            os.mkdir('results/' + video_name)
        SR_rgb.save('results/' + video_name + '/sr_' +
                    str(idx_iter + 2).rjust(2, '0') + '.png')
Exemple #4
0
def main():
    
	'''
	训练时时并行的,测试时也应当并行,不然会报告如下的错误:
	Missing key(s) in state_dict: ...(如:conv1.weight)
	'''
	print('testing processing....')

	#加载模型
	test_model = VRCNN(opt.upscale_factor)
	test_model = torch.nn.DataParallel(test_model,device_ids=gpus_list,output_device=gpus_list[1])

	test_model = test_model.cuda(gpus_list[0])

	print('---------- Networks architecture -------------')
	print_network(test_model)
	print('----------------------------------------------')

	#加载预训练模型
	model_name = os.path.join(opt.model_save_folder,opt.exp_name,opt.test_model)
	print('model_name=',model_name)
	if os.path.exists(model_name):
		pretrained_dict=torch.load(model_name,map_location=lambda storage, loc: storage)
		model_dict=test_model.state_dict()
		# 1. filter out unnecessary keys
		pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
		# 2. overwrite entries in the existing state dict
		model_dict.update(pretrained_dict)
		test_model.load_state_dict(model_dict)
		print('Pre-trained SR model is loaded.')

	if not os.path.exists(opt.pre_result):
		os.mkdir(opt.pre_result)

	with open(opt.train_log + '/psnr_ssim-xr-200.txt', 'a') as psnr_ssim:
		with torch.no_grad():
			ave_psnr = 0
			ave_ssim = 0
			single_ave_psnr = 0
			single_ave_ssim = 0
			numb = 2
			valSet = ValidationsetLoader(opt.val_dataset_hr,opt.val_dataset_lr)
			valLoader = DataLoader(dataset=valSet,batch_size=opt.test_val_batchSize,shuffle=False)
			val_bar = tqdm(valLoader)
			for data in val_bar:
				test_model.eval()
				# dual_net.eval()
				batch_lr_y, label, SR_cb,SR_cr,idx,bicubic_restore = data
				batch_lr_y,label = Variable(batch_lr_y).cuda(gpus_list[0]), Variable(label).cuda(gpus_list[0])
				output = test_model(batch_lr_y)

				SR_ycbcr = np.concatenate((np.array(output.squeeze(0).data.cpu()), SR_cb, SR_cr), axis=0).transpose(1,2,0)            
				SR_rgb = ycbcr2rgb(SR_ycbcr) * 255.0
				SR_rgb = np.clip(SR_rgb, 0, 255)
				SR_rgb = ToPILImage()(SR_rgb.astype(np.uint8))
				#ToTensor() ---image(0-255)==>image(0-1), (H,W,C)==>(C,H,W)
				SR_rgb = ToTensor()(SR_rgb)

				#将给定的Tensor保存成image文件。如果给定的是mini-batch tensor,那就用make-grid做成雪碧图,再保存。与utils.make_grid()配套使用
				if not os.path.exists(opt.pre_result+'/'+opt.exp_name):
					os.mkdir(opt.pre_result+'/'+opt.exp_name)
				utils.save_image(SR_rgb, opt.pre_result+'/' +opt.exp_name +'/' + 'my'+str(numb).rjust(3,'0')+'.png') 
				numb = numb + 1

				psnr_value =  psnr(np.array(torch.squeeze(label).data.cpu())*255,np.array(torch.squeeze(output).data.cpu())*255)
				ave_psnr = ave_psnr + psnr_value
				single_ave_psnr = single_ave_psnr + psnr_value
				ssim_value =  calculate_ssim(np.array(torch.squeeze(label).data.cpu())*255,np.array(torch.squeeze(output).data.cpu())*255)
				ave_ssim = ave_ssim + ssim_value
				single_ave_ssim = single_ave_ssim + ssim_value
				
				val_bar.set_description('===>{}th video {}th frame, wsPSNR:{:.4f} dB,wsSSIM:{:.6f}'.format(idx // 98 + 1,idx % 98 + 1,psnr_value,ssim_value))
				
				if idx == 293 or idx == 97 or idx == 195 or idx == 391:
					print("===> {}th video Avg. wsPSNR: {:.4f} dB".format(idx // 98+1,single_ave_psnr / 98))
					print("===> {}th video Avg. wsSSIM: {:.6f}".format(idx // 98+1,single_ave_ssim / 98))
					psnr_ssim.write('===>{}th video avg wsPSNR:{:.4f} dB,wsSSIM:{:.6f}\n'.format(idx // 98+1,single_ave_psnr / 98,single_ave_ssim / 98))
					single_ave_psnr = 0
					single_ave_ssim = 0

			print("===> All Avg. wsPSNR: {:.4f} dB".format(ave_psnr / len(valLoader)))
			print("===> ALL Avg. wsSSIM: {:.6f}".format(ave_ssim / len(valLoader)))
			psnr_ssim.write('===>all videos avg wsPSNR:{:.4f} dB,wsSSIM:{:.6f}\n'.format(ave_psnr / len(valLoader),ave_ssim / len(valLoader)))

	print('testing finished!')
Exemple #5
0
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        if self.augment:
            # Set up the same image transforms for the chunk
            self.flip_p = random.random()
            self.hflip_p = random.random()
            # self.rot = RandomRotation.get_params((0, 90))
            self.color_transform = ColorJitter.get_params(brightness=(0.3,
                                                                      1.5),
                                                          contrast=(0.8, 1.2),
                                                          saturation=(0.8,
                                                                      1.2),
                                                          hue=(0, 0))

        # Construct target signals
        target = tr.tensor(self.label[idx])

        # Construct networks input
        A = tr.empty(self.C, self.H, self.W, dtype=tr.float)
        M = tr.empty(self.C, self.H, self.W, dtype=tr.float)

        with h5py.File(self.db_path, 'r') as db:
            frames = db['frames']
            img1 = frames[idx, :, :, :]
            img2 = frames[idx + 1, :, :, :]

            # ----------------------------
            # Crop baby with yolo
            # ----------------------------
            if self.crop and self.is_bbox:
                bbox = db['bbox'][idx, :]
                y1, y2, x1, x2 = bbox[0], bbox[1], bbox[2], bbox[3]

                # check to be inside image size
                if y2 > img1.shape[0]:
                    y2 = img1.shape[0]
                if x2 > img1.shape[1]:
                    x2 = img1.shape[1]
                if y1 < 0:
                    y1 = 0
                if x1 < 0:
                    x1 = 0
                # check validity
                if y2 - y1 < 1 or x2 - x1 < 1:
                    y1 = x1 = 0
                    y2, x2 = img1.shape[:2]

                img1 = img1[y1:y2, x1:x2, :]
                img2 = img2[y1:y2, x1:x2, :]
            elif self.crop and not self.is_bbox:
                x1, y1, x2, y2 = babybox(self.yolo, img1, self.device)
                img1 = img1[y1:y2, x1:x2, :]
                img2 = img2[y1:y2, x1:x2, :]

        # Downsample image
        try:
            img1 = cv2.resize(img1, (self.H, self.W),
                              interpolation=cv2.INTER_CUBIC)
            img2 = cv2.resize(img2, (self.H, self.W),
                              interpolation=cv2.INTER_CUBIC)
        except:
            print('\n--------- ERROR! -----------\nUsual cv empty error')
            print(f'Shape of img1: {img1.shape}; Shape of im2: {img2.shape}')
            print(f'bbox: {bbox}')
            print(f'This is at idx: {idx}')
            exit(666)

        if self.augment:
            img1 = ToPILImage()(img1)
            img2 = ToPILImage()(img2)
            if self.flip_p > 0.5:
                img1 = TF.vflip(img1)
                img2 = TF.vflip(img2)
            if self.flip_p > 0.5:
                img1 = TF.hflip(img1)
                img2 = TF.hflip(img2)

            # img1 = TF.rotate(img1, self.rot)
            # img2 = TF.rotate(img2, self.rot)

            img1 = self.color_transform(img1)
            img2 = self.color_transform(img2)

            img1 = tr.from_numpy(np.array(img1).astype(np.float32))
            img2 = tr.from_numpy(np.array(img2).astype(np.float32))
            img1 = img1.permute(2, 0, 1)
            img2 = img2.permute(2, 0, 1)
        else:
            img1 = tr.from_numpy(img1.astype(np.float32))
            img2 = tr.from_numpy(img2.astype(np.float32))
            # Swap axes because  numpy image: H x W x C | torch image: C X H X W
            img1 = img1.permute(2, 0, 1)
            img2 = img2.permute(2, 0, 1)

        # 2.) construct the normalized frame difference for motion stream input
        M = tr.div(img2 - img1, img1 + img2 + 1)  # +1 for numerical stability
        # M = tr.sub(M, tr.mean(M, (1, 2)).view(3, 1, 1))  # spatial intensity norm for each channel

        A = img1 / 255.  # convert image to [0, 1]
        A = tr.sub(A,
                   tr.mean(A, (1, 2)).view(
                       3, 1, 1))  # spatial intensity norm for each channel

        sample = ((A, M), target)

        # Video shape: C x D x H X W
        return sample