def __call__(self, sample): if self.type_of_map == "unit": kspace = sample["kspace"] #TODO(kp) Figure out a way to skip this class entirely if sensitivity map already in sample and est = false #(kp) added if statement to keep sensitivity map from being altered if already existing if 'sensitivity_map' in sample: return sample sensitivity_map = torch.zeros(kspace.shape).float() # TODO(jt): Named variant, this assumes the complex channel is last. if not kspace.names[-1] == "complex": raise NotImplementedError(f"Assuming last channel is complex.") sensitivity_map[..., 0] = 1.0 sample["sensitivity_map"] = sensitivity_map.refine_names(*kspace.names).to( kspace.device ) elif self.type_of_map == "rss_estimate": acs_image = self.estimate_acs_image(sample) acs_image_rss = T.root_sum_of_squares(acs_image, dim="coil").align_as( acs_image ) sample["sensitivity_map"] = T.safe_divide(acs_image, acs_image_rss) else: raise ValueError( f"Expected type of map to be either `unit` or `rss_estimate`. Got {self.type_of_map}." ) return sample
def estimate_sensitivity_map(self, sample): kspace_data = sample[self.kspace_key] if kspace_data.shape[0] == 1: warnings.warn( f"`Single-coil data, skipping estimation of sensitivity map. " f"This warning will be displayed only once." ) return sample if "sensitivity_map" in sample: warnings.warn( f"`sensitivity_map` is given, but will be overwritten. " f"This warning will be displayed only once." ) kspace_acs = T.apply_mask(kspace_data, sample["acs_mask"], return_mask=False) # Get complex-valued data solution image = self.backward_operator(kspace_acs) rss_image = T.root_sum_of_squares(image, dim="coil").align_as(image) # TODO(jt): Safe divide. sensitivity_mask = torch.where( rss_image.rename(None) == 0, torch.tensor([0.0], dtype=rss_image.dtype).to(rss_image.device), (image / rss_image).rename(None), ).refine_names(*image.names) return sensitivity_mask
def __call__(self, sample: Dict[str, Any]): """ Parameters ---------- sample: dict Returns ------- data dictionary """ kspace = sample["kspace"] sensitivity_map = sample.get("sensitivity_map", None) filename = sample["filename"] if "sampling_mask" in sample: if self.mask_func is not None: warnings.warn( f"`sampling_mask` is passed by the Dataset class, yet `mask_func` is also set. " f"This will be ignored and the `sampling_mask` will be used instead. " f"Be aware of this as it can lead to unexpected results. " f"This warning will be issued only once." ) mask_func = sample["sampling_mask"] else: mask_func = self.mask_func seed = None if not self.use_seed else tuple(map(ord, str(filename))) if np.random.random() >= self.kspace_crop_probability: kspace, backprojected_kspace, sensitivity_map = self.__random_image_crop( kspace, sensitivity_map ) masked_kspace, sampling_mask = T.apply_mask(kspace, mask_func, seed) else: masked_kspace, sampling_mask = T.apply_mask(kspace, mask_func, seed) ( kspace, masked_kspace, sampling_mask, backprojected_kspace, sensitivity_map, ) = self.__central_kspace_crop( kspace, masked_kspace, sampling_mask, sensitivity_map ) sample["target"] = T.root_sum_of_squares(backprojected_kspace, dim="coil") del sample["kspace"] sample["masked_kspace"] = masked_kspace sample["sampling_mask"] = sampling_mask if sensitivity_map is not None: sample["sensitivity_map"] = sensitivity_map return sample
def test_root_sum_of_squares_complex(shape, dims): shape = shape + [ 2, ] data = create_input(shape, named=True) # noqa out_torch = transforms.root_sum_of_squares(data, dims).numpy() input_numpy = tensor_to_complex_numpy(data) out_numpy = np.sqrt( np.sum(np.abs(input_numpy)**2, dims if not dims == "coils" else 0)) assert np.allclose(out_torch, out_numpy)
def __call__(self, sample): kspace = sample["kspace"] # We need to create an ACS mask based on the shape of this kspace, as it can be cropped. seed = None if not self.use_seed else tuple(map(ord, str(sample["filename"]))) kspace_shape = sample["kspace"].shape[1:] acs_mask = self.mask_func(kspace_shape, seed, return_acs=True) kspace = acs_mask * kspace + 0.0 acs_image = self.backward_operator(kspace) sample["body_coil_image"] = T.root_sum_of_squares(acs_image, dim="coil") return sample
def __call__(self, sample: Dict[str, Any]): """ Parameters ---------- sample: dict Returns ------- data dictionary """ kspace = sample["kspace"] # Image-space croppable objects croppable_images = ["sensitivity_map", "input_image"] sensitivity_map = sample.get("sensitivity_map", None) sampling_mask = sample["sampling_mask"] backprojected_kspace = self.backward_operator(kspace) # TODO: Also create a kspace-like crop function if self.crop: cropped_output = self.crop_func( [ backprojected_kspace, *[sample[_] for _ in croppable_images if _ in sample], ], self.crop, contiguous=True, ) backprojected_kspace = cropped_output[0] for idx, key in enumerate(croppable_images): sample[key] = cropped_output[1 + idx] # Compute new k-space for the cropped input_image kspace = self.forward_operator(backprojected_kspace) masked_kspace, sampling_mask = T.apply_mask(kspace, sampling_mask) sample["target"] = T.root_sum_of_squares(backprojected_kspace, dim="coil") sample["masked_kspace"] = masked_kspace sample["sampling_mask"] = sampling_mask sample["kspace"] = kspace # The cropped kspace if sensitivity_map is not None: sample["sensitivity_map"] = sensitivity_map return sample
def __call__(self, sample): kspace_data = sample[self.kspace_key] # Get complex-valued image solution image = self.backward_operator(kspace_data) if self.type_reconstruction == 'complex': sample[self.target_key] = image.sum('coil') elif self.type_reconstruction.lower() == 'rss': sample[self.target_key] = transforms.root_sum_of_squares(image, dim='coil') elif self.type_reconstruction == 'sense': if 'sensitivity_map' not in sample: raise ValueError('Sensitivity map is required for SENSE reconstruction.') raise NotImplementedError('SENSE is not implemented.') return sample
def __call__(self, sample): kspace_data = sample[self.kspace_key] # Get complex-valued data solution image = self.backward_operator(kspace_data) if self.type_reconstruction == "complex": sample[self.target_key] = image.sum("coil") elif self.type_reconstruction.lower() == "rss": sample[self.target_key] = T.root_sum_of_squares(image, dim="coil") elif self.type_reconstruction == "sense": if "sensitivity_map" not in sample: raise ValueError( "Sensitivity map is required for SENSE reconstruction.") raise NotImplementedError("SENSE is not implemented.") return sample
def __call__(self, sample: Dict[str, Any]): """ Parameters ---------- sample: dict Returns ------- data dictionary """ kspace = sample['kspace'] sensitivity_map = sample.get('sensitivity_map', None) filename = sample['filename'] if 'sampling_mask' in sample and self.mask_func is not None: warnings.warn(f'`sampling_mask` is passed by the Dataset class, yet `mask_func` is also set. ' f'This will be ignored and the `sampling_mask` will be used instead. ' f'Be aware of this as it can lead to unexpected results. ' f'This warning will be issued only once.') raise NotImplementedError('This is required when a mask is present,' ' but in this case this should be applied differently!') seed = None if not self.use_seed else tuple(map(ord, str(filename))) if np.random.random() >= self.kspace_crop_probability: kspace, backprojected_kspace, sensitivity_map = self.__random_image_crop(kspace, sensitivity_map) masked_kspace, sampling_mask = transforms.apply_mask(kspace, self.mask_func, seed) else: masked_kspace, sampling_mask = transforms.apply_mask(kspace, self.mask_func, seed) kspace, masked_kspace, sampling_mask, backprojected_kspace, sensitivity_map = self.__central_kspace_crop( kspace, masked_kspace, sampling_mask, sensitivity_map) sample['target'] = transforms.root_sum_of_squares(backprojected_kspace, dim='coil') del sample['kspace'] sample['masked_kspace'] = masked_kspace sample['sampling_mask'] = sampling_mask if sensitivity_map is not None: sample['sensitivity_map'] = sensitivity_map return sample
def __call__(self, sample): if self.type_of_map == "unit": kspace = sample["kspace"] sensitivity_map = torch.zeros(kspace.shape).float() # TODO(jt): Named variant, this assumes the complex channel is last. if not kspace.names[-1] == "complex": raise NotImplementedError(f"Assuming last channel is complex.") sensitivity_map[..., 0] = 1.0 sample["sensitivity_map"] = sensitivity_map.refine_names( *kspace.names).to(kspace.device) elif self.type_of_map == "rss_estimate": acs_image = self.estimate_acs_image(sample) acs_image_rss = T.root_sum_of_squares( acs_image, dim="coil").align_as(acs_image) sample["sensitivity_map"] = T.safe_divide(acs_image, acs_image_rss) else: raise ValueError( f"Expected type of map to be either `unit` or `rss_estimate`. Got {self.type_of_map}." ) return sample
def test_root_sum_of_squares_real(shape, dims): data = create_input(shape, named=True) # noqa out_torch = transforms.root_sum_of_squares(data, dims).numpy() out_numpy = np.sqrt(np.sum(data.numpy()**2, dims)) assert np.allclose(out_torch, out_numpy)