Exemplo n.º 1
0
def visualize_from_kspace(kspace_recons, kspace_targets, smoothing_factor=4):
    """
    Assumes that all values are on the same scale and have the same shape.
    """
    image_recons = complex_abs(ifft2(kspace_recons))
    image_targets = complex_abs(ifft2(kspace_targets))
    image_recons, image_targets, image_deltas = make_grid_triplet(
        image_recons, image_targets)
    kspace_targets = make_k_grid(kspace_targets, smoothing_factor)
    kspace_recons = make_k_grid(kspace_recons, smoothing_factor)
    return kspace_recons, kspace_targets, image_recons, image_targets, image_deltas
    def forward(self, kspace_outputs, targets, extra_params):
        if kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one slice at a time for now.')

        kspace_targets = targets['kspace_targets']

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (kspace_outputs.size(-1) - kspace_targets.size(-2)) // 2
        right = left + kspace_targets.size(-2)

        # Cropping width dimension by pad.
        kspace_recons = nchw_to_kspace(kspace_outputs[..., left:right])

        assert kspace_recons.shape == kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (kspace_recons.size(-3) % 2 == 0) and (kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            kspace_recons = kspace_recons / weighting

        if self.replace:  # Replace with original k-space if replace=True
            mask = extra_params['masks']
            kspace_recons = kspace_recons * (1 - mask) + kspace_targets * mask

        cmg_recons = ifft2(kspace_recons)
        img_recons = complex_abs(cmg_recons)
        recons = {
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        return recons  # Returning scaled reconstructions. Not rescaled.
Exemplo n.º 3
0
    def forward(self, k_output, targets, scales):
        """

        Args:
            k_output (torch.Tensor):
            targets (list):
            scales (list):

        Returns:
            image_recons (list):
            kspace_recons (list):

        """
        assert k_output.size(0) == len(targets) == len(scales)
        image_recons = list()
        kspace_recons = list()

        for k_slice, target, scaling in zip(k_output, targets, scales):

            left = (k_slice.size(-1) - target.size(-1)) // 2
            right = left + target.size(-1)

            k_slice_recon = chw_to_k_slice(k_slice[..., left:right] * scaling)
            i_slice_recon = complex_abs(ifft2(k_slice_recon))

            assert i_slice_recon.shape == target.shape
            image_recons.append(i_slice_recon)
            kspace_recons.append(k_slice_recon)

        return image_recons, kspace_recons
    def __call__(self, masked_kspace, target, attrs, file_name, slice_num):
        assert isinstance(
            masked_kspace, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if masked_kspace.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            masked_kspace = masked_kspace.expand(1, 1, -1, -1, -1)
        elif masked_kspace.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            masked_kspace = masked_kspace.expand(1, -1, -1, -1, -1)
        elif masked_kspace.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if masked_kspace.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            input_image = complex_abs(ifft2(masked_kspace))
            # Cropping is mandatory if RSS means the 320 version. Not so if a larger image is used.
            # However, I think that removing target processing is worth the loss of flexibility.
            input_image = center_crop(input_image,
                                      shape=(self.resolution, self.resolution))
            img_scale = torch.std(input_image)
            input_image /= img_scale

            extra_params = {'img_scales': img_scale}
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            input_rss = root_sum_of_squares(input_image, dim=1).squeeze()
            targets = {'img_inputs': input_image, 'rss_inputs': input_rss}

        return input_image, targets, extra_params
Exemplo n.º 5
0
    def __call__(self, k_slice, target, attrs, file_name, slice_num):
        """
        Args:
            k_slice (numpy.array): Input k-space of shape (num_coils, height, width) for multi-coil
                data or (rows, cols) for single coil data.
            target (numpy.array): Target (320x320) image. May be None.
            attrs (dict): Acquisition related information stored in the HDF5 object.
            file_name (str): File name
            slice_num (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                data (torch.Tensor): kspace data converted to CHW format for CNNs, where C=(2*num_coils).
                    Also has padding in the width axis for auto-encoders, which have down-sampling regions.
                    This requires the data to be divisible by some number (usually 2**num_pooling_layers).
                    Otherwise, concatenation will not work in the decoder due to different sizes.
                    Only the width dimension is padded in this case due to the nature of the dataset.
                    The height is fixed at 640, while the width is variable.
                labels (torch.Tensor): Coil-wise ground truth images. Shape=(num_coils, H, W)
        """
        assert np.iscomplexobj(k_slice), 'kspace must be complex.'
        assert k_slice.shape[-1] % 2 == 0, 'k-space data width must be even.'

        if k_slice.ndim == 2:  # For singlecoil. Makes data processing later on much easier.
            k_slice = np.expand_dims(k_slice, axis=0)
        elif k_slice.ndim != 3:  # Prevents possible errors.
            raise TypeError('Invalid slice type')

        with torch.no_grad():  # Remove unnecessary gradient calculations.
            # Now a Tensor of (num_coils, height, width, 2), where 2 is (real, imag).
            # The data is in the GPU and has been amplified by the amplification factor.
            k_slice = to_tensor(k_slice).to(device=self.device) * self.amp_fac
            # k_slice = to_tensor(k_slice).cuda(self.device) * self.amp_fac
            target_slice = complex_abs(ifft2(k_slice))  # I need cuda here!
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask = apply_mask(k_slice, self.mask_func, seed)

            data_slice = k_slice_to_chw(masked_kspace)
            # assert data_slice.size(-1) % 2 == 0

            margin = (data_slice.shape[-1] % self.divisor)

            if margin > 0:
                pad = [(self.divisor - margin) // 2,
                       (1 + self.divisor - margin) // 2]
            else:  # This is a temporary fix.
                pad = [0, 0]
            # right_pad = self.divisor - left_pad
            # pad = [pad, pad]
            data_slice = F.pad(
                data_slice, pad=pad,
                value=0)  # This pads at the last dimension of a tensor.

            # Using the data acquisition method (fat suppression) may be useful later on.
        # print(1, data_slice.size())
        # print(2, target_slice.size())
        return data_slice, target_slice
def test_ifft2(shape):
    shape = shape + [2]
    tensor = create_tensor(shape)
    out_torch = data_transforms.ifft2(tensor).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    tensor_numpy = data_transforms.tensor_to_complex_np(tensor)
    tensor_numpy = np.fft.ifftshift(tensor_numpy, (-2, -1))
    out_numpy = np.fft.ifft2(tensor_numpy, norm='ortho')
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))
    assert np.allclose(out_torch, out_numpy)
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            image = complex_abs(ifft2(masked_kspace))
            if self.crop_center:
                image = center_crop(image,
                                    shape=(self.resolution, self.resolution))
            else:  # Super-hackish temporary line.  # TODO: Fix this thing later!
                image = center_crop(image, shape=(352, image.size(-1)))

            margin = image.size(-1) % self.divisor
            if margin > 0:
                pad = [(self.divisor - margin) // 2,
                       (1 + self.divisor - margin) // 2]
            else:  # This is a fix to prevent padding by half the divisor when margin=0.
                pad = [0, 0]

            # This pads at the last dimension of a tensor with 0.
            image = F.pad(image, pad=pad, value=0)

            img_scale = torch.std(
                center_crop(image, shape=(self.resolution,
                                          self.resolution)))  # Also a hack!
            image /= img_scale

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            targets = {'img_inputs': image}

        return image, targets, extra_params
Exemplo n.º 8
0
    def __call__(self, k_slice, target, attrs, file_name, slice_num):
        assert np.iscomplexobj(k_slice), 'kspace must be complex.'

        if k_slice.ndim == 2:  # For singlecoil. Makes data processing later on much easier.
            k_slice = np.expand_dims(k_slice, axis=0)
        elif k_slice.ndim != 3:  # Prevents possible errors.
            raise RuntimeError(
                'Invalid slice shape. Please check input shape.')

        with torch.no_grad():  # Remove unnecessary gradient calculations.
            # Now a Tensor of (num_coils, height, width, 2), where 2 is (real, imag).
            kspace_target = to_tensor(k_slice).to(device=self.device)
            c_img_target = k_slice_to_chw(
                ifft2(kspace_target))  # Assumes only C2C will be calculated.

            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask = apply_mask(kspace_target, self.mask_func,
                                             seed)

            c_img_input = k_slice_to_chw(ifft2(masked_kspace))
            c_scale = torch.std(c_img_input)
            c_img_input *= (torch.tensor(1) / c_scale)
            c_bias = torch.mean(c_img_input)
            c_img_input -= c_bias

            margin = c_img_input.size(-1) % self.divisor

            if margin > 0:  # Cut off instead of adding padding.
                left = margin // 2
                right = (margin + 1) // 2
                assert c_img_input.size() == c_img_target.size()
                c_img_input = c_img_input[..., left:-right]
                c_img_target = c_img_target[..., left:-right]

        assert c_img_input.size() == c_img_target.size()

        return c_img_input, c_img_target, (c_scale, c_bias)
Exemplo n.º 9
0
    def forward(self, tensor,
                out_shape):  # Using out_shape only works for batch size of 1.
        """
        Args:
            tensor (torch.Tensor): Input tensor of shape [batch_size, in_chans, height, width]
            out_shape (tuple): shape [batch_size, num_coils, true_height, true_width].
            Note that in_chans = 2 * num_coils
        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """
        stack = list()
        output = tensor
        # Apply down-sampling layers
        for layer in self.down_sample_layers:
            output = layer(output)
            stack.append(output)
            output = F.max_pool2d(output, kernel_size=2)

        output = self.conv(output)

        # Apply up-sampling layers
        for layer in self.up_sample_layers:
            output = F.interpolate(output,
                                   scale_factor=2,
                                   mode='bilinear',
                                   align_corners=False)
            output = torch.cat((output, stack.pop()), dim=1)
            output = layer(output)

        output = self.conv2(output)  # End of learning.

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (output.size(-1) - out_shape[-1]
                ) // 2  # This depends on mini-batch size being 1 to work.
        right = left + out_shape[-1]

        # Previously, cropping was done by  [pad:-pad]. However, this fails catastrophically when pad=0 as
        # all values are wiped out in those cases where [0:0] creates an empty tensor.

        # Cropping width dimension by pad.
        output = output[..., left:right]

        # Processing to k-space form.
        output = nchw_to_kspace(output)

        # Convert to image.
        output = complex_abs(ifft2(output))

        assert output.size() == out_shape  # Checking just in case.
        return output
Exemplo n.º 10
0
    def forward(self, kspace_output, c_img_target, extra_params):
        kspace_target, k_scale, mask = extra_params

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (kspace_output.size(-1) - kspace_target.size(-2)) // 2
        right = left + kspace_target.size(-2)

        # Cropping width dimension by pad. Multiply by scales to restore the original scaling.
        kspace_recon = exp_weighting(nchw_to_kspace(kspace_output[..., left:right]), scale=self.log_amp_scale) * k_scale
        assert kspace_recon.size() == kspace_target.size(), 'Reconstruction and target sizes do not match.'

        kspace_recon = kspace_recon * (1 - mask) + kspace_target * mask
        c_img_recons = ifft2(kspace_recon)

        return c_img_recons
Exemplo n.º 11
0
    def forward(self, kspace_outputs, targets, extra_params):
        if kspace_outputs.size(0) > 1:
            raise NotImplementedError('Only one slice at a time for now.')

        kspace_targets = targets['kspace_targets']

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (kspace_outputs.size(-1) - kspace_targets.size(-2)) // 2
        right = left + kspace_targets.size(-2)

        # Cropping width dimension by pad.
        kspace_recons = nchw_to_kspace(kspace_outputs[..., left:right])
        assert kspace_recons.shape == kspace_targets.shape, 'Reconstruction and target sizes are different.'
        assert (kspace_recons.size(-3) % 2 == 0) and (kspace_recons.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        # Removing weighting.
        if self.weighted:
            weighting = extra_params['weightings']
            kspace_recons = kspace_recons / weighting

        if self.residual_acs:
            num_low_freqs = extra_params['num_low_frequency']
            acs_mask = find_acs_mask(kspace_recons, num_low_freqs)
            kspace_recons = kspace_recons + acs_mask * kspace_targets

        if self.replace:  # Replace with original k-space if replace=True
            mask = extra_params['masks']
            kspace_recons = kspace_recons * (1 - mask) + kspace_targets * mask

        cmg_recons = ifft2(kspace_recons)
        img_recons = complex_abs(cmg_recons)
        recons = {
            'kspace_recons': kspace_recons,
            'cmg_recons': cmg_recons,
            'img_recons': img_recons
        }

        if img_recons.size(1) == 15:
            top = (img_recons.size(-2) - self.resolution) // 2
            left = (img_recons.size(-1) - self.resolution) // 2
            rss_recon = img_recons[:, :, top:top + self.resolution,
                                   left:left + self.resolution]
            rss_recon = root_sum_of_squares(
                rss_recon, dim=1).squeeze()  # rss_recon is in 2D
            recons['rss_recons'] = rss_recon

        return recons  # Returning scaled reconstructions. Not rescaled.
    def __call__(self, masked_kspace, target, attrs, file_name, slice_num):
        assert isinstance(
            masked_kspace, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if masked_kspace.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            masked_kspace = masked_kspace.expand(1, 1, -1, -1, -1)
        elif masked_kspace.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            masked_kspace = masked_kspace.expand(1, -1, -1, -1, -1)
        elif masked_kspace.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if masked_kspace.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            image = complex_abs(ifft2(masked_kspace))

            if self.crop_center:
                image = center_crop(image,
                                    shape=(self.resolution, self.resolution))

            img_scale = torch.std(image)
            image /= img_scale

            extra_params = {'img_scales': img_scale}
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            targets = {'img_inputs': image}

            # margin = image.size(-1) % self.divisor
            # if margin > 0:
            #     pad = [(self.divisor - margin) // 2, (1 + self.divisor - margin) // 2]
            # else:  # This is a fix to prevent padding by half the divisor when margin=0.
            #     pad = [0, 0]
            #
            # # This pads at the last dimension of a tensor with 0.
            # inputs = F.pad(image, pad=pad, value=0)

        return image, targets, extra_params
Exemplo n.º 13
0
    def forward(self, kspace_output, c_img_target, extra_params):
        kspace_target, scaling, mask = extra_params

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (kspace_output.size(-1) - kspace_target.size(-2)) // 2
        right = left + kspace_target.size(-2)

        # Cropping width dimension by pad. Multiply by scales to restore the original scaling.
        k_output = kspace_output[..., left:right] * scaling

        # Processing to k-space form. This is where the batch_size == 1 is important.
        kspace_recon = nchw_to_kspace(k_output)

        assert kspace_recon.size() == kspace_target.size(), 'Reconstruction and target sizes do not match.'

        kspace_recon = kspace_recon * (1 - mask) + kspace_target * mask

        c_img_recons = ifft2(kspace_recon)

        return c_img_recons
Exemplo n.º 14
0
    def __call__(self, k_slice, target, attrs, file_name, slice_num):
        assert np.iscomplexobj(k_slice), 'kspace must be complex.'
        # assert k_slice.shape[-1] % 2 == 0, 'k-space data width must be even.'

        if k_slice.ndim == 2:  # For singlecoil. Makes data processing later on much easier.
            k_slice = np.expand_dims(k_slice, axis=0)
        elif k_slice.ndim != 3:  # Prevents possible errors.
            raise RuntimeError(
                'Invalid slice shape. Please check input shape.')

        with torch.no_grad():  # Remove unnecessary gradient calculations.
            # Now a Tensor of (num_coils, height, width, 2), where 2 is (real, imag).
            k_slice = to_tensor(k_slice).to(device=self.device)
            scaling = torch.std(
                k_slice)  # Pseudo-standard deviation for normalization.
            target_slice = complex_abs(
                ifft2(k_slice))  # Labels are not standardized.
            k_slice *= (torch.ones(
                ()) / scaling)  # Standardization of CNN inputs.
            # Using weird multiplication because multiplication is much faster than division.
            # Multiplying the whole tensor by 1/scaling is faster than dividing the whole tensor by scaling.

            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask = apply_mask(k_slice, self.mask_func, seed)

            data_slice = k_slice_to_chw(masked_kspace)

            margin = data_slice.size(-1) % self.divisor

            if margin > 0:
                pad = [(self.divisor - margin) // 2,
                       (1 + self.divisor - margin) // 2]
            else:  # This is a temporary fix to prevent padding by half the divisor when margin=0.
                pad = [0, 0]

            data_slice = F.pad(
                data_slice, pad=pad,
                value=0)  # This pads at the last dimension of a tensor with 0.

        return data_slice, target_slice, scaling  # This has a different output API.
Exemplo n.º 15
0
    def __call__(self, k_slice, target_slice):
        """
        'k_slice' should be in the (C, H, W) format.
        'target_slice' should be (Coil, Height, Width). It should be all real values. C = 2 * Coil
        k_slice should be padded while target_slice should not be padded.
        Singlecoil and Multicoil data both have the same dimensions. Singlecoil has Coil = 1
        """
        # assert k_slice.dim() == target_slice.dim() == 3
        # assert k_slice.size(0) / 2 == target_slice.size(0)

        # Remove padding, etc.
        k_slice = self.k_slice_fn(k_slice, target_slice)

        # Convert to image domain.
        recon_slice = complex_abs(ifft2(k_slice))

        # Image domain post-processing.
        recon_slice = self.img_slice_fn(recon_slice, target_slice)

        assert recon_slice.size() == target_slice.size(), 'Shape conversion went wrong somewhere.'
        return recon_slice
Exemplo n.º 16
0
    def __call__(self, k_slice, target, attrs, file_name, slice_num):
        assert np.iscomplexobj(k_slice), 'kspace must be complex.'

        if k_slice.ndim == 2:  # For singlecoil. Makes data processing later on much easier.
            k_slice = np.expand_dims(k_slice, axis=0)
        elif k_slice.ndim != 3:  # Prevents possible errors.
            raise RuntimeError(
                'Invalid slice shape. Please check input shape.')

        with torch.no_grad():  # Remove unnecessary gradient calculations.
            # Now a Tensor of (num_coils, height, width, 2), where 2 is (real, imag).
            kspace_target = to_tensor(k_slice).to(device=self.device)
            c_img_target = ifft2(kspace_target)

            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask = apply_mask(kspace_target, self.mask_func,
                                             seed)

            # Multiplying the whole tensor by 1/scaling is faster than dividing the whole tensor by scaling.
            k_scale = torch.std(
                masked_kspace)  # Pseudo-standard deviation for normalization.
            masked_kspace *= (torch.tensor(1) / k_scale
                              )  # Standardization of CNN inputs.
            # Weighting is performed here since it is shape independent and inputs of 0 result in outputs of 0.
            masked_kspace = log_weighting(k_slice_to_chw(masked_kspace),
                                          scale=self.log_amp_scale)
            margin = masked_kspace.size(-1) % self.divisor

            if margin > 0:
                pad = [(self.divisor - margin) // 2,
                       (1 + self.divisor - margin) // 2]
            else:  # This is a temporary fix to prevent padding by half the divisor when margin=0.
                pad = [0, 0]

            # This pads at the last dimension of a tensor with 0.
            masked_kspace = F.pad(masked_kspace, pad=pad, value=0)

        return masked_kspace, c_img_target, (kspace_target.unsqueeze(dim=0),
                                             k_scale, mask)
Exemplo n.º 17
0
    def forward(self, cmg_output: Tensor, targets: dict, extra_params: dict):
        assert cmg_output.dim() == 5 and cmg_output.size(
            1) == 2, 'Invalid shape!'
        if cmg_output.size(0) > 1:
            raise NotImplementedError('Only one at a time for now.')

        kspace_target = targets['kspace_targets']
        cmg_recon = cmg_output.permute(dims=(0, 2, 3, 4,
                                             1))  # Convert back into NCHW2

        if cmg_recon.shape != kspace_target.shape:  # Cropping recon left-right.
            left = (cmg_recon.size(-2) - kspace_target.size(-2)) // 2
            cmg_recon = cmg_recon[..., left:left + kspace_target.size(-2), :]

        assert cmg_recon.shape == kspace_target.shape, 'Reconstruction and target sizes are different.'
        assert (cmg_recon.size(-3) % 2 == 0) and (cmg_recon.size(-2) % 2 == 0), \
            'Not impossible but not expected to have sides with odd lengths.'

        kspace_recon = fft2(cmg_recon)

        if self.replace_kspace:
            mask = extra_params['masks']
            kspace_recon = kspace_target * mask + (1 - mask) * kspace_recon
            cmg_recon = ifft2(kspace_recon)

        img_recon = complex_abs(cmg_recon)

        # recons = {'kspace_recons': kspace_recon, 'cmg_recons': cmg_recon, 'img_recons': img_recon}
        recons = dict()

        if self.challenge == 'multicoil':
            rss_recon = center_crop(img_recon, (
                self.resolution, self.resolution)) * extra_params['cmg_scales']
            rss_recon = root_sum_of_squares(rss_recon, dim=1).squeeze()
            recons['rss_recons'] = rss_recon

        return recons  # recons are not rescaled except rss_recons.
Exemplo n.º 18
0
    def forward(self, k_output, targets, scales):
        """
            Output post-processing for output k-space tensor with batch size of 1.
            This is experimental and is subject to change.
            Planning on taking k-space outputs from CNNs, then transforming them into original k-space shapes.
            No batch size planned yet.

            Args:
                k_output (torch.Tensor): CNN output of k-space. Expected to have batch-size of 1.
                targets (torch.Tensor): Target image domain data.
                scales (torch.Tensor): scaling factor used to divide the input k-space slice data.
            Returns:
                kspace (torch.Tensor): kspace in original shape with batch dimension in-place.
        """

        # For removing width dimension padding. Recall that k-space form has 2 as last dim size.
        left = (k_output.size(-1) - targets.size(-1)
                ) // 2  # This depends on mini-batch size being 1 to work.
        right = left + targets.size(-1)

        # Previously, cropping was done by  [pad:-pad]. However, this fails catastrophically when pad == 0 as
        # all values are wiped out in those cases where [0:0] creates an empty tensor.

        # Cropping width dimension by pad. Multiply by scales to restore the original scaling.
        k_output = k_output[..., left:right] * scales

        # Processing to k-space form. This is where the batch_size == 1 is important.
        kspace_recons = nchw_to_kspace(k_output)

        # Convert to image.
        image_recons = complex_abs(ifft2(kspace_recons))

        assert image_recons.size() == targets.size(
        ), 'Reconstruction and target sizes do not match.'

        return image_recons, kspace_recons
Exemplo n.º 19
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            complex_image = ifft2(masked_kspace)
            cmg_target = ifft2(kspace_target)

            if self.crop_center:
                complex_image = complex_center_crop(complex_image,
                                                    shape=(self.resolution,
                                                           self.resolution))
                cmg_target = complex_center_crop(cmg_target,
                                                 shape=(self.resolution,
                                                        self.resolution))

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    # Last dim is real/complex dimension for complex image and target.
                    complex_image = torch.flip(complex_image, dims=(-3, -2))
                    cmg_target = torch.flip(cmg_target, dims=(-3, -2))
                    target = torch.flip(target, dims=(
                        -2, -1))  # Has only two dimensions, height and width.

                elif flip_ud:
                    complex_image = torch.flip(complex_image, dims=(-3, ))
                    cmg_target = torch.flip(cmg_target, dims=(-3, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    complex_image = torch.flip(complex_image, dims=(-2, ))
                    cmg_target = torch.flip(cmg_target, dims=(-2, ))
                    target = torch.flip(target, dims=(-1, ))

            # Adding pi to angles so that the phase is in the [0, 2pi] range for better learning.
            phase_input = torch.atan2(complex_image[..., 1], complex_image[...,
                                                                           0])
            phase_input += math.pi  # Don't forget to remove the pi in the output transform!

            img_input = complex_abs(complex_image)
            img_scale = torch.std(img_input)
            img_input /= img_scale

            cmg_target /= img_scale
            img_target = complex_abs(cmg_target)
            kspace_target = fft2(
                cmg_target
            )  # Reconstruct k-space target after cropping and image augmentation.
            phase_target = torch.atan2(cmg_target[..., 1], cmg_target[..., 0])

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Use plurals as keys to reduce confusion.
            targets = {
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'phase_targets': phase_target,
                'img_inputs': img_input
            }

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

            # Converting to NCHW format for CNN. Also adding phase input.
            inputs = (img_input, phase_input)

        return inputs, targets, extra_params
Exemplo n.º 20
0
import torch
import numpy as np
from data.data_transforms import ifft2, fft2

a = torch.rand(20, 40, 60, 92, 2, device='cuda:1')
b = ifft2(a * 1E8)
c = fft2(b) * 1E-8

# print(torch.all(a == c))
# print(torch.allclose(a, c, rtol=0.01))
eps = np.finfo(np.float64).eps

print(torch.max(c / (a + eps)))
print(torch.min(c / (a + eps)))
print(torch.mean(c / (a + eps)).cpu().numpy())

# print(torch.sum(a != c) / (20 * 40 * 60 * 92 * 2))
Exemplo n.º 21
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            num_low_freqs = info['num_low_frequency']
            acs_mask = self.find_acs_mask(kspace_target, num_low_freqs)
            acs_kspace = kspace_target * acs_mask
            semi_kspace_acs = fft1(complex_center_crop(
                ifft2(acs_kspace), shape=(self.resolution, self.resolution)),
                                   direction='width')

            complex_image = ifft2(masked_kspace)
            complex_image = complex_center_crop(complex_image,
                                                shape=(self.resolution,
                                                       self.resolution))
            # img_input is not actually an input but what the input would look like in the image domain.
            img_input = complex_abs(complex_image)

            # Direction is fixed due to challenge conditions.
            semi_kspace = fft1(complex_image, direction='width')
            weighting = self.weight_func(semi_kspace)
            semi_kspace *= weighting

            sk_scale = torch.std(semi_kspace)
            semi_kspace /= sk_scale
            inputs = kspace_to_nchw(semi_kspace)

            extra_params = {
                'sk_scales': sk_scale,
                'masks': mask,
                'weightings': weighting
            }
            extra_params.update(info)
            extra_params.update(attrs)

            # Recall that the Fourier transform is a linear transform.
            cmg_target = ifft2(kspace_target)
            cmg_target = complex_center_crop(cmg_target,
                                             shape=(self.resolution,
                                                    self.resolution))
            cmg_target /= sk_scale
            img_target = complex_abs(cmg_target)
            semi_kspace_target = fft1(cmg_target, direction='width')
            kspace_target = fft2(cmg_target)

            # Use plurals as keys to reduce confusion.
            targets = {
                'semi_kspace_targets': semi_kspace_target,
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'img_inputs': img_input,
                'semi_kspace_acss': semi_kspace_acs
            }

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

        return inputs, targets, extra_params
Exemplo n.º 22
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            input_image = complex_abs(ifft2(masked_kspace))
            # Cropping is mandatory if RSS means the 320 version. Not so if a larger image is used.
            # However, I think that removing target processing is worth the loss of flexibility.
            input_image = center_crop(input_image,
                                      shape=(self.resolution, self.resolution))
            img_scale = torch.std(input_image)
            input_image /= img_scale

            # No subtraction by the mean. I do not know if this is a good idea or not.

            if self.use_patch:
                assert self.resolution == 320
                left = torch.randint(low=0,
                                     high=self.resolution - self.patch_size,
                                     size=(1, )).squeeze()
                top = torch.randint(low=0,
                                    high=self.resolution - self.patch_size,
                                    size=(1, )).squeeze()

                input_image = input_image[:, :, top:top + self.patch_size,
                                          left:left + self.patch_size]
                target = target[top:top + self.patch_size,
                                left:left + self.patch_size]

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, -1))
                    target = torch.flip(target, dims=(-2, -1))

                elif flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    input_image = torch.flip(input_image, dims=(-1, ))
                    target = torch.flip(target, dims=(-1, ))

            # Use plurals as keys to reduce confusion.
            input_rss = root_sum_of_squares(input_image, dim=1).squeeze()
            targets = {
                'img_inputs': input_image,
                'rss_targets': target,
                'rss_inputs': input_rss
            }

        return input_image, targets, extra_params
Exemplo n.º 23
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            image = complex_abs(ifft2(masked_kspace))
            if self.crop_center:
                image = center_crop(image,
                                    shape=(self.resolution, self.resolution))
            img_scale = torch.std(image)
            image /= img_scale

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            img_target = complex_abs(ifft2(kspace_target))
            if self.crop_center:
                img_target = center_crop(img_target,
                                         shape=(self.resolution,
                                                self.resolution))
            img_target /= img_scale

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    image = torch.flip(image, dims=(-2, -1))
                    img_target = torch.flip(img_target, dims=(-2, -1))
                    target = torch.flip(target, dims=(-2, -1))

                elif flip_ud:
                    image = torch.flip(image, dims=(-2, ))
                    img_target = torch.flip(img_target, dims=(-2, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    image = torch.flip(image, dims=(-1, ))
                    img_target = torch.flip(img_target, dims=(-1, ))
                    target = torch.flip(target, dims=(-1, ))

            # Use plurals as keys to reduce confusion.
            targets = {'img_targets': img_target, 'img_inputs': image}

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

        return image, targets, extra_params
Exemplo n.º 24
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            # img_input is not actually an input but what the input would look like in the image domain.
            img_input = complex_abs(ifft2(masked_kspace))

            semi_kspace = ifft1(masked_kspace, direction='height')

            weighting = self.weight_func(semi_kspace)
            semi_kspace *= weighting

            # The slope is meaningless as the results always become the same after standardization no matter the slope.
            # The ordering could be changed to allow a difference, but this would make the inputs non-standardized.
            sk_scale = torch.std(semi_kspace)

            semi_kspace /= sk_scale
            inputs = kspace_to_nchw(semi_kspace)

            extra_params = {
                'sk_scales': sk_scale,
                'masks': mask,
                'weightings': weighting
            }
            extra_params.update(info)
            extra_params.update(attrs)

            # Recall that the Fourier transform is a linear transform.
            kspace_target /= sk_scale
            cmg_target = ifft2(kspace_target)
            img_target = complex_abs(cmg_target)
            semi_kspace_target = ifft1(kspace_target, direction='height')

            # Use plurals as keys to reduce confusion.
            targets = {
                'semi_kspace_targets': semi_kspace_target,
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'img_inputs': img_input
            }

            if kspace_target.size(1) == 15:  # If multi-coil.
                targets[
                    'rss_targets'] = target  # Scaling needed for metric comparison later.

            margin = inputs.size(-1) % self.divisor
            if margin > 0:
                pad = [(self.divisor - margin) // 2,
                       (1 + self.divisor - margin) // 2]
            else:  # This is a fix to prevent padding by half the divisor when margin=0.
                pad = [0, 0]

            # This pads at the last dimension of a tensor with 0.
            inputs = F.pad(inputs, pad=pad, value=0)

        return inputs, targets, extra_params
Exemplo n.º 25
0
import torch
import h5py

from data.data_transforms import ifft2, to_tensor, root_sum_of_squares, center_crop, complex_center_crop, complex_abs

file = '/media/veritas/D/FastMRI/multicoil_val/file1001798.h5'
sdx = 10
with h5py.File(file, mode='r') as hf:
    kspace = hf['kspace'][sdx]
    target = hf['reconstruction_rss'][sdx]

cmg_scale = 2E-5
recon = complex_center_crop(ifft2(to_tensor(kspace) / cmg_scale),
                            shape=(320, 320)) * cmg_scale
recon = root_sum_of_squares(complex_abs(recon))
target = to_tensor(target)

print(torch.allclose(recon, target))
Exemplo n.º 26
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            input_image = complex_abs(ifft2(masked_kspace))
            # Cropping is mandatory if RSS means the 320 version. Not so if a larger image is used.
            # However, I think that removing target processing is worth the loss of flexibility.
            input_image = center_crop(input_image,
                                      shape=(self.resolution, self.resolution))
            img_scale = torch.std(input_image)
            input_image /= img_scale

            extra_params = {'img_scales': img_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            if target is 0:
                target = center_crop(
                    root_sum_of_squares(complex_abs(ifft2(kspace_target)),
                                        dim=1),
                    shape=(self.resolution, self.resolution)).squeeze()

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, -1))
                    target = torch.flip(target, dims=(-2, -1))

                elif flip_ud:
                    input_image = torch.flip(input_image, dims=(-2, ))
                    target = torch.flip(target, dims=(-2, ))

                elif flip_lr:
                    input_image = torch.flip(input_image, dims=(-1, ))
                    target = torch.flip(target, dims=(-1, ))

            # Use plurals as keys to reduce confusion.
            input_rss = root_sum_of_squares(input_image, dim=1).squeeze()
            targets = {
                'img_inputs': input_image,
                'rss_targets': target,
                'rss_inputs': input_rss
            }

            if self.fat_info:
                fat_supp = extra_params[
                    'acquisition'] == 'CORPDFS_FBK'  # Fat suppressed acquisition.
                batch, _, height, width = input_image.shape
                fat_info = torch.ones(size=(batch, 1, height, width),
                                      device=self.device) * fat_supp
                input_image = torch.cat([input_image, fat_info], dim=1)

        return input_image, targets, extra_params
Exemplo n.º 27
0
def check_invertible():
    orig = torch.rand(4, 6, 8, 12, 2, dtype=torch.float64) * 1024 - 64
    trans = fft2(orig) * 100
    trans = ifft2(trans) / 100
    print(torch.allclose(orig, trans))
Exemplo n.º 28
0
import torch
import h5py
import numpy as np

from data.data_transforms import root_sum_of_squares, center_crop, ifft2, to_tensor, complex_abs

file = '/media/veritas/D/FastMRI/multicoil_val/file1000229.h5'
with h5py.File(file, 'r') as hf:
    kspace = hf['kspace'][()]
    rss = hf['reconstruction_rss'][()]

kspace = to_tensor(kspace)
image = center_crop(root_sum_of_squares(complex_abs(ifft2(kspace)), dim=1),
                    shape=(320, 320)).squeeze().numpy()
print(np.allclose(image, rss))
Exemplo n.º 29
0
    def __call__(self, kspace_target, target, attrs, file_name, slice_num):
        assert isinstance(
            kspace_target, torch.Tensor
        ), 'k-space target was expected to be a Pytorch Tensor.'
        if kspace_target.dim(
        ) == 3:  # If the collate function does not expand dimensions for single-coil.
            kspace_target = kspace_target.expand(1, 1, -1, -1, -1)
        elif kspace_target.dim(
        ) == 4:  # If the collate function does not expand dimensions for multi-coil.
            kspace_target = kspace_target.expand(1, -1, -1, -1, -1)
        elif kspace_target.dim(
        ) != 5:  # Expanded k-space should have 5 dimensions.
            raise RuntimeError('k-space target has invalid shape!')

        if kspace_target.size(0) != 1:
            raise NotImplementedError('Batch size should be 1 for now.')

        with torch.no_grad():
            # Apply mask
            seed = None if not self.use_seed else tuple(map(ord, file_name))
            masked_kspace, mask, info = apply_info_mask(
                kspace_target, self.mask_func, seed)

            # Complex image made from down-sampled k-space.
            complex_image = ifft2(masked_kspace)

            if self.crop_center:
                complex_image = complex_center_crop(complex_image,
                                                    shape=(self.resolution,
                                                           self.resolution))

            cmg_scale = torch.std(complex_image)
            complex_image /= cmg_scale

            extra_params = {'cmg_scales': cmg_scale, 'masks': mask}
            extra_params.update(info)
            extra_params.update(attrs)

            # Recall that the Fourier transform is a linear transform.
            kspace_target /= cmg_scale
            cmg_target = ifft2(kspace_target)

            if self.crop_center:
                cmg_target = complex_center_crop(cmg_target,
                                                 shape=(self.resolution,
                                                        self.resolution))

            # Data augmentation by flipping images up-down and left-right.
            if self.augment_data:  # No rotation implemented.
                flip_lr = torch.rand(()) < 0.5
                flip_ud = torch.rand(()) < 0.5

                if flip_lr and flip_ud:
                    # Last dim is real/complex dimension for complex image and target.
                    complex_image = torch.flip(complex_image, dims=(-3, -2))
                    cmg_target = torch.flip(cmg_target, dims=(-3, -2))
                    target = torch.flip(target, dims=(
                        -2, -1))  # Has only two dimensions, height and width.
                    kspace_target = fft2(cmg_target)

                elif flip_ud:
                    complex_image = torch.flip(complex_image, dims=(-3, ))
                    cmg_target = torch.flip(cmg_target, dims=(-3, ))
                    target = torch.flip(target, dims=(-2, ))
                    kspace_target = fft2(cmg_target)

                elif flip_lr:
                    complex_image = torch.flip(complex_image, dims=(-2, ))
                    cmg_target = torch.flip(cmg_target, dims=(-2, ))
                    target = torch.flip(target, dims=(-1, ))
                    kspace_target = fft2(cmg_target)

            # The image target is obtained after flipping the complex image.
            # This removes the need to flip the image target.
            img_target = complex_abs(cmg_target)
            img_inputs = complex_abs(complex_image)

            # Use plurals as keys to reduce confusion.
            targets = {
                'kspace_targets': kspace_target,
                'cmg_targets': cmg_target,
                'img_targets': img_target,
                'cmg_inputs': complex_image,
                'img_inputs': img_inputs
            }

            if self.challenge == 'multicoil':
                targets['rss_targets'] = target

            # Creating concatenated image of real/imag/abs channels.
            concat_image = torch.cat(
                [complex_image, img_inputs.unsqueeze(dim=-1)], dim=-1)

            # Converting to NCHW format for CNN.
            inputs = kspace_to_nchw(concat_image)

        return inputs, targets, extra_params