def reconstruct(self, coeff):

        if self.nbands != len(coeff[1]):
            raise Exception("Unmatched number of orientations")

        height, width = coeff[0].shape[2], coeff[0].shape[1] 
        log_rad, angle = math_utils.prepare_grid(height, width)

        Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
        Yrcos  = np.sqrt(Yrcos)
        YIrcos = np.sqrt(np.abs(1 - Yrcos**2))

        lo0mask = pointOp(log_rad, YIrcos, Xrcos)
        hi0mask = pointOp(log_rad, Yrcos, Xrcos)

        # Note that we expand dims to support broadcasting later
        lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device)
        hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device)

        # Start recursive reconstruction
        tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle)

        hidft = torch.rfft(coeff[0], signal_ndim=2, onesided=False)
        hidft = math_utils.batch_fftshift2d(hidft)

        outdft = tempdft * lo0mask + hidft * hi0mask

        reconstruction = math_utils.batch_ifftshift2d(outdft)
        reconstruction = torch.ifft(reconstruction, signal_ndim=2)
        reconstruction = torch.unbind(reconstruction, -1)[0]  # real

        return reconstruction
    def build(self, im_batch):
        ''' Decomposes a batch of images into a complex steerable pyramid. 
        The pyramid typically has ~4 levels and 4-8 orientations. 
        
        Args:
            im_batch (torch.Tensor): Batch of images of shape [N,C,H,W]
        
        Returns:
            pyramid: list containing torch.Tensor objects storing the pyramid
        '''
        
        assert im_batch.device == self.device, 'Devices invalid (pyr = {}, batch = {})'.format(self.device, im_batch.device)
        assert im_batch.dtype == torch.float32, 'Image batch must be torch.float32'
        assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]'
        assert im_batch.shape[1] == 1, 'Second dimension must be 1 encoding grayscale image'

        im_batch = im_batch.squeeze(1)  # flatten channels dim
        height, width = im_batch.shape[2], im_batch.shape[1] 
        
        # Check whether image size is sufficient for number of levels
        if self.height > int(np.floor(np.log2(min(width, height))) - 2):
            raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height))
        
        # Prepare a grid
        log_rad, angle = math_utils.prepare_grid(height, width)

        # Radial transition function (a raised cosine in log-frequency):
        Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
        Yrcos = np.sqrt(Yrcos)

        YIrcos = np.sqrt(1 - Yrcos**2)

        lo0mask = pointOp(log_rad, YIrcos, Xrcos)
        hi0mask = pointOp(log_rad, Yrcos, Xrcos)

        # Note that we expand dims to support broadcasting later
        lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device)
        hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device)

        # Fourier transform (2D) and shifting
        batch_dft = torch.rfft(im_batch, signal_ndim=2, onesided=False)
        batch_dft = math_utils.batch_fftshift2d(batch_dft)

        # Low-pass
        lo0dft = batch_dft * lo0mask

        # Start recursively building the pyramids
        coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height-1)

        # High-pass
        hi0dft = batch_dft * hi0mask
        hi0 = math_utils.batch_ifftshift2d(hi0dft)
        hi0 = torch.ifft(hi0, signal_ndim=2)
        hi0_real = torch.unbind(hi0, -1)[0]
        coeff.insert(0, hi0_real)
        return coeff
    def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):

        if len(coeff) == 1:
            dft = torch.rfft(coeff[0], signal_ndim=2, onesided=False)
            dft = math_utils.batch_fftshift2d(dft)
            return dft

        Xrcos = Xrcos - np.log2(self.scale_factor)

        ####################################################################
        ####################### Orientation Residue ########################
        ####################################################################

        himask = pointOp(log_rad, Yrcos, Xrcos)
        himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device)

        lutsize = 1024
        Xcosn = np.pi * np.array(range(-(2*lutsize+1), (lutsize+2)))/lutsize
        order = self.nbands - 1
        const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order))
        Ycosn = np.sqrt(const) * np.power(np.cos(Xcosn), order)

        orientdft = torch.zeros_like(coeff[0][0])
        for b in range(self.nbands):

            anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands)
            anglemask = anglemask[None,:,:,None]  # for broadcasting
            anglemask = torch.from_numpy(anglemask).float().to(self.device)

            banddft = torch.fft(coeff[0][b], signal_ndim=2)
            banddft = math_utils.batch_fftshift2d(banddft)

            banddft = banddft * anglemask * himask
            banddft = torch.unbind(banddft, -1)
            banddft_real = self.complex_fact_reconstruct.real*banddft[0] - self.complex_fact_reconstruct.imag*banddft[1]
            banddft_imag = self.complex_fact_reconstruct.real*banddft[1] + self.complex_fact_reconstruct.imag*banddft[0]
            banddft = torch.stack((banddft_real, banddft_imag), -1)

            orientdft = orientdft + banddft

        ####################################################################
        ########## Lowpass component are upsampled and convolved ##########
        ####################################################################
        
        dims = np.array(coeff[0][0].shape[1:3])
        
        lostart = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(np.int32)
        loend = lostart + np.ceil((dims-0.5)/2).astype(np.int32)

        nlog_rad = log_rad[lostart[0]:loend[0], lostart[1]:loend[1]]
        nangle = angle[lostart[0]:loend[0], lostart[1]:loend[1]]
        YIrcos = np.sqrt(np.abs(1 - Yrcos**2))
        lomask = pointOp(nlog_rad, YIrcos, Xrcos)

        # Filtering
        lomask = pointOp(nlog_rad, YIrcos, Xrcos)
        lomask = torch.from_numpy(lomask[None,:,:,None])
        lomask = lomask.float().to(self.device)

        ################################################################################

        # Recursive call for image reconstruction        
        nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle)

        resdft = torch.zeros_like(coeff[0][0]).to(self.device)
        resdft[:,lostart[0]:loend[0], lostart[1]:loend[1],:] = nresdft * lomask

        return resdft + orientdft
예제 #4
0
fft_numpy_ang_viz = np.angle(fft_numpy)

ifft_numpy1 = np.fft.ifftshift(fft_numpy)
ifft_numpy = np.fft.ifft2(ifft_numpy1)

################################################################################
# Torch

device = torch.device('cpu')

im_torch = torch.from_numpy(im[None, :, :])  # add batch dim
im_torch = im_torch.to(device)

# fft = complex-to-complex, rfft = real-to-complex
fft_torch = torch.rfft(im_torch, signal_ndim=2, onesided=False)
fft_torch = fft_utils.batch_fftshift2d(fft_torch)

ifft_torch = fft_utils.batch_ifftshift2d(fft_torch)
ifft_torch = torch.ifft(ifft_torch, signal_ndim=2, normalized=False)

ifft_torch_to_numpy = ifft_torch.numpy()
ifft_torch_to_numpy = np.split(ifft_torch_to_numpy, 2,
                               -1)  # complex => real/imag
ifft_torch_to_numpy = np.squeeze(ifft_torch_to_numpy, -1)
ifft_torch_to_numpy = ifft_torch_to_numpy[0] + 1j * ifft_torch_to_numpy[1]
all_close_ifft = np.allclose(ifft_numpy, ifft_torch_to_numpy, atol=tolerance)
print('ifft all close: ', all_close_ifft)

fft_torch = fft_torch.cpu().numpy().squeeze()
fft_torch = np.split(fft_torch, 2, -1)  # complex => real/imag
fft_torch = np.squeeze(fft_torch, -1)