Exemplo n.º 1
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """

        target_inference = transforms.to_tensor(target)

        kspace = transforms.to_tensor(kspace)
        target = transforms.ifft2(kspace)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        if self.use_mask:
            mask = transforms.get_mask(kspace, self.mask_func, seed)
            masked_kspace = mask * kspace
        else:
            masked_kspace = kspace

        image = transforms.ifft2(masked_kspace)
        image_crop = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        _, mean, std = transforms.normalize_instance_complex(image_crop,
                                                             eps=1e-11)

        image_abs = transforms.complex_abs(image_crop)
        image_abs, mean_abs, std_abs = transforms.normalize_instance(image_abs,
                                                                     eps=1e-11)

        image = transforms.normalize(image, mean, std)

        target_image_complex_norm = transforms.normalize(target, mean, std)
        target_kspace_train = transforms.fft2(target_image_complex_norm)

        target = transforms.complex_center_crop(target, (320, 320))
        target = transforms.complex_abs(target)
        target_train = target

        if RENORM:
            target_train = transforms.normalize(target_train, mean_abs,
                                                std_abs)

        if CLAMP:
            image = image.clamp(-6, 6)
            target_train = target_train.clamp(-6, 6)

        return image, target_train, target_kspace_train, mean, std, mask, mean_abs, std_abs, target_inference, attrs[
            'max'], attrs['norm'].astype(np.float32)
Exemplo n.º 2
0
 def __call__(self, k_space, mask, target, attrs, f_name, slice):
     """
     Args:
         kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
             data or (rows, cols, 2) for single coil data.
         mask (numpy.array): Mask from the test dataset
         target (numpy.array): Target image
         attrs (dict): Acquisition related information stored in the HDF5 object.
         fname (str): File name
         slice (int): Serial number of the slice.
     Returns:
         (tuple): tuple containing:
             k_space (torch.Tensor): k-space(resolution x resolution x 2)
             target (torch.Tensor): Target image converted to a torch Tensor.
             fname (str): File name
             slice (int): Serial number of the slice.
     """
     k_space = transforms.to_tensor(k_space)
     full_image = transforms.ifft2(k_space)
     cropped_image = transforms.complex_center_crop(
         full_image, (self.resolution, self.resolution))
     k_space = transforms.fft2(cropped_image)
     # Normalize input
     cropped_image, mean, std = transforms.normalize_instance(cropped_image,
                                                              eps=1e-11)
     cropped_image = cropped_image.clamp(-6, 6)
     # Normalize target
     target = transforms.to_tensor(target)
     target = transforms.center_crop(target,
                                     (self.resolution, self.resolution))
     target = transforms.normalize(target, mean, std, eps=1e-11)
     target = target.clamp(-6, 6)
     return k_space, target, f_name, slice
Exemplo n.º 3
0
def nufft(input, coord, oversamp=1.25, width=4.0, n=128, device='cuda'):
    ndim = coord.shape[-1]
    beta = numpy.pi * (((width / oversamp) * (oversamp - 0.5)) ** 2 - 0.8) ** 0.5
    os_shape = _get_oversamp_shape(input.shape, ndim, oversamp)

    output = input.clone()

    # Apodize
    output = _apodize(output, ndim, oversamp, width, beta, device)

    # Zero-pad
    output = output / util.prod(input.shape[-ndim:]) ** 0.5
    output = util.resize(output, os_shape, device=device)

    # FFT
    output = output.permute(0, 1, 3, 4, 2)
    output = transforms.fft2(output)
    output = output.permute(0, 1, 4, 2, 3)

    # Interpolate
    coord = _scale_coord(coord, input.shape, oversamp, device)
    kernel = _get_kaiser_bessel_kernel(n, width, beta, coord.dtype, device)
    output = interp.interpolate(output, width, kernel, coord, device)

    return output
Exemplo n.º 4
0
    def __call__(self, kspace, target, attrs, fname, slice):
        kspace = transforms.to_tensor(kspace)
        image = transforms.ifft2_regular(kspace)
        image = transforms.complex_center_crop(
            image, (self.resolution, self.resolution))
        # image = transforms.complex_abs(image)
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        image = image.clamp(-6, 6)
        kspace = transforms.fft2(image)

        return image, mean, std, fname, slice
Exemplo n.º 5
0
def test_fft2(shape):
    shape = shape + [2]
    input = create_input(shape)
    out_torch = transforms.fft2(input).numpy()
    out_torch = out_torch[..., 0] + 1j * out_torch[..., 1]

    input_numpy = utils.tensor_to_complex_np(input)
    input_numpy = np.fft.ifftshift(input_numpy, (-2, -1))
    out_numpy = np.fft.fft2(input_numpy, norm='ortho')
    out_numpy = np.fft.fftshift(out_numpy, (-2, -1))
    assert np.allclose(out_torch, out_numpy)
Exemplo n.º 6
0
    def __call__(self, kspace, target, attrs, fname, slice):
        kspace_rect = transforms.to_tensor(kspace)  ##rectangular kspace

        image_rect = transforms.ifft2(kspace_rect)  ##rectangular FS image
        image_square = transforms.complex_center_crop(
            image_rect,
            (self.resolution, self.resolution))  ##cropped to FS square image
        kspace_square = self.c3object.apply(
            transforms.fft2(image_square))  #* 10000  ##kspace of square iamge

        if self.augmentation:
            kspace_square = self.augmentation.apply(kspace_square)

        image_square = ifft_c3(kspace_square)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace_square, mask = transforms.apply_mask(
            kspace_square, self.mask_func, seed)  ##ZF square kspace

        # Inverse Fourier Transform to get zero filled solution
        # image = transforms.ifft2(masked_kspace)
        image_square_us = ifft_c3(
            masked_kspace_square)  ## US square complex image

        # Crop input image
        # image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        # image = transforms.complex_abs(image)
        image_square_abs = transforms.complex_abs(
            image_square_us)  ## US square real image

        # Apply Root-Sum-of-Squares if multicoil data
        # if self.which_challenge == 'multicoil':
        #     image = transforms.root_sum_of_squares(image)
        # Normalize input
        # image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        _, mean, std = transforms.normalize_instance(image_square_abs,
                                                     eps=1e-11)
        # image = image.clamp(-6, 6)

        # target = transforms.to_tensor(target)
        target = image_square.permute(2, 0, 1)
        # Normalize target
        # target = transforms.normalize(target, mean, std, eps=1e-11)
        # target = target.clamp(-6, 6)
        # return image, target, mean, std, attrs['norm'].astype(np.float32)

        # return masked_kspace_square.permute((2,0,1)), image, image_square.permute(2,0,1), mean, std, attrs['norm'].astype(np.float32)

        # ksp, zf, target, me, st, nor
        return masked_kspace_square.permute((2,0,1)), image_square_us.permute((2,0,1)), \
            target,  \
            mean, std, attrs['norm'].astype(np.float32)
Exemplo n.º 7
0
def data_for_training(rawdata, sensitivity, mask_func, norm=True):
    ''' normalize each slice using complex absolute max value'''
    
    coils, Ny, Nx, ps = rawdata.shape
   
    # shift data
    shift_kspace = rawdata
    x, y = np.meshgrid(np.arange(1, Nx + 1), np.arange(1, Ny + 1))
    adjust = (-1) ** (x + y)
    shift_kspace = T.ifftshift(shift_kspace, dim=(-3,-2)) * torch.from_numpy(adjust).view(1, Ny, Nx, 1).float()
 
    # apply masks
    
    shape = np.array(shift_kspace.shape)
    shape[:-3] = 1
    mask = mask_func(shape)
    mask = T.ifftshift(mask)  # shift mask

    # undersample
    masked_kspace = torch.where(mask == 0, torch.Tensor([0]), shift_kspace)
    masks = mask.repeat(coils, Ny, 1, ps)

    img_gt, img_und = T.ifft2(shift_kspace), T.ifft2(masked_kspace)
    
    if norm:
        # perform k space raw data normalization
        # during inference there is no ground truth image so use the zero-filled recon to normalize
        norm = T.complex_abs(img_und).max()
        if norm < 1e-6: norm = 1e-6
        # normalized recon
    else: 
        norm = 1
    
    # normalize data to learn more effectively    
    img_gt, img_und = img_gt/norm, img_und/norm

    rawdata_und = masked_kspace/norm  # faster

    sense_gt = cobmine_all_coils(img_gt, sensitivity)
    
    sense_und = cobmine_all_coils(img_und, sensitivity) 

    sense_und_kspace = T.fft2(sense_und) 
        
    return sense_und, sense_gt, sense_und_kspace, rawdata_und, masks, sensitivity
Exemplo n.º 8
0
def sample_vector(k_space, vector):
    # TODO: make the loop parallel using torch parallel loops
    # this function will use grid sample ato nufft to sample the k_space
    # for each sample in k_space batch there will be a sampling vector in vector parameter
    # the vector that is given is with values between -resolution to resolution
    images = torch.zeros_like(k_space)
    k_space = k_space.permute(0, 3, 1, 2)
    for i in range(k_space.shape[0]):
        space = k_space[i].unsqueeze(0)

        # we need to reverse x,y in the sampling vector
        # because grid sample samples y,x
        sampling_vector = torch.zeros_like(vector[i])
        sampling_vector[...,
                        0], sampling_vector[...,
                                            1] = vector[i][...,
                                                           1], vector[i][...,
                                                                         0]
        # normalize the vector to be in the right range for sampling
        # the values are between -1 and 1
        sampling_vector = sampling_vector.unsqueeze(0).unsqueeze(0)
        normalized_sampling_vector = (
            sampling_vector + (k_space.shape[2] / 2)) / (k_space.shape[2] - 1)
        normalized_sampling_vector = 2 * normalized_sampling_vector - 1
        sampled_k_space = torch.nn.functional.grid_sample(
            space,
            normalized_sampling_vector,
            mode='bilinear',
            padding_mode='zeros',
            align_corners=False).unsqueeze(2)
        # for the nufft the indexes sould e in the in indexes domain and no normalize to between -1 and 1
        # and so we will use the original sampling vector
        image = nufft.nufft_adjoint(sampled_k_space,
                                    vector[i],
                                    space.shape,
                                    device=k_space.device).squeeze(0)
        images[i] = fft2(image)
    return images
Exemplo n.º 9
0
def tosquare(ksp,shp):
    rec = T.ifft2(ksp)
    sz = rec.shape
    
    return c3m * T.fft2(T.complex_center_crop(rec,shp)) * 100000
Exemplo n.º 10
0
def X_I_operator(img):
    return transforms.fft2(img)
Exemplo n.º 11
0
def reducedimension(kspace, resolution):
    image = croppedimage(kspace, resolution)
    kspace = transforms.fft2(image)
    return kspace
Exemplo n.º 12
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace_rect = transforms.to_tensor(kspace)  ##rectangular kspace

        image_rect = transforms.ifft2(kspace_rect)  ##rectangular FS image
        image_square = transforms.complex_center_crop(
            image_rect,
            (self.resolution, self.resolution))  ##cropped to FS square image
        kspace_square = transforms.fft2(image_square)  ##kspace of square iamge

        if self.augmentation:
            kspace_square = self.augmentation.apply(kspace_square)
            image_square = transforms.ifft2(kspace_square)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace_square, mask = transforms.apply_mask(
            kspace_square, self.mask_func, seed)  ##ZF square kspace

        # Inverse Fourier Transform to get zero filled solution
        # image = transforms.ifft2(masked_kspace)
        image_square_us = transforms.ifft2(
            masked_kspace_square)  ## US square complex image

        # Crop input image
        # image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        # image = transforms.complex_abs(image)
        image_square_abs = transforms.complex_abs(
            image_square_us)  ## US square real image

        # Apply Root-Sum-of-Squares if multicoil data
        # if self.which_challenge == 'multicoil':
        #     image = transforms.root_sum_of_squares(image)
        # Normalize input
        # image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        _, mean, std = transforms.normalize_instance(image_square_abs,
                                                     eps=1e-11)
        # image = image.clamp(-6, 6)

        # target = transforms.to_tensor(target)
        target = image_square.permute(2, 0, 1)
        # Normalize target
        # target = transforms.normalize(target, mean, std, eps=1e-11)
        # target = target.clamp(-6, 6)
        # return image, target, mean, std, attrs['norm'].astype(np.float32)

        # return masked_kspace_square.permute((2,0,1)), image, image_square.permute(2,0,1), mean, std, attrs['norm'].astype(np.float32)

        # ksp, zf, target, me, st, nor
        return masked_kspace_square.permute((2,0,1)), image_square_us.permute((2,0,1)), \
            target,  \
            mean, std, attrs['norm'].astype(np.float32)
Exemplo n.º 13
0
    def __call__(self, kspace, target, attrs, fname, slice):
        kspace_rect = transforms.to_tensor(kspace)  ##rectangular kspace

        image_rect = transforms.ifft2(kspace_rect)  ##rectangular FS image
        image_square = transforms.complex_center_crop(
            image_rect,
            (self.resolution, self.resolution))  ##cropped to FS square image

        kspace_square = self.c3object.apply(
            transforms.fft2(image_square)) * 10000  ##kspace of square iamge
        image_square2 = ifft_c3(kspace_square)  ##for training domain_transform

        if self.augmentation:
            kspace_square = self.augmentation.apply(kspace_square)

        # image_square = ifft_c3(kspace_square)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace_square, mask = transforms.apply_mask(
            kspace_square, self.mask_func, seed)  ##ZF square kspace

        # Inverse Fourier Transform to get zero filled solution
        # image = transforms.ifft2(masked_kspace)
        us_image_square = ifft_c3(
            masked_kspace_square)  ## US square complex image

        # Crop input image
        # image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        # image = transforms.complex_abs(image)
        us_image_square_abs = transforms.complex_abs(
            us_image_square)  ## US square real image
        us_image_square_rss = transforms.root_sum_of_squares(
            us_image_square_abs, dim=0)

        stacked_kspace_square = []
        for i in (range(len(kspace_square[:, 0, 0, 0]))):
            stacked_kspace_square.append(kspace_square[i, :, :, 0])
            stacked_kspace_square.append(kspace_square[i, :, :, 1])

        stacked_kspace_square = torch.stack(stacked_kspace_square)

        stacked_masked_kspace_square = []
        # masked_kspace_square = transforms.to_tensor(masked_kspace_square)
        # for i in range(len(masked_kspace_square[:,0,0,0])):
        # stacked_masked_kspace_square.stack(masked_kspace_square[i,:,:,0],masked_kspace_square[i,:,:,1])

        for i in (range(len(masked_kspace_square[:, 0, 0, 0]))):
            stacked_masked_kspace_square.append(masked_kspace_square[i, :, :,
                                                                     0])
            stacked_masked_kspace_square.append(masked_kspace_square[i, :, :,
                                                                     1])

        stacked_masked_kspace_square = torch.stack(
            stacked_masked_kspace_square)

        stacked_image_square = []
        for i in (range(len(image_square[:, 0, 0, 0]))):
            stacked_image_square.append(image_square2[i, :, :, 0])
            stacked_image_square.append(image_square2[i, :, :, 1])

        stacked_image_square = torch.stack(stacked_image_square)




        return stacked_kspace_square,stacked_masked_kspace_square , stacked_image_square , \
            us_image_square_rss ,   \
            target *10000 \
            #mean, std, attrs['norm'].astype(np.float32)
        '''
Exemplo n.º 14
0
 def sens_expand(x):
     return T.fft2(T.complex_mul(x, sens_maps))
Exemplo n.º 15
0
def data_consistency(k, k0, mask, noise_lvl=None):
    """
    k    - input in k-space (predicted filling)
    k0   - initially sampled elements in k-space
    mask - corresponding nonzero location
    """
    v = noise_lvl
    if v is not None:  # noisy case
        out = (1 - mask) * k + mask * (k + v * k0) / (1 + v)
    else:  # noiseless case
        out = (1 - mask) * k + mask * k0
    return out

ifft_func = lambda ksp: ifft2(ksp.permute((0,2,3,1))).permute((0,3,1,2))
fft_func = lambda im:fft2(im.permute(0,2,3,1)).permute((0,3,1,2))


class DataConsistencyInKspace(nn.Module):
    """ Create data consistency operator
    Warning: note that FFT2 (by the default of torch.fft) is applied to the last 2 axes of the input.
    This method detects if the input tensor is 4-dim (2D data) or 5-dim (3D data)
    and applies FFT2 to the (nx, ny) axis.
    """

    def __init__(self, noise_lvl=None, norm='ortho'):
        super(DataConsistencyInKspace, self).__init__()
        self.normalized = norm == 'ortho'
        self.noise_lvl = noise_lvl

    def forward(self, *input, **kwargs):
Exemplo n.º 16
0
def project_to_consistent_subspace(output, input, mask):
    reconstructed_kspace = transforms.fft2(output)
    original_kspace = transforms.fft2(input)
    new_kspace = (1 - mask) * reconstructed_kspace + mask * original_kspace
    return transforms.ifft2(new_kspace), original_kspace, new_kspace
Exemplo n.º 17
0
def get_attack_loss_new(model, ori_target, loss_f=torch.nn.MSELoss(reduction='none'), 
    xs=np.random.randint(low=100, high=320-100, size=(16,)), 
    ys=np.random.randint(low=100, high=320-100, size=(16,)), 
    shape=(320, 320), n_pixel_range=(10, 11), train=False, optimizer=None):
    
    input_o = ori_target.unsqueeze(1).to(args.device)
    input_o = input_o.clone()
    
    #input_o = transforms.complex_abs(ori_input.clone())
    #input_o, mean, std = transforms.normalize_instance(ori_target.unsqueeze(1).clone())
    #input_o = torch.clamp(input_o, -6, 6)

    #perturb_noise = perturb_noise_init(x=x, y=y, shape=shape, n_pixel_range=n_pixel_range)
    p_max = input_o.max().cpu()
    #p_min = (p_max - input.min()) / 2
    #p_min = (p_max - input_o.min())
    p_min = input_o.min().cpu()
    perturb_noise = [perturb_noise_init(x=x, y=y, shape=shape, n_pixel_range=n_pixel_range, pixel_value_range=(p_min, p_max)) for x, y in zip(xs, ys)]
    perturb_noise = np.stack(perturb_noise)
            
    # perturb the target to get the perturbed image
    #perturb_noise = np.expand_dims(perturb_noise, axis=0)
    #perturb_noise = np.stack((perturb_noise,)*ori_target.shape(0), -1)

    seed = np.random.randint(999999999)
    
    
    perturb_noise = transforms.to_tensor(perturb_noise).unsqueeze(1).to(args.device)
    
    if not args.fnaf_eval_control:
        input_o += perturb_noise
    target = input_o.clone()
    
    #print(input_o.shape)
    input_o = np.complex64(input_o.cpu().numpy())
    input_o = transforms.to_tensor(input_o)
    input_o = transforms.fft2(input_o)
    input_o, mask = transforms.apply_mask(input_o, mask_f, seed)
    input_o = transforms.ifft2(input_o)
    
    image = transforms.complex_abs(input_o).to(args.device)
    image, mean, std = transforms.normalize_instance(image, eps=1e-11)
    image = image.clamp(-6, 6)
    
    target = transforms.normalize(target, mean, std, eps=1e-11)
    target = target.clamp(-6, 6)

    #information_loss = loss_f(og_image.squeeze(1), image.squeeze(1)).mean(-1).mean(-1).cpu().numpy()
    #information_loss = np.array([0]*len(xs))

    # apply the perturbed image to the model to get the loss
    if train:
        output = model(image).squeeze(1)
    else:
        with torch.no_grad():
            output = model(image).squeeze(1)
            
    #perturb_noise_tensor = transforms.to_tensor(perturb_noise).to(args.device, dtype=torch.double)
    perturb_noise = perturb_noise.squeeze(1)
    perturb_noise_tensor = perturb_noise
    
    perturb_noise = perturb_noise.cpu().numpy()
        
    mask = adjusted_mask((perturb_noise > 0).astype(np.double))
    #mask = (perturb_noise > 0).astype(np.double)
    

        
    target = target.squeeze(1)
    mask_0 = transforms.to_tensor(mask).to(args.device)

    loss = loss_f((output*mask_0), (target*mask_0))

    if train:
        b_loss = loss.sum() / mask_0.sum() * 1 + loss_f(output, target).mean()
        b_loss.backward()
        optimizer.step()
        loss = loss.detach()

        loss = loss.mean(-1).mean(-1).cpu().numpy()
    #loss = loss.mean(-1).mean(-1).numpy()

    # information_loss_list.append(information_loss)
    # xs_list.append(xs)
    # ys_list.append(ys)
    
    
    return loss
Exemplo n.º 18
0
def to_k_space(image):
    #image = image.numpy()
    image = np.complex64(image)
    image = transforms.to_tensor(image)
    return transforms.fft2(image)
 def forward(self, input_image):
     k_space = fft2(input_image)
     reconstructed_k_space = self.K_space_reconstruction(k_space)
     image = ifft2(reconstructed_k_space)
     reconstructed_image = self.Unet_model(image)
     return reconstructed_image
Exemplo n.º 20
0
    def __call__(self, kspace, target, attrs, fname, slice):
        """
        Args:
            kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil
                data or (rows, cols, 2) for single coil data.
            target (numpy.array): Target image
            attrs (dict): Acquisition related information stored in the HDF5 object.
            fname (str): File name
            slice (int): Serial number of the slice.
        Returns:
            (tuple): tuple containing:
                image (torch.Tensor): Zero-filled input image.
                target (torch.Tensor): Target image converted to a torch Tensor.
                mean (float): Mean value used for normalization.
                std (float): Standard deviation value used for normalization.
                norm (float): L2 norm of the entire volume.
        """
        kspace = transforms.to_tensor(kspace)
        gt = transforms.ifft2(kspace)
        gt = transforms.complex_center_crop(gt, (self.resolution, self.resolution))
        kspace = transforms.fft2(gt)

        # Apply mask
        seed = None if not self.use_seed else tuple(map(ord, fname))
        masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed)
        # Inverse Fourier Transform to get zero filled solution
        image = transforms.ifft2(masked_kspace)
        masked_kspace = transforms.fft2_nshift(image)
        # Crop input image
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        # Absolute value
        image_mod = transforms.complex_abs(image).max()
        image_r = image[:, :, 0]*6.0/image_mod
        image_i = image[:, :, 1]*6.0/image_mod
        # image_r = image[:, :, 0]
        # image_i = image[:, :, 1]
        # Apply Root-Sum-of-Squares if multicoil data
        if self.which_challenge == 'multicoil':
            image = transforms.root_sum_of_squares(image)
        # Normalize input

        image = np.stack((image_r, image_i), axis=-1)
        image = image.transpose((2, 0, 1))
        image = transforms.to_tensor(image)

        target = transforms.ifft2(kspace)
        target = transforms.complex_center_crop(target, (self.resolution, self.resolution))
        # Normalize target
        target_r = target[:, :, 0]*6.0/image_mod
        target_i = target[:, :, 1]*6.0/image_mod
        # target_r = target[:, :, 0]
        # target_i = target[:, :, 1]

        target = np.stack((target_r, target_i), axis=-1)
        target = target.transpose((2, 0, 1))
        target = transforms.to_tensor(target)

        image_mod = np.stack((image_mod, image_mod), axis=0)
        image_mod = transforms.to_tensor(image_mod)

        norm = attrs['norm'].astype(np.float32)
        norm = np.stack((norm, norm), axis=-1)
        norm = transforms.to_tensor(norm)

        mask = mask.expand(kspace.shape)
        mask = mask.transpose(0, 2).transpose(1, 2)
        mask = transforms.ifftshift(mask)

        masked_kspace = masked_kspace.transpose(0, 2).transpose(1, 2)

        return image, target