Beispiel #1
0
    def get_for_subses(self,
                       subses_dict,
                       reg_subject,
                       reg_template,
                       subject_sls=None,
                       template_sls=None):
        mapping_file, meta_fname = self.get_fnames(self.extension, subses_dict)

        if self.use_prealign:
            reg_prealign = np.load(
                self.prealign(subses_dict, reg_subject, reg_template))
        else:
            reg_prealign = None
        if not op.exists(mapping_file):
            start_time = time()
            mapping = self.gen_mapping(subses_dict, reg_subject, reg_template,
                                       subject_sls, template_sls, reg_prealign)
            total_time = time() - start_time

            reg.write_mapping(mapping, mapping_file)
            meta = dict(type="displacementfield", timing=total_time)
            afd.write_json(meta_fname, meta)
        if self.use_prealign:
            reg_prealign_inv = np.linalg.inv(reg_prealign)
        else:
            reg_prealign_inv = None
        mapping = reg.read_mapping(mapping_file,
                                   subses_dict['dwi_file'],
                                   reg_template,
                                   prealign=reg_prealign_inv)
        return mapping
Beispiel #2
0
    def get_for_row(self, afq_object, row):
        mapping_file, meta_fname = self.get_fnames(
            self.extension, afq_object, row)

        if self.use_prealign:
            reg_prealign = np.load(self.prealign(afq_object, row))
        else:
            reg_prealign = None
        if not op.exists(mapping_file):
            reg_template_img, reg_template_sls = \
                afq_object._reg_img(afq_object.reg_template, False, row)
            reg_subject_img, reg_subject_sls = \
                afq_object._reg_img(afq_object.reg_subject, True, row)

            start_time = time()
            mapping = self.gen_mapping(
                afq_object, row, reg_template_img, reg_template_sls,
                reg_subject_img, reg_subject_sls, reg_prealign)
            row['timing']['Registration'] =\
                row['timing']['Registration'] + time() - start_time

            reg.write_mapping(mapping, mapping_file)
            meta = dict(type="displacementfield")
            afd.write_json(meta_fname, meta)
        if self.use_prealign:
            reg_prealign_inv = np.linalg.inv(reg_prealign)
        else:
            reg_prealign_inv = None
        mapping = reg.read_mapping(
            mapping_file,
            row['dwi_file'],
            afq_object.reg_template_img,
            prealign=reg_prealign_inv)
        return mapping
Beispiel #3
0
    def _mapping(self, row):
        if self.use_prealign:
            mapping_file = self._get_fname(
                row, '_mapping_from-DWI_to_MNI_xfm.nii.gz')
        else:
            mapping_file = self._get_fname(
                row,
                '_mapping_from-DWI_to_MNI_xfm' + '_without_prealign.nii.gz')

        if self.force_recompute or not op.exists(mapping_file):
            gtab = row['gtab']
            if self.use_prealign:
                reg_prealign = np.load(self._reg_prealign(row))
            else:
                reg_prealign = None

            warped_b0, mapping = reg.syn_register_dwi(
                row['dwi_file'],
                gtab,
                template=self.reg_template,
                prealign=reg_prealign)

            if self.use_prealign:
                mapping.codomain_world2grid = np.linalg.inv(reg_prealign)

            reg.write_mapping(mapping, mapping_file)
            meta_fname = self._get_fname(row, '_mapping_reg_prealign.json')
            meta = dict(type="displacementfield")
            afd.write_json(meta_fname, meta)

        return mapping_file
Beispiel #4
0
def test_syn_registration():
    with nbtmp.InTemporaryDirectory() as tmpdir:
        warped_moving, mapping = syn_registration(subset_b0,
                                                  subset_t2,
                                                  moving_affine=hardi_affine,
                                                  static_affine=MNI_T2_affine,
                                                  step_length=0.1,
                                                  metric='CC',
                                                  dim=3,
                                                  level_iters=[5, 5, 5],
                                                  sigma_diff=2.0,
                                                  radius=1,
                                                  prealign=None)

        npt.assert_equal(warped_moving.shape, subset_t2.shape)
        mapping_fname = op.join(tmpdir, 'mapping.nii.gz')
        write_mapping(mapping, mapping_fname)
        file_mapping = read_mapping(mapping_fname, subset_b0_img,
                                    subset_t2_img)

        # Test that it has the same effect on the data:
        warped_from_file = file_mapping.transform(subset_b0)
        npt.assert_equal(warped_from_file, warped_moving)

        # Test that it is, attribute by attribute, identical:
        for k in mapping.__dict__:
            assert (np.all(
                mapping.__getattribute__(k) == file_mapping.__getattribute__(
                    k)))
Beispiel #5
0
def test_syn_registration():
    with nbtmp.InTemporaryDirectory() as tmpdir:
        warped_moving, mapping = syn_registration(subset_b0,
                                                  subset_t2,
                                                  moving_affine=hardi_affine,
                                                  static_affine=MNI_T2_affine,
                                                  step_length=0.1,
                                                  metric='CC',
                                                  dim=3,
                                                  level_iters=[10, 10, 5],
                                                  sigma_diff=2.0,
                                                  prealign=None)

        npt.assert_equal(warped_moving.shape, subset_t2.shape)
        mapping_fname = op.join(tmpdir, 'mapping.nii.gz')
        write_mapping(mapping, mapping_fname)
        file_mapping = read_mapping(mapping_fname,
                                    subset_b0_img,
                                    subset_t2_img)

        # Test that it has the same effect on the data:
        warped_from_file = file_mapping.transform(subset_b0)
        npt.assert_equal(warped_from_file, warped_moving)

        # Test that it is, attribute by attribute, identical:
        for k in mapping.__dict__:
            assert (np.all(mapping.__getattribute__(k) ==
                           file_mapping.__getattribute__(k)))
Beispiel #6
0
def _mapping(row, force_recompute=False):
    mapping_file = _get_fname(row, '_mapping.nii.gz')
    if not op.exists(mapping_file) or force_recompute:
        gtab = row['gtab']
        reg_template = dpd.read_mni_template()
        mapping = reg.syn_register_dwi(row['dwi_file'], gtab,
                                       template=reg_template)

        reg.write_mapping(mapping, mapping_file)
    return mapping_file
Beispiel #7
0
def _mapping(row, force_recompute=False):
    mapping_file = _get_fname(row, '_mapping.nii.gz')
    if not op.exists(mapping_file) or force_recompute:
        gtab = row['gtab']
        reg_template = dpd.read_mni_template()
        mapping = reg.syn_register_dwi(row['dwi_file'],
                                       gtab,
                                       template=reg_template)

        reg.write_mapping(mapping, mapping_file)
    return mapping_file
Beispiel #8
0
def main():
    with open('config.json') as config_json:
        config = json.load(config_json)

    data_file = str(config['data_file'])
    data_bval = str(config['data_bval'])
    data_bvec = str(config['data_bvec'])

    img = nib.load(data_file)

    print("Calculating DTI...")
    if not op.exists('./dti_FA.nii.gz'):
        dti_params = dti.fit_dti(data_file, data_bval, data_bvec, out_dir='.')
    else:
        dti_params = {'FA': './dti_FA.nii.gz', 'params': './dti_params.nii.gz'}

    tg = nib.streamlines.load('csa_prob.trk').tractogram
    streamlines = tg.apply_affine(np.linalg.inv(img.affine)).streamlines

    # Use only a small portion of the streamlines, for expedience:
    streamlines = streamlines[::100]

    templates = afd.read_templates()
    bundle_names = ["CST", "ILF"]

    bundles = {}
    for name in bundle_names:
        for hemi in ['_R', '_L']:
            bundles[name + hemi] = {
                'ROIs': [
                    templates[name + '_roi1' + hemi],
                    templates[name + '_roi1' + hemi]
                ],
                'rules': [True, True]
            }

    print("Registering to template...")
    MNI_T2_img = dpd.read_mni_template()
    bvals, bvecs = read_bvals_bvecs(data_bval, data_bvec)
    gtab = gradient_table(bvals, bvecs, b0_threshold=100)
    mapping = reg.syn_register_dwi(data_file, gtab)
    reg.write_mapping(mapping, './mapping.nii.gz')

    print("Segmenting fiber groups...")
    fiber_groups = seg.segment(data_file,
                               data_bval,
                               data_bvec,
                               streamlines,
                               bundles,
                               reg_template=MNI_T2_img,
                               mapping=mapping,
                               as_generator=False,
                               affine=img.affine)
    """
Beispiel #9
0
def _mapping(row, reg_template, force_recompute=False):
    mapping_file = _get_fname(row, '_mapping.nii.gz')
    if not op.exists(mapping_file) or force_recompute:
        gtab = row['gtab']
        reg_prealign = np.load(
            _reg_prealign(row, force_recompute=force_recompute))
        warped_b0, mapping = reg.syn_register_dwi(row['dwi_file'],
                                                  gtab,
                                                  template=reg_template,
                                                  prealign=reg_prealign)
        mapping.codomain_world2grid = np.linalg.inv(reg_prealign)
        reg.write_mapping(mapping, mapping_file)
    return mapping_file
Beispiel #10
0
def test_slr_registration():
    # have to import subject sls
    file_dict = afd.read_stanford_hardi_tractography()
    streamlines = file_dict['tractography_subsampled.trk']

    # have to import sls atlas
    afd.fetch_hcp_atlas_16_bundles()
    atlas_fname = op.join(afd.afq_home, 'hcp_atlas_16_bundles',
                          'Atlas_in_MNI_Space_16_bundles', 'whole_brain',
                          'whole_brain_MNI.trk')
    hcp_atlas = load_tractogram(atlas_fname, 'same', bbox_valid_check=False)

    with nbtmp.InTemporaryDirectory() as tmpdir:
        mapping = slr_registration(streamlines,
                                   hcp_atlas.streamlines,
                                   moving_affine=subset_b0_img.affine,
                                   static_affine=subset_t2_img.affine,
                                   moving_shape=subset_b0_img.shape,
                                   static_shape=subset_t2_img.shape,
                                   progressive=False,
                                   greater_than=10,
                                   rm_small_clusters=1,
                                   rng=np.random.RandomState(seed=8))
        warped_moving = mapping.transform(subset_b0)

        npt.assert_equal(warped_moving.shape, subset_t2.shape)
        mapping_fname = op.join(tmpdir, 'mapping.npy')
        write_mapping(mapping, mapping_fname)
        file_mapping = read_mapping(mapping_fname, subset_b0_img,
                                    subset_t2_img)

        # Test that it has the same effect on the data:
        warped_from_file = file_mapping.transform(subset_b0)
        npt.assert_equal(warped_from_file, warped_moving)

        # Test that it is, attribute by attribute, identical:
        for k in mapping.__dict__:
            assert (np.all(
                mapping.__getattribute__(k) == file_mapping.__getattribute__(
                    k)))
Beispiel #11
0
MNI_T2_img = afd.read_mni_template()

if not op.exists(op.join(working_dir, 'mapping.nii.gz')):
    import dipy.core.gradients as dpg
    gtab = dpg.gradient_table(hardi_fbval, hardi_fbvec)
    b0 = np.mean(img.get_fdata()[..., gtab.b0s_mask], -1)
    # Prealign using affine registration
    _, prealign = affine_registration(b0, MNI_T2_img.get_fdata(), img.affine,
                                      MNI_T2_img.affine)

    # Then register using a non-linear registration using the affine for
    # prealignment
    warped_hardi, mapping = reg.syn_register_dwi(hardi_fdata,
                                                 gtab,
                                                 prealign=prealign)
    reg.write_mapping(mapping, op.join(working_dir, 'mapping.nii.gz'))
else:
    mapping = reg.read_mapping(op.join(working_dir, 'mapping.nii.gz'), img,
                               MNI_T2_img)

##########################################################################
# Read in bundle specification
# -------------------------------------------
# The waypoint ROIs, in addition to bundle probability maps are stored in this
# data structure. The templates are first resampled into the MNI space, before
# they are brought into the subject's individual native space.
# For speed, we only segment two bundles here.

bundles = api.BundleDict(["CST", "ARC"], resample_to=MNI_T2_img)

##########################################################################
Beispiel #12
0
                             out_dir='.')
else:
    dti_params = {'FA': './dti_FA.nii.gz', 'params': './dti_params.nii.gz'}

FA_img = nib.load(dti_params['FA'])
FA_data = FA_img.get_fdata()

print("Registering to template...")
MNI_T2_img = afd.read_mni_template()
if not op.exists('mapping.nii.gz'):
    import dipy.core.gradients as dpg
    gtab = dpg.gradient_table(hardi_fbval, hardi_fbvec)
    warped_hardi, mapping = reg.syn_register_dwi(hardi_fdata,
                                                 gtab,
                                                 template=MNI_T2_img)
    reg.write_mapping(mapping, './mapping.nii.gz')
else:
    mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)

bundle_names = ["CST", "UF", "CC_ForcepsMajor", "CC_ForcepsMinor"]
bundles = api.make_bundle_dict(bundle_names=bundle_names, seg_algo="reco")

print("Tracking...")
if not op.exists('dti_streamlines_reco.trk'):
    seed_roi = np.zeros(img.shape[:-1])
    for bundle in bundles:
        if bundle != 'whole_brain':
            sl_xform = dts.Streamlines(
                dtu.transform_tracking_output(bundles[bundle]['sl'],
                                              MNI_T2_img.affine))
Beispiel #13
0
    dti_params = dti.fit_dti(fdata, fbval, fbvec,
                             out_dir='.', mask=brain_mask)
else:
    dti_params = {'FA': './dti_FA.nii.gz',
                  'MD': './dti_MD.nii.gz',
                  'RD': './dti_RD.nii.gz',
                  'AD': './dti_AD.nii.gz',
                  'params': './dti_params.nii.gz'}

print("Registering to template...")
MNI_T2_img = dpd.read_mni_template()
if not op.exists('mapping.nii.gz'):
    import dipy.core.gradients as dpg
    gtab = dpg.gradient_table(fbval, fbvec)
    mapping = reg.syn_register_dwi(fdata, gtab)
    reg.write_mapping(mapping, './mapping.nii.gz')
else:
    mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)

print("Tracking...")
if not op.exists('dti_streamlines.trk'):
    FA = nib.load(dti_params["FA"]).get_data()
    wm_mask = np.zeros_like(FA)
    wm_mask[FA > 0.2] = 1
    step_size = 1
    min_length_mm = 50
    streamlines = dts.Streamlines(
        aft.track(dti_params['params'],
                  directions="det",
                  seed_mask=wm_mask,
                  seeds=2,
Beispiel #14
0
def main():
    with open('config.json') as config_json:
        config = json.load(config_json)

    #Paths to data
    data_file = str(config['data_file'])
    data_bval = str(config['data_bval'])
    data_bvec = str(config['data_bvec'])

    img = nib.load(data_file)
    """
	print("Calculating DTI...")
	if not op.exists('./dti_FA.nii.gz'):
	    dti_params = dti.fit_dti(data_file, data_bval, data_bvec, out_dir='.')
	else:
	    dti_params = {'FA': './dti_FA.nii.gz',
			  'params': './dti_params.nii.gz'}
	"""
    #tg = nib.streamlines.load('track.trk').tractogram

    tg = nib.streamlines.load(config['tck_data']).tractogram
    streamlines = tg.apply_affine(np.linalg.inv(img.affine)).streamlines

    # Use only a small portion of the streamlines, for expedience:
    #streamlines = streamlines[::100]

    templates = afd.read_templates()
    bundle_names = ["CST", "ILF"]

    bundles = {}
    for name in bundle_names:
        for hemi in ['_R', '_L']:
            bundles[name + hemi] = {
                'ROIs': [
                    templates[name + '_roi1' + hemi],
                    templates[name + '_roi1' + hemi]
                ],
                'rules': [True, True]
            }

    print("Registering to template...")
    if not op.exists('mapping.nii.gz'):
        gtab = gradient_table(data_bval, data_bvec)
        mapping = reg.syn_register_dwi(data_file, gtab)
        reg.write_mapping(mapping, './mapping.nii.gz')
    else:
        mapping = reg.read_mapping('./mapping.nii.gz', img, MNI_T2_img)
    """
	MNI_T2_img = dpd.read_mni_template()
	bvals, bvecs = read_bvals_bvecs(data_bval, data_bvec)
	gtab = gradient_table(bvals, bvecs, b0_threshold=100)
	mapping = reg.syn_register_dwi(data_file, gtab)
	reg.write_mapping(mapping, './mapping.nii.gz')
	"""

    print("Segmenting fiber groups...")
    fiber_groups = seg.segment(data_file,
                               data_bval,
                               data_bvec,
                               streamlines,
                               bundles,
                               reg_template=MNI_T2_img,
                               mapping=mapping,
                               as_generator=False,
                               affine=img.affine)

    path = os.getcwd() + '/tract1/'
    if not os.path.exists(path):
        os.makedirs(path)

    for fg in fiber_groups:
        streamlines = fiber_groups[fg]
        fname = fg + ".tck"
        trg = nib.streamlines.Tractogram(streamlines,
                                         affine_to_rasmm=img.affine)
        nib.streamlines.save(trg, path + fname)