示例#1
0
def filter_grid(x: Tensor) -> Tuple[Tensor, Tensor]:
    r"""Returns the (quadrant-shifted) frequency grid for :math:`x`.

    Args:
        x: An input tensor, :math:`(*, H, W)`.

    Returns:
        The radius and phase tensors, both :math:`(H, W)`.

    Example:
        >>> x = torch.rand(5, 5)
        >>> r, phi = filter_grid(x)
        >>> r
        tensor([[0.0000, 0.2500, 0.5000, 0.5000, 0.2500],
                [0.2500, 0.3536, 0.5590, 0.5590, 0.3536],
                [0.5000, 0.5590, 0.7071, 0.7071, 0.5590],
                [0.5000, 0.5590, 0.7071, 0.7071, 0.5590],
                [0.2500, 0.3536, 0.5590, 0.5590, 0.3536]])
        >>> phi
        tensor([[-0.0000, -1.5708, -1.5708,  1.5708,  1.5708],
                [-0.0000, -0.7854, -1.1071,  1.1071,  0.7854],
                [-0.0000, -0.4636, -0.7854,  0.7854,  0.4636],
                [-3.1416, -2.6779, -2.3562,  2.3562,  2.6779],
                [-3.1416, -2.3562, -2.0344,  2.0344,  2.3562]])
    """

    u, v = [(torch.arange(n).to(x) - n // 2) / (n - n % 2)
            for n in x.shape[-2:]]
    u, v = fft.ifftshift(u[:, None]), fft.ifftshift(v[None, :])

    r = (u**2 + v**2).sqrt()
    phi = torch.atan2(-v, u)

    return r, phi
示例#2
0
文件: ffts.py 项目: aisari/torchsar
def ifft(x, n=None, axis=0, norm="backward", shift=False):
    """IFFT in torchsar

    IFFT in torchsar, since ifft in torch only supports complex-complex transformation,
    for real ifft, we insert imaginary part with zeros (torch.stack((x,torch.zeros_like(x), dim=-1))),
    also you can use torch's rifft.

    Parameters
    ----------
    x : {torch array}
        both complex and real representation are supported. Since torch does not
        support complex array, when :attr:`x` is complex, we will change the representation
        in real formation(last dimension is 2, real, imag), after IFFT, it will be change back.
    n : int, optional
        number of ifft points (the default is None --> equals to signal dimension)
    axis : int, optional
        axis of ifft (the default is 0, which the first dimension)
    norm : bool, optional
        Normalization mode. For the backward transform (ifft()), these correspond to:
        - "forward" - no normalization
        - "backward" - normalize by ``1/n`` (default)
        - "ortho" - normalize by 1``/sqrt(n)`` (making the IFFT orthonormal)
    shift : bool, optional
        shift the zero frequency to center (the default is False)

    Returns
    -------
    y : {torch array}
        ifft results torch array with the same type as :attr:`x`

    Raises
    ------
    ValueError
        nfft is small than signal dimension.
    """

    if norm is None:
        norm = 'backward'

    if (x.size(-1) == 2) and (not th.is_complex(x)):
        realflag = True
        x = th.view_as_complex(x)
        if axis < 0:
            axis += 1
    else:
        realflag = False

    if shift:
        y = thfft.ifftshift(thfft.ifft(thfft.ifftshift(x, dim=axis),
                                       n=n,
                                       dim=axis,
                                       norm=norm),
                            dim=axis)
    else:
        y = thfft.ifft(x, n=n, dim=axis, norm=norm)

    if realflag:
        y = th.view_as_real(y)

    return y
示例#3
0
    def __getitem__(self, i):
        fname, slice_id = self.examples[i]
        with h5py.File(fname, "r") as data:
            kspace = data["kspace"][slice_id]
            kspace = torch.from_numpy(np.stack([kspace.real, kspace.imag], axis=-1))

            # For 1.8+
            # pytorch now offers a complex64 data type
            kspace = torch.view_as_complex(kspace)
            kspace = ifftshift(kspace, dim=(0, 1))
            # norm=forward means no normalization
            target = ifft(kspace, dim=(0, 1), norm="forward")
            target = ifftshift(target, dim=(0, 1))

            # Plot images to confirm fft worked
            # t_img = complex_magnitude(target)
            # print(t_img.dtype, t_img.shape)
            # plt.imshow(t_img)
            # plt.show()
            # plt.imshow(target.real)
            # plt.show()

            # center crop and resize
            # target = torch.unsqueeze(target, dim=0)
            # target = center_crop(target, (128, 128))
            # target = torch.squeeze(target)

            # Crop out ends
            target = np.stack([target.real, target.imag], axis=-1)
            target = target[100:-100, 24:-24, :]

            # Downsample in image space
            shape = target.shape
            target = tf.image.resize(
                target,
                (IMG_H, IMG_W),
                method="lanczos5",
                # preserve_aspect_ratio=True,
                antialias=True,
            ).numpy()

            # Get kspace of cropped image
            target = torch.view_as_complex(torch.from_numpy(target))
            kspace = fftshift(target, dim=(0, 1))
            kspace = fft(kspace, dim=(0, 1))
            # Realign kspace to keep high freq signal in center
            # Note that original fastmri code did not do this...
            kspace = fftshift(kspace, dim=(0, 1))

            # Normalize using mean of k-space in training data
            target /= 7.072103529760345e-07
            kspace /= 7.072103529760345e-07

        return kspace, target
示例#4
0
    def construct_pyramid(self, image, n_levels, n_orientations):
        if image.size() != self.image_size or \
            n_levels != self.n_levels or \
            n_orientations != self.n_orientations:
            # Need to recalculate the filters.
            self.image_size = image.size()
            self.n_levels = n_levels
            self.n_orientations = n_orientations
            self.calculate_filters()
        ft = fftshift(fft2(image))

        curr_level = {}
        pyramid = []
        h0 = self.H0_FILT * ft
        curr_level['h'] = torch.real(ifft2(ifftshift(h0)))

        l0 = self.L0_FILT * ft
        curr_level['l'] = torch.real(ifft2(ifftshift(l0)))

        # apply bandpass filter(B) and downsample iteratively. save pyramid
        _last = l0
        for i in range(self.n_levels):
            curr_level['b'] = []
            for j in range(len(self.B_FILT[i])):
                lb = _last * self.B_FILT[i][j]
                curr_level['b'].append(torch.real(ifft2(ifftshift(lb))))

            # apply lowpass filter(L) to image(Fourier Domain) downsampled.
            l1 = _last * self.L_FILT[i]

            ## Downsampling
            down_size = [l1.size(-2) // 4, l1.size(-1) // 4]

            # extract the central part of DFT
            down_image = l1[:, :, down_size[0]:3 * down_size[0],
                            down_size[1]:3 * down_size[1]] / 4
            #
            _last = down_image.clone()
            pyramid.append(curr_level)
            curr_level = {}

        # lowpass residual
        curr_level['l'] = torch.real(ifft2(ifftshift(_last)))
        pyramid.append(curr_level)
        return pyramid
示例#5
0
    def ifftshift(x, dim=None):
        """Move the center value to the front of the tensor.

        Notes
        -----
        .. If the dimension has an even shape, the center is the first
            position *after* the middle of the tensor: `c = s//2`
        .. This function triggers a copy of the data.
        .. If the dimension has an even shape, `fftshift` and `ifftshift`
           are equivalent.

        Parameters
        ----------
        x : tensor
            Input tensor
        dim : [sequence of] int, default=all
            Dimensions to shift

        Returns
        -------
        x : tensor
            Shifted tensor

        """
        x = torch.as_tensor(x)
        if _torch_has_fftshift:
            if isinstance(dim, range):
                dim = tuple(dim)
            return fft_mod.ifftshift(x, dim)

        if dim is None:
            dim = list(range(x.dim()))
        dim = py.make_list(dim)
        if not _torch_has_complex and not real:
            dim = [d - (not real) if d < 0 else d for d in dim]
        if len(dim) > 1:
            x = x.clone()  # clone to get an additional buffer

        y = torch.empty_like(x)
        slicer = [slice(None)] * x.dim()
        for d in dim:
            # move back to front
            pre = list(slicer)
            pre[d] = slice(None, (x.shape[d] + 1) // 2)
            post = list(slicer)
            post[d] = slice(x.shape[d] // 2, None)
            y[tuple(pre)] = x[tuple(post)]
            # move front to back
            pre = list(slicer)
            pre[d] = slice(None, x.shape[d] // 2)
            post = list(slicer)
            post[d] = slice((x.shape[d] + 1) // 2, None)
            y[tuple(post)] = x[tuple(pre)]
            # exchange buffers
            x, y = y, x

        return x
示例#6
0
    def __getitem__(self, i):
        fname, slice_id = self.examples[i]
        with h5py.File(fname, "r") as data:
            kspace = data["kspace"][slice_id]
            kspace = torch.from_numpy(np.stack([kspace.real, kspace.imag], axis=-1))
            kspace = torch.view_as_complex(kspace)
            kspace = ifftshift(kspace, dim=(0, 1))
            target = ifft(kspace, dim=(0, 1), norm="forward")
            target = ifftshift(target, dim=(0, 1))

            # transform
            target = torch.stack([target.real, target.imag])
            target = self.deform(target)  # outputs numpy
            target = torch.from_numpy(target)
            target = target.permute(1, 2, 0).contiguous()
            # center crop and resize
            # target = center_crop(target, (128, 128))
            # target = resize(target, (128,128))

            # Crop out ends
            target = target.numpy()[100:-100, 24:-24, :]
            # Downsample in image space
            target = tf.image.resize(
                target,
                (IMG_H, IMG_W),
                method="lanczos5",
                # preserve_aspect_ratio=True,
                antialias=True,
            ).numpy()

            # Making contiguous is necessary for complex view
            target = torch.from_numpy(target)
            target = target.contiguous()
            target = torch.view_as_complex(target)

            kspace = fftshift(target, dim=(0, 1))
            kspace = fft(kspace, dim=(0, 1))
            kspace = fftshift(kspace, dim=(0, 1))

            # Normalize using mean of k-space in training data
            target /= 7.072103529760345e-07
            kspace /= 7.072103529760345e-07

        return kspace, target
示例#7
0
    def forward(self, x, up_feat_in):
        # separate feature for two frequency
        freq_x = fft.fftn(x, dim=(-2, -1))
        freq_shift = fft.fftshift(freq_x, dim=(-2, -1))

        # low_freq_shift = self.easy_low_pass_filter(freq_x)
        # high_freq_shift = self.easy_high_pass_filter(freq_x)
        low_freq_shift, high_freq_shift = self.guassian_low_high_pass_filter(
            freq_shift)

        low_freq_ishift = fft.ifftshift(low_freq_shift, dim=(-2, -1))
        high_freq_ishift = fft.ifftshift(high_freq_shift, dim=(-2, -1))

        _low_freq_x = torch.abs(fft.ifftn(low_freq_ishift, dim=(-2, -1)))
        _high_freq_x = torch.abs(fft.ifftn(high_freq_ishift, dim=(-2, -1)))

        low_freq_x = self.low_project(_low_freq_x)
        high_freq_x = self.high_project(_high_freq_x)

        feat = torch.cat([x, low_freq_x, high_freq_x], dim=1)
        context = self.out_project(feat)
        fuse_feature = context + x  # Whether use skip connection or not

        if self.up_flag and self.smf_flag:
            if up_feat_in is not None:
                fuse_feature = self.upsample_add(up_feat_in, fuse_feature)
            up_feature = self.up(fuse_feature)
            smooth_feature = self.smooth(fuse_feature)
            return up_feature, smooth_feature

        if self.up_flag and not self.smf_flag:
            if up_feat_in is not None:
                fuse_feature = self.upsample_add(up_feat_in, fuse_feature)
            up_feature = self.up(fuse_feature)
            return up_feature

        if not self.up_flag and self.smf_flag:
            if up_feat_in is not None:
                fuse_feature = self.upsample_add(up_feat_in, fuse_feature)
            smooth_feature = self.smooth(fuse_feature)
            return smooth_feature
示例#8
0
    def forward(self, x):
        # self.writer = writer
        freq_x = fft.fftn(x)
        freq_shift = fft.fftshift(freq_x)
        
        # low_freq_shift = self.easy_low_pass_filter(freq_x)
        # high_freq_shift = self.easy_high_pass_filter(freq_x)
        low_freq_shift, high_freq_shift = self.guassian_low_high_pass_filter(freq_shift)

        # low_freq_ishift = fft.ifftshift(low_freq_shift)
        high_freq_ishift = fft.ifftshift(high_freq_shift)
        
        # _low_freq_x = torch.abs(fft.ifftn(low_freq_ishift))
        _high_freq_x = torch.abs(fft.ifftn(high_freq_ishift))

        feat_rgb = self.sp(_high_freq_x)
        feat_dct = self.cp(x)
        feat_fuse = torch.cat((feat_rgb, feat_dct), dim=1)
        logits = self.head(feat_fuse)
        out = F.interpolate(logits, scale_factor=self.block_size, mode='bilinear', \
            align_corners=True)
        return out
示例#9
0
    def reconstruct_from_pyramid(self, pyramid):
        # Work out pyramid parameters, and check they match current settings.
        n_orientations = len(pyramid[0]['b'])
        n_levels = len(pyramid)
        size = pyramid[0]['h'].size()
        if n_orientations != self.n_orientations or \
            n_levels != self.n_levels or \
            size != self.image_size:
            self.n_orientations = n_orientations
            self.image_size = size
            self.n_levels = n_levels
            self.calculate_filters()

        curr = fftshift(fft2(pyramid[-1]['l']))
        for l in range(len(pyramid) - 2, -1, -1):
            # Upsample the current reconstruction
            tmp = torch.zeros([
                curr.size(0),
                curr.size(1),
                curr.size(2) * 2,
                curr.size(3) * 2
            ],
                              dtype=torch.complex64)
            offsety = curr.size(-2) // 2
            offsetx = curr.size(-1) // 2
            tmp[:, :, offsety:3 * offsety, offsetx:3 * offsetx] = curr * 4
            curr = tmp

            curr = curr * self.L_FILT[l]

            for b in range(len(self.B_FILT[l])):
                curr += self.B_FILT[l][b] * fftshift(fft2(pyramid[l]['b'][b]))

        reconstruction = curr * self.L0_FILT + fftshift(fft2(
            pyramid[0]['h'])) * self.H0_FILT

        return torch.real(ifft2(ifftshift(reconstruction)))
示例#10
0
 def ifftshift(x, dim=None):
     x = torch.as_tensor(x)
     if isinstance(dim, range):
         dim = tuple(dim)
     return fft_mod.ifftshift(x, dim)