Exemplo n.º 1
0
def general_gaussian_blur_3D_periodic(u, fk, fi, fj):
    'Applies a separable convolution whose Fourier filters are (fk,fi,fj) on u, typically anisotropic Gaussian blur '
    fU = torch.rfft(u, 1, onesided=False)  # 1D Fourier Transform, (Z,X,Y,2)
    fU = complex_mult(fU, fj)
    U = torch.irfft(fU, 1, onesided=False).permute(0, 2, 1)  # (Z,Y,X)

    fU = torch.rfft(U, 1, onesided=False)  # 1D Fourier Transform, (Z,Y,X,2)
    fU = complex_mult(fU, fi)
    U = torch.irfft(fU, 1, onesided=False).permute(2, 1, 0)  # (X,Y,Z)

    fU = torch.rfft(U, 1, onesided=False)  # 1D Fourier Transform, (X,Y,Z,2)
    fU = complex_mult(fU, fk)
    U = torch.irfft(fU, 1, onesided=False).permute(2, 0, 1)  # (Z,X,Y)

    return U
def Decomposition(input, N):

    High_freq = rfft(input.unsqueeze(3), signal_ndim=3)
    Low_freq = torch.FloatTensor(High_freq.shape).cuda()

    width, height = High_freq.shape[1], High_freq.shape[2]

    for x in range(N//2, width - N //2):
        for y in range(N//2, height - N // 2):
    # for x in range(0, N):
    #     for y in range(0, N):
            Low_freq[:,x,y] = High_freq[:,x,y]
    High_freq -= Low_freq
    # return High_freq, Low_freq
    return irfft(High_freq, signal_ndim=3).squeeze(3), irfft(Low_freq, signal_ndim=3).squeeze(3)
Exemplo n.º 3
0
    def forward(ctx, h1, s1, h2, s2, output_size, x, y, force_cpu_scatter_add=False):
        ctx.save_for_backward(h1, s1, h2, s2, x, y)
        ctx.x_size = tuple(x.size())
        ctx.y_size = tuple(y.size())
        ctx.force_cpu_scatter_add = force_cpu_scatter_add
        ctx.output_size = output_size

        # Compute the count sketch of each input
        px = CountSketchFn_forward(
            h1, s1, output_size, x, force_cpu_scatter_add)
        fx = torch.rfft(px, 1)
        re_fx = fx.select(-1, 0)
        im_fx = fx.select(-1, 1)
        del px
        py = CountSketchFn_forward(
            h2, s2, output_size, y, force_cpu_scatter_add)
        fy = torch.rfft(py, 1)
        re_fy = fy.select(-1, 0)
        im_fy = fy.select(-1, 1)
        del py

        # Convolution of the two sketch using an FFT.
        # Compute the FFT of each sketch

        # Complex multiplication
        re_prod, im_prod = ComplexMultiply_forward(re_fx, im_fx, re_fy, im_fy)

        # Back to real domain
        # The imaginary part should be zero's
        re = torch.irfft(torch.stack((re_prod, im_prod),
                                     re_prod.dim()), 1, signal_sizes=(output_size,))

        return re
Exemplo n.º 4
0
 def inner():
     scaled_spectrum_t = scale * spectrum_real_imag_t
     image = torch.irfft(scaled_spectrum_t, 2, normalized=True, signal_sizes=(h, w))
     image = image[:batch, :channels, :h, :w]
     magic = 4.0 # Magic constant from Lucid library; increasing this seems to reduce saturation
     image = image / magic
     return image
Exemplo n.º 5
0
 def synth(self, y_orig, params):
     wetdry = params[:, 0]
     decay = params[:, 1]
     # Pad the input sequence
     y_orig = nn.functional.pad(y_orig, (0, self.size), "constant", 0)
     # Compute STFT
     Y_S = torch.rfft(y, 1)
     # Compute the current impulse response
     idx = torch.sigmoid(wetdry) * identity
     imp = torch.sigmoid(1 - wetdry) * y_orig
     dcy = torch.exp(-(torch.exp(decay) + 2) *
                     torch.linspace(0, 1, self.size).to(y_orig.device))
     final_impulse = idx + imp * dcy
     # Pad the impulse response
     impulse = nn.functional.pad(final_impulse, (0, self.size), "constant",
                                 0)
     if y.shape[-1] > self.size:
         impulse = nn.functional.pad(impulse,
                                     (0, y.shape[-1] - impulse.shape[-1]),
                                     "constant", 0)
     IR_S = torch.rfft(impulse.detach(), 1).expand_as(Y_S)
     # Apply the reverb
     Y_S_CONV = torch.zeros_like(IR_S)
     Y_S_CONV[:, :,
              0] = Y_S[:, :, 0] * IR_S[:, :, 0] - Y_S[:, :, 1] * IR_S[:, :,
                                                                      1]
     Y_S_CONV[:, :,
              1] = Y_S[:, :, 0] * IR_S[:, :, 1] + Y_S[:, :, 1] * IR_S[:, :,
                                                                      0]
     # Invert the reverberated signal
     y = torch.irfft(Y_S_CONV, 1, signal_sizes=(y.shape[-1], ))
     return y
Exemplo n.º 6
0
    def interaction_function(
        self,
        f: torch.FloatTensor,
        g: torch.FloatTensor,
    ) -> torch.FloatTensor:
        """Evaluate the interaction function for given embeddings.
        The embeddings have to be in a broadcastable shape.
        :param h: shape: (batch_size, num_entities, d)
            Head embeddings.
        :param r: shape: (batch_size, num_entities, d)
            Relation embeddings.
        :param t: shape: (batch_size, num_entities, d)
            Tail embeddings.
        :return: shape: (batch_size, num_entities)
            The scores.
        """
        # Circular correlation of entity embeddings
        a_fft = torch.rfft(f, signal_ndim=1, onesided=True)
        b_fft = torch.rfft(g, signal_ndim=1, onesided=True)

        # complex conjugate, a_fft.shape = (batch_size, num_entities, d', 2)
        a_fft[:, :, 1] *= -1

        # Hadamard product in frequency domain
        p_fft = a_fft * b_fft

        # inverse real FFT, shape: (batch_size, num_entities, d)
        composite = torch.irfft(p_fft,
                                signal_ndim=1,
                                onesided=True,
                                signal_sizes=(f.shape[-1], ))

        return composite

        return scores
Exemplo n.º 7
0
def InverseLPFinFreq(input,sigma):
    ax = np.arange(-input.size(2) // 2 + 1., input.size(3) // 2 + 1.)
    xx, yy = np.meshgrid(ax, ax)
    kernel = np.exp(-(xx**2 + yy**2) / (2. * sigma**2))
    kernel = kernel / np.sum(kernel)

    kernel = Variable((torch.from_numpy(kernel).float()).cuda()).view(1,1,input.size(2),input.size(3))
    kernel = kernel.expand(input.size(0),input.size(1),input.size(2),input.size(3))
    kernel = 1/(kernel / torch.max(kernel))
    x_f = torch.rfft(input,2,onesided=False)
    x_f_r = x_f[:,:,:,:,0].contiguous()
    x_f_i = x_f[:,:,:,:,1].contiguous()
    x_mag = torch.sqrt(x_f_r*x_f_r + x_f_i*x_f_i)
    x_pha = torch.atan2(x_f_i,x_f_r)

    x_mag = torch.cat((x_mag[:,:,128:,:],x_mag[:,:,0:128,:]),2)
    x_mag = torch.cat((x_mag[:,:,:,128:],x_mag[:,:,:,0:128]),3)
    
    x_mag = x_mag * kernel

    x_mag = torch.cat((x_mag[:,:,128:,:],x_mag[:,:,0:128,:]),2)
    x_mag = torch.cat((x_mag[:,:,:,128:],x_mag[:,:,:,0:128]),3)

    x_f_r = x_mag * torch.cos(x_pha)
    x_f_i = x_mag * torch.sin(x_pha)
    x_f = torch.cat((x_f_r.view(x_f_r.size(0),x_f_r.size(1),x_f_r.size(2),x_f_r.size(3),1),x_f_i.view(x_f_r.size(0),x_f_r.size(1),x_f_r.size(2),x_f_r.size(3),1)),4)
    
    output = torch.irfft(x_f,2,onesided=False)
    
    return output
Exemplo n.º 8
0
 def _register_translation_gpu(self,frame,upsample_factor=20):
     # Whole-pixel shift - Compute cross-correlation by an IFFT
     img_fourier = th.rfft(frame, self.frame_dims, onesided=False)
     image_product = th.zeros(img_fourier.size())
     image_product[:,:,:,0] = img_fourier[:,:,:,0]* self.alignment_frame_fourier[:,:,:,0]- \
     img_fourier[:,:,:,1]* self.alignment_frame_fourier[:,:,:,1]
     image_product[:,:,:,1] = img_fourier[:,:,:,0]* self.alignment_frame_fourier[:,:,:,1]+ \
     img_fourier[:,:,:,1]* self.alignment_frame_fourier[:,:,:,0]                                        
     cross_correlation = th.irfft(image_product, self.frame_dims, onesided=False, signal_sizes = frame.shape)
     # Locate maximum
     maxima = self._to_cpu(th.argmax(cross_correlation))        
     maxima = np.unravel_index(maxima, cross_correlation.size(), order='C')
     maxima = np.asarray(maxima)
     shifts = np.array(maxima, dtype=np.float64)
     shifts[shifts > self.frame_midpoints] -= np.array(cross_correlation.shape)[shifts > self.frame_midpoints] # in bara chie??
     shifts = np.round(shifts * upsample_factor) / upsample_factor #aya round numpy ba torch yejur amal mikone?bale
     upsampled_region_size = np.ceil(upsample_factor * 1.5)        
     dftshift = np.fix(upsampled_region_size / 2.0)
     upsample_factor = np.array(upsample_factor, dtype=np.float64)
     normalization = (img_fourier.numel() * upsample_factor ** 2)
     sample_region_offset = dftshift - shifts*upsample_factor
     image_product = self._to_cpu(image_product)
     imag_part = 1j*image_product[:,:,:,1]
     img_product_cpu = image_product[:,:,:,0]+imag_part         
     cross_correlation = self._upsampled_dft_cpu(img_product_cpu.conj(), upsampled_region_size, upsample_factor, sample_region_offset).conj()
     
     cross_correlation /= normalization
     # Locate maximum and map back to original pixel grid
     maxima = np.array(np.unravel_index(np.argmax(np.abs(cross_correlation)),cross_correlation.shape,order='C'), dtype=np.float64)
     maxima -= dftshift        
     shifts = shifts + maxima / upsample_factor
     return shifts
Exemplo n.º 9
0
def compute_amplitude_gradients_for_X(model, X):
    device = next(model.parameters()).device
    ffted = np.fft.rfft(X, axis=2)
    amps = np.abs(ffted)
    phases = np.angle(ffted)
    amps_th = to_tensor(amps.astype(np.float32),
                        device=device).requires_grad_(True)
    phases_th = to_tensor(phases.astype(np.float32),
                          device=device).requires_grad_(True)

    fft_coefs = amps_th.unsqueeze(-1) * torch.stack(
        (torch.cos(phases_th), torch.sin(phases_th)), dim=-1)
    fft_coefs = fft_coefs.squeeze(3)

    iffted = torch.irfft(fft_coefs, signal_ndim=1, signal_sizes=(X.shape[2], ))

    outs = model(iffted)

    n_filters = outs.shape[1]
    amp_grads_per_filter = np.full((n_filters, ) + ffted.shape,
                                   np.nan,
                                   dtype=np.float32)
    for i_filter in range(n_filters):
        mean_out = torch.mean(outs[:, i_filter])
        mean_out.backward(retain_graph=True)
        amp_grads = to_numpy(amps_th.grad.clone())
        amp_grads_per_filter[i_filter] = amp_grads
        amps_th.grad.zero_()
    assert not np.any(np.isnan(amp_grads_per_filter))
    return amp_grads_per_filter
Exemplo n.º 10
0
    def forward(self, x):
        n1, n2, n3 = self.n1, self.n2, self.n3
        batchsize = x.shape[0]

        x_ft = torch.rfft(x, 3, normalized=True, onesided=True)

        out_ft = torch.zeros(batchsize,
                             self.in_channels,
                             x.size(-3),
                             x.size(-2),
                             x.size(-1) // 2 + 1,
                             2,
                             device=x.device)
        out_ft[:, :, :n1, :n2, :n3] = compl_mul3d(x_ft[:, :, :n1, :n2, :n3],
                                                  self.weights1)
        out_ft[:, :, -n1:, :n2, :n3] = compl_mul3d(x_ft[:, :, -n1:, :n2, :n3],
                                                   self.weights2)
        out_ft[:, :, :n1, -n2:, :n3] = compl_mul3d(x_ft[:, :, :n1, -n2:, :n3],
                                                   self.weights3)
        out_ft[:, :, -n1:,
               -n2:, :n3] = compl_mul3d(x_ft[:, :, -n1:, -n2:, :n3],
                                        self.weights4)

        out = torch.irfft(out_ft,
                          3,
                          normalized=True,
                          onesided=True,
                          signal_sizes=(x.size(-3), x.size(-2), x.size(-1)))
        return out
Exemplo n.º 11
0
    def get_feature_weights(self, patch):
        n, k = self.n, self.k
        assert patch.ndimension() == 4
        batch_size = patch.shape[0]
        c = patch.shape[1]
        assert patch.shape[2] == n * self.k and patch.shape[2] == patch.shape[3]
        w1 = patch.view(batch_size, c, n, k, n, k).permute(0, 2, 4, 1, 3,
                                                           5).contiguous()
        '''
        w1_flatten = w1.view(w1.shape[0], w1.shape[1] * w1.shape[2], -1)
        normalize_term = torch.sqrt(torch.sum(w1_flatten ** 2, dim=2))
        normalize_term = torch.mean(normalize_term, dim=1) + 1e-6
        w1 = w1 / normalize_term[:, None, None, None, None, None]
        '''
        w1 = w1.view(batch_size, n * n, c, k, k)
        if self.tiled:
            w1 = tile2d(w1, 1)
        w1 = w1.view(batch_size * n * n, c, k, k)
        w1_f = torch.rfft(w1, signal_ndim=2, normalized=True,
                          onesided=True) * self.freq_coeff
        w1_new = torch.irfft(w1_f,
                             signal_ndim=2,
                             normalized=True,
                             onesided=True,
                             signal_sizes=w1.shape[-2:])

        return w1_new
def root_filter(img, num_filters=2):
    assert (num_filters > 1)
    cuda = torch.device('cuda')
    img_cu = torch.from_numpy(img.transpose([2, 0, 1])).float().to('cpu')
    img_cu = MinMaxNormalize(img_cu)
    imgs = [img]
    img_cu.unsqueeze_(0)
    I_fft = torch.rfft(img_cu, signal_ndim=2, onesided=False, normalized=False)
    I_mag = ((I_fft[:, :, :, :, 0]**2 + I_fft[:, :, :, :, 1]**2)**0.5)
    #I_mag_nth = I_mag**(1-0.1)
    pf = 1.0 / num_filters
    I_mag_nth = I_mag**(pf)
    for i in range(num_filters):
        I_fft[:, :, :, :, 0] = I_fft[:, :, :, :, 0] / I_mag_nth
        I_fft[:, :, :, :, 1] = I_fft[:, :, :, :, 1] / I_mag_nth
        I_fft[I_fft != I_fft] = 0
        I_hat = torch.irfft(I_fft,
                            signal_ndim=2,
                            onesided=False,
                            normalized=False)
        I_hat_normalized = MinMaxNormalize(I_hat)
        I_hat_normalized = I_hat_normalized.cpu().numpy()[0].transpose(
            [1, 2, 0])
        imgs.append(I_hat_normalized)
    return imgs
Exemplo n.º 13
0
 def forward(self, z):
     z, conditions = z
     # Pad the input sequence
     y = nn.functional.pad(z, (0, self.size), "constant", 0)
     # Compute STFT
     Y_S = torch.rfft(y, 1)
     # Compute the current impulse response
     idx = torch.sigmoid(self.wetdry) * self.identity
     imp = torch.sigmoid(1 - self.wetdry) * self.impulse
     dcy = torch.exp(-(torch.exp(self.decay) + 2) *
                     torch.linspace(0, 1, self.size).to(z.device))
     final_impulse = idx + imp * dcy
     # Pad the impulse response
     impulse = nn.functional.pad(final_impulse, (0, self.size), "constant",
                                 0)
     if y.shape[-1] > self.size:
         impulse = nn.functional.pad(impulse,
                                     (0, y.shape[-1] - impulse.shape[-1]),
                                     "constant", 0)
     IR_S = torch.rfft(impulse.detach(), 1).expand_as(Y_S)
     # Apply the reverb
     Y_S_CONV = torch.zeros_like(IR_S)
     Y_S_CONV[:, :,
              0] = Y_S[:, :, 0] * IR_S[:, :, 0] - Y_S[:, :, 1] * IR_S[:, :,
                                                                      1]
     Y_S_CONV[:, :,
              1] = Y_S[:, :, 0] * IR_S[:, :, 1] + Y_S[:, :, 1] * IR_S[:, :,
                                                                      0]
     # Invert the reverberated signal
     y = torch.irfft(Y_S_CONV, 1, signal_sizes=(y.shape[-1], ))
     # Crop back to original input length
     y = y[:, :z.size(-1)]
     return y
Exemplo n.º 14
0
    def _exp_decay_reverb(self, y, wetdry, decay):
        # Compute STFT
        Y_S = torch.rfft(y, 1)
        # Compute the current impulse response

        # idx = torch.sigmoid(wetdry).unsqueeze(1).expand(-1, y.size(1))
        # imp = torch.sigmoid(1 - wetdry).unsqueeze(1) * self.impulse.expand(len(wetdry), -1)
        # dcy = torch.exp(-(torch.exp(decay) + 2).unsqueeze(1) * torch.linspace(0,1, self.size).to(y.device))
        idx = wetdry.unsqueeze(1).expand(-1, y.size(1))
        imp = (1 - wetdry).unsqueeze(1) * self.impulse.expand(len(wetdry), -1)
        dcy = torch.exp(-(torch.exp(decay) + 2).unsqueeze(1) *
                        torch.linspace(0, 1, self.size).to(y.device))
        final_impulse = idx + imp * dcy
        # Pad the impulse response
        impulse = final_impulse
        IR_S = torch.rfft(impulse, 1).expand_as(Y_S)
        # IR_S = torch.rfft(impulse.detach(),1).expand_as(Y_S)
        # Apply the reverb
        Y_S_CONV = torch.zeros_like(IR_S)
        Y_S_CONV[:, :,
                 0] = Y_S[:, :, 0] * IR_S[:, :, 0] - Y_S[:, :, 1] * IR_S[:, :,
                                                                         1]
        Y_S_CONV[:, :,
                 1] = Y_S[:, :, 0] * IR_S[:, :, 1] + Y_S[:, :, 1] * IR_S[:, :,
                                                                         0]
        # Invert the reverberated signal
        y = torch.irfft(Y_S_CONV, 1, signal_sizes=(y.shape[-1], ))
        return y
Exemplo n.º 15
0
def test():
    #a=np.random.randn(128,1,64,64)
    a = np.zeros((2, 1, 5, 5))
    a[0, 0, :, :] = np.array([[1., 2., 3., 4., 5.], [1., 2., 3., 4., 5.],
                              [1., 2., 3., 4., 5.], [1., 2., 3., 4., 5.],
                              [1., 2., 3., 4., 5.]])

    a_torch = torch.tensor(a, dtype=torch.float32, device='cuda')

    astara_1 = xcorr2_torch(a_torch, a_torch).to(device='cuda')
    astara_2 = xcorr2_torch_CPU(a_torch, a_torch).to(device='cuda')
    astara_3 = xcorr2_torch(a_torch).to(device='cuda')

    #Apply an fftshift to astara_3 to make it consistent with the other definitions
    astara_3 = np.array(astara_3.cpu())
    astara_3 = np.fft.fftshift(astara_3, (2, 3))
    astara_3 = torch.tensor(astara_3, dtype=torch.float32, device='cuda')

    [n_batch, n_c, ha, wa] = a.shape
    absFa2 = torch.zeros((n_batch, n_c, 2 * ha - 1, 2 * wa - 1, 2))
    absFa2[:, :, :, :, 0] = FourierMod2(a_torch)
    astara_4 = torch.irfft(absFa2,
                           signal_ndim=2,
                           onesided=False,
                           normalized=False)  #0-lag is at 0

    # Apply an fftshift to astara_3 to make it consistent with the other definitions
    astara_4 = np.array(astara_4)
    astara_4 = np.fft.fftshift(astara_4, (2, 3))
    astara_4 = torch.tensor(astara_4, dtype=torch.float32, device='cuda')

    diff_12 = torch.norm(astara_1 - astara_2, 2)
    diff_13 = torch.norm(astara_1 - astara_3, 2)
    diff_14 = torch.norm(astara_1 - astara_4, 2)
Exemplo n.º 16
0
    def forward(self, x):
        w1 = torch.nn.ReLU()(self.W1.repeat(x.shape[0], 1, 1, 1).to(x.device))
        w2 = torch.nn.ReLU()(self.W2.repeat(x.shape[0], 1, 1, 1).to(x.device))
        b1 = torch.nn.ReLU()(self.B1.repeat(x.shape[0], 1, 1, 1).to(x.device))
        b2 = torch.nn.ReLU()(self.B2.repeat(x.shape[0], 1, 1, 1).to(x.device))

        rft_x = torch.rfft(x, signal_ndim=3, normalized=True, onesided=True)
        init_spectrum = torch.sqrt(
            torch.pow(rft_x[..., 0], 2) + torch.pow(rft_x[..., 1], 2))

        if self.log:
            spectrum = w2 * self.activation(w1 * torch.log(1 + init_spectrum) +
                                            b1) + b2
        else:
            spectrum = w2 * self.activation(w1 * init_spectrum + b1) + b2

        irf = torch.irfft(torch.stack([
            rft_x[..., 0] * spectrum / (init_spectrum + 1e-16),
            rft_x[..., 1] * spectrum / (init_spectrum + 1e-16)
        ],
                                      dim=-1),
                          signal_ndim=3,
                          normalized=True,
                          onesided=True,
                          signal_sizes=x.shape[1:])

        return irf
Exemplo n.º 17
0
 def mel2mcc(mel):
     mel_ndim = mel.shape[1]
     mel = mel.transpose(1, 2).unsqueeze(-1)
     mel = torch.cat([mel, torch.zeros_like(mel)], dim=-1)
     mcc = torch.irfft(mel, signal_ndim=1, signal_sizes=(2 * (mel_ndim - 1),)).transpose(1, 2)[:, :mel_ndim]
     mcc[:, 0] /= 2.
     return mcc
Exemplo n.º 18
0
    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.rfft(x, 3, normalized=True, onesided=True)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize,
                             self.out_channels,
                             x.size(-3),
                             x.size(-2),
                             x.size(-1) // 2 + 1,
                             2,
                             device=x.device)
        out_ft[:, :, :self.modes1, :self.modes2, :self.modes3] = \
            compl_mul3d(x_ft[:, :, :self.modes1, :self.modes2, :self.modes3], self.weights1)
        out_ft[:, :, -self.modes1:, :self.modes2, :self.modes3] = \
            compl_mul3d(x_ft[:, :, -self.modes1:, :self.modes2, :self.modes3], self.weights2)
        out_ft[:, :, :self.modes1, -self.modes2:, :self.modes3] = \
            compl_mul3d(x_ft[:, :, :self.modes1, -self.modes2:, :self.modes3], self.weights3)
        out_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3] = \
            compl_mul3d(x_ft[:, :, -self.modes1:, -self.modes2:, :self.modes3], self.weights4)

        #Return to physical space
        x = torch.irfft(out_ft,
                        3,
                        normalized=True,
                        onesided=True,
                        signal_sizes=(x.size(-3), x.size(-2), x.size(-1)))
        return x
Exemplo n.º 19
0
 def ifft_image(self, input):
     input = input * self.scale
     input = torch.irfft(input,
                         2,
                         normalized=True,
                         signal_sizes=(self.h, self.w))
     return input / 4
Exemplo n.º 20
0
def diagImpConvFFT(x,
                   B,
                   h,
                   mode=None):  # convolution using FFT of (I + h B' B)^-1
    n = x.shape
    m = B.shape
    mid1 = (m[2] - 1) // 2
    mid2 = (m[3] - 1) // 2
    Bp = torch.zeros(m[0], n[2], n[3], device=B.device)
    # flips up-down and left-right
    Bp[:, 0:(mid1 + 1), 0:(mid2 + 1)] = B[:, 0, mid1:, mid2:]
    Bp[:, -mid1:, 0:(mid2 + 1)] = B[:, 0, 0:mid1, -(mid2 + 1):]
    Bp[:, 0:(mid1 + 1), -mid2:] = B[:, 0, -(mid1 + 1):, 0:mid2]
    Bp[:, -mid1:, -mid2:] = B[:, 0, 0:mid1, 0:mid2]
    xh = torch.rfft(x, 2, onesided=False)
    Bh = torch.rfft(Bp, 2, onesided=False)
    xBh = torch.zeros(n[0], n[1], n[2], n[3], 2, device=B.device)
    if mode == 'laplacian':  # don't need semi pos def if L is laplacian
        t = 1.0 / (h * torch.abs(Bh[:, :, :, 0]) + 1.0)
    else:
        t = 1.0 / (h * (Bh[:, :, :, 0]**2 + Bh[:, :, :, 1]**2) + 1.0)
    for i in range(n[0]):
        xBh[i, :, :, :, 0] = xh[i, :, :, :, 0] * t
        xBh[i, :, :, :, 1] = xh[i, :, :, :, 1] * t
    xB = torch.irfft(xBh, 2, onesided=False)
    return xB
Exemplo n.º 21
0
def FFTConv(imgs, filt, plot=False):
    # take image of size B*C*H*W and filts of size 1*1*H*W, both real
    # filter should have odd dimensions
    # center is assumed at ceil(size/2)
    #
    # output: size B*C*H*W

    filtSize = np.array(filt.size()[2:])
    if np.any(filtSize % 2 == 0):
        raise TypeError("filter size {} should be odd".format(filtSize))

    imSize = np.array(imgs.size()[2:])

    # zero pad
    # Pad arg = (last dim pad left side, last dim pad right side, 2nd last dim left side, etc..)

    fftSize = imSize + filtSize - 1

    imgs = F.pad(imgs, (0, fftSize[0] - imSize[0], 0, fftSize[1] - imSize[1]))
    filt = F.pad(filt,
                 (0, fftSize[0] - filtSize[0], 0, fftSize[1] - filtSize[1]))

    # shift the center to the upper left corner
    filt = roll_n(filt, 2, filtSize[0] // 2)
    filt = roll_n(filt, 3, filtSize[1] // 2)

    imgsFourier = torch.rfft(
        imgs, 2, onesided=False)  # rfft doesn't require complex input
    filtFourier = torch.rfft(filt, 2, onesided=False)

    # Extract the real and imaginary parts
    imgR, imgIm = torch.unbind(imgsFourier, -1)
    filtR, filtIm = torch.unbind(filtFourier, -1)

    if plot == True:

        save_fig_double(filtR.data.cpu(),
                        filtIm.data.cpu(),
                        './',
                        'CurrentCTF-Fourier',
                        iteration=None,
                        Title1='Real',
                        Title2='Imag')
        save_fig_double((imgR + 1e-8).abs().log().data.cpu(),
                        (imgIm + 1e-8).abs().log().data.cpu(),
                        './',
                        'CurrentProj-Fourier',
                        iteration=None,
                        Title1='Real',
                        Title2='Imag')

    # Do element wise complex multiplication
    imgFilterdR = imgR * filtR - imgIm * filtIm
    imgFilteredIm = imgIm * filtR + imgR * filtIm

    imgFiltered = torch.stack((imgFilterdR, imgFilteredIm), -1)
    imgFiltered = torch.irfft(imgFiltered, 2, onesided=False)

    return imgFiltered[:, :, :imSize[0], :
                       imSize[1]], imgsFourier, filtFourier, filt
def mse_bp(hx, y, scale):
    # first part: LS loss
    dip_loss_dif = (hx - y)
    dip_loss = torch.rfft(dip_loss_dif, 2, onesided=False,
                          normalized=True).to(self.device)

    # second part: BP loss
    eps = 1e-3
    mul_factor = 1e5
    eps_ignored = 0.01
    sigma = 0
    h = torch.from_numpy(get_bicubic(scale))
    h = torch.unsqueeze(h, 0).unsqueeze(0)
    conv_shape = (h.shape[2] + h.shape[2] - 1, h.shape[3] + h.shape[3] - 1)
    H = fft2(h, conv_shape[1], conv_shape[0])
    H_flip = fft2(flip(h), conv_shape[1], conv_shape[0])
    H_mul_H_flip = mul_complex(H, H_flip)
    H_mul_H_flip_ifft = torch.irfft(H_mul_H_flip,
                                    signal_ndim=2,
                                    normalized=True,
                                    onesided=False)
    h_downsampled = H_mul_H_flip_ifft[:, :, 1::scale, 1::scale]
    h_downsampled = pad_shift_filter(y, h_downsampled)

    H_downsampled = torch.rfft(h_downsampled,
                               signal_ndim=2,
                               normalized=True,
                               onesided=False)
    bp_loss = torch.sqrt(abs2(H_downsampled)[:, :, :, :, 0:1])
    bp_loss = mul_factor * bp_loss + eps_ignored * (sigma**2) + eps
    bp_loss = 1 / (torch.sqrt(bp_loss))
    bp_loss = torch.repeat_interleave(bp_loss, 2, -1).to(self.device)
    loss_mat = bp_loss.to(self.device) * dip_loss

    return torch.mean(loss_mat**2)
Exemplo n.º 23
0
    def forward(self, z, x):
        if self.is_train:
            z_tmp = z[:, :, 56:183, 56:183]
            _, z2 = self.features(z_tmp)
        z1, z = self.features(z)  # 5
        x1, x = self.features(x)  # 19

        # fast cross correlation
        n, c, h, w = x.size()
        x = x.view(1, n * c, h, w)
        out = F.conv2d(x, z2, groups=n)
        out = out.view(n, 1, out.size(-2), out.size(-1))

        # adjust the scale of responses
        #out = 0.001 * out + 0.0
        out = self.adjust(out)

        #correlation_filter
        if self.is_train:
            zf = torch.rfft(z1, signal_ndim=2)
            xf = torch.rfft(x1, signal_ndim=2)
            kzzf = torch.sum(torch.sum(zf**2, dim=4, keepdim=True),
                             dim=1,
                             keepdim=True)
            kxzf = torch.sum(complex_mulconj(xf, zf), dim=1, keepdim=True)
            alphaf = self.yf / (kzzf + self.lambda0)  # very Ugly
            response = torch.irfft(complex_mul(kxzf, alphaf), signal_ndim=2)
            return out, response
        else:

            return out
Exemplo n.º 24
0
def idst(x, expkp1=None):
    """ Batch Inverse Discrete Sine Transformation without normalization to coefficients.
    Compute y_u = \sum_i  x_i cos(pi*(2u+1)*i/(2N)),
    Impelements the 2N padding trick to solve IDCT with IFFT in the following link,
    https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/spectral_ops.py

    1. Multiply by 2*exp(1j*pi*u/(2N))
    2. Pad x by zeros
    3. Perform IFFT
    4. Extract the real part
    """
    # last dimension
    N = x.size(-1)

    if expkp1 is None:
        expkp1 = get_expkp1(N, dtype=x.dtype, device=x.device)

    # multiply by 2*exp(1j*pi*u/(2N))
    x_pad = x.unsqueeze(-1).mul(expkp1)
    # pad second last dimension, excluding the complex number dimension
    x_pad = F.pad(x_pad, (0, 0, 0, N), 'constant', 0)

    if len(x.size()) == 1:
        x_pad.unsqueeze_(0)

    # the last dimension here becomes -2 because complex numbers introduce a new dimension
    y = torch.irfft(x_pad, signal_ndim=1, normalized=False, onesided=False, signal_sizes=[2*N])[..., 1:N+1]
    y.mul_(N)

    if len(x.size()) == 1:
        y.squeeze_(0)

    return y
Exemplo n.º 25
0
    def forward(self, z):
        """
        :param z: the multiscale searching patch. Shape (num_scale, 3, crop_sz, crop_sz)
        :return response: the response of cross correlation. Shape (num_scale, 1, crop_sz, crop_sz)

        You are required to calculate response using self.wf to do cross correlation on the searching patch z
        """
        # obtain feature of z and add hanning window
        z = self.feature(z) * self.config.cos_window
        # TODO: You are required to calculate response using self.wf to do cross correlation on the searching patch z
        # put your code here
        
        #calculate fourier transform
        z = torch.rfft(z, signal_ndim = 2)
        
        ### complex multiplication (x + yi)(u + vi) = (xu - yv) + (xv + yu)i
        ### conjugate multiplication (y1 - y2i)(p1 + p2i) = (y1p1 + y2p2) + (y1p2 - y2p1)i
        
        real_phiw = self.wf[..., 0] * z[..., 0] + self.wf[..., 1] * z[..., 1]
        imag_phiw = self.wf[..., 0] * z[..., 1] - z[..., 0] * self.wf[..., 1]
        phi_w = torch.stack((real_phiw, imag_phiw), -1)
        
        #n, c, h, w, v = phi_w.shape
        #response = torch.irfft(phi_w.reshape(n, 1, c, h, w, v).sum(2), signal_ndim = 2)
        response = torch.irfft(torch.sum(phi_w, dim = 1, keepdim=True), signal_ndim = 2)
        
        return response
Exemplo n.º 26
0
    def forward(self, x):
        '''
        :param x: (b, c, h, w)
        :return: (b, c, h, w)
        '''
        batch, c, h, w = x.size()
        r_size = x.size()

        ffted = torch.rfft(x, signal_ndim=2)  # (batch, c, h, w/2+1, 2)
        # (batch, c, 2, h, w/2+1)
        ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()
        ffted = ffted.view(ffted.size()[0:1] + (-1, ) + ffted.size()[3:])

        ffted = self.conv_layer(ffted)  # (batch, c*2, h, w/2+1)
        ffted = self.relu(self.bn(ffted))

        ffted = ffted.view(ffted.size()[0:1] + (
            c,
            2,
        ) + ffted.size()[2:]).permute(
            0, 1, 3, 4, 2).contiguous()  # (batch,c, t, h, w/2+1, 2)

        output = torch.irfft(ffted, signal_ndim=2, signal_sizes=r_size[2:])

        return self.alpha * output + x
Exemplo n.º 27
0
def idct_1d(X):
    """
    The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
    Our definition of idct is that idct(dct(x)) == x

    :param X: the input signal
    :return: the inverse DCT-II of the signal over the last dimension
    """

    x_shape = X.shape
    N = x_shape[-1]

    X_v = X.contiguous().view(-1, x_shape[-1]) / 2

    k = torch.arange(x_shape[-1], dtype=X.dtype,
                     device=X.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V_t_r = X_v
    V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)

    V_r = V_t_r * W_r - V_t_i * W_i
    V_i = V_t_r * W_i + V_t_i * W_r

    V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)

    v = torch.irfft(V, 1, onesided=False)
    x = v.new_zeros(v.shape)
    x[:, ::2] += v[:, :N - (N // 2)]
    x[:, 1::2] += v.flip([1])[:, :N // 2]

    return x.view(*x_shape)
Exemplo n.º 28
0
    def forward(self, z):
        """
        :param z: the multiscale searching patch. Shape (num_scale, 3, crop_sz, crop_sz)
        :return response: the response of cross correlation. Shape (num_scale, 1, crop_sz, crop_sz)

        You are required to calculate response using self.wf to do cross correlation on the searching patch z
        """
        # obtain feature of z and add hanning window
        z = self.feature(z) * self.config.cos_window

        num_scale, channels, crop_sz, crop_sz = z.shape
        zf = torch.rfft(z, 2)
        w_star = self.wf.clone().detach()
        w_star[:, :, :, :, 1] = w_star[:, :, :, :, 1] * -1
        output = torch.cuda.FloatTensor(num_scale, 1, crop_sz,
                                        crop_sz // 2 + 1, 2).fill_(0)
        for c in range(num_scale):
            for l in range(channels):
                temp = torch.mul(w_star[0, 1, :, :, :], zf[c, l, :, :, :])
                out_real, out_imag = self.imag_mult(w_star[0, 1, :, :, 0],
                                                    w_star[0, 1, :, :, 1],
                                                    zf[c, l, :, :,
                                                       0], zf[c, l, :, :, 1])
                temp[:, :, 0] = out_real
                temp[:, :, 1] = out_imag
                output[c, 0, :, :, :] += temp
        response = torch.irfft(output, 2)
        return response
Exemplo n.º 29
0
def FDA_source_to_target(src_img, trg_img, L=0.1):
    # exchange magnitude
    # input: src_img, trg_img

    # get fft of both source and target
    fft_src = torch.rfft(src_img.clone(), signal_ndim=2, onesided=False)
    fft_trg = torch.rfft(trg_img.clone(), signal_ndim=2, onesided=False)

    # extract amplitude and phase of both ffts
    amp_src, pha_src = extract_ampl_phase(fft_src.clone())
    amp_trg, pha_trg = extract_ampl_phase(fft_trg.clone())

    # replace the low frequency amplitude part of source with that from target
    amp_src_ = low_freq_mutate(amp_src.clone(), amp_trg.clone(), L=L)

    # recompose fft of source
    fft_src_ = torch.zeros(fft_src.size(), dtype=torch.float)
    fft_src_[:, :, :, :, 0] = torch.cos(pha_src.clone()) * amp_src_.clone()
    fft_src_[:, :, :, :, 1] = torch.sin(pha_src.clone()) * amp_src_.clone()

    # get the recomposed image: source content, target style
    _, _, imgH, imgW = src_img.size()
    src_in_trg = torch.irfft(fft_src_,
                             signal_ndim=2,
                             onesided=False,
                             signal_sizes=[imgH, imgW])

    return src_in_trg
Exemplo n.º 30
0
def fft_frequency_decompose(x, min_size):
    coeffs = torch.rfft(input=x, signal_ndim=1, normalized=True)

    def make_mask(size, start, stop):
        mask = torch.zeros(size).to(x.device)
        mask[start:stop] = 1
        return mask[None, None, :, None]

    output = {}

    current_size = min_size

    while current_size <= x.shape[-1]:
        sl = coeffs[:, :, :current_size // 2 + 1, :]
        if current_size > min_size:
            mask = make_mask(size=sl.shape[2],
                             start=current_size // 4,
                             stop=current_size // 2 + 1)
            sl = sl * mask
        recon = torch.irfft(input=sl,
                            signal_ndim=1,
                            normalized=True,
                            signal_sizes=(current_size, ))

        # if recon.shape[-1] != x.shape[-1]:
        #     recon = torch.zeros_like(recon)

        output[recon.shape[-1]] = recon
        current_size *= 2

    return output