def forward(self, img_output, targets, extra_params):
        if img_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        img_target = targets['img_targets']
        # For removing width dimension padding. Recall that complex number form has 2 as last dim size.
        left = (img_output.size(-1) - img_target.size(-1)) // 2
        right = left + img_target.size(-1)

        # Cropping width dimension by pad.
        img_recon = F.relu(
            img_output[..., left:right]
        )  # Removing values below 0, which are impossible anyway.

        assert img_recon.shape == img_target.shape, 'Reconstruction and target sizes are different.'

        recons = {'img_recons': img_recon}

        if self.challenge == 'multicoil':
            rss_recon = center_crop(img_recon, (
                self.resolution, self.resolution)) * extra_params['img_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
    def forward(self, cmg_output, targets, extra_params):
        if cmg_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        cmg_target = targets['cmg_targets']
        cmg_recon = nchw_to_kspace(cmg_output)
        assert cmg_recon.shape == cmg_target.shape, 'Reconstruction and target sizes are different.'
        assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        if self.residual_acs:  # Adding the semi-k-space of the ACS as a residual. Necessary due to complex cropping.
            raise NotImplementedError('Not ready yet.')
            # cmg_acs = targets['cmg_acss']
            # cmg_recon = cmg_recon + cmg_acs

        kspace_recon = fft2(cmg_recon)
        img_recon = complex_abs(cmg_recon)

        recons = {
            'kspace_recons': kspace_recon,
            'cmg_recons': cmg_recon,
            'img_recons': img_recon
        }

        if self.challenge == 'multicoil':
            rss_recon = center_crop(img_recon, (
                self.resolution, self.resolution)) * extra_params['cmg_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
    def forward(self, outputs, targets, extra_params):
        img_recon, phase_recon = outputs

        if img_recon.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        img_target = targets['img_targets']
        img_recon = F.relu(img_recon)  # img_recons must be positive numbers.
        assert img_recon.shape == img_target.shape, 'Reconstruction and target sizes are different.'

        # Input transform had addition of pi as pre-processing.
        phase_recon = phase_recon - math.pi  # No clamping implemented since the loss is MSE.
        cmg_recon = torch.stack([
            img_recon * torch.cos(phase_recon),
            img_recon * torch.sin(phase_recon)
        ],
                                dim=-1)
        kspace_recon = fft2(cmg_recon)

        recons = {
            'img_recons': img_recon,
            'phase_recons': phase_recon,
            'cmg_recons': cmg_recon,
            'kspace_recons': kspace_recon
        }

        if self.challenge == 'multicoil':
            rss_recon = center_crop(
                img_recon,
                shape=(self.resolution,
                       self.resolution)) * extra_params['img_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
    def __call__(self, masked_kspace, target, attrs, file_name, slice_num):
        assert isinstance(
            masked_kspace, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if masked_kspace.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            masked_kspace = masked_kspace.expand(1, 1, -1, -1, -1)
        elif masked_kspace.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            masked_kspace = masked_kspace.expand(1, -1, -1, -1, -1)
        elif masked_kspace.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if masked_kspace.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            input_image = complex_abs(ifft2(masked_kspace))
            # Cropping is mandatory if RSS means the 320 version. Not so if a larger image is used.
            # However, I think that removing target processing is worth the loss of flexibility.
            input_image = center_crop(input_image,
                                      shape=(self.resolution, self.resolution))
            img_scale = torch.std(input_image)
            input_image /= img_scale

            extra_params = {'img_scales': img_scale}
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            input_rss = root_sum_of_squares(input_image, dim=1).squeeze()
            targets = {'img_inputs': input_image, 'rss_inputs': input_rss}

        return input_image, targets, extra_params
    def forward(self, semi_kspace_outputs, targets, extra_params):
        if semi_kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        semi_kspace_targets = targets['semi_kspace_targets']
        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (semi_kspace_outputs.size(-1) -
                semi_kspace_targets.size(-2)) // 2
        right = left + semi_kspace_targets.size(-2)

        # Cropping width dimension by pad.
        semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs[...,
                                                                left:right])

        assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            semi_kspace_recons = semi_kspace_recons / weighting

        if self.residual_acs:
            num_low_freqs = extra_params['num_low_frequency']
            acs_mask = find_acs_mask(semi_kspace_recons, num_low_freqs)
            semi_kspace_recons = semi_kspace_recons + acs_mask * semi_kspace_targets

        if self.replace:
            mask = extra_params['masks']
            semi_kspace_recons = semi_kspace_recons * (
                1 - mask) + semi_kspace_targets * mask

        kspace_recons = fft1(semi_kspace_recons, direction='height')
        cmg_recons = ifft1(semi_kspace_recons, direction='width')
        img_recons = complex_abs(cmg_recons)

        recons = {
            'semi_kspace_recons': semi_kspace_recons,
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        if self.challenge == 'multicoil':
            rss_recons = center_crop(img_recons,
                                     (self.resolution, self.resolution))
            rss_recons = root_sum_of_squares(rss_recons, dim=1).squeeze()
            rss_recons *= extra_params[
                'sk_scales']  # This value was divided in the inputs. It is thus multiplied here.
            recons['rss_recons'] = rss_recons

        return recons  # Returning scaled reconstructions. Not rescaled. RSS images are rescaled.
    def forward(self, img_output, targets, extra_params):
        if img_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        # Removing values below 0, which are impossible anyway.
        img_recon = F.relu(
            center_crop(img_output, shape=(self.resolution, self.resolution)))
        img_recon *= extra_params['img_scales']

        if self.challenge == 'multicoil':
            img_recon = root_sum_of_squares(img_recon, dim=1)

        return img_recon.squeeze()
    def forward(self, kspace_outputs, targets, extra_params):
        if kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one slice at a time for now.')

        kspace_targets = targets['kspace_targets']

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (kspace_outputs.size(-1) - kspace_targets.size(-2)) // 2
        right = left + kspace_targets.size(-2)

        # Cropping width dimension by pad.
        kspace_recons = nchw_to_kspace(kspace_outputs[..., left:right])
        assert kspace_recons.shape == kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (kspace_recons.size(-3) % 2 == 0) and (kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            kspace_recons = kspace_recons / weighting

        if self.residual_acs:
            num_low_freqs = extra_params['num_low_frequency']
            acs_mask = find_acs_mask(kspace_recons, num_low_freqs)
            kspace_recons = kspace_recons + acs_mask * kspace_targets

        if self.replace:  # Replace with original k-space if replace=True
            mask = extra_params['masks']
            kspace_recons = kspace_recons * (1 - mask) + kspace_targets * mask

        cmg_recons = ifft2(kspace_recons)
        img_recons = complex_abs(cmg_recons)
        recons = {
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        if img_recons.size(1) == 15:
            top = (img_recons.size(-2) - self.resolution) // 2
            left = (img_recons.size(-1) - self.resolution) // 2
            rss_recon = img_recons[:, :, top:top + self.resolution,
                                   left:left + self.resolution]
            rss_recon = root_sum_of_squares(
                rss_recon, dim=1).squeeze()  # rss_recon is in 2D
            recons['rss_recons'] = rss_recon

        return recons  # Returning scaled reconstructions. Not rescaled.
Beispiel #8
0
    def forward(self, semi_kspace_output: Tensor, targets: dict,
                extra_params: dict):
        assert semi_kspace_output.dim() == 5 and semi_kspace_output.size(
            1) == 2, 'Invalid shape!'
        if semi_kspace_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        semi_kspace_target = targets['semi_kspace_targets']
        semi_kspace_recon = semi_kspace_output.permute(
            dims=(0, 2, 3, 4, 1))  # Convert back into NCHW2

        if semi_kspace_recon.shape != semi_kspace_target.shape:  # Cropping recon left-right.
            left = (semi_kspace_recon.size(-2) -
                    semi_kspace_target.size(-2)) // 2
            semi_kspace_recon = semi_kspace_recon[
                ..., left:left + semi_kspace_target.size(-2), :]

        assert semi_kspace_recon.shape == semi_kspace_target.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recon.size(-3) % 2 == 0) and (semi_kspace_recon.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        if self.weighted:
            semi_kspace_recon = semi_kspace_recon / extra_params['weightings']

        if self.replace:
            mask = extra_params['masks']
            semi_kspace_recon = semi_kspace_target * mask + (
                1 - mask) * semi_kspace_recon

        # kspace_recon = fft1(semi_kspace_recon, direction='height')
        cmg_recon = ifft1(semi_kspace_recon, direction='width')
        img_recon = complex_abs(cmg_recon)

        recons = {
            'semi_kspace_recons': semi_kspace_recon,
            'cmg_recons': cmg_recon,
            'img_recons': img_recon
        }

        if self.challenge == 'multicoil':
            rss_recon = center_crop(
                img_recon,
                (self.resolution, self.resolution)) * extra_params['sk_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
Beispiel #9
0
def make_rss_slice(image, dim=1, resolution=320):
    assert image.dim() == 4, 'Real-valued images are expected.'
    if image.size(0) > 1:
        raise NotImplementedError('Batch size is expected to be 1.')

    top = (image.size(-2) - resolution) // 2
    left = (image.size(-1) - resolution) // 2

    image = image.detach()[:, :, top:top + resolution, left:left + resolution]

    if image.size(1) == 1:  # Single-coil
        return image.squeeze().to(device='cpu', non_blocking=True)
    elif image.size(1) != 15:
        raise ValueError('Invalid number of coils for this dataset.')

    rss = root_sum_of_squares(image, dim=dim).squeeze()
    return rss.to(device='cpu', non_blocking=True)
    def forward(self, output, targets, extra_params):
        if output.size(0) > 1:
            raise NotImplementedError('Batch size is expected to be 1.')

        if self.output_mode == 'cmg':
            recons = self._cmg_output(output, targets, extra_params)
        elif self.output_mode == 'img':
            recons = self._img_output(output, targets, extra_params)
        else:
            raise NotImplementedError('Invalid output mode.')

        if self.challenge == 'multicoil':
            rss_recon = center_crop(recons['img_recons'],
                                    shape=(self.resolution, self.resolution))
            rss_recon *= extra_params['cmg_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons
    def forward(self, semi_kspace_outputs, targets, extra_params):
        if semi_kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs)
        semi_kspace_targets = targets['semi_kspace_targets']
        assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            semi_kspace_recons = semi_kspace_recons / weighting

        if self.residual_acs:  # Adding the semi-k-space of the ACS as a residual. Necessary due to complex cropping.
            semi_kspace_acs = targets['semi_kspace_acss']
            semi_kspace_recons = semi_kspace_recons + semi_kspace_acs

        kspace_recons = fft1(semi_kspace_recons, direction='height')
        cmg_recons = ifft1(semi_kspace_recons, direction='width')
        img_recons = complex_abs(cmg_recons)

        recons = {
            'semi_kspace_recons': semi_kspace_recons,
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        if self.challenge == 'multicoil':
            rss_recons = root_sum_of_squares(img_recons, dim=1).squeeze()
            rss_recons *= extra_params['sk_scales']
            recons['rss_recons'] = rss_recons

        return recons  # Returning scaled reconstructions. Not rescaled. RSS images are rescaled.
def test_root_sum_of_squares(shape, dim):
    tensor = create_tensor(shape)
    out_torch = data_transforms.root_sum_of_squares(tensor, dim).numpy()
    out_numpy = np.sqrt(np.sum(tensor.numpy()**2, dim))
    assert np.allclose(out_torch, out_numpy)
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            input_image = complex_abs(ifft2(masked_kspace))
            # Cropping is mandatory if RSS means the 320 version. Not so if a larger image is used.
            # However, I think that removing target processing is worth the loss of flexibility.
            input_image = center_crop(input_image,
                                      shape=(self.resolution, self.resolution))
            img_scale = torch.std(input_image)
            input_image /= img_scale

            # No subtraction by the mean. I do not know if this is a good idea or not.

            if self.use_patch:
                assert self.resolution == 320
                left = torch.randint(low=0,
                                     high=self.resolution - self.patch_size,
                                     size=(1, )).squeeze()
                top = torch.randint(low=0,
                                    high=self.resolution - self.patch_size,
                                    size=(1, )).squeeze()

                input_image = input_image[:, :, top:top + self.patch_size,
                                          left:left + self.patch_size]
                target = target[top:top + self.patch_size,
                                left:left + self.patch_size]

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, -1))
                    target = torch.flip(target, dims=(-2, -1))

                elif flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    input_image = torch.flip(input_image, dims=(-1, ))
                    target = torch.flip(target, dims=(-1, ))

            # Use plurals as keys to reduce confusion.
            input_rss = root_sum_of_squares(input_image, dim=1).squeeze()
            targets = {
                'img_inputs': input_image,
                'rss_targets': target,
                'rss_inputs': input_rss
            }

        return input_image, targets, extra_params
Beispiel #14
0
import torch
import h5py

from data.data_transforms import ifft2, to_tensor, root_sum_of_squares, center_crop, complex_center_crop, complex_abs

file = '/media/veritas/D/FastMRI/multicoil_val/file1001798.h5'
sdx = 10
with h5py.File(file, mode='r') as hf:
    kspace = hf['kspace'][sdx]
    target = hf['reconstruction_rss'][sdx]

cmg_scale = 2E-5
recon = complex_center_crop(ifft2(to_tensor(kspace) / cmg_scale),
                            shape=(320, 320)) * cmg_scale
recon = root_sum_of_squares(complex_abs(recon))
target = to_tensor(target)

print(torch.allclose(recon, target))
Beispiel #15
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            input_image = complex_abs(ifft2(masked_kspace))
            # Cropping is mandatory if RSS means the 320 version. Not so if a larger image is used.
            # However, I think that removing target processing is worth the loss of flexibility.
            input_image = center_crop(input_image,
                                      shape=(self.resolution, self.resolution))
            img_scale = torch.std(input_image)
            input_image /= img_scale

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            if target is 0:
                target = center_crop(
                    root_sum_of_squares(complex_abs(ifft2(kspace_target)),
                                        dim=1),
                    shape=(self.resolution, self.resolution)).squeeze()

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, -1))
                    target = torch.flip(target, dims=(-2, -1))

                elif flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    input_image = torch.flip(input_image, dims=(-1, ))
                    target = torch.flip(target, dims=(-1, ))

            # Use plurals as keys to reduce confusion.
            input_rss = root_sum_of_squares(input_image, dim=1).squeeze()
            targets = {
                'img_inputs': input_image,
                'rss_targets': target,
                'rss_inputs': input_rss
            }

            if self.fat_info:
                fat_supp = extra_params[
                    'acquisition'] == 'CORPDFS_FBK'  # Fat suppressed acquisition.
                batch, _, height, width = input_image.shape
                fat_info = torch.ones(size=(batch, 1, height, width),
                                      device=self.device) * fat_supp
                input_image = torch.cat([input_image, fat_info], dim=1)

        return input_image, targets, extra_params
Beispiel #16
0
import torch
import h5py
import numpy as np

from data.data_transforms import root_sum_of_squares, center_crop, ifft2, to_tensor, complex_abs

file = '/media/veritas/D/FastMRI/multicoil_val/file1000229.h5'
with h5py.File(file, 'r') as hf:
    kspace = hf['kspace'][()]
    rss = hf['reconstruction_rss'][()]

kspace = to_tensor(kspace)
image = center_crop(root_sum_of_squares(complex_abs(ifft2(kspace)), dim=1),
                    shape=(320, 320)).squeeze().numpy()
print(np.allclose(image, rss))