def forward(self, semi_kspace_outputs, targets, extra_params):
        if semi_kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        semi_kspace_targets = targets['semi_kspace_targets']
        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (semi_kspace_outputs.size(-1) -
                semi_kspace_targets.size(-2)) // 2
        right = left + semi_kspace_targets.size(-2)

        # Cropping width dimension by pad.
        semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs[...,
                                                                left:right])

        assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            semi_kspace_recons = semi_kspace_recons / weighting

        if self.residual_acs:
            num_low_freqs = extra_params['num_low_frequency']
            acs_mask = find_acs_mask(semi_kspace_recons, num_low_freqs)
            semi_kspace_recons = semi_kspace_recons + acs_mask * semi_kspace_targets

        if self.replace:
            mask = extra_params['masks']
            semi_kspace_recons = semi_kspace_recons * (
                1 - mask) + semi_kspace_targets * mask

        kspace_recons = fft1(semi_kspace_recons, direction='height')
        cmg_recons = ifft1(semi_kspace_recons, direction='width')
        img_recons = complex_abs(cmg_recons)

        recons = {
            'semi_kspace_recons': semi_kspace_recons,
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        if self.challenge == 'multicoil':
            rss_recons = center_crop(img_recons,
                                     (self.resolution, self.resolution))
            rss_recons = root_sum_of_squares(rss_recons, dim=1).squeeze()
            rss_recons *= extra_params[
                'sk_scales']  # This value was divided in the inputs. It is thus multiplied here.
            recons['rss_recons'] = rss_recons

        return recons  # Returning scaled reconstructions. Not rescaled. RSS images are rescaled.
def test_ifft1_width(shape):
    shape = shape + [2]
    tensor = create_tensor(shape)
    out_torch = data_transforms.ifft1(tensor, direction='width').numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    tensor_numpy = data_transforms.tensor_to_complex_np(tensor)
    tensor_numpy = np.fft.ifftshift(tensor_numpy, axes=-1)
    out_numpy = np.fft.ifft(tensor_numpy, axis=-1, norm='ortho')
    out_numpy = np.fft.fftshift(out_numpy, axes=-1)

    assert np.allclose(out_torch, out_numpy)
예제 #3
0
    def forward(self, semi_kspace_output: Tensor, targets: dict,
                extra_params: dict):
        assert semi_kspace_output.dim() == 5 and semi_kspace_output.size(
            1) == 2, 'Invalid shape!'
        if semi_kspace_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        semi_kspace_target = targets['semi_kspace_targets']
        semi_kspace_recon = semi_kspace_output.permute(
            dims=(0, 2, 3, 4, 1))  # Convert back into NCHW2

        if semi_kspace_recon.shape != semi_kspace_target.shape:  # Cropping recon left-right.
            left = (semi_kspace_recon.size(-2) -
                    semi_kspace_target.size(-2)) // 2
            semi_kspace_recon = semi_kspace_recon[
                ..., left:left + semi_kspace_target.size(-2), :]

        assert semi_kspace_recon.shape == semi_kspace_target.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recon.size(-3) % 2 == 0) and (semi_kspace_recon.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        if self.weighted:
            semi_kspace_recon = semi_kspace_recon / extra_params['weightings']

        if self.replace:
            mask = extra_params['masks']
            semi_kspace_recon = semi_kspace_target * mask + (
                1 - mask) * semi_kspace_recon

        # kspace_recon = fft1(semi_kspace_recon, direction='height')
        cmg_recon = ifft1(semi_kspace_recon, direction='width')
        img_recon = complex_abs(cmg_recon)

        recons = {
            'semi_kspace_recons': semi_kspace_recon,
            'cmg_recons': cmg_recon,
            'img_recons': img_recon
        }

        if self.challenge == 'multicoil':
            rss_recon = center_crop(
                img_recon,
                (self.resolution, self.resolution)) * extra_params['sk_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
    def forward(self, semi_kspace_outputs, targets, extra_params):
        if semi_kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one batch at a time for now.')

        semi_kspace_targets = targets['semi_kspace_targets']
        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (semi_kspace_outputs.size(-1) -
                semi_kspace_targets.size(-2)) // 2
        right = left + semi_kspace_targets.size(-2)

        # Cropping width dimension by pad.
        semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs[...,
                                                                left:right])

        assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            semi_kspace_recons = semi_kspace_recons / weighting

        if self.replace:
            mask = extra_params['masks']
            semi_kspace_recons = semi_kspace_recons * (
                1 - mask) + semi_kspace_targets * mask

        kspace_recons = fft1(semi_kspace_recons, direction=self.direction)
        cmg_recons = ifft1(semi_kspace_recons, direction=self.recon_direction)
        img_recons = complex_abs(cmg_recons)

        recons = {
            'semi_kspace_recons': semi_kspace_recons,
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        return recons  # Returning scaled reconstructions. Not rescaled.
    def forward(self, semi_kspace_outputs, targets, extra_params):
        if semi_kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs)
        semi_kspace_targets = targets['semi_kspace_targets']
        assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            semi_kspace_recons = semi_kspace_recons / weighting

        if self.residual_acs:  # Adding the semi-k-space of the ACS as a residual. Necessary due to complex cropping.
            semi_kspace_acs = targets['semi_kspace_acss']
            semi_kspace_recons = semi_kspace_recons + semi_kspace_acs

        kspace_recons = fft1(semi_kspace_recons, direction='height')
        cmg_recons = ifft1(semi_kspace_recons, direction='width')
        img_recons = complex_abs(cmg_recons)

        recons = {
            'semi_kspace_recons': semi_kspace_recons,
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        if self.challenge == 'multicoil':
            rss_recons = root_sum_of_squares(img_recons, dim=1).squeeze()
            rss_recons *= extra_params['sk_scales']
            recons['rss_recons'] = rss_recons

        return recons  # Returning scaled reconstructions. Not rescaled. RSS images are rescaled.
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            # img_input is not actually an input but what the input would look like in the image domain.
            img_input = complex_abs(ifft2(masked_kspace))

            semi_kspace = ifft1(masked_kspace, direction='height')

            weighting = self.weight_func(semi_kspace)
            semi_kspace *= weighting

            # The slope is meaningless as the results always become the same after standardization no matter the slope.
            # The ordering could be changed to allow a difference, but this would make the inputs non-standardized.
            sk_scale = torch.std(semi_kspace)

            semi_kspace /= sk_scale
            inputs = kspace_to_nchw(semi_kspace)

            extra_params = {
                'sk_scales': sk_scale,
                'masks': mask,
                'weightings': weighting
            }
            extra_params.update(info)
            extra_params.update(attrs)

            # Recall that the Fourier transform is a linear transform.
            kspace_target /= sk_scale
            cmg_target = ifft2(kspace_target)
            img_target = complex_abs(cmg_target)
            semi_kspace_target = ifft1(kspace_target, direction='height')

            # Use plurals as keys to reduce confusion.
            targets = {
                'semi_kspace_targets': semi_kspace_target,
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'img_inputs': img_input
            }

            if kspace_target.size(1) == 15:  # If multi-coil.
                targets[
                    'rss_targets'] = target  # Scaling needed for metric comparison later.

            margin = inputs.size(-1) % self.divisor
            if margin > 0:
                pad = [(self.divisor - margin) // 2,
                       (1 + self.divisor - margin) // 2]
            else:  # This is a fix to prevent padding by half the divisor when margin=0.
                pad = [0, 0]

            # This pads at the last dimension of a tensor with 0.
            inputs = F.pad(inputs, pad=pad, value=0)

        return inputs, targets, extra_params