def forward(self, y, Fker, beta=None, z=None): im_num = y.shape[0] if z is None: z = torch.zeros_like(y) if beta is None: beta = torch.stack([self.beta.view(self.chn_num, 1, 1, 1).cuda()] * im_num, dim=0) xhat = torch.zeros_like(y) for i in range(im_num): shape = y[i, 0, ].size()[-2:] Fw = Fwv(self.dec2d, shape=shape).cuda() Fker_conj = cf.conj(Fker[i]).cuda() Fw_conj = cf.conj(Fw).cuda() Fy = cf.fft(y[i, 0, ]) Fz = cf.fft(z[i, 0, ]).cuda() Fx_num = cf.mul(Fker_conj, Fy) + torch.sum( beta[i] * cf.mul(Fw_conj, cf.mul(Fw, Fz)), dim=0) Fx_den = cf.abs_square(Fker[i], keepdim=True) + torch.sum( beta[i] * cf.mul(Fw_conj, Fw), dim=0) Fx = cf.div(Fx_num, Fx_den) xhat[i, 0, ] = cf.ifft(Fx) return xhat
def forward(self, y, Fker, z=None, u=None): if z is None: z = torch.zeros_like(y) if u is None: u = torch.zeros_like(y) im_num = y.shape[0] xhat = torch.zeros_like(y) for i in range(im_num): shape = y[i, 0, ].size()[-2:] Fw = Fwv(self.dec2d, shape=shape).cuda() Fker_conj = cf.conj(Fker[i]).cuda() Fw_conj = cf.conj(Fw).cuda() Fy = cf.fft( y[i, 0, ] - u[i, 0, ] ) # minus w to incorporate the prior approximation of noise Fz = cf.fft(z[i, 0, ]).cuda() Fx_num = cf.mul(Fker_conj, Fy) + torch.sum( self.lmd * cf.mul(Fw_conj, cf.mul(Fw, Fz)), dim=0) Fx_den = cf.abs_square(Fker[i], keepdim=True) + torch.sum( self.lmd * cf.mul(Fw_conj, Fw), dim=0) Fx = cf.div(Fx_num, Fx_den) xhat[i, 0, ] = cf.ifft(Fx) return xhat
def __getitem__(self, item): '''load test item one by one''' i = item // self.ker_num j = item % self.ker_num sp = imread(os.path.join(self.sp_dir, 'im_%d.png' % (i + 1))) bl_path = glob( os.path.join(self.bl_dir, 'im_%d_ker_%d*.png' % (i + 1, j + 1))) bl = imread(bl_path[0]) ker_name = glob( os.path.join(self.ker_dir, 'k_%d_im_%d_*' % (j + 1, i + 1))) ker = imread(ker_name[0]) ker = ker / np.sum(ker) tr_ker = self.get_ker(j) tr_ker_pad = np.full([50, 50], np.nan) tr_ker_pad[:tr_ker.shape[0], :tr_ker.shape[1]] = tr_ker tr_ker_mat = torch.FloatTensor(for_fft(tr_ker, shape=np.shape(sp))) tr_Fker = cf.fft(tr_ker_mat).unsqueeze(0) if self.taper == 'valid': from utils.imtools import pad_for_kernel, edgetaper bl = edgetaper(pad_for_kernel(bl, tr_ker, 'edge'), ker) bl = bl.astype(np.float32) ker_pad = np.full([50, 50], np.nan) ker_pad[:ker.shape[0], :ker.shape[1]] = ker ker_mat = torch.FloatTensor(for_fft(ker, shape=np.shape(sp))) Fker = cf.fft(ker_mat).unsqueeze(0) hy = (ker.shape[0] - 1) // 2 hx = (ker.shape[0] - 1) - hy wy = (ker.shape[1] - 1) // 2 wx = (ker.shape[1] - 1) - wy padding = np.array((hx, hy, wx, wy), dtype=np.int64) sp = torch.from_numpy(sp).unsqueeze(0) bl = torch.from_numpy(bl).unsqueeze(0) dic = { 'bl': bl, 'sp': sp, 'Fker': Fker, 'padding': padding.copy(), 'ker': ker_pad.copy(), 'tr_ker': tr_ker_pad.copy(), 'tr_Fker': tr_Fker, 'name': 'im_%d_ker_%d' % (i + 1, j + 1) } return dic
def __getitem__(self, item): '''load test item one by one''' ker_name = self.ker_file[item] sp_name = re.findall(r'(?<=s\/).+(?=_[\d*])', ker_name)[0] name = re.findall(r'(?<=s\/).+(?=.png)', ker_name)[0] bl = imread(self.bl_dir + sp_name + '.jpg').astype(np.float32) / 255 ker = imread(ker_name) if ker.ndim == 3: ker = rgb2gray(ker) ker = ker / np.sum(ker) # ker = np.rot90(ker,2) if self.taper == 'valid': from utils.imtools import pad_for_kernel, edgetaper bl_pad = [] for chn in range(3): bl_pad.append( edgetaper(pad_for_kernel(bl[:, :, chn], ker, 'edge'), ker).astype(np.float32)) bl = np.stack(bl_pad, axis=2) ker_pad = np.full([110, 110], np.nan) ker_pad[:ker.shape[0], :ker.shape[1]] = ker ker_mat = torch.FloatTensor(for_fft(ker, shape=np.shape(bl[:, :, 0]))) Fker = cf.fft(ker_mat).unsqueeze(0) bl = torch.from_numpy(bl).unsqueeze(0) imshow(bl, 'im%d_pad' % item) dic = {'bl': bl, 'Fker': Fker, 'ker': ker_pad.copy(), 'name': name} return dic
def blurbyker(z, Fker): im_num = z.size(0) ker_z = torch.zeros_like(z) for i in range(im_num): Fdn = cf.fft(z[i, 0, ]) ker_z[i, 0, ] = cf.ifft(cf.mul(Fker[i], Fdn)) return ker_z
def forward(self, x, y, Fker): im_num = x.size(0) kx = torch.zeros_like(x) for i in range(im_num): Fz = cf.fft(x[i, 0, ]) kx[i, 0, ] = cf.ifft(cf.mul(Fker[i], Fz)) res = y - kx u = self.net(res, x) return u
def __getitem__(self, item): i = item // self.ker_num j = item % self.ker_num sp = imread(os.path.join(self.sp_dir, 'im_%d.png' % (i + 1))) bl = imread( os.path.join(self.bl_dir, 'im_%d_ker_%d.png' % (i + 1, j + 1))) ker = self.get_ker(j) ker_pad = np.full([50, 50], np.nan) ker_pad[:ker.shape[0], :ker.shape[1]] = ker ker_mat = torch.FloatTensor(for_fft(ker, shape=np.shape(sp))) Fker = cf.fft(ker_mat) sp = torch.from_numpy(sp).unsqueeze(0) bl = torch.from_numpy(bl).unsqueeze(0) dic = {'bl': bl, 'sp': sp, 'Fker': Fker, 'ker': ker_pad.copy()} return dic
def __getitem__(self, item): '''load test item one by one''' ker_name = self.ker_file[item] sp_name = re.findall(r'psf_([\s\S]*)_kernel', ker_name)[0] tr_ker_name = re.findall(r'_kernel_([\s\S]*)_1', ker_name)[0] sp = imread(self.sp_dir + sp_name + '.png')[:, :, :3] bl = imread(self.bl_dir + sp_name + '_kernel_' + tr_ker_name + '.png') ker = rgb2gray(imread(ker_name)) ker = ker / np.sum(ker) ker = np.rot90(ker, 2) tr_ker = rgb2gray( imread('./data/Lai_NK/kernels/kernel_' + tr_ker_name + '.png')) tr_ker = tr_ker / np.sum(tr_ker) ker = center_ker(ker, tr_ker) tr_ker_pad = np.full([110, 110], np.nan) tr_ker_pad[:tr_ker.shape[0], :tr_ker.shape[1]] = tr_ker if self.taper == 'valid': from utils.imtools import pad_for_kernel, edgetaper bl_pad = np.zeros_like(sp) for chn in range(3): bl_pad[:, :, chn] = edgetaper( pad_for_kernel(bl[:, :, chn], tr_ker, 'edge'), ker).astype(np.float32) bl = bl_pad ker_mat = torch.FloatTensor( for_fft(tr_ker, shape=np.shape(sp[:, :, 0]))) tr_Fker = cf.fft(ker_mat).unsqueeze(0) ker_pad = np.full([110, 110], np.nan) ker_pad[:ker.shape[0], :ker.shape[1]] = ker ker_mat = torch.FloatTensor(for_fft(ker, shape=np.shape(sp[:, :, 0]))) Fker = cf.fft(ker_mat).unsqueeze(0) hy = (ker.shape[0] - 1) // 2 hx = (ker.shape[0] - 1) - hy wy = (ker.shape[1] - 1) // 2 wx = (ker.shape[1] - 1) - wy padding = np.array((hx, hy, wx, wy), dtype=np.int64) sp = torch.from_numpy(sp).unsqueeze(0) bl = torch.from_numpy(bl).unsqueeze(0) dic = { 'bl': bl, 'sp': sp, 'Fker': Fker, 'padding': padding.copy(), 'ker': ker_pad.copy(), 'tr_ker': tr_ker_pad.copy(), 'tr_Fker': tr_Fker, 'name': sp_name + '_' + tr_ker_name } return dic