Beispiel #1
0
    def __getitem__(self, i):
        fname, slice_id = self.examples[i]
        with h5py.File(fname, "r") as data:
            kspace = data["kspace"][slice_id]
            kspace = torch.from_numpy(np.stack([kspace.real, kspace.imag], axis=-1))

            # For 1.8+
            # pytorch now offers a complex64 data type
            kspace = torch.view_as_complex(kspace)
            kspace = ifftshift(kspace, dim=(0, 1))
            # norm=forward means no normalization
            target = ifft(kspace, dim=(0, 1), norm="forward")
            target = ifftshift(target, dim=(0, 1))

            # Plot images to confirm fft worked
            # t_img = complex_magnitude(target)
            # print(t_img.dtype, t_img.shape)
            # plt.imshow(t_img)
            # plt.show()
            # plt.imshow(target.real)
            # plt.show()

            # center crop and resize
            # target = torch.unsqueeze(target, dim=0)
            # target = center_crop(target, (128, 128))
            # target = torch.squeeze(target)

            # Crop out ends
            target = np.stack([target.real, target.imag], axis=-1)
            target = target[100:-100, 24:-24, :]

            # Downsample in image space
            shape = target.shape
            target = tf.image.resize(
                target,
                (IMG_H, IMG_W),
                method="lanczos5",
                # preserve_aspect_ratio=True,
                antialias=True,
            ).numpy()

            # Get kspace of cropped image
            target = torch.view_as_complex(torch.from_numpy(target))
            kspace = fftshift(target, dim=(0, 1))
            kspace = fft(kspace, dim=(0, 1))
            # Realign kspace to keep high freq signal in center
            # Note that original fastmri code did not do this...
            kspace = fftshift(kspace, dim=(0, 1))

            # Normalize using mean of k-space in training data
            target /= 7.072103529760345e-07
            kspace /= 7.072103529760345e-07

        return kspace, target
Beispiel #2
0
    def __getitem__(self, i):
        fname, slice_id = self.examples[i]
        with h5py.File(fname, "r") as data:
            kspace = data["kspace"][slice_id]
            kspace = torch.from_numpy(np.stack([kspace.real, kspace.imag], axis=-1))
            kspace = torch.view_as_complex(kspace)
            kspace = ifftshift(kspace, dim=(0, 1))
            target = ifft(kspace, dim=(0, 1), norm="forward")
            target = ifftshift(target, dim=(0, 1))

            # transform
            target = torch.stack([target.real, target.imag])
            target = self.deform(target)  # outputs numpy
            target = torch.from_numpy(target)
            target = target.permute(1, 2, 0).contiguous()
            # center crop and resize
            # target = center_crop(target, (128, 128))
            # target = resize(target, (128,128))

            # Crop out ends
            target = target.numpy()[100:-100, 24:-24, :]
            # Downsample in image space
            target = tf.image.resize(
                target,
                (IMG_H, IMG_W),
                method="lanczos5",
                # preserve_aspect_ratio=True,
                antialias=True,
            ).numpy()

            # Making contiguous is necessary for complex view
            target = torch.from_numpy(target)
            target = target.contiguous()
            target = torch.view_as_complex(target)

            kspace = fftshift(target, dim=(0, 1))
            kspace = fft(kspace, dim=(0, 1))
            kspace = fftshift(kspace, dim=(0, 1))

            # Normalize using mean of k-space in training data
            target /= 7.072103529760345e-07
            kspace /= 7.072103529760345e-07

        return kspace, target
Beispiel #3
0
    def reconstruct_from_pyramid(self, pyramid):
        # Work out pyramid parameters, and check they match current settings.
        n_orientations = len(pyramid[0]['b'])
        n_levels = len(pyramid)
        size = pyramid[0]['h'].size()
        if n_orientations != self.n_orientations or \
            n_levels != self.n_levels or \
            size != self.image_size:
            self.n_orientations = n_orientations
            self.image_size = size
            self.n_levels = n_levels
            self.calculate_filters()

        curr = fftshift(fft2(pyramid[-1]['l']))
        for l in range(len(pyramid) - 2, -1, -1):
            # Upsample the current reconstruction
            tmp = torch.zeros([
                curr.size(0),
                curr.size(1),
                curr.size(2) * 2,
                curr.size(3) * 2
            ],
                              dtype=torch.complex64)
            offsety = curr.size(-2) // 2
            offsetx = curr.size(-1) // 2
            tmp[:, :, offsety:3 * offsety, offsetx:3 * offsetx] = curr * 4
            curr = tmp

            curr = curr * self.L_FILT[l]

            for b in range(len(self.B_FILT[l])):
                curr += self.B_FILT[l][b] * fftshift(fft2(pyramid[l]['b'][b]))

        reconstruction = curr * self.L0_FILT + fftshift(fft2(
            pyramid[0]['h'])) * self.H0_FILT

        return torch.real(ifft2(ifftshift(reconstruction)))
Beispiel #4
0
    def construct_pyramid(self, image, n_levels, n_orientations):
        if image.size() != self.image_size or \
            n_levels != self.n_levels or \
            n_orientations != self.n_orientations:
            # Need to recalculate the filters.
            self.image_size = image.size()
            self.n_levels = n_levels
            self.n_orientations = n_orientations
            self.calculate_filters()
        ft = fftshift(fft2(image))

        curr_level = {}
        pyramid = []
        h0 = self.H0_FILT * ft
        curr_level['h'] = torch.real(ifft2(ifftshift(h0)))

        l0 = self.L0_FILT * ft
        curr_level['l'] = torch.real(ifft2(ifftshift(l0)))

        # apply bandpass filter(B) and downsample iteratively. save pyramid
        _last = l0
        for i in range(self.n_levels):
            curr_level['b'] = []
            for j in range(len(self.B_FILT[i])):
                lb = _last * self.B_FILT[i][j]
                curr_level['b'].append(torch.real(ifft2(ifftshift(lb))))

            # apply lowpass filter(L) to image(Fourier Domain) downsampled.
            l1 = _last * self.L_FILT[i]

            ## Downsampling
            down_size = [l1.size(-2) // 4, l1.size(-1) // 4]

            # extract the central part of DFT
            down_image = l1[:, :, down_size[0]:3 * down_size[0],
                            down_size[1]:3 * down_size[1]] / 4
            #
            _last = down_image.clone()
            pyramid.append(curr_level)
            curr_level = {}

        # lowpass residual
        curr_level['l'] = torch.real(ifft2(ifftshift(_last)))
        pyramid.append(curr_level)
        return pyramid
    def forward(self, x, up_feat_in):
        # separate feature for two frequency
        freq_x = fft.fftn(x, dim=(-2, -1))
        freq_shift = fft.fftshift(freq_x, dim=(-2, -1))

        # low_freq_shift = self.easy_low_pass_filter(freq_x)
        # high_freq_shift = self.easy_high_pass_filter(freq_x)
        low_freq_shift, high_freq_shift = self.guassian_low_high_pass_filter(
            freq_shift)

        low_freq_ishift = fft.ifftshift(low_freq_shift, dim=(-2, -1))
        high_freq_ishift = fft.ifftshift(high_freq_shift, dim=(-2, -1))

        _low_freq_x = torch.abs(fft.ifftn(low_freq_ishift, dim=(-2, -1)))
        _high_freq_x = torch.abs(fft.ifftn(high_freq_ishift, dim=(-2, -1)))

        low_freq_x = self.low_project(_low_freq_x)
        high_freq_x = self.high_project(_high_freq_x)

        feat = torch.cat([x, low_freq_x, high_freq_x], dim=1)
        context = self.out_project(feat)
        fuse_feature = context + x  # Whether use skip connection or not

        if self.up_flag and self.smf_flag:
            if up_feat_in is not None:
                fuse_feature = self.upsample_add(up_feat_in, fuse_feature)
            up_feature = self.up(fuse_feature)
            smooth_feature = self.smooth(fuse_feature)
            return up_feature, smooth_feature

        if self.up_flag and not self.smf_flag:
            if up_feat_in is not None:
                fuse_feature = self.upsample_add(up_feat_in, fuse_feature)
            up_feature = self.up(fuse_feature)
            return up_feature

        if not self.up_flag and self.smf_flag:
            if up_feat_in is not None:
                fuse_feature = self.upsample_add(up_feat_in, fuse_feature)
            smooth_feature = self.smooth(fuse_feature)
            return smooth_feature
Beispiel #6
0
    def forward(self, x):
        # self.writer = writer
        freq_x = fft.fftn(x)
        freq_shift = fft.fftshift(freq_x)
        
        # low_freq_shift = self.easy_low_pass_filter(freq_x)
        # high_freq_shift = self.easy_high_pass_filter(freq_x)
        low_freq_shift, high_freq_shift = self.guassian_low_high_pass_filter(freq_shift)

        # low_freq_ishift = fft.ifftshift(low_freq_shift)
        high_freq_ishift = fft.ifftshift(high_freq_shift)
        
        # _low_freq_x = torch.abs(fft.ifftn(low_freq_ishift))
        _high_freq_x = torch.abs(fft.ifftn(high_freq_ishift))

        feat_rgb = self.sp(_high_freq_x)
        feat_dct = self.cp(x)
        feat_fuse = torch.cat((feat_rgb, feat_dct), dim=1)
        logits = self.head(feat_fuse)
        out = F.interpolate(logits, scale_factor=self.block_size, mode='bilinear', \
            align_corners=True)
        return out
Beispiel #7
0
            return torch.complex(y * _real(x), y * imag(x))
    xreal, yreal = py.make_list(real, 2)
    if xreal and yreal:
        return torch.mul(x, y)
    elif xreal:
        return x.unsqueeze(-1) * y
    elif yreal:
        return x * y.unsqueeze(-1)
    else:
        return complex(
            _real(x) * _real(y) - imag(x) * imag(y),
            _real(x) * imag(y) + imag(x) * _real(y))


if _torch_has_fftshift:
    fftshift = lambda x, dim, real=False: fft_mod.fftshift(
        torch.as_tensor(x), dim)
else:

    def fftshift(x, dim=None, real=False):
        """Move the first value to the center of the tensor.

        Notes
        -----
        .. If the dimension has an even shape, the center is the first
            position *after* the middle of the tensor: `c = s//2`
        .. This function triggers a copy of the data.
        .. If the dimension has an even shape, `fftshift` and `ifftshift`
           are equivalent.

        Parameters
        ----------
Beispiel #8
0
    plt.legend(['real', 'imag'])
    plt.subplot(222)
    plt.plot(t * 1e6, th.angle(St))
    plt.xlabel('Time/us')
    plt.subplot(223)
    plt.plot(t * 1e6, th.real(Sr))
    plt.plot(t * 1e6, th.imag(Sr))
    plt.xlabel('Time/us')
    plt.legend(['real', 'imag'])
    plt.subplot(224)
    plt.plot(t * 1e6, th.angle(Sr))
    plt.xlabel('Time/us')
    plt.show()

    # ---Frequency domain
    Yt = fftshift(fft(fftshift(St, dim=0), dim=0), dim=0)
    Yr = fftshift(fft(fftshift(Sr, dim=0), dim=0), dim=0)

    # ---Plot signals
    plt.figure(figsize=(10, 8))
    plt.subplot(221)
    plt.plot(t * 1e6, th.real(St))
    plt.grid()
    plt.title('Real part')
    plt.xlabel('Time/μs')
    plt.ylabel('Amplitude')
    plt.subplot(222)
    plt.plot(t * 1e6, th.imag(St))
    plt.grid()
    plt.title('Imaginary part')
    plt.xlabel('Time/μs')
Beispiel #9
0
def fft(x, n=None, axis=0, norm="backward", shift=False):
    """FFT in torchsar

    FFT in torchsar.

    Parameters
    ----------
    x : {torch array}
        complex representation is supported. Since torch1.7 and above support complex array,
        when :attr:`x` is in real-representation formation(last dimension is 2, real, imag),
        we will change the representation in complex formation, after FFT, it will be change back.
    n : int, optional
        number of fft points (the default is None --> equals to signal dimension)
    axis : int, optional
        axis of fft (the default is 0, which the first dimension)
    norm : {None or str}, optional
        Normalization mode. For the forward transform (fft()), these correspond to:
        - "forward" - normalize by ``1/n``
        - "backward" - no normalization (default)
        - "ortho" - normalize by ``1/sqrt(n)`` (making the FFT orthonormal)
    shift : bool, optional
        shift the zero frequency to center (the default is False)

    Returns
    -------
    y : {torch array}
        fft results torch array with the same type as :attr:`x`

    Raises
    ------
    ValueError
        nfft is small than signal dimension.
    """

    if norm is None:
        norm = 'backward'

    if (x.size(-1) == 2) and (not th.is_complex(x)):
        realflag = True
        x = th.view_as_complex(x)
        if axis < 0:
            axis += 1
    else:
        realflag = False

    d = x.size(axis)
    if n is None:
        n = d
    if d < n:
        x = padfft(x, n, axis, shift)
    elif d > n:
        raise ValueError('nfft is small than signal dimension!')

    if shift:
        y = thfft.fftshift(thfft.fft(thfft.fftshift(x, dim=axis),
                                     n=n,
                                     dim=axis,
                                     norm=norm),
                           dim=axis)
    else:
        y = thfft.fft(x, n=n, dim=axis, norm=norm)

    if realflag:
        y = th.view_as_real(y)

    return y