def cs_total_variation(args, kspace, acquisition, acceleration, num_low_freqs): """ Run ESPIRIT coil sensitivity estimation and Total Variation Minimization based reconstruction algorithm using the BART toolkit. """ if acquisition not in REG_PARAM[args.challenge]: raise ValueError(f'Invalid acquisition protocol: {acquisition}') if acceleration not in {4, 8}: raise ValueError(f'Invalid acceleration factor: {acceleration}') if args.challenge == 'singlecoil': kspace = kspace.unsqueeze(0) kspace = kspace.permute(1, 2, 0, 3).unsqueeze(0) kspace = tensor_to_complex_np(kspace) # Estimate sensitivity maps sens_maps = bart.bart(1, f'ecalib -d0 -m1 -r {num_low_freqs}', kspace) # Use Total Variation Minimization to reconstruct the image reg_wt = REG_PARAM[args.challenge][acquisition][acceleration] pred = bart.bart(1, f'pics -d0 -S -R T:7:0:{reg_wt} -i {args.num_iters}', kspace, sens_maps) pred = torch.from_numpy(np.abs(pred[0])) # Crop the predicted image to selected resolution if bigger smallest_width = min(args.resolution, pred.shape[-1]) smallest_height = min(args.resolution, pred.shape[-2]) return transforms.center_crop(pred, (smallest_height, smallest_width))
def resize(hparams, image, target): smallest_width = min(hparams.resolution, image.shape[-2]) smallest_height = min(hparams.resolution, image.shape[-3]) if target is not None: smallest_width = min(smallest_width, target.shape[-1]) smallest_height = min(smallest_height, target.shape[-2]) crop_size = (smallest_height, smallest_width) image = transforms.complex_center_crop(image, crop_size) # Absolute value image_abs = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if hparams.challenge == "multicoil": image_abs = transforms.root_sum_of_squares(image_abs) # Normalize input image_abs, mean, std = transforms.normalize_instance(image_abs, eps=1e-11) image_abs = image_abs.clamp(-6, 6) # Normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, image_abs, target, mean, std
def __call__(self, image): """ Args: image (numpy.array): DICOM image Returns: image (torch.Tensor): Zero-filled input image. """ # image = np.rot90(image, axes=(0, 1)).copy() image = np.flip(image, 0) # if image.shape[0] < self.resolution or image.shape[1] < self.resolution: # return None # # Crop center # image = transforms.center_crop(image, (self.resolution, self.resolution)) res_crop = min(image.shape[0], image.shape[1]) image = transforms.center_crop(image, (res_crop, res_crop)) image = cv2.resize(image, dsize=(self.resolution, self.resolution), interpolation=cv2.INTER_CUBIC) # Normalize input image = transforms.to_tensor(image) image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def __call__(self, k_space, mask, target, attrs, f_name, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: k_space (torch.Tensor): k-space(resolution x resolution x 2) target (torch.Tensor): Target image converted to a torch Tensor. fname (str): File name slice (int): Serial number of the slice. """ k_space = transforms.to_tensor(k_space) full_image = transforms.ifft2(k_space) cropped_image = transforms.complex_center_crop( full_image, (self.resolution, self.resolution)) k_space = transforms.fft2(cropped_image) # Normalize input cropped_image, mean, std = transforms.normalize_instance(cropped_image, eps=1e-11) cropped_image = cropped_image.clamp(-6, 6) # Normalize target target = transforms.to_tensor(target) target = transforms.center_crop(target, (self.resolution, self.resolution)) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) return k_space, target, f_name, slice
def __call__(self, img, fname, slice): image = transforms.to_tensor(img) image = transforms.center_crop(image.permute(2,0,1), (self.resolution, self.resolution)) image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image, mean, std, fname
def test_step(self, batch, batch_idx): masked_kspace, mask, _, fname, slice, _ = batch output = self.forward(masked_kspace, mask) output = T.center_crop(output,(self.hparams.resolution,self.hparams.resolution)) return { 'fname': fname, 'slice': slice, 'output': output.cpu().numpy(), }
def __call__(self, kspace, mask, target, attrs, fname, slice): """ Args: kspace (numpy.array): Input k-space of shape (num_coils, rows, cols, 2) for multi-coil data or (rows, cols, 2) for single coil data. mask (numpy.array): Mask from the test dataset target (numpy.array): Target image attrs (dict): Acquisition related information stored in the HDF5 object. fname (str): File name slice (int): Serial number of the slice. Returns: (tuple): tuple containing: image (torch.Tensor): Zero-filled input image. target (torch.Tensor): Target image converted to a torch Tensor. mean (float): Mean value used for normalization. std (float): Standard deviation value used for normalization. """ kspace = transforms.to_tensor(kspace) # Apply mask if self.mask_func: seed = None if not self.use_seed else tuple(map(ord, fname)) masked_kspace, mask = transforms.apply_mask( kspace, self.mask_func, seed) else: masked_kspace = kspace # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) # Crop input image to given resolution if larger smallest_width = min(self.resolution, image.shape[-2]) smallest_height = min(self.resolution, image.shape[-3]) if target is not None: smallest_width = min(smallest_width, target.shape[-1]) smallest_height = min(smallest_height, target.shape[-2]) crop_size = (smallest_height, smallest_width) image = transforms.complex_center_crop(image, crop_size) # Absolute value image = transforms.complex_abs(image) # Apply Root-Sum-of-Squares if multicoil data if self.which_challenge == 'multicoil': image = transforms.root_sum_of_squares(image) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) # Normalize target if target is not None: target = transforms.to_tensor(target) target = transforms.center_crop(target, crop_size) target = transforms.normalize(target, mean, std, eps=1e-11) target = target.clamp(-6, 6) else: target = torch.Tensor([0]) return image, target, mean, std, fname, slice
def evaluate(args, recons_key): metrics = Metrics(METRIC_FUNCS) for tgt_file in args.target_path.iterdir(): with h5py.File(tgt_file, 'r') as target, h5py.File( args.predictions_path / tgt_file.name, 'r') as recons: if args.acquisition and args.acquisition != target.attrs[ 'acquisition']: continue if args.acceleration and target.attrs[ 'acceleration'] != args.acceleration: continue target = target[recons_key][()] recons = recons['reconstruction'][()] target = transforms.center_crop( target, (target.shape[-1], target.shape[-1])) recons = transforms.center_crop( recons, (target.shape[-1], target.shape[-1])) metrics.push(target, recons) return metrics
def k_space_to_image_with_mask(kspace, mask_func=None, seed=None): #use_seed = False #seed = None if not use_seed else tuple(map(ord, fname)) #seed = 42 #print(fname) #kspace = transforms.to_tensor(kspace) if mask_func: masked_kspace, mask = transforms.apply_mask(kspace, mask_func, seed) # Inverse Fourier Transform to get zero filled solution image = transforms.ifft2(masked_kspace) else: image = transforms.ifft2(kspace) image = transforms.complex_abs(image) image = transforms.center_crop(image, (320, 320)) # Normalize input image, mean, std = transforms.normalize_instance(image, eps=1e-11) image = image.clamp(-6, 6) return image
def cs_total_variation(args, kspace): """ Run ESPIRIT coil sensitivity estimation and Total Variation Minimization based reconstruction algorithm using the BART toolkit. """ if args.challenge == 'singlecoil': kspace = kspace.unsqueeze(0) kspace = kspace.permute(1, 2, 0, 3).unsqueeze(0) kspace = tensor_to_complex_np(kspace) # Estimate sensitivity maps sens_maps = bart.bart(1, f'ecalib -d0 -m1', kspace) # Use Total Variation Minimization to reconstruct the image pred = bart.bart( 1, f'pics -d0 -S -R T:7:0:{args.reg_wt} -i {args.num_iters}', kspace, sens_maps) pred = torch.from_numpy(np.abs(pred[0])) # Crop the predicted image to the correct size return transforms.center_crop(pred, (args.resolution, args.resolution))
def evaluate(): args = create_arg_parser().parse_args() args.target_path = f'{args.data_path}/{args.data_split}' args.predictions_path = f'summary/{args.test_name}/rec' metrics = Metrics(METRIC_FUNCS) i=0 for rcn_file in pathlib.Path(args.predictions_path).iterdir(): with h5py.File(rcn_file) as recons, h5py.File( args.target_path +'/'+ rcn_file.name) as target: target = target['data'][()] target = transforms.to_tensor(target) target = transforms.center_crop(target.permute(2, 0, 1), (144, 144)) target, mean, std = transforms.normalize_instance(target, eps=1e-11) target = target.clamp(-6, 6) recons = recons['reconstruction'][()] if target.max() != 0: target -= target.min() target /= target.max() recons -= recons.min() recons /= recons.max() metrics.push(target.numpy(), recons.squeeze()) return metrics
def test_center_crop(shape, target_shape): input = create_input(shape) out_torch = transforms.center_crop(input, target_shape).numpy() assert list(out_torch.shape) == target_shape