def visualize(args, epoch, model, inference, data_loader, writer): def save_image(image, tag): image -= image.min() image /= image.max() grid = torchvision.utils.make_grid(image, nrow=4, pad_value=1) writer.add_image(tag, grid, epoch) model.eval() with torch.no_grad(): for iter, data in enumerate(data_loader): output, target = inference(model, data, device=args.device) # HACK to make images look good in tensorboard output, mean_o, std_o = transforms.normalize_instance(output) output = output.clamp(-6, 6) output = transforms.unnormalize(output, mean_o, std_o) target, mean_t, std_t = transforms.normalize_instance(target) target = target.clamp(-6, 6) target = transforms.unnormalize(target, mean_t, std_t) output = output.unsqueeze(1) # [batch_sz, h, w] --> [batch_sz, 1, h, w] target = target.unsqueeze(1) # [batch_sz, h, w] --> [batch_sz, 1, h, w] if isinstance(output, dict): for k, output_val in output.items(): # save_image(input, 'Input_{}'.format(k)) save_image(target, 'Target_{}'.format(k)) save_image(output, 'Reconstruction_{}'.format(k)) save_image(torch.abs(target - output), 'Error_{}'.format(k)) else: # save_image(input, 'Input') save_image(target, 'Target') save_image(output, 'Reconstruction') save_image(torch.abs(target - output), 'Error') break
def evaluate(args): args.target_path = f'/home/tomerweiss/Datasets/singlecoil_{args.data_split}' args.predictions_path = f'/home/liyon/PILOT/summary/{args.test_name}/rec' print('/home/liyon/PILOT/summary/' + args.test_name + '/rec') print("args: {}".format(args)) metrics = Metrics(METRIC_FUNCS) for tgt_file in pathlib.Path(args.target_path).iterdir(): if tgt_file.is_dir(): continue with h5py.File(tgt_file) as target, h5py.File(args.predictions_path + '/' + tgt_file.name) as recons: if args.acquisition and args.acquisition == target.attrs[ 'acquisition']: continue recons = recons['reconstruction'].value.squeeze(0) target = target['reconstruction_esc'].value target = transforms.to_tensor(target[:]) target = torch.nn.functional.avg_pool2d(target, args.resolution_degrading) target = center_crop_3d(target, recons.shape) target, mean, std = transforms.normalize_instance(target, eps=1e-11) recons, meanr, stdr = transforms.normalize_instance(recons, eps=1e-11) target = target.numpy() metrics.push(target, recons) return metrics
def mnormalize(masked_kspace): #getting image from masked data image = transforms.ifft2(masked_kspace) #normalizing the image nimage, mean, std = transforms.normalize_instance(image, eps=1e-11) #getting kspace data from normalized image maksed_kspace_fni = transforms.ifftshift(nimage, dim=(-3, -2)) maksed_kspace_fni = torch.fft(maksed_kspace_fni, 2) maksed_kspace_fni = transforms.fftshift(maksed_kspace_fni, dim=(-3, -2)) maksed_kspace_fni, mean, std = transforms.normalize_instance(masked_kspace, eps=1e-11) return maksed_kspace_fni, mean, std
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. norm (float): L2 norm of the entire volume. """ target_inference = transforms.to_tensor(target) kspace = transforms.to_tensor(kspace) target = transforms.ifft2(kspace) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) if self.use_mask: mask = transforms.get_mask(kspace, self.mask_func, seed) masked_kspace = mask * kspace else: masked_kspace = kspace image = transforms.ifft2(masked_kspace) image_crop = transforms.complex_center_crop( image, (self.resolution, self.resolution)) _, mean, std = transforms.normalize_instance_complex(image_crop, eps=1e-11) image_abs = transforms.complex_abs(image_crop) image_abs, mean_abs, std_abs = transforms.normalize_instance(image_abs, eps=1e-11) image = transforms.normalize(image, mean, std) target_image_complex_norm = transforms.normalize(target, mean, std) target_kspace_train = transforms.fft2(target_image_complex_norm) target = transforms.complex_center_crop(target, (320, 320)) target = transforms.complex_abs(target) target_train = target if RENORM: target_train = transforms.normalize(target_train, mean_abs, std_abs) if CLAMP: image = image.clamp(-6, 6) target_train = target_train.clamp(-6, 6) return image, target_train, target_kspace_train, mean, std, mask, mean_abs, std_abs, target_inference, attrs[ 'max'], attrs['norm'].astype(np.float32)
def __call__(self, k_space, mask, target, attrs, f_name, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: k_space (torch.Tensor): k-space(resolution x resolution x 2) target (torch.Tensor): Target image converted to a torch Tensor. fname (str): File name slice (int): Serial number of the slice. """ k_space = transforms.to_tensor(k_space) full_image = transforms.ifft2(k_space) cropped_image = transforms.complex_center_crop( full_image, (self.resolution, self.resolution)) k_space = transforms.fft2(cropped_image) # Normalize input cropped_image, mean, std = transforms.normalize_instance(cropped_image, eps=1e-11) cropped_image = cropped_image.clamp(-6, 6) # Normalize target target = transforms.to_tensor(target) target = transforms.center_crop(target, (self.resolution, self.resolution)) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) return k_space, target, f_name, slice
def __call__(self, image, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. norm (float): L2 norm of the entire volume. """ image = transforms.to_tensor(image) image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) target = transforms.to_tensor(target) # Normalize target target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) return image, target, mean, std, fname, slice
def __call__(self, img, fname, slice): image = transforms.to_tensor(img) image = transforms.center_crop(image.permute(2,0,1), (self.resolution, self.resolution)) image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image, mean, std, fname
def __call__(self, kspace, target, challenge, fname, slice_index): original_kspace = transforms.to_tensor(kspace) if self.reduce: original_kspace = reducedimension(original_kspace, self.resolution) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = transforms.apply_mask(original_kspace, self.mask_func, seed) # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image image = transforms.complex_center_crop( image, (self.resolution, self.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) target = transforms.to_tensor(target) # Normalize target target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) if self.polar: original_kspace = cartesianToPolar(original_kspace) masked_kspace = cartesianToPolar(masked_kspace) return original_kspace, masked_kspace, mask, target, fname, slice_index
def __call__(self, kspace, target, attrs, fname, slice): kspace = transforms.to_tensor(kspace) image = transforms.ifft2_regular(kspace) image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # image = transforms.complex_abs(image) image, mean, std = transforms.normalize_instance(image, eps=1e-11) # image, mean, std = transforms.normalize_instance_per_channel(image, eps=1e-11) # image = image.clamp(-6, 6) # kspace = transforms.fft2(image) target = transforms.to_tensor(target) target, mean, std = transforms.normalize_instance(target, eps=1e-11) # # target = transforms.normalize(target, mean, std) # target = target.clamp(-6, 6) mean = std = 0 return image, target, mean, std, attrs['norm'].astype(np.float32)
def resize(hparams, image, target): smallest_width = min(hparams.resolution, image.shape[-2]) smallest_height = min(hparams.resolution, image.shape[-3]) if target is not None: smallest_width = min(smallest_width, target.shape[-1]) smallest_height = min(smallest_height, target.shape[-2]) crop_size = (smallest_height, smallest_width) image = transforms.complex_center_crop(image, crop_size) # Absolute value image_abs = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if hparams.challenge == "multicoil": image_abs = transforms.root_sum_of_squares(image_abs) # Normalize input image_abs, mean, std = transforms.normalize_instance(image_abs, eps=1e-11) image_abs = image_abs.clamp(-6, 6) # Normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, image_abs, target, mean, std
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.Array): k-space measurements target (numpy.Array): Target image attrs (dict): Acquisition related information stored in the HDF5 object fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice Returns: (tuple): tuple containing: image (torch.Tensor): Normalized zero-filled input image mean (float): Mean of the zero-filled image std (float): Standard deviation of the zero-filled image fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice """ kspace = transforms.to_tensor(kspace) image = transforms.ifft2(kspace) image = transforms.complex_center_crop( image, (self.resolution, self.resolution)) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) image = transforms.complex_abs(image) image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) kspace = transforms.rfft2(image) return kspace, mean, std, fname, slice
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. norm (float): L2 norm of the entire volume. """ kspace = transforms.to_tensor(kspace) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) if self.use_mask: mask = transforms.get_mask(kspace, self.mask_func, seed) masked_kspace = mask * kspace else: masked_kspace = kspace image = transforms.ifft2(masked_kspace) _, mean_image, std_image = transforms.normalize_instance(image, eps=1e-11) masked_kspace, mean, std = transforms.normalize_instance_complex(masked_kspace) kspace = transforms.normalize(kspace, mean, std) return masked_kspace, kspace, mean, std, mean_image, std_image, mask
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.Array): k-space measurements target (numpy.Array): Target image attrs (dict): Acquisition related information stored in the HDF5 object fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice Returns: (tuple): tuple containing: image (torch.Tensor): Normalized zero-filled input image mean (float): Mean of the zero-filled image std (float): Standard deviation of the zero-filled image fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice """ kspace = transforms.to_tensor(kspace) if self.mask_func is not None: seed = tuple(map(ord, fname)) masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed) else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image) image = image.clamp(-6, 6) return image, mean, std, fname, slice
def __call__(self, image): """ Args: image (numpy.array): DICOM image Returns: image (torch.Tensor): Zero-filled input image. """ # image = np.rot90(image, axes=(0, 1)).copy() image = np.flip(image, 0) # if image.shape[0] < self.resolution or image.shape[1] < self.resolution: # return None # # Crop center # image = transforms.center_crop(image, (self.resolution, self.resolution)) res_crop = min(image.shape[0], image.shape[1]) image = transforms.center_crop(image, (res_crop, res_crop)) image = cv2.resize(image, dsize=(self.resolution, self.resolution), interpolation=cv2.INTER_CUBIC) # Normalize input image = transforms.to_tensor(image) image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def evaluate(args): args.target_path = f'Datasets/brainT1/{args.data_split}' args.predictions_path = f'/summary/{args.test_name}/rec' print('summary/' + args.test_name + '/rec') print("args: {}".format(args)) metrics = Metrics(METRIC_FUNCS) for tgt_file in pathlib.Path(args.target_path).iterdir(): if tgt_file.is_dir(): continue with h5py.File(tgt_file) as target, h5py.File(args.predictions_path + '/' + tgt_file.name) as recons: if args.acquisition and args.acquisition == target.attrs[ 'acquisition']: continue recons = recons['reconstruction'].value.squeeze(0) target = target['data'].value target = transforms.to_tensor(target[:]) if recons.shape[0] == recons.shape[2]: # This means we look at the 3d resolution min_to_crop = min(target.shape[0], target.shape[1], target.shape[2]) min_to_crop -= min_to_crop % 2 target = pytorch_nufft.transforms.center_crop_3d( target, (min_to_crop, min_to_crop, min_to_crop)) target = ndimage.zoom(target, recons.shape[0] / min_to_crop) target = torch.from_numpy(target).float() else: target = torch.nn.functional.avg_pool2d( target, args.resolution_degrading) args.resolution = min( target.shape[0] // args.resolution_degrading, recons.shape[0]) target = pytorch_nufft.transforms.center_crop_3d( target, recons.shape) target, mean, std = transforms.normalize_instance(target, eps=1e-11) recons, meanr, stdr = transforms.normalize_instance(recons, eps=1e-11) target = target.numpy() metrics.push(target, recons) return metrics
def test_normalize_instance(shape): input = create_input(shape) output, mean, stddev = transforms.normalize_instance(input) output = output.numpy() assert np.isclose(input.numpy().mean(), mean, rtol=1e-2) assert np.isclose(input.numpy().std(), stddev, rtol=1e-2) assert np.isclose(output.mean(), 0, rtol=1e-2, atol=1e-3) assert np.isclose(output.std(), 1, rtol=1e-2, atol=1e-3)
def __call__(self, ds_slice, gt_slice, attrs, file_name, s_idx, acc_fac): assert gt_slice is None with torch.autograd.no_grad(): ds_slice, mean, std = normalize_instance(to_tensor(ds_slice)) ds_slice = ds_slice.clamp(min=-6, max=6).unsqueeze(dim=0) assert isinstance(file_name, str) and isinstance(s_idx, int), 'Incorrect types!' extra_params = dict(mean=mean, std=std, acc_fac=acc_fac, attrs=attrs) return ds_slice, file_name, s_idx, extra_params
def __call__(self, kspace, mask, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. """ kspace = transforms.to_tensor(kspace) # Apply mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = transforms.apply_mask( kspace, self.mask_func, seed) else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image to given resolution if larger smallest_width = min(self.resolution, image.shape[-2]) smallest_height = min(self.resolution, image.shape[-3]) if target is not None: smallest_width = min(smallest_width, target.shape[-1]) smallest_height = min(smallest_height, target.shape[-2]) crop_size = (smallest_height, smallest_width) image = transforms.complex_center_crop(image, crop_size) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # Normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target, mean, std, fname, slice
def __call__(self, kspace, target, attrs, fname, slice): kspace_rect = transforms.to_tensor(kspace) ##rectangular kspace image_rect = transforms.ifft2(kspace_rect) ##rectangular FS image image_square = transforms.complex_center_crop( image_rect, (self.resolution, self.resolution)) ##cropped to FS square image kspace_square = self.c3object.apply( transforms.fft2(image_square)) #* 10000 ##kspace of square iamge if self.augmentation: kspace_square = self.augmentation.apply(kspace_square) image_square = ifft_c3(kspace_square) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace_square, mask = transforms.apply_mask( kspace_square, self.mask_func, seed) ##ZF square kspace # Inverse Fourier Transform to get zero filled solution # image = transforms.ifft2(masked_kspace) image_square_us = ifft_c3( masked_kspace_square) ## US square complex image # Crop input image # image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # Absolute value # image = transforms.complex_abs(image) image_square_abs = transforms.complex_abs( image_square_us) ## US square real image # Apply Root-Sum-of-Squares if multicoil data # if self.which_challenge == 'multicoil': # image = transforms.root_sum_of_squares(image) # Normalize input # image, mean, std = transforms.normalize_instance(image, eps=1e-11) _, mean, std = transforms.normalize_instance(image_square_abs, eps=1e-11) # image = image.clamp(-6, 6) # target = transforms.to_tensor(target) target = image_square.permute(2, 0, 1) # Normalize target # target = transforms.normalize(target, mean, std, eps=1e-11) # target = target.clamp(-6, 6) # return image, target, mean, std, attrs['norm'].astype(np.float32) # return masked_kspace_square.permute((2,0,1)), image, image_square.permute(2,0,1), mean, std, attrs['norm'].astype(np.float32) # ksp, zf, target, me, st, nor return masked_kspace_square.permute((2,0,1)), image_square_us.permute((2,0,1)), \ target, \ mean, std, attrs['norm'].astype(np.float32)
def __call__(self, ds_slice, gt_slice, attrs, file_name, s_idx, acc_fac): with torch.autograd.no_grad(): ds_slice, mean, std = normalize_instance(to_tensor(ds_slice)) gt_slice = normalize(to_tensor(gt_slice), mean, std) ds_slice = ds_slice.clamp(min=-6, max=6).unsqueeze(dim=0) gt_slice = gt_slice.clamp(min=-6, max=6).unsqueeze(dim=0) ds_slice, gt_slice = self.augment_data(ds_slice, gt_slice) return ds_slice, gt_slice, 0
def __call__(self, kspace, target, attrs, fname, slice): kspace = transforms.to_tensor(kspace) image = transforms.ifft2_regular(kspace) image = transforms.complex_center_crop( image, (self.resolution, self.resolution)) # image = transforms.complex_abs(image) image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) kspace = transforms.fft2(image) return image, mean, std, fname, slice
def train_step(model, data, device): input, target, mean, std, norm, _, image_updated = data # input, mean, std = transforms.normalize_instance(input, eps=1e-11) target, _, _ = transforms.normalize_instance(target, eps=1e-11) if CLAMP: target = target.clamp(-6, 6) if len(image_updated) != 0: input = image_updated input, mean, std = transforms.normalize_instance(input, eps=1e-11) if CLAMP: input = input.clamp(-6, 6) input = input.unsqueeze(0).unsqueeze(1).to(device) target = target.to(device) output = model(input).squeeze(1).squeeze(0) if SMOOTH: loss = F.smooth_l1_loss(output, target) else: loss = F.l1_loss(output, target) return loss
def __call__(self, kspace, target, attrs, fname): target = transforms.to_tensor(target[:]) target = torch.nn.functional.avg_pool2d(target,self.resolution_degrading) self.resolution=min(320//self.resolution_degrading,self.resolution) target = pytorch_nufft.transforms.center_crop_3d(target, (self.depth, self.resolution, self.resolution)) target, mean, std = transforms.normalize_instance(target, eps=1e-11) target = target.clamp(-6, 6) kspace = pytorch_nufft.transforms.rfft3_regular(target) return kspace, target, mean, std
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. norm (float): L2 norm of the entire volume. """ kspace = transforms.to_tensor(kspace) # Apply mask seed = None if not self.use_seed else tuple(map(ord, fname)) if self.use_mask: mask = transforms.get_mask(kspace, self.mask_func, seed) masked_kspace = mask * kspace else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image image = transforms.complex_center_crop( image, (self.resolution, self.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) if CLAMP: image = image.clamp(-6, 6) # Normalize target target = transforms.to_tensor(target) target_train = transforms.normalize(target, mean, std, eps=1e-11) if CLAMP: target_train = target_train.clamp( -6, 6) # Return target (for viz) and target_clamped (for training) return image, target_train, mean, std, attrs['norm'].astype( np.float32), target
def data_transform(kspace, mask_function, target, data_attributes, filename, slice_num): """ Perform preprocessing of the kspace image, in order to get a proper input for the net. Should be invoked from the SliceData class. Args: - kspace: complete sampled kspace image - mask_func: masking function to apply mask to kspace (TODO not working: we are passing from outside) - target: the target image to be reconstructed from the kspace - data_attributes: attributes of the whole HDF5 file Returns: - normalized_masked_image: original kspace with mask applied and cropped to 320 x 320 - mask: mask generated by masking function - normalized_target: normalized target - max_value: highest entry in target tensor (for SSIM loss) """ kspace_t = transforms.to_tensor(kspace) kspace_t = transforms.normalize_instance(kspace_t)[0] masked_kspace, mask = transforms.apply_mask( data=kspace_t, mask_func=mask_func ) # apply mask: returns masked space and generated mask masked_image = fastmri.ifft2c( masked_kspace ) # Apply Inverse Fourier Transform to get the complex image masked_image = transforms.complex_center_crop( masked_image, (320, 320)) # center crop masked image masked_image = masked_image.permute( 2, 0, 1) # permuting the masked image fot pytorch n x c x h x w format masked_image = transforms.normalize_instance(masked_image)[0] # normalize target = transforms.to_tensor(target) target = transforms.normalize_instance(target)[0] # normalize target = torch.unsqueeze(target, 0) # add dimension return kspace_t, masked_image, target, mask, data_attributes[ 'max'], slice_num
def __call__(self, kspace, target, attrs, fname, slice): """ Args: kspace (numpy.Array): k-space measurements target (numpy.Array): Target image attrs (dict): Acquisition related information stored in the HDF5 object fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice Returns: (tuple): tuple containing: image (torch.Tensor): Normalized zero-filled input image mean (float): Mean of the zero-filled image std (float): Standard deviation of the zero-filled image fname (pathlib.Path): Path to the input file slice (int): Serial number of the slice """ kspace = transforms.to_tensor(kspace) if self.mask_func is not None: seed = tuple(map(ord, fname)) masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed) else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image image = transforms.complex_center_crop(image, (self.resolution, self.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image) image = image.clamp(-6, 6) # difference between kspace actual and target dim extra = int(masked_kspace.shape[1] - self.kspace_x) # clip kspace at input dim if extra > 0: masked_kspace = masked_kspace[:, (extra//2):-(extra//2), :] # zero pad if necessary elif extra < 0: empty_kspace = torch.zeros((masked_kspace.shape[0], self.kspace_x, masked_kspace.shape[2])) empty_kspace[:, -(extra//2):(extra//2), :] = masked_kspace masked_kspace = empty_kspace #TODO return mask as well for exclusive updates return masked_kspace, image, mean, std, fname, slice
def inference(model, data, device): input, target, mean, std, norm, unnormalized_target, image_updated = data if len(target) != 0: target, _, _ = transforms.normalize_instance(target, eps=1e-11) if len(image_updated) != 0: input = image_updated input, mean, std = transforms.normalize_instance(input, eps=1e-11) if CLAMP: input = input.clamp(-6, 6) input = input.unsqueeze(0).unsqueeze(1).to(device) if len(unnormalized_target) != 0: unnormalized_target = unnormalized_target.to(device) output = model(input).squeeze(1).squeeze(0) mean = mean.unsqueeze(0).unsqueeze(1).unsqueeze(2).to(device) std = std.unsqueeze(0).unsqueeze(1).unsqueeze(2).to(device) output = transforms.unnormalize(output, mean, std) # if len(target) != 0: # target = transforms.unnormalize(target, mean, std) # if len(target) != 0: # target = target * std + mean return output, unnormalized_target
def to_spatial(self, kspace, resolution): ''' k space: pytorch tensor post enchancement ''' # Inverse Fourier Transform to get interpolated solution image = transforms.ifft2(kspace) # Crop input image image = transforms.complex_center_crop(image, (resolution, resolution)) # Absolute value image = transforms.complex_abs(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def kspacetoimage(kspace, args): # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(kspace) # Crop input image image = transforms.complex_center_crop(image, (args.resolution, args.resolution)) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if args.challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def nkspacetoimage(args, kspace_fni, mean, std, eps=1e-11): #nkspace to image assert kspace_fni.size(-1) == 2 image = transforms.ifftshift(kspace_fni, dim=(-3, -2)) image = torch.ifft(image, 2) image = transforms.fftshift(image, dim=(-3, -2)) #denormalizing the nimage image = (image * std) + mean image = image[0] image = transforms.complex_center_crop(image, (args.resolution, args.resolution)) # Absolute value image = transforms.complex_abs(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image