def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. norm (float): L2 norm of the entire volume. """ target_inference = transforms.to_tensor(target) kspace = transforms.to_tensor(kspace) target = transforms.ifft2(kspace) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) if self.use_mask: mask = transforms.get_mask(kspace, self.mask_func, seed) masked_kspace = mask * kspace else: masked_kspace = kspace image = transforms.ifft2(masked_kspace) image_crop = transforms.complex_center_crop( image, (self.resolution, self.resolution)) _, mean, std = transforms.normalize_instance_complex(image_crop, eps=1e-11) image_abs = transforms.complex_abs(image_crop) image_abs, mean_abs, std_abs = transforms.normalize_instance(image_abs, eps=1e-11) image = transforms.normalize(image, mean, std) target_image_complex_norm = transforms.normalize(target, mean, std) target_kspace_train = transforms.fft2(target_image_complex_norm) target = transforms.complex_center_crop(target, (320, 320)) target = transforms.complex_abs(target) target_train = target if RENORM: target_train = transforms.normalize(target_train, mean_abs, std_abs) if CLAMP: image = image.clamp(-6, 6) target_train = target_train.clamp(-6, 6) return image, target_train, target_kspace_train, mean, std, mask, mean_abs, std_abs, target_inference, attrs[ 'max'], attrs['norm'].astype(np.float32)
def generate(generator, data, device): input, _, mean, std, mask, _, _, _ = data input = input.to(device) mask = mask.to(device) output_network = generator(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # Projection to consistent K-space output_consistent, target_kspace, output_kspace = project_to_consistent_subspace( output_network, input, mask) # Take loss on the cropped, real valued image (abs) mean = mean.to(device) std = std.to(device) mean = mean.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device) std = std.unsqueeze(1).unsqueeze(2).unsqueeze(3).to(device) output_consistent = transforms.unnormalize(output_consistent, mean, std) output_consistent = transforms.complex_center_crop(output_consistent, (320, 320)) output_consistent = transforms.complex_abs(output_consistent) output_network = transforms.unnormalize(output_network, mean, std) output_network = transforms.complex_center_crop(output_network, (320, 320)) output_network = transforms.complex_abs(output_network) return output_consistent, output_network, target_kspace, output_kspace
def train_step(model, data, device): input, target, mean, std, mean_image, std_image, mask = data input = input.to(device) mask = mask.to(device) target = target.to(device) output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # Projection to consistent K-space output = input * mask + (1-mask) * output # Consistent K-space loss (with the normalized output and target) loss_k_consistent = F.l1_loss(output, target) mean = mean.to(device) std = std.to(device) target = transforms.unnormalize(target, mean, std) output = transforms.unnormalize(output, mean, std) output_image = transforms.ifft2(output) target_image = transforms.ifft2(target) output_image = transforms.complex_center_crop(output_image, (320, 320)) output_image = transforms.complex_abs(output_image) target_image = transforms.complex_center_crop(target_image, (320, 320)) target_image = transforms.complex_abs(target_image) mean_image = mean_image.unsqueeze(1).unsqueeze(2).to(device) std_image = std_image.unsqueeze(1).unsqueeze(2).to(device) output_image = transforms.normalize(output_image, mean_image, std_image) target_image = transforms.normalize(target_image, mean_image, std_image) target_image = target_image.clamp(-6, 6) # Consistent image loss (with the unnormalized output and target) loss_image = F.l1_loss(output_image, target_image) loss = loss_k_consistent + loss_image return loss
def __call__(self, k_space, mask, target, attrs, f_name, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: k_space (torch.Tensor): k-space(resolution x resolution x 2) target (torch.Tensor): Target image converted to a torch Tensor. fname (str): File name slice (int): Serial number of the slice. """ k_space = transforms.to_tensor(k_space) full_image = transforms.ifft2(k_space) cropped_image = transforms.complex_center_crop( full_image, (self.resolution, self.resolution)) k_space = transforms.fft2(cropped_image) # Normalize input cropped_image, mean, std = transforms.normalize_instance(cropped_image, eps=1e-11) cropped_image = cropped_image.clamp(-6, 6) # Normalize target target = transforms.to_tensor(target) target = transforms.center_crop(target, (self.resolution, self.resolution)) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) return k_space, target, f_name, slice
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.Array): k-space measurements target (numpy.Array): Target image attrs (dict): Acquisition related information stored in the HDF5 object fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice Returns: (tuple): tuple containing: image (torch.Tensor): Normalized zero-filled input image mean (float): Mean of the zero-filled image std (float): Standard deviation of the zero-filled image fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice """ kspace = transforms.to_tensor(kspace) if self.mask_func is not None: seed = tuple(map(ord, fname)) masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed) else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image) image = image.clamp(-6, 6) return image, mean, std, fname, slice
def test_complex_center_crop(shape, target_shape): shape = shape + [2] input = create_input(shape) out_torch = transforms.complex_center_crop(input, target_shape).numpy() assert list(out_torch.shape) == target_shape + [ 2, ]
def __call__(self, kspace, target, challenge, fname, slice_index): original_kspace = transforms.to_tensor(kspace) if self.reduce: original_kspace = reducedimension(original_kspace, self.resolution) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = transforms.apply_mask(original_kspace, self.mask_func, seed) # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image image = transforms.complex_center_crop( image, (self.resolution, self.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) target = transforms.to_tensor(target) # Normalize target target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) if self.polar: original_kspace = cartesianToPolar(original_kspace) masked_kspace = cartesianToPolar(masked_kspace) return original_kspace, masked_kspace, mask, target, fname, slice_index
def data_consistency_kspace(self, prediction, k_space_slice, mask): """ Args: - prediction: net (or block) predicted real image in complex domain - k_space_slice: initially sampled elements in k-space - mask: corresponding nonzero location in kspace Res: image in k space where: - masked entries of initial slice are replaced with entries predicted by output - non masked entries of initial slice stay the same """ prediction = prediction[:, :, 0:320, 0:320] prediction = prediction.permute( 0, 2, 3, 1) # prediction from 1 x 2 x h x w to 1 x h x w x 2 prediction = self.proper_padding( prediction, k_space_slice) # pad prediction to be 640 x 372 x 2 k_space_prediction = fastmri.fft2c( prediction) # transform prediction to kspace domain k_space_out = ( 1 - mask) * k_space_prediction + mask * k_space_slice # apply mask prediction = fastmri.ifft2c(k_space_out) # back to cplx image prediction = transforms.complex_center_crop( prediction, (320, 320)) # crop image to 320 x 320 prediction = prediction.permute(0, 3, 1, 2) # back to 1 x 2 x h x w return prediction
def resize(hparams, image, target): smallest_width = min(hparams.resolution, image.shape[-2]) smallest_height = min(hparams.resolution, image.shape[-3]) if target is not None: smallest_width = min(smallest_width, target.shape[-1]) smallest_height = min(smallest_height, target.shape[-2]) crop_size = (smallest_height, smallest_width) image = transforms.complex_center_crop(image, crop_size) # Absolute value image_abs = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if hparams.challenge == "multicoil": image_abs = transforms.root_sum_of_squares(image_abs) # Normalize input image_abs, mean, std = transforms.normalize_instance(image_abs, eps=1e-11) image_abs = image_abs.clamp(-6, 6) # Normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, image_abs, target, mean, std
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.Array): k-space measurements target (numpy.Array): Target image attrs (dict): Acquisition related information stored in the HDF5 object fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice Returns: (tuple): tuple containing: image (torch.Tensor): Normalized zero-filled input image mean (float): Mean of the zero-filled image std (float): Standard deviation of the zero-filled image fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice """ kspace = transforms.to_tensor(kspace) image = transforms.ifft2(kspace) image = transforms.complex_center_crop( image, (self.resolution, self.resolution)) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) image = transforms.complex_abs(image) image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) kspace = transforms.rfft2(image) return kspace, mean, std, fname, slice
def __call__(self, kspace, mask, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. """ kspace = transforms.to_tensor(kspace) # Apply mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = transforms.apply_mask( kspace, self.mask_func, seed) else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image to given resolution if larger smallest_width = min(self.resolution, image.shape[-2]) smallest_height = min(self.resolution, image.shape[-3]) if target is not None: smallest_width = min(smallest_width, target.shape[-1]) smallest_height = min(smallest_height, target.shape[-2]) crop_size = (smallest_height, smallest_width) image = transforms.complex_center_crop(image, crop_size) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # Normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target, mean, std, fname, slice
def __call__(self, kspace, target, attrs, fname, slice): kspace_rect = transforms.to_tensor(kspace) ##rectangular kspace image_rect = transforms.ifft2(kspace_rect) ##rectangular FS image image_square = transforms.complex_center_crop( image_rect, (self.resolution, self.resolution)) ##cropped to FS square image kspace_square = self.c3object.apply( transforms.fft2(image_square)) #* 10000 ##kspace of square iamge if self.augmentation: kspace_square = self.augmentation.apply(kspace_square) image_square = ifft_c3(kspace_square) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace_square, mask = transforms.apply_mask( kspace_square, self.mask_func, seed) ##ZF square kspace # Inverse Fourier Transform to get zero filled solution # image = transforms.ifft2(masked_kspace) image_square_us = ifft_c3( masked_kspace_square) ## US square complex image # Crop input image # image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # Absolute value # image = transforms.complex_abs(image) image_square_abs = transforms.complex_abs( image_square_us) ## US square real image # Apply Root-Sum-of-Squares if multicoil data # if self.which_challenge == 'multicoil': # image = transforms.root_sum_of_squares(image) # Normalize input # image, mean, std = transforms.normalize_instance(image, eps=1e-11) _, mean, std = transforms.normalize_instance(image_square_abs, eps=1e-11) # image = image.clamp(-6, 6) # target = transforms.to_tensor(target) target = image_square.permute(2, 0, 1) # Normalize target # target = transforms.normalize(target, mean, std, eps=1e-11) # target = target.clamp(-6, 6) # return image, target, mean, std, attrs['norm'].astype(np.float32) # return masked_kspace_square.permute((2,0,1)), image, image_square.permute(2,0,1), mean, std, attrs['norm'].astype(np.float32) # ksp, zf, target, me, st, nor return masked_kspace_square.permute((2,0,1)), image_square_us.permute((2,0,1)), \ target, \ mean, std, attrs['norm'].astype(np.float32)
def __call__(self, kspace, target, attrs, fname, slice): kspace = transforms.to_tensor(kspace) image = transforms.ifft2_regular(kspace) image = transforms.complex_center_crop( image, (self.resolution, self.resolution)) # image = transforms.complex_abs(image) image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) kspace = transforms.fft2(image) return image, mean, std, fname, slice
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. norm (float): L2 norm of the entire volume. """ kspace = transforms.to_tensor(kspace) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) if self.use_mask: mask = transforms.get_mask(kspace, self.mask_func, seed) masked_kspace = mask * kspace else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image image = transforms.complex_center_crop( image, (self.resolution, self.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) if CLAMP: image = image.clamp(-6, 6) # Normalize target target = transforms.to_tensor(target) target_train = transforms.normalize(target, mean, std, eps=1e-11) if CLAMP: target_train = target_train.clamp( -6, 6) # Return target (for viz) and target_clamped (for training) return image, target_train, mean, std, attrs['norm'].astype( np.float32), target
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.Array): k-space measurements target (numpy.Array): Target image attrs (dict): Acquisition related information stored in the HDF5 object fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice Returns: (tuple): tuple containing: image (torch.Tensor): Normalized zero-filled input image mean (float): Mean of the zero-filled image std (float): Standard deviation of the zero-filled image fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice """ kspace = transforms.to_tensor(kspace) if self.mask_func is not None: seed = tuple(map(ord, fname)) masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed) else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image) image = image.clamp(-6, 6) # difference between kspace actual and target dim extra = int(masked_kspace.shape[1] - self.kspace_x) # clip kspace at input dim if extra > 0: masked_kspace = masked_kspace[:, (extra//2):-(extra//2), :] # zero pad if necessary elif extra < 0: empty_kspace = torch.zeros((masked_kspace.shape[0], self.kspace_x, masked_kspace.shape[2])) empty_kspace[:, -(extra//2):(extra//2), :] = masked_kspace masked_kspace = empty_kspace #TODO return mask as well for exclusive updates return masked_kspace, image, mean, std, fname, slice
def to_spatial(self, kspace, resolution): ''' k space: pytorch tensor post enchancement ''' # Inverse Fourier Transform to get interpolated solution image = transforms.ifft2(kspace) # Crop input image image = transforms.complex_center_crop(image, (resolution, resolution)) # Absolute value image = transforms.complex_abs(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def inference(model, data, device): with torch.no_grad(): input, target, mean, std, _, _, mask = data input = input.to(device) mask = mask.to(device) output = model(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) output = input * mask + (1-mask) * output target = target.to(device) mean = mean.to(device) std = std.to(device) output = transforms.unnormalize(output, mean, std) target = transforms.unnormalize(target, mean, std) output = transforms.ifft2(output) target = transforms.ifft2(target) output = transforms.complex_center_crop(output, (320, 320)) output = transforms.complex_abs(output) target = transforms.complex_center_crop(target, (320, 320)) target = transforms.complex_abs(target) return output, target
def generate(generator, data, device): input, _, mean, std, mask, _, _, _ = data input = input.to(device) mask = mask.to(device) output_network = generator(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # Take loss on the cropped, real valued image (abs) mean = mean.to(device) std = std.to(device) output_network = transforms.unnormalize(output_network, mean, std) output_network = transforms.complex_center_crop(output_network, (320, 320)) output_network = transforms.complex_abs(output_network) return output_network
def __call__(self, kspace, target, attrs, fname, slice): kspace = transforms.to_tensor(kspace) image = transforms.ifft2_regular(kspace) image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # image = transforms.complex_abs(image) image, mean, std = transforms.normalize_instance(image, eps=1e-11) # image, mean, std = transforms.normalize_instance_per_channel(image, eps=1e-11) # image = image.clamp(-6, 6) # kspace = transforms.fft2(image) target = transforms.to_tensor(target) target, mean, std = transforms.normalize_instance(target, eps=1e-11) # # target = transforms.normalize(target, mean, std) # target = target.clamp(-6, 6) mean = std = 0 return image, target, mean, std, attrs['norm'].astype(np.float32)
def kspacetoimage(kspace, args): # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(kspace) # Crop input image image = transforms.complex_center_crop(image, (args.resolution, args.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if args.challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def nkspacetoimage(args, kspace_fni, mean, std, eps=1e-11): #nkspace to image assert kspace_fni.size(-1) == 2 image = transforms.ifftshift(kspace_fni, dim=(-3, -2)) image = torch.ifft(image, 2) image = transforms.fftshift(image, dim=(-3, -2)) #denormalizing the nimage image = (image * std) + mean image = image[0] image = transforms.complex_center_crop(image, (args.resolution, args.resolution)) # Absolute value image = transforms.complex_abs(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def generate(generator, data, device): input, _, mean, std, mask, _, _, _ = data input = input.to(device) mask = mask.to(device) # Use network to predict residual residual = generator(input.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) # Projection to consistent K-space if PROJECT: output = project_to_consistent_subspace(residual, input, mask) # Take loss on the cropped, real valued image (abs) mean = mean.to(device) std = std.to(device) output = transforms.unnormalize(output, mean, std) output = transforms.complex_center_crop(output, (320, 320)) output = transforms.complex_abs(output) return output
def save_zero_filled(data_dir, out_dir, which_challenge, resolution): reconstructions = {} for file in data_dir.iterdir(): print("file:{}".format(file)) with h5py.File(file, "r") as hf: masked_kspace = transforms.to_tensor(hf['kspace'][()]) # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image smallest_width = min(resolution, image.shape[-2]) smallest_height = min(resolution, image.shape[-3]) image = transforms.complex_center_crop(image, (smallest_height, smallest_width)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image, dim=1) reconstructions[file.name] = image save_reconstructions(reconstructions, out_dir)
def data_transform(kspace, mask_function, target, data_attributes, filename, slice_num): """ Perform preprocessing of the kspace image, in order to get a proper input for the net. Should be invoked from the SliceData class. Args: - kspace: complete sampled kspace image - mask_func: masking function to apply mask to kspace (TODO not working: we are passing from outside) - target: the target image to be reconstructed from the kspace - data_attributes: attributes of the whole HDF5 file Returns: - normalized_masked_image: original kspace with mask applied and cropped to 320 x 320 - mask: mask generated by masking function - normalized_target: normalized target - max_value: highest entry in target tensor (for SSIM loss) """ kspace_t = transforms.to_tensor(kspace) kspace_t = transforms.normalize_instance(kspace_t)[0] masked_kspace, mask = transforms.apply_mask( data=kspace_t, mask_func=mask_func ) # apply mask: returns masked space and generated mask masked_image = fastmri.ifft2c( masked_kspace ) # Apply Inverse Fourier Transform to get the complex image masked_image = transforms.complex_center_crop( masked_image, (320, 320)) # center crop masked image masked_image = masked_image.permute( 2, 0, 1) # permuting the masked image fot pytorch n x c x h x w format masked_image = transforms.normalize_instance(masked_image)[0] # normalize target = transforms.to_tensor(target) target = transforms.normalize_instance(target)[0] # normalize target = torch.unsqueeze(target, 0) # add dimension return kspace_t, masked_image, target, mask, data_attributes[ 'max'], slice_num
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. norm (float): L2 norm of the entire volume. """ kspace = transforms.to_tensor(kspace) gt = transforms.ifft2(kspace) gt = transforms.complex_center_crop(gt, (self.resolution, self.resolution)) kspace = transforms.fft2(gt) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = transforms.apply_mask(kspace, self.mask_func, seed) # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) masked_kspace = transforms.fft2_nshift(image) # Crop input image image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # Absolute value image_mod = transforms.complex_abs(image).max() image_r = image[:, :, 0]*6.0/image_mod image_i = image[:, :, 1]*6.0/image_mod # image_r = image[:, :, 0] # image_i = image[:, :, 1] # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image = np.stack((image_r, image_i), axis=-1) image = image.transpose((2, 0, 1)) image = transforms.to_tensor(image) target = transforms.ifft2(kspace) target = transforms.complex_center_crop(target, (self.resolution, self.resolution)) # Normalize target target_r = target[:, :, 0]*6.0/image_mod target_i = target[:, :, 1]*6.0/image_mod # target_r = target[:, :, 0] # target_i = target[:, :, 1] target = np.stack((target_r, target_i), axis=-1) target = target.transpose((2, 0, 1)) target = transforms.to_tensor(target) image_mod = np.stack((image_mod, image_mod), axis=0) image_mod = transforms.to_tensor(image_mod) norm = attrs['norm'].astype(np.float32) norm = np.stack((norm, norm), axis=-1) norm = transforms.to_tensor(norm) mask = mask.expand(kspace.shape) mask = mask.transpose(0, 2).transpose(1, 2) mask = transforms.ifftshift(mask) masked_kspace = masked_kspace.transpose(0, 2).transpose(1, 2) return image, target
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. norm (float): L2 norm of the entire volume. """ kspace = transforms.to_tensor(kspace) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) if self.use_mask: mask = transforms.get_mask(kspace, self.mask_func, seed) masked_kspace = mask * kspace else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image image = transforms.complex_center_crop( image, (self.resolution, self.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input if self.normalize: image, mean, std = transforms.normalize_instance(image, eps=1e-11) if CLAMP: image = image.clamp(-6, 6) else: mean = -1.0 std = -1.0 # Normalize target if target is not None: target = transforms.to_tensor(target) target_train = target if self.normalize: target_train = transforms.normalize(target, mean, std, eps=1e-11) if CLAMP: target_train = target_train.clamp( -6, 6 ) # Return target (for viz) and target_clamped (for training) norm = attrs['norm'].astype(np.float32) else: target_train = [] target = [] norm = -1.0 image_updated = [] if os.path.exists( '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_train/' + fname): updated_fname = '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_train/' + fname with h5py.File(updated_fname, 'r') as data: image_updated = data['reconstruction'][slice] image_updated = transforms.to_tensor(image_updated) elif os.path.exists( '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_val/' + fname): updated_fname = '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_val/' + fname with h5py.File(updated_fname, 'r') as data: image_updated = data['reconstruction'][slice] image_updated = transforms.to_tensor(image_updated) elif os.path.exists( '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_test/' + fname): updated_fname = '/home/manivasagam/code/fastMRIPrivate/models/unet_volumes/reconstructions_test/' + fname with h5py.File(updated_fname, 'r') as data: image_updated = data['reconstruction'][slice] image_updated = transforms.to_tensor(image_updated) return image, target_train, mean, std, norm, target, image_updated
def croppedimage(kspace, resolution): image = transforms.ifft2(kspace) image = transforms.complex_center_crop(image, (resolution, resolution)) return image
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. norm (float): L2 norm of the entire volume. """ kspace_rect = transforms.to_tensor(kspace) ##rectangular kspace image_rect = transforms.ifft2(kspace_rect) ##rectangular FS image image_square = transforms.complex_center_crop( image_rect, (self.resolution, self.resolution)) ##cropped to FS square image kspace_square = transforms.fft2(image_square) ##kspace of square iamge if self.augmentation: kspace_square = self.augmentation.apply(kspace_square) image_square = transforms.ifft2(kspace_square) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace_square, mask = transforms.apply_mask( kspace_square, self.mask_func, seed) ##ZF square kspace # Inverse Fourier Transform to get zero filled solution # image = transforms.ifft2(masked_kspace) image_square_us = transforms.ifft2( masked_kspace_square) ## US square complex image # Crop input image # image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # Absolute value # image = transforms.complex_abs(image) image_square_abs = transforms.complex_abs( image_square_us) ## US square real image # Apply Root-Sum-of-Squares if multicoil data # if self.which_challenge == 'multicoil': # image = transforms.root_sum_of_squares(image) # Normalize input # image, mean, std = transforms.normalize_instance(image, eps=1e-11) _, mean, std = transforms.normalize_instance(image_square_abs, eps=1e-11) # image = image.clamp(-6, 6) # target = transforms.to_tensor(target) target = image_square.permute(2, 0, 1) # Normalize target # target = transforms.normalize(target, mean, std, eps=1e-11) # target = target.clamp(-6, 6) # return image, target, mean, std, attrs['norm'].astype(np.float32) # return masked_kspace_square.permute((2,0,1)), image, image_square.permute(2,0,1), mean, std, attrs['norm'].astype(np.float32) # ksp, zf, target, me, st, nor return masked_kspace_square.permute((2,0,1)), image_square_us.permute((2,0,1)), \ target, \ mean, std, attrs['norm'].astype(np.float32)
def __call__(self, kspace, target, attrs, fname, slice): kspace_rect = transforms.to_tensor(kspace) ##rectangular kspace image_rect = transforms.ifft2(kspace_rect) ##rectangular FS image image_square = transforms.complex_center_crop( image_rect, (self.resolution, self.resolution)) ##cropped to FS square image kspace_square = self.c3object.apply( transforms.fft2(image_square)) * 10000 ##kspace of square iamge image_square2 = ifft_c3(kspace_square) ##for training domain_transform if self.augmentation: kspace_square = self.augmentation.apply(kspace_square) # image_square = ifft_c3(kspace_square) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace_square, mask = transforms.apply_mask( kspace_square, self.mask_func, seed) ##ZF square kspace # Inverse Fourier Transform to get zero filled solution # image = transforms.ifft2(masked_kspace) us_image_square = ifft_c3( masked_kspace_square) ## US square complex image # Crop input image # image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # Absolute value # image = transforms.complex_abs(image) us_image_square_abs = transforms.complex_abs( us_image_square) ## US square real image us_image_square_rss = transforms.root_sum_of_squares( us_image_square_abs, dim=0) stacked_kspace_square = [] for i in (range(len(kspace_square[:, 0, 0, 0]))): stacked_kspace_square.append(kspace_square[i, :, :, 0]) stacked_kspace_square.append(kspace_square[i, :, :, 1]) stacked_kspace_square = torch.stack(stacked_kspace_square) stacked_masked_kspace_square = [] # masked_kspace_square = transforms.to_tensor(masked_kspace_square) # for i in range(len(masked_kspace_square[:,0,0,0])): # stacked_masked_kspace_square.stack(masked_kspace_square[i,:,:,0],masked_kspace_square[i,:,:,1]) for i in (range(len(masked_kspace_square[:, 0, 0, 0]))): stacked_masked_kspace_square.append(masked_kspace_square[i, :, :, 0]) stacked_masked_kspace_square.append(masked_kspace_square[i, :, :, 1]) stacked_masked_kspace_square = torch.stack( stacked_masked_kspace_square) stacked_image_square = [] for i in (range(len(image_square[:, 0, 0, 0]))): stacked_image_square.append(image_square2[i, :, :, 0]) stacked_image_square.append(image_square2[i, :, :, 1]) stacked_image_square = torch.stack(stacked_image_square) return stacked_kspace_square,stacked_masked_kspace_square , stacked_image_square , \ us_image_square_rss , \ target *10000 \ #mean, std, attrs['norm'].astype(np.float32) '''
def tosquare(ksp,shp): rec = T.ifft2(ksp) sz = rec.shape return c3m * T.fft2(T.complex_center_crop(rec,shp)) * 100000