Exemplo n.º 1
0
    def __init__(self,
                 device,
                 seeds=None,
                 dataset="100307",
                 step_width=0.8,
                 b_val=1000,
                 action_space=100,
                 grid_dim=(3, 3, 3),
                 max_steps=2000,
                 fa_threshold=0.2,
                 bundles_path="data/gt_bundles/",
                 odf_mode="CSD"):

        print("Loading dataset # ", dataset)
        self.device = device
        preprocessor = DataPreprocessor().normalize().crop(b_val).fa_estimate()
        if dataset == 'ISMRM':
            self.dataset = preprocessor.get_ismrm(f"data/ISMRM2015/")
        else:
            self.dataset = preprocessor.get_hcp(f"data/HCP/{dataset}/")
        self.sphere = get_uniform_hemisphere_with_points(
            action_space=action_space)
        self.directions = torch.from_numpy(
            self.sphere.vertices).to(device=device)
        self.grid = torch.from_numpy(get_grid(
            np.array(grid_dim))).to(device=device)
        self.action_space = Discrete(action_space)

        if seeds is None:
            seeds = utils.seeds_from_mask(self.dataset.binary_mask,
                                          self.dataset.aff)
        self.seeds = seeds.to(device=device).float()  # IJK
        print("[I] seeds are supposed to be in IJK")
        self.max_steps = max_steps
        self.step_width = step_width

        self.dwi = torch.from_numpy(Resample100().process(
            self.dataset, None, self.dataset.dwi)).to(device=device).float()
        self.dwi_processor = TorchGridInterpolator(self.dwi)
        self.binary_mask = torch.from_numpy(
            self.dataset.binary_mask).to(device=device)
        self.fa_interpolator = TorchGridInterpolator(
            torch.from_numpy(
                self.dataset.fa).to(device=device).unsqueeze(-1).float())
        self.fa_threshold = fa_threshold

        self._init_na(Path(bundles_path))
        self._init_odf(odf_mode=odf_mode)
        self.ras_aff = torch.from_numpy(
            self.dataset.aff).to(device=device).float()
        self.ijk_aff = self.ras_aff.inverse().float()
        self.state: Optional[TractographyState] = None
        self.no_steps = 0
        self.state_history = torch.zeros((self.max_steps + 1, 3))
        self.na_reward_history = torch.zeros(
            (self.max_steps, self.tract_masks.shape[-1]))
        self.reset()
Exemplo n.º 2
0
    def __init__(self,
                 rotate=True,
                 grid_dimension=(3, 3, 3),
                 grid_spacing=1.0,
                 postprocessing=None,
                 normalize=None,
                 normalize_mean=(9.8811e-01, 2.6814e-04, 1.2876e-03),
                 normalize_std=(0.0262, 0.1064, 0.1078)):
        """

        If the parameters are passed as none, the value from the config.ini is used.

        Parameters
        ----------
        rotate : bool, optional
            Indicates wether grid should be rotated along fiber, by default None
        grid_dimension : numpy.ndarray, optional
            Grid dimension (X,Y,Z) of the interpolation grid, by default None
        grid_spacing : float, optional
            Grid spacing, by default None
        postprocessing : data.postprocessing, optional
            The postprocessing to be done on the interpolated DWI, by default None
        normalize : bool, optional
            Indicates whether data should be normalized, by default None
        normalize_mean : numpy.ndarray, optional
            Give mean for normalization, by default None
        normalize_std : numpy.ndarray, optional
            Give std for normalization, by default None
        """
        if isinstance(grid_dimension, tuple):
            grid_dimension = np.array(grid_dimension)

        normalize = normalize if normalize is not None else rotate

        self.options = SimpleNamespace()

        if rotate and normalize:
            if isinstance(normalize_mean, tuple):
                normalize_mean = np.array(normalize_mean)

            if isinstance(normalize_std, tuple):
                normalize_std = np.array(normalize_std)

            self.options.normalize_mean = normalize_mean
            self.options.normalize_std = normalize_std

        self.options.rotate = rotate
        self.options.normalize = normalize
        self.options.grid_dimension = grid_dimension
        self.options.grid_spacing = grid_spacing
        self.options.postprocessing = postprocessing
        self.grid = get_grid(grid_dimension) * grid_spacing

        self.id = "RegressionProcessing-r{}-grid{}x{}x{}-spacing{}-postprocessing-{}".format(
            rotate, *grid_dimension, grid_spacing, postprocessing.id)
Exemplo n.º 3
0
    def __init__(self, data_container: DataContainer,
                 points: Optional[int] = None,
                 sphere="repulsion100", postprocessing: PostprocessingOption = None,
                 grid_dimension: tuple = (3, 3, 3), grid_spacing: float = 1.0):
        super().__init__()

        self.postprocessing = postprocessing

        if isinstance(sphere, str):
            sphere = get_sphere(sphere)
        self.dataset = data_container
        self.sphere = sphere
        self.points_ijk = np.array(self.dataset.binary_mask.nonzero()).swapaxes(0, 1)
        if points is not None:
            self.points_ijk = self.points_ijk[np.random.choice(self.points_ijk.shape[0], points, replace=False)]
        self.points = self.dataset.to_ras(self.points_ijk)

        if isinstance(grid_dimension, tuple):
            grid_dimension = np.array(grid_dimension)
        self.grid = get_grid(grid_dimension) * grid_spacing
        self._setup_odf()
    def __init__(self,
                 device,
                 seeds=None,
                 step_width=0.8,
                 dataset='100307',
                 grid_dim=(3, 3, 3),
                 max_l2_dist_to_state=0.1,
                 tracking_in_RAS=True,
                 fa_threshold=0.1,
                 b_val=1000,
                 odf_state=True,
                 odf_mode="CSD",
                 action_space=100,
                 pFolderBundles="data/gt_bundles/"):
        print("DEPRECATED! Dont use anymore.")
        self.state_history = None
        self.reference_seed_point_ijk = None
        self.points_visited = None
        self.past_reward = None
        self.reward = None
        self.stepCounter = None
        self.done = None
        self.seed_index = None
        self.step_angles = None
        self.line = None
        self.na_reward_history = None
        self.av_na_reward = None
        self.past_bundle = None
        print("Loading dataset # ", dataset)
        self.device = device
        preprocessor = DataPreprocessor().normalize().crop(b_val).fa_estimate()
        if dataset == 'ISMRM':
            self.dataset = preprocessor.get_ismrm(f"data/ISMRM2015/")
        else:
            self.dataset = preprocessor.get_hcp(f"data/HCP/{dataset}/")

        self.step_width = step_width
        self.dtype = torch.FloatTensor  # vs. torch.cuda.FloatTensor
        self.dti_model = None
        self.dti_fit = None
        self.odf_interpolator = None
        self.sh_coefficient = None
        self.odf_mode = odf_mode

        np.random.seed(42)
        action_space = action_space
        phi = np.pi * np.random.rand(action_space)
        theta = 2 * np.pi * np.random.rand(action_space)
        sphere = HemiSphere(theta=theta,
                            phi=phi)  #Sphere(theta=theta, phi=phi)
        sphere, _ = disperse_charges(
            sphere, 5000)  # enforce uniform distribtuion of our points
        self.sphere = sphere
        self.sphere_odf = sphere

        # -- interpolation function of state's value --
        self.state_interpol_func = self.interpolate_dwi_at_state
        if odf_state:
            print("Interpolating ODF as state Value")
            self.state_interpol_func = self.interpolate_odf_at_state

        self.directions = torch.from_numpy(self.sphere.vertices).to(device)
        no_actions, _ = self.directions.shape
        self.directions_odf = torch.from_numpy(
            self.sphere_odf.vertices).to(device)

        self.action_space = Discrete(
            no_actions)  # spaces.Discrete(no_actions+1)
        self.dwi_postprocessor = Resample(
            sphere=get_sphere('repulsion100'))  # resample(sphere=sphere)
        self.referenceStreamline_ijk = None
        self.grid = get_grid(np.array(grid_dim))
        self.maxL2dist_to_State = max_l2_dist_to_state
        self.tracking_in_RAS = tracking_in_RAS

        # -- load streamlines --
        self.fa_threshold = fa_threshold
        self.maxSteps = 2000

        # -- init seeds --
        self.seeds = seeds
        if self.seeds is None:
            if self.dti_fit is None:
                self._init_odf()

            dti_model = dti.TensorModel(self.dataset.gtab, fit_method='LS')
            dti_fit = dti_model.fit(self.dataset.dwi,
                                    mask=self.dataset.binary_mask)

            fa_img = dti_fit.fa
            seed_mask = fa_img.copy()
            seed_mask[seed_mask >= 0.2] = 1
            seed_mask[seed_mask < 0.2] = 0

            seeds = utils.seeds_from_mask(seed_mask,
                                          affine=np.eye(4),
                                          density=1)  # tracking in IJK
            self.seeds = torch.from_numpy(seeds)

        # -- init bundles for neuroanatomical reward --
        print("Init tract masks for neuroanatomical reward")
        fibers = []
        self.bundleNames = os.listdir(pFolderBundles)
        for fibFile in self.bundleNames:
            pFibre = pFolderBundles + fibFile
            #print(" @ " + pFibre)
            fibers.append(
                FiberBundleDataset(path_to_files=pFibre,
                                   dataset=self.dataset).tractMask)

        self.tractMasksAllBundles = torch.stack(fibers, dim=0).to(self.device)

        # -- set default values --
        self.reset()

        # -- init observation space --
        obs_shape = self.get_observation_from_state(self.state).shape
        self.observation_space = Box(low=0, high=150, shape=obs_shape)
Exemplo n.º 5
0
    def __init__(self,
                 device,
                 seeds=None,
                 step_width=0.8,
                 dataset='100307',
                 grid_dim=(3, 3, 3),
                 max_l2_dist_to_state=0.1,
                 tracking_in_RAS=False,
                 fa_threshold=0.1,
                 b_val=1000,
                 odf_state=True,
                 odf_mode="CSD",
                 action_space=100,
                 pFolderBundles="data/gt_bundles/",
                 rnd_seed=2342):
        print(
            "Will be deprecated by NARLTractEnvironment as soon as Jos fixes all bugs in the reward function."
        )
        self.state_history = None
        self.reference_seed_point_ijk = None
        self.points_visited = None
        self.past_reward = None
        self.reward = None
        self.stepCounter = None
        self.done = None
        self.seed_index = None
        self.step_angles = None
        self.line = None
        self.na_reward_history = None
        self.av_na_reward = None
        self.past_bundle = None
        print("Loading dataset # ", dataset)
        self.device = device
        preprocessor = DataPreprocessor().normalize().crop(b_val).fa_estimate()
        if dataset == 'ISMRM':
            self.dataset = preprocessor.get_ismrm(f"data/ISMRM2015/")
        else:
            self.dataset = preprocessor.get_hcp(f"data/HCP/{dataset}/")

        self.step_width = step_width
        self.dtype = torch.FloatTensor  # vs. torch.cuda.FloatTensor
        self.dti_model = None
        self.dti_fit = None
        self.odf_interpolator = None
        self.sh_coefficient = None
        self.odf_mode = odf_mode

        # build DWI object by interpolating at all IJK coordinates
        interpol_pts = None
        # permute into CxHxWxD
        self.dwi = torch.from_numpy(Resample100().process(
            self.dataset, None, self.dataset.dwi)).to(device=device).float()

        set_seed(rnd_seed)
        X = random_uniform_on_sphere(n=action_space)

        self.sphere = HemiSphere(xyz=X)
        self.sphere_odf = self.sphere

        # -- interpolation function of state's value --
        self.state_interpol_func = self.interpolate_dwi_at_state
        if odf_state:
            print("Interpolating ODF as state Value")
            self.state_interpol_func = self.interpolate_odf_at_state

        self.directions = torch.from_numpy(self.sphere.vertices).to(device)
        no_actions, _ = self.directions.shape
        self.directions_odf = torch.from_numpy(
            self.sphere_odf.vertices).to(device)

        self.action_space = Discrete(
            no_actions)  # spaces.Discrete(no_actions+1)
        self.dwi_postprocessor = Resample(
            sphere=get_sphere('repulsion100'))  # resample(sphere=sphere)
        self.referenceStreamline_ijk = None
        self.grid = get_grid(np.array(grid_dim))
        self.grid = torch.from_numpy(self.grid).to(self.device)
        self.maxL2dist_to_State = max_l2_dist_to_state
        self.tracking_in_RAS = tracking_in_RAS

        # -- load streamlines --
        self.fa_threshold = fa_threshold
        self.maxSteps = 2000

        # -- init seeds --
        self.seeds = seeds
        if self.seeds is None:
            if self.dti_fit is None:
                self._init_odf()

            dti_model = dti.TensorModel(self.dataset.gtab, fit_method='LS')
            dti_fit = dti_model.fit(self.dataset.dwi,
                                    mask=self.dataset.binary_mask)

            fa_img = dti_fit.fa
            seed_mask = fa_img.copy()
            seed_mask[seed_mask >= 0.2] = 1
            seed_mask[seed_mask < 0.2] = 0

            seeds = utils.seeds_from_mask(seed_mask,
                                          affine=np.eye(4),
                                          density=1)  # tracking in IJK
            self.seeds = torch.from_numpy(seeds)

        # -- init bundles for neuroanatomical reward --
        print("Init tract masks for neuroanatomical reward")
        fibers = []
        self.bundleNames = os.listdir(pFolderBundles)
        for fibFile in self.bundleNames:
            pFibre = pFolderBundles + fibFile
            #print(" @ " + pFibre)
            fibers.append(
                FiberBundleDataset(path_to_files=pFibre,
                                   dataset=self.dataset).tractMask)

        ## Define our interpolators
        self.tractMasks = torch.stack(fibers, dim=0).to(self.device).permute(
            (1, 2, 3, 0))  # [X,Y,Z,C]
        print(self.tractMasks.shape)
        self.tractMask_interpolator = TorchGridInterpolator(self.tractMasks)
        self.binary_mask = torch.from_numpy(
            self.dataset.binary_mask).to(device=device)
        self.fa_interpolator = TorchGridInterpolator(
            torch.from_numpy(
                self.dataset.fa).to(device=device).unsqueeze(-1).float())
        self.dwi_interpolator = TorchGridInterpolator(self.dwi.to(self.device))
        self.brainMask_interpolator = TorchGridInterpolator(
            torch.from_numpy(self.dataset.binary_mask).to(
                self.device).unsqueeze(-1).float())

        # -- set default values --
        self.reset()
Exemplo n.º 6
0
    def __init__(self,
                 device,
                 seeds=None,
                 step_width=0.8,
                 dataset='100307',
                 grid_dim=(3, 3, 3),
                 max_l2_dist_to_state=0.1,
                 tracking_in_RAS=True,
                 fa_threshold=0.1,
                 b_val=1000,
                 max_angle=80.,
                 odf_state=True,
                 odf_mode="CSD"):
        self.state_history = None
        self.reference_seed_point_ijk = None
        self.points_visited = None
        self.past_reward = None
        self.reward = None
        self.stepCounter = None
        self.done = None
        self.seed_index = None
        self.step_angles = None
        self.line = None
        print("Loading dataset # ", dataset)
        self.device = device
        preprocessor = DataPreprocessor().normalize().crop(b_val).fa_estimate()
        if dataset == 'ISMRM':
            self.dataset = preprocessor.get_ismrm(f"data/ISMRM2015/")
        else:
            self.dataset = preprocessor.get_hcp(f"data/HCP/{dataset}/")

        self.step_width = step_width
        self.dtype = torch.FloatTensor  # vs. torch.cuda.FloatTensor
        self.dti_model = None
        self.dti_fit = None
        self.odf_interpolator = None
        self.sh_coefficient = None
        self.odf_mode = odf_mode

        np.random.seed(42)

        # phi = np.pi * np.random.rand(action_space)
        # theta = 2 * np.pi * np.random.rand(action_space)
        # sphere = HemiSphere(theta=theta, phi=phi)  #Sphere(theta=theta, phi=phi)
        # sphere, potential = disperse_charges(sphere, 5000) # enforce uniform distribtuion of our points
        # self.sphere = sphere
        self.sphere_odf = get_sphere('repulsion100')
        self.sphere = self.sphere_odf
        # print("sphere_odf = sphere_action = repulsion100")

        # -- interpolation function of state's value --
        self.state_interpol_func = self.interpolate_dwi_at_state
        if odf_state:
            print("Interpolating ODF as state Value")
            self.state_interpol_func = self.interpolate_odf_at_state

        self.directions = torch.from_numpy(self.sphere.vertices).to(device)
        no_actions, _ = self.directions.shape
        self.directions_odf = torch.from_numpy(
            self.sphere_odf.vertices).to(device)

        self.action_space = Discrete(
            no_actions)  # spaces.Discrete(no_actions+1)
        self.dwi_postprocessor = Resample(
            sphere=get_sphere('repulsion100'))  # resample(sphere=sphere)
        self.referenceStreamline_ijk = None
        self.grid = get_grid(np.array(grid_dim))
        self.maxL2dist_to_State = max_l2_dist_to_state
        self.tracking_in_RAS = tracking_in_RAS

        # -- load streamlines --
        # self.changeReferenceStreamlinesFile(pReferenceStreamlines)
        self.fa_threshold = fa_threshold
        self.maxSteps = 2000

        # -- init seeds --
        self.seeds = seeds
        if self.seeds is None:
            if self.dti_fit is None:
                self._init_odf()

            dti_model = dti.TensorModel(self.dataset.gtab, fit_method='LS')
            dti_fit = dti_model.fit(self.dataset.dwi,
                                    mask=self.dataset.binary_mask)

            fa_img = dti_fit.fa
            seed_mask = fa_img.copy()
            seed_mask[seed_mask >= 0.2] = 1
            seed_mask[seed_mask < 0.2] = 0

            seeds = utils.seeds_from_mask(seed_mask,
                                          affine=np.eye(4),
                                          density=1)  # tracking in IJK
            self.seeds = torch.from_numpy(seeds)

        self.reset()

        # -- init adjacency matrix --
        self.max_angle = max_angle  # the maximum angle between two direction vectors
        self.cos_similarity = np.cos(
            np.deg2rad(max_angle)
        )  # set cosine similarity threshold for initialization of adjacency matrix
        self._set_adjacency_matrix(self.sphere, self.cos_similarity)

        # -- init observation space --
        obs_shape = self.get_observation_from_state(self.state).shape
        self.observation_space = Box(low=0, high=150, shape=obs_shape)

        self.state = None