def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            image = complex_abs(ifft2(masked_kspace))
            if self.crop_center:
                image = center_crop(image,
                                    shape=(self.resolution, self.resolution))
            else:  # Super-hackish temporary line.  # TODO: Fix this thing later!
                image = center_crop(image, shape=(352, image.size(-1)))

            margin = image.size(-1) % self.divisor
            if margin > 0:
                pad = [(self.divisor - margin) // 2,
                       (1 + self.divisor - margin) // 2]
            else:  # This is a fix to prevent padding by half the divisor when margin=0.
                pad = [0, 0]

            # This pads at the last dimension of a tensor with 0.
            image = F.pad(image, pad=pad, value=0)

            img_scale = torch.std(
                center_crop(image, shape=(self.resolution,
                                          self.resolution)))  # Also a hack!
            image /= img_scale

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            targets = {'img_inputs': image}

        return image, targets, extra_params
    def forward(self, img_output, targets, extra_params):
        if img_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        img_target = targets['img_targets']
        # For removing width dimension padding. Recall that complex number form has 2 as last dim size.
        left = (img_output.size(-1) - img_target.size(-1)) // 2
        right = left + img_target.size(-1)

        # Cropping width dimension by pad.
        img_recon = F.relu(
            img_output[..., left:right]
        )  # Removing values below 0, which are impossible anyway.

        assert img_recon.shape == img_target.shape, 'Reconstruction and target sizes are different.'

        recons = {'img_recons': img_recon}

        if self.challenge == 'multicoil':
            rss_recon = center_crop(img_recon, (
                self.resolution, self.resolution)) * extra_params['img_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
    def forward(self, cmg_output, targets, extra_params):
        if cmg_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        cmg_target = targets['cmg_targets']
        cmg_recon = nchw_to_kspace(cmg_output)
        assert cmg_recon.shape == cmg_target.shape, 'Reconstruction and target sizes are different.'
        assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        if self.residual_acs:  # Adding the semi-k-space of the ACS as a residual. Necessary due to complex cropping.
            raise NotImplementedError('Not ready yet.')
            # cmg_acs = targets['cmg_acss']
            # cmg_recon = cmg_recon + cmg_acs

        kspace_recon = fft2(cmg_recon)
        img_recon = complex_abs(cmg_recon)

        recons = {
            'kspace_recons': kspace_recon,
            'cmg_recons': cmg_recon,
            'img_recons': img_recon
        }

        if self.challenge == 'multicoil':
            rss_recon = center_crop(img_recon, (
                self.resolution, self.resolution)) * extra_params['cmg_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
    def forward(self, outputs, targets, extra_params):
        img_recon, phase_recon = outputs

        if img_recon.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        img_target = targets['img_targets']
        img_recon = F.relu(img_recon)  # img_recons must be positive numbers.
        assert img_recon.shape == img_target.shape, 'Reconstruction and target sizes are different.'

        # Input transform had addition of pi as pre-processing.
        phase_recon = phase_recon - math.pi  # No clamping implemented since the loss is MSE.
        cmg_recon = torch.stack([
            img_recon * torch.cos(phase_recon),
            img_recon * torch.sin(phase_recon)
        ],
                                dim=-1)
        kspace_recon = fft2(cmg_recon)

        recons = {
            'img_recons': img_recon,
            'phase_recons': phase_recon,
            'cmg_recons': cmg_recon,
            'kspace_recons': kspace_recon
        }

        if self.challenge == 'multicoil':
            rss_recon = center_crop(
                img_recon,
                shape=(self.resolution,
                       self.resolution)) * extra_params['img_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
    def __call__(self, masked_kspace, target, attrs, file_name, slice_num):
        assert isinstance(
            masked_kspace, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if masked_kspace.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            masked_kspace = masked_kspace.expand(1, 1, -1, -1, -1)
        elif masked_kspace.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            masked_kspace = masked_kspace.expand(1, -1, -1, -1, -1)
        elif masked_kspace.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if masked_kspace.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            input_image = complex_abs(ifft2(masked_kspace))
            # Cropping is mandatory if RSS means the 320 version. Not so if a larger image is used.
            # However, I think that removing target processing is worth the loss of flexibility.
            input_image = center_crop(input_image,
                                      shape=(self.resolution, self.resolution))
            img_scale = torch.std(input_image)
            input_image /= img_scale

            extra_params = {'img_scales': img_scale}
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            input_rss = root_sum_of_squares(input_image, dim=1).squeeze()
            targets = {'img_inputs': input_image, 'rss_inputs': input_rss}

        return input_image, targets, extra_params
    def forward(self, semi_kspace_outputs, targets, extra_params):
        if semi_kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        semi_kspace_targets = targets['semi_kspace_targets']
        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (semi_kspace_outputs.size(-1) -
                semi_kspace_targets.size(-2)) // 2
        right = left + semi_kspace_targets.size(-2)

        # Cropping width dimension by pad.
        semi_kspace_recons = nchw_to_kspace(semi_kspace_outputs[...,
                                                                left:right])

        assert semi_kspace_recons.shape == semi_kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recons.size(-3) % 2 == 0) and (semi_kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            semi_kspace_recons = semi_kspace_recons / weighting

        if self.residual_acs:
            num_low_freqs = extra_params['num_low_frequency']
            acs_mask = find_acs_mask(semi_kspace_recons, num_low_freqs)
            semi_kspace_recons = semi_kspace_recons + acs_mask * semi_kspace_targets

        if self.replace:
            mask = extra_params['masks']
            semi_kspace_recons = semi_kspace_recons * (
                1 - mask) + semi_kspace_targets * mask

        kspace_recons = fft1(semi_kspace_recons, direction='height')
        cmg_recons = ifft1(semi_kspace_recons, direction='width')
        img_recons = complex_abs(cmg_recons)

        recons = {
            'semi_kspace_recons': semi_kspace_recons,
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        if self.challenge == 'multicoil':
            rss_recons = center_crop(img_recons,
                                     (self.resolution, self.resolution))
            rss_recons = root_sum_of_squares(rss_recons, dim=1).squeeze()
            rss_recons *= extra_params[
                'sk_scales']  # This value was divided in the inputs. It is thus multiplied here.
            recons['rss_recons'] = rss_recons

        return recons  # Returning scaled reconstructions. Not rescaled. RSS images are rescaled.
    def forward(self, img_output, targets, extra_params):
        if img_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        # Removing values below 0, which are impossible anyway.
        img_recon = F.relu(
            center_crop(img_output, shape=(self.resolution, self.resolution)))
        img_recon *= extra_params['img_scales']

        if self.challenge == 'multicoil':
            img_recon = root_sum_of_squares(img_recon, dim=1)

        return img_recon.squeeze()
Beispiel #8
0
    def forward(self, semi_kspace_output: Tensor, targets: dict,
                extra_params: dict):
        assert semi_kspace_output.dim() == 5 and semi_kspace_output.size(
            1) == 2, 'Invalid shape!'
        if semi_kspace_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        semi_kspace_target = targets['semi_kspace_targets']
        semi_kspace_recon = semi_kspace_output.permute(
            dims=(0, 2, 3, 4, 1))  # Convert back into NCHW2

        if semi_kspace_recon.shape != semi_kspace_target.shape:  # Cropping recon left-right.
            left = (semi_kspace_recon.size(-2) -
                    semi_kspace_target.size(-2)) // 2
            semi_kspace_recon = semi_kspace_recon[
                ..., left:left + semi_kspace_target.size(-2), :]

        assert semi_kspace_recon.shape == semi_kspace_target.shape, 'Reconstruction and target sizes are different.'
        assert (semi_kspace_recon.size(-3) % 2 == 0) and (semi_kspace_recon.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        if self.weighted:
            semi_kspace_recon = semi_kspace_recon / extra_params['weightings']

        if self.replace:
            mask = extra_params['masks']
            semi_kspace_recon = semi_kspace_target * mask + (
                1 - mask) * semi_kspace_recon

        # kspace_recon = fft1(semi_kspace_recon, direction='height')
        cmg_recon = ifft1(semi_kspace_recon, direction='width')
        img_recon = complex_abs(cmg_recon)

        recons = {
            'semi_kspace_recons': semi_kspace_recon,
            'cmg_recons': cmg_recon,
            'img_recons': img_recon
        }

        if self.challenge == 'multicoil':
            rss_recon = center_crop(
                img_recon,
                (self.resolution, self.resolution)) * extra_params['sk_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
    def forward(self, output, targets, extra_params):
        if output.size(0) > 1:
            raise NotImplementedError('Batch size is expected to be 1.')

        if self.output_mode == 'cmg':
            recons = self._cmg_output(output, targets, extra_params)
        elif self.output_mode == 'img':
            recons = self._img_output(output, targets, extra_params)
        else:
            raise NotImplementedError('Invalid output mode.')

        if self.challenge == 'multicoil':
            rss_recon = center_crop(recons['img_recons'],
                                    shape=(self.resolution, self.resolution))
            rss_recon *= extra_params['cmg_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons
    def __call__(self, masked_kspace, target, attrs, file_name, slice_num):
        assert isinstance(
            masked_kspace, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if masked_kspace.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            masked_kspace = masked_kspace.expand(1, 1, -1, -1, -1)
        elif masked_kspace.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            masked_kspace = masked_kspace.expand(1, -1, -1, -1, -1)
        elif masked_kspace.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if masked_kspace.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            image = complex_abs(ifft2(masked_kspace))

            if self.crop_center:
                image = center_crop(image,
                                    shape=(self.resolution, self.resolution))

            img_scale = torch.std(image)
            image /= img_scale

            extra_params = {'img_scales': img_scale}
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            targets = {'img_inputs': image}

            # margin = image.size(-1) % self.divisor
            # if margin > 0:
            #     pad = [(self.divisor - margin) // 2, (1 + self.divisor - margin) // 2]
            # else:  # This is a fix to prevent padding by half the divisor when margin=0.
            #     pad = [0, 0]
            #
            # # This pads at the last dimension of a tensor with 0.
            # inputs = F.pad(image, pad=pad, value=0)

        return image, targets, extra_params
def test_center_crop(shape, target_shape):
    tensor = create_tensor(shape)
    out_torch = data_transforms.center_crop(tensor, target_shape).numpy()
    assert list(out_torch.shape) == target_shape
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            input_image = complex_abs(ifft2(masked_kspace))
            # Cropping is mandatory if RSS means the 320 version. Not so if a larger image is used.
            # However, I think that removing target processing is worth the loss of flexibility.
            input_image = center_crop(input_image,
                                      shape=(self.resolution, self.resolution))
            img_scale = torch.std(input_image)
            input_image /= img_scale

            # No subtraction by the mean. I do not know if this is a good idea or not.

            if self.use_patch:
                assert self.resolution == 320
                left = torch.randint(low=0,
                                     high=self.resolution - self.patch_size,
                                     size=(1, )).squeeze()
                top = torch.randint(low=0,
                                    high=self.resolution - self.patch_size,
                                    size=(1, )).squeeze()

                input_image = input_image[:, :, top:top + self.patch_size,
                                          left:left + self.patch_size]
                target = target[top:top + self.patch_size,
                                left:left + self.patch_size]

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, -1))
                    target = torch.flip(target, dims=(-2, -1))

                elif flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    input_image = torch.flip(input_image, dims=(-1, ))
                    target = torch.flip(target, dims=(-1, ))

            # Use plurals as keys to reduce confusion.
            input_rss = root_sum_of_squares(input_image, dim=1).squeeze()
            targets = {
                'img_inputs': input_image,
                'rss_targets': target,
                'rss_inputs': input_rss
            }

        return input_image, targets, extra_params
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            image = complex_abs(ifft2(masked_kspace))
            if self.crop_center:
                image = center_crop(image,
                                    shape=(self.resolution, self.resolution))
            img_scale = torch.std(image)
            image /= img_scale

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            img_target = complex_abs(ifft2(kspace_target))
            if self.crop_center:
                img_target = center_crop(img_target,
                                         shape=(self.resolution,
                                                self.resolution))
            img_target /= img_scale

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    image = torch.flip(image, dims=(-2, -1))
                    img_target = torch.flip(img_target, dims=(-2, -1))
                    target = torch.flip(target, dims=(-2, -1))

                elif flip_ud:
                    image = torch.flip(image, dims=(-2, ))
                    img_target = torch.flip(img_target, dims=(-2, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    image = torch.flip(image, dims=(-1, ))
                    img_target = torch.flip(img_target, dims=(-1, ))
                    target = torch.flip(target, dims=(-1, ))

            # Use plurals as keys to reduce confusion.
            targets = {'img_targets': img_target, 'img_inputs': image}

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

        return image, targets, extra_params
Beispiel #14
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            input_image = complex_abs(ifft2(masked_kspace))
            # Cropping is mandatory if RSS means the 320 version. Not so if a larger image is used.
            # However, I think that removing target processing is worth the loss of flexibility.
            input_image = center_crop(input_image,
                                      shape=(self.resolution, self.resolution))
            img_scale = torch.std(input_image)
            input_image /= img_scale

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            if target is 0:
                target = center_crop(
                    root_sum_of_squares(complex_abs(ifft2(kspace_target)),
                                        dim=1),
                    shape=(self.resolution, self.resolution)).squeeze()

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, -1))
                    target = torch.flip(target, dims=(-2, -1))

                elif flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    input_image = torch.flip(input_image, dims=(-1, ))
                    target = torch.flip(target, dims=(-1, ))

            # Use plurals as keys to reduce confusion.
            input_rss = root_sum_of_squares(input_image, dim=1).squeeze()
            targets = {
                'img_inputs': input_image,
                'rss_targets': target,
                'rss_inputs': input_rss
            }

            if self.fat_info:
                fat_supp = extra_params[
                    'acquisition'] == 'CORPDFS_FBK'  # Fat suppressed acquisition.
                batch, _, height, width = input_image.shape
                fat_info = torch.ones(size=(batch, 1, height, width),
                                      device=self.device) * fat_supp
                input_image = torch.cat([input_image, fat_info], dim=1)

        return input_image, targets, extra_params
Beispiel #15
0
import torch
import h5py
import numpy as np

from data.data_transforms import root_sum_of_squares, center_crop, ifft2, to_tensor, complex_abs

file = '/media/veritas/D/FastMRI/multicoil_val/file1000229.h5'
with h5py.File(file, 'r') as hf:
    kspace = hf['kspace'][()]
    rss = hf['reconstruction_rss'][()]

kspace = to_tensor(kspace)
image = center_crop(root_sum_of_squares(complex_abs(ifft2(kspace)), dim=1),
                    shape=(320, 320)).squeeze().numpy()
print(np.allclose(image, rss))