def coldfilt(X, ha, hb, highpass=False, mode='symmetric'): if X is None or X.shape == torch.Size([]): return torch.zeros(1,1,1,1, device=X.device) batch, ch, r, c = X.shape r2 = r // 2 if r % 4 != 0: raise ValueError('No. of rows in X must be a multiple of 4\n' + 'X was {}'.format(X.shape)) if mode == 'symmetric': m = ha.shape[2] xe = symm_pad(r, m) X1 = X[:,:,xe[2::2]] X2 = X[:,:,xe[3::2]] h = torch.cat((ha.repeat(ch, 1, 1, 1), hb.repeat(ch, 1, 1, 1)), dim=0) Y = F.conv2d(torch.cat((X1, X2), dim=1), h, stride=(2,1), groups=ch*2) Y1 = Y[:, :ch] Y2 = Y[:, ch:] del Y else: raise NotImplementedError() # Reshape result to be shape [Batch, ch, r/2, c]. This reshaping # interleaves the columns if highpass: Y = torch.stack((Y2, Y1), dim=-2).view(batch, ch, r2, c) else: Y = torch.stack((Y1, Y2), dim=-2).view(batch, ch, r2, c) return Y
def colfilter(X, h): if X is None or X.shape == torch.Size([0]): return torch.zeros(1,1,1,1, device=X.device) ch, r = X.shape[1:3] m = h.shape[2] // 2 xe = symm_pad(r, m) return F.conv2d(X[:,:,xe], h.repeat(ch,1,1,1), groups=ch)
def rowdfilt(X, ha, hb, highpass=False, mode='symmetric'): if X is None or X.shape == torch.Size([]): return torch.zeros(1, 1, 1, 1, device=X.device) batch, ch, r, c = X.shape c2 = c // 2 if c % 4 != 0: raise ValueError('No. of cols in X must be a multiple of 4\n' + 'X was {}'.format(X.shape)) if mode == 'symmetric': m = ha.shape[2] xe = symm_pad(c, m) X = torch.cat((X[:, :, :, xe[2::2]], X[:, :, :, xe[3::2]]), dim=1) h = torch.cat( (ha.reshape(1, 1, 1, m).repeat(ch, 1, 1, 1), hb.reshape( 1, 1, 1, m).repeat(ch, 1, 1, 1)), dim=0) X = F.conv2d(X, h, stride=(1, 2), groups=ch * 2) else: raise NotImplementedError() # Reshape result to be shape [Batch, ch, r/2, c]. This reshaping # interleaves the columns if highpass: Y = torch.stack((X[:, ch:], X[:, :ch]), dim=-1).view(batch, ch, r, c2) else: Y = torch.stack((X[:, :ch], X[:, ch:]), dim=-1).view(batch, ch, r, c2) return Y
def rowfilter(X, h): if X is None or X.shape == torch.Size([0]): return torch.zeros(1,1,1,1, device=X.device) ch, _, c = X.shape[1:] m = h.shape[2] // 2 xe = symm_pad(c, m) h = h.transpose(2,3).contiguous() return F.conv2d(X[:,:,:,xe], h.repeat(ch,1,1,1), groups=ch)
def colfilter(X, h, mode='symmetric'): if X is None or X.shape == torch.Size([]): return torch.zeros(1,1,1,1, device=X.device) ch, r = X.shape[1:3] m = h.shape[2] // 2 if mode == 'symmetric': xe = symm_pad(r, m) y = F.conv2d(X[:,:,xe], h.repeat(ch,1,1,1), groups=ch) else: y = F.conv2d(X, h.repeat(ch, 1, 1, 1), groups=ch, padding=(m, 0)) return y
def colifilt(X, ha, hb, highpass=False, mode='symmetric'): if X is None or X.shape == torch.Size([]): return torch.zeros(1, 1, 1, 1, device=X.device) m = ha.shape[2] m2 = m // 2 hao = ha[:, :, 1::2] hae = ha[:, :, ::2] hbo = hb[:, :, 1::2] hbe = hb[:, :, ::2] batch, ch, r, c = X.shape if r % 2 != 0: raise ValueError('No. of rows in X must be a multiple of 2.\n' + 'X was {}'.format(X.shape)) xe = symm_pad(r, m2) if m2 % 2 == 0: h1 = hae h2 = hbe h3 = hao h4 = hbo if highpass: X = torch.cat((X[:, :, xe[1:-2:2]], X[:, :, xe[:-2:2]], X[:, :, xe[3::2]], X[:, :, xe[2::2]]), dim=1) else: X = torch.cat((X[:, :, xe[:-2:2]], X[:, :, xe[1:-2:2]], X[:, :, xe[2::2]], X[:, :, xe[3::2]]), dim=1) else: h1 = hao h2 = hbo h3 = hae h4 = hbe if highpass: X = torch.cat((X[:, :, xe[2:-1:2]], X[:, :, xe[1:-1:2]], X[:, :, xe[2:-1:2]], X[:, :, xe[1:-1:2]]), dim=1) else: X = torch.cat((X[:, :, xe[1:-1:2]], X[:, :, xe[2:-1:2]], X[:, :, xe[1:-1:2]], X[:, :, xe[2:-1:2]]), dim=1) h = torch.cat( (h1.repeat(ch, 1, 1, 1), h2.repeat(ch, 1, 1, 1), h3.repeat( ch, 1, 1, 1), h4.repeat(ch, 1, 1, 1)), dim=0) X = F.conv2d(X, h, groups=4 * ch) # Stack 4 tensors of shape [batch, ch, r2, c] into one tensor # [batch, ch, r2, 4, c] X = torch.stack( [X[:, :ch], X[:, ch:2 * ch], X[:, 2 * ch:3 * ch], X[:, 3 * ch:]], dim=3).view(batch, ch, r * 2, c) return X
def rowfilter(X, h, mode='symmetric'): if X is None or X.shape == torch.Size([]): return torch.zeros(1,1,1,1, device=X.device) ch, _, c = X.shape[1:] m = h.shape[2] // 2 h = h.transpose(2,3).contiguous() if mode == 'symmetric': xe = symm_pad(c, m) y = F.conv2d(X[:,:,:,xe], h.repeat(ch,1,1,1), groups=ch) else: y = F.conv2d(X, h.repeat(ch,1,1,1), groups=ch, padding=(0, m)) return y
def rowifilt(X, ha, hb, highpass=False, mode='symmetric'): if X is None or X.shape == torch.Size([]): return torch.zeros(1, 1, 1, 1, device=X.device) m = ha.shape[2] m2 = m // 2 hao = ha[:, :, 1::2] hae = ha[:, :, ::2] hbo = hb[:, :, 1::2] hbe = hb[:, :, ::2] batch, ch, r, c = X.shape if c % 2 != 0: raise ValueError('No. of cols in X must be a multiple of 2.\n' + 'X was {}'.format(X.shape)) xe = symm_pad(c, m2) if m2 % 2 == 0: h1 = hae h2 = hbe h3 = hao h4 = hbo X1 = X[:, :, :, xe[:-2:2]] X2 = X[:, :, :, xe[1:-2:2]] X3 = X[:, :, :, xe[2::2]] X4 = X[:, :, :, xe[3::2]] else: h1 = hao h2 = hbo h3 = hae h4 = hbe X1 = X[:, :, :, xe[1:-1:2]] X2 = X[:, :, :, xe[2:-1:2]] X3 = X[:, :, :, xe[1:-1:2]] X4 = X[:, :, :, xe[2:-1:2]] if highpass: X2, X1 = X1, X2 X4, X3 = X3, X4 h = torch.cat( (h1.repeat(ch, 1, 1, 1), h2.repeat(ch, 1, 1, 1), h3.repeat( ch, 1, 1, 1), h4.repeat(ch, 1, 1, 1)), dim=0).reshape(4 * ch, 1, 1, m2) X = torch.cat((X1, X2, X3, X4), dim=1) Y = F.conv2d(X, h, groups=4 * ch) # Stack 4 tensors of shape [batch, ch, r2, c] into one tensor # [batch, ch, r2, 4, c] Y = torch.stack( [Y[:, :ch], Y[:, ch:2 * ch], Y[:, 2 * ch:3 * ch], Y[:, 3 * ch:]], dim=4).view(batch, ch, r, c * 2) return Y