def _visualize_outputs(c_img_recons, c_img_targets, smoothing_factor=8):
     image_recons = complex_abs(c_img_recons)
     image_targets = complex_abs(c_img_targets)
     kspace_recons = make_k_grid(fft2(c_img_recons), smoothing_factor)
     kspace_targets = make_k_grid(fft2(c_img_targets), smoothing_factor)
     image_recons, image_targets, image_deltas = make_grid_triplet(
         image_recons, image_targets)
     return kspace_recons, kspace_targets, image_recons, image_targets, image_deltas
    def forward(self, cmg_output, targets, extra_params):
        if cmg_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        cmg_target = targets['cmg_targets']
        cmg_recon = nchw_to_kspace(cmg_output)
        assert cmg_recon.shape == cmg_target.shape, 'Reconstruction and target sizes are different.'
        assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        if self.residual_acs:  # Adding the semi-k-space of the ACS as a residual. Necessary due to complex cropping.
            raise NotImplementedError('Not ready yet.')
            # cmg_acs = targets['cmg_acss']
            # cmg_recon = cmg_recon + cmg_acs

        kspace_recon = fft2(cmg_recon)
        img_recon = complex_abs(cmg_recon)

        recons = {
            'kspace_recons': kspace_recon,
            'cmg_recons': cmg_recon,
            'img_recons': img_recon
        }

        if self.challenge == 'multicoil':
            rss_recon = center_crop(img_recon, (
                self.resolution, self.resolution)) * extra_params['cmg_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
    def forward(self, outputs, targets, extra_params):
        img_recon, phase_recon = outputs

        if img_recon.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        img_target = targets['img_targets']
        img_recon = F.relu(img_recon)  # img_recons must be positive numbers.
        assert img_recon.shape == img_target.shape, 'Reconstruction and target sizes are different.'

        # Input transform had addition of pi as pre-processing.
        phase_recon = phase_recon - math.pi  # No clamping implemented since the loss is MSE.
        cmg_recon = torch.stack([
            img_recon * torch.cos(phase_recon),
            img_recon * torch.sin(phase_recon)
        ],
                                dim=-1)
        kspace_recon = fft2(cmg_recon)

        recons = {
            'img_recons': img_recon,
            'phase_recons': phase_recon,
            'cmg_recons': cmg_recon,
            'kspace_recons': kspace_recon
        }

        if self.challenge == 'multicoil':
            rss_recon = center_crop(
                img_recon,
                shape=(self.resolution,
                       self.resolution)) * extra_params['img_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
def test_fft2(shape):
    shape = shape + [2]
    tensor = create_tensor(shape)
    out_torch = data_transforms.fft2(tensor).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    tensor_numpy = data_transforms.tensor_to_complex_np(tensor)
    tensor_numpy = np.fft.ifftshift(tensor_numpy, (-2, -1))
    out_numpy = np.fft.fft2(tensor_numpy, norm='ortho')
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))
    assert np.allclose(out_torch, out_numpy)
 def _cmg_output(cmg_output, targets, extra_params):
     cmg_target = targets['cmg_targets']
     cmg_recon = nchw_to_kspace(
         cmg_output)  # Assumes data was cropped already.
     assert cmg_recon.shape == cmg_target.shape, 'Reconstruction and target sizes are different.'
     assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \
         'Not impossible but not expected to have sides with odd lengths.'
     cmg_recon = cmg_recon + targets[
         'cmg_inputs']  # Residual of complex input.
     kspace_recon = fft2(cmg_recon)
     img_recon = complex_abs(cmg_recon)
     recons = {
         'kspace_recons': kspace_recon,
         'cmg_recons': cmg_recon,
         'img_recons': img_recon
     }
     return recons
Exemple #6
0
    def forward(self, cmg_output: Tensor, targets: dict, extra_params: dict):
        assert cmg_output.dim() == 5 and cmg_output.size(
            1) == 2, 'Invalid shape!'
        if cmg_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        kspace_target = targets['kspace_targets']
        cmg_recon = cmg_output.permute(dims=(0, 2, 3, 4,
                                             1))  # Convert back into NCHW2

        if cmg_recon.shape != kspace_target.shape:  # Cropping recon left-right.
            left = (cmg_recon.size(-2) - kspace_target.size(-2)) // 2
            cmg_recon = cmg_recon[..., left:left + kspace_target.size(-2), :]

        assert cmg_recon.shape == kspace_target.shape, 'Reconstruction and target sizes are different.'
        assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        kspace_recon = fft2(cmg_recon)

        if self.replace_kspace:
            mask = extra_params['masks']
            kspace_recon = kspace_target * mask + (1 - mask) * kspace_recon
            cmg_recon = ifft2(kspace_recon)

        img_recon = complex_abs(cmg_recon)

        # recons = {'kspace_recons': kspace_recon, 'cmg_recons': cmg_recon, 'img_recons': img_recon}
        recons = dict()

        if self.challenge == 'multicoil':
            rss_recon = center_crop(img_recon, (
                self.resolution, self.resolution)) * extra_params['cmg_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
import torch
import numpy as np
from data.data_transforms import ifft2, fft2

a = torch.rand(20, 40, 60, 92, 2, device='cuda:1')
b = ifft2(a * 1E8)
c = fft2(b) * 1E-8

# print(torch.all(a == c))
# print(torch.allclose(a, c, rtol=0.01))
eps = np.finfo(np.float64).eps

print(torch.max(c / (a + eps)))
print(torch.min(c / (a + eps)))
print(torch.mean(c / (a + eps)).cpu().numpy())

# print(torch.sum(a != c) / (20 * 40 * 60 * 92 * 2))
def check_invertible():
    orig = torch.rand(4, 6, 8, 12, 2, dtype=torch.float64) * 1024 - 64
    trans = fft2(orig) * 100
    trans = ifft2(trans) / 100
    print(torch.allclose(orig, trans))
import torch
from data.data_transforms import ifft2, fft2, complex_abs

image = torch.rand(10, 20, 30, 2)
lr_flip = torch.flip(image, dims=[-2])
ud_flip = torch.flip(image, dims=[-3])
all_flip = torch.flip(image, dims=[-3, -2])

kspace = fft2(image)
lr_kspace = fft2(lr_flip)
ud_kspace = fft2(ud_flip)
all_kspace = fft2(all_flip)

absolute = torch.sum(complex_abs(kspace))
lr_abs = torch.sum(complex_abs(lr_kspace))
ud_abs = torch.sum(complex_abs(ud_kspace))
all_abs = torch.sum(complex_abs(all_kspace))

a = torch.allclose(absolute, lr_abs)
b = torch.allclose(absolute, ud_abs)
c = torch.allclose(absolute, all_abs)

print(a, b, c)


    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            num_low_freqs = info['num_low_frequency']
            acs_mask = self.find_acs_mask(kspace_target, num_low_freqs)
            acs_kspace = kspace_target * acs_mask
            semi_kspace_acs = fft1(complex_center_crop(
                ifft2(acs_kspace), shape=(self.resolution, self.resolution)),
                                   direction='width')

            complex_image = ifft2(masked_kspace)
            complex_image = complex_center_crop(complex_image,
                                                shape=(self.resolution,
                                                       self.resolution))
            # img_input is not actually an input but what the input would look like in the image domain.
            img_input = complex_abs(complex_image)

            # Direction is fixed due to challenge conditions.
            semi_kspace = fft1(complex_image, direction='width')
            weighting = self.weight_func(semi_kspace)
            semi_kspace *= weighting

            sk_scale = torch.std(semi_kspace)
            semi_kspace /= sk_scale
            inputs = kspace_to_nchw(semi_kspace)

            extra_params = {
                'sk_scales': sk_scale,
                'masks': mask,
                'weightings': weighting
            }
            extra_params.update(info)
            extra_params.update(attrs)

            # Recall that the Fourier transform is a linear transform.
            cmg_target = ifft2(kspace_target)
            cmg_target = complex_center_crop(cmg_target,
                                             shape=(self.resolution,
                                                    self.resolution))
            cmg_target /= sk_scale
            img_target = complex_abs(cmg_target)
            semi_kspace_target = fft1(cmg_target, direction='width')
            kspace_target = fft2(cmg_target)

            # Use plurals as keys to reduce confusion.
            targets = {
                'semi_kspace_targets': semi_kspace_target,
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'img_inputs': img_input,
                'semi_kspace_acss': semi_kspace_acs
            }

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

        return inputs, targets, extra_params
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            # Complex image made from down-sampled k-space.
            complex_image = ifft2(masked_kspace)

            if self.crop_center:
                complex_image = complex_center_crop(complex_image,
                                                    shape=(self.resolution,
                                                           self.resolution))

            cmg_scale = torch.std(complex_image)
            complex_image /= cmg_scale

            extra_params = {'cmg_scales': cmg_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Recall that the Fourier transform is a linear transform.
            kspace_target /= cmg_scale
            cmg_target = ifft2(kspace_target)

            if self.crop_center:
                cmg_target = complex_center_crop(cmg_target,
                                                 shape=(self.resolution,
                                                        self.resolution))

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:  # No rotation implemented.
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    # Last dim is real/complex dimension for complex image and target.
                    complex_image = torch.flip(complex_image, dims=(-3, -2))
                    cmg_target = torch.flip(cmg_target, dims=(-3, -2))
                    target = torch.flip(target, dims=(
                        -2, -1))  # Has only two dimensions, height and width.
                    kspace_target = fft2(cmg_target)

                elif flip_ud:
                    complex_image = torch.flip(complex_image, dims=(-3, ))
                    cmg_target = torch.flip(cmg_target, dims=(-3, ))
                    target = torch.flip(target, dims=(-2, ))
                    kspace_target = fft2(cmg_target)

                elif flip_lr:
                    complex_image = torch.flip(complex_image, dims=(-2, ))
                    cmg_target = torch.flip(cmg_target, dims=(-2, ))
                    target = torch.flip(target, dims=(-1, ))
                    kspace_target = fft2(cmg_target)

            # The image target is obtained after flipping the complex image.
            # This removes the need to flip the image target.
            img_target = complex_abs(cmg_target)
            img_inputs = complex_abs(complex_image)

            # Use plurals as keys to reduce confusion.
            targets = {
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'cmg_inputs': complex_image,
                'img_inputs': img_inputs
            }

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

            # Creating concatenated image of real/imag/abs channels.
            concat_image = torch.cat(
                [complex_image, img_inputs.unsqueeze(dim=-1)], dim=-1)

            # Converting to NCHW format for CNN.
            inputs = kspace_to_nchw(concat_image)

        return inputs, targets, extra_params
Exemple #12
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            complex_image = ifft2(masked_kspace)
            cmg_target = ifft2(kspace_target)

            if self.crop_center:
                complex_image = complex_center_crop(complex_image,
                                                    shape=(self.resolution,
                                                           self.resolution))
                cmg_target = complex_center_crop(cmg_target,
                                                 shape=(self.resolution,
                                                        self.resolution))

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    # Last dim is real/complex dimension for complex image and target.
                    complex_image = torch.flip(complex_image, dims=(-3, -2))
                    cmg_target = torch.flip(cmg_target, dims=(-3, -2))
                    target = torch.flip(target, dims=(
                        -2, -1))  # Has only two dimensions, height and width.

                elif flip_ud:
                    complex_image = torch.flip(complex_image, dims=(-3, ))
                    cmg_target = torch.flip(cmg_target, dims=(-3, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    complex_image = torch.flip(complex_image, dims=(-2, ))
                    cmg_target = torch.flip(cmg_target, dims=(-2, ))
                    target = torch.flip(target, dims=(-1, ))

            # Adding pi to angles so that the phase is in the [0, 2pi] range for better learning.
            phase_input = torch.atan2(complex_image[..., 1], complex_image[...,
                                                                           0])
            phase_input += math.pi  # Don't forget to remove the pi in the output transform!

            img_input = complex_abs(complex_image)
            img_scale = torch.std(img_input)
            img_input /= img_scale

            cmg_target /= img_scale
            img_target = complex_abs(cmg_target)
            kspace_target = fft2(
                cmg_target
            )  # Reconstruct k-space target after cropping and image augmentation.
            phase_target = torch.atan2(cmg_target[..., 1], cmg_target[..., 0])

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            targets = {
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'phase_targets': phase_target,
                'img_inputs': img_input
            }

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

            # Converting to NCHW format for CNN. Also adding phase input.
            inputs = (img_input, phase_input)

        return inputs, targets, extra_params