def reconstruct(self, coeff):

        if self.nbands != len(coeff[1]):
            raise Exception("Unmatched number of orientations")

        height, width = coeff[0].shape[2], coeff[0].shape[1] 
        log_rad, angle = math_utils.prepare_grid(height, width)

        Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
        Yrcos  = np.sqrt(Yrcos)
        YIrcos = np.sqrt(np.abs(1 - Yrcos**2))

        lo0mask = pointOp(log_rad, YIrcos, Xrcos)
        hi0mask = pointOp(log_rad, Yrcos, Xrcos)

        # Note that we expand dims to support broadcasting later
        lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device)
        hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device)

        # Start recursive reconstruction
        tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle)

        hidft = torch.rfft(coeff[0], signal_ndim=2, onesided=False)
        hidft = math_utils.batch_fftshift2d(hidft)

        outdft = tempdft * lo0mask + hidft * hi0mask

        reconstruction = math_utils.batch_ifftshift2d(outdft)
        reconstruction = torch.ifft(reconstruction, signal_ndim=2)
        reconstruction = torch.unbind(reconstruction, -1)[0]  # real

        return reconstruction
    def build(self, im_batch):
        ''' Decomposes a batch of images into a complex steerable pyramid. 
        The pyramid typically has ~4 levels and 4-8 orientations. 
        
        Args:
            im_batch (torch.Tensor): Batch of images of shape [N,C,H,W]
        
        Returns:
            pyramid: list containing torch.Tensor objects storing the pyramid
        '''
        
        assert im_batch.device == self.device, 'Devices invalid (pyr = {}, batch = {})'.format(self.device, im_batch.device)
        assert im_batch.dtype == torch.float32, 'Image batch must be torch.float32'
        assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]'
        assert im_batch.shape[1] == 1, 'Second dimension must be 1 encoding grayscale image'

        im_batch = im_batch.squeeze(1)  # flatten channels dim
        height, width = im_batch.shape[2], im_batch.shape[1] 
        
        # Check whether image size is sufficient for number of levels
        if self.height > int(np.floor(np.log2(min(width, height))) - 2):
            raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height))
        
        # Prepare a grid
        log_rad, angle = math_utils.prepare_grid(height, width)

        # Radial transition function (a raised cosine in log-frequency):
        Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
        Yrcos = np.sqrt(Yrcos)

        YIrcos = np.sqrt(1 - Yrcos**2)

        lo0mask = pointOp(log_rad, YIrcos, Xrcos)
        hi0mask = pointOp(log_rad, Yrcos, Xrcos)

        # Note that we expand dims to support broadcasting later
        lo0mask = torch.from_numpy(lo0mask).float()[None,:,:,None].to(self.device)
        hi0mask = torch.from_numpy(hi0mask).float()[None,:,:,None].to(self.device)

        # Fourier transform (2D) and shifting
        batch_dft = torch.rfft(im_batch, signal_ndim=2, onesided=False)
        batch_dft = math_utils.batch_fftshift2d(batch_dft)

        # Low-pass
        lo0dft = batch_dft * lo0mask

        # Start recursively building the pyramids
        coeff = self._build_levels(lo0dft, log_rad, angle, Xrcos, Yrcos, self.height-1)

        # High-pass
        hi0dft = batch_dft * hi0mask
        hi0 = math_utils.batch_ifftshift2d(hi0dft)
        hi0 = torch.ifft(hi0, signal_ndim=2)
        hi0_real = torch.unbind(hi0, -1)[0]
        coeff.insert(0, hi0_real)
        return coeff
    def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
        
        if height <= 1:

            # Low-pass
            lo0 = math_utils.batch_ifftshift2d(lodft)
            lo0 = torch.ifft(lo0, signal_ndim=2)
            lo0_real = torch.unbind(lo0, -1)[0]
            coeff = [lo0_real]

        else:
            
            Xrcos = Xrcos - np.log2(self.scale_factor)

            ####################################################################
            ####################### Orientation bandpass #######################
            ####################################################################

            himask = pointOp(log_rad, Yrcos, Xrcos)
            himask = torch.from_numpy(himask[None,:,:,None]).float().to(self.device)

            order = self.nbands - 1
            const = np.power(2, 2*order) * np.square(factorial(order)) / (self.nbands * factorial(2*order))
            #SF ADAPTATION: Modified from the below:
            #Ycosn = 2*np.sqrt(const) * np.power(np.cos(self.Xcosn), order) * (np.abs(self.alpha) < np.pi/2) # [n,]
            Ycosn = np.sqrt(const) * np.power(np.cos(self.Xcosn), order)

            # Loop through all orientation bands
            orientations = []
            for b in range(self.nbands):

                anglemask = pointOp(angle, Ycosn, self.Xcosn + np.pi*b/self.nbands)
                anglemask = anglemask[None,:,:,None]  # for broadcasting
                anglemask = torch.from_numpy(anglemask).float().to(self.device)

                # Bandpass filtering                
                banddft = lodft * anglemask * himask

                # Now multiply with complex number
                # (x+yi)(u+vi) = (xu-yv) + (xv+yu)i
                banddft = torch.unbind(banddft, -1)
                banddft_real = self.complex_fact_construct.real*banddft[0] - self.complex_fact_construct.imag*banddft[1]
                banddft_imag = self.complex_fact_construct.real*banddft[1] + self.complex_fact_construct.imag*banddft[0]
                banddft = torch.stack((banddft_real, banddft_imag), -1)

                band = math_utils.batch_ifftshift2d(banddft)
                band = torch.ifft(band, signal_ndim=2)
                #SF ADAPTATION: For SF pyramid, just take real part of band.
                band = band[..., 0] #Real part is first entry in last dimension
                orientations.append(band)

            ####################################################################
            ######################## Subsample lowpass #########################
            ####################################################################

            # Don't consider batch_size and imag/real dim
            dims = np.array(lodft.shape[1:3])  

            # Both are tuples of size 2
            low_ind_start = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(int)
            low_ind_end   = (low_ind_start + np.ceil((dims-0.5)/2)).astype(int)

            # Subsampling indices
            log_rad = log_rad[low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]]
            angle = angle[low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1]]

            # Actual subsampling
            lodft = lodft[:,low_ind_start[0]:low_ind_end[0],low_ind_start[1]:low_ind_end[1],:]

            # Filtering
            YIrcos = np.abs(np.sqrt(1 - Yrcos**2))
            lomask = pointOp(log_rad, YIrcos, Xrcos)
            lomask = torch.from_numpy(lomask[None,:,:,None]).float()
            lomask = lomask.to(self.device)

            # Convolution in spatial domain
            lodft = lomask * lodft

            ####################################################################
            ####################### Recursion next level #######################
            ####################################################################

            coeff = self._build_levels(lodft, log_rad, angle, Xrcos, Yrcos, height-1)
            coeff.insert(0, orientations)

        return coeff
Ejemplo n.º 4
0
ifft_numpy1 = np.fft.ifftshift(fft_numpy)
ifft_numpy = np.fft.ifft2(ifft_numpy1)

################################################################################
# Torch

device = torch.device('cpu')

im_torch = torch.from_numpy(im[None, :, :])  # add batch dim
im_torch = im_torch.to(device)

# fft = complex-to-complex, rfft = real-to-complex
fft_torch = torch.rfft(im_torch, signal_ndim=2, onesided=False)
fft_torch = fft_utils.batch_fftshift2d(fft_torch)

ifft_torch = fft_utils.batch_ifftshift2d(fft_torch)
ifft_torch = torch.ifft(ifft_torch, signal_ndim=2, normalized=False)

ifft_torch_to_numpy = ifft_torch.numpy()
ifft_torch_to_numpy = np.split(ifft_torch_to_numpy, 2,
                               -1)  # complex => real/imag
ifft_torch_to_numpy = np.squeeze(ifft_torch_to_numpy, -1)
ifft_torch_to_numpy = ifft_torch_to_numpy[0] + 1j * ifft_torch_to_numpy[1]
all_close_ifft = np.allclose(ifft_numpy, ifft_torch_to_numpy, atol=tolerance)
print('ifft all close: ', all_close_ifft)

fft_torch = fft_torch.cpu().numpy().squeeze()
fft_torch = np.split(fft_torch, 2, -1)  # complex => real/imag
fft_torch = np.squeeze(fft_torch, -1)
fft_torch = fft_torch[0] + 1j * fft_torch[1]