Ejemplo n.º 1
0
def bundle_extract(atlas_track_path, atlas_bundle_path, target_track_path):

    time0 = time.time()

    atlas_file = atlas_track_path
    target_file = target_track_path

    print('loading data begin! time:', time.time() - time0)

    sft_atlas = load_trk(atlas_file, "same", bbox_valid_check=False)
    atlas = sft_atlas.streamlines
    atlas_header = create_tractogram_header(atlas_file,
                                            *sft_atlas.space_attributes)

    sft_target = load_trk(target_file, "same", bbox_valid_check=False)
    target = sft_target.streamlines
    target_header = create_tractogram_header(target_file,
                                             *sft_target.space_attributes)

    moved, transform, qb_centroids1, qb_centroids2 = whole_brain_slr(
        atlas,
        target,
        x0='affine',
        verbose=True,
        progressive=True,
        rng=np.random.RandomState(1984))

    bundle_track = StatefulTractogram(moved, target_header, Space.RASMM)
    save_trk(bundle_track, 'moved.trk', bbox_valid_check=False)

    np.save("slr_transform.npy", transform)

    model_bundle_file = atlas_bundle_path
    model_bundle = load_trk(model_bundle_file, "same", bbox_valid_check=False)
    model_bundle = model_bundle.streamlines

    print('comparing begin! time:', time.time() - time0)

    rb = RecoBundles(moved, verbose=True, rng=np.random.RandomState(2001))

    recognized_bundle, bundle_labels = rb.recognize(model_bundle=model_bundle,
                                                    model_clust_thr=0,
                                                    reduction_thr=20,
                                                    reduction_distance='mam',
                                                    slr=True,
                                                    slr_metric='asymmetric',
                                                    pruning_distance='mam')

    bundle_track = StatefulTractogram(target[bundle_labels], target_header,
                                      Space.RASMM)
    return bundle_track
Ejemplo n.º 2
0
def save_roisubset(trkfile, roislist, roisexcel, labelmask):
    #loads trk file, list of rois, the full correspondance of structure => label and the label mask, and saves the
    # tracts traversing each region

    trkdata = load_trk(trkfile, 'same')
    trkdata.to_vox()
    if hasattr(trkdata, 'space_attribute'):
        header = trkdata.space_attribute
    elif hasattr(trkdata, 'space_attributes'):
        header = trkdata.space_attributes
    trkstreamlines = trkdata.streamlines
    import pandas as pd
    df = pd.read_excel(roisexcel, sheet_name='Sheet1')
    df['Structure'] = df['Structure'].str.lower()

    for rois in roislist:

        labelslist = []
        for roi in rois:
            rslt_df = df.loc[df['Structure'] == roi.lower()]
            if rois[0].lower() == "wholebrain" or rois[0].lower() == "brain":
                labelslist = None
            else:
                labelslist = np.concatenate((labelslist, np.array(rslt_df.index2)))
        print(labelslist)
        if isempty(labelslist) and roi.lower() != "wholebrain" and roi.lower() != "brain":
            txt = "Warning: Unrecognized roi, will take whole brain as ROI. The roi specified was: " + roi
            print(txt)

        if isempty(labelslist):
            if labelmask is None:
                roimask = (fdwi_data[:, :, :, 0] > 0)
            else:
                roimask = np.where(labelmask == 0, False, True)
        else:
            if labelmask is None:
                raise ("File not found error: labels requested but labels file could not be found at "+dwipath+ " for subject " + subject)
            roimask = np.zeros(np.shape(labelmask),dtype=int)
            for label in labelslist:
                roimask = roimask + (labelmask == label)

        trkroipath = trkfile.replace(".trk", "_" + rois + ".trk")
        if not os.path.exists(trkroipath):
            affinetemp = np.eye(4)
            trkroistreamlines = target(trkstreamlines, affinetemp, roimask, include=True, strict="longstring")
            trkroistreamlines = Streamlines(trkroistreamlines)
            myheader = create_tractogram_header(trkroipath, *header)
            roi_sl = lambda: (s for s in trkroistreamlines)
            tract_save.save_trk_heavy_duty(trkroipath, streamlines=roi_sl,
                                           affine=header[0], header=myheader)
Ejemplo n.º 3
0
def make_atlas():
    print('Making atlas...')
    trk_list = []
    sft = None
    for t in tracts:
        print('\r' + t + '      ', end='')
        sft = load_trk(tractseg_dir + training_ids[1] + '/tracts/' + t +
                       '.trk',
                       'same',
                       bbox_valid_check=False)
        trk_list.append(sft.streamlines)
    out_file = training_paths[1] + 'T1w/Diffusion/atlas.trk'
    target_header = create_tractogram_header(out_file, *sft.space_attributes)
    sft_rec = StatefulTractogram(
        nib.streamlines.array_sequence.concatenate(trk_list, 0), target_header,
        Space.RASMM)
    save_trk(sft_rec, out_file, bbox_valid_check=False)
Ejemplo n.º 4
0
def save_trk_header(filepath,
                    streamlines,
                    header,
                    affine=np.eye(4),
                    fix_streamlines=False,
                    verbose=False):

    myheader = create_tractogram_header(filepath, *header)
    trk_sl = lambda: (s for s in streamlines)
    if verbose:
        print(f'Saving streamlines to {filepath}')
        time1 = time.perf_counter()
    save_trk_heavy_duty(filepath,
                        streamlines=trk_sl,
                        affine=affine,
                        header=myheader,
                        fix_streamlines=fix_streamlines,
                        return_tractogram=False)
    if verbose:
        time2 = time.perf_counter()
        print(f'Saved in {time2 - time1:0.4f} seconds')
Ejemplo n.º 5
0
def run_recobundles(input_folder, atlas_file, ground_truth_folder,
                    output_folder):
    print('Running Recobundles in ' + input_folder)

    # make a folder to save output
    try:
        Path(output_folder).mkdir(parents=True, exist_ok=True)
    except OSError:
        print('Could not create output dir. Aborting...')
        return

    # Uncomment for first exemplary use
    # target_file, target_folder = fetch_target_tractogram_hcp()
    # atlas_file, atlas_folder = fetch_bundle_atlas_hcp842()

    # target_file = get_target_tractogram_hcp()

    target_file = input_folder + 'whole_brain.trk'

    # use this line to select tracts if necessary
    sel_tracts = tracts
    # sel_bundle_paths = [data_path + 'tractseg/599469/tracts/AF_left.trk']
    # print(sel_bundle_paths)
    sft_atlas = load_trk(atlas_file, 'same', bbox_valid_check=False)
    atlas = sft_atlas.streamlines

    sft_target = load_trk(target_file, 'same', bbox_valid_check=False)
    target = sft_target.streamlines
    target_header = create_tractogram_header(target_file,
                                             *sft_atlas.space_attributes)

    target, transform, qb_centroids1, qb_centroids2 = whole_brain_slr(
        target, atlas, x0='affine', progressive=True)
    print(transform)
    sft_rec = StatefulTractogram(
        nib.streamlines.array_sequence.concatenate([target, atlas], 0),
        target_header, Space.RASMM)
    save_trk(sft_rec, output_folder + 'test.trk', bbox_valid_check=False)
Ejemplo n.º 6
0
def reducetractnumber(oldtrkfile, newtrkfilepath, getdata=False, ratio=10, return_affine= False, verbose=False):

    if verbose:
        print("Beginning to read " + oldtrkfile)
    trkdata = load_trk(oldtrkfile, "same")
    if verbose:
        print("loaded the file " + oldtrkfile)
    trkdata.to_vox()
    if hasattr(trkdata, 'space_attribute'):
        header = trkdata.space_attribute
    elif hasattr(trkdata, 'space_attributes'):
        header = trkdata.space_attributes
    affine = trkdata._affine
    trkstreamlines = trkdata.streamlines

    ministream=[]
    for idx, stream in enumerate(trkstreamlines):
        if (idx % ratio) == 0:
            ministream.append(stream)
    del trkstreamlines
    myheader = create_tractogram_header(newtrkfilepath, *header)
    ratioed_sl = lambda: (s for s in ministream)
    save_trk_heavy_duty(newtrkfilepath, streamlines=ratioed_sl,
                                   affine=affine, header=myheader)
    if verbose:
        print("The file " + oldtrkfile + " was reduced to one "+str(ratio)+"th of its size and saved to "+newtrkfilepath)

    if getdata:
        if return_affine:
            return(ministream,affine)
        else:
            return(ministream)
    else:
        if return_affine:
            return(affine)
        else:
            return
Ejemplo n.º 7
0
from dipy.io.stateful_tractogram import Space, StatefulTractogram
from dipy.io.streamline import load_trk, save_trk
from dipy.io.utils import create_tractogram_header
"""
Download and read data for this tutorial
"""

target_file, target_folder = fetch_target_tractogram_hcp()
atlas_file, atlas_folder = fetch_bundle_atlas_hcp842()

atlas_file, all_bundles_files = get_bundle_atlas_hcp842()
target_file = get_target_tractogram_hcp()

sft_atlas = load_trk(atlas_file, "same", bbox_valid_check=False)
atlas = sft_atlas.streamlines
atlas_header = create_tractogram_header(atlas_file, *sft_atlas.space_attribute)

sft_target = load_trk(target_file, "same", bbox_valid_check=False)
target = sft_target.streamlines
target_header = create_tractogram_header(atlas_file,
                                         *sft_atlas.space_attribute)
"""
let's visualize atlas tractogram and target tractogram before registration
"""

interactive = False

ren = window.Renderer()
ren.SetBackground(1, 1, 1)
ren.add(actor.line(atlas, colors=(1, 0, 1)))
ren.add(actor.line(target, colors=(1, 1, 0)))
Ejemplo n.º 8
0
def QCSA_tractmake(data, affine, vox_size, gtab, mask, masktype, header, step_size, peak_processes, outpathtrk, subject='NA',
                   ratio=1, overwrite=False, get_params=False, doprune=False, figspath=None, verbose=None):
    # Compute odfs in Brain Mask
    t2 = time()
    if os.path.isfile(outpathtrk) and not overwrite:
        txt = "Subject already saved at "+outpathtrk
        print(txt)
        streamlines_generator = None
        params = None
        return outpathtrk, streamlines_generator, params

    csa_model = CsaOdfModel(gtab, 6)
    if peak_processes == 1:
        parallel = False
    else:
        parallel = True
    if verbose:
        send_mail("Starting calculation of Constant solid angle model for subject " + subject,subject="CSA model start")

    wholemask = np.where(mask == 0, False, True)
    print(f"There are {peak_processes} and {parallel} here")
    csa_peaks = peaks_from_model(model=csa_model,
                                 data=data,
                                 sphere=peaks.default_sphere,  # issue with complete sphere
                                 mask=wholemask,
                                 relative_peak_threshold=.5,
                                 min_separation_angle=25,
                                 parallel=parallel,
                                 nbr_processes=peak_processes)

    duration = time() - t2
    if verbose:
        print(subject + ' CSA duration %.3f' % (duration,))

    t3 = time()


    if verbose:
        send_mail('Computing classifier for local tracking for subject ' + subject +
                  ',it has been ' + str(round(duration)) + 'seconds since the start of tractmaker',subject="Seed computation" )

        print('Computing classifier for local tracking for subject ' + subject)

    if masktype == "FA":
        #tensor_model = dti.TensorModel(gtab)
        #tenfit = tensor_model.fit(data, mask=labels > 0)
        #FA = fractional_anisotropy(tenfit.evals)
        FA_threshold = 0.05
        classifier = ThresholdStoppingCriterion(mask, FA_threshold)

        if figspath is not None:
            fig = plt.figure()
            mask_fa = mask.copy()
            mask_fa[mask_fa < FA_threshold] = 0
            plt.xticks([])
            plt.yticks([])
            plt.imshow(mask_fa[:, :, data.shape[2] // 2].T, cmap='gray', origin='lower',
                       interpolation='nearest')
            fig.tight_layout()
            fig.savefig(figspath + 'threshold_fa.png')
    else:
        classifier = BinaryStoppingCriterion(wholemask)

    # generates about 2 seeds per voxel
    # seeds = utils.random_seeds_from_mask(fa > .2, seeds_count=2,
    #                                      affine=np.eye(4))

    # generates about 2 million streamlines
    # seeds = utils.seeds_from_mask(fa > .2, density=1,
    #                              affine=np.eye(4))

    if verbose:
        print('Computing seeds')
    seeds = utils.seeds_from_mask(wholemask, density=1,
                                  affine=np.eye(4))

    #streamlines_generator = local_tracking.local_tracker(csa_peaks,classifier,seeds,affine=np.eye(4),step_size=step_size)
    if verbose:
        print('Computing the local tracking')
        duration = time() - t2
        send_mail('Start of the local tracking ' + ',it has been ' + str(round(duration)) +
                  'seconds since the start of tractmaker', subject="Seed computation")

    #stepsize = 2 #(by default)
    stringstep = str(step_size)
    stringstep = stringstep.replace(".", "_")
    if verbose:
        print("stringstep is "+stringstep)

    streamlines_generator = LocalTracking(csa_peaks, classifier,
                                          seeds, affine=np.eye(4), step_size=step_size)

    if verbose:
        duration = time() - t2
        txt = 'About to save streamlines at ' + outpathtrk + ',it has been ' + str(round(duration)) + \
              'seconds since the start of tractmaker',
        send_mail(txt,subject="Tract saving" )

    cutoff = 2
    if doprune:
        streamlines_generator = prune_streamlines(list(streamlines_generator), data[:, :, :, 0], cutoff=cutoff,
                                                  verbose=verbose)
        myheader = create_tractogram_header(outpathtrk, *header)
        sg = lambda: (s for i, s in enumerate(streamlines_generator) if i % ratio == 0)
        save_trk_heavy_duty(outpathtrk, streamlines=sg,
                            affine=affine, header=myheader,
                            shape=mask.shape, vox_size=vox_size)
    else:
        sg = lambda: (s for i, s in enumerate(streamlines_generator) if i % ratio == 0)
        myheader = create_tractogram_header(outpathtrk, *header)
        save_trk_heavy_duty(outpathtrk, streamlines=sg,
                            affine=affine, header=myheader,
                            shape=mask.shape, vox_size=vox_size)
    if verbose:
        duration = time() - t2
        txt = "Tract files were saved at "+outpathtrk + ',it has been ' + str(round(duration)) + \
              'seconds since the start of tractmaker'
        print(txt)
        send_mail(txt,subject="Tract saving" )

    # save everything - will generate a 20+ GBytes of data - hard to manipulate

    # possibly add parameter in csv file or other to decide whether to save large tractogram file
    # outpathfile=outpath+subject+"bmCSA_detr"+stringstep+".trk"
    # myheader=create_tractogram_header(outpathfile,*get_reference_info(fdwi))
    duration3 = time() - t2
    if verbose:
        print(duration3)
        print(subject + ' Tracking duration %.3f' % (duration3,))
        send_mail("Finished file save at "+outpathtrk+" with tracking duration of " + str(duration3) + "seconds",
                  subject="file save update" )

    if get_params:
        numtracts, minlength, maxlength, meanlength, stdlength = get_trk_params(streamlines_generator, verbose)
        params = [numtracts, minlength, maxlength, meanlength, stdlength]
        if verbose:
            print("For subject " + str(subject) + " the number of tracts is " + str(numtracts) + ", the minimum length is " +
                  str(minlength) + ", the maximum length is " + str(maxlength) + ", the mean length is " +
                  str(meanlength) + ", the std is " + str(stdlength))
    else:
        params = None

    return outpathtrk, streamlines_generator, params
Ejemplo n.º 9
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(
        parser,
        [args.in_sh, args.in_seed, args.in_map_include, args.map_exclude_file])
    assert_outputs_exist(parser, args, args.out_tractogram)

    if not nib.streamlines.is_supported(args.out_tractogram):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.out_tractogram))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {}mm was provided.'.format(
            args.min_length))
    if args.max_length < args.min_length:
        parser.error(
            'maxL must be > than minL, (minL={}mm, maxL={}mm).'.format(
                args.min_length, args.max_length))

    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warning(
                'You are using an error rate of {}.\nWe recommend setting it '
                'between 0.001 and 1.\n0.001 will do almost nothing to the '
                'tracts while 1 will higly compress/linearize the tracts'.
                format(args.compress))

    if args.particles <= 0:
        parser.error('--particles must be >= 1.')

    if args.back_tracking <= 0:
        parser.error('PFT backtracking distance must be > 0.')

    if args.forward_tracking <= 0:
        parser.error('PFT forward tracking distance must be > 0.')

    if args.npv and args.npv <= 0:
        parser.error('Number of seeds per voxel must be > 0.')

    if args.nt and args.nt <= 0:
        parser.error('Total number of seeds must be > 0.')

    fodf_sh_img = nib.load(args.in_sh)
    if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]),
                       fodf_sh_img.header.get_zooms()[0],
                       atol=1.e-3):
        parser.error(
            'SH file is not isotropic. Tracking cannot be ran robustly.')

    tracking_sphere = HemiSphere.from_sphere(get_sphere('repulsion724'))

    # Check if sphere is unit, since we couldn't find such check in Dipy.
    if not np.allclose(np.linalg.norm(tracking_sphere.vertices, axis=1), 1.):
        raise RuntimeError('Tracking sphere should be unit normed.')

    sh_basis = args.sh_basis

    if args.algo == 'det':
        dgklass = DeterministicMaximumDirectionGetter
    else:
        dgklass = ProbabilisticDirectionGetter

    theta = get_theta(args.theta, args.algo)

    # Reminder for the future:
    # pmf_threshold == clip pmf under this
    # relative_peak_threshold is for initial directions filtering
    # min_separation_angle is the initial separation angle for peak extraction
    dg = dgklass.from_shcoeff(fodf_sh_img.get_fdata(dtype=np.float32),
                              max_angle=theta,
                              sphere=tracking_sphere,
                              basis_type=sh_basis,
                              pmf_threshold=args.sf_threshold,
                              relative_peak_threshold=args.sf_threshold_init)

    map_include_img = nib.load(args.in_map_include)
    map_exclude_img = nib.load(args.map_exclude_file)
    voxel_size = np.average(map_include_img.header['pixdim'][1:4])

    if not args.act:
        tissue_classifier = CmcStoppingCriterion(
            map_include_img.get_fdata(dtype=np.float32),
            map_exclude_img.get_fdata(dtype=np.float32),
            step_size=args.step_size,
            average_voxel_size=voxel_size)
    else:
        tissue_classifier = ActStoppingCriterion(
            map_include_img.get_fdata(dtype=np.float32),
            map_exclude_img.get_fdata(dtype=np.float32))

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = fodf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.in_seed)
    seeds = track_utils.random_seeds_from_mask(
        get_data_as_mask(seed_img, dtype=np.bool),
        np.eye(4),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Note that max steps is used once for the forward pass, and
    # once for the backwards. This doesn't, in fact, control the real
    # max length
    max_steps = int(args.max_length / args.step_size) + 1
    pft_streamlines = ParticleFilteringTracking(
        dg,
        tissue_classifier,
        seeds,
        np.eye(4),
        max_cross=1,
        step_size=vox_step_size,
        maxlen=max_steps,
        pft_back_tracking_dist=args.back_tracking,
        pft_front_tracking_dist=args.forward_tracking,
        particle_count=args.particles,
        return_all=args.keep_all,
        random_seed=args.seed,
        save_seeds=args.save_seeds)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    if args.save_seeds:
        filtered_streamlines, seeds = \
            zip(*((s, p) for s, p in pft_streamlines
                  if scaled_min_length <= length(s) <= scaled_max_length))
        data_per_streamlines = {'seeds': lambda: seeds}
    else:
        filtered_streamlines = \
            (s for s in pft_streamlines
             if scaled_min_length <= length(s) <= scaled_max_length)
        data_per_streamlines = {}

    if args.compress:
        filtered_streamlines = (compress_streamlines(s, args.compress)
                                for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                data_per_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.out_tractogram)
    reference = get_reference_info(seed_img)
    header = create_tractogram_header(filetype, *reference)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.out_tractogram, header=header)
Ejemplo n.º 10
0
            ROI_streamlines = [int(k) for k in select_streamlines]
            ROI_streamlines_all.append(ROI_streamlines)
            ROI_name = ([ROI_list[ROI_tuple[0]], ROI_list[ROI_tuple[1]]])
            ROI_names.append(ROI_name)

            if ROI_streamlines:
                trk_ROI_streamlines = []
                for ROI_stream in ROI_streamlines:
                    trk_ROI_streamlines.append(trkstreamlines[ROI_stream])
                # trk_ROI_streamlines = trkstreamlines[ROI_streamlines]
                pathfile_name = os.path.join(
                    trk_outpath_subj, subj + str_identifier + "_" +
                    ROI_name[0][0:11] + "_" + ROI_name[1][0:11] + '.trk')
                ROI_sl = lambda: (s for s in trk_ROI_streamlines)
                myheader = create_tractogram_header(pathfile_name, *header)
                if trksave and not os.path.isfile(pathfile_name):
                    save_trk_heavy_duty(pathfile_name,
                                        streamlines=ROI_sl,
                                        affine=affine,
                                        header=myheader)
                affine_streams = np.eye(4)
                trkaffine = np.eye(4)
                #favals, mdvals, numtracts, minlength, maxlength, meanlength, stdlength = get_connectome_attributes(trk_ROI_streamlines, affine=trkaffine, fa=fa, md=None, verbose=True)

                metric = SumPointwiseEuclideanMetric(
                    feature=ArcLengthFeature())
                qb = QuickBundles(threshold=2., metric=metric)
                clusters = qb.cluster(trk_ROI_streamlines)

                if view:
Ejemplo n.º 11
0
def save_roisubset(streamlines, roislist, roisexcel, labelmask, stringstep, ratios, trkpath, subject, affine, header):
    
    #atlas_legends = BIGGUS_DISKUS + "/atlases/CHASSSYMM3AtlasLegends.xlsx"
    
    df = pd.read_excel(roisexcel, sheet_name='Sheet1')
    df['Structure'] = df['Structure'].str.lower()    
    
    for rois in roislist:
        if len(rois)==1:
            roiname = "_" + rois[0] + "_"
        elif len(rois)>1:
            roiname="_"
            for roi in rois:
                roiname = roiname + roi[0:4]
            roiname = roiname + "_"    
            
        labelslist=[]#fimbria

        for roi in rois:
            rslt_df = df.loc[df['Structure'] == roi.lower()]
            if roi.lower() == "wholebrain" or roi.lower() == "brain":
                labelslist=None
            else:
                labelslist=np.concatenate((labelslist,np.array(rslt_df.index2)))

        if isempty(labelslist) and roi.lower() != "wholebrain" and roi.lower() != "brain":
            txt = "Warning: Unrecognized roi, will take whole brain as ROI. The roi specified was: " + roi
            print(txt)

#bvec_orient=[-2,1,3]    
    
        if isempty(labelslist):
            roimask = np.where(labelmask == 0, False, True)
        else:
            if labelmask is None:
                raise ("Bad label data, could not define ROI for streams")
            roimask = np.zeros(np.shape(labelmask),dtype=int)
            for label in labelslist:
                roimask = roimask + (labelmask == label)
        
        if not isempty(labelslist):
            trkroipath = trkpath + '/' + subject + roiname + "_stepsize_" + stringstep + '.trk'
            if not os.path.exists(trkroipath):
                affinetemp=np.eye(4)
                trkstreamlines = target(streamlines, affinetemp, roimask, include=True, strict="longstring")
                trkstreamlines = Streamlines(trkstreamlines)
                myheader = create_tractogram_header(trkroipath, *header)
                roi_sl = lambda: (s for s in trkstreamlines)
                save_trk_heavy_duty(trkroipath, streamlines=roi_sl,
                            affine=affine, header=myheader)
            else:
                trkdata = load_trk(trkroipath, 'same')
                trkdata.to_vox()
                if hasattr(trkdata, 'space_attribute'):
                    header = trkdata.space_attribute
                elif hasattr(trkdata, 'space_attributes'):
                    header = trkdata.space_attributes
                trkstreamlines = trkdata.streamlines
                
        for ratio in ratios:
            if ratio != 1:
                trkroiminipath = trkpath + '/' + subject + '_ratio_' + ratios + roiname + "_stepsize_" + stringstep + '.trk'
                if not os.path.exists(trkroiminipath):
                    ministream = []
                    for idx, stream in enumerate(trkstreamlines):
                        if (idx % ratio) == 0:
                            ministream.append(stream)
                    trkstreamlines = ministream
                    myheader = create_tractogram_header(trkminipath, *header)
                    ratioed_roi_sl_gen = lambda: (s for s in trkstreamlines)
                    if allsave:
                        save_trk_heavy_duty(trkroiminipath, streamlines=ratioed_roi_sl_gen,
                                            affine=affine, header=myheader)
                else:
                    trkdata = load_trk(trkminipath, 'same')
                    trkdata.to_vox()
                    if hasattr(trkdata, 'space_attribute'):
                        header = trkdata.space_attribute
                    elif hasattr(trkdata, 'space_attributes'):
                        header = trkdata.space_attributes
                    trkstreamlines = trkdata.streamlines
Ejemplo n.º 12
0
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [args.sh_file, args.seed_file, args.mask_file])
    assert_outputs_exist(parser, args, args.output_file)

    if not nib.streamlines.is_supported(args.output_file):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.output_file))

    if not args.min_length > 0:
        parser.error('minL must be > 0, {}mm was provided.'.format(
            args.min_length))
    if args.max_length < args.min_length:
        parser.error(
            'maxL must be > than minL, (minL={}mm, maxL={}mm).'.format(
                args.min_length, args.max_length))

    if args.compress:
        if args.compress < 0.001 or args.compress > 1:
            logging.warning(
                'You are using an error rate of {}.\nWe recommend setting it '
                'between 0.001 and 1.\n0.001 will do almost nothing to the '
                'tracts while 1 will higly compress/linearize the tracts'.
                format(args.compress))

    if args.npv and args.npv <= 0:
        parser.error('Number of seeds per voxel must be > 0.')

    if args.nt and args.nt <= 0:
        parser.error('Total number of seeds must be > 0.')

    mask_img = nib.load(args.mask_file)
    mask_data = mask_img.get_fdata()

    # Make sure the mask is isotropic. Else, the strategy used
    # when providing information to dipy (i.e. working as if in voxel space)
    # will not yield correct results.
    fodf_sh_img = nib.load(args.sh_file)
    if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]),
                       fodf_sh_img.header.get_zooms()[0],
                       atol=1.e-3):
        parser.error(
            'SH file is not isotropic. Tracking cannot be ran robustly.')

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = fodf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.seed_file)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_fdata(),
        np.eye(4),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Tracking is performed in voxel space
    max_steps = int(args.max_length / args.step_size) + 1
    streamlines = LocalTracking(_get_direction_getter(args, mask_data),
                                BinaryStoppingCriterion(mask_data),
                                seeds,
                                np.eye(4),
                                step_size=vox_step_size,
                                max_cross=1,
                                maxlen=max_steps,
                                fixedstep=True,
                                return_all=True,
                                random_seed=args.seed,
                                save_seeds=args.save_seeds)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    if args.save_seeds:
        filtered_streamlines, seeds = \
            zip(*((s, p) for s, p in streamlines
                  if scaled_min_length <= length(s) <= scaled_max_length))
        data_per_streamlines = {'seeds': lambda: seeds}
    else:
        filtered_streamlines = \
            (s for s in streamlines
             if scaled_min_length <= length(s) <= scaled_max_length)
        data_per_streamlines = {}

    if args.compress:
        filtered_streamlines = (compress_streamlines(s, args.compress)
                                for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                data_per_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.output_file)
    reference = get_reference_info(seed_img)
    header = create_tractogram_header(filetype, *reference)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.output_file, header=header)
Ejemplo n.º 13
0
def main():
    t_init = perf_counter()
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    assert_inputs_exist(parser, [args.in_odf, args.in_mask, args.in_seed])
    assert_outputs_exist(parser, args, args.out_tractogram)
    if args.compress is not None:
        verify_compression_th(args.compress)

    odf_sh_img = nib.load(args.in_odf)
    mask = get_data_as_mask(nib.load(args.in_mask))
    seed_mask = get_data_as_mask(nib.load(args.in_seed))
    odf_sh = odf_sh_img.get_fdata(dtype=np.float32)

    t0 = perf_counter()
    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    # Seeds are returned with origin `center`.
    # However, GPUTracker expects origin to be `corner`.
    # Therefore, we need to shift the seed positions by half voxel.
    seeds = random_seeds_from_mask(seed_mask,
                                   np.eye(4),
                                   seeds_count=nb_seeds,
                                   seed_count_per_voxel=seed_per_vox,
                                   random_seed=args.rng_seed) + 0.5
    logging.info('Generated {0} seed positions in {1:.2f}s.'.format(
        len(seeds),
        perf_counter() - t0))

    voxel_size = odf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    vox_max_length = args.max_length / voxel_size
    vox_min_length = args.min_length / voxel_size
    min_strl_len = int(vox_min_length / vox_step_size) + 1
    max_strl_len = int(vox_max_length / vox_step_size) + 1

    # initialize tracking
    tracker = GPUTacker(odf_sh, mask, seeds, vox_step_size, min_strl_len,
                        max_strl_len, args.theta, args.sh_basis,
                        args.batch_size, args.forward_only, args.rng_seed)

    # wrapper for tracker.track() yielding one TractogramItem per
    # streamline for use with the LazyTractogram.
    def tracks_generator_wrapper():
        for strl, seed in tracker.track():
            # seed must be saved in voxel space, with origin `center`.
            dps = {'seeds': seed - 0.5} if args.save_seeds else {}

            # TODO: Investigate why the streamline must NOT be shifted to
            # origin `corner` for LazyTractogram.
            strl *= voxel_size  # in mm.
            if args.compress:
                strl = compress_streamlines(strl, args.compress)
            yield TractogramItem(strl, dps, {})

    # instantiate tractogram
    tractogram = LazyTractogram.from_data_func(tracks_generator_wrapper)
    tractogram.affine_to_rasmm = odf_sh_img.affine

    filetype = nib.streamlines.detect_format(args.out_tractogram)
    reference = get_reference_info(odf_sh_img)
    header = create_tractogram_header(filetype, *reference)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.out_tractogram, header=header)
    logging.info('Saved tractogram to {0}.'.format(args.out_tractogram))

    # Total runtime
    logging.info('Total runtime of {0:.2f}s.'.format(perf_counter() - t_init))
Ejemplo n.º 14
0
def LiFEvaluation(dwidata,
                  trk_streamlines,
                  gtab,
                  subject="lifesubj",
                  header=None,
                  roimask=None,
                  affine=None,
                  display=True,
                  outpathpickle=None,
                  outpathtrk=None,
                  processes=1,
                  outpathfig=None,
                  strproperty="",
                  verbose=None):
    """     Implementation of Linear Fascicle Evaluation, outputs histograms, evals

    Parameters
    ----------
    dwidata : array
        output trk filename
    trkdata : array
    gtab : array og bval & bvec table
    outpath: string
    folder location for resulting values and figures
    display : boolean, optional
    Condition to display the results (default = False)
    savefig: boolean, optional
    Condition to save the results in outpath (default = True)

    Defined by Pestilli, F., Yeatman, J, Rokem, A. Kay, K. and Wandell B.A. (2014).
    Validation and statistical inference in living connectomes and recreated by Dipy

    :param dwidata: array of diffusion data
    :param trkdata: array of tractography data obtained from dwi
    :param gtab: bval & bvec table
    :param outpath: location to save analysis outputs
    :param display:
    :param savefig:
    :return:
    """
    """""" """
    if not op.exists('lr-superiorfrontal.trk'):
        
    else:
        # We'll need to know where the corpus callosum is from these variables:
        from dipy.data import (read_stanford_labels,
                               fetch_stanford_t1,
                               read_stanford_t1)
        hardi_img, gtab, labels_img = read_stanford_labels()
        labels = labels_img.get_data()
        cc_slice = labels == 2
        fetch_stanford_t1()
        t1 = read_stanford_t1()
        t1_data = t1.get_data()
        data = hardi_img.get_data()
    """ """"""
    # Read the candidates from file in voxel space:

    if roimask is None:
        roimask = dwidata > 0
    else:
        dwidataroi = dwidata * np.repeat(
            roimask[:, :, :, None], np.shape(dwidata)[3], axis=3)

    print("verbose: " + str(verbose) + " outpathpickle: " + str(outpathpickle))
    fiber_model = life.FiberModel(gtab)
    # inv_affine must be used if the streamlines are in the world space, and thus we must useapply the inverse affine of dwi
    #when comparing the diffusion directions gtab and the voxels of trk
    #inv_affine = np.linalg.inv(hardi_img.affine)

    #fiber_fit will fit the streamlines to the original diffusion data and
    if verbose:
        txt = "Begin the evaluation over " + str(
            np.size(trk_streamlines)) + " streamlines"
        print(txt)
        send_mail(txt, subject="LifE start msg ")

    fiber_fit = fiber_model.fit(dwidata,
                                trk_streamlines,
                                affine=np.eye(4),
                                processes=processes,
                                verbose=verbose)
    #fiber_fit_roi = fiber_model.fit(dwidataroi, trk_streamlines, affine=np.eye(4), processes=processes, verbose=verbose)
    optimized_sl = list(
        np.array(trk_streamlines)[np.where(fiber_fit.beta > 0)[0]])
    plt.ioff()
    if verbose:
        txt = "End of the evaluation over " + str(np.size(trk_streamlines))
        print(txt)
        send_mail(txt, subject="LifE status msg ")
    if outpathtrk is not None:
        outpathfile = str(
            outpathtrk) + subject + strproperty + "_lifeopt_test.trk"
        myheader = create_tractogram_header(outpathfile, *header)
        optimized_sl_gen = lambda: (s for s in optimized_sl)
        save_trk_heavy_duty(outpathfile,
                            streamlines=optimized_sl_gen,
                            affine=affine,
                            header=myheader)
        txt = ("Saved final trk at " + outpathfile)
        print(txt)
        send_mail(txt, subject="LifE save msg ")
        """
        except TypeError:
            txt=('Could not save new tractogram file, header of original trk file not properly implemented into '
                  'LifEvaluation')
            print(txt)
            send_mail(txt,subject="LifE error msg ")
        """
    """
    if interactive:
        ren = window.Renderer()
        ren.add(actor.streamtube(optimized_sl, cmap.line_colors(optimized_sl)))
        ren.add(ROI_actor)
        #ren.add(vol_actor)
        if interactive:
            window.show(ren)      
        if outpathfig is not None:
            print("reached windowrecord")
            window.record(ren, n_frames=1, out_path=outpathfig +'_life_optimized.png',
                size=(800, 800))
            print("did window record")
    """
    maxsize_var = 20525023825

    sizebeta = getsize(fiber_fit.beta)
    if sizebeta < maxsize_var:
        picklepath = outpathpickle + subject + strproperty + '_beta.p'
        txt = ("fiber_fit.beta saved at " + picklepath)
        pickle.dump(fiber_fit.beta, open(picklepath, "wb"))
        if verbose:
            print(txt)
            send_mail(txt, subject="LifE save msg ")
    else:
        txt = (
            "Object fiber_fit.beta exceeded the imposed the 20GB limit with a size of: "
            + str(sizebeta / (10 ^ 9)) + "GB")
        print(txt)
        send_mail(txt, subject="LifE error msg")

    sizecoords = getsize(fiber_fit.vox_coords)
    if sizecoords < maxsize_var:
        picklepath = outpathpickle + subject + strproperty + '_voxcoords.p'
        txt = ("fiber_fit.voxcoords saved at " + picklepath)
        pickle.dump(fiber_fit.vox_coords, open(picklepath, "wb"))
        if verbose:
            print(txt)
            send_mail(txt, subject="LifE save msg ")
    else:
        txt = (
            "Object fiber_fit.beta exceeded the imposed the 20GB limit with a size of: "
            + str(sizebeta / (10 ^ 9)) + "GB")
        print(txt)
        send_mail(txt, subject="LifE error msg")

    #predict diffusion data based on new model
    model_predict = fiber_fit.predict(
    )  #possible to predict based on different gtab or base signal (change gtab, S0)
    model_error = model_predict - fiber_fit.data  #compare original dwi data and the model fit, calculate error
    model_rmse = np.sqrt(
        np.mean(model_error[:, 10:]**2,
                -1))  #this is good, but must check ways to avoid overfitting
    #how does the model get built? add lasso? JS

    beta_baseline = np.zeros(
        fiber_fit.beta.shape[0]
    )  #baseline assumption where the streamlines weight is 0
    pred_weighted = np.reshape(
        opt.spdot(fiber_fit.life_matrix, beta_baseline),
        (fiber_fit.vox_coords.shape[0], np.sum(~gtab.b0s_mask)))
    mean_pred = np.empty((fiber_fit.vox_coords.shape[0], gtab.bvals.shape[0]))
    S0 = fiber_fit.b0_signal

    mean_pred[..., gtab.b0s_mask] = S0[:, None]
    mean_pred[..., ~gtab.b0s_mask] = \
        (pred_weighted + fiber_fit.mean_signal[:, None]) * S0[:, None]
    mean_error = mean_pred - fiber_fit.data
    mean_rmse = np.sqrt(np.mean(mean_error**2, -1))

    size_meanrmse = getsize(mean_rmse)
    if size_meanrmse < maxsize_var:
        picklepath = outpathpickle + subject + strproperty + '_mean_rmse.p'
        txt = ("mean_rmse saved at " + picklepath)
        pickle.dump(mean_rmse, open(picklepath, "wb"))
        if verbose:
            print(txt)
            send_mail(txt, subject="LifE save msg ")
    else:
        txt = (
            "Object mean_rmse exceeded the imposed the 20GB limit with a size of: "
            + str(size_meanrmse / (10 ^ 9)) + " GB")
        print(txt)
        send_mail(txt, subject="LifE error msg")

    size_modelrmse = getsize(model_rmse)
    if size_modelrmse < maxsize_var:
        picklepath = outpathpickle + subject + strproperty + '_model_rmse.p'
        txt = ("model_rmse saved at " + picklepath)
        pickle.dump(model_rmse, open(picklepath, "wb"))
        if verbose:
            print(txt)
            send_mail(txt, subject="LifE save msg ")
    else:
        txt = (
            "Object model_rmse exceeded the imposed the 20GB limit with a size of: "
            + str(size_modelrmse / (10 ^ 9)) + " GB")
        print(txt)
        send_mail(txt, subject="LifE error msg")

    if outpathfig is not None:
        try:
            import matplotlib.pyplot as myplot
            fig, ax = plt.subplots(1)
            ax.hist(fiber_fit.beta, bins=100, histtype='step')
            LifEcreate_fig(fiber_fit.beta,
                           mean_rmse,
                           model_rmse,
                           fiber_fit.vox_coords,
                           dwidata,
                           subject,
                           t1_data=dwidata[:, :, :, 0],
                           outpathfig=outpathfig,
                           interactive=False,
                           strproperty=strproperty,
                           verbose=verbose)
        except:
            print(
                "Coult not launch life create fig, possibly qsub location (this is a template warning, to be improved upon"
            )
    return model_error, mean_error
Ejemplo n.º 15
0
def save_tractogram(sft, filename, bbox_valid_check=True):
    """ Save the stateful tractogram in any format (trk, tck, vtk, fib, dpy)

    Parameters
    ----------
    sft : StatefulTractogram
        The stateful tractogram to save
    filename : string
        Filename with valid extension

    Returns
    -------
    output : bool
        Did the saving work properly
    """

    _, extension = os.path.splitext(filename)
    if extension not in ['.trk', '.tck', '.vtk', '.fib', '.dpy']:
        TypeError('Output filename is not one of the supported format')

    if bbox_valid_check and not sft.is_bbox_in_vox_valid():
        raise ValueError('Bounding box is not valid in voxel space, cannot ' +
                         'save a valid file if some coordinates are invalid')

    old_space = deepcopy(sft.space)
    old_shift = deepcopy(sft.shifted_origin)

    sft.to_rasmm()
    sft.to_center()

    timer = time.time()
    if extension in ['.trk', '.tck']:
        tractogram_type = detect_format(filename)
        header = create_tractogram_header(tractogram_type,
                                          *sft.space_attribute)
        new_tractogram = Tractogram(sft.streamlines,
                                    affine_to_rasmm=np.eye(4))

        if extension == '.trk':
            new_tractogram.data_per_point = sft.data_per_point
            new_tractogram.data_per_streamline = sft.data_per_streamline

        fileobj = tractogram_type(new_tractogram, header=header)
        nib.streamlines.save(fileobj, filename)

    elif extension in ['.vtk', '.fib']:
        save_vtk_streamlines(sft.streamlines, filename, binary=True)
    elif extension in ['.dpy']:
        dpy_obj = Dpy(filename, mode='w')
        dpy_obj.write_tracks(sft.streamlines)
        dpy_obj.close()

    logging.debug('Save %s with %s streamlines in %s seconds',
                  filename, len(sft), round(time.time() - timer, 3))

    if old_space == Space.VOX:
        sft.to_vox()
    elif old_space == Space.VOXMM:
        sft.to_voxmm()

    if old_shift:
        sft.to_corner()

    return True
def main():
    parser = _build_arg_parser()
    args = parser.parse_args()

    if args.verbose:
        logging.basicConfig(level=logging.DEBUG)

    assert_inputs_exist(parser, [args.in_odf, args.in_seed, args.in_mask])
    assert_outputs_exist(parser, args, args.out_tractogram)

    if not nib.streamlines.is_supported(args.out_tractogram):
        parser.error('Invalid output streamline file format (must be trk or ' +
                     'tck): {0}'.format(args.out_tractogram))

    verify_streamline_length_options(parser, args)
    verify_compression_th(args.compress)
    verify_seed_options(parser, args)

    mask_img = nib.load(args.in_mask)
    mask_data = get_data_as_mask(mask_img, dtype=bool)

    # Make sure the data is isotropic. Else, the strategy used
    # when providing information to dipy (i.e. working as if in voxel space)
    # will not yield correct results.
    odf_sh_img = nib.load(args.in_odf)
    if not np.allclose(np.mean(odf_sh_img.header.get_zooms()[:3]),
                       odf_sh_img.header.get_zooms()[0], atol=1e-03):
        parser.error(
            'ODF SH file is not isotropic. Tracking cannot be ran robustly.')

    if args.npv:
        nb_seeds = args.npv
        seed_per_vox = True
    elif args.nt:
        nb_seeds = args.nt
        seed_per_vox = False
    else:
        nb_seeds = 1
        seed_per_vox = True

    voxel_size = odf_sh_img.header.get_zooms()[0]
    vox_step_size = args.step_size / voxel_size
    seed_img = nib.load(args.in_seed)
    seeds = track_utils.random_seeds_from_mask(
        seed_img.get_fdata(dtype=np.float32),
        np.eye(4),
        seeds_count=nb_seeds,
        seed_count_per_voxel=seed_per_vox,
        random_seed=args.seed)

    # Tracking is performed in voxel space
    max_steps = int(args.max_length / args.step_size) + 1
    streamlines_generator = LocalTracking(
        _get_direction_getter(args),
        BinaryStoppingCriterion(mask_data),
        seeds, np.eye(4),
        step_size=vox_step_size, max_cross=1,
        maxlen=max_steps,
        fixedstep=True, return_all=True,
        random_seed=args.seed,
        save_seeds=args.save_seeds)

    scaled_min_length = args.min_length / voxel_size
    scaled_max_length = args.max_length / voxel_size

    if args.save_seeds:
        filtered_streamlines, seeds = \
            zip(*((s, p) for s, p in streamlines_generator
                  if scaled_min_length <= length(s) <= scaled_max_length))
        data_per_streamlines = {'seeds': lambda: seeds}
    else:
        filtered_streamlines = \
            (s for s in streamlines_generator
             if scaled_min_length <= length(s) <= scaled_max_length)
        data_per_streamlines = {}

    if args.compress:
        filtered_streamlines = (
            compress_streamlines(s, args.compress)
            for s in filtered_streamlines)

    tractogram = LazyTractogram(lambda: filtered_streamlines,
                                data_per_streamlines,
                                affine_to_rasmm=seed_img.affine)

    filetype = nib.streamlines.detect_format(args.out_tractogram)
    reference = get_reference_info(seed_img)
    header = create_tractogram_header(filetype, *reference)

    # Use generator to save the streamlines on-the-fly
    nib.streamlines.save(tractogram, args.out_tractogram, header=header)
Ejemplo n.º 17
0
def save_tractogram(sft, filename, bbox_valid_check=True):
    """ Save the stateful tractogram in any format (trk, tck, vtk, fib, dpy)

    Parameters
    ----------
    sft : StatefulTractogram
        The stateful tractogram to save
    filename : string
        Filename with valid extension
    bbox_valid_check : bool
        Verification for negative voxel coordinates or values above the
        volume dimensions. Default is True, to enforce valid file.

    Returns
    -------
    output : bool
        True if the saving operation was successful
    """

    _, extension = os.path.splitext(filename)
    if extension not in ['.trk', '.tck', '.vtk', '.fib', '.dpy']:
        raise TypeError('Output filename is not one of the supported format')

    if bbox_valid_check and not sft.is_bbox_in_vox_valid():
        raise ValueError('Bounding box is not valid in voxel space, cannot ' +
                         'load a valid file if some coordinates are ' +
                         'invalid. Please use the function ' +
                         'remove_invalid_streamlines to discard invalid ' +
                         'streamlines or set bbox_valid_check to False')

    old_space = deepcopy(sft.space)
    old_shift = deepcopy(sft.shifted_origin)

    sft.to_rasmm()
    sft.to_center()

    timer = time.time()
    if extension in ['.trk', '.tck']:
        tractogram_type = detect_format(filename)
        header = create_tractogram_header(tractogram_type,
                                          *sft.space_attributes)
        new_tractogram = Tractogram(sft.streamlines, affine_to_rasmm=np.eye(4))

        if extension == '.trk':
            new_tractogram.data_per_point = sft.data_per_point
            new_tractogram.data_per_streamline = sft.data_per_streamline

        fileobj = tractogram_type(new_tractogram, header=header)
        nib.streamlines.save(fileobj, filename)

    elif extension in ['.vtk', '.fib']:
        save_vtk_streamlines(sft.streamlines, filename, binary=True)
    elif extension in ['.dpy']:
        dpy_obj = Dpy(filename, mode='w')
        dpy_obj.write_tracks(sft.streamlines)
        dpy_obj.close()

    logging.debug('Save %s with %s streamlines in %s seconds', filename,
                  len(sft), round(time.time() - timer, 3))

    if old_space == Space.VOX:
        sft.to_vox()
    elif old_space == Space.VOXMM:
        sft.to_voxmm()

    if old_shift:
        sft.to_corner()

    return True
Ejemplo n.º 18
0
def header_superpose_trk(target_path, origin_path, outpath=None):

    if not isinstance(origin_path, str):
        origin_trk = origin_path
    else:
        origin_trk = load_trk(origin_path, 'same')

    target_data, target_affine, vox_size, target_header, target_ref_info = extract_nii_info(
        target_path)

    if outpath is None:
        if isinstance(origin_path, str):
            warnings.warn("Will copy over old trkfile, if this what you want?")
            permission = input("enter yes or y if you are ok with this")
            if permission.lower() == "yes" or permission.lower() == "y":
                outpath = origin_trk
            else:
                raise Exception("Will not copy over old trk file")
        else:
            raise Exception("Need to specify a output path of some kind")

    trk_header = origin_trk.space_attributes
    trk_affine = origin_trk._affine
    trkstreamlines = origin_trk.streamlines
    if np.any(trk_header[1][0:3] != np.shape(target_data)[0:3]):
        raise TypeError(
            'Size of the originating matrix are difference, recalculation not implemented'
        )
    if np.any(trk_affine != target_affine):
        test = 3
        if test == 1:
            trk_header = list(trk_header)
            trk_header[0] = target_affine
            trk_header = tuple(trk_header)
            myheader = create_tractogram_header(outpath, *trk_header)
            trk_sl = lambda: (s for s in trkstreamlines)
            save_trk_heavy_duty(outpath,
                                streamlines=trk_sl,
                                affine=target_affine,
                                header=myheader)
        elif test == 2:
            transform_matrix = (
                np.inverse(np.transpose(trk_affine) * trk_affine) *
                np.transpose(trk_affine)) * target_affine
            from dipy.tracking.streamline import transform_streamlines
            myheader = create_tractogram_header(outpath, *trk_header)
            new_streamlines = transform_streamlines(trkstreamlines,
                                                    transform_matrix)
            trk_sl = lambda: (s for s in new_streamlines)
            save_trk_heavy_duty(outpath,
                                streamlines=trkstreamlines,
                                affine=trk_affine,
                                header=myheader)
        elif test == 3:
            myheader = create_tractogram_header(outpath, *target_ref_info)
            trk_sl = lambda: (s for s in trkstreamlines)
            save_trk_heavy_duty(outpath,
                                streamlines=trk_sl,
                                affine=target_affine,
                                header=myheader)
    else:
        print("No need to change affine, bring to new path")
        if isinstance(origin_path, str):
            copyfile(origin_path, outpath)
        else:
            myheader = create_tractogram_header(outpath, *trk_header)
            save_trk_heavy_duty(outpath,
                                streamlines=trkstreamlines,
                                affine=target_affine,
                                header=myheader)