Ejemplo n.º 1
0
 def _brain_mask(self,
                 row,
                 median_radius=4,
                 numpass=1,
                 autocrop=False,
                 vol_idx=None,
                 dilate=10):
     brain_mask_file = self._get_fname(row, '_brain_mask.nii.gz')
     if self.force_recompute or not op.exists(brain_mask_file):
         b0_file = self._b0(row)
         mean_b0_img = nib.load(b0_file)
         mean_b0 = mean_b0_img.get_fdata()
         _, brain_mask = median_otsu(mean_b0,
                                     median_radius,
                                     numpass,
                                     autocrop,
                                     dilate=dilate)
         be_img = nib.Nifti1Image(brain_mask.astype(int),
                                  mean_b0_img.affine)
         nib.save(be_img, brain_mask_file)
         meta = dict(source=b0_file,
                     median_radius=median_radius,
                     numpass=numpass,
                     autocrop=autocrop,
                     vol_idx=vol_idx)
         meta_fname = self._get_fname(row, '_brain_mask.json')
         afd.write_json(meta_fname, meta)
     return brain_mask_file
Ejemplo n.º 2
0
    def _mapping(self, row):
        if self.use_prealign:
            mapping_file = self._get_fname(
                row, '_mapping_from-DWI_to_MNI_xfm.nii.gz')
        else:
            mapping_file = self._get_fname(
                row,
                '_mapping_from-DWI_to_MNI_xfm' + '_without_prealign.nii.gz')

        if self.force_recompute or not op.exists(mapping_file):
            gtab = row['gtab']
            if self.use_prealign:
                reg_prealign = np.load(self._reg_prealign(row))
            else:
                reg_prealign = None

            warped_b0, mapping = reg.syn_register_dwi(
                row['dwi_file'],
                gtab,
                template=self.reg_template,
                prealign=reg_prealign)

            if self.use_prealign:
                mapping.codomain_world2grid = np.linalg.inv(reg_prealign)

            reg.write_mapping(mapping, mapping_file)
            meta_fname = self._get_fname(row, '_mapping_reg_prealign.json')
            meta = dict(type="displacementfield")
            afd.write_json(meta_fname, meta)

        return mapping_file
Ejemplo n.º 3
0
def plot_tract_profiles(subses_dict, scalars, tracking_params,
                        segmentation_params, segmentation_imap):
    start_time = time()
    fnames = []
    for scalar in scalars:
        if not isinstance(scalar, str):
            this_scalar = scalar.name
        else:
            this_scalar = scalar
        fname = get_fname(
            subses_dict, f'_{this_scalar}_profile_plots',
            tracking_params=tracking_params,
            segmentation_params=segmentation_params)
        tract_profiles_folder = op.join(
            op.dirname(fname),
            "tract_profile_plots")
        fname = op.join(
            tract_profiles_folder,
            op.basename(fname))
        os.makedirs(op.abspath(tract_profiles_folder), exist_ok=True)

        visualize_tract_profiles(
            segmentation_imap["profiles_file"],
            scalar=this_scalar,
            file_name=fname,
            n_boot=100)
        fnames.append(fname)
    meta_fname = get_fname(
        subses_dict, '_profile_plots.json',
        tracking_params=tracking_params,
        segmentation_params=segmentation_params)
    meta = dict(Timing=time() - start_time)
    afd.write_json(meta_fname, meta)

    return fnames
Ejemplo n.º 4
0
 def prealign(self, afq_object, row, save=True):
     prealign_file = afq_object._get_fname(
         row, '_prealign_from-DWI_to-MNI_xfm.npy')
     if not op.exists(prealign_file):
         reg_subject_img, _ = afq_object._reg_img(
             afq_object.reg_subject, True, row)
         start_time = time()
         _, aff = affine_registration(
             reg_subject_img,
             afq_object.reg_template_img,
             **self.affine_kwargs)
         row['timing']['Registration_pre_align'] =\
             row['timing']['Registration_pre_align'] + time() - start_time
         if save:
             np.save(prealign_file, aff)
             meta_fname = afq_object._get_fname(
                 row, '_prealign_from-DWI_to-MNI_xfm.json')
             meta = dict(type="rigid")
             afd.write_json(meta_fname, meta)
         else:
             return aff
     if save:
         return prealign_file
     else:
         return np.load(prealign_file)
Ejemplo n.º 5
0
    def get_for_row(self, afq_object, row):
        mapping_file, meta_fname = self.get_fnames(
            self.extension, afq_object, row)

        if self.use_prealign:
            reg_prealign = np.load(self.prealign(afq_object, row))
        else:
            reg_prealign = None
        if not op.exists(mapping_file):
            reg_template_img, reg_template_sls = \
                afq_object._reg_img(afq_object.reg_template, False, row)
            reg_subject_img, reg_subject_sls = \
                afq_object._reg_img(afq_object.reg_subject, True, row)

            start_time = time()
            mapping = self.gen_mapping(
                afq_object, row, reg_template_img, reg_template_sls,
                reg_subject_img, reg_subject_sls, reg_prealign)
            row['timing']['Registration'] =\
                row['timing']['Registration'] + time() - start_time

            reg.write_mapping(mapping, mapping_file)
            meta = dict(type="displacementfield")
            afd.write_json(meta_fname, meta)
        if self.use_prealign:
            reg_prealign_inv = np.linalg.inv(reg_prealign)
        else:
            reg_prealign_inv = None
        mapping = reg.read_mapping(
            mapping_file,
            row['dwi_file'],
            afq_object.reg_template_img,
            prealign=reg_prealign_inv)
        return mapping
Ejemplo n.º 6
0
def export_bundles(subses_dict, clean_bundles_file, bundles_file, bundle_dict,
                   tracking_params, segmentation_params):
    img = nib.load(subses_dict['dwi_file'])
    for this_bundles_file, folder in zip([clean_bundles_file, bundles_file],
                                         ['clean_bundles', 'bundles']):
        bundles_dir = op.join(subses_dict['results_dir'], folder)
        os.makedirs(bundles_dir, exist_ok=True)
        trk = nib.streamlines.load(this_bundles_file)
        tg = trk.tractogram
        streamlines = tg.streamlines
        for bundle in bundle_dict:
            if bundle != "whole_brain":
                uid = bundle_dict[bundle]['uid']
                idx = np.where(tg.data_per_streamline['bundle'] == uid)[0]
                this_sl = dtu.transform_tracking_output(
                    streamlines[idx], np.linalg.inv(img.affine))

                this_tgm = StatefulTractogram(this_sl, img, Space.VOX)
                fname = op.split(
                    get_fname(subses_dict, f'-{bundle}'
                              f'_tractography.trk',
                              tracking_params=tracking_params,
                              segmentation_params=segmentation_params))
                fname = op.join(bundles_dir, fname[1])
                logger.info(f"Saving {fname}")
                save_tractogram(this_tgm, fname, bbox_valid_check=False)
                meta = dict(source=this_bundles_file)
                meta_fname = fname.split('.')[0] + '.json'
                afd.write_json(meta_fname, meta)
    return True
Ejemplo n.º 7
0
        def wrapper_as_file(*args, **kwargs):
            subses_dict = get_args(
                func, ["subses_dict"], args)[0]
            if include_track:
                tracking_params = get_args(
                    func, ["tracking_params"], args)[0]
            else:
                tracking_params = None
            if include_seg:
                segmentation_params = get_args(
                    func, ["segmentation_params"], args)[0]
            else:
                segmentation_params = None
            this_file = get_fname(
                subses_dict, suffix,
                tracking_params=tracking_params,
                segmentation_params=segmentation_params)
            if not op.exists(this_file):
                img_trk_or_csv, meta = func(*args, **kwargs)

                logger.info(f"Saving {this_file}")
                if isinstance(img_trk_or_csv, nib.Nifti1Image):
                    nib.save(img_trk_or_csv, this_file)
                elif isinstance(img_trk_or_csv, StatefulTractogram):
                    save_tractogram(
                        img_trk_or_csv, this_file, bbox_valid_check=False)
                else:
                    img_trk_or_csv.to_csv(this_file)
                meta_fname = get_fname(
                    subses_dict, suffix.split('.')[0] + '.json',
                    tracking_params=tracking_params,
                    segmentation_params=segmentation_params)
                afd.write_json(meta_fname, meta)
            return this_file
Ejemplo n.º 8
0
def export_rois(subses_dict, bundle_dict, mapping, dwi_affine):
    rois_dir = op.join(subses_dict['results_dir'], 'ROIs')
    os.makedirs(rois_dir, exist_ok=True)
    roi_files = {}
    for bundle in bundle_dict:
        roi_files[bundle] = []
        for ii, roi in enumerate(bundle_dict[bundle]['ROIs']):
            if bundle_dict[bundle]['rules'][ii]:
                inclusion = 'include'
            else:
                inclusion = 'exclude'

            fname = op.split(
                get_fname(
                    subses_dict,
                    f'_desc-ROI-{bundle}-{ii + 1}-{inclusion}.nii.gz'))

            fname = op.join(rois_dir, fname[1])
            if not op.exists(fname):
                warped_roi = auv.transform_inverse_roi(
                    roi,
                    mapping,
                    bundle_name=bundle)

                # Cast to float32, so that it can be read in by MI-Brain:
                logger.info(f"Saving {fname}")
                nib.save(
                    nib.Nifti1Image(
                        warped_roi.astype(np.float32),
                        dwi_affine), fname)
                meta = dict()
                meta_fname = fname.split('.')[0] + '.json'
                afd.write_json(meta_fname, meta)
            roi_files[bundle].append(fname)
    return {'rois_file': roi_files}
Ejemplo n.º 9
0
    def _export_bundles(self, row):
        odf_model = self.tracking_params['odf_model']
        directions = self.tracking_params['directions']
        seg_algo = self.segmentation_params['seg_algo']

        for func, folder in zip([self._clean_bundles, self._segment],
                                ['clean_bundles', 'bundles']):
            bundles_file = func(row)

            bundles_dir = op.join(row['results_dir'], folder)
            os.makedirs(bundles_dir, exist_ok=True)
            trk = nib.streamlines.load(bundles_file)
            tg = trk.tractogram
            streamlines = tg.streamlines
            for bundle in self.bundle_dict:
                if bundle != "whole_brain":
                    uid = self.bundle_dict[bundle]['uid']
                    idx = np.where(tg.data_per_streamline['bundle'] == uid)[0]
                    this_sl = dtu.transform_tracking_output(
                        streamlines[idx], np.linalg.inv(row['dwi_affine']))

                    this_tgm = StatefulTractogram(this_sl, row['dwi_img'],
                                                  Space.VOX)

                    fname = op.split(
                        self._get_fname(
                            row, f'_space-RASMM_model-{odf_model}_desc-'
                            f'{directions}-{seg_algo}-{bundle}'
                            f'_tractography.trk'))
                    fname = op.join(fname[0], bundles_dir, fname[1])
                    save_tractogram(this_tgm, fname, bbox_valid_check=False)
                    meta = dict(source=bundles_file)
                    meta_fname = fname.split('.')[0] + '.json'
                    afd.write_json(meta_fname, meta)
Ejemplo n.º 10
0
 def _csd(
     self,
     row,
     response=None,
     sh_order=8,
     lambda_=1,
     tau=0.1,
 ):
     csd_params_file = self._get_fname(row, '_model-CSD_diffmodel.nii.gz')
     if self.force_recompute or not op.exists(csd_params_file):
         img = nib.load(row['dwi_file'])
         data = img.get_fdata()
         gtab = row['gtab']
         brain_mask_file = self._brain_mask(row)
         mask = nib.load(brain_mask_file).get_fdata()
         csdf = csd_fit(gtab,
                        data,
                        mask=mask,
                        response=response,
                        sh_order=sh_order,
                        lambda_=lambda_,
                        tau=tau)
         nib.save(nib.Nifti1Image(csdf.shm_coeff, row['dwi_affine']),
                  csd_params_file)
         meta_fname = self._get_fname(row, '_model-CSD_diffmodel.json')
         meta = dict(SphericalHarmonicDegree=sh_order,
                     ResponseFunctionTensor=response,
                     SphericalHarmonicBasis="DESCOTEAUX",
                     ModelURL=f"{DIPY_GH}reconst/csdeconv.py",
                     lambda_=lambda_,
                     tau=tau)
         afd.write_json(meta_fname, meta)
     return csd_params_file
Ejemplo n.º 11
0
    def get_for_subses(self,
                       subses_dict,
                       reg_subject,
                       reg_template,
                       subject_sls=None,
                       template_sls=None):
        mapping_file, meta_fname = self.get_fnames(self.extension, subses_dict)

        if self.use_prealign:
            reg_prealign = np.load(
                self.prealign(subses_dict, reg_subject, reg_template))
        else:
            reg_prealign = None
        if not op.exists(mapping_file):
            start_time = time()
            mapping = self.gen_mapping(subses_dict, reg_subject, reg_template,
                                       subject_sls, template_sls, reg_prealign)
            total_time = time() - start_time

            reg.write_mapping(mapping, mapping_file)
            meta = dict(type="displacementfield", timing=total_time)
            afd.write_json(meta_fname, meta)
        if self.use_prealign:
            reg_prealign_inv = np.linalg.inv(reg_prealign)
        else:
            reg_prealign_inv = None
        mapping = reg.read_mapping(mapping_file,
                                   subses_dict['dwi_file'],
                                   reg_template,
                                   prealign=reg_prealign_inv)
        return mapping
Ejemplo n.º 12
0
    def get_for_row(self, afq_object, row):
        scalar_file = afq_object._get_fname(row, f'_model-{self.name}.nii.gz')
        if not op.exists(scalar_file):
            scalar_data, meta = self.get_data(afq_object, row)

            afq_object.log_and_save_nii(
                nib.Nifti1Image(scalar_data, row['dwi_affine']), scalar_file)
            meta_fname = afq_object._get_fname(row, f'_model-{self.name}.json')
            afd.write_json(meta_fname, meta)
        return scalar_file
Ejemplo n.º 13
0
 def _dti_cfa(self, row):
     dti_cfa_file = self._get_fname(row, '_model-DTI_desc-DEC_FA.nii.gz')
     if self.force_recompute or not op.exists(dti_cfa_file):
         tf = self._dti_fit(row)
         cfa = tf.color_fa
         nib.save(nib.Nifti1Image(cfa, row['dwi_affine']), dti_cfa_file)
         meta_fname = self._get_fname(row, '_model-DTI_desc-DEC_FA.json')
         meta = dict()
         afd.write_json(meta_fname, meta)
     return dti_cfa_file
Ejemplo n.º 14
0
 def _dti_md(self, row):
     dti_md_file = self._get_fname(row, '_model-DTI_MD.nii.gz')
     if self.force_recompute or not op.exists(dti_md_file):
         tf = self._dti_fit(row)
         md = tf.md
         nib.save(nib.Nifti1Image(md, row['dwi_affine']), dti_md_file)
         meta_fname = self._get_fname(row, '_model-DTI_MD.json')
         meta = dict()
         afd.write_json(meta_fname, meta)
     return dti_md_file
Ejemplo n.º 15
0
    def _dti_pdd(self, row):
        dti_pdd_file = self._get_fname(row, '_model-DTI_PDD.nii.gz')
        if self.force_recompute or not op.exists(dti_pdd_file):
            tf = self._dti_fit(row)
            pdd = tf.directions.squeeze()
            # Invert the x coordinates:
            pdd[..., 0] = pdd[..., 0] * -1

            nib.save(nib.Nifti1Image(pdd, row['dwi_affine']), dti_pdd_file)
            meta_fname = self._get_fname(row, '_model-DTI_PDD.json')
            meta = dict()
            afd.write_json(meta_fname, meta)
        return dti_pdd_file
Ejemplo n.º 16
0
 def _b0(self, row):
     b0_file = self._get_fname(row, '_b0.nii.gz')
     if self.force_recompute or not op.exists(b0_file):
         img = nib.load(row['dwi_file'])
         data = img.get_fdata()
         gtab = row['gtab']
         mean_b0 = np.mean(data[..., ~gtab.b0s_mask], -1)
         mean_b0_img = nib.Nifti1Image(mean_b0, img.affine)
         nib.save(mean_b0_img, b0_file)
         meta = dict(b0_threshold=gtab.b0_threshold, source=row['dwi_file'])
         meta_fname = self._get_fname(row, '_b0.json')
         afd.write_json(meta_fname, meta)
     return b0_file
Ejemplo n.º 17
0
    def _segment(self, row):
        # We pass `clean_params` here, but do not use it, so we have the
        # same signature as `_clean_bundles`.
        odf_model = self.tracking_params["odf_model"]
        directions = self.tracking_params["directions"]
        seg_algo = self.segmentation_params["seg_algo"]
        bundles_file = self._get_fname(
            row, f'_space-RASMM_model-{odf_model}_desc-{directions}-'
            f'{seg_algo}_tractography.trk')

        if self.force_recompute or not op.exists(bundles_file):
            streamlines_file = self._streamlines(row)

            img = nib.load(row['dwi_file'])
            tg = load_tractogram(streamlines_file, img, Space.VOX)
            if self.use_prealign:
                reg_prealign = np.load(self._reg_prealign(row))
            else:
                reg_prealign = None

            segmentation = seg.Segmentation(**self.segmentation_params)
            bundles = segmentation.segment(self.bundle_dict,
                                           tg,
                                           row['dwi_file'],
                                           row['bval_file'],
                                           row['bvec_file'],
                                           reg_template=self.reg_template,
                                           mapping=self._mapping(row),
                                           reg_prealign=reg_prealign)

            if self.segmentation_params['return_idx']:
                idx = {
                    bundle: bundles[bundle]['idx'].tolist()
                    for bundle in self.bundle_dict
                }
                afd.write_json(bundles_file.split('.')[0] + '_idx.json', idx)
                bundles = {
                    bundle: bundles[bundle]['sl']
                    for bundle in self.bundle_dict
                }

            tgram = aus.bundles_to_tgram(bundles, self.bundle_dict, img)
            save_tractogram(tgram, bundles_file)
            meta = dict(source=streamlines_file,
                        Parameters=self.segmentation_params)
            meta_fname = bundles_file.split('.')[0] + '.json'
            afd.write_json(meta_fname, meta)

        return bundles_file
Ejemplo n.º 18
0
    def _tract_profiles(self, row, weighting=None):
        profiles_file = self._get_fname(row, '_profiles.csv')
        if self.force_recompute or not op.exists(profiles_file):
            bundles_file = self._clean_bundles(row)
            keys = []
            vals = []
            for k in self.bundle_dict.keys():
                if k != "whole_brain":
                    keys.append(self.bundle_dict[k]['uid'])
                    vals.append(k)
            reverse_dict = dict(zip(keys, vals))

            bundle_names = []
            profiles = []
            node_numbers = []
            scalar_names = []

            trk = nib.streamlines.load(bundles_file)
            for scalar in self.scalars:
                scalar_file = self._scalar_dict[scalar](self, row)
                scalar_data = nib.load(scalar_file).get_fdata()
                for b in np.unique(
                        trk.tractogram.data_per_streamline['bundle']):
                    idx = np.where(
                        trk.tractogram.data_per_streamline['bundle'] == b)[0]
                    this_sl = trk.streamlines[idx]
                    bundle_name = reverse_dict[b]
                    this_profile = afq_profile(scalar_data, this_sl,
                                               row["dwi_affine"])
                    nodes = list(np.arange(this_profile.shape[0]))
                    bundle_names.extend([bundle_name] * len(nodes))
                    node_numbers.extend(nodes)
                    scalar_names.extend([scalar] * len(nodes))
                    profiles.extend(list(this_profile))

            profile_dframe = pd.DataFrame(
                dict(profiles=profiles,
                     bundle=bundle_names,
                     node=node_numbers,
                     scalar=scalar_names))
            profile_dframe.to_csv(profiles_file)
            meta = dict(source=bundles_file,
                        parameters=get_default_args(afq_profile))
            meta_fname = profiles_file.split('.')[0] + '.json'
            afd.write_json(meta_fname, meta)

        return profiles_file
Ejemplo n.º 19
0
        def get_for_subses_getter(
                subses_dict, dwi_affine, reg_template, mapping):
            scalar_file = get_fname(
                subses_dict,
                f'_model-{self.name}.nii.gz')
            if not op.exists(scalar_file):
                scalar_data, meta = self.get_data(
                    subses_dict, dwi_affine, reg_template, mapping)

                nib.save(
                    nib.Nifti1Image(scalar_data, dwi_affine),
                    scalar_file)
                meta_fname = get_fname(
                    subses_dict,
                    f'_model-{self.name}.json')
                afd.write_json(meta_fname, meta)
            return scalar_file
Ejemplo n.º 20
0
 def _dti(self, row):
     dti_params_file = self._get_fname(row, '_model-DTI_diffmodel.nii.gz')
     if self.force_recompute or not op.exists(dti_params_file):
         img = nib.load(row['dwi_file'])
         data = img.get_fdata()
         gtab = row['gtab']
         brain_mask_file = self._brain_mask(row)
         mask = nib.load(brain_mask_file).get_fdata()
         dtf = dti_fit(gtab, data, mask=mask)
         nib.save(nib.Nifti1Image(dtf.model_params, row['dwi_affine']),
                  dti_params_file)
         meta_fname = self._get_fname(row, '_model-DTI_diffmodel.json')
         meta = dict(Parameters=dict(FitMethod="WLS"),
                     OutlierRejection=False,
                     ModelURL=f"{DIPY_GH}reconst/dti.py")
         afd.write_json(meta_fname, meta)
     return dti_params_file
Ejemplo n.º 21
0
 def _reg_prealign(self, row):
     prealign_file = self._get_fname(row,
                                     '_prealign_from-DWI_to-MNI_xfm.npy')
     if self.force_recompute or not op.exists(prealign_file):
         moving = nib.load(self._b0(row))
         static = dpd.read_mni_template()
         moving_data = moving.get_fdata()
         moving_affine = moving.affine
         static_data = static.get_fdata()
         static_affine = static.affine
         _, aff = reg.affine_registration(moving_data, static_data,
                                          moving_affine, static_affine)
         np.save(prealign_file, aff)
         meta_fname = self._get_fname(row,
                                      '_prealign_from-DWI_to-MNI_xfm.json')
         meta = dict(type="rigid")
         afd.write_json(meta_fname, meta)
     return prealign_file
Ejemplo n.º 22
0
    def _streamlines(self, row):
        odf_model = self.tracking_params["odf_model"]
        directions = self.tracking_params["directions"]

        streamlines_file = self._get_fname(
            row, f'_space-RASMM_model-{odf_model}_desc-{directions}' +
            '_tractography.trk')

        if self.force_recompute or not op.exists(streamlines_file):
            if odf_model == "DTI":
                params_file = self._dti(row)
            elif odf_model == "CSD":
                params_file = self._csd(row)
            wm_mask_fname = self._wm_mask(row)
            wm_mask = nib.load(wm_mask_fname).get_fdata().astype(bool)
            self.tracking_params['seed_mask'] = wm_mask
            self.tracking_params['stop_mask'] = wm_mask
            sft = aft.track(params_file, **self.tracking_params)
            sft.to_vox()
            meta_directions = {"det": "deterministic", "prob": "probabilistic"}

            meta = dict(TractographyClass="local",
                        TractographyMethod=meta_directions[
                            self.tracking_params["directions"]],
                        Count=len(sft.streamlines),
                        Seeding=dict(
                            ROI=wm_mask_fname,
                            n_seeds=self.tracking_params["n_seeds"],
                            random_seeds=self.tracking_params["random_seeds"]),
                        Constraints=dict(AnatomicalImage=wm_mask_fname),
                        Parameters=dict(
                            Units="mm",
                            StepSize=self.tracking_params["step_size"],
                            MinimumLength=self.tracking_params["min_length"],
                            MaximumLength=self.tracking_params["max_length"],
                            Unidirectional=False))

            meta_fname = self._get_fname(
                row, f'_space-RASMM_model-{odf_model}_desc-'
                f'{directions}_tractography.json')
            afd.write_json(meta_fname, meta)
            save_tractogram(sft, streamlines_file, bbox_valid_check=False)

        return streamlines_file
Ejemplo n.º 23
0
 def prealign(self, subses_dict, reg_subject, reg_template, save=True):
     prealign_file = get_fname(subses_dict,
                               '_prealign_from-DWI_to-MNI_xfm.npy')
     if not op.exists(prealign_file):
         start_time = time()
         _, aff = affine_registration(reg_subject, reg_template,
                                      **self.affine_kwargs)
         meta = dict(type="rigid", timing=time() - start_time)
         if save:
             np.save(prealign_file, aff)
             meta_fname = get_fname(subses_dict,
                                    '_prealign_from-DWI_to-MNI_xfm.json')
             afd.write_json(meta_fname, meta)
         else:
             return aff
     if save:
         return prealign_file
     else:
         return np.load(prealign_file)
Ejemplo n.º 24
0
    def _export_rois(self, row):
        if self.use_prealign:
            reg_prealign = np.load(self._reg_prealign(row))
            reg_prealign_inv = np.linalg.inv(reg_prealign)
        else:
            reg_prealign_inv = None

        mapping = reg.read_mapping(self._mapping(row),
                                   row['dwi_file'],
                                   self.reg_template,
                                   prealign=reg_prealign_inv)

        rois_dir = op.join(row['results_dir'], 'ROIs')
        os.makedirs(rois_dir, exist_ok=True)

        for bundle in self.bundle_dict:
            for ii, roi in enumerate(self.bundle_dict[bundle]['ROIs']):

                if self.bundle_dict[bundle]['rules'][ii]:
                    inclusion = 'include'
                else:
                    inclusion = 'exclude'

                warped_roi = auv.patch_up_roi((mapping.transform_inverse(
                    roi.get_fdata(), interpolation='linear')) > 0).astype(int)

                fname = op.split(
                    self._get_fname(
                        row,
                        f'_desc-ROI-{bundle}-{ii + 1}-{inclusion}.nii.gz'))

                fname = op.join(fname[0], rois_dir, fname[1])

                # Cast to float32, so that it can be read in by MI-Brain:
                nib.save(
                    nib.Nifti1Image(warped_roi.astype(np.float32),
                                    row['dwi_affine']), fname)
                meta = dict()
                meta_fname = fname.split('.')[0] + '.json'
                afd.write_json(meta_fname, meta)
Ejemplo n.º 25
0
    def _wm_mask(self, row, wm_fa_thresh=0.2):
        wm_mask_file = self._get_fname(row, '_wm_mask.nii.gz')
        if self.force_recompute or not op.exists(wm_mask_file):
            dwi_img = nib.load(row['dwi_file'])
            dwi_data = dwi_img.get_fdata()

            if 'seg_file' in row.index:
                # If we found a white matter segmentation in the
                # expected location:
                seg_img = nib.load(row['seg_file'])
                seg_data_orig = seg_img.get_fdata()
                # For different sets of labels, extract all the voxels that
                # have any of these values:
                wm_mask = np.sum(
                    np.concatenate([(seg_data_orig == l)[..., None]
                                    for l in self.wm_labels], -1), -1)

                # Resample to DWI data:
                wm_mask = np.round(
                    reg.resample(wm_mask, dwi_data[..., 0], seg_img.affine,
                                 dwi_img.affine)).astype(int)
                meta = dict(source=row['seg_file'], wm_labels=self.wm_labels)
            else:
                # Otherwise, we'll identify the white matter based on FA:
                fa_fname = self._dti_fa(row)
                dti_fa = nib.load(fa_fname).get_fdata()
                wm_mask = dti_fa > wm_fa_thresh
                meta = dict(source=fa_fname, fa_threshold=wm_fa_thresh)

            # Dilate to be sure to reach the gray matter:
            wm_mask = binary_dilation(wm_mask) > 0

            nib.save(nib.Nifti1Image(wm_mask.astype(int), row['dwi_affine']),
                     wm_mask_file)

            meta_fname = self._get_fname(row, '_wm_mask.json')
            afd.write_json(meta_fname, meta)

        return wm_mask_file
Ejemplo n.º 26
0
def viz_indivBundle(subses_dict,
                    dwi_affine,
                    viz_backend,
                    bundle_dict,
                    data_imap,
                    mapping_imap,
                    segmentation_imap,
                    tracking_params,
                    segmentation_params,
                    reg_template,
                    best_scalar,
                    xform_volume_indiv=False,
                    cbv_lims_indiv=[None, None],
                    xform_color_by_volume_indiv=False,
                    volume_opacity_indiv=0.3,
                    n_points_indiv=40):
    mapping = mapping_imap["mapping"]
    scalar_dict = segmentation_imap["scalar_dict"]
    volume = data_imap["b0_file"]
    color_by_volume = data_imap[best_scalar + "_file"]

    start_time = time()
    volume = _viz_prepare_vol(
        volume, xform_volume_indiv, mapping, scalar_dict)
    color_by_volume = _viz_prepare_vol(
        color_by_volume, xform_color_by_volume_indiv, mapping, scalar_dict)

    flip_axes = [False, False, False]
    for i in range(3):
        flip_axes[i] = (dwi_affine[i, i] < 0)

    bundle_names = bundle_dict.keys()

    for bundle_name in bundle_names:
        logger.info(f"Generating {bundle_name} visualization...")
        uid = bundle_dict[bundle_name]['uid']
        figure = viz_backend.visualize_volume(
            volume,
            opacity=volume_opacity_indiv,
            flip_axes=flip_axes,
            interact=False,
            inline=False)
        try:
            figure = viz_backend.visualize_bundles(
                segmentation_imap["clean_bundles_file"],
                color_by_volume=color_by_volume,
                cbv_lims=cbv_lims_indiv,
                bundle_dict=bundle_dict,
                bundle=uid,
                n_points=n_points_indiv,
                flip_axes=flip_axes,
                interact=False,
                inline=False,
                figure=figure)
        except ValueError:
            logger.info(
                "No streamlines found to visualize for "
                + bundle_name)

        if segmentation_params["filter_by_endpoints"]:
            warped_rois = []
            endpoint_info = segmentation_params["endpoint_info"]
            if endpoint_info is not None:
                start_p = endpoint_info[bundle_name]['startpoint']
                end_p = endpoint_info[bundle_name]['endpoint']
                for pp in [start_p, end_p]:
                    pp = resample(
                        pp.get_fdata(),
                        reg_template,
                        pp.affine,
                        reg_template.affine).get_fdata()

                    atlas_roi = np.zeros(pp.shape)
                    atlas_roi[np.where(pp > 0)] = 1
                    warped_roi = auv.transform_inverse_roi(
                        atlas_roi,
                        mapping,
                        bundle_name=bundle_name)
                    warped_rois.append(warped_roi)
            else:
                aal_atlas = afd.read_aal_atlas(reg_template)
                atlas = aal_atlas['atlas'].get_fdata()
                aal_targets = afd.bundles_to_aal(
                    [bundle_name], atlas=atlas)[0]
                for targ in aal_targets:
                    if targ is not None:
                        aal_roi = np.zeros(atlas.shape[:3])
                        aal_roi[targ[:, 0],
                                targ[:, 1],
                                targ[:, 2]] = 1
                        warped_roi = auv.transform_inverse_roi(
                            aal_roi,
                            mapping,
                            bundle_name=bundle_name)
                        warped_rois.append(warped_roi)
            for i, roi in enumerate(warped_rois):
                figure = viz_backend.visualize_roi(
                    roi,
                    name=f"{bundle_name} endpoint ROI {i}",
                    flip_axes=flip_axes,
                    inline=False,
                    interact=False,
                    figure=figure)

        for i, roi in enumerate(mapping_imap["rois_file"][bundle_name]):
            figure = viz_backend.visualize_roi(
                roi,
                name=f"{bundle_name} ROI {i}",
                flip_axes=flip_axes,
                inline=False,
                interact=False,
                figure=figure)

        roi_dir = op.join(subses_dict['results_dir'], 'viz_bundles')
        os.makedirs(roi_dir, exist_ok=True)
        if "no_gif" not in viz_backend.backend:
            fname = op.split(
                get_fname(
                    subses_dict,
                    f'_{bundle_name}'
                    f'_viz.gif',
                    tracking_params=tracking_params,
                    segmentation_params=segmentation_params))

            fname = op.join(roi_dir, fname[1])
            viz_backend.create_gif(figure, fname)
        if "plotly" in viz_backend.backend:
            roi_dir = op.join(subses_dict['results_dir'], 'viz_bundles')
            os.makedirs(roi_dir, exist_ok=True)
            fname = op.split(
                get_fname(
                    subses_dict,
                    f'_{bundle_name}'
                    f'_viz.html',
                    tracking_params=tracking_params,
                    segmentation_params=segmentation_params))

            fname = op.join(roi_dir, fname[1])
            figure.write_html(fname)
    meta_fname = get_fname(
        subses_dict, '_vizIndiv.json',
        tracking_params=tracking_params,
        segmentation_params=segmentation_params)
    meta = dict(Timing=time() - start_time)
    afd.write_json(meta_fname, meta)
    return True
Ejemplo n.º 27
0
def viz_bundles(subses_dict,
                dwi_affine,
                viz_backend,
                bundle_dict,
                data_imap,
                mapping_imap,
                segmentation_imap,
                tracking_params,
                segmentation_params,
                best_scalar,
                xform_volume=False,
                cbv_lims=[None, None],
                xform_color_by_volume=False,
                volume_opacity=0.3,
                n_points=40):
    mapping = mapping_imap["mapping"]
    scalar_dict = segmentation_imap["scalar_dict"]
    volume = data_imap["b0_file"]
    color_by_volume = data_imap[best_scalar + "_file"]
    start_time = time()
    volume = _viz_prepare_vol(volume, xform_volume, mapping, scalar_dict)
    color_by_volume = _viz_prepare_vol(
        color_by_volume, xform_color_by_volume, mapping, scalar_dict)

    flip_axes = [False, False, False]
    for i in range(3):
        flip_axes[i] = (dwi_affine[i, i] < 0)

    figure = viz_backend.visualize_volume(
        volume,
        opacity=volume_opacity,
        flip_axes=flip_axes,
        interact=False,
        inline=False)

    figure = viz_backend.visualize_bundles(
        segmentation_imap["clean_bundles_file"],
        color_by_volume=color_by_volume,
        cbv_lims=cbv_lims,
        bundle_dict=bundle_dict,
        n_points=n_points,
        flip_axes=flip_axes,
        interact=False,
        inline=False,
        figure=figure)

    if "no_gif" not in viz_backend.backend:
        fname = get_fname(
            subses_dict, '_viz.gif',
            tracking_params=tracking_params,
            segmentation_params=segmentation_params)

        viz_backend.create_gif(figure, fname)
    if "plotly" in viz_backend.backend:
        fname = get_fname(
            subses_dict, '_viz.html',
            tracking_params=tracking_params,
            segmentation_params=segmentation_params)

        figure.write_html(fname)
    meta_fname = get_fname(
        subses_dict, '_viz.json',
        tracking_params=tracking_params,
        segmentation_params=segmentation_params)
    meta = dict(Timing=time() - start_time)
    afd.write_json(meta_fname, meta)
    return figure
Ejemplo n.º 28
0
    def _clean_bundles(self, row):
        odf_model = self.tracking_params['odf_model']
        directions = self.tracking_params['directions']
        seg_algo = self.segmentation_params['seg_algo']
        clean_bundles_file = self._get_fname(
            row, f'_space-RASMM_model-{odf_model}_desc-{directions}-'
            f'{seg_algo}-clean_tractography.trk')

        if self.force_recompute or not op.exists(clean_bundles_file):
            bundles_file = self._segment(row)

            sft = load_tractogram(bundles_file, row['dwi_img'], Space.VOX)

            tgram = nib.streamlines.Tractogram([], {'bundle': []})
            if self.clean_params['return_idx']:
                return_idx = {}

            for b in self.bundle_dict.keys():
                if b != "whole_brain":
                    idx = np.where(sft.data_per_streamline['bundle'] ==
                                   self.bundle_dict[b]['uid'])[0]
                    this_tg = StatefulTractogram(sft.streamlines[idx],
                                                 row['dwi_img'], Space.VOX)
                    this_tg = seg.clean_bundle(this_tg, **self.clean_params)
                    if self.clean_params['return_idx']:
                        this_tg, this_idx = this_tg
                        idx_file = bundles_file.split('.')[0] + '_idx.json'
                        with open(idx_file) as ff:
                            bundle_idx = json.load(ff)[b]
                        return_idx[b] = \
                            np.array(bundle_idx)[this_idx].tolist()
                    this_tgram = nib.streamlines.Tractogram(
                        this_tg.streamlines,
                        data_per_streamline={
                            'bundle':
                            (len(this_tg) * [self.bundle_dict[b]['uid']])
                        },
                        affine_to_rasmm=row['dwi_affine'])
                    tgram = aus.add_bundles(tgram, this_tgram)
            save_tractogram(
                StatefulTractogram(
                    tgram.streamlines,
                    sft,
                    Space.VOX,
                    data_per_streamline=tgram.data_per_streamline),
                clean_bundles_file)

            seg_args = get_default_args(seg.clean_bundle)
            for k in seg_args:
                if callable(seg_args[k]):
                    seg_args[k] = seg_args[k].__name__

            meta = dict(source=bundles_file, Parameters=seg_args)
            meta_fname = clean_bundles_file.split('.')[0] + '.json'
            afd.write_json(meta_fname, meta)

            if self.clean_params['return_idx']:
                afd.write_json(
                    clean_bundles_file.split('.')[0] + '_idx.json', return_idx)

        return clean_bundles_file