def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            image = complex_abs(ifft2(masked_kspace))
            if self.crop_center:
                image = center_crop(image,
                                    shape=(self.resolution, self.resolution))
            else:  # Super-hackish temporary line.  # TODO: Fix this thing later!
                image = center_crop(image, shape=(352, image.size(-1)))

            margin = image.size(-1) % self.divisor
            if margin > 0:
                pad = [(self.divisor - margin) // 2,
                       (1 + self.divisor - margin) // 2]
            else:  # This is a fix to prevent padding by half the divisor when margin=0.
                pad = [0, 0]

            # This pads at the last dimension of a tensor with 0.
            image = F.pad(image, pad=pad, value=0)

            img_scale = torch.std(
                center_crop(image, shape=(self.resolution,
                                          self.resolution)))  # Also a hack!
            image /= img_scale

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            targets = {'img_inputs': image}

        return image, targets, extra_params
def test_apply_uniform_mask(shape, center_fractions, accelerations):
    mask_func = UniformMaskFunc(center_fractions, accelerations)
    expected_mask, expected_info = mask_func(shape, seed=123)
    tensor = create_tensor(shape)
    output, mask, info = data_transforms.apply_info_mask(tensor,
                                                         mask_func,
                                                         seed=123)
    assert isinstance(info, dict)
    assert output.shape == tensor.shape
    assert mask.shape == expected_mask.shape
    assert np.all(expected_mask.numpy() == mask.numpy())
    assert np.all((output * mask).numpy() == output.numpy())
    assert expected_info == info
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            num_low_freqs = info['num_low_frequency']
            acs_mask = self.find_acs_mask(kspace_target, num_low_freqs)
            acs_kspace = kspace_target * acs_mask
            semi_kspace_acs = fft1(complex_center_crop(
                ifft2(acs_kspace), shape=(self.resolution, self.resolution)),
                                   direction='width')

            complex_image = ifft2(masked_kspace)
            complex_image = complex_center_crop(complex_image,
                                                shape=(self.resolution,
                                                       self.resolution))
            # img_input is not actually an input but what the input would look like in the image domain.
            img_input = complex_abs(complex_image)

            # Direction is fixed due to challenge conditions.
            semi_kspace = fft1(complex_image, direction='width')
            weighting = self.weight_func(semi_kspace)
            semi_kspace *= weighting

            sk_scale = torch.std(semi_kspace)
            semi_kspace /= sk_scale
            inputs = kspace_to_nchw(semi_kspace)

            extra_params = {
                'sk_scales': sk_scale,
                'masks': mask,
                'weightings': weighting
            }
            extra_params.update(info)
            extra_params.update(attrs)

            # Recall that the Fourier transform is a linear transform.
            cmg_target = ifft2(kspace_target)
            cmg_target = complex_center_crop(cmg_target,
                                             shape=(self.resolution,
                                                    self.resolution))
            cmg_target /= sk_scale
            img_target = complex_abs(cmg_target)
            semi_kspace_target = fft1(cmg_target, direction='width')
            kspace_target = fft2(cmg_target)

            # Use plurals as keys to reduce confusion.
            targets = {
                'semi_kspace_targets': semi_kspace_target,
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'img_inputs': img_input,
                'semi_kspace_acss': semi_kspace_acs
            }

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

        return inputs, targets, extra_params
Пример #4
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            input_image = complex_abs(ifft2(masked_kspace))
            # Cropping is mandatory if RSS means the 320 version. Not so if a larger image is used.
            # However, I think that removing target processing is worth the loss of flexibility.
            input_image = center_crop(input_image,
                                      shape=(self.resolution, self.resolution))
            img_scale = torch.std(input_image)
            input_image /= img_scale

            # No subtraction by the mean. I do not know if this is a good idea or not.

            if self.use_patch:
                assert self.resolution == 320
                left = torch.randint(low=0,
                                     high=self.resolution - self.patch_size,
                                     size=(1, )).squeeze()
                top = torch.randint(low=0,
                                    high=self.resolution - self.patch_size,
                                    size=(1, )).squeeze()

                input_image = input_image[:, :, top:top + self.patch_size,
                                          left:left + self.patch_size]
                target = target[top:top + self.patch_size,
                                left:left + self.patch_size]

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, -1))
                    target = torch.flip(target, dims=(-2, -1))

                elif flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    input_image = torch.flip(input_image, dims=(-1, ))
                    target = torch.flip(target, dims=(-1, ))

            # Use plurals as keys to reduce confusion.
            input_rss = root_sum_of_squares(input_image, dim=1).squeeze()
            targets = {
                'img_inputs': input_image,
                'rss_targets': target,
                'rss_inputs': input_rss
            }

        return input_image, targets, extra_params
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            image = complex_abs(ifft2(masked_kspace))
            if self.crop_center:
                image = center_crop(image,
                                    shape=(self.resolution, self.resolution))
            img_scale = torch.std(image)
            image /= img_scale

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            img_target = complex_abs(ifft2(kspace_target))
            if self.crop_center:
                img_target = center_crop(img_target,
                                         shape=(self.resolution,
                                                self.resolution))
            img_target /= img_scale

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    image = torch.flip(image, dims=(-2, -1))
                    img_target = torch.flip(img_target, dims=(-2, -1))
                    target = torch.flip(target, dims=(-2, -1))

                elif flip_ud:
                    image = torch.flip(image, dims=(-2, ))
                    img_target = torch.flip(img_target, dims=(-2, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    image = torch.flip(image, dims=(-1, ))
                    img_target = torch.flip(img_target, dims=(-1, ))
                    target = torch.flip(target, dims=(-1, ))

            # Use plurals as keys to reduce confusion.
            targets = {'img_targets': img_target, 'img_inputs': image}

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

        return image, targets, extra_params
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            # Complex image made from down-sampled k-space.
            complex_image = ifft2(masked_kspace)

            if self.crop_center:
                complex_image = complex_center_crop(complex_image,
                                                    shape=(self.resolution,
                                                           self.resolution))

            cmg_scale = torch.std(complex_image)
            complex_image /= cmg_scale

            extra_params = {'cmg_scales': cmg_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Recall that the Fourier transform is a linear transform.
            kspace_target /= cmg_scale
            cmg_target = ifft2(kspace_target)

            if self.crop_center:
                cmg_target = complex_center_crop(cmg_target,
                                                 shape=(self.resolution,
                                                        self.resolution))

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:  # No rotation implemented.
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    # Last dim is real/complex dimension for complex image and target.
                    complex_image = torch.flip(complex_image, dims=(-3, -2))
                    cmg_target = torch.flip(cmg_target, dims=(-3, -2))
                    target = torch.flip(target, dims=(
                        -2, -1))  # Has only two dimensions, height and width.
                    kspace_target = fft2(cmg_target)

                elif flip_ud:
                    complex_image = torch.flip(complex_image, dims=(-3, ))
                    cmg_target = torch.flip(cmg_target, dims=(-3, ))
                    target = torch.flip(target, dims=(-2, ))
                    kspace_target = fft2(cmg_target)

                elif flip_lr:
                    complex_image = torch.flip(complex_image, dims=(-2, ))
                    cmg_target = torch.flip(cmg_target, dims=(-2, ))
                    target = torch.flip(target, dims=(-1, ))
                    kspace_target = fft2(cmg_target)

            # The image target is obtained after flipping the complex image.
            # This removes the need to flip the image target.
            img_target = complex_abs(cmg_target)
            img_inputs = complex_abs(complex_image)

            # Use plurals as keys to reduce confusion.
            targets = {
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'cmg_inputs': complex_image,
                'img_inputs': img_inputs
            }

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

            # Creating concatenated image of real/imag/abs channels.
            concat_image = torch.cat(
                [complex_image, img_inputs.unsqueeze(dim=-1)], dim=-1)

            # Converting to NCHW format for CNN.
            inputs = kspace_to_nchw(concat_image)

        return inputs, targets, extra_params
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            # img_input is not actually an input but what the input would look like in the image domain.
            img_input = complex_abs(ifft2(masked_kspace))

            semi_kspace = ifft1(masked_kspace, direction='height')

            weighting = self.weight_func(semi_kspace)
            semi_kspace *= weighting

            # The slope is meaningless as the results always become the same after standardization no matter the slope.
            # The ordering could be changed to allow a difference, but this would make the inputs non-standardized.
            sk_scale = torch.std(semi_kspace)

            semi_kspace /= sk_scale
            inputs = kspace_to_nchw(semi_kspace)

            extra_params = {
                'sk_scales': sk_scale,
                'masks': mask,
                'weightings': weighting
            }
            extra_params.update(info)
            extra_params.update(attrs)

            # Recall that the Fourier transform is a linear transform.
            kspace_target /= sk_scale
            cmg_target = ifft2(kspace_target)
            img_target = complex_abs(cmg_target)
            semi_kspace_target = ifft1(kspace_target, direction='height')

            # Use plurals as keys to reduce confusion.
            targets = {
                'semi_kspace_targets': semi_kspace_target,
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'img_inputs': img_input
            }

            if kspace_target.size(1) == 15:  # If multi-coil.
                targets[
                    'rss_targets'] = target  # Scaling needed for metric comparison later.

            margin = inputs.size(-1) % self.divisor
            if margin > 0:
                pad = [(self.divisor - margin) // 2,
                       (1 + self.divisor - margin) // 2]
            else:  # This is a fix to prevent padding by half the divisor when margin=0.
                pad = [0, 0]

            # This pads at the last dimension of a tensor with 0.
            inputs = F.pad(inputs, pad=pad, value=0)

        return inputs, targets, extra_params
Пример #8
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            complex_image = ifft2(masked_kspace)
            cmg_target = ifft2(kspace_target)

            if self.crop_center:
                complex_image = complex_center_crop(complex_image,
                                                    shape=(self.resolution,
                                                           self.resolution))
                cmg_target = complex_center_crop(cmg_target,
                                                 shape=(self.resolution,
                                                        self.resolution))

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    # Last dim is real/complex dimension for complex image and target.
                    complex_image = torch.flip(complex_image, dims=(-3, -2))
                    cmg_target = torch.flip(cmg_target, dims=(-3, -2))
                    target = torch.flip(target, dims=(
                        -2, -1))  # Has only two dimensions, height and width.

                elif flip_ud:
                    complex_image = torch.flip(complex_image, dims=(-3, ))
                    cmg_target = torch.flip(cmg_target, dims=(-3, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    complex_image = torch.flip(complex_image, dims=(-2, ))
                    cmg_target = torch.flip(cmg_target, dims=(-2, ))
                    target = torch.flip(target, dims=(-1, ))

            # Adding pi to angles so that the phase is in the [0, 2pi] range for better learning.
            phase_input = torch.atan2(complex_image[..., 1], complex_image[...,
                                                                           0])
            phase_input += math.pi  # Don't forget to remove the pi in the output transform!

            img_input = complex_abs(complex_image)
            img_scale = torch.std(img_input)
            img_input /= img_scale

            cmg_target /= img_scale
            img_target = complex_abs(cmg_target)
            kspace_target = fft2(
                cmg_target
            )  # Reconstruct k-space target after cropping and image augmentation.
            phase_target = torch.atan2(cmg_target[..., 1], cmg_target[..., 0])

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            targets = {
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'phase_targets': phase_target,
                'img_inputs': img_input
            }

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

            # Converting to NCHW format for CNN. Also adding phase input.
            inputs = (img_input, phase_input)

        return inputs, targets, extra_params
Пример #9
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            input_image = complex_abs(ifft2(masked_kspace))
            # Cropping is mandatory if RSS means the 320 version. Not so if a larger image is used.
            # However, I think that removing target processing is worth the loss of flexibility.
            input_image = center_crop(input_image,
                                      shape=(self.resolution, self.resolution))
            img_scale = torch.std(input_image)
            input_image /= img_scale

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            if target is 0:
                target = center_crop(
                    root_sum_of_squares(complex_abs(ifft2(kspace_target)),
                                        dim=1),
                    shape=(self.resolution, self.resolution)).squeeze()

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, -1))
                    target = torch.flip(target, dims=(-2, -1))

                elif flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    input_image = torch.flip(input_image, dims=(-1, ))
                    target = torch.flip(target, dims=(-1, ))

            # Use plurals as keys to reduce confusion.
            input_rss = root_sum_of_squares(input_image, dim=1).squeeze()
            targets = {
                'img_inputs': input_image,
                'rss_targets': target,
                'rss_inputs': input_rss
            }

            if self.fat_info:
                fat_supp = extra_params[
                    'acquisition'] == 'CORPDFS_FBK'  # Fat suppressed acquisition.
                batch, _, height, width = input_image.shape
                fat_info = torch.ones(size=(batch, 1, height, width),
                                      device=self.device) * fat_supp
                input_image = torch.cat([input_image, fat_info], dim=1)

        return input_image, targets, extra_params