def __init__(self, root_path: str, subject: str, patch_size, step): self._image_path = os.path.join(root_path, subject, "T1", "T1.nii.gz") self._image = PadToShape(target_shape=[1, 256, 256, 192])(ToNumpyArray()(self._image_path)) self._slices = SliceBuilder(self._image, patch_size=patch_size, step=step).build_slices( keep_centered_on_foreground=True) self._image_max = self._image.max() self._patch_size = patch_size self._step = step
def _crop_to_content(self, t1, labels): self._transform_crop = transforms.Compose( [CropToContent(), PadToShape(self._normalized_shape)]) try: sample = Sample(x=t1, y=labels, is_labeled=True) transformed_image = self._transform_crop(sample) return transformed_image.x, transformed_image.y except Exception as e: self.LOGGER.warning(e)
def __init__(self, root_dir_iseg: str, root_dir_mrbrains: str, output_dir: str, params: dict = None): self._root_dir_iseg = root_dir_iseg self._root_dir_mrbrains = root_dir_mrbrains self._output_dir = output_dir self._normalized_shape = self.compute_normalized_shape_from_images_in( self._root_dir_iseg, self._root_dir_mrbrains) self._transforms = transforms.Compose([ ToNumpyArray(), CropToContent(), PadToShape(self._normalized_shape) ]) self._params = params
def __init__(self, images: List[str], patch_size: Tuple[int, int, int, int], reconstructed_image_size: Tuple[int, int, int, int], step: Tuple[int, int, int, int], batch_size: int = 5, models: List[torch.nn.Module] = None, normalize: bool = False, is_ground_truth: bool = False, normalize_and_segment: bool = False, is_multimodal=False, alpha=0.0, prob_bias=0.0, snr=0.0, prob_noise=0.0): self._patch_size = patch_size self._reconstructed_image_size = reconstructed_image_size self._step = step self._models = models self._do_normalize = normalize self._is_ground_truth = is_ground_truth self._do_normalize_and_segment = normalize_and_segment self._is_multimodal = is_multimodal self._batch_size = batch_size self._alpha = alpha self._snr = snr self._prob_bias = prob_bias self._prob_noise = prob_noise transformed_images = [] for image in images: transform = Compose([ ToNumpyArray(), PadToShape(target_shape=self._reconstructed_image_size) ]) transformed_images.append(transform(image)) self._images = transformed_images self._overlap_maps = list( map( lambda image: SliceBuilder(image, self._patch_size, self._step) .build_overlap_map(), self._images))
class SingleImageDataset(Dataset): def __init__(self, root_path: str, subject: str, patch_size, step): self._image_path = os.path.join(root_path, subject, "T1", "T1.nii.gz") self._image = PadToShape(target_shape=[1, 256, 256, 192])(ToNumpyArray()(self._image_path)) self._slices = SliceBuilder(self._image, patch_size=patch_size, step=step).build_slices( keep_centered_on_foreground=True) self._image_max = self._image.max() self._patch_size = patch_size self._step = step def __getitem__(self, index): try: image = torch.tensor([(self._image[self._slices[index]])], dtype=torch.float32, requires_grad=False).squeeze(0) return image except Exception as e: pass def __len__(self): return len(self._slices) @property def image_shape(self): return self._image.shape
def run(self, prefix="standardize_"): images_np = list() headers = list() file_names = list() root_dirs = list() root_dirs_number = list() EXCLUDED = ["ROI", "label", "Normalized"] for root, dirs, files in os.walk(os.path.join( self._root_dir_mrbrains)): root_dir_number = os.path.basename(os.path.normpath(root)) images = list(filter(re.compile(r"^T.*\.nii").search, files)) for file in images: try: self.LOGGER.info("Processing: {}".format(file)) file_names.append(file) root_dirs.append(root) root_dirs_number.append(root_dir_number) images_np.append(self._transforms(os.path.join(root, file))) headers.append( self._get_image_header(os.path.join(root, file))) except Exception as e: self.LOGGER.warning(e) for root, dirs, files in os.walk(os.path.join(self._root_dir_iseg)): if os.path.basename(os.path.normpath(root)) in EXCLUDED: continue root_dir_number = os.path.basename(os.path.normpath(root)) images = list(filter(re.compile(r".*T.*\.nii").search, files)) for file in images: try: self.LOGGER.info("Processing: {}".format(file)) file_names.append(file) root_dirs.append(root) root_dirs_number.append(root_dir_number) images_np.append(self._transforms(os.path.join(root, file))) headers.append( self._get_image_header(os.path.join(root, file))) except Exception as e: self.LOGGER.warning(e) images = np.array(images_np).astype(np.float32) transformed_images = np.subtract(images, np.mean(images)) / np.std(images) for i in range(transformed_images.shape[0]): if "MRBrainS" in root_dirs[i]: root_dir_number = os.path.basename( os.path.normpath(root_dirs[i])) if not os.path.exists( os.path.join( self._output_dir, "MRBrainS/Dual_Standardized/{}". format(root_dir_number))): os.makedirs( os.path.join( self._output_dir, "MRBrainS/Dual_Standardized/{}".format( root_dir_number))) transforms_ = transforms.Compose([ ToNifti1Image(), NiftiToDisk( os.path.join( os.path.join( self._output_dir, os.path.join("MRBrainS/Dual_Standardized", root_dir_number)), prefix + file_names[i])) ]) transforms_(transformed_images[i]) elif "iSEG" in root_dirs[i]: root_dir_number = os.path.basename( os.path.normpath(root_dirs[i])) if not os.path.exists( os.path.join( self._output_dir, "iSEG/Dual_Standardized/{}". format(root_dir_number))): os.makedirs( os.path.join( self._output_dir, "iSEG/Dual_Standardized/{}".format( root_dir_number))) transforms_ = transforms.Compose([ ToNifti1Image(), NiftiToDisk( os.path.join( os.path.join( self._output_dir, os.path.join("iSEG/Dual_Standardized", root_dir_number)), prefix + file_names[i])) ]) transforms_(transformed_images[i]) for root, dirs, files in os.walk(self._root_dir_mrbrains): root_dir_end = os.path.basename(os.path.normpath(root)) images = list( filter(re.compile(r"^LabelsFor.*\.nii").search, files)) for file in images: if not os.path.exists( os.path.join( self._output_dir, os.path.join("MRBrainS/Dual_Standardized", root_dir_end))): os.makedirs( os.path.join( self._output_dir, os.path.join("MRBrainS/Dual_Standardized", root_dir_end))) transforms_ = transforms.Compose([ ToNumpyArray(), CropToContent(), PadToShape(self._normalized_shape), ToNifti1Image(), NiftiToDisk( os.path.join( os.path.join( self._output_dir, os.path.join("MRBrainS/Dual_Standardized", root_dir_end)), file)) ]) transforms_(os.path.join(root, file)) for root, dirs, files in os.walk(self._root_dir_iseg): root_dir_end = os.path.basename(os.path.normpath(root)) if "ROI" in root_dir_end or "label" in root_dir_end: for file in files: if not os.path.exists( os.path.join( self._output_dir, os.path.join("iSEG/Dual_Standardized", root_dir_end))): os.makedirs( os.path.join( self._output_dir, os.path.join("iSEG/Dual_Standardized", root_dir_end))) transforms_ = transforms.Compose([ ToNumpyArray(), CropToContent(), PadToShape(self._normalized_shape), ToNifti1Image(), NiftiToDisk( os.path.join( os.path.join( self._output_dir, os.path.join("iSEG/Dual_Standardized", root_dir_end)), file)) ]) transforms_(os.path.join(root, file))
return np.sqrt(js_div) if __name__ == '__main__': model = Unet(1, 1, True, True) iseg_csv = "/mnt/md0/Data/iSEG_scaled/Training/output_iseg_images.csv" mrbrains_csv = "/mnt/md0/Data/MRBrainS_scaled/DataNii/TrainingData/output_mrbrains_images.csv" abide_csv = "/mnt/md0/Data/ABIDE_scaled/output_abide_images.csv" iseg_csv = pandas.read_csv(iseg_csv) mrbrains_csv = pandas.read_csv(mrbrains_csv) abide_csv = pandas.read_csv(abide_csv).sample(75) c, d, h, w = 1, 256, 256, 192 transform = transforms.Compose([ToNumpyArray(), PadToShape((c, d, h, w))]) iseg_inputs = torch.tensor([transform(image) for image in iseg_csv["T1"]]) mrbrains_inputs = torch.tensor( [transform(image) for image in mrbrains_csv["T1"]]) abide_inputs = torch.tensor( [transform(image) for image in abide_csv["T1"]]) generated_iseg = transform( "/mnt/md0/Research/DualUNet/Reconstructed_Normalized_iSEG_Image_80.nii.gz" ) generated_mrbrains = transform( "/mnt/md0/Research/DualUNet/Reconstructed_Normalized_MRBrainS_Image_80.nii.gz" ) generated_abide = transform( "/mnt/md0/Research/DualUNet/Reconstructed_Normalized_ABIDE_Image_80.nii.gz" ) segmentation_iseg = transform(