コード例 #1
0
def circular_convolution_fft(keys, values, normalized=True, conj=False, cuda=False):
    '''
    For the circular convolution of x and y to be equivalent,
    you must pad the vectors with zeros to length at least N + L - 1
    before you take the DFT. After you invert the product of the
    DFTs, retain only the first N + L - 1 elements.
    '''
    assert values.dim() == keys.dim() == 2, "only 2 dims supported"
    assert values.size(-1) % 2 == keys.size(-1) % 2 == 0, "need last dim to be divisible by 2"
    batch_size, keys_feature_size = keys.size(0), keys.size(1)
    values_feature_size = values.size(1)
    required_size = keys_feature_size + values_feature_size - 1
    required_size = required_size + 1 if required_size % 2 != 0 else required_size

    # conj transpose
    keys = Complex(keys).conj().unstack() if conj else keys

    # reshape to [batch, [real, imag]]
    half = keys.size(-1) // 2
    keys = torch.cat([keys[:, 0:half].unsqueeze(2), keys[:, half:].unsqueeze(2)], -1)
    values = torch.cat([values[:, 0:half].unsqueeze(2), values[:, half:].unsqueeze(2)], -1)

    # do the fft, ifft and return num_required
    kf = torch.fft(keys, signal_ndim=1, normalized=normalized)
    vf = torch.fft(values, signal_ndim=1, normalized=normalized)
    kvif = torch.ifft(kf*vf, signal_ndim=1, normalized=normalized)#[:, 0:required_size]

    # if conj:
    #     return Complex(kvif[:, :, 1], kvif[:, :, 0]).unstack()
    #return Complex(kvif[:, :, 0], kvif[:, :, 1]).abs() if not conj \
    # return Complex(kvif[:, :, 0], kvif[:, :, 1]).unstack() # if not conj \
        # else Complex(kvif[:, :, 1], kvif[:, :, 0]).abs()

    return Complex(kvif[:, :, 0], kvif[:, :, 1]).unstack().view(batch_size, -1)
コード例 #2
0
ファイル: so3_fft.py プロジェクト: janithPet/s2cnn-1
def so3_ifft(x, for_grad=False, b_out=None):
    '''
    :param x: [l * m * n, ..., complex]
    '''
    assert x.size(-1) == 2
    nspec = x.size(0)
    b_in = round((3 / 4 * nspec)**(1 / 3))
    assert nspec == b_in * (4 * b_in**2 - 1) // 3
    if b_out is None:
        b_out = b_in
    batch_size = x.size()[1:-1]

    x = x.view(nspec, -1, 2)  # [l * m * n, batch, complex] (nspec, nbatch, 2)
    '''
    :param x: [l * m * n, batch, complex] (b_in (4 b_in**2 - 1) // 3, nbatch, 2)
    :return: [batch, beta, alpha, gamma, complex] (nbatch, 2 b_out, 2 b_out, 2 b_out, 2)
    '''
    nbatch = x.size(1)

    wigner = _setup_wigner(
        b_out, nl=b_in, weighted=for_grad,
        device=x.device)  # [beta, l * m * n] (2 * b_out, nspec)

    output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2 * b_out, 2))
    if x.is_cuda and x.dtype == torch.float32:
        cuda_kernel = _setup_so3ifft_cuda_kernel(b_in=b_in,
                                                 b_out=b_out,
                                                 nbatch=nbatch,
                                                 real_output=False,
                                                 device=x.device.index)
        cuda_kernel(x, wigner, output)  # [batch, beta, m, n, complex]
    else:
        output.fill_(0)
        for l in range(min(b_in, b_out)):
            s = slice(l * (4 * l**2 - 1) // 3,
                      l * (4 * l**2 - 1) // 3 + (2 * l + 1)**2)
            out = torch.einsum("mnzc,bmn->zbmnc",
                               (x[s].view(2 * l + 1, 2 * l + 1, -1, 2),
                                wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1)))
            l1 = min(l, b_out - 1)  # if b_out < b_in
            output[:, :, :l1 + 1, :l1 + 1] += out[:, :, l:l + l1 + 1,
                                                  l:l + l1 + 1]
            if l > 0:
                output[:, :, -l1:, :l1 + 1] += out[:, :, l - l1:l,
                                                   l:l + l1 + 1]
                output[:, :, :l1 + 1, -l1:] += out[:, :, l:l + l1 + 1,
                                                   l - l1:l]
                output[:, :, -l1:, -l1:] += out[:, :, l - l1:l, l - l1:l]

    output = torch.ifft(
        output, 2) * output.size(-2)**2  # [batch, beta, alpha, gamma, complex]
    output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2 * b_out, 2)
    return output
コード例 #3
0
    def backward(ctx, grad_output):

        # extract saved tensors for gradient update
        device, fft_res, phasemask_term, phase_emitter = ctx.device, ctx.fft_res, ctx.phasemask_term, ctx.phase_emitter
        normfactor, Nbatch, Nemitters, H, W = ctx.normfactor, ctx.Nbatch, ctx.Nemitters, ctx.H, ctx.W

        # gradient w.r.t the single-emitter images
        grad_input = grad_output.data

        # depth-wise normalization factor
        for i in range(Nbatch):
            for j in range(Nemitters):
                grad_input[i,
                           j, :, :] = grad_input[i, j, :, :] * normfactor[i, j]

        # gradient of abs squared
        grad_abs_square = torch.zeros((Nbatch, Nemitters, H, W, 2)).to(device)
        grad_abs_square[:, :, :, :,
                        0] = 2 * grad_input * fft_res[:, :, :, :, 0]
        grad_abs_square[:, :, :, :,
                        1] = 2 * grad_input * fft_res[:, :, :, :, 1]

        # calculate the centered inverse fourier transform on the H, W dims
        grad_fft = batch_fftshift2d(
            torch.ifft(batch_ifftshift2d(grad_abs_square), 2, True))

        # gradient w.r.t phase mask phase_term
        grad_phasemask_term = torch.zeros(
            (Nbatch, Nemitters, H, W, 2)).to(device)
        grad_phasemask_term[:, :, :, :,
                            0] = grad_fft[:, :, :, :,
                                          0] * phase_emitter[:, :, :, :,
                                                             0] + grad_fft[:, :, :, :,
                                                                           1] * phase_emitter[:, :, :, :,
                                                                                              1]
        grad_phasemask_term[:, :, :, :,
                            1] = -grad_fft[:, :, :, :,
                                           0] * phase_emitter[:, :, :, :,
                                                              1] + grad_fft[:, :, :, :,
                                                                            1] * phase_emitter[:, :, :, :,
                                                                                               0]

        # gradient w.r.t the phasemask 4D
        grad_phasemask4D = -grad_phasemask_term[:, :, :, :,
                                                0] * phasemask_term[:, :, :, :,
                                                                    1] + grad_phasemask_term[:, :, :, :,
                                                                                             1] * phasemask_term[:, :, :, :,
                                                                                                                 0]

        # sum to get the final gradient
        grad_phasemask = grad_phasemask4D.sum(0).sum(0)

        return grad_phasemask, None, None, None
コード例 #4
0
def forward_operator_from_real(x, mask):
    """ Forward operator for real images
    :param x: real input image
    :param mask: mask of radial lines
    :return: x_new
    """
    x_new = torch.rfft(x, signal_ndim=3, onesided=False) / x.shape[1]
    x_new[:, :, :, 0] = torch.mul(torch.from_numpy(mask).float().cuda(), x_new[:, :, :, 0])
    x_new[:, :, :, 1] = torch.mul(torch.from_numpy(mask).float().cuda(), x_new[:, :, :, 1])
    x_new = torch.ifft(x_new, signal_ndim=3) * x.shape[1]

    return x_new
コード例 #5
0
def forward_operator(x, mask):
    """ Forward operator for complex images
    :param x: complex input image
    :param mask: mask of radial lines
    :return: x_new
    """
    x_new = torch.fft(x, signal_ndim=3) / x.shape[1]
    x_new[:, :, :, 0] = torch.mul(torch.from_numpy(mask).float().cuda(), x_new[:, :, :, 0])
    x_new[:, :, :, 1] = torch.mul(torch.from_numpy(mask).float().cuda(), x_new[:, :, :, 1])
    x_new = torch.ifft(x_new, signal_ndim=3) * x.shape[1]

    return x_new
コード例 #6
0
ファイル: utils.py プロジェクト: theocohen/neural-holography
def ifft2(tensor_re, tensor_im, shift=False):
    """Applies a 2D ifft to the complex tensor represented by tensor_re and _im"""
    tensor_out = torch.stack((tensor_re, tensor_im), 4)

    if shift:
        tensor_out = ifftshift(tensor_out)
    (tensor_out_re, tensor_out_im) = torch.ifft(tensor_out, 2, True).split(1, 4)

    tensor_out_re = tensor_out_re.squeeze(4)
    tensor_out_im = tensor_out_im.squeeze(4)

    return tensor_out_re, tensor_out_im
コード例 #7
0
ファイル: wrappers.py プロジェクト: zhenchen16/adorym
def ifft(var_real, var_imag, axis=-1, backend='autograd', normalize=False):
    if backend == 'autograd':
        var = var_real + 1j * var_imag
        norm = None if not normalize else 'ortho'
        var = anp.fft.ifft(var, axis=axis, norm=norm)
        return anp.real(var), anp.imag(var)
    elif backend == 'pytorch':
        var = tc.stack([var_real, var_imag], dim=-1)
        var = tc.ifft(var, signal_ndim=1, normalized=normalize)
        var_real, var_imag = tc.split(var, 1, dim=-1)
        slicer = [slice(None)] * (len(var_real.shape) - 1) + [0]
        return var_real[tuple(slicer)], var_imag[tuple(slicer)]
コード例 #8
0
ファイル: fpm.py プロジェクト: diamond2nv/LearnedDesignFPM
 def cropMeasurements(self, measurements):
     # cropping
     measurements_cropped = torch.zeros(measurements.shape[0],
                                        self.Np_meas[0], self.Np_meas[1], 2)
     for img_idx in range(measurements.shape[0]):
         fmeas = utility.fftshift2(torch.fft(measurements[img_idx, ...], 2))
         tmp = fmeas[self.crops[0]:self.crops[1],
                     self.crops[2]:self.crops[3], :]
         measurements_cropped[img_idx,
                              ...] = (1 / self.scaling) * torch.ifft(
                                  utility.ifftshift2(tmp), 2)
     return measurements_cropped
コード例 #9
0
def fft(x):
    """
    Layer that performs a fast Fourier-Transformation.
    """
    img_size = x.size(1) // 2
    # sort the incoming tensor in real and imaginary part
    arr_real = x[:, 0:img_size].reshape(-1, int(sqrt(img_size)), int(sqrt(img_size)))
    arr_imag = x[:, img_size:].reshape(-1, int(sqrt(img_size)), int(sqrt(img_size)))
    arr = torch.stack((arr_real, arr_imag), dim=-1)
    # perform fourier transformation and switch imaginary and real part
    arr_fft = torch.ifft(arr, 2).permute(0, 3, 2, 1).transpose(2, 3)
    return arr_fft
コード例 #10
0
def s2_ifft(x, for_grad=False, b_out=None):
    '''
    :param x: [l * m, ..., complex]
    '''
    assert x.size(-1) == 2
    nspec = x.size(0)
    b_in = round(nspec**0.5)
    assert nspec == b_in**2
    if b_out is None:
        b_out = b_in
    assert b_out >= b_in
    batch_size = x.size()[1:-1]

    x = x.view(nspec, -1, 2)  # [l * m, batch, complex] (nspec, nbatch, 2)
    '''
    :param x: [l * m, batch, complex] (b_in**2, nbatch, 2)
    :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2)
    '''
    nbatch = x.size(1)

    wigner = _setup_wigner(b_out, nl=b_in, weighted=for_grad, device=x.device)
    wigner = wigner.view(2 * b_out, -1)  # [beta, l * m] (2 * b_out, nspec)

    if x.is_cuda and x.dtype == torch.float32:
        import s2cnn.utils.cuda as cuda_utils
        cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out,
                                                nl=b_in,
                                                nbatch=nbatch,
                                                device=x.device.index)
        stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
        output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2))
        cuda_kernel(block=(1024, 1, 1),
                    grid=(cuda_utils.get_blocks(nbatch * (2 * b_out)**2,
                                                1024), 1, 1),
                    args=[x.data_ptr(),
                          wigner.data_ptr(),
                          output.data_ptr()],
                    stream=stream)
        # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2)
    else:
        output = x.new_zeros((nbatch, 2 * b_out, 2 * b_out, 2))
        for l in range(b_in):
            s = slice(l**2, l**2 + 2 * l + 1)
            out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s]))
            output[:, :, :l + 1] += out[:, :, -l - 1:]
            if l > 0:
                output[:, :, -l:] += out[:, :, :l]

    output = torch.ifft(output, 1) * output.size(
        -2)  # [batch, beta, alpha, complex]
    output = output.view(*batch_size, 2 * b_out, 2 * b_out, 2)
    return output
コード例 #11
0
    def test_circulant(self):
        batch_size = 10
        n = 13
        for complex in [False, True]:
            dtype = torch.float32 if not complex else torch.complex64
            col = torch.randn(n, dtype=dtype)
            C = la.circulant(col.numpy())
            input = torch.randn(batch_size, n, dtype=dtype)
            out_torch = torch.tensor(input.detach().numpy() @ C.T)
            out_np = torch.tensor(np.fft.ifft(
                np.fft.fft(input.numpy()) * np.fft.fft(col.numpy())),
                                  dtype=dtype)
            self.assertTrue(
                torch.allclose(out_torch, out_np, self.rtol, self.atol))
            # Just to show how to implement circulant multiply with FFT
            if complex:
                input_f = view_as_complex(
                    torch.fft(view_as_real(input), signal_ndim=1))
                col_f = view_as_complex(
                    torch.fft(view_as_real(col), signal_ndim=1))
                prod_f = complex_mul(input_f, col_f)
                out_fft = view_as_complex(
                    torch.ifft(view_as_real(prod_f), signal_ndim=1))
                self.assertTrue(
                    torch.allclose(out_torch, out_fft, self.rtol, self.atol))
            for separate_diagonal in [True, False]:
                b = torch_butterfly.special.circulant(
                    col, transposed=False, separate_diagonal=separate_diagonal)
                out = b(input)
                self.assertTrue(
                    torch.allclose(out, out_torch, self.rtol, self.atol))

            row = torch.randn(n, dtype=dtype)
            C = la.circulant(row.numpy()).T
            input = torch.randn(batch_size, n, dtype=dtype)
            out_torch = torch.tensor(input.detach().numpy() @ C.T)
            # row is the reverse of col, except the 0-th element stays put
            # This corresponds to the same reversal in the frequency domain.
            # https://en.wikipedia.org/wiki/Discrete_Fourier_transform#Time_and_frequency_reversal
            row_f = np.fft.fft(row.numpy())
            row_f_reversed = np.hstack((row_f[:1], row_f[1:][::-1]))
            out_np = torch.tensor(np.fft.ifft(
                np.fft.fft(input.numpy()) * row_f_reversed),
                                  dtype=dtype)
            self.assertTrue(
                torch.allclose(out_torch, out_np, self.rtol, self.atol))
            for separate_diagonal in [True, False]:
                b = torch_butterfly.special.circulant(
                    row, transposed=True, separate_diagonal=separate_diagonal)
                out = b(input)
                self.assertTrue(
                    torch.allclose(out, out_torch, self.rtol, self.atol))
コード例 #12
0
 def forward(self, x, mask):
     x_dim_0 = x.shape[0]
     x_dim_1 = x.shape[1]
     x_dim_2 = x.shape[2]
     x_dim_3 = x.shape[3]
     x = x.view(-1, x_dim_2, x_dim_3, 1)
     y = torch.zeros_like(x)
     z = torch.cat([x, y], 3)
     fftz = torch.fft(z, 2)
     z_hat = torch.ifft(fftz * mask, 2)
     x = z_hat[:, :, :, 0:1]
     x = x.view(x_dim_0, x_dim_1, x_dim_2, x_dim_3)
     return x
コード例 #13
0
def ifft2(data: torch.Tensor,
          dim: Tuple[str, ...] = ('height', 'width'),
          centered: bool = True) -> torch.Tensor:
    """
    Apply centered two-dimensional Inverse Fast Fourier Transform

    Parameters
    ----------
    data : torch.Tensor
        Complex-valued input tensor.
    dim : tuple, list or int
        Dimensions over which to compute.
    centered : bool
        Whether to apply a centered ifft (center of kspace is in the center versus in the corners).
        For FastMRI dataset this has to be true and for the Calgary-Campinas dataset false.

    Returns
    -------
    torch.Tensor: the ifft of the output.
    """
    assert_complex(data)

    if centered:
        data = ifftshift(data, dim=dim)
    names = data.names
    # TODO: Fix when ifft supports named tensors
    # Verify whether half precision and if ifft is possible in this shape. Else do a typecast.
    if verify_fft_dtype_possible(data, dim):
        data = torch.ifft(data.rename(None), 2, normalized=True)
    else:
        data = torch.ifft(data.rename(None).float(), 2,
                          normalized=True).type(data.type())

    if any(names):
        data = data.refine_names(*names)

    if centered:
        data = fftshift(data, dim=dim)
    return data
コード例 #14
0
 def forward(self, x_under, mask):
     x_under_per = x_under.permute(0, 2, 3, 1)
     x_zf_per = torch.ifft(x_under_per, 2)
     x_zf = x_zf_per.permute(0, 3, 1, 2)
     x_rec_dc = x_zf
     recimg = list()
     recimg.append(sigtoimage(x_zf))
     for i, l in enumerate(self.layer):
         x_rec = self.layer[i](x_rec_dc)
         x_res = x_rec_dc + x_rec
         x_rec_dc = self.dc(mask, x_res, x_under)
         recimg.append(sigtoimage(x_rec_dc))
     return recimg
コード例 #15
0
def ifft2(data):
    """
    Apply centered 2-dimensional Inverse Fast Fourier Transform.
    Args:
        data (torch.Tensor): Complex valued input data containing at least 3 dimensions: dimensions
            -3 & -2 are spatial dimensions and dimension -1 has size 2. All other dimensions are
            assumed to be batch dimensions.
    Returns:
        torch.Tensor: The IFFT of the input.
    """
    assert data.size(-1) == 2
    data = torch.ifft(data, 2, normalized=True)
    return data
コード例 #16
0
def ifft(x):
    '''x.size()=[1, 2, h, w]'''
    #change to [n,1,h,w,2]
    x = complex_split(x)
    #center_to_topleft
    x = cen_cor(x)
    #ifft
    x = torch.ifft(x, 2, normalized=True)
    #topleft to center
    x = cor_cen(x)
    #merge back to [1, 2, h, w]
    x = complex_merge(x)
    return x.to(device)
コード例 #17
0
ファイル: imageUtil.py プロジェクト: tinyRattar/CSMRI_0325
def imgFromSubF_pytorch(subF, returnComplex=False):
    subIm = torch.ifft(subF, 2, normalized=True)
    if (len(subIm.shape) == 4):
        subIm = subIm.permute(0, 3, 1, 2)
    else:
        subIm = subIm.permute(0, 4, 1, 2, 3)

    if (returnComplex):
        return subIm
    else:
        subIm = torch.sqrt(subIm[:, 0:1] * subIm[:, 0:1] +
                           subIm[:, 1:2] * subIm[:, 1:2])
        return subIm
コード例 #18
0
ファイル: fpm.py プロジェクト: sanjana-266/LearnedDesignFPM
    def generateMultiMeas(self, field, device="cpu"):
        self.pupils = self.pupils.to(self.device)
        self.planewaves = self.planewaves.to(self.device)
        self.P = self.P.to(self.device)

        output = mul_c(self.planewaves, field)
        output = torch.fft(output, 2)
        output = mul_c(self.P, output)
        output = torch.ifft(output, 2)
        output = abs2_c(output)
        multiMeas = torch.matmul(output.permute(1, 2, 0),
                                 self.C.permute(1, 0)).permute(2, 0, 1)
        return multiMeas
コード例 #19
0
    def forward(self, x_hat):
        # x_hat: n * n * 2
        s = 1 / (2 * pi)**2 * torch.mean(
            torch.mean((x_hat[:, :, 0]**2 + x_hat[:, :, 1]**2) *
                       (self.phi_hat_real**2 + self.phi_hat_imag**2), 0), 0)
        s = s.unsqueeze(0)
        J = self.psi_hat_real.size()[1]
        for i in range(J):
            temp_real = x_hat[:, :,
                              0] * self.psi_hat_real[:self.K,
                                                     i, :, :] - x_hat[:, :,
                                                                      1] * self.psi_hat_imag[:self
                                                                                             .
                                                                                             K,
                                                                                             i, :, :]  # first layer only use gabor wavelets which has K angles
            temp_imag = x_hat[:, :,
                              0] * self.psi_hat_imag[:self.K,
                                                     i, :, :] + x_hat[:, :,
                                                                      1] * self.psi_hat_real[:self
                                                                                             .
                                                                                             K,
                                                                                             i, :, :]
            temp = torch.ifft(
                torch.cat((temp_real.unsqueeze(3), temp_imag.unsqueeze(3)), 3),
                2)  # K * n * n * 2

            temp2 = torch.rfft(torch.sqrt(temp[:, :, :, 0]**2 +
                                          temp[:, :, :, 1]**2 + 1e-8),
                               2,
                               onesided=False)  # K * n * n * 2

            a = 1 / (2 * pi)**2 * torch.mean(
                torch.mean(
                    (temp2[:, :, :, 0]**2 + temp2[:, :, :, 1]**2) *
                    (self.phi_hat_real**2 + self.phi_hat_imag**2), 2), 1)
            s = torch.cat((s, a), 0)
            if i < J - 1:
                temp3 = (temp2[:, :, :, 0]**2 +
                         temp2[:, :, :, 1]**2).unsqueeze(1).unsqueeze(2)
                if self.second_all:
                    temp4 = (self.psi_hat_real[:, :, :, :]**2 +
                             self.psi_hat_imag[:, :, :, :]**2).unsqueeze(0)
                else:
                    temp4 = (
                        self.psi_hat_real[:, (i + 1):J, :, :]**2 +
                        self.psi_hat_imag[:, (i + 1):J, :, :]**2).unsqueeze(0)
                b = 1 / (2 * pi)**2 * torch.mean(torch.mean(temp3 * temp4, 4),
                                                 3)
                s = torch.cat((s, b.flatten()), 0)

        return s
コード例 #20
0
ファイル: fpm.py プロジェクト: sanjana-266/LearnedDesignFPM
    def grad(self, field_est, device='cpu'):
        self.measurements = self.measurements.to(self.device)
        self.pupils = self.pupils.to(self.device)
        self.planewaves = self.planewaves.to(self.device)
        self.P = self.P.to(self.device)

        multiMeas = torch.matmul(self.measurements.permute(1, 2, 0),
                                 self.C.permute(1, 0)).permute(2, 0, 1)
        multiMeas = torch.abs(multiMeas)

        # simulate current estimate of measurements
        y = self.generateMultiMeas(field_est, device=device)

        # compute residual
        sqrty = torch.sqrt(y + EPS)
        residual = sqrty - torch.sqrt(multiMeas + EPS)
        cost = torch.sum(torch.pow(residual, 2)).detach()
        Ajx = residual / (sqrty + 1e-10)
        Ajx_c = torch.stack((Ajx, torch.zeros_like(Ajx)), dim=len(Ajx.shape))

        # compute gradient
        output = mul_c(self.planewaves, field_est)
        output = torch.fft(output, 2)
        output = mul_c(self.P, output)
        output = torch.ifft(output, 2)

        g = field_est * 0.
        for meas_index in range(self.Nmeas):
            output2 = mul_c(Ajx_c[meas_index, ...], output)
            output2 = mul_c(conj(self.planewaves), output2)
            output2 = torch.fft(output2, 2)
            output2 = mul_c(self.pupils, output2)
            output2 = torch.ifft(output2, 2)
            g_tmp = torch.matmul(output2.permute(1, 2, 3, 0),
                                 self.C[meas_index, :])
            g = g + g_tmp
#         return -1 * self.alpha * g, cost
        return g
コード例 #21
0
def loss_QSMnet(outputs, QSMs, Masks, D):
    # l1 loss
    loss = lossL1()
    outputs = outputs[:, 0:1, ...]
    device = outputs.get_device()

    outputs_cplx = torch.zeros(*(outputs.size() + (2, ))).to(device)
    outputs_cplx[..., 0] = outputs

    QSMs_cplx = torch.zeros(*(QSMs.size() + (2, ))).to(device)
    QSMs_cplx[..., 0] = QSMs

    D = np.repeat(D[np.newaxis, np.newaxis, ..., np.newaxis],
                  outputs.size()[0],
                  axis=0)
    D_cplx = np.concatenate((D, np.zeros(D.shape)), axis=-1)
    D_cplx = torch.tensor(D_cplx, device=device).float()

    RDFs_outputs = torch.ifft(cplx_mlpy(torch.fft(outputs_cplx, 3), D_cplx), 3)
    RDFs_QSMs = torch.ifft(cplx_mlpy(torch.fft(QSMs_cplx, 3), D_cplx), 3)

    errl1 = loss(outputs * Masks, QSMs * Masks)
    errModel = loss(RDFs_outputs[..., 0] * Masks, RDFs_QSMs[..., 0] * Masks)
    errl1_grad = loss(abs(dxp(outputs)) * Masks,
                      abs(dxp(QSMs)) * Masks) + loss(
                          abs(dyp(outputs)) * Masks,
                          abs(dyp(QSMs)) * Masks) + loss(
                              abs(dzp(outputs)) * Masks,
                              abs(dzp(QSMs)) * Masks)
    errModel_grad = loss(
        abs(dxp(RDFs_outputs[..., 0])) * Masks,
        abs(dxp(RDFs_QSMs[..., 0])) * Masks) + loss(
            abs(dyp(RDFs_outputs[..., 0])) * Masks,
            abs(dyp(RDFs_QSMs[..., 0])) * Masks) + loss(
                abs(dzp(RDFs_outputs[..., 0])) * Masks,
                abs(dzp(RDFs_QSMs[..., 0])) * Masks)
    errGrad = errl1_grad + errModel_grad
    return errl1 + errModel + 0.1 * errGrad
コード例 #22
0
def Fourier_based_Corruption(dataset, imgsize, position):
    CUDA_AVAILABLE = torch.cuda.is_available()
    N = imgsize
    origin_value = 1.0
    i, j = position[0], position[1]
    print("position({},{})".format(i, j))
    testset = dataset
    samples_size = dataset.__len__()
    samples = np.array(range(samples_size))
    loader = transforms.Compose([transforms.ToTensor()])

    F_base_vec = torch.zeros(
        (N, N, 2)).cuda() if (CUDA_AVAILABLE) else torch.zeros((N, N, 2))
    F_base_vec[i][j][0] = F_base_vec[i][j][0] = origin_value
    Uij = torch.ifft(F_base_vec, 2)[:, :, 0].cuda() if (
        CUDA_AVAILABLE) else torch.ifft(F_base_vec, 2)[:, :, 0]
    Uij /= torch.norm(Uij, p=2)
    result = torch.zeros(
        (samples_size, 3, N, N)).cuda() if (CUDA_AVAILABLE) else torch.zeros(
            (samples_size, 3, N, N))

    for k in range(samples_size):
        img = testset[samples[k]][0]
        img_array = loader(img)
        img_new_array = torch.zeros(
            (3, N, N)).cuda() if (CUDA_AVAILABLE) else torch.zeros((3, N, N))
        for channel in range(3):
            img_one_channel = torch.Tensor(img_array[channel, :, :])
            img_one_channel = img_one_channel.cuda() if (
                CUDA_AVAILABLE) else img_one_channel
            L2norm = torch.norm(img_one_channel, p=2) * 0.1
            r = 1
            rvUij = Uij * L2norm * r
            img_one_channel += rvUij
            img_new_array[channel, :, :] = img_one_channel
        result[k] = img_new_array
    result = result.cpu()
    return result
コード例 #23
0
    def forward(self, x, y):
        x = x.squeeze(2)
        y = y.squeeze(2)
        x = x.permute([0, 2, 3, 1])
        y = y.permute([0, 2, 3, 1])

        cEs = self.batch_fftshift2d(torch.fft(x, 3, normalized=True))
        cEsp = self.complex_mult(cEs, self.prop)

        S = torch.ifft(self.batch_ifftshift2d(cEsp), 3, normalized=True)
        Se = S[:, :, :, 0]

        mse = torch.mean(torch.abs(Se - y[:, :, :, 0])) / 2
        return mse
コード例 #24
0
 def test_ifft_unitary(self):
     batch_size = 10
     n = 16
     input = torch.randn(batch_size, n, dtype=torch.complex64)
     normalized = True
     out_torch = view_as_complex(
         torch.ifft(view_as_real(input),
                    signal_ndim=1,
                    normalized=normalized))
     for br_first in [True, False]:
         b = torch_butterfly.special.ifft_unitary(n, br_first=br_first)
         out = b(input)
         self.assertTrue(
             torch.allclose(out, out_torch, self.rtol, self.atol))
コード例 #25
0
def ifft(x): # input is assumed to be a tensor of size mbs x n
    a = x[0]
    b = x[1]
    bs = a.size()[0]
    nu = a.size()[1]
    a2 = a.view(bs, 1, nu)
    a3 = torch.transpose(a2, 1, 2)
    b2 = b.view(bs, 1, nu)
    b3 = torch.transpose(b2, 1, 2).view(bs, nu, 1)
    x_in = torch.cat([a3,b3], dim=2)
    p = torch.ifft(x_in, 1, normalized=True)
    out_re = p[:,:,0].view(bs, nu)
    out_im = p[:,:,1].view(bs, nu)
    return (out_re, out_im)
コード例 #26
0
ファイル: deconv.py プロジェクト: sourav22899/MELD
    def reverse(self, z, device='cpu'):
        with torch.no_grad():
            if self.fullInvFlag:
                ys = torch.stack((self.y, torch.zeros_like(self.y)), 2)
                Fy = torch.fft(ys, 2)
                AHy = torch.ifft(mul_c(conj(self.fpsf), Fy), 2)[..., 0]
                aAHy = self.alpha * AHy

                xkpaAHy = z - aAHy

                ts = torch.stack((xkpaAHy, torch.zeros_like(xkpaAHy)), 2)
                Ft = torch.fft(ts, 2)
                AHA = mul_c(conj(self.fpsf), self.fpsf)
                I = torch.zeros_like(AHA)
                I[..., 0] = 1
                ImaAHA = I - self.alpha * AHA
                x = torch.ifft(div_c(Ft, ImaAHA), 2)
                return x[..., 0]
            else:
                x = z
                for _ in range(self.T):
                    x = z - self.step(x)
                return x
コード例 #27
0
def wiener_filt_torch(Y, R, Np, batch_dim=False):
    if not batch_dim:
        Y, R = Y.unsqueeze(0), R.unsqueeze(0)

    Yc = complexify(Y)
    W = torch.ifft(Yc, signal_ndim=2)

    n, s1, s2 = R.shape
    s2 *= 3

    S_auto = torch_fft_shift(W, dims=(1, 2))
    XR_cross = S_auto[:, :, :2 * R.shape[1]]

    _, t1, t2, _ = XR_cross.shape

    R_ = torch.zeros((n, t1, t2, 2), dtype=torch.float, device=W.device)
    R_[:, :s1, :R.shape[2], 0] = R
    F_R = torch.fft(R_, signal_ndim=2)
    F_R_conj = torch.clone(F_R)
    F_R_conj[..., 1] *= -1
    F_R_abs = (F_R**2).sum(-1, keepdim=True)

    F_SNR = Np * torch.norm(Y, dim=(1, 2))**2 / torch.norm(Y, p=1,
                                                           dim=(1, 2))**2
    print('XR_cross.shape!', XR_cross.shape, torch.norm(Y, dim=(1, 2)).shape)
    F_SNR = F_SNR.reshape(n, 1, 1, 1)
    X_ = torch.ifft(complex_mult(torch.fft(XR_cross, signal_ndim=2), F_R) /
                    (F_R_abs + 1 / F_SNR),
                    signal_ndim=2)
    print('X_.shape', X_.shape, R.shape[1], R.shape[2])

    X = X_[:, R.shape[1]:, R.shape[2]:]

    if not batch_dim:
        X = X[0]

    return X
コード例 #28
0
    def forward(self, x):
        if self.dr is not None:
            x = self.conv_dr_block(x)
        bsn = 1
        batchSize, dim, h, w = x.data.shape
        x_flat = x.permute(0, 2, 3,
                           1).contiguous().view(-1,
                                                dim)  # batchsize,h, w, dim,
        y = torch.ones(batchSize, self.output_dim, device=x.device)

        for img in range(batchSize // bsn):
            segLen = bsn * h * w
            upper = batchSize * h * w
            interLarge = torch.arange(img * segLen,
                                      min(upper, (img + 1) * segLen),
                                      dtype=torch.long)
            interSmall = torch.arange(img * bsn,
                                      min(upper, (img + 1) * bsn),
                                      dtype=torch.long)
            batch_x = x_flat[interLarge, :]

            sketch1 = batch_x.mm(self.sparseM[0].to(x.device)).unsqueeze(2)
            sketch1 = torch.fft(
                torch.cat(
                    (sketch1, torch.zeros(sketch1.size(), device=x.device)),
                    dim=2), 1)

            sketch2 = batch_x.mm(self.sparseM[1].to(x.device)).unsqueeze(2)
            sketch2 = torch.fft(
                torch.cat(
                    (sketch2, torch.zeros(sketch2.size(), device=x.device)),
                    dim=2), 1)

            Re = sketch1[:, :, 0].mul(sketch2[:, :, 0]) - sketch1[:, :, 1].mul(
                sketch2[:, :, 1])
            Im = sketch1[:, :, 0].mul(sketch2[:, :, 1]) + sketch1[:, :, 1].mul(
                sketch2[:, :, 0])

            tmp_y = torch.ifft(
                torch.cat((Re.unsqueeze(2), Im.unsqueeze(2)), dim=2), 1)[:, :,
                                                                         0]

            y[interSmall, :] = tmp_y.view(
                torch.numel(interSmall), h, w,
                self.output_dim).sum(dim=1).sum(dim=1)

        y = self._signed_sqrt(y)
        y = self._l2norm(y)
        return y
コード例 #29
0
def phase_correlation(a, b):
    B, H, W = a.size()
    a = a.unsqueeze(dim=-1).expand(B, H, W, 2)
    b = b.unsqueeze(dim=-1).expand(B, H, W, 2)
    G_a = torch.fft(a, signal_ndim=2)
    G_b = torch.fft(b, signal_ndim=2)
    conj_b = torch.conj(G_b)
    R = G_a * conj_b
    R /= torch.abs(R)
    r = torch.ifft(R, signal_ndim=2)
    r = torch.split(r, 1, dim=-1)[0].squeeze(-1)
    shift = r.view(B, -1).argmax(dim=1)
    shift = torch.cat(((shift / W).view(-1, 1), (shift % W).view(-1, 1)),
                      dim=1)
    return shift
コード例 #30
0
ファイル: iradon.py プロジェクト: AlbertZhangHIT/torch-radon
def filterProjections(radon_img, filter_mode, d=1.):
	length = radon_img.size(0)
	H = designFilter(filter_mode, length, d)
	p = torch.zeros(len(H), radon_img.size(1), 2) # p holds fft of projections
	p[0:length, :, 0] = radon_img # zero pad

	fp = torch.fft(p.permute(1,0,2), signal_ndim=1)

	H_expand = H.unsqueeze(0).expand([fp.size(0), fp.size(1)]).unsqueeze(-1).expand(*fp.size())
	fp = fp * H_expand # frequency domain filtering
	p = torch.ifft(fp, signal_ndim=1).permute(1,0,2)
	p = p[...,0] # real part
	p = p[0:length, :] #Truncate the filtered projection

	return p.contiguous() # method 'contiguous' is vitally important, if not it will cause memory leaking
コード例 #31
0
    def forward(self, x):
        image = x['image'].permute(0, 2, 3, 1) # prepare for torch.fft

        temp = torch.fft(image, signal_ndim=2, normalized=True)
        
        if self.noise:
            temp = (temp + self.noise * x['k']) / (1 + self.noise)
        else:
            temp = (1 - x['mask']) * temp + x['k']

        temp = torch.ifft(temp, signal_ndim=2, normalized=True)

        temp = temp.permute(0, 3, 1, 2).float() 
        x['image'] = temp
        return x