Esempio n. 1
0
def test_varnet(shape, chans, center_fractions, accelerations, mask_center):
    mask_func = RandomMaskFunc(center_fractions, accelerations)
    x = create_input(shape)
    outputs, masks = [], []
    for i in range(x.shape[0]):
        output, mask, _ = transforms.apply_mask(x[i:i + 1],
                                                mask_func,
                                                seed=123)
        outputs.append(output)
        masks.append(mask)

    output = torch.cat(outputs)
    mask = torch.cat(masks)

    varnet = VarNet(
        num_cascades=2,
        sens_chans=4,
        sens_pools=2,
        chans=chans,
        pools=2,
        mask_center=mask_center,
    )

    y = varnet(output, mask.byte())

    assert y.shape[1:] == x.shape[2:4]
Esempio n. 2
0
def test_mask_types(mask_type):
    shape_list = ((4, 32, 32, 2), (2, 64, 32, 2), (1, 33, 24, 2))
    center_fraction_list = ([0.08], [0.04], [0.04, 0.08])
    acceleration_list = ([4], [8], [4, 8])
    state = np.random.get_state()

    for shape in shape_list:
        for center_fractions, accelerations in zip(center_fraction_list,
                                                   acceleration_list):
            mask_func = create_mask_for_mask_type(mask_type, center_fractions,
                                                  accelerations)
            expected_mask, expected_num_low_frequencies = mask_func(shape,
                                                                    seed=123)
            x = create_input(shape)
            output, mask, num_low_frequencies = transforms.apply_mask(
                x, mask_func, seed=123)

            assert (state[1] == np.random.get_state()[1]).all()
            assert output.shape == x.shape
            assert mask.shape == expected_mask.shape
            assert np.all(expected_mask.numpy() == mask.numpy())
            assert np.all(
                np.where(mask.numpy() == 0, 0, output.numpy()) ==
                output.numpy())
            assert num_low_frequencies == expected_num_low_frequencies
Esempio n. 3
0
def test_varnet_num_sense_lines(shape, chans, center_fractions, accelerations,
                                mask_center):
    mask_func = RandomMaskFunc(center_fractions, accelerations)
    x = create_input(shape)
    output, mask, num_low_freqs = transforms.apply_mask(x, mask_func, seed=123)

    varnet = VarNet(
        num_cascades=2,
        sens_chans=4,
        sens_pools=2,
        chans=chans,
        pools=2,
        mask_center=mask_center,
    )

    if mask_center is True:
        pad, net_low_freqs = varnet.sens_net.get_pad_and_num_low_freqs(
            mask, num_low_freqs)
        assert net_low_freqs == num_low_freqs
        assert torch.allclose(
            mask.squeeze()[int(pad):int(pad + net_low_freqs)].to(torch.int8),
            torch.ones([int(net_low_freqs)], dtype=torch.int8),
        )

    y = varnet(output, mask.byte(), num_low_frequencies=4)

    assert y.shape[1:] == x.shape[2:4]
Esempio n. 4
0
    def __getitem__(self, i: int):
        fname, dataslice, metadata = self.examples[i]

        with h5py.File(fname, "r") as hf:
            kspace = hf["kspace"][dataslice]

######################################################

            mask_func = RandomMaskFunc(center_fractions=[0.04], accelerations=[8])  # Create the mask function object
            masked_kspace, mask = T.apply_mask(T.to_tensor(kspace), mask_func)   # Apply the mask to k-space
            loss_masked_kspace = masked_kspace[:,:,0].numpy()
            #print("kspace shape  : ", kspace.shape)
            #print("loss_masked_kspace shape : ", loss_masked_kspace.shape)
            trn_masked_kspace = kspace - loss_masked_kspace

######################################################

            mask = np.asarray(hf["mask"]) if "mask" in hf else None

            target = hf[self.recons_key][dataslice] if self.recons_key in hf else None

######################################################
            kspace = trn_masked_kspace
            target = loss_masked_kspace
######################################################

            attrs = dict(hf.attrs)
            attrs.update(metadata)

        if self.transform is None:
            sample = (kspace, mask, target, attrs, fname.name, dataslice)
        else:
            sample = self.transform(kspace, mask, target, attrs, fname.name, dataslice)

        return sample
Esempio n. 5
0
    def __call__(self, kspace, mask, target, attrs, fname, slice_num):
        """
        Data Transformer that simply returns the input masked k-space data and
        relevant attributes needed for running MRI reconstruction algorithms
        implemented in BART.

        Args:
            masked_kspace (numpy.array): Input k-space of shape (num_coils, rows,
                cols, 2) for multi-coil data or (rows, cols, 2) for single coil
                data.
            target (numpy.array, optional): Target image.
            attrs (dict): Acquisition related information stored in the HDF5
                object.
            fname (str): File name.
            slice_num (int): Serial number of the slice.

        Returns:
            tuple: tuple containing:
                masked_kspace (torch.Tensor): Sub-sampled k-space with the same
                    shape as kspace.
                reg_wt (float): Regularization parameter.
                fname (str): File name containing the current data item.
                slice_num (int): The index of the current slice in the volume.
                crop_size (tuple): Size of the image to crop to given ISMRMRD
                    header.
                num_low_freqs (int): Number of low-resolution lines acquired.
        """
        kspace = T.to_tensor(kspace)

        # apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = T.apply_mask(kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        if self.retrieve_acc:
            num_low_freqs = attrs["num_low_frequency"]
        else:
            num_low_freqs = None

        if self.retrieve_acc and self.reg_wt is None:
            acquisition = attrs["acquisition"]
            acceleration = attrs["acceleration"]

            with open("cs_config.yaml", "r") as f:
                param_dict = yaml.safe_load(f)

            if acquisition not in param_dict[args.challenge]:
                raise ValueError(f"Invalid acquisition protocol: {acquisition}")
            if acceleration not in (4, 8):
                raise ValueError(f"Invalid acceleration factor: {acceleration}")

            reg_wt = param_dict[args.challenge][acquisition][acceleration]
        else:
            reg_wt = self.reg_wt

        crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

        return (masked_kspace, reg_wt, fname, slice_num, crop_size, num_low_freqs)
Esempio n. 6
0
def load_data_from_pathlist(path):
    file_num = len(path)
    use_num = file_num // 3
    total_target_list = []
    total_sampled_image_list = []
    for h5_num in range(use_num):
        total_kspace, slices_num, target = load_dataset(path[h5_num])
        image_list = []
        slice_kspace_tensor_list = []
        target_image_list = []
        for i in range(slices_num):
            slice_kspace = total_kspace[i]
            #target_image = target[i]
            slice_kspace_tensor = T.to_tensor(
                slice_kspace)  # convert numpy to tensor
            slice_kspace_tensor = slice_kspace_tensor.float()
            #print(slice_kspace_tensor.shape)
            slice_kspace_tensor_list.append(
                slice_kspace_tensor)  # 35* torch[640, 368])
            #target = target_image_list.append(target_image)

        #image_list_tensor = torch.stack(image_list, dim=0)  # torch.Size([35, 640, 368])
        #total_image_list.append(image_list_tensor)
        mask_func = RandomMaskFunc(
            center_fractions=[0.08],
            accelerations=[4])  # create the mask function object
        sampled_image_list = []
        target_list = []
        for i in range(slices_num):
            slice_kspace_tensor = slice_kspace_tensor_list[i]
            masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func)
            Ny, Nx, _ = slice_kspace_tensor.shape
            mask = mask.repeat(Ny, 1, 1).squeeze()
            # functions.show_slice(mask, cmap='gray')
            # functions.show_slice(image_list[10], cmap='gray')
            sampled_image = fastmri.ifft2c(
                masked_kspace)  # inverse fast FT to get the complex image
            sampled_image = T.complex_center_crop(sampled_image, (320, 320))
            sampled_image_abs = fastmri.complex_abs(sampled_image)
            sampled_image_list.append(sampled_image_abs)
        sampled_image_list_tensor = torch.stack(
            sampled_image_list, dim=0)  # torch.Size([35, 640, 368])
        total_sampled_image_list.append(sampled_image_list_tensor)
        target = T.to_tensor(target)
        total_target_list.append(target)
    #target_image_tensor = torch.cat(target_image_list, dim=0)                       # torch.Size([6965, 640, 368])
    total_target = torch.cat(total_target_list, dim=0)
    total_sampled_image_tensor = torch.cat(
        total_sampled_image_list, dim=0)  # torch.Size([6965, 640, 368])
    total_sampled_image_tensor, mean, std = T.normalize_instance(
        total_sampled_image_tensor, eps=1e-11)
    total_sampled_image_tensor = total_sampled_image_tensor.clamp(-6, 6)
    target_image_tensor = T.normalize(total_target, mean, std, eps=1e-11)
    target_image_tensor = target_image_tensor.clamp(-6, 6)
    # total_image_tensor = torch.stack(total_image_list, dim=0)  # torch.Size([199, 35, 640, 368])
    # total_sampled_image_tensor = torch.stack(total_sampled_image_list, dim=0)  # torch.Size([199, 35, 640, 368])
    #print(target_image_tensor.shape)
    #print(total_sampled_image_tensor.shape)
    return target_image_tensor, total_sampled_image_tensor
Esempio n. 7
0
def test_apply_mask(shape, num_low_frequencies, accelerations):
    mask_func = RandomMask(num_low_frequencies, accelerations)
    expected_mask, _ = mask_func(shape, seed=123)
    input = create_input(shape)
    output, mask, _ = transforms.apply_mask(input, mask_func, seed=123)
    assert output.shape == input.shape
    assert mask.shape == expected_mask.shape
    assert np.all(expected_mask.numpy() == mask.numpy())
    assert np.all((output * mask).numpy() == output.numpy())
    def __call__(
        self,
        kspace: np.ndarray,
        mask: np.ndarray,
        target: np.ndarray,
        attrs: Dict,
        fname: str,
        slice_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str,
               int, float]:
        """
        Args:
            kspace: Input k-space of shape (num_coils, rows, cols) for
                multi-coil data or (rows, cols) for single coil data.
            mask: Mask from the test dataset.
            target: Target image.
            attrs: Acquisition related information stored in the HDF5 object.
            fname: File name.
            slice_num: Serial number of the slice.

        Returns:
            tuple containing:
                image: Zero-filled input image.
                target: Target image converted to a torch.Tensor.
                mean: Mean value used for normalization.
                std: Standard deviation value used for normalization.
                fname: File name.
                slice_num: Serial number of the slice.
        """
        kspace = T.to_tensor(kspace)

        # check for max value
        max_value = attrs["max"] if "max" in attrs.keys() else 0.0

        # apply subsampling mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = T.apply_mask(kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        if self.splitter_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            mask_loss = self.splitter_func(masked_kspace.shape, seed)
        else:
            mask_loss = torch.Tensor([0])

        # normalize target
        if target is not None:
            target = T.to_tensor(target)
        else:
            target = torch.Tensor([0])

        return masked_kspace, target, fname, slice_num, max_value, attrs, mask_loss
Esempio n. 9
0
def load_data(file_dir_path):
    file_path = get_files(file_dir_path)
    file_num = len(file_path)
    total_image_list = []
    total_sampled_image_list = []
    for h5_num in range(file_num):
        total_kspace, slices_num = load_dataset(file_path[0])
        image_list = []
        slice_kspace_tensor_list = []
        for i in range(slices_num):
            slice_kspace = total_kspace[i]
            slice_kspace_tensor = T.to_tensor(
                slice_kspace)  # convert numpy to tensor
            slice_image = fastmri.ifft2c(
                slice_kspace_tensor)  # inverse fast FT
            slice_image_abs = fastmri.complex_abs(
                slice_image)  # compute the absolute value to get a real image
            image_list.append(slice_image_abs)
            slice_kspace_tensor_list.append(
                slice_kspace_tensor)  # 35* torch[640, 368])

        image_list_tensor = torch.stack(image_list,
                                        dim=0)  # torch.Size([35, 640, 368])
        total_image_list.append(image_list_tensor)
        mask_func = RandomMaskFunc(
            center_fractions=[0.08],
            accelerations=[4])  # create the mask function object
        sampled_image_list = []
        for i in range(slices_num):
            slice_kspace_tensor = slice_kspace_tensor_list[i]
            masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func)
            Ny, Nx, _ = slice_kspace_tensor.shape
            mask = mask.repeat(Ny, 1, 1).squeeze()
            # functions.show_slice(mask, cmap='gray')
            # functions.show_slice(image_list[10], cmap='gray')
            sampled_image = fastmri.ifft2c(
                masked_kspace)  # inverse fast FT to get the complex image
            sampled_image_abs = fastmri.complex_abs(sampled_image)
            sampled_image_list.append(sampled_image_abs)
        sampled_image_list_tensor = torch.stack(
            sampled_image_list, dim=0)  # torch.Size([35, 640, 368])
        total_sampled_image_list.append(sampled_image_list_tensor)
    # total_image_tensor = torch.cat(total_image_list, dim=0)                       # torch.Size([6965, 640, 368])
    # total_sampled_image_tensor = torch.cat(total_sampled_image_list, dim=0)       # torch.Size([6965, 640, 368])
    total_image_tensor = torch.stack(total_image_list,
                                     dim=0)  # torch.Size([199, 35, 640, 368])
    total_sampled_image_tensor = torch.stack(
        total_sampled_image_list, dim=0)  # torch.Size([199, 35, 640, 368])
    print(total_image_tensor.shape)
    print(total_sampled_image_tensor.shape)
    return total_image_tensor, total_sampled_image_tensor
Esempio n. 10
0
def test_apply_mask(shape, center_fractions, accelerations):
    state = np.random.get_state()

    mask_func = RandomMaskFunc(center_fractions, accelerations)
    expected_mask = mask_func(shape, seed=123)
    x = create_input(shape)
    output, mask = transforms.apply_mask(x, mask_func, seed=123)

    assert (state[1] == np.random.get_state()[1]).all()
    assert output.shape == x.shape
    assert mask.shape == expected_mask.shape
    assert np.all(expected_mask.numpy() == mask.numpy())
    assert np.all(
        np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy())
Esempio n. 11
0
def test_varnet(shape, out_chans, chans, center_fractions, accelerations):
    mask_func = RandomMaskFunc(center_fractions, accelerations)
    x = create_input(shape)
    output, mask = transforms.apply_mask(x, mask_func, seed=123)

    varnet = VarNet(num_cascades=2,
                    sens_chans=4,
                    sens_pools=2,
                    chans=4,
                    pools=2)

    y = varnet(output, mask.byte())

    assert y.shape[1:] == x.shape[2:4]
Esempio n. 12
0
def test_apply_mask(shape, center_fractions, accelerations):
    state = np.random.get_state()

    mask_func = RandomMaskFunc(center_fractions, accelerations)
    expected_mask, expected_num_low_frequencies = mask_func(shape, seed=123)
    assert expected_num_low_frequencies in [
        round(cf * shape[-2]) for cf in center_fractions
    ]
    x = create_input(shape)
    output, mask, num_low_frequencies = transforms.apply_mask(x,
                                                              mask_func,
                                                              seed=123)

    assert (state[1] == np.random.get_state()[1]).all()
    assert output.shape == x.shape
    assert mask.shape == expected_mask.shape
    assert np.all(expected_mask.numpy() == mask.numpy())
    assert np.all(
        np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy())
    assert num_low_frequencies == expected_num_low_frequencies
Esempio n. 13
0
plt.imshow(np.abs(slice_image_rss.numpy()), cmap='gray')


# So far, we have been looking at fully-sampled data. We can simulate under-sampled data by creating a mask and applying it to k-space.

# In[18]:


from fastmri.data.subsample import RandomMaskFunc
mask_func = RandomMaskFunc(center_fractions=[0.08], accelerations=[4])  # Create the mask function object


# In[19]:


masked_kspace, mask = T.apply_mask(slice_kspace2, mask_func)   # Apply the mask to k-space


# Let's see what the subsampled image looks like:

# In[20]:


sampled_image = fastmri.ifft2c(masked_kspace)           # Apply Inverse Fourier Transform to get the complex image
sampled_image_abs = fastmri.complex_abs(sampled_image)   # Compute absolute value to get a real image
sampled_image_rss = fastmri.rss(sampled_image_abs, dim=0)
show_coils(sampled_image_abs, [0], cmap='gray')

ckpt_path = "fastmri_examples/varnet/varnet/varnet_demo/checkpoints/epoch=4-step=52114.ckpt"
from fastmri.pl_modules import SSVarNetModule as VarNetModule
model = VarNetModule.load_from_checkpoint(ckpt_path)
Esempio n. 14
0
def data_transform(kspace, mask, target, data_attributes, filename, slice_num):
    # Transform the data into appropriate format
    # Here we simply mask the k-space and return the result
    kspace = transforms.to_tensor(kspace)
    masked_kspace,_ = transforms.apply_mask(kspace, mask_func)
    return masked_kspace
Esempio n. 15
0
    def __call__(self, kspace, mask, target, attrs, fname, slice_num):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows,
                cols, 2) for multi-coil data or (rows, cols, 2) for single coil
                data.
            mask (numpy.array): Mask from the test dataset.
            target (numpy.array): Target image.
            attrs (dict): Acquisition related information stored in the HDF5
                object.
            fname (str): File name.
            slice_num (int): Serial number of the slice.

        Returns:
            (tuple): tuple containing:
                masked_kspace (torch.Tensor): k-space after applying sampling
                    mask.
                mask (torch.Tensor): The applied sampling mask
                target (torch.Tensor): The target image (if applicable).
                fname (str): File name.
                slice_num (int): The slice index.
                max_value (float): Maximum image value.
                crop_size (torch.Tensor): the size to crop the final image.
        """
        if target is not None:
            target = T.to_tensor(target)
            max_value = attrs["max"]
        else:
            target = torch.tensor(0)
            max_value = 0.0

        kspace = T.to_tensor(kspace)
        seed = None if not self.use_seed else tuple(map(ord, fname))
        acq_start = attrs["padding_left"]
        acq_end = attrs["padding_right"]

        crop_size = torch.tensor([attrs["recon_size"][0], attrs["recon_size"][1]])

        if self.mask_func:
            masked_kspace, mask = T.apply_mask(
                kspace, self.mask_func, seed, (acq_start, acq_end)
            )
        else:
            masked_kspace = kspace
            shape = np.array(kspace.shape)
            num_cols = shape[-2]
            shape[:-3] = 1
            mask_shape = [1 for _ in shape]
            mask_shape[-2] = num_cols
            mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32))
            mask[:, :, :acq_start] = 0
            mask[:, :, acq_end:] = 0

        return (
            masked_kspace,
            mask.byte(),
            target,
            fname,
            slice_num,
            max_value,
            crop_size,
        )
Esempio n. 16
0
target = target / np.max(np.abs(target))

target = np.sqrt(np.sum(T.center_crop(target, crop_size) ** 2, 0))

​

crop_size = (320, 320)

mask_func = create_mask_for_mask_type(mask_type_str="random", center_fractions=[0.08], accelerations=[4])

​

_kspace = T.to_tensor(kspace)[slice]

masked_kspace, mask = T.apply_mask(_kspace, mask_func)

​

linear_recon = masked_kspace[..., 0] + 1j * masked_kspace[..., 1]

linear_recon = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(linear_recon, axes=(-2, -1)), axes=(-2, -1)),

                               axes=(-2, -1))

linear_recon = linear_recon / np.max(np.abs(linear_recon))

linear_recon = np.sqrt(np.sum(T.center_crop(linear_recon, (320, 320)) ** 2, 0))

​
Esempio n. 17
0
    def __call__(self, target_ksp, target_im, attrs, fname, slice):
        kspace_np = target_ksp
        target_im = transforms.to_tensor(target_im)
        target_ksp = transforms.to_tensor(target_ksp)

        if self.args.coil_compress_coils:
            target_ksp = transforms.coil_compress(target_ksp, self.args.coil_compress_coils)

        if self.args.calculate_offsets_directly:
            krow = kspace_np.sum(axis=(0,1)) # flatten to a single row
            width = len(krow)
            offset = (krow != 0).argmax()
            acq_start = offset
            acq_end = width - (krow[::-1] != 0).argmax() #exclusive
        else:
            offset = None # Mask will pick randomly
            if self.partition == 'val' and 'mask_offset' in attrs:
                offset = attrs['mask_offset']

            acq_start = attrs['padding_left']
            acq_end = attrs['padding_right']

        #pdb.set_trace()

        seed = None if not self.use_seed else tuple(map(ord, fname))
        input_ksp, mask, num_lf = transforms.apply_mask(
            target_ksp, self.mask_func, 
            seed, offset,
            (acq_start, acq_end))

        #pdb.set_trace()

        sens_map = torch.Tensor(0)
        if self.args.compute_sensitivities:
            start_of_center_mask = (kspace_np.shape[-1] - num_lf + 1) // 2
            end_of_center_mask = start_of_center_mask + num_lf
            sens_map = est_sens_maps(kspace_np, start_of_center_mask, end_of_center_mask)
            sens_map = transforms.to_tensor(sens_map)

        if self.args.grappa_input:
            with h5py.File(self.args.grappa_input_path / self.partition / fname, 'r') as hf:
                kernel = transforms.to_tensor(hf['kernel'][slice])
                input_ksp = transforms.apply_grappa(input_ksp, kernel, target_ksp, mask)

        grappa_kernel = torch.Tensor(0)
        if self.args.grappa_path is not None:
            with h5py.File(self.args.grappa_path / self.partition / fname, 'r') as hf:
                grappa_kernel = transforms.to_tensor(hf['kernel'][slice])

        if self.args.grappa_target:
            with h5py.File(self.args.grappa_target_path / self.partition / fname, 'r') as hf:
                kernel = transforms.to_tensor(hf['kernel'][slice])
                target_ksp = transforms.apply_grappa(target_ksp.clone(), kernel, target_ksp, mask, sample_accel=2)
                target_im = transforms.root_sum_of_squares(transforms.complex_abs(transforms.ifft2(target_ksp)))

        input_im = transforms.ifft2(input_ksp)
        if not self.args.scale_inputs:
            scale = torch.Tensor([1.])
        else:
            abs_input = transforms.complex_abs(input_im)
            if self.args.scale_type == 'max':
                scale = torch.max(abs_input)
            else:
                scale = torch.mean(abs_input)

            input_ksp /= scale
            target_ksp /= scale
            target_im /= scale

        scale = scale.view([1, 1, 1])
        attrs_dict = dict(**attrs)

        return OrderedDict(
            input = input_ksp,
            target = target_ksp,
            target_im = target_im,
            mask = mask,
            grappa_kernel = grappa_kernel,
            scale = scale,
            attrs_dict = attrs_dict,
            fname = fname,
            slice = slice,
            num_lf = num_lf,
            sens_map = sens_map,
        )
Esempio n. 18
0
    def __call__(self, kspace, mask, target, attrs, fname, slice_num):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows,
                cols, 2) for multi-coil data or (rows, cols, 2) for single coil
                data.
            mask (numpy.array): Mask from the test dataset.
            target (numpy.array): Target image.
            attrs (dict): Acquisition related information stored in the HDF5
                object.
            fname (str): File name.
            slice_num (int): Serial number of the slice.

        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch
                    Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                fname (str): File name.
                slice_num (int): Serial number of the slice.
        """
        kspace = transforms.to_tensor(kspace)

        # apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = transforms.apply_mask(
                kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # inverse Fourier transform to get zero filled solution
        image = fastmri.ifft2c(masked_kspace)

        # crop input to correct size
        if target is not None:
            crop_size = (target.shape[-2], target.shape[-1])
        else:
            crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

        # check for FLAIR 203
        if image.shape[-2] < crop_size[1]:
            crop_size = (image.shape[-2], image.shape[-2])

        image = transforms.complex_center_crop(image, crop_size)

        # absolute value
        image = fastmri.complex_abs(image)

        # apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == "multicoil":
            image = fastmri.rss(image)

        # normalize input
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        # normalize target
        if target is not None:
            target = transforms.to_tensor(target)
            target = transforms.center_crop(target, crop_size)
            target = transforms.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])

        return image, target, mean, std, fname, slice_num
    def __call__(
        self,
        kspace: np.ndarray,
        mask: np.ndarray,
        target: np.ndarray,
        attrs: Dict,
        fname: str,
        slice_num: int,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str,
               int, float]:
        """
        Args:
            kspace: Input k-space of shape (num_coils, rows, cols) for
                multi-coil data or (rows, cols) for single coil data.
            mask: Mask from the test dataset.
            target: Target image.
            attrs: Acquisition related information stored in the HDF5 object.
            fname: File name.
            slice_num: Serial number of the slice.

        Returns:
            tuple containing:
                image: Zero-filled input image.
                target: Target image converted to a torch.Tensor.
                mean: Mean value used for normalization.
                std: Standard deviation value used for normalization.
                fname: File name.
                slice_num: Serial number of the slice.
        """
        kspace = T.to_tensor(kspace)

        # check for max value
        max_value = attrs["max"] if "max" in attrs.keys() else 0.0

        # apply mask
        if self.mask_func:
            seed = None if not self.use_seed else tuple(map(ord, fname))
            masked_kspace, mask = T.apply_mask(kspace, self.mask_func, seed)
        else:
            masked_kspace = kspace

        # inverse Fourier transform to get zero filled solution
        image = fastmri.ifft2c(masked_kspace)

        if not self.test_mode:
            # crop input to correct size
            if target is not None:
                crop_size = (target.shape[-2], target.shape[-1])
            else:
                crop_size = (attrs["recon_size"][0], attrs["recon_size"][1])

        # check for FLAIR 203
        if self.test_mode or image.shape[-2] < crop_size[1]:
            crop_size = (image.shape[-2], image.shape[-2])

        image = T.complex_center_crop(image, crop_size)

        # absolute value
        image = fastmri.complex_abs(image)

        # apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == "multicoil":
            image = fastmri.rss(image)

        # normalize input
        image, mean, std = T.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)

        # normalize target
        if not self.test_mode and target is not None:
            target = T.to_tensor(target)
            target = T.center_crop(target, crop_size)
            target = T.normalize(target, mean, std, eps=1e-11)
            target = target.clamp(-6, 6)
        else:
            target = torch.Tensor([0])

        return image, target, mean, std, fname, slice_num, max_value