Esempio n. 1
0
def rfft(input_tensor,
         signal_ndim=1,
         n=None,
         dim=-1,
         norm=None) -> torch.Tensor:
    check_fft_version()
    if "torch.fft" not in sys.modules:
        return torch.rfft(input_tensor, signal_ndim=signal_ndim)
    else:
        return torch.fft.rfft(input_tensor, n, dim, norm)
Esempio n. 2
0
    def forward(self, x):
        # batch_size x rho[0] + 2(len(rho) - 1)*rho[0]
        batch_len = len(x)
        mask = self.mask.to(x.device)
        grouped = scatter_add(x, mask).reshape(batch_len, len(self.rhos), self.num, 2).transpose(1, 2)

        signal = torch.irfft(grouped, 1)
        result = torch.rfft(self.nonlinearity(signal), 1, onesided=True).transpose(1, 2)
        result = result.reshape(batch_len, -1)[:, mask]
        return result
Esempio n. 3
0
def eval_torch_rfft2d(x, runs):
    for i in range(100):
        a = torch.rfft(x, signal_ndim=2, onesided=True)
    torch.cuda.synchronize()
    tt = time.time()
    for i in range(runs):
        a = torch.rfft(x, signal_ndim=2, onesided=True)
    torch.cuda.synchronize()
    print("torch.rfft2d takes %.7f ms" % ((time.time()-tt)/runs*1000))

    b = torch.irfft(a, signal_ndim=2, onesided=True, signal_sizes=x.shape)
    torch.cuda.synchronize()
    tt = time.time()
    for i in range(runs):
        b = torch.irfft(a, signal_ndim=2, onesided=True, signal_sizes=x.shape)
    torch.cuda.synchronize()
    print("torch.irfft2d takes %.7f ms" % ((time.time()-tt)/runs*1000))

    print("")
Esempio n. 4
0
def calculate_energy(frames):
    """
    Calculate energy of each frame by rfft.
    :param frame: (nframes, framelen)
    :return: (nframes, framelen//2) or (nframes, framelen//2 - 1)
             that equals to half frequencies
    """
    mag = torch.norm(torch.rfft(frames, 1), dim=2)[:, 1:]
    energy = mag**2
    return energy
Esempio n. 5
0
    def apply(self, x):
        x = torch.rfft(x, signal_ndim=2, normalized=False, onesided=False)
        x = batch_fftshift2d(x)

        x = x * self.mask

        x = batch_ifftshift2d(x)
        x = torch.irfft(x, signal_ndim=2, normalized=False, onesided=False)

        return x
Esempio n. 6
0
def make_cepts2(X, T_pi):
    """Calculate the squared real cepstral coefficents."""
    Y = F.unfold(X, kernel_size=[T_pi, 1], stride=T_pi)
    Y = torch.transpose(Y, 1, 2)

    # Compute the power spectral density
    window = torch.Tensor(hann(Y.shape[-1])[np.newaxis, np.newaxis]).type(Y.dtype)
    Yf = torch.rfft(Y * window, 1, onesided=True)
    spect = Yf[:, :, :, 0]**2 + Yf[:, :, :, 1]**2
    spect = spect.mean(dim=1)
    spect = torch.cat([torch.flip(spect[:, 1:], dims=(1,)), spect], dim=1)

    # Log of the DFT of the autocorrelation
    logspect = torch.log(spect) - np.log(float(Y.shape[-1]))

    # Compute squared cepstral coefs (b_k^2)
    cepts = torch.rfft(logspect, 1, onesided=True) / float(Y.shape[-1])
    cepts = torch.sqrt(cepts[:, :, 0]**2 + cepts[:, :, 1]**2)
    return cepts**2
Esempio n. 7
0
def ccorr(a, b):
    """
	Compute circular correlation of two tensors.
	Parameters
	----------
	a: Tensor, 1D or 2D
	b: Tensor, 1D or 2D

	Notes
	-----
	Input a and b should have the same dimensions. And this operation supports broadcasting.

	Returns
	-------
	Tensor, having the same dimension as the input a.
	"""
    return th.irfft(com_mult(conj(th.rfft(a, 1)), th.rfft(b, 1)),
                    1,
                    signal_sizes=(a.shape[-1], ))
Esempio n. 8
0
def FourierMod2_nopad(a):
    [n_batch,n_c,ha,wa]=a.shape
    mydevice=a.device
    assert n_c==1, "Only grayscale currently supported"
    a=a.view(n_batch,ha,wa)
    A=torch.rfft(a,signal_ndim=2,onesided=False,normalized=False)
    Ar = A[:, :, :, 0]
    Ai = A[:, :, :, 1]
    Aabs2=Ar.abs()**2+Ai.abs()**2#Unlike the definition used in xcorr2, Aabs2 is not complex here.
    return Aabs2.reshape([n_batch,n_c,ha,wa]), Ar.reshape([n_batch,n_c,ha,wa]), Ai.reshape([n_batch,n_c,ha,wa])
Esempio n. 9
0
def dct1(x):
    """
    Discrete Cosine Transform, Type I
    :param x: the input signal
    :return: the DCT-I of the signal over the last dimension
    """
    x_shape = x.shape
    x = x.view(-1, x_shape[-1])

    return torch.rfft(torch.cat([x, x.flip([1])[:, 1:-1]], dim=1), 1)[:, :, 0].view(*x_shape)
Esempio n. 10
0
def data_solution(x, FB, FBC, F2B, FBFy, alpha, sf):
    FR = FBFy + torch.rfft(alpha*x, 2, onesided=False)
    x1 = cmul(FB, FR)
    FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
    invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
    invWBR = cdiv(FBR, csum(invW, alpha))
    FCBinvWBR = cmul(FBC, invWBR.repeat(1, 1, sf, sf, 1))
    FX = (FR-FCBinvWBR)/alpha.unsqueeze(-1)
    Xest = torch.irfft(FX, 2, onesided=False)
    return Xest
Esempio n. 11
0
def fft_old(z, x, label):
    # [batch, 32, 121, 61, 2]
    zf = torch.rfft(z, signal_ndim=2)
    xf = torch.rfft(x, signal_ndim=2)

    # [batch, 1, 121, 61, 1]
    kzzf = torch.sum(torch.sum(zf**2, dim=4, keepdim=True),
                     dim=1,
                     keepdim=True)

    # [batch, 1, 121, 61, 2]
    t = cn.mulconj(xf, zf)
    kxzf = torch.sum(t, dim=1, keepdim=True)

    # [batch, 1, 121, 61, 2]
    alphaf = label.to(device=z.device) / (kzzf + lambda0)

    # [batch, 1, 121, 121]
    return torch.irfft(cn.mul(kxzf, alphaf), signal_ndim=2)
Esempio n. 12
0
def D(x, Dh_DFT, Dv_DFT):
    x_DFT = torch.rfft(x, signal_ndim=1, onesided=False)
    x_DFT = torch.view_as_complex(x_DFT).cuda()
    Dh_x = torch.irfft(torch.view_as_real(Dh_DFT * x_DFT),
                       signal_ndim=1,
                       onesided=False)
    Dv_x = torch.irfft(torch.view_as_real(Dv_DFT * x_DFT),
                       signal_ndim=1,
                       onesided=False)
    return Dh_x, Dv_x
Esempio n. 13
0
    def forward(self, x, edge_index, edge_weight=None):
        """"""
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)

        row, col = edge_index
        num_nodes, num_edges, order_filter = x.size(0), row.size(0), self.weight.size(0)

        if edge_weight is None:
            edge_weight = x.new_ones((num_edges,))
        edge_weight = edge_weight.view(-1)
        assert edge_weight.size(0) == edge_index.size(1)

        deg = degree(row, num_nodes, dtype=x.dtype)

        # Compute normalized and rescaled Laplacian.
        deg = deg.pow(-0.5)
        deg[deg == float('inf')] = 0
        lap = -deg[row] * edge_weight * deg[col]

        def weight_mult(x, w):
            y = torch.einsum('fgrs,ifrs->igrs', w, x)
            return y

        def lap_mult(edge_index, lap, x):
            L = torch.sparse.IntTensor(edge_index, lap, torch.Size([x.shape[0], x.shape[0]])).to_dense()
            x_tilde = torch.einsum('ij,ifrs->jfrs', L, x)
            return x_tilde

        # Perform filter operation recurrently.
        horizon = x.shape[1]
        x = x.permute(0, 2, 1)
        x_hat = torch.rfft(x, 1, normalized=True, onesided=True)

        Tx_0 = x_hat

        y_hat = weight_mult(Tx_0, self.weight[0, :])

        if order_filter > 1:

            Tx_1 = lap_mult(edge_index, lap, x_hat)
            y_hat = y_hat + weight_mult(Tx_1, self.weight[1, :])

            for k in range(2, order_filter):
                Tx_2 = 2 * lap_mult(edge_index, lap, Tx_1) - Tx_0
                y_hat = y_hat + weight_mult(Tx_2, self.weight[k, :])

                Tx_0, Tx_1 = Tx_1, Tx_2

        y = torch.irfft(y_hat, 1, normalized=True, onesided=True, signal_sizes=(horizon,))
        y = y.permute(0, 2, 1)

        if self.bias is not None:
            y = y + self.bias

        return y
Esempio n. 14
0
    def initGT(self, device):
       
        
        if 'Ribo' in self.args.dataset:
             self.GroundTruth=torch.Tensor( mrcfile.open('./Datasets/Ribo/emd_2660.mrc').data).to(device)
             self.gen.X.data=self.GroundTruth
        elif 'Betagal' in self.args.dataset:
            print("Using 2.5 A molmap")
            fittedBetagal=torch.Tensor( mrcfile.open('./Datasets/Betagal-Synthetic/fitted_betagal_2.5A.mrc').data).to(device)
            fittedBetagal=torch.nn.functional.avg_pool3d(fittedBetagal.unsqueeze(0).unsqueeze(0), kernel_size=self.args.DownSampleRate, stride=self.args.DownSampleRate, padding=0).squeeze()
           
            n=self.args.VolumeSize
            
            GroundTruth= self.ExpandVolume( fittedBetagal, n, device)
            self.gen.X.data=GroundTruth.unsqueeze(0)
            self.GroundTruth=self.gen.X.data
         
                
                
        elif 'ABC' in self.args.dataset:
            n=self.args.VolumeSize
            
            if self.args.VolumeNumbers==1:
                vol=torch.Tensor( mrcfile.open('./Datasets/ABC-Synthetic/fitted_ABC.mrc').data).to(device)
                self.gen.X.data=self.ExpandVolume( vol, n, device).unsqueeze(0)
                
            else:    
                for i in range(self.args.VolumeNumbers):
                    num=4773+2*i
                    vol=torch.Tensor( mrcfile.open('./Datasets/ABC/fitted_'+str(num)+'.mrc').data).to(device)/4
                    self.gen.X.data[i]= self.ExpandVolume( vol, n, device).unsqueeze(0)

         
            self.GroundTruth=self.gen.X.data
            
        elif 'proteasome' in self.args.dataset:
            n=self.args.VolumeSize
           
                
            vol=torch.Tensor( mrcfile.open('./Datasets/proteasome-Synthetic/fitted_proteasome.mrc').data).to(device)
            self.gen.X.data[0]= self.ExpandVolume( vol, n, device)
            self.GroundTruth=self.gen.X.data
        
        elif 'serotonin' in self.args.dataset:
            n=self.args.VolumeSize
           
                
            vol=torch.Tensor( mrcfile.open('./Datasets/serotonin-Synthetic/fitted_serotonin.mrc').data).to(device)
            self.gen.X.data[0]= self.ExpandVolume( vol, n, device)
            self.GroundTruth=self.gen.X.data
            
            
        if self.args.VolumeDomain =='fourier' and self.args.FourierProjector==True:
            self.gen.X.data=fftshift(torch.rfft(ifftshift(self.GroundTruth, mode='real', signal_dim=3), 3, onesided=False),mode='complex', signal_dim=3)
            print("Fourier projector")
Esempio n. 15
0
def fft_conv2d(x, weight, bias=None):
    b, c, h, w = x.shape
    n_out, n_in, k, k = weight.shape
    total_pad_w, total_pad_h = k - 1, k - 1
    pad_lw, pad_lh = total_pad_w // 2, total_pad_h // 2
    pad_rw, pad_rh = total_pad_w - pad_lw, total_pad_h - pad_lh
    x = F.pad(x, (pad_lw, pad_rw, pad_lh, pad_rh))
    b, c, h, w = x.shape  # Get it again in case we padded

    fft_pad_lw, fft_pad_lh = (w - 1) // 2, (h - 1) // 2
    fft_pad_rw, fft_pad_rh = w - 1 - fft_pad_lw, h - 1 - fft_pad_lh
    weight_pad_lw, weight_pad_lh = (w - k) // 2, (h - k) // 2
    weight_pad_rw, weight_pad_rh = w - k - weight_pad_lw, h - k - weight_pad_lh
    weight_padded = F.pad(
        weight,
        (
            weight_pad_lw + fft_pad_lw,
            weight_pad_rw + fft_pad_rw,
            weight_pad_lh + fft_pad_lh,
            weight_pad_rh + fft_pad_rh,
        ),
    )
    x = F.pad(x, (fft_pad_lw, fft_pad_rw, fft_pad_lh, fft_pad_rh))
    x_fft = torch.rfft(x, 2)
    weight_fft = torch.rfft(weight_padded, 2)
    result_fft = compl_mul(x_fft, weight_fft)
    result = torch.irfft(result_fft,
                         2,
                         signal_sizes=(x.shape[-2], x.shape[-1]))
    b, c, h, w = result.shape  # Get it again in case we unpadded

    if bias is not None:
        result += bias
    res = result.clone()
    res[:, :, :int(np.floor(h / 2)), :] = result[:, :, int(np.ceil(h / 2)):, :]
    res[:, :, int(np.floor(h / 2)):, :] = result[:, :, :int(np.ceil(h / 2)), :]
    result = res.clone()
    res[:, :, :, :int(np.floor(w / 2))] = result[:, :, :, int(np.ceil(w / 2)):]
    res[:, :, :, int(np.floor(w / 2)):] = result[:, :, :, :int(np.ceil(w / 2))]
    res = res[:, :, pad_lh + fft_pad_lh:h - pad_rh - fft_pad_rh,
              pad_lw + fft_pad_lw:w - pad_rw - fft_pad_rw, ].contiguous()
    return res
Esempio n. 16
0
def synthesis(s_target, test_id, ind, scat, n, min_error, err_it, nit, is_complex = False, initial_type = 'gaussian'):
    print('shape of s:', s_target.size())
#     s_target = s_target / torch.sum(s_target) 
    if initial_type == 'gaussian':
        x0 = torch.randn(n,n)
    elif initial_type == 'uniform':
        x0 = torch.rand(n,n)

    if torch.cuda.is_available():
        s_target = s_target.cuda()
        x0 = x0.cuda()
    x0 = Variable(x0, requires_grad=True)
    
    x0_hat = torch.rfft(x0, 2, onesided = False)
    
    s0 = scat(x0_hat)
    loss = nn.MSELoss()
    optimizer = optim.Adam([x0], lr=lr)
    output = loss(s_target, s0)
    l0 = output
    error = []
    count = 0
    while output / l0 > min_error:
        optimizer.zero_grad() 
        x0_hat = torch.rfft(x0, 2, onesided = False)
        s0 = scat(x0_hat)
        output = loss(s_target, s0)
        if count % err_it ==0:   
            error.append(output.item())
        output.backward()
        optimizer.step()
        output = loss(s_target, s0)
        if count % nit == 0:
            print(output.data.cpu().numpy())
            np.save('./result%d/syn_error%d.npy'%(test_id, ind), np.asarray(error))
            np.save('./result%d/syn_result%d.npy'%(test_id, ind), x0.data.cpu().numpy())
            # plot_image(x0.data.cpu().numpy(), test_id, ind, count, nit)
        count += 1
    print('error reduced by: ', output / l0)
    print('error supposed reduced by: ', min_error)
    np.save('./result%d/syn_error%d.npy'%(test_id, ind), np.asarray(error))
    np.save('./result%d/syn_result%d.npy'%(test_id, ind), x0.data.cpu().numpy())
Esempio n. 17
0
    def CACF_update(self, inputs, lr=1.):

        inputs[0] = self.feature(inputs[0]) * self.config.cos_window
        inputs[1] = self.feature(inputs[1]) * self.config.cos_window
        inputs[2] = self.feature(inputs[2]) * self.config.cos_window
        inputs[3] = self.feature(inputs[3]) * self.config.cos_window
        inputs[4] = self.feature(inputs[4]) * self.config.cos_window
        zf = torch.rfft(inputs[0], signal_ndim=2)  # target region
        cf1 = torch.rfft(inputs[1], signal_ndim=2)  # contex 1 region
        cf2 = torch.rfft(inputs[2], signal_ndim=2)  # contex 2 region
        cf3 = torch.rfft(inputs[3], signal_ndim=2)  # contex 3 region
        cf4 = torch.rfft(inputs[4], signal_ndim=2)  # contex 4 region
        kzzf = torch.sum(torch.sum(zf**2, dim=4, keepdim=True),
                         dim=1,
                         keepdim=True)
        kccf1 = torch.sum(torch.sum(cf1**2, dim=4, keepdim=True),
                          dim=1,
                          keepdim=True)
        kccf2 = torch.sum(torch.sum(cf2**2, dim=4, keepdim=True),
                          dim=1,
                          keepdim=True)
        kccf3 = torch.sum(torch.sum(cf3**2, dim=4, keepdim=True),
                          dim=1,
                          keepdim=True)
        kccf4 = torch.sum(torch.sum(cf4**2, dim=4, keepdim=True),
                          dim=1,
                          keepdim=True)

        if lr > 0.99:
            alphaf = self.config.yf / (kzzf + self.config.lambda0)
        else:
            alphaf = self.config.yf / (kzzf + self.config.lambda0 +
                                       self.config.lambda1 *
                                       (kccf1 + kccf2 + kccf3 + kccf4))

        if lr > 0.99:
            self.model_alphaf = alphaf
            self.model_zf = zf
        else:
            self.model_alphaf = (
                1 - lr) * self.model_alphaf.data + lr * alphaf.data
            self.model_zf = (1 - lr) * self.model_zf.data + lr * zf.data
Esempio n. 18
0
    def forward(ctx,
                h1,
                s1,
                h2,
                s2,
                output_size,
                x,
                y,
                force_cpu_scatter_add=False):
        ctx.save_for_backward(h1, s1, h2, s2, x, y)
        ctx.x_size = tuple(x.size())
        ctx.y_size = tuple(y.size())
        ctx.force_cpu_scatter_add = force_cpu_scatter_add
        ctx.output_size = output_size

        # Compute the count sketch of each input
        px = CountSketchFn_forward(h1, s1, output_size, x,
                                   force_cpu_scatter_add)
        fx = torch.rfft(px, 1)
        re_fx = fx.select(-1, 0)
        im_fx = fx.select(-1, 1)
        del px
        py = CountSketchFn_forward(h2, s2, output_size, y,
                                   force_cpu_scatter_add)
        fy = torch.rfft(py, 1)
        re_fy = fy.select(-1, 0)
        im_fy = fy.select(-1, 1)
        del py

        # Convolution of the two sketch using an FFT.
        # Compute the FFT of each sketch

        # Complex multiplication
        re_prod, im_prod = ComplexMultiply_forward(re_fx, im_fx, re_fy, im_fy)

        # Back to real domain
        # The imaginary part should be zero's
        re = torch.irfft(torch.stack((re_prod, im_prod), re_prod.dim()),
                         1,
                         signal_sizes=(output_size, ))

        return re
Esempio n. 19
0
 def forward(self, input, FB, FBC, F2B, FBFy, alpha, sf):
     alpha = alpha[:, 1:2, ...]
     FR = FBFy + torch.rfft(alpha * input, 2, onesided=False)
     x1 = cmul(FB, FR)
     FBR = torch.mean(splits(x1, sf), dim=-1, keepdim=False)
     invW = torch.mean(splits(F2B, sf), dim=-1, keepdim=False)
     invWBR = cdiv(FBR, csum(invW, alpha))
     FCBinvWBR = cmul(FBC, invWBR.repeat(1, 1, sf, sf, 1))
     FX = (FR - FCBinvWBR) / alpha.unsqueeze(-1)
     Xest = torch.irfft(FX, 2, onesided=False)
     return Xest
Esempio n. 20
0
def rfft(t):
    # Real-to-complex Discrete Fourier Transform
    ver = torch.__version__
    major, minor, ver = ver.split('.')
    ver_int = int(major) * 100 + int(minor)
    if ver_int >= 108:
        ft = torch.fft.fft2(t)
        ft = torch.stack([torch.real(ft), torch.imag(ft)], dim=-1)
    else:
        ft = torch.rfft(t, 2, onesided=False)
    return ft
    def forward(self, x):
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.rfft(x, 1, normalized=True, onesided=True)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(x_ft.shape, device=x.device)
        out_ft[:, :self.modes1, :] = x_ft[:, :self.modes1,  :]

        #Return to physical space
        x = torch.irfft(out_ft, 1, normalized=True, onesided=True, signal_sizes=(x.size(-1), ))
        return x
Esempio n. 22
0
 def forward(ctx, params, luts, inverse, mv):
     ctx.params = params
     ctx.luts = luts
     ctx.inverse = inverse
     sh = mv.shape
     spatial_dim = len(sh)-2
     Fmv = torch.rfft(mv, spatial_dim, normalized=True)
     lagomorph_cuda.fluid_operator(Fmv, inverse,
             luts['cos'], luts['sin'], *params)
     return torch.irfft(Fmv, spatial_dim, normalized=True,
             signal_sizes=sh[2:])
Esempio n. 23
0
def rfft2(data):
    # (H x W) or (C x H x W) or (B x C x H x W)
    assert data.shape[-1] == data.shape[-2]
    assert (len(data.shape) == 2
            or (len(data.shape) == 3 and data.shape[0] == 1)
            or (len(data.shape) == 4 and data.shape[1] == 1))
    data = ifftshift(data, dim=(-2, -1))
    data = torch.rfft(data, 2, normalized=True, onesided=False)
    # Now complex valued with dim -1 as [real, imaginary] dimension
    data = fftshift(data, dim=(-3, -2))
    return data
Esempio n. 24
0
    def extract_power(self, frames):
        t_frames = torch.Tensor(frames).to(self.opts['device'])
        t_frames -= t_frames.mean(2, True)
        t_frames = t_frames * self.window

        t_frames = F.pad(t_frames, (0, self.next_power - self.opts['win_len']))
        spect = torch.rfft(t_frames, 1)
        power_spect = torch.pow(spect[:, :, :, 0], 2.0) + torch.pow(
            spect[:, :, :, 1], 2.0)
        # c1 =torch.matmul(power_spect, self.melbank)
        return power_spect
Esempio n. 25
0
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            sigma = group['sigma']

            index = 0
            for p in group['params']:
                if p.grad is None:
                    continue
                # Laplacian smoothing gradient
                if sigma > 0:
                    if self.cs[index] is not None:
                        g = p.grad.data.view(-1)
                        fg = torch.rfft(g, 1, onesided=False)
                        g = torch.zeros_like(fg)
                        self.cs[index] = self.cs[index].to(fg.device)
                        g = complex_div(fg, self.cs[index])
                        g = torch.irfft(g, 1, onesided=False).view(
                            p.grad.data.size())
                        p.grad.data = g
                    index += 1

                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(
                            d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                p.data.add_(-group['lr'], d_p)

        return loss
Esempio n. 26
0
def test_rfft(signal_ndim, normalized, onesided):
    input = torch.randn(4, 3, 8, 8)
    assert torch.allclose(torch.rfft(input, signal_ndim, normalized, onesided),
                          utils._rfft(input, signal_ndim, normalized,
                                      onesided),
                          atol=1e-4)
    assert torch.allclose(utils._irfft(
        utils._rfft(input, signal_ndim, normalized, onesided), signal_ndim,
        normalized, onesided),
                          input,
                          atol=1e-4)
Esempio n. 27
0
    def extract_fbank(self, wavs):
        t_frames = self.enframes(wavs)
        t_frames -= t_frames.mean(2, True)
        t_frames = t_frames * self.window

        t_frames = F.pad(t_frames, (0, self.next_power - self.opts['win_len']))
        spect = torch.rfft(t_frames, 1)
        power_spect = torch.pow(spect[:, :, :, 0], 2.0) + torch.pow(
            spect[:, :, :, 1], 2.0)
        c1 = torch.matmul(power_spect, self.melbank)
        return c1
    def forward(self, x):

        fftfull = torch.rfft(x,2)
        xreal = fftfull[... , 0]
        xim = fftfull[... ,1]
        x = torch.cat((xreal.unsqueeze(1), xim.unsqueeze(1)), 1 ).unsqueeze( -3 )
        x = torch.index_select( x, -2, self.indF[0] )

        x   = self.hl0 * x 
        h0f = x.select( -3, 0 ).unsqueeze( -3 )
        l0f = x.select( -3, 1 ).unsqueeze( -3 )
        lf  = l0f 

        output = []

        for n in range( self.N ):

            bf = self.b[n] * lf 
            lf = self.l[n] * central_crop( lf ) 
            if self.hilb:
                hbf = self.s[n] * torch.cat( (bf.narrow(1,1,1), -bf.narrow(1,0,1)), 1 )
                bf  = torch.cat( ( bf , hbf ), -3 )
            if self.includeHF and n == 0:
                bf  = torch.cat( ( h0f,  bf ), -3 )

            output.append( bf )

        output.append( lf  ) 

        for n in range( len( output ) ):
            output[n] = torch.index_select( output[n], -2, self.indB[n] )
            sig_size = [output[n].shape[-2],(output[n].shape[-1]-1)*2]
            output[n] = torch.stack((output[n].select(1,0), output[n].select(1,1)),-1)
            output[n] = torch.irfft( output[n], 2, signal_sizes = sig_size)

        if self.includeHF:
            output.insert( 0, output[0].narrow( -3, 0, 1                    ) )
            output[1]       = output[1].narrow( -3, 1, output[1].size(-3)-1 )

        for n in range( len( output ) ):
            if self.hilb:
                if ((not self.includeHF) or 0 < n) and n < len(output)-1:
                    nfeat = output[n].size(-3)//2
                    o1 = output[n].narrow( -3,     0, nfeat ).unsqueeze(1)
                    o2 = -output[n].narrow( -3, nfeat, nfeat ).unsqueeze(1)
                    output[n] = torch.cat( (o2, o1), 1 ) 
                else:
                    output[n] = output[n].unsqueeze(1)

        for n in range( len( output ) ):
            if n>0:
                output[n] = output[n]*(2**(n-1))
                
        return output
Esempio n. 29
0
def p2o(psf, shape):
    otf = torch.zeros(psf.shape[:-2] + shape).type_as(psf)
    otf[..., :psf.shape[2], :psf.shape[3]].copy_(psf)
    for axis, axis_size in enumerate(psf.shape[2:]):
        otf = torch.roll(otf, -int(axis_size / 2), dims=axis + 2)
    otf = torch.rfft(otf, 2, onesided=False)
    n_ops = torch.sum(psf.size * torch.log2(psf.shape))
    otf[...,
        1][torch.asb(otf[...,
                         1]) < n_ops * 2.22e-16] = torch.tensor(0).float()
    return otf
Esempio n. 30
0
 def forward(self, x):
     input_size = x.shape[2]
     projection_size_padded = \
         max(64, int(2 ** (2 * torch.tensor(input_size)).float().log2().ceil()))
     pad_width = projection_size_padded - input_size
     padded_tensor = F.pad(x, (0,0,0,pad_width))
     f = self._get_fourier_filter(padded_tensor.shape[2]).to(x.device)
     fourier_filter = self.create_filter(f)
     fourier_filter = fourier_filter.unsqueeze(-2)
     projection = torch.rfft(padded_tensor.transpose(2,3), 1, onesided=False).transpose(2,3) * fourier_filter
     return torch.irfft(projection.transpose(2,3), 1, onesided=False).transpose(2,3)[:,:,:input_size,:]