コード例 #1
0
def create_fft_plots(sample, model, epoch):
    train_code = model.encode(sample.view(1, -1))
    train_code = train_code[0]
    fig = plt.figure()
    train_code_pad = torch.zeros(100).cuda()
    train_code_pad[:len(train_code)] = train_code
    train_code_complex = torch.stack(
        (train_code_pad, torch.zeros(*train_code_pad.size()).cuda()),
        dim=1).cuda()
    H = torch.fft(train_code_complex, 1,
                  normalized=True).cpu().detach().numpy()
    plt.plot([np.sqrt(H[i, 0]**2 + H[i, 1]**2) for i in range(len(H))])

    filter_real = model.conv1.conv_real.weight.data.view(-1)
    filter_imag = model.conv1.conv_imag.weight.data.view(-1)
    # lowpass_pad = torch.zeros(L)
    # lowpass_pad[:len(lowpass_coeff)] = lowpass_coeff
    filter_complex = torch.stack((filter_real, filter_imag), dim=1)
    print(filter_complex.shape)
    lowpass_fft = torch.fft(filter_complex, 1,
                            normalized=False).cpu().detach().numpy()
    plt.plot([
        np.sqrt(lowpass_fft[i, 0]**2 + lowpass_fft[i, 1]**2)
        for i in range(len(lowpass_fft))
    ])
    plt.title('Epoch ' + str(epoch))
    plt.savefig('../results/images/fft_none/fft_%s.png' %
                (str(epoch).zfill(4)))
    fig.clf()
    plt.close()
コード例 #2
0
    def reconstruct_matrix(self, signal, calibration_matrix,
                           eigval1_reciprocal, eigval2_reciprocal):
        """
        Recovers the random matrix.

        Parameters
        ----------

        signal: ComplexTensor,
            Tensor with the signal values
        calibration_matrix: torch.Tensor,
            calibration matrix (the partial one)
        eigenvalues_reciprocal1: ComplexTensor,
            eigenvalues of the first circulant matrix block of the partial calibration matrix.
        eigenvalues_reciprocal2: ComplexTensor,
            eigenvalues of the second circulant matrix block of the partial calibration matrix.

        Returns
        -------
        reconstructed_A: ComplexTensor,
            recostructed transmission matrix. If batch size<rows, it is a batch of rows.

        """
        start = time()

        if self.solver == "least-square":

            inv_calibration_matrix = torch.pinverse(calibration_matrix)
            reconstructed_A = ComplexTensor(
                real=torch.matmul(signal.real, inv_calibration_matrix),
                imag=torch.matmul(signal.imag, inv_calibration_matrix))

        elif self.solver == "fft":
            signal = signal.conj().stack()

            signal1_star = signal[:, :self.n_signals // 2]
            signal2_star = signal[:, self.n_signals // 2:]

            fft_buffer = torch.fft(signal1_star, signal_ndim=1)

            block1 = ComplexTensor(real=fft_buffer[:, :, 0],
                                   imag=fft_buffer[:, :, 1])
            block1 = block1.batch_elementwise(eigval1_reciprocal)

            fft_buffer = torch.fft(signal2_star, signal_ndim=1)

            block2 = ComplexTensor(real=fft_buffer[:, :, 0],
                                   imag=fft_buffer[:, :, 1])

            block2 = block2.batch_elementwise(eigval2_reciprocal)

            reconstructed_A = torch.ifft((block1 + block2).stack(),
                                         signal_ndim=1)
            reconstructed_A = ComplexTensor(real=reconstructed_A[:, :, 0],
                                            imag=reconstructed_A[:, :,
                                                                 1]).conj()

        self.time_logger["solver"] += time() - start

        return reconstructed_A
コード例 #3
0
def one_hot_add(inputs, shift):
    """Performs (inputs + shift) % vocab_size in the one-hot space.
    Args:
        inputs: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
        Tensor.
        shift: Tensor of shape `[..., vocab_size]`. Typically a soft/hard one-hot
        Tensor specifying how much to shift the corresponding one-hot vector in
        inputs. Soft values perform a "weighted shift": for example,
        shift=[0.2, 0.3, 0.5] performs a linear combination of 0.2 * shifting by
        zero; 0.3 * shifting by one; and 0.5 * shifting by two.
    Returns:
        Tensor of same shape and dtype as inputs.
    """
    inputs = torch.stack((inputs, torch.zeros_like(inputs)), dim=-1)
    shift = torch.stack((shift, torch.zeros_like(shift)), dim=-1)
    inputs_fft = torch.fft(
        inputs, 1)  #ignore last and first dimension to do batched fft
    shift_fft = torch.fft(shift, 1)
    result_fft_real = inputs_fft[..., 0] * shift_fft[..., 0] - inputs_fft[
        ..., 1] * shift_fft[..., 1]
    result_fft_imag = inputs_fft[..., 0] * shift_fft[..., 1] + inputs_fft[
        ..., 1] * shift_fft[..., 0]
    result_fft = torch.stack((result_fft_real, result_fft_imag), dim=-1)
    return torch.ifft(
        result_fft,
        1)[...,
           0], result_fft, inputs_fft, shift_fft  #return only the real part
コード例 #4
0
 def test_fft2d(self):
     batch_size = 10
     n1 = 16
     n2 = 32
     input = torch.randn(batch_size, n2, n1, dtype=torch.complex64)
     for normalized in [False, True]:
         out_torch = view_as_complex(
             torch.fft(view_as_real(input),
                       signal_ndim=2,
                       normalized=normalized))
         # Just to show how fft2d is exactly 2 ffts on each dimension
         input_f = view_as_complex(
             torch.fft(view_as_real(input),
                       signal_ndim=1,
                       normalized=normalized))
         out_fft = view_as_complex(
             torch.fft(view_as_real(input_f.transpose(-1, -2)),
                       signal_ndim=1,
                       normalized=normalized)).transpose(-1, -2)
         self.assertTrue(
             torch.allclose(out_torch, out_fft, self.rtol, self.atol))
         for br_first in [True, False]:
             for flatten in [False, True]:
                 b = torch_butterfly.special.fft2d(n1,
                                                   n2,
                                                   normalized=normalized,
                                                   br_first=br_first,
                                                   flatten=flatten)
                 out = b(input)
                 self.assertTrue(
                     torch.allclose(out, out_torch, self.rtol, self.atol))
コード例 #5
0
def cross_corr_and_conv(x, y, pad=False, real=True):
    
    if pad:
        x = zero_pad(x)
        y = zero_pad(y)
    
    x = complexify(x)
    y = complexify(y)
    
    xf = torch.fft(x, signal_ndim=2)
    yf = torch.fft(y, signal_ndim=2)
    
    convf = complex_mult(xf, yf)
    
    yf_conj = yf
    yf_conj[..., 1] *= -1
    
    corrf = complex_mult(xf, yf_conj)
    
    conv = torch.ifft(convf, signal_ndim=2)
    corr = torch.ifft(corrf, signal_ndim=2)
    
    if real:
        conv = conv[..., 0]
        corr = corr[..., 0]
    
    return corr, conv
コード例 #6
0
    def get_target_tensor(self,
                          input,
                          target_is_real,
                          degree,
                          mask,
                          pred_and_gt=None):

        if target_is_real:
            target_tensor = torch.ones_like(input)
            target_tensor[:] = degree

        else:
            target_tensor = torch.zeros_like(input)
            if not self.use_mse_as_energy:
                if degree != 1:
                    target_tensor[:] = degree
            else:
                pred, gt = pred_and_gt
                if self.options.dataroot == "KNEE_RAW":
                    gt = center_crop(gt, [368, 320])
                    pred = center_crop(pred, [368, 320])
                w = gt.shape[2]
                ks_gt = fft(gt, normalized=True)
                ks_input = fft(pred, normalized=True)
                ks_row_mse = F.mse_loss(ks_input, ks_gt, reduce=False).sum(
                    1, keepdim=True).sum(2, keepdim=True).squeeze() / (2 * w)
                energy = torch.exp(-ks_row_mse * self.gamma)

                target_tensor[:] = energy
            # force observed part to always
            for i in range(mask.shape[0]):
                idx = torch.nonzero(mask[i, 0, 0, :])
                target_tensor[i, idx] = 1
        return target_tensor
コード例 #7
0
    def forward(self, x):
        """
        x: input Tensor of shape [batch_size, input_dim1, height, width].
        """
        batch_size, input_dim, height, width = x.size()
        assert input_dim == self.input_dim

        x_flat = x.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim).to(self.device)

        sketch_1 = x_flat.mm(self.sparse_sketch_matrix1)
        sketch_2 = x_flat.mm(self.sparse_sketch_matrix2)

        # Build real+imag arrays to compute FFT, with imag = 0
        sketch_1 = torch.stack((sketch_1, torch.zeros(sketch_1.shape).to(self.device)), dim=-1)
        sketch_2 = torch.stack((sketch_2, torch.zeros(sketch_2.shape).to(self.device)), dim=-1)

        fft1 = torch.fft(sketch_1, signal_ndim=1)
        fft2 = torch.fft(sketch_2, signal_ndim=1)
        del sketch_1, sketch_2

        # Element-wise complex product
        real1, imag1 = fft1.transpose(0, -1)
        real2, imag2 = fft2.transpose(0, -1)
        prod = torch.stack((real1 * real2 - imag1 * imag2,
            real1 * imag2 + imag1 * real2), dim=0).transpose(0, -1)
        del real1, real2, imag1, imag2

        cbp_flat = torch.ifft(prod, signal_ndim=1)[..., 0]

        cbp = cbp_flat.view(batch_size, height, width, self.output_dim)

        if self.sum_pool:
            cbp = cbp.sum(dim=1).sum(dim=1)

        return cbp
コード例 #8
0
ファイル: fft_loss.py プロジェクト: BIGWangYuDong/MCD-Net
 def forward(self, img1, img2):
     zeros = torch.zeros(img1.size()).cuda(img1.device)
     loss = nn.L1Loss(size_average=True)(torch.fft(
         torch.stack((img1, zeros), -1),
         2), torch.fft(torch.stack((img2, zeros), -1), 2))
     loss = self.loss_weight * loss
     return loss
コード例 #9
0
ファイル: fpm.py プロジェクト: sourav22899/MELD
    def grad(self, field_est, device='cpu'):
        self.measurements = self.measurements.to(self.device)
        self.pupils = self.pupils.to(self.device)
        self.planewaves = self.planewaves.to(self.device)
        self.P = self.P.to(self.device)
        
        multiMeas = torch.matmul(self.measurements.permute(1,2,0),self.C.permute(1,0)).permute(2,0,1)
        multiMeas = torch.abs(multiMeas)
        
        # simulate current estimate of measurements
        y = self.generateMultiMeas(field_est,device=device)

        # compute residual
        sqrty = torch.sqrt(y + EPS)
        residual = sqrty - torch.sqrt(multiMeas + EPS)
        cost = torch.sum(torch.pow(residual,2)).detach()
        Ajx = residual/(sqrty + 1e-10)
        Ajx_c = torch.stack((Ajx,torch.zeros_like(Ajx)),dim=len(Ajx.shape))
        
        # compute gradient
        output = mul_c(self.planewaves,field_est)
        output = torch.fft(output,2)
        output = mul_c(self.P,output)
        output = torch.ifft(output,2)
        
        g = field_est*0.
        for meas_index in range(self.Nmeas):
            output2 = mul_c(Ajx_c[meas_index,...],output)
            output2 = mul_c(conj(self.planewaves),output2)
            output2 = torch.fft(output2,2)
            output2 = mul_c(self.pupils,output2)
            output2 = torch.ifft(output2,2)
            g_tmp = torch.matmul(output2.permute(1,2,3,0),self.C[meas_index,:])
            g = g + g_tmp
        return g
コード例 #10
0
def circular_convolution_fft(keys, values, normalized=True, conj=False, cuda=False):
    '''
    For the circular convolution of x and y to be equivalent,
    you must pad the vectors with zeros to length at least N + L - 1
    before you take the DFT. After you invert the product of the
    DFTs, retain only the first N + L - 1 elements.
    '''
    assert values.dim() == keys.dim() == 2, "only 2 dims supported"
    assert values.size(-1) % 2 == keys.size(-1) % 2 == 0, "need last dim to be divisible by 2"
    batch_size, keys_feature_size = keys.size(0), keys.size(1)
    values_feature_size = values.size(1)
    required_size = keys_feature_size + values_feature_size - 1
    required_size = required_size + 1 if required_size % 2 != 0 else required_size

    # conj transpose
    keys = Complex(keys).conj().unstack() if conj else keys

    # reshape to [batch, [real, imag]]
    half = keys.size(-1) // 2
    keys = torch.cat([keys[:, 0:half].unsqueeze(2), keys[:, half:].unsqueeze(2)], -1)
    values = torch.cat([values[:, 0:half].unsqueeze(2), values[:, half:].unsqueeze(2)], -1)

    # do the fft, ifft and return num_required
    kf = torch.fft(keys, signal_ndim=1, normalized=normalized)
    vf = torch.fft(values, signal_ndim=1, normalized=normalized)
    kvif = torch.ifft(kf*vf, signal_ndim=1, normalized=normalized)#[:, 0:required_size]

    # if conj:
    #     return Complex(kvif[:, :, 1], kvif[:, :, 0]).unstack()
    #return Complex(kvif[:, :, 0], kvif[:, :, 1]).abs() if not conj \
    # return Complex(kvif[:, :, 0], kvif[:, :, 1]).unstack() # if not conj \
        # else Complex(kvif[:, :, 1], kvif[:, :, 0]).abs()

    return Complex(kvif[:, :, 0], kvif[:, :, 1]).unstack().view(batch_size, -1)
コード例 #11
0
ファイル: utils.py プロジェクト: overjetdental/ilo
def partial_circulant_torch(inputs, filters, indices, sign_pattern):
    '''

    '''
    n = np.prod(inputs.shape[1:])
    bs = inputs.shape[0]
    input_reshape = inputs.reshape(bs, n)
    input_sign = input_reshape * sign_pattern

    def to_complex(tensor):
        zeros = torch.zeros_like(tensor)
        concat = torch.cat((tensor, zeros), axis=0)
        reshape = concat.view(2, -1, n)
        return reshape.permute(1, 2, 0)

    complex_input = to_complex(input_sign)
    complex_filter = to_complex(filters)
    input_fft = torch.fft(complex_input, 1)
    filter_fft = torch.fft(complex_filter, 1)
    output_fft = torch.zeros_like(input_fft)
    # is there a simpler way to do complex multiplies in pytorch?
    output_fft[:, :, 0] = input_fft[:, :, 0] * filter_fft[:, :, 0] - input_fft[:, :, 1] * filter_fft[:, :, 1]
    output_fft[:, :, 1] = input_fft[:, :, 1] * filter_fft[:, :, 0] + input_fft[:, :, 0] * filter_fft[:, :, 1]
    output_ifft = torch.ifft(output_fft, 1)
    output_real = output_ifft[:, :, 0]
    return output_real[:, indices]
コード例 #12
0
ファイル: networkUtil.py プロジェクト: Cassie317/CSMRI
def kspaceFuse(x1, x2):
    lout = []
    for xin in [x1, x2]:
        if (len(xin.shape) == 4):
            if (xin.shape[1] == 1):
                emptyImag = torch.zeros_like(xin)
                xin_c = torch.cat([xin, emptyImag], 1).permute(0, 2, 3, 1)
            else:
                xin_c = xin.permute(0, 2, 3, 1)
        elif (len(xin.shape) == 5):
            if (xin.shape[1] == 1):
                emptyImag = torch.zeros_like(xin)
                xin_c = torch.cat([xin, emptyImag], 1).permute(0, 2, 3, 4, 1)
            else:
                xin_c = xin.permute(0, 2, 3, 4, 1)
        else:
            assert False, "xin shape length has to be 4(2d) or 5(3d)"
        lout.append(xin_c)
    x1c, x2c = lout

    x1f = torch.fft(x1c, 2, normalized=True)
    x2f = torch.fft(x2c, 2, normalized=True)

    xout_f = x1f + x2f

    xout = torch.ifft(xout_f, 2, normalized=True)
    if (len(x1.shape) == 4):
        xout = xout.permute(0, 3, 1, 2)
    else:
        xout = xout.permute(0, 4, 1, 2, 3)
    if (xin.shape[1] == 1):
        xout = torch.sqrt(xout[:, 0:1] * xout[:, 0:1] +
                          xout[:, 1:2] * xout[:, 1:2])

    return xout
コード例 #13
0
    def forward(self, bottom1, bottom2):
        assert bottom1.size(1) == self.input_dim1 and \
               bottom2.size(1) == self.input_dim2

        batch_size, _, height, width = bottom1.size()

        bottom1_flat = bottom1.permute(0, 2, 3, 1).contiguous().view(
            -1, self.input_dim1)
        bottom2_flat = bottom2.permute(0, 2, 3, 1).contiguous().view(
            -1, self.input_dim2)

        sketch_1 = bottom1_flat.mm(self.sparse_sketch_matrix1)
        sketch_2 = bottom2_flat.mm(self.sparse_sketch_matrix2)
        sketch_1 = torch.stack((sketch_1, torch.zeros_like(sketch_1)), 2)
        sketch_2 = torch.stack((sketch_2, torch.zeros_like(sketch_2)), 2)

        fft1 = torch.fft(sketch_1, 1).split(1, dim=-1)
        fft1_real = fft1[0].squeeze()
        fft1_imag = fft1[1].squeeze()

        fft2 = torch.fft(sketch_2, 1).split(1, dim=-1)
        fft2_real = fft2[0].squeeze()
        fft2_imag = fft2[1].squeeze()

        fft_product_real = fft1_real.mul(fft2_real) - fft1_imag.mul(fft2_imag)
        fft_product_imag = fft1_real.mul(fft2_imag) + fft1_imag.mul(fft2_real)
        fft_product = torch.stack((fft_product_real, fft_product_imag), dim=2)

        cbp_flat = torch.ifft(fft_product, 1).split(1, dim=-1)[0].squeeze()
        cbp = cbp_flat.view(batch_size, height, width, self.output_dim)

        if self.sum_pool:
            cbp = cbp.sum(dim=1).sum(dim=1)

        return cbp
コード例 #14
0
def inv_filt_torch(Y, R, batch_dim=False):
    if not batch_dim:
        Y, R = Y.unsqueeze(0), R.unsqueeze(0)

    Yc = complexify(Y)
    W = torch.ifft(Yc, signal_ndim=2)

    n, s1, s2 = R.shape
    s2 *= 3

    S_auto = torch_fft_shift(W, dims=(1, 2))
    XR_cross = S_auto[:, :, :2 * R.shape[1]]

    _, t1, t2, _ = XR_cross.shape

    R_ = torch.zeros((n, t1, t2, 2), dtype=torch.float, device=W.device)
    R_[:, :s1, :R.shape[2], 0] = R
    F_R = torch.fft(R_, signal_ndim=2)
    F_R_conj = torch.clone(F_R)
    F_R_conj[..., 1] *= -1
    F_R_abs = (F_R**2).sum(-1, keepdim=True)

    X_ = torch.ifft(complex_mult(torch.fft(XR_cross, signal_ndim=2), F_R) /
                    F_R_abs,
                    signal_ndim=2)

    X = X_[:, R.shape[1]:, R.shape[2]:]

    if not batch_dim:
        X = X[0]

    return X
コード例 #15
0
def orth_phase22(im2, loc, devid):
    """
    Given a batch of images and a tensor of local maxima, this function returns
    a tuple consisting of the phase and the orthogonal phase centered at the local
    minima.
    """
    #im2 (1, P_c, N, N, 2)
    size = im2.size(-2)
    phase = torch.atan2(im2[...,1], im2[...,0])  # (1, P_c, N, N)
    #  unwrapping
    z = torch.arange(-size//2,size//2).unsqueeze(0).unsqueeze(0)
    z = z.repeat(tuple(loc.size()[:2])+(1,)).type(torch.cuda.FloatTensor)  # (1, P_c, N)
    z1 = z.unsqueeze(-1).repeat(1, 1, 1, size)  # (1, P_c, N, N)
    z2 = z.unsqueeze(-2).repeat(1, 1, size, 1)  # (1, P_c, N, N)
    z = z1**2 + z2**2
    del z1; del z2
    z = shift2(z, -torch.cuda.FloatTensor([size//2,size//2]).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1,loc.size(1),1,1), devid)
    z = z.squeeze().unsqueeze(0).unsqueeze(-1)
    cplx_phase = torch.stack((torch.cos(phase), torch.sin(phase)), dim=-1)
    cpi = complex_mul(torch.ifft(z*torch.fft(cplx_phase,2),2),conjugate(cplx_phase))[...,1]
    lin_sp_c_dx = torch.ifft(torch.fft(torch.stack((cpi, 0*cpi.clone()),dim=-1),2)/z, 2)[...,0]
    shifted_phase = shift2(phase, loc, devid)
    lin_sp_c_dx = lin_sp_c_dx - lin_sp_c_dx[:,:,0,0].unsqueeze(-1).unsqueeze(-1)
    lin_sp_c_dx = lin_sp_c_dx.unsqueeze(2) + shifted_phase[:,:,:,0,0].unsqueeze(-1).unsqueeze(-1)

    t_s_phase = torch.transpose(lin_sp_c_dx, -1, -2)  # (1, P_c, m, N, N)
    del(shifted_phase)
    if loc.size(2) == 1:
        t_s_phase = torch.flip(t_s_phase.unsqueeze(-2), [-2,-3]).squeeze().unsqueeze(0).unsqueeze(0)  # (1, P_c, m, N, N)
    else:
        t_s_phase = torch.flip(t_s_phase.unsqueeze(-2), [-2,-3]).squeeze().unsqueeze(0)  # (1, P_c, m, N, N)
    orth_ph = unshift2(t_s_phase, loc, devid)  # (1, P_c, m, N, N)
    phase_ = unshift2(lin_sp_c_dx, loc, devid)
    del(t_s_phase)
    return phase_, orth_ph
コード例 #16
0
ファイル: Lens.py プロジェクト: pvjosue/WaveBlocks
    def propagate_focal_to_back(self, u1):
        # Based on the function propFF out of the book "Computational Fourier
        # Optics. A MATLAB Tutorial". There you can find more information.

        wavelenght = self.optic_config.PSF_config.wvl
        [M, N] = u1.shape[-3:-1]

        #source sample interval
        dx1 = self.sampling_rate

        # obs sidelength
        L2 = wavelenght * self.focal_length / dx1

        #obs sample interval
        #dx2 = wavelenght*self.focal_length/L1

        # filter input with apperture mask
        # mask = self.mask.unsqueeze(0).unsqueeze(0).repeat((u1.shape[0],u1.shape[1],1,1,1))
        # u1 = torch.mul(u1,self.TransferFunctionIncoherent)

        #output field
        if M % 2 == 1:
            u2 = torch.mul(
                self.TransferFunctionIncoherent,
                ob.batch_fftshift2d(torch.fft(ob.batch_ifftshift2d(u1),
                                              2))) * dx1 * dx1
        else:
            u2 = torch.mul(
                self.TransferFunctionIncoherent,
                ob.batch_ifftshift2d(torch.fft(ob.batch_fftshift2d(u1),
                                               2))) * dx1 * dx1

        # multiply by precomputed coeff
        u2 = ob.mulComplex(u2, self.coefU1minus)
        return u2, L2
コード例 #17
0
def synthesis(target, test_id, ind, scat, n, min_error, err_it, nit, is_complex = False, initial_type = 'gaussian'):
    if torch.cuda.is_available():
        target = target.cuda()
    print(is_complex)
    # set up target
    if is_complex:
        target_hat = torch.fft(target,2)
        s_target = scat(target_hat)
        if initial_type == 'gaussian':
            x0 = torch.randn(n,n,2)
        elif initial_type == 'uniform':
            x0 = torch.rand(n,n,2)
    else:
        target_hat = torch.rfft(target, 2, onesided = False)
        s_target = scat(target_hat)
        if initial_type == 'gaussian':
            x0 = torch.randn(n,n)
        elif initial_type == 'uniform':
            x0 = torch.rand(n,n)
    x0 = Dealias(x0)
    
    if torch.cuda.is_available():
        s_target = s_target.cuda()
        x0 = x0.cuda()
    x0 = Variable(x0, requires_grad=True)
    
    if is_complex:
        x0_hat = torch.fft(x0, 2)
    else:
        x0_hat = torch.rfft(x0, 2, onesided = False)
    s0 = scat(x0_hat)
    loss = nn.MSELoss()
    optimizer = optim.Adam([x0], lr=lr)
    output = loss(s_target, s0)
    l0 = output
    error = []
    count = 0
    while output / l0 > min_error:
        optimizer.zero_grad() 
        if is_complex:
            x0_hat = torch.fft(x0, 2)
        else:
            x0_hat = torch.rfft(x0, 2, onesided = False)
        s0 = scat(x0_hat)
        output = loss(s_target, s0)
        if count % err_it ==0:   
            error.append(output.item())
        output.backward()
        optimizer.step()
        output = loss(s_target, s0)
        if count % nit == 0:
            print(output.data.cpu().numpy())
            np.save('./result%d/syn_error%d.npy'%(test_id, ind), np.asarray(error))
            np.save('./result%d/syn_result%d.npy'%(test_id, ind), x0.data.cpu().numpy())
#            plot_image(x0.data.cpu().numpy(), test_id, ind, count, nit)
        count += 1
    print('error reduced by: ', output / l0)
    print('error supposed reduced by: ', min_error)
    np.save('./result%d/syn_error%d.npy'%(test_id, ind), np.asarray(error))
    np.save('./result%d/syn_result%d.npy'%(test_id, ind), x0.data.cpu().numpy())
コード例 #18
0
def cross_conv_corr(x, x_hat):
    if len(x.shape) != 3 or len(x_hat.shape) != 3:
        raise RuntimeError(
            'Expects images with dimensions (batch_idx, nx, ny)')

    x_comp = torch.zeros(x.shape[0], x.shape[1], x.shape[2], 2)
    x_comp[:, :, :, 0] = x

    x_hat_comp = torch.zeros(x.shape[0], x.shape[1], x.shape[2], 2)
    x_hat_comp[:, :, :, 0] = x_hat

    fx = torch.fft(x_comp, signal_ndim=2)
    fx_hat = torch.fft(x_hat_comp, signal_ndim=2)

    cross_conv_fft = torch.stack([
        fx[:, :, :, 0] * fx_hat[:, :, :, 0] -
        fx[:, :, :, 1] * fx_hat[:, :, :, 1],
        fx[:, :, :, 0] * fx_hat[:, :, :, 1] +
        fx[:, :, :, 1] * fx_hat[:, :, :, 0]
    ], -1)
    cross_corr_fft = torch.stack([
        fx[:, :, :, 0] * fx_hat[:, :, :, 0] +
        fx[:, :, :, 1] * fx_hat[:, :, :, 1],
        fx[:, :, :, 0] * fx_hat[:, :, :, 1] -
        fx[:, :, :, 1] * fx_hat[:, :, :, 0]
    ], -1)

    cross_conv = torch.ifft(cross_conv_fft, signal_ndim=2)
    cross_corr = torch.ifft(cross_corr_fft, signal_ndim=2)
    return cross_conv[..., 0], cross_corr[..., 0]
コード例 #19
0
 def __call__(self, x, masks=None):
     if masks is None:
         y = self.mask.view(1, *self.mask.shape, 1) * torch.fft(
             x.permute(0, 2, 3, 1), signal_ndim=2, normalized=True)
     else:
         y = masks.view(*masks.shape, 1) * torch.fft(
             x.permute(0, 2, 3, 1), signal_ndim=2, normalized=True)
     return y.permute(0, 3, 1, 2)
コード例 #20
0
def calCC(aT,bT):
    N=aT.shape[0]
    cT=torch.zeros((N,2))
    aT=torch.fft(aT,1)
    bT=torch.fft(bT,1)
    cT[:,0]=aT[:,0]*bT[:,0]+aT[:,1]*bT[:,1]
    cT[:,1]=aT[:,0]*bT[:,1]-aT[:,1]*bT[:,0]
    return torch.ifft(cT,1)[:,0]
コード例 #21
0
def custom_fft(iT, real=True):
    ###make complex again --- TEMP: use complex signal
    if real:
        iTC = torch.stack([iT, torch.zeros_like(iT, requires_grad=False)],
                          dim=-1)
        return torch.fft(iTC, signal_ndim=2, normalized=normalizeFFT)
    ###convert to modified transformation
    else:
        return torch.fft(iT, signal_ndim=2, normalized=normalizeFFT)
コード例 #22
0
def lower_level_mixed_derivs(x, w, y, S, alpha, eps, A, reg_func):
    # Compute mixed second derivatives of the lower level objective function for use in the adjoint method
    Fw = torch.fft(w, 2, normalized=True)
    DwDy = -S.view(1, *S.shape, 1)**2 * Fw
    DwDalpha = torch.sum(w * A.T(reg_func.grad(A(x))))
    Fx = torch.fft(x, 2, normalized=True)
    DwDS = torch.sum(Fw * 2 * S.view(1, *S.shape, 1) * (Fx - y), dim=(0, 3))
    DwDeps = torch.sum(w * x)
    return DwDy, DwDS, DwDalpha, DwDeps
コード例 #23
0
 def forward(self, h, r, t):
     h_e, r_e, t_e = self.embed(h, r, t)
     r_e = F.normalize(r_e, p=2, dim=-1)
     h_e = torch.stack((h_e, torch.zeros_like(h_e)), -1)
     t_e = torch.stack((t_e, torch.zeros_like(t_e)), -1)
     e, _ = torch.unbind(
         torch.ifft(torch.conj(torch.fft(h_e, 1)) * torch.fft(t_e, 1), 1),
         -1)
     return -F.sigmoid(torch.sum(r_e * e, 1))
コード例 #24
0
ファイル: loss.py プロジェクト: mmrebuttal/code
 def forward(self, input, recon):
     input = input.unsqueeze(4)
     recon = recon.unsqueeze(4)
     input = torch.cat((input, input), 4)
     recon = torch.cat((recon, recon), 4)
     input_fft = torch.fft(input, 2)
     recon_fft = torch.fft(recon, 2)
     fft_loss = self.mse_loss(input_fft, recon_fft)
     return fft_loss
コード例 #25
0
    def forward(self, data_streams):
        """
        Main forward pass of the model.

        :param data_streams: DataStreams({'images',**})
        :type data_streams: ``ptp.dadatypes.DataStreams``
        """
        # Unpack DataStreams.
        enc_img = data_streams[self.key_image_encodings]
        enc_q = data_streams[self.key_question_encodings]

        sketch_pm_img = self.image_sketch_projection_matrix
        sketch_pm_q = self.question_sketch_projection_matrix

        # Project both batches.
        sketch_img = enc_img.mm(sketch_pm_img)
        sketch_q = enc_q.mm(sketch_pm_q)

        # Add imaginary parts (with zeros).
        sketch_img_reim = torch.stack([
            sketch_img,
            torch.zeros(sketch_img.shape).type(self.app_state.FloatTensor)
        ],
                                      dim=2)
        sketch_q_reim = torch.stack([
            sketch_q,
            torch.zeros(sketch_q.shape).type(self.app_state.FloatTensor)
        ],
                                    dim=2)
        #print("\n sketch_img_reim=",sketch_img_reim)
        #print("\n sketch_img_reim.shape=",sketch_img_reim.shape)

        # Perform FFT.
        # Returns the real and the imaginary parts together as one tensor of the same shape of input.
        fft_img = torch.fft(sketch_img_reim, signal_ndim=1)
        fft_q = torch.fft(sketch_q_reim, signal_ndim=1)
        #print(fft_img)

        # Get real and imaginary parts.
        real1 = fft_img[:, :, 0]
        imag1 = fft_img[:, :, 1]
        real2 = fft_q[:, :, 0]
        imag2 = fft_q[:, :, 1]

        # Calculate product.
        fft_product = torch.stack(
            [real1 * real2 - imag1 * imag2, real1 * imag2 + imag1 * real2],
            dim=-1)
        #print("fft_product=",fft_product)

        # Inverse FFT.
        cbp = torch.ifft(fft_product, signal_ndim=1)[:, :, 0]
        #print("cbp=",cbp)

        # Add predictions to datadict.
        data_streams.publish({self.key_outputs: cbp})
コード例 #26
0
    def dispersion(xpol, ypol, length):
        frequency_domain_xpol = torch.fft(xpol, 1)
        frequency_domain_ypol = torch.fft(ypol, 1)
        frequency_domain_xpol = complex_exp(frequency_domain_xpol, D * length)
        frequency_domain_ypol = complex_exp(frequency_domain_ypol, D * length)

        xpol_time_domain = torch.ifft(frequency_domain_xpol, 1)
        ypol_time_domain = torch.ifft(frequency_domain_ypol, 1)

        return xpol_time_domain, ypol_time_domain
コード例 #27
0
    def perform(self, x, k0, mask, sensitivity):
        """
        transform to x-f space with subtraction of average temporal frame in multi-coil setting
        :param x: input image with shape [nt, nx, ny, 2]
        :param mask: undersampling mask [nt, ns, nx, ny, 2]
        :param k0: undersampled k-space data [nt, ns, nx, ny, 2]
        :param sensitivity: sensitivity maps [nt, ns, nx, ny, 2]
        :return: difference data; DC baseline
        """

        x = complex_multiply(x[..., 0].unsqueeze(1), x[..., 1].unsqueeze(1),
                             sensitivity[..., 0], sensitivity[..., 1])
        k = torch.fft(x, 2, normalized=self.normalized)
        if self.divide_by_n:
            k_avg = torch.div(torch.sum(k, 0), k.shape[0])
        else:
            k_avg = torch.div(torch.sum(k0, 0),
                              torch.clamp(torch.sum(mask, 0), min=1))

        ns, nx, ny, nc = k_avg.shape
        k_avg = k_avg.view(1, ns, nx, ny, nc)
        k_avg = k_avg.repeat(k.shape[0], 1, 1, 1, 1)

        # subtract the temporal average frame
        k_diff = torch.sub(k, k_avg)
        x_diff = torch.ifft(k_diff, 2, normalized=self.normalized)
        Sx_diff = complex_multiply(x_diff[..., 0], x_diff[..., 1],
                                   sensitivity[...,
                                               0], -sensitivity[..., 1]).sum(
                                                   dim=1)  # [nt, nx, ny, 2]

        # transform to x-f space to get the baseline
        x_avg = torch.ifft(k_avg, 2, normalized=self.normalized)
        Sx_avg = complex_multiply(x_avg[..., 0], x_avg[...,
                                                       1], sensitivity[..., 0],
                                  -sensitivity[..., 1]).sum(dim=1)

        Sx_avg = Sx_avg.permute(1, 2, 0, 3)  # [nx, ny, nt, 2]
        x_f_avg = fftshift_pytorch(torch.fft(ifftshift_pytorch(Sx_avg,
                                                               axes=[-2]),
                                             1,
                                             normalized=self.normalized),
                                   axes=[-2])
        x_f_avg = x_f_avg.permute(2, 0, 1, 3)

        # difference data
        Sx_diff = Sx_diff.permute(1, 2, 0, 3)  # [nx, ny, nt, 2]
        x_f_diff = fftshift_pytorch(torch.fft(ifftshift_pytorch(Sx_diff,
                                                                axes=[-2]),
                                              1,
                                              normalized=self.normalized),
                                    axes=[-2])
        x_f_diff = x_f_diff.permute(2, 0, 1, 3)

        return x_f_diff, x_f_avg
コード例 #28
0
def bench(batch_size: int, d: int, hw: int, num_iter: int):
    if not torch.cuda.is_available():
        print("GPU is not available")
        return

    device = torch.device('cuda:0')

    torch.set_grad_enabled(False)

    # BxDxHxWx2
    inp = torch.randn(batch_size, d, hw, hw, 2, device=device)

    # warmup
    outp = torch.fft(inp, 3)
    inp_ = torch.ifft(outp, 3)

    # fft
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    with contexttimer.Timer() as t:
        for it in range(num_iter):
            outp = torch.fft(inp, 3)
    end.record()
    torch.cuda.synchronize()
    elapsed = start.elapsed_time(end) / 1e3
    tps = num_iter / elapsed
    fft_time_consume = elapsed

    del outp, inp

    outp = torch.randn(batch_size, d, hw, hw, 2, device=device)

    # ifft
    start.record()
    with contexttimer.Timer() as t:
        for it in range(num_iter):
            inp_ = torch.ifft(outp, 3)
    end.record()
    torch.cuda.synchronize()
    elapsed = start.elapsed_time(end) / 1e3
    itps = num_iter / elapsed
    ifft_time_consume = elapsed

    print(
        json.dumps({
            "TPS": tps,
            "fft_elapsed": fft_time_consume,
            "ITPS": itps,
            "ifft_elapsed": ifft_time_consume,
            "n": num_iter,
            "batch_size": batch_size,
            "D_size": d,
            "HW_size": hw,
        }))
コード例 #29
0
    def grad(self, x):
        ys = torch.stack((self.y, torch.zeros_like(self.y)), 2)
        Fy = torch.fft(ys, 2)
        AHy = mul_c(conj(self.fpsf), Fy)

        xs = torch.stack((x, torch.zeros_like(x)), 2)
        Fx = torch.fft(xs, 2)
        AHA = mul_c(conj(self.fpsf), self.fpsf)
        AHAx = mul_c(AHA, Fx)

        return torch.ifft(AHAx - AHy, 2)[..., 0]
コード例 #30
0
    def test_fft_function_clobbered(self, device):
        t = torch.randn((100, 2), device=device)
        eager_result = fft_fn(t, 1)

        def method_fn(t):
            return t.fft(1)
        scripted_method_fn = torch.jit.script(method_fn)

        self.assertEqual(scripted_method_fn(t), eager_result)

        with self.assertRaisesRegex(TypeError, "'module' object is not callable"):
            torch.fft(t, 1)
コード例 #31
0
    def test_circulant(self):
        batch_size = 10
        n = 13
        for complex in [False, True]:
            dtype = torch.float32 if not complex else torch.complex64
            col = torch.randn(n, dtype=dtype)
            C = la.circulant(col.numpy())
            input = torch.randn(batch_size, n, dtype=dtype)
            out_torch = torch.tensor(input.detach().numpy() @ C.T)
            out_np = torch.tensor(np.fft.ifft(
                np.fft.fft(input.numpy()) * np.fft.fft(col.numpy())),
                                  dtype=dtype)
            self.assertTrue(
                torch.allclose(out_torch, out_np, self.rtol, self.atol))
            # Just to show how to implement circulant multiply with FFT
            if complex:
                input_f = view_as_complex(
                    torch.fft(view_as_real(input), signal_ndim=1))
                col_f = view_as_complex(
                    torch.fft(view_as_real(col), signal_ndim=1))
                prod_f = complex_mul(input_f, col_f)
                out_fft = view_as_complex(
                    torch.ifft(view_as_real(prod_f), signal_ndim=1))
                self.assertTrue(
                    torch.allclose(out_torch, out_fft, self.rtol, self.atol))
            for separate_diagonal in [True, False]:
                b = torch_butterfly.special.circulant(
                    col, transposed=False, separate_diagonal=separate_diagonal)
                out = b(input)
                self.assertTrue(
                    torch.allclose(out, out_torch, self.rtol, self.atol))

            row = torch.randn(n, dtype=dtype)
            C = la.circulant(row.numpy()).T
            input = torch.randn(batch_size, n, dtype=dtype)
            out_torch = torch.tensor(input.detach().numpy() @ C.T)
            # row is the reverse of col, except the 0-th element stays put
            # This corresponds to the same reversal in the frequency domain.
            # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal
            row_f = np.fft.fft(row.numpy())
            row_f_reversed = np.hstack((row_f[:1], row_f[1:][::-1]))
            out_np = torch.tensor(np.fft.ifft(
                np.fft.fft(input.numpy()) * row_f_reversed),
                                  dtype=dtype)
            self.assertTrue(
                torch.allclose(out_torch, out_np, self.rtol, self.atol))
            for separate_diagonal in [True, False]:
                b = torch_butterfly.special.circulant(
                    row, transposed=True, separate_diagonal=separate_diagonal)
                out = b(input)
                self.assertTrue(
                    torch.allclose(out, out_torch, self.rtol, self.atol))