def test_varnet(shape, chans, center_fractions, accelerations, mask_center): mask_func = RandomMaskFunc(center_fractions, accelerations) x = create_input(shape) outputs, masks = [], [] for i in range(x.shape[0]): output, mask, _ = transforms.apply_mask(x[i:i + 1], mask_func, seed=123) outputs.append(output) masks.append(mask) output = torch.cat(outputs) mask = torch.cat(masks) varnet = VarNet( num_cascades=2, sens_chans=4, sens_pools=2, chans=chans, pools=2, mask_center=mask_center, ) y = varnet(output, mask.byte()) assert y.shape[1:] == x.shape[2:4]
def test_mask_types(mask_type): shape_list = ((4, 32, 32, 2), (2, 64, 32, 2), (1, 33, 24, 2)) center_fraction_list = ([0.08], [0.04], [0.04, 0.08]) acceleration_list = ([4], [8], [4, 8]) state = np.random.get_state() for shape in shape_list: for center_fractions, accelerations in zip(center_fraction_list, acceleration_list): mask_func = create_mask_for_mask_type(mask_type, center_fractions, accelerations) expected_mask, expected_num_low_frequencies = mask_func(shape, seed=123) x = create_input(shape) output, mask, num_low_frequencies = transforms.apply_mask( x, mask_func, seed=123) assert (state[1] == np.random.get_state()[1]).all() assert output.shape == x.shape assert mask.shape == expected_mask.shape assert np.all(expected_mask.numpy() == mask.numpy()) assert np.all( np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy()) assert num_low_frequencies == expected_num_low_frequencies
def test_varnet_num_sense_lines(shape, chans, center_fractions, accelerations, mask_center): mask_func = RandomMaskFunc(center_fractions, accelerations) x = create_input(shape) output, mask, num_low_freqs = transforms.apply_mask(x, mask_func, seed=123) varnet = VarNet( num_cascades=2, sens_chans=4, sens_pools=2, chans=chans, pools=2, mask_center=mask_center, ) if mask_center is True: pad, net_low_freqs = varnet.sens_net.get_pad_and_num_low_freqs( mask, num_low_freqs) assert net_low_freqs == num_low_freqs assert torch.allclose( mask.squeeze()[int(pad):int(pad + net_low_freqs)].to(torch.int8), torch.ones([int(net_low_freqs)], dtype=torch.int8), ) y = varnet(output, mask.byte(), num_low_frequencies=4) assert y.shape[1:] == x.shape[2:4]
def __getitem__(self, i: int): fname, dataslice, metadata = self.examples[i] with h5py.File(fname, "r") as hf: kspace = hf["kspace"][dataslice] ###################################################### mask_func = RandomMaskFunc(center_fractions=[0.04], accelerations=[8]) # Create the mask function object masked_kspace, mask = T.apply_mask(T.to_tensor(kspace), mask_func) # Apply the mask to k-space loss_masked_kspace = masked_kspace[:,:,0].numpy() #print("kspace shape : ", kspace.shape) #print("loss_masked_kspace shape : ", loss_masked_kspace.shape) trn_masked_kspace = kspace - loss_masked_kspace ###################################################### mask = np.asarray(hf["mask"]) if "mask" in hf else None target = hf[self.recons_key][dataslice] if self.recons_key in hf else None ###################################################### kspace = trn_masked_kspace target = loss_masked_kspace ###################################################### attrs = dict(hf.attrs) attrs.update(metadata) if self.transform is None: sample = (kspace, mask, target, attrs, fname.name, dataslice) else: sample = self.transform(kspace, mask, target, attrs, fname.name, dataslice) return sample
def __call__(self, kspace, mask, target, attrs, fname, slice_num): """ Data Transformer that simply returns the input masked k-space data and relevant attributes needed for running MRI reconstruction algorithms implemented in BART. Args: masked_kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. target (numpy.array, optional): Target image. attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name. slice_num (int): Serial number of the slice. Returns: tuple: tuple containing: masked_kspace (torch.Tensor): Sub-sampled k-space with the same shape as kspace. reg_wt (float): Regularization parameter. fname (str): File name containing the current data item. slice_num (int): The index of the current slice in the volume. crop_size (tuple): Size of the image to crop to given ISMRMRD header. num_low_freqs (int): Number of low-resolution lines acquired. """ kspace = T.to_tensor(kspace) # apply mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = T.apply_mask(kspace, self.mask_func, seed) else: masked_kspace = kspace if self.retrieve_acc: num_low_freqs = attrs["num_low_frequency"] else: num_low_freqs = None if self.retrieve_acc and self.reg_wt is None: acquisition = attrs["acquisition"] acceleration = attrs["acceleration"] with open("cs_config.yaml", "r") as f: param_dict = yaml.safe_load(f) if acquisition not in param_dict[args.challenge]: raise ValueError(f"Invalid acquisition protocol: {acquisition}") if acceleration not in (4, 8): raise ValueError(f"Invalid acceleration factor: {acceleration}") reg_wt = param_dict[args.challenge][acquisition][acceleration] else: reg_wt = self.reg_wt crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) return (masked_kspace, reg_wt, fname, slice_num, crop_size, num_low_freqs)
def load_data_from_pathlist(path): file_num = len(path) use_num = file_num // 3 total_target_list = [] total_sampled_image_list = [] for h5_num in range(use_num): total_kspace, slices_num, target = load_dataset(path[h5_num]) image_list = [] slice_kspace_tensor_list = [] target_image_list = [] for i in range(slices_num): slice_kspace = total_kspace[i] #target_image = target[i] slice_kspace_tensor = T.to_tensor( slice_kspace) # convert numpy to tensor slice_kspace_tensor = slice_kspace_tensor.float() #print(slice_kspace_tensor.shape) slice_kspace_tensor_list.append( slice_kspace_tensor) # 35* torch[640, 368]) #target = target_image_list.append(target_image) #image_list_tensor = torch.stack(image_list, dim=0) # torch.Size([35, 640, 368]) #total_image_list.append(image_list_tensor) mask_func = RandomMaskFunc( center_fractions=[0.08], accelerations=[4]) # create the mask function object sampled_image_list = [] target_list = [] for i in range(slices_num): slice_kspace_tensor = slice_kspace_tensor_list[i] masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func) Ny, Nx, _ = slice_kspace_tensor.shape mask = mask.repeat(Ny, 1, 1).squeeze() # functions.show_slice(mask, cmap='gray') # functions.show_slice(image_list[10], cmap='gray') sampled_image = fastmri.ifft2c( masked_kspace) # inverse fast FT to get the complex image sampled_image = T.complex_center_crop(sampled_image, (320, 320)) sampled_image_abs = fastmri.complex_abs(sampled_image) sampled_image_list.append(sampled_image_abs) sampled_image_list_tensor = torch.stack( sampled_image_list, dim=0) # torch.Size([35, 640, 368]) total_sampled_image_list.append(sampled_image_list_tensor) target = T.to_tensor(target) total_target_list.append(target) #target_image_tensor = torch.cat(target_image_list, dim=0) # torch.Size([6965, 640, 368]) total_target = torch.cat(total_target_list, dim=0) total_sampled_image_tensor = torch.cat( total_sampled_image_list, dim=0) # torch.Size([6965, 640, 368]) total_sampled_image_tensor, mean, std = T.normalize_instance( total_sampled_image_tensor, eps=1e-11) total_sampled_image_tensor = total_sampled_image_tensor.clamp(-6, 6) target_image_tensor = T.normalize(total_target, mean, std, eps=1e-11) target_image_tensor = target_image_tensor.clamp(-6, 6) # total_image_tensor = torch.stack(total_image_list, dim=0) # torch.Size([199, 35, 640, 368]) # total_sampled_image_tensor = torch.stack(total_sampled_image_list, dim=0) # torch.Size([199, 35, 640, 368]) #print(target_image_tensor.shape) #print(total_sampled_image_tensor.shape) return target_image_tensor, total_sampled_image_tensor
def test_apply_mask(shape, num_low_frequencies, accelerations): mask_func = RandomMask(num_low_frequencies, accelerations) expected_mask, _ = mask_func(shape, seed=123) input = create_input(shape) output, mask, _ = transforms.apply_mask(input, mask_func, seed=123) assert output.shape == input.shape assert mask.shape == expected_mask.shape assert np.all(expected_mask.numpy() == mask.numpy()) assert np.all((output * mask).numpy() == output.numpy())
def __call__( self, kspace: np.ndarray, mask: np.ndarray, target: np.ndarray, attrs: Dict, fname: str, slice_num: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, int, float]: """ Args: kspace: Input k-space of shape (num_coils, rows, cols) for multi-coil data or (rows, cols) for single coil data. mask: Mask from the test dataset. target: Target image. attrs: Acquisition related information stored in the HDF5 object. fname: File name. slice_num: Serial number of the slice. Returns: tuple containing: image: Zero-filled input image. target: Target image converted to a torch.Tensor. mean: Mean value used for normalization. std: Standard deviation value used for normalization. fname: File name. slice_num: Serial number of the slice. """ kspace = T.to_tensor(kspace) # check for max value max_value = attrs["max"] if "max" in attrs.keys() else 0.0 # apply subsampling mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = T.apply_mask(kspace, self.mask_func, seed) else: masked_kspace = kspace if self.splitter_func: seed = None if not self.use_seed else tuple(map(ord, fname)) mask_loss = self.splitter_func(masked_kspace.shape, seed) else: mask_loss = torch.Tensor([0]) # normalize target if target is not None: target = T.to_tensor(target) else: target = torch.Tensor([0]) return masked_kspace, target, fname, slice_num, max_value, attrs, mask_loss
def load_data(file_dir_path): file_path = get_files(file_dir_path) file_num = len(file_path) total_image_list = [] total_sampled_image_list = [] for h5_num in range(file_num): total_kspace, slices_num = load_dataset(file_path[0]) image_list = [] slice_kspace_tensor_list = [] for i in range(slices_num): slice_kspace = total_kspace[i] slice_kspace_tensor = T.to_tensor( slice_kspace) # convert numpy to tensor slice_image = fastmri.ifft2c( slice_kspace_tensor) # inverse fast FT slice_image_abs = fastmri.complex_abs( slice_image) # compute the absolute value to get a real image image_list.append(slice_image_abs) slice_kspace_tensor_list.append( slice_kspace_tensor) # 35* torch[640, 368]) image_list_tensor = torch.stack(image_list, dim=0) # torch.Size([35, 640, 368]) total_image_list.append(image_list_tensor) mask_func = RandomMaskFunc( center_fractions=[0.08], accelerations=[4]) # create the mask function object sampled_image_list = [] for i in range(slices_num): slice_kspace_tensor = slice_kspace_tensor_list[i] masked_kspace, mask = T.apply_mask(slice_kspace_tensor, mask_func) Ny, Nx, _ = slice_kspace_tensor.shape mask = mask.repeat(Ny, 1, 1).squeeze() # functions.show_slice(mask, cmap='gray') # functions.show_slice(image_list[10], cmap='gray') sampled_image = fastmri.ifft2c( masked_kspace) # inverse fast FT to get the complex image sampled_image_abs = fastmri.complex_abs(sampled_image) sampled_image_list.append(sampled_image_abs) sampled_image_list_tensor = torch.stack( sampled_image_list, dim=0) # torch.Size([35, 640, 368]) total_sampled_image_list.append(sampled_image_list_tensor) # total_image_tensor = torch.cat(total_image_list, dim=0) # torch.Size([6965, 640, 368]) # total_sampled_image_tensor = torch.cat(total_sampled_image_list, dim=0) # torch.Size([6965, 640, 368]) total_image_tensor = torch.stack(total_image_list, dim=0) # torch.Size([199, 35, 640, 368]) total_sampled_image_tensor = torch.stack( total_sampled_image_list, dim=0) # torch.Size([199, 35, 640, 368]) print(total_image_tensor.shape) print(total_sampled_image_tensor.shape) return total_image_tensor, total_sampled_image_tensor
def test_apply_mask(shape, center_fractions, accelerations): state = np.random.get_state() mask_func = RandomMaskFunc(center_fractions, accelerations) expected_mask = mask_func(shape, seed=123) x = create_input(shape) output, mask = transforms.apply_mask(x, mask_func, seed=123) assert (state[1] == np.random.get_state()[1]).all() assert output.shape == x.shape assert mask.shape == expected_mask.shape assert np.all(expected_mask.numpy() == mask.numpy()) assert np.all( np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy())
def test_varnet(shape, out_chans, chans, center_fractions, accelerations): mask_func = RandomMaskFunc(center_fractions, accelerations) x = create_input(shape) output, mask = transforms.apply_mask(x, mask_func, seed=123) varnet = VarNet(num_cascades=2, sens_chans=4, sens_pools=2, chans=4, pools=2) y = varnet(output, mask.byte()) assert y.shape[1:] == x.shape[2:4]
def test_apply_mask(shape, center_fractions, accelerations): state = np.random.get_state() mask_func = RandomMaskFunc(center_fractions, accelerations) expected_mask, expected_num_low_frequencies = mask_func(shape, seed=123) assert expected_num_low_frequencies in [ round(cf * shape[-2]) for cf in center_fractions ] x = create_input(shape) output, mask, num_low_frequencies = transforms.apply_mask(x, mask_func, seed=123) assert (state[1] == np.random.get_state()[1]).all() assert output.shape == x.shape assert mask.shape == expected_mask.shape assert np.all(expected_mask.numpy() == mask.numpy()) assert np.all( np.where(mask.numpy() == 0, 0, output.numpy()) == output.numpy()) assert num_low_frequencies == expected_num_low_frequencies
plt.imshow(np.abs(slice_image_rss.numpy()), cmap='gray') # So far, we have been looking at fully-sampled data. We can simulate under-sampled data by creating a mask and applying it to k-space. # In[18]: from fastmri.data.subsample import RandomMaskFunc mask_func = RandomMaskFunc(center_fractions=[0.08], accelerations=[4]) # Create the mask function object # In[19]: masked_kspace, mask = T.apply_mask(slice_kspace2, mask_func) # Apply the mask to k-space # Let's see what the subsampled image looks like: # In[20]: sampled_image = fastmri.ifft2c(masked_kspace) # Apply Inverse Fourier Transform to get the complex image sampled_image_abs = fastmri.complex_abs(sampled_image) # Compute absolute value to get a real image sampled_image_rss = fastmri.rss(sampled_image_abs, dim=0) show_coils(sampled_image_abs, [0], cmap='gray') ckpt_path = "fastmri_examples/varnet/varnet/varnet_demo/checkpoints/epoch=4-step=52114.ckpt" from fastmri.pl_modules import SSVarNetModule as VarNetModule model = VarNetModule.load_from_checkpoint(ckpt_path)
def data_transform(kspace, mask, target, data_attributes, filename, slice_num): # Transform the data into appropriate format # Here we simply mask the k-space and return the result kspace = transforms.to_tensor(kspace) masked_kspace,_ = transforms.apply_mask(kspace, mask_func) return masked_kspace
def __call__(self, kspace, mask, target, attrs, fname, slice_num): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset. target (numpy.array): Target image. attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name. slice_num (int): Serial number of the slice. Returns: (tuple): tuple containing: masked_kspace (torch.Tensor): k-space after applying sampling mask. mask (torch.Tensor): The applied sampling mask target (torch.Tensor): The target image (if applicable). fname (str): File name. slice_num (int): The slice index. max_value (float): Maximum image value. crop_size (torch.Tensor): the size to crop the final image. """ if target is not None: target = T.to_tensor(target) max_value = attrs["max"] else: target = torch.tensor(0) max_value = 0.0 kspace = T.to_tensor(kspace) seed = None if not self.use_seed else tuple(map(ord, fname)) acq_start = attrs["padding_left"] acq_end = attrs["padding_right"] crop_size = torch.tensor([attrs["recon_size"][0], attrs["recon_size"][1]]) if self.mask_func: masked_kspace, mask = T.apply_mask( kspace, self.mask_func, seed, (acq_start, acq_end) ) else: masked_kspace = kspace shape = np.array(kspace.shape) num_cols = shape[-2] shape[:-3] = 1 mask_shape = [1 for _ in shape] mask_shape[-2] = num_cols mask = torch.from_numpy(mask.reshape(*mask_shape).astype(np.float32)) mask[:, :, :acq_start] = 0 mask[:, :, acq_end:] = 0 return ( masked_kspace, mask.byte(), target, fname, slice_num, max_value, crop_size, )
target = target / np.max(np.abs(target)) target = np.sqrt(np.sum(T.center_crop(target, crop_size) ** 2, 0)) crop_size = (320, 320) mask_func = create_mask_for_mask_type(mask_type_str="random", center_fractions=[0.08], accelerations=[4]) _kspace = T.to_tensor(kspace)[slice] masked_kspace, mask = T.apply_mask(_kspace, mask_func) linear_recon = masked_kspace[..., 0] + 1j * masked_kspace[..., 1] linear_recon = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(linear_recon, axes=(-2, -1)), axes=(-2, -1)), axes=(-2, -1)) linear_recon = linear_recon / np.max(np.abs(linear_recon)) linear_recon = np.sqrt(np.sum(T.center_crop(linear_recon, (320, 320)) ** 2, 0))
def __call__(self, target_ksp, target_im, attrs, fname, slice): kspace_np = target_ksp target_im = transforms.to_tensor(target_im) target_ksp = transforms.to_tensor(target_ksp) if self.args.coil_compress_coils: target_ksp = transforms.coil_compress(target_ksp, self.args.coil_compress_coils) if self.args.calculate_offsets_directly: krow = kspace_np.sum(axis=(0,1)) # flatten to a single row width = len(krow) offset = (krow != 0).argmax() acq_start = offset acq_end = width - (krow[::-1] != 0).argmax() #exclusive else: offset = None # Mask will pick randomly if self.partition == 'val' and 'mask_offset' in attrs: offset = attrs['mask_offset'] acq_start = attrs['padding_left'] acq_end = attrs['padding_right'] #pdb.set_trace() seed = None if not self.use_seed else tuple(map(ord, fname)) input_ksp, mask, num_lf = transforms.apply_mask( target_ksp, self.mask_func, seed, offset, (acq_start, acq_end)) #pdb.set_trace() sens_map = torch.Tensor(0) if self.args.compute_sensitivities: start_of_center_mask = (kspace_np.shape[-1] - num_lf + 1) // 2 end_of_center_mask = start_of_center_mask + num_lf sens_map = est_sens_maps(kspace_np, start_of_center_mask, end_of_center_mask) sens_map = transforms.to_tensor(sens_map) if self.args.grappa_input: with h5py.File(self.args.grappa_input_path / self.partition / fname, 'r') as hf: kernel = transforms.to_tensor(hf['kernel'][slice]) input_ksp = transforms.apply_grappa(input_ksp, kernel, target_ksp, mask) grappa_kernel = torch.Tensor(0) if self.args.grappa_path is not None: with h5py.File(self.args.grappa_path / self.partition / fname, 'r') as hf: grappa_kernel = transforms.to_tensor(hf['kernel'][slice]) if self.args.grappa_target: with h5py.File(self.args.grappa_target_path / self.partition / fname, 'r') as hf: kernel = transforms.to_tensor(hf['kernel'][slice]) target_ksp = transforms.apply_grappa(target_ksp.clone(), kernel, target_ksp, mask, sample_accel=2) target_im = transforms.root_sum_of_squares(transforms.complex_abs(transforms.ifft2(target_ksp))) input_im = transforms.ifft2(input_ksp) if not self.args.scale_inputs: scale = torch.Tensor([1.]) else: abs_input = transforms.complex_abs(input_im) if self.args.scale_type == 'max': scale = torch.max(abs_input) else: scale = torch.mean(abs_input) input_ksp /= scale target_ksp /= scale target_im /= scale scale = scale.view([1, 1, 1]) attrs_dict = dict(**attrs) return OrderedDict( input = input_ksp, target = target_ksp, target_im = target_im, mask = mask, grappa_kernel = grappa_kernel, scale = scale, attrs_dict = attrs_dict, fname = fname, slice = slice, num_lf = num_lf, sens_map = sens_map, )
def __call__(self, kspace, mask, target, attrs, fname, slice_num): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset. target (numpy.array): Target image. attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name. slice_num (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. fname (str): File name. slice_num (int): Serial number of the slice. """ kspace = transforms.to_tensor(kspace) # apply mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = transforms.apply_mask( kspace, self.mask_func, seed) else: masked_kspace = kspace # inverse Fourier transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for FLAIR 203 if image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) image = transforms.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": image = fastmri.rss(image) # normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target, mean, std, fname, slice_num
def __call__( self, kspace: np.ndarray, mask: np.ndarray, target: np.ndarray, attrs: Dict, fname: str, slice_num: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, str, int, float]: """ Args: kspace: Input k-space of shape (num_coils, rows, cols) for multi-coil data or (rows, cols) for single coil data. mask: Mask from the test dataset. target: Target image. attrs: Acquisition related information stored in the HDF5 object. fname: File name. slice_num: Serial number of the slice. Returns: tuple containing: image: Zero-filled input image. target: Target image converted to a torch.Tensor. mean: Mean value used for normalization. std: Standard deviation value used for normalization. fname: File name. slice_num: Serial number of the slice. """ kspace = T.to_tensor(kspace) # check for max value max_value = attrs["max"] if "max" in attrs.keys() else 0.0 # apply mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = T.apply_mask(kspace, self.mask_func, seed) else: masked_kspace = kspace # inverse Fourier transform to get zero filled solution image = fastmri.ifft2c(masked_kspace) if not self.test_mode: # crop input to correct size if target is not None: crop_size = (target.shape[-2], target.shape[-1]) else: crop_size = (attrs["recon_size"][0], attrs["recon_size"][1]) # check for FLAIR 203 if self.test_mode or image.shape[-2] < crop_size[1]: crop_size = (image.shape[-2], image.shape[-2]) image = T.complex_center_crop(image, crop_size) # absolute value image = fastmri.complex_abs(image) # apply Root-Sum-of-Squares if multicoil data if self.which_challenge == "multicoil": image = fastmri.rss(image) # normalize input image, mean, std = T.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # normalize target if not self.test_mode and target is not None: target = T.to_tensor(target) target = T.center_crop(target, crop_size) target = T.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target, mean, std, fname, slice_num, max_value