def forward(self, kspace_outputs, targets, extra_params):
        if kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one slice at a time for now.')

        kspace_targets = targets['kspace_targets']

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (kspace_outputs.size(-1) - kspace_targets.size(-2)) // 2
        right = left + kspace_targets.size(-2)

        # Cropping width dimension by pad.
        kspace_recons = nchw_to_kspace(kspace_outputs[..., left:right])

        assert kspace_recons.shape == kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (kspace_recons.size(-3) % 2 == 0) and (kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            kspace_recons = kspace_recons / weighting

        if self.replace:  # Replace with original k-space if replace=True
            mask = extra_params['masks']
            kspace_recons = kspace_recons * (1 - mask) + kspace_targets * mask

        cmg_recons = ifft2(kspace_recons)
        img_recons = complex_abs(cmg_recons)
        recons = {
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        return recons  # Returning scaled reconstructions. Not rescaled.
    def forward(self, cmg_output, targets, extra_params):
        if cmg_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        cmg_target = targets['cmg_targets']
        cmg_recon = nchw_to_kspace(cmg_output)
        assert cmg_recon.shape == cmg_target.shape, 'Reconstruction and target sizes are different.'
        assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        if self.residual_acs:  # Adding the semi-k-space of the ACS as a residual. Necessary due to complex cropping.
            raise NotImplementedError('Not ready yet.')
            # cmg_acs = targets['cmg_acss']
            # cmg_recon = cmg_recon + cmg_acs

        kspace_recon = fft2(cmg_recon)
        img_recon = complex_abs(cmg_recon)

        recons = {
            'kspace_recons': kspace_recon,
            'cmg_recons': cmg_recon,
            'img_recons': img_recon
        }

        if self.challenge == 'multicoil':
            rss_recon = center_crop(img_recon, (
                self.resolution, self.resolution)) * extra_params['cmg_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
    def forward(self, semi_kspace_outputs, targets, extra_params):
        if semi_kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        semi_kspace_targets = targets['semi_kspace_targets']
        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (semi_kspace_outputs.size(-1) -
                semi_kspace_targets.size(-2)) // 2
        right = left + semi_kspace_targets.size(-2)

        # Cropping width dimension by pad.
        semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs[...,
                                                                left:right])

        assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            semi_kspace_recons = semi_kspace_recons / weighting

        if self.residual_acs:
            num_low_freqs = extra_params['num_low_frequency']
            acs_mask = find_acs_mask(semi_kspace_recons, num_low_freqs)
            semi_kspace_recons = semi_kspace_recons + acs_mask * semi_kspace_targets

        if self.replace:
            mask = extra_params['masks']
            semi_kspace_recons = semi_kspace_recons * (
                1 - mask) + semi_kspace_targets * mask

        kspace_recons = fft1(semi_kspace_recons, direction='height')
        cmg_recons = ifft1(semi_kspace_recons, direction='width')
        img_recons = complex_abs(cmg_recons)

        recons = {
            'semi_kspace_recons': semi_kspace_recons,
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        if self.challenge == 'multicoil':
            rss_recons = center_crop(img_recons,
                                     (self.resolution, self.resolution))
            rss_recons = root_sum_of_squares(rss_recons, dim=1).squeeze()
            rss_recons *= extra_params[
                'sk_scales']  # This value was divided in the inputs. It is thus multiplied here.
            recons['rss_recons'] = rss_recons

        return recons  # Returning scaled reconstructions. Not rescaled. RSS images are rescaled.
    def forward(self, kspace_output, kspace_target, extra_params):
        k_scale, mask = extra_params

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (kspace_output.size(-1) - kspace_target.size(-2)) // 2
        right = left + kspace_target.size(-2)

        # Processing to k-space form. This is where the batch_size == 1 is important.
        # 1. Crop padding. 2. Reshape to kspace shape. 3. Unweight k-space values. 4. Rescale to original scale.
        kspace_recon = exp_weighting(nchw_to_kspace(kspace_output[..., left:right]), scale=self.log_amp_scale) * k_scale

        assert kspace_recon.size() == kspace_target.size(), 'Reconstruction and target sizes do not match.'
        kspace_recon = kspace_recon * (1 - mask) + kspace_target * mask
        return kspace_recon
    def forward(self, tensor,
                out_shape):  # Using out_shape only works for batch size of 1.
        """
        Args:
            tensor (torch.Tensor): Input tensor of shape [batch_size, in_chans, height, width]
            out_shape (tuple): shape [batch_size, num_coils, true_height, true_width].
            Note that in_chans = 2 * num_coils
        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        stack = list()
        output = tensor
        # Apply down-sampling layers
        for layer in self.down_sample_layers:
            output = layer(output)
            stack.append(output)
            output = F.max_pool2d(output, kernel_size=2)

        output = self.conv(output)

        # Apply up-sampling layers
        for layer in self.up_sample_layers:
            output = F.interpolate(output,
                                   scale_factor=2,
                                   mode='bilinear',
                                   align_corners=False)
            output = torch.cat((output, stack.pop()), dim=1)
            output = layer(output)

        output = self.conv2(output)  # End of learning.

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (output.size(-1) - out_shape[-1]
                ) // 2  # This depends on mini-batch size being 1 to work.
        right = left + out_shape[-1]

        # Previously, cropping was done by  [pad:-pad]. However, this fails catastrophically when pad=0 as
        # all values are wiped out in those cases where [0:0] creates an empty tensor.

        # Cropping width dimension by pad.
        output = output[..., left:right]

        # Processing to k-space form.
        output = nchw_to_kspace(output)

        # Convert to image.
        output = complex_abs(ifft2(output))

        assert output.size() == out_shape  # Checking just in case.
        return output
    def forward(self, kspace_output, c_img_target, extra_params):
        kspace_target, k_scale, mask = extra_params

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (kspace_output.size(-1) - kspace_target.size(-2)) // 2
        right = left + kspace_target.size(-2)

        # Cropping width dimension by pad. Multiply by scales to restore the original scaling.
        kspace_recon = exp_weighting(nchw_to_kspace(kspace_output[..., left:right]), scale=self.log_amp_scale) * k_scale
        assert kspace_recon.size() == kspace_target.size(), 'Reconstruction and target sizes do not match.'

        kspace_recon = kspace_recon * (1 - mask) + kspace_target * mask
        c_img_recons = ifft2(kspace_recon)

        return c_img_recons
    def forward(self, kspace_outputs, targets, extra_params):
        if kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one slice at a time for now.')

        kspace_targets = targets['kspace_targets']

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (kspace_outputs.size(-1) - kspace_targets.size(-2)) // 2
        right = left + kspace_targets.size(-2)

        # Cropping width dimension by pad.
        kspace_recons = nchw_to_kspace(kspace_outputs[..., left:right])
        assert kspace_recons.shape == kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (kspace_recons.size(-3) % 2 == 0) and (kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            kspace_recons = kspace_recons / weighting

        if self.residual_acs:
            num_low_freqs = extra_params['num_low_frequency']
            acs_mask = find_acs_mask(kspace_recons, num_low_freqs)
            kspace_recons = kspace_recons + acs_mask * kspace_targets

        if self.replace:  # Replace with original k-space if replace=True
            mask = extra_params['masks']
            kspace_recons = kspace_recons * (1 - mask) + kspace_targets * mask

        cmg_recons = ifft2(kspace_recons)
        img_recons = complex_abs(cmg_recons)
        recons = {
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        if img_recons.size(1) == 15:
            top = (img_recons.size(-2) - self.resolution) // 2
            left = (img_recons.size(-1) - self.resolution) // 2
            rss_recon = img_recons[:, :, top:top + self.resolution,
                                   left:left + self.resolution]
            rss_recon = root_sum_of_squares(
                rss_recon, dim=1).squeeze()  # rss_recon is in 2D
            recons['rss_recons'] = rss_recon

        return recons  # Returning scaled reconstructions. Not rescaled.
 def _cmg_output(cmg_output, targets, extra_params):
     cmg_target = targets['cmg_targets']
     cmg_recon = nchw_to_kspace(
         cmg_output)  # Assumes data was cropped already.
     assert cmg_recon.shape == cmg_target.shape, 'Reconstruction and target sizes are different.'
     assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \
         'Not impossible but not expected to have sides with odd lengths.'
     cmg_recon = cmg_recon + targets[
         'cmg_inputs']  # Residual of complex input.
     kspace_recon = fft2(cmg_recon)
     img_recon = complex_abs(cmg_recon)
     recons = {
         'kspace_recons': kspace_recon,
         'cmg_recons': cmg_recon,
         'img_recons': img_recon
     }
     return recons
    def forward(self, kspace_output, kspace_target, extra_params):
        if not kspace_output.size(0) == 1:
            raise NotImplementedError('Only single batch for now.')

        scaling, mask = extra_params

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (kspace_output.size(-1) - kspace_target.size(-2)) // 2
        right = left + kspace_target.size(-2)

        # Cropping width dimension by pad. Multiply by scales to restore the original scaling.
        k_output = kspace_output[..., left:right] * scaling

        # Processing to k-space form. This is where the batch_size == 1 is important.
        kspace_recon = nchw_to_kspace(k_output)

        assert kspace_recon.size() == kspace_target.size(), 'Reconstruction and target sizes do not match.'

        return kspace_recon
    def forward(self, kspace_output, c_img_target, extra_params):
        kspace_target, scaling, mask = extra_params

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (kspace_output.size(-1) - kspace_target.size(-2)) // 2
        right = left + kspace_target.size(-2)

        # Cropping width dimension by pad. Multiply by scales to restore the original scaling.
        k_output = kspace_output[..., left:right] * scaling

        # Processing to k-space form. This is where the batch_size == 1 is important.
        kspace_recon = nchw_to_kspace(k_output)

        assert kspace_recon.size() == kspace_target.size(), 'Reconstruction and target sizes do not match.'

        kspace_recon = kspace_recon * (1 - mask) + kspace_target * mask

        c_img_recons = ifft2(kspace_recon)

        return c_img_recons
    def forward(self, semi_kspace_outputs, targets, extra_params):
        if semi_kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one batch at a time for now.')

        semi_kspace_targets = targets['semi_kspace_targets']
        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (semi_kspace_outputs.size(-1) -
                semi_kspace_targets.size(-2)) // 2
        right = left + semi_kspace_targets.size(-2)

        # Cropping width dimension by pad.
        semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs[...,
                                                                left:right])

        assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            semi_kspace_recons = semi_kspace_recons / weighting

        if self.replace:
            mask = extra_params['masks']
            semi_kspace_recons = semi_kspace_recons * (
                1 - mask) + semi_kspace_targets * mask

        kspace_recons = fft1(semi_kspace_recons, direction=self.direction)
        cmg_recons = ifft1(semi_kspace_recons, direction=self.recon_direction)
        img_recons = complex_abs(cmg_recons)

        recons = {
            'semi_kspace_recons': semi_kspace_recons,
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        return recons  # Returning scaled reconstructions. Not rescaled.
    def forward(self, semi_kspace_outputs, targets, extra_params):
        if semi_kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs)
        semi_kspace_targets = targets['semi_kspace_targets']
        assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            semi_kspace_recons = semi_kspace_recons / weighting

        if self.residual_acs:  # Adding the semi-k-space of the ACS as a residual. Necessary due to complex cropping.
            semi_kspace_acs = targets['semi_kspace_acss']
            semi_kspace_recons = semi_kspace_recons + semi_kspace_acs

        kspace_recons = fft1(semi_kspace_recons, direction='height')
        cmg_recons = ifft1(semi_kspace_recons, direction='width')
        img_recons = complex_abs(cmg_recons)

        recons = {
            'semi_kspace_recons': semi_kspace_recons,
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        if self.challenge == 'multicoil':
            rss_recons = root_sum_of_squares(img_recons, dim=1).squeeze()
            rss_recons *= extra_params['sk_scales']
            recons['rss_recons'] = rss_recons

        return recons  # Returning scaled reconstructions. Not rescaled. RSS images are rescaled.
Beispiel #13
0
    def forward(self, k_output, targets, scales):
        """
            Output post-processing for output k-space tensor with batch size of 1.
            This is experimental and is subject to change.
            Planning on taking k-space outputs from CNNs, then transforming them into original k-space shapes.
            No batch size planned yet.

            Args:
                k_output (torch.Tensor): CNN output of k-space. Expected to have batch-size of 1.
                targets (torch.Tensor): Target image domain data.
                scales (torch.Tensor): scaling factor used to divide the input k-space slice data.
            Returns:
                kspace (torch.Tensor): kspace in original shape with batch dimension in-place.
        """

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (k_output.size(-1) - targets.size(-1)
                ) // 2  # This depends on mini-batch size being 1 to work.
        right = left + targets.size(-1)

        # Previously, cropping was done by  [pad:-pad]. However, this fails catastrophically when pad == 0 as
        # all values are wiped out in those cases where [0:0] creates an empty tensor.

        # Cropping width dimension by pad. Multiply by scales to restore the original scaling.
        k_output = k_output[..., left:right] * scales

        # Processing to k-space form. This is where the batch_size == 1 is important.
        kspace_recons = nchw_to_kspace(k_output)

        # Convert to image.
        image_recons = complex_abs(ifft2(kspace_recons))

        assert image_recons.size() == targets.size(
        ), 'Reconstruction and target sizes do not match.'

        return image_recons, kspace_recons
 def restore_orig_shape(k_slice, target_slice):
     left_pad = (k_slice.size(-1) - target_slice.size(-1)) // 2
     right_pad = (1 + k_slice.size(-1) - target_slice.size(-1)) // 2
     k_slice = k_slice[..., left_pad:-right_pad]
     return nchw_to_kspace(k_slice)
Beispiel #15
0
import numpy as np
import torch
from time import time

k1 = np.random.uniform(size=(32, 15, 640, 328))
k2 = np.random.uniform(size=(32, 15, 640, 328))

k = k1 + k2 * 1j

kt = to_tensor(k)

tic = time()
ncwh = kspace_to_nchw(kt)
mid = kt.shape[1]

for idx, kts in enumerate(kt):
    temp = k_slice_to_chw(kts)
    print(idx, torch.eq(ncwh[idx], temp).all())

chan = 17
ri = chan % 2
sli = chan // 2

print(
    torch.eq(torch.squeeze(ncwh[3, chan, ...]),
             torch.squeeze(kt[3, sli, ..., ri])).all())
kspace = nchw_to_kspace(ncwh)
toc = time() - tic

print(torch.eq(kt, kspace).all(), toc)