예제 #1
0
def atlasRegistrationBatch(patientsDTI, i):
    #for i in range(len(patientsDTI)):
    sub = patientsDTI[i]

    #input/output paths from QSI prep
    preop3T = join(pathQSIPREP, "qsiprep", f"sub-{sub}", "anat",
                   f"sub-{sub}_desc-preproc_T1w.nii.gz")
    preop3Tmask = join(pathQSIPREP, "qsiprep", f"sub-{sub}", "anat",
                       f"sub-{sub}_desc-brain_mask.nii.gz")
    pathAtlasRegistration = join(pathTractography, f"sub-{sub}",
                                 "atlasRegistration")
    preop3Tbrain = join(pathAtlasRegistration, "brain.nii.gz")
    preop3TSTD = join(pathAtlasRegistration,
                      utils.baseSplitextNiiGz(preop3T)[2] + "_std.nii.gz")
    #preop3TmaskExpand = join(pathAtlasRegistration, "brain_maskExpanded.nii.gz")
    utils.checkPathAndMake(pathTractography, pathAtlasRegistration)

    #MNI registration
    if not utils.checkIfFileExists(
            join(pathAtlasRegistration, "mni_fnirt.nii.gz")):
        #extract brain
        getBrainFromMask(preop3T, preop3Tmask, preop3Tbrain)
        utils.show_slices(preop3Tbrain,
                          save=True,
                          saveFilename=join(pathAtlasRegistration,
                                            "brain.png"))
        #expand Brain mask
        register_MNI_to_preopT1(preop3T, preop3Tbrain, MNItemplatePath,
                                MNItemplateBrainPath, "mni",
                                pathAtlasRegistration)
        utils.show_slices(preop3T,
                          save=True,
                          saveFilename=join(pathAtlasRegistration,
                                            "preop3T.png"))
        utils.show_slices(join(pathAtlasRegistration, "mni_flirt.nii.gz"),
                          save=True,
                          saveFilename=join(pathAtlasRegistration,
                                            "mni_flirt.png"))
        utils.show_slices(join(pathAtlasRegistration, "mni_fnirt.nii.gz"),
                          save=True,
                          saveFilename=join(pathAtlasRegistration,
                                            "mni_fnirt.png"))

    #apply warp to atlases
    applywarp_to_atlas(atlases.getAllAtlasPaths(),
                       preop3TSTD,
                       join(pathAtlasRegistration, "mni_warp.nii.gz"),
                       pathAtlasRegistration,
                       isDir=False)
예제 #2
0
def batchStructuralConnectivity(patientsDTI, i, batchBegin, batchEnd,
                                pathTractography, pathQSIPREP, atlasList):
    #for i in range(len(patientsDTI)):
    sub = patientsDTI[i]
    print(f"\n\n\n\n{sub}")
    if utils.getSubType(sub) == "control": ses = "control3T"
    else: ses = "preop3T"
    #input/output paths from QSI prep
    pathTracts = join(pathTractography, f"sub-{sub}", "tracts")
    pathDWI = join(pathQSIPREP, "qsiprep", f"sub-{sub}", f"ses-{ses}", "dwi",
                   f"sub-{sub}_ses-{ses}_space-T1w_desc-preproc_dwi.nii.gz")
    pathTRK = f"{join(pathTracts, utils.baseSplitextNiiGz(pathDWI)[2])}.trk.gz"
    pathFIB = f"{join(pathTracts, utils.baseSplitextNiiGz(pathDWI)[2])}.fib.gz"
    preop3T = join(pathQSIPREP, "qsiprep", f"sub-{sub}", "anat",
                   f"sub-{sub}_desc-preproc_T1w.nii.gz")
    pathAtlasRegistration = join(pathTractography, f"sub-{sub}",
                                 "atlasRegistration")
    pathStructuralConnectivity = join(pathTractography, f"sub-{sub}",
                                      "structuralConnectivity")
    utils.checkPathAndMake(pathTractography, pathStructuralConnectivity)

    for a in range(batchBegin, batchEnd):
        atlas = join(pathAtlasRegistration,
                     utils.baseSplitextNiiGz(atlasList[a])[0])
        if utils.checkIfFileExists(atlas):
            output = join(pathStructuralConnectivity, f"sub-{sub}")
            if not utils.checkIfFileExists(
                    output +
                    f".{utils.baseSplitextNiiGz(atlasList[a])[2]}.count.pass.connectogram.txt",
                    returnOpposite=True):
                t0 = time.time()
                tractography.getStructuralConnectivity(
                    BIDS, dsiStudioSingularityPatah, pathFIB, pathTRK, preop3T,
                    atlas, output)
                utils.calculateTimeToComplete(t0, time.time(),
                                              len(patientsDTI), i)
def applywarp_to_atlas(atlas_directory, atlas_paths, preop_T1_std, mni_warp,
                       atlas_registration_output_path):
    full_atlas_paths = []
    for a in range(len(atlas_paths)):
        utils.checkPathError(
            f"{os.path.join(atlas_directory, atlas_paths[a])}")
        full_atlas_paths.append(
            f"{os.path.join(atlas_directory, atlas_paths[a])}")
    utils.checkPathError(mni_warp)
    utils.checkPathError(preop_T1_std)
    utils.checkPathError(atlas_registration_output_path)
    for a in range(len(full_atlas_paths)):
        atlasName = basename(splitext(splitext(full_atlas_paths[a])[0])[0])
        outputAtlasName = join(atlas_registration_output_path,
                               atlasName + ".nii.gz")
        #if not os.path.exists(outputAtlasName):
        print(
            f"\r{np.round((a+1)/len(full_atlas_paths)*100,2)}%   {atlasName[0:20]}                         ",
            end="\r")
        if not utils.checkIfFileExists(outputAtlasName, printBOOL=False):
            cmd = f"applywarp -i { full_atlas_paths[a]} -r {preop_T1_std} -w {mni_warp} --interp=nn -o {outputAtlasName}"
            os.system(cmd)
  
  feature_name = "absolute_slope"
  location_feature = join(BIDS, datasetiEEG_spread, "single_features", f"sub-{RID}" )
  location_abs_slope_basename = f"{splitext(fname)[0]}_{feature_name}.pickle"
  location_abs_slope_file = join(location_feature, location_abs_slope_basename)
  
  feature_name = "line_length"
  location_line_length_basename = f"{splitext(fname)[0]}_{feature_name}.pickle"
  location_line_length_file = join(location_feature, location_line_length_basename)
  
  feature_name = "power_broadband"
  location_power_broadband_basename = f"{splitext(fname)[0]}_{feature_name}.pickle"
  location_power_broadband = join(location_feature, location_power_broadband_basename)
  
 
  if utils.checkIfFileExists( spread_location_file , printBOOL=False) and utils.checkIfFileExists( location_abs_slope_file , printBOOL=False):
      #print("\n\n\n\nSPREAD FILE EXISTS\n\n\n\n")
  
      if model_ID == "WN" or model_ID == "CNN" or model_ID == "LSTM":
          with open(spread_location_file, 'rb') as f:[probWN, probCNN, probLSTM, data_scalerDS, channels, window, skipWindow, secondsBefore, secondsAfter] = pickle.load(f)
          
      
      if model_ID == "WN":
          #print(model_ID)
          prob_array= probWN
      elif model_ID == "CNN":
          #print(model_ID)
          prob_array= probCNN
      elif model_ID == "LSTM":
          #print(model_ID)
          prob_array= probLSTM
예제 #5
0
 if utils.checkIfFileExistsGlob(T00) and utils.checkIfFileExistsGlob(T00electrodes): 
     #check if freesurfer has been run on preferred session
     if utils.checkIfFileExistsGlob( join(freesurferReconAllDir, f"sub-{sub}", f"ses-{preferredSurface}", "freesurfer", "surf", "lh.pial" ) ) and utils.checkIfFileExistsGlob( join(freesurferReconAllDir, f"sub-{sub}", f"ses-{preferredSurface}", "freesurfer", "surf", "rh.pial" ) ):
             sessionPath = np.array(glob( join(freesurferReconAllDir, f"sub-{sub}", f"ses-{preferredSurface}")))[0]
             session = basename(sessionPath)[4:]
             outputpath = join(BIDS, "derivatives", f"{derivativesOutput}", f"sub-{sub}", f"ses-{session}")
             utils.checkPathAndMake(outputpath, join(outputpath, "freesurfer", 'mri' ))
             utils.checkPathAndMake(outputpath, join(outputpath, "freesurfer", 'surf' ))
             utils.checkPathAndMake(outputpath, join(outputpath, "html"))
             utils.executeCommand(f"cp {join(sessionPath, 'freesurfer', 'mri' , 'orig_nu.mgz')}  {join(outputpath, 'freesurfer', 'mri' )} ")
             utils.executeCommand(f"cp {join(sessionPath, 'freesurfer', 'mri' , 'brain.mgz')}  {join(outputpath, 'freesurfer', 'mri' )} ")
             utils.executeCommand(f"cp {join(sessionPath, 'freesurfer', 'surf' , 'lh.pial')}  {join(outputpath, 'freesurfer', 'surf')} ")
             utils.executeCommand(f"cp {join(sessionPath, 'freesurfer', 'surf' , 'rh.pial')}  {join(outputpath, 'freesurfer', 'surf')} ")
             
     #check if freesurfer has been run on implant session
     elif utils.checkIfFileExists( join(freesurferReconAllDir, f"sub-{sub}", f"ses-{implantName}", "freesurfer", "surf", "lh.pial" ) ) and utils.checkIfFileExists( join(freesurferReconAllDir, f"sub-{sub}", f"ses-{implantName}", "freesurfer", "surf", "rh.pial" ) ):
             sessionPath = np.array(glob( join(freesurferReconAllDir, f"sub-{sub}", f"ses-{implantName}")))[0]
             session = basename(sessionPath)[4:]
             outputpath = join(BIDS, "derivatives", f"{derivativesOutput}", f"sub-{sub}", f"ses-{session}")
             utils.checkPathAndMake(outputpath, join(outputpath, "freesurfer", 'mri' ))
             utils.checkPathAndMake(outputpath, join(outputpath, "freesurfer", 'surf' ))
             utils.checkPathAndMake(outputpath, join(outputpath, "html"))
             utils.executeCommand(f"cp {join(sessionPath, 'freesurfer', 'mri' , 'orig_nu.mgz')}  {join(outputpath, 'freesurfer', 'mri' )} ")
             utils.executeCommand(f"cp {join(sessionPath, 'freesurfer', 'mri' , 'brain.mgz')}  {join(outputpath, 'freesurfer', 'mri' )} ")
             utils.executeCommand(f"cp {join(sessionPath, 'freesurfer', 'surf' , 'lh.pial')}  {join(outputpath, 'freesurfer', 'surf')} ")
             utils.executeCommand(f"cp {join(sessionPath, 'freesurfer', 'surf' , 'rh.pial')}  {join(outputpath, 'freesurfer', 'surf')} ")
     else:
         continue
     #check if lh and rh.pial have been moved to implantRenders derivatives path properly 
     if utils.checkIfFileDoesNotExist(f"{outputpath}/html/brain.glb", returnOpposite=False):
         if utils.checkIfFileExists(f"{join(outputpath, 'freesurfer', 'surf' ,'lh.pial')}" ) and utils.checkIfFileExists( f"{join(outputpath, 'freesurfer', 'surf' ,'rh.pial')}") :       
예제 #6
0
def calculate_mean_rank_deep_learning(i,
                                      patientsWithseizures,
                                      version,
                                      threshold=0.6,
                                      smoothing=20,
                                      model_ID="WN",
                                      secondsAfter=180,
                                      secondsBefore=180,
                                      type_of_overlap="soz",
                                      override_soz=False,
                                      seconds_active=None,
                                      tanh=False):

    #%%
    RID = np.array(patientsWithseizures["subject"])[i]
    idKey = np.array(patientsWithseizures["idKey"])[i]
    seizure_length = patientsWithseizures.length[i]

    print(RID)
    #CHECKING IF SPREAD FILES EXIST

    fname = DataJson.get_fname_ictal(RID,
                                     "Ictal",
                                     idKey,
                                     dataset=datasetiEEG,
                                     session=session,
                                     startUsec=None,
                                     stopUsec=None,
                                     startKey="EEC",
                                     secondsBefore=secondsBefore,
                                     secondsAfter=secondsAfter)

    spread_location = join(BIDS, datasetiEEG_spread, f"v{version:03d}",
                           f"sub-{RID}")
    spread_location_file_basename = f"{splitext(fname)[0]}_spread.pickle"
    spread_location_file = join(spread_location, spread_location_file_basename)

    feature_name = "absolute_slope"
    location_feature = join(BIDS, datasetiEEG_spread, "single_features",
                            f"sub-{RID}")
    location_abs_slope_basename = f"{splitext(fname)[0]}_{feature_name}.pickle"
    location_abs_slope_file = join(location_feature,
                                   location_abs_slope_basename)

    feature_name = "line_length"
    location_line_length_basename = f"{splitext(fname)[0]}_{feature_name}.pickle"
    location_line_length_file = join(location_feature,
                                     location_line_length_basename)

    feature_name = "power_broadband"
    location_power_broadband_basename = f"{splitext(fname)[0]}_{feature_name}.pickle"
    location_power_broadband = join(location_feature,
                                    location_power_broadband_basename)

    if seconds_active is None:
        seconds = np.arange(0, 60 * 2 + 1, 1)
    else:
        seconds = seconds_active
    percent_active_vec = np.zeros(len(seconds))
    percent_active_vec[:] = np.nan

    if utils.checkIfFileExists(spread_location_file,
                               printBOOL=False) and utils.checkIfFileExists(
                                   location_abs_slope_file, printBOOL=False):
        #print("\n\n\n\nSPREAD FILE EXISTS\n\n\n\n")

        #Getting SOZ labels
        RID_keys = list(jsonFile["SUBJECTS"].keys())
        hup_num_all = [jsonFile["SUBJECTS"][x]["HUP"] for x in RID_keys]

        hup_int = hup_num_all[RID_keys.index(RID)]
        hup_int_pad = f"{hup_int:03d}"

        #i_patient = patients.index(f"HUP{hup_int_pad}")
        #HUP = patients[i_patient]
        #hup = int(HUP[3:])

        #channel_names = labels[i_patient]
        #soz_ind = np.where(soz[i_patient] == 1)[0]
        #soz_channel_names = np.array(channel_names)[soz_ind]

        #resected_ind = np.where(resect[i_patient] == 1)[0]
        #resected_channel_names = np.array(channel_names)[resected_ind]

        #ignore_ind = np.where(ignore[i_patient] == 1)[0]
        #ignore__channel_names = np.array(channel_names)[ignore_ind]

        #soz_channel_names = echobase.channel2std(soz_channel_names)
        #resected_channel_names = echobase.channel2std(resected_channel_names)
        #ignore__channel_names = echobase.channel2std(ignore__channel_names)

        #soz_channel_names = channel2std_ECoG(soz_channel_names)
        #resected_channel_names = channel2std_ECoG(resected_channel_names)
        #ignore__channel_names = channel2std_ECoG(ignore__channel_names)
        #%
        THRESHOLD = threshold
        SMOOTHING = smoothing  #in seconds

        if model_ID == "WN" or model_ID == "CNN" or model_ID == "LSTM":
            with open(spread_location_file, 'rb') as f:
                [
                    probWN, probCNN, probLSTM, data_scalerDS, channels, window,
                    skipWindow, secondsBefore, secondsAfter
                ] = pickle.load(f)

        if model_ID == "WN":
            #print(model_ID)
            prob_array = probWN
        elif model_ID == "CNN":
            #print(model_ID)
            prob_array = probCNN
        elif model_ID == "LSTM":
            #print(model_ID)
            prob_array = probLSTM
        elif model_ID == "absolute_slope":
            if utils.checkIfFileExists(location_abs_slope_file,
                                       printBOOL=False):
                with open(location_abs_slope_file, 'rb') as f:
                    [
                        abs_slope_normalized, abs_slope_normalized_tanh,
                        channels, window, skipWindow, secondsBefore,
                        secondsAfter
                    ] = pickle.load(f)
                if not tanh:
                    #abs_slope_normalized = utils.apply_arctanh(abs_slope_normalized_tanh)/1e-1
                    abs_slope_normalized / np.max(abs_slope_normalized)
                    abs_slope_normalized = abs_slope_normalized / np.max(
                        abs_slope_normalized)
                    prob_array = abs_slope_normalized
                else:
                    prob_array = abs_slope_normalized_tanh
            else:
                print(
                    f"{i} {RID} file does not exist {location_abs_slope_file}\n"
                )
        elif model_ID == "line_length":
            if utils.checkIfFileExists(location_line_length_file,
                                       printBOOL=False):
                with open(location_line_length_file, 'rb') as f:
                    [
                        probLL, probLL_tanh, channels, window, skipWindow,
                        secondsBefore, secondsAfter
                    ] = pickle.load(f)
                if not tanh:
                    probLL = probLL / np.max(probLL)
                    prob_array = probLL
                else:
                    prob_array = probLL_tanh
            else:
                print(
                    f"{i} {RID} file does not exist {location_line_length_file}\n"
                )
        elif model_ID == "power_broadband":
            if utils.checkIfFileExists(location_power_broadband,
                                       printBOOL=False):
                with open(location_power_broadband, 'rb') as f:
                    [
                        power_total, power_total_tanh, channels, window,
                        skipWindow, secondsBefore, secondsAfter
                    ] = pickle.load(f)
                if not tanh:
                    #power_total = utils.apply_arctanh(power_total_tanh)/7e-2
                    power_total = power_total / np.max(power_total)
                    prob_array = power_total

                else:
                    prob_array = power_total_tanh

            else:
                print(
                    f"{i} {RID} file does not exist {location_power_broadband}\n"
                )
        else:
            print("model ID not recognized. Using Wavenet")
            prob_array = probWN

        #####
        seizure_start = int((secondsBefore - 0) / skipWindow)
        seizure_stop = int((secondsBefore + seizure_length) / skipWindow)

        probability_arr_movingAvg, probability_arr_threshold = prob_threshold_moving_avg(
            prob_array, fsds, skip, threshold=THRESHOLD, smoothing=SMOOTHING)
        #sns.heatmap( probability_arr_movingAvg.T )
        #sns.heatmap( probability_arr_threshold.T)
        spread_start, seizure_start, spread_start_loc, channel_order, channel_order_labels = get_start_times(
            secondsBefore, skipWindow, fsds, channels, 0, seizure_length,
            probability_arr_threshold)

        channel_order_labels = remove_EGG_and_ref(channel_order_labels)
        channels2 = remove_EGG_and_ref(channels)

        channel_order_labels = channel2std_ECoG(channel_order_labels)
        channels2 = channel2std_ECoG(channels2)

        #print(soz_channel_names)
        #print(resected_channel_names)
        #print(channel_order_labels)

        #remove ignore electrodes from channel_order_labels
        #ignore_index = np.intersect1d(  channel_order_labels, ignore__channel_names, return_indices=True)
        #channel_order_labels[-ignore_index[1]]
        #channel_order_labels = np.delete(channel_order_labels, ignore_index[1])

        print(i)
        RID = np.array(patientsWithseizures["subject"])[i]
        seizure = np.array(patientsWithseizures["idKey"])[i]
        seizure_length = patientsWithseizures.length[i]

        #atlas
        atlas = "BN_Atlas_246_1mm"
        atlas = "AAL3v1_1mm"
        #atlas = "AAL2"
        #atlas = "HarvardOxford-sub-ONLY_maxprob-thr25-1mm"

        atlas_names_short = list(atlas_files["STANDARD"].keys())
        atlas_names = [
            atlas_files["STANDARD"][x]["name"] for x in atlas_names_short
        ]
        ind = np.where(f"{atlas}.nii.gz" == np.array(atlas_names))[0][0]
        atlas_label_name = atlas_files["STANDARD"][
            atlas_names_short[ind]]["label"]

        atlas_label = pd.read_csv(join(paths.ATLAS_LABELS, atlas_label_name))
        atlas_label_names = np.array(atlas_label.iloc[1:, 1])
        atlas_label_region_numbers = np.array(atlas_label.iloc[1:, 0])

        connectivity_loc = join(
            paths.BIDS_DERIVATIVES_STRUCTURAL_MATRICES, f"sub-{RID}",
            "ses-research3Tv[0-9][0-9]", "matrices",
            f"sub-{RID}.{atlas}.count.pass.connectogram.txt")

        connectivity_loc_glob = glob.glob(connectivity_loc)

        if len(connectivity_loc_glob) > 0:
            connectivity_loc_path = connectivity_loc_glob[0]

            sc = utils.read_DSI_studio_Txt_files_SC(connectivity_loc_path)
            sc = sc / sc.max()
            #sc = utils.log_normalize_adj(sc)
            #sc=utils.log_normalize_adj(sc)
            sc_region_labels = utils.read_DSI_studio_Txt_files_SC_return_regions(
                connectivity_loc_path, atlas).astype(ind)

            atlas_localization_path = join(
                paths.BIDS_DERIVATIVES_ATLAS_LOCALIZATION, f"sub-{RID}",
                f"ses-{session}",
                f"sub-{RID}_ses-{session}_desc-atlasLocalization.csv")
            if utils.checkIfFileExists(atlas_localization_path,
                                       printBOOL=False):
                atlas_localization = pd.read_csv(atlas_localization_path)

                atlas_localization.channel = channel2std_ECoG(
                    atlas_localization.channel)
                #get channels in hipp

                ##############################################################
                ##############################################################
                ##############################################################
                ##############################################################
                ##############################################################
                ##############################################################
                #find the activation time between channels

                spread_start
                channels2

                sc_vs_time = pd.DataFrame(columns=["ch1", "ch2", "time", "sc"])

                ch1 = 1
                ch2 = 10
                """
                for ch1 in range(len(channels2)):
                    for ch2 in range(len(channels2)):
                        time_between = abs(spread_start[ch1] - spread_start[ch2])*skipWindow
                        if spread_start[ch1]*skipWindow > seizure_length:
                            time_between = np.nan
                            
                        if spread_start[ch2]*skipWindow > seizure_length:
                            time_between = np.nan
                        
                        ch1_name = channels2[ch1]
                        ch2_name = channels2[ch2]
                        ch1_ind_atlas_loc_ind = np.where(ch1_name == atlas_localization.channel )[0]
                        ch2_ind_atlas_loc_ind = np.where(ch2_name == atlas_localization.channel )[0]
                        if len(ch1_ind_atlas_loc_ind) >0:
                            if len(ch2_ind_atlas_loc_ind)>0:
                                region_num1 = atlas_localization[f"{atlas}_region_number"][ch1_ind_atlas_loc_ind[0]]
                                region_num2 = atlas_localization[f"{atlas}_region_number"][ch2_ind_atlas_loc_ind[0]]
                                
                                
                                
                                if region_num1 in sc_region_labels:
                                    if region_num2 in sc_region_labels:
                                        if not region_num2 == region_num1:
                                            ch1_sc_ind = np.where(region_num1 == sc_region_labels)[0][0]
                                            ch2_sc_ind = np.where(region_num2 == sc_region_labels)[0][0]
                                            connectivity = sc[ch1_sc_ind, ch2_sc_ind]
                                            
                                            sc_vs_time = sc_vs_time.append( dict(ch1 = ch1_name, ch2 = ch2_name , time = time_between, sc = connectivity) , ignore_index=True)
                        
                sc_vs_time["inverse"] = 1/sc_vs_time["time"]
                fig, axes = utils.plot_make()
                sns.regplot(data = sc_vs_time, x = "sc", y = "time", ax = axes, scatter_kws=dict(s = 1))
                
                sc_vs_time.loc[sc_vs_time["time"] == 0, "inverse"] = 0
                
                sc_vs_time_nanfill = sc_vs_time.fillna(0)
                
                axes.set_ylim([0,1])            
                axes.set_xlim([0,0.5])            
                
                """
                #get the average time each region was active
                region_times = pd.DataFrame(columns=["region", "time"])
                for r in range(len(sc_region_labels)):
                    reg_num = sc_region_labels[r]
                    channels_in_reg = np.where(
                        reg_num ==
                        atlas_localization[f"{atlas}_region_number"])[0]

                    reg_starts = []
                    if len(channels_in_reg) > 0:
                        for ch in range(len(channels_in_reg)):
                            ch_name = atlas_localization.channel[
                                channels_in_reg[ch]]
                            ch_in_spread = np.where(ch_name == channels2)[0]
                            if len(ch_in_spread) > 0:
                                if spread_start[ch_in_spread[
                                        0]] * skipWindow > seizure_length:
                                    reg_starts.append(np.nan)
                                else:
                                    reg_starts.append(
                                        spread_start[ch_in_spread[0]] *
                                        skipWindow)
                        reg_mean = np.nanmean(reg_starts)
                    else:
                        reg_mean = np.nan
                    region_times = region_times.append(dict(region=reg_num,
                                                            time=reg_mean),
                                                       ignore_index=True)

                sc_vs_time_reg = pd.DataFrame(
                    columns=["reg1", "reg2", "time", "sc"])

                for r1 in range(len(sc)):
                    for r2 in range(r1 + 1, len(sc)):
                        connectvity = sc[r1, r2]
                        time_diff = abs(region_times.iloc[r1, 1] -
                                        region_times.iloc[r2, 1])
                        sc_vs_time_reg = sc_vs_time_reg.append(
                            dict(reg1=sc_region_labels[r1],
                                 reg2=sc_region_labels[r2],
                                 time=time_diff,
                                 sc=connectvity),
                            ignore_index=True)

                #fig, axes = utils.plot_make()
                #sns.scatterplot(data = sc_vs_time_reg, x = "sc", y = "time", linewidth = 0, s=5)

                fig, axes = utils.plot_make(size_length=5)
                g = sns.regplot(data=sc_vs_time_reg,
                                x="sc",
                                y="time",
                                scatter_kws=dict(linewidth=0, s=50),
                                ci=None,
                                line_kws=dict(lw=7, color="black"))

                x = sc_vs_time_reg["sc"]
                y = sc_vs_time_reg["time"]
                y_nanremoved = y[~np.isnan(y)]
                x_nanremoved = x[~np.isnan(y)]
                corr = spearmanr(x_nanremoved, y_nanremoved)
                corr = pearsonr(x_nanremoved, y_nanremoved)
                corr_r = np.round(corr[0], 3)
                corr_p = np.round(corr[1], 8)

                axes.set_title(f"r = {corr_r}, p = {corr_p}")
                #axes.set_ylim([-0.033,0.2])
                for l, tick in enumerate(axes.xaxis.get_major_ticks()):
                    tick.label.set_fontsize(6)
                axes.tick_params(width=4)
                # change all spines
                for axis in ['top', 'bottom', 'left', 'right']:
                    axes.spines[axis].set_linewidth(6)

#%%
            plt.savefig(join(paths.SEIZURE_SPREAD_FIGURES, "connectivity",
                             f"sc_vs_spread_time_SINGLE_PATIENT_{RID}_01.pdf"),
                        bbox_inches='tight')

            #sns.regplot(data = sc_vs_time_reg, x = "time", y = "sc", scatter_kws=dict(s = 1))

            sc_vs_time_reg["inverse"] = 1 / sc_vs_time_reg["time"]
            #sc_vs_time_reg.loc[sc_vs_time["time"] == 0, "inverse"] = 0

            sc_vs_time_reg_fill = copy.deepcopy(sc_vs_time_reg)
            sc_vs_time_reg_fill = sc_vs_time_reg_fill.fillna(0)

            sns.scatterplot(data=sc_vs_time_reg_fill,
                            x="sc",
                            y="inverse",
                            linewidth=0,
                            s=5)

            fig, axes = utils.plot_make(size_length=5)
            g = sns.regplot(data=sc_vs_time_reg_fill,
                            x="sc",
                            y="time",
                            scatter_kws=dict(linewidth=0, s=20),
                            ci=None,
                            line_kws=dict(lw=5, color="black"))

            x = sc_vs_time_reg_fill["sc"]
            y = sc_vs_time_reg_fill["time"]
            y_nanremoved = y[~np.isnan(y)]
            x_nanremoved = x[~np.isnan(y)]
            corr = spearmanr(x_nanremoved, y_nanremoved)
            corr = pearsonr(x_nanremoved, y_nanremoved)
            corr_r = np.round(corr[0], 2)
            corr_p = np.round(corr[1], 10)

            axes.set_title(f"{corr_r}, p = {corr_p}")
            #axes.set_ylim([-0.033,0.2])
            for i, tick in enumerate(axes.xaxis.get_major_ticks()):
                tick.label.set_fontsize(6)
            axes.tick_params(width=4)
            # change all spines
            for axis in ['top', 'bottom', 'left', 'right']:
                axes.spines[axis].set_linewidth(6)

            plt.savefig(join(paths.SEIZURE_SPREAD_FIGURES, "connectivity",
                             f"sc_vs_spread_time_SINGLE_PATIENT_{RID}_02.pdf"),
                        bbox_inches='tight')
예제 #7
0
                 newNameFi = listOfFiles[fi].replace(old, new)
                 utils.executeCommand(f"mv {listOfFiles[fi]} {newNameFi}")



#%% add Intended for to fmaps
BIDSdir = BIDSserver

subFolders = glob(BIDSdir + "/*/")


for l in range(len(subFolders)):

    folders = glob(subFolders[l]+ "/*/")
    for f in range(len(folders)):
        if utils.checkIfFileExists( join( folders[f], "fmap")  ):
            session = basename(folders[f][:-1])
            subject = basename(subFolders[l][:-1])

            files = glob( join( folders[f], "fmap/*.json")   )
            for fi in range(len(files)):
                if "phasediff" in files[fi] or "_epi" in files[fi]:
                    with open(files[fi]) as fas: jsonfile = json.load(fas)


                    if utils.checkIfFileExists(join(join( folders[f], "dwi", f"{subject}_{session}_dwi.nii.gz") )):
                        jsonfile['IntendedFor'] = f"{session}/dwi/{subject}_{session}_dwi.nii.gz"
                        with open(files[fi], 'w', encoding='utf-8') as fas: json.dump(jsonfile, fas, ensure_ascii=False, indent=4)
                    else:
                        if 'IntendedFor' in jsonfile:
                            del jsonfile['IntendedFor']
예제 #8
0
    preprocessed_location = join(BIDS, datasetiEEG_preprocessed,
                                 f"v{version:03d}", f"sub-{RID}")
    if version == 15:
        version_similar = 14
        preprocessed_location = join(BIDS, datasetiEEG_preprocessed,
                                     f"v{version_similar:03d}", f"sub-{RID}")
    utils.checkPathAndMake(preprocessed_location,
                           preprocessed_location,
                           printBOOL=False)

    preprocessed_file_basename = f"{splitext(fname)[0]}_preprocessed.pickle"
    preprocessed_file = join(preprocessed_location, preprocessed_file_basename)

    #check if preproceed file exists, and if it does, load that instead of dowloading ieeg data and performing preprocessing on that
    if utils.checkIfFileExists(preprocessed_file, printBOOL=False):
        print(f"\n{RID} {i} PREPROCESSED FILE EXISTS")
        with open(preprocessed_file, 'rb') as f:
            [
                dataII_scaler, data_scaler, dataII_scalerDS, data_scalerDS,
                channels
            ] = pickle.load(f)
        df, fs, ictalStartIndex, ictalStopIndex = DataJson.get_precitalIctalPostictal(
            RID,
            "Ictal",
            idKey,
            username,
            password,
            BIDS=BIDS,
            dataset=datasetiEEG,
            session=session,
예제 #9
0
def calculate_mean_rank_deep_learning(i, patientsWithseizures, version, threshold=0.6, smoothing = 20, model_ID="WN", secondsAfter=180, secondsBefore=180, tanh = False):
    #override_soz if True, then if there are no soz marking, then use the resection markings and assume those are SOZ contacts
    RID = np.array(patientsWithseizures["subject"])[i]
    idKey = np.array(patientsWithseizures["idKey"])[i]
    seizure_length = patientsWithseizures.length[i]
    
    
    #CHECKING IF SPREAD FILES EXIST

    fname = DataJson.get_fname_ictal(RID, "Ictal", idKey, dataset= datasetiEEG, session = session, startUsec = None, stopUsec= None, startKey = "EEC", secondsBefore = secondsBefore, secondsAfter = secondsAfter )
    
    spread_location = join(BIDS, datasetiEEG_spread, f"v{version:03d}", f"sub-{RID}" )
    spread_location_file_basename = f"{splitext(fname)[0]}_spread.pickle"
    spread_location_file = join(spread_location, spread_location_file_basename)
    
    
    feature_name = "absolute_slope"
    location_feature = join(BIDS, datasetiEEG_spread, "single_features", f"sub-{RID}" )
    location_abs_slope_basename = f"{splitext(fname)[0]}_{feature_name}.pickle"
    location_abs_slope_file = join(location_feature, location_abs_slope_basename)
    
    feature_name = "line_length"
    location_line_length_basename = f"{splitext(fname)[0]}_{feature_name}.pickle"
    location_line_length_file = join(location_feature, location_line_length_basename)
    
    feature_name = "power_broadband"
    location_power_broadband_basename = f"{splitext(fname)[0]}_{feature_name}.pickle"
    location_power_broadband = join(location_feature, location_power_broadband_basename)
   
    
    if utils.checkIfFileExists( spread_location_file , printBOOL=False) and utils.checkIfFileExists( location_abs_slope_file , printBOOL=False):
        #print("\n\n\n\nSPREAD FILE EXISTS\n\n\n\n")
    
    
    
        #Getting SOZ labels
        RID_keys =  list(jsonFile["SUBJECTS"].keys() )
        hup_num_all = [jsonFile["SUBJECTS"][x]["HUP"]  for  x   in  RID_keys]
        
        hup_int = hup_num_all[RID_keys.index(RID)]
        hup_int_pad = f"{hup_int:03d}" 
        
        i_patient = patients.index(f"HUP{hup_int_pad}")
        HUP = patients[i_patient]
        hup = int(HUP[3:])
        
    
        
        channel_names = labels[i_patient]
        soz_ind = np.where(soz[i_patient] == 1)[0]
        soz_channel_names = np.array(channel_names)[soz_ind]
        
        resected_ind = np.where(resect[i_patient] == 1)[0]
        resected_channel_names = np.array(channel_names)[resected_ind]
        
        ignore_ind = np.where(ignore[i_patient] == 1)[0]
        ignore__channel_names = np.array(channel_names)[ignore_ind]
        
        soz_channel_names = echobase.channel2std(soz_channel_names)
        resected_channel_names = echobase.channel2std(resected_channel_names)
        #ignore__channel_names = echobase.channel2std(ignore__channel_names)
        

        
        soz_channel_names = channel2std_ECoG(soz_channel_names)
        resected_channel_names = channel2std_ECoG(resected_channel_names)
        ignore__channel_names = channel2std_ECoG(ignore__channel_names)
        #%
        THRESHOLD = threshold
        SMOOTHING = smoothing #in seconds
        
    
        
        if model_ID == "WN" or model_ID == "CNN" or model_ID == "LSTM":
            with open(spread_location_file, 'rb') as f:[probWN, probCNN, probLSTM, data_scalerDS, channels, window, skipWindow, secondsBefore, secondsAfter] = pickle.load(f)
            
        
        if model_ID == "WN":
            #print(model_ID)
            prob_array= probWN
        elif model_ID == "CNN":
            #print(model_ID)
            prob_array= probCNN
        elif model_ID == "LSTM":
            #print(model_ID)
            prob_array= probLSTM
        elif model_ID == "absolute_slope":
            if utils.checkIfFileExists(location_abs_slope_file, printBOOL=False):
                with open(location_abs_slope_file, 'rb') as f:[abs_slope_normalized, abs_slope_normalized_tanh, channels, window, skipWindow, secondsBefore, secondsAfter] = pickle.load(f)
                if not tanh:
                    #abs_slope_normalized = utils.apply_arctanh(abs_slope_normalized_tanh)/1e-1 
                    abs_slope_normalized/np.max(abs_slope_normalized)
                    abs_slope_normalized = abs_slope_normalized/np.max(abs_slope_normalized)
                    prob_array=  abs_slope_normalized
                else:
                    prob_array= abs_slope_normalized_tanh
            else: 
                print(f"{i} {RID} file does not exist {location_abs_slope_file}\n")
        elif model_ID == "line_length":
            if utils.checkIfFileExists(location_line_length_file, printBOOL=False):
                with open(location_line_length_file, 'rb') as f:[probLL, probLL_tanh, channels, window, skipWindow, secondsBefore, secondsAfter] = pickle.load(f)
                if not tanh:
                    probLL = probLL/np.max(probLL)
                    prob_array= probLL
                else:
                    prob_array= probLL_tanh
            else: 
                print(f"{i} {RID} file does not exist {location_line_length_file}\n")
        elif model_ID == "power_broadband":
            if utils.checkIfFileExists(location_power_broadband, printBOOL=False):
                with open(location_power_broadband, 'rb') as f:[power_total, power_total_tanh, channels, window, skipWindow, secondsBefore, secondsAfter] = pickle.load(f)
                if not tanh:
                    #power_total = utils.apply_arctanh(power_total_tanh)/7e-2  
                    power_total = power_total/np.max(power_total)
                    prob_array=  power_total
                    
                else:
                    prob_array= power_total_tanh
            
            else: 
                print(f"{i} {RID} file does not exist {location_power_broadband}\n")
        else:
            print("model ID not recognized. Using Wavenet")
            prob_array= probWN
        
        #####
        seizure_start = int((secondsBefore-0)/skipWindow)
        seizure_stop = int((secondsBefore + seizure_length)/skipWindow)
        
        probability_arr_movingAvg, probability_arr_threshold = prob_threshold_moving_avg(prob_array, fsds, skip, threshold = THRESHOLD, smoothing = SMOOTHING)
        #sns.heatmap( probability_arr_movingAvg.T )      
        #sns.heatmap( probability_arr_threshold.T)    
        spread_start, seizure_start, spread_start_loc, channel_order, channel_order_labels = get_start_times(secondsBefore, skipWindow, fsds, channels, 0, seizure_length, probability_arr_threshold)
   
        
        channel_order_labels = remove_EGG_and_ref(channel_order_labels)
        channels2 = remove_EGG_and_ref(channels)
        
        channel_order_labels = channel2std_ECoG(channel_order_labels)
        channels2 = channel2std_ECoG(channels2)
        
        #print(soz_channel_names)
        #print(resected_channel_names)
        #print(channel_order_labels)
    
    
        #remove ignore electrodes from channel_order_labels
        #ignore_index = np.intersect1d(  channel_order_labels, ignore__channel_names, return_indices=True)
        #channel_order_labels[-ignore_index[1]]
        #channel_order_labels = np.delete(channel_order_labels, ignore_index[1])
        
        
        atlas_localization_path = join(paths.BIDS_DERIVATIVES_ATLAS_LOCALIZATION, f"sub-{RID}", f"ses-{session}", f"sub-{RID}_ses-{session}_desc-atlasLocalization.csv")
        if utils.checkIfFileExists(atlas_localization_path, printBOOL=False):
            atlas_localization = pd.read_csv(atlas_localization_path)
            
            
            atlas_localization.channel = channel2std_ECoG(atlas_localization.channel)
            
            
            for r in range(len(atlas_localization)):
                reg_AAL = atlas_localization.AAL_label[r]
                reg_BNA = atlas_localization.BN_Atlas_246_1mm_label[r]
                reg_HO = atlas_localization["HarvardOxford-combined_label"][r]
                
                
        coord_start_times = pd.DataFrame(columns = ["channel", "x", "y", "z", "start_time"])
        
        coord_start_times["channel"] = channels2
        
        for ch in range(len(coord_start_times)):
            x = np.array(atlas_localization[channels2[ch]  == atlas_localization.channel]["x"])[0]
            y = np.array(atlas_localization[channels2[ch]  == atlas_localization.channel]["y"])[0]
            z = np.array(atlas_localization[channels2[ch]  == atlas_localization.channel]["z"])[0]
            
            x = np.array(atlas_localization[channels2[ch]  == atlas_localization.channel]["x"])[0]
            y = np.array(atlas_localization[channels2[ch]  == atlas_localization.channel]["y"])[0]
            z = np.array(atlas_localization[channels2[ch]  == atlas_localization.channel]["z"])[0]
            
            
            coord_start_times.loc[coord_start_times["channel"] == coord_start_times["channel"][ch]   ,    "x"] = x
            coord_start_times.loc[coord_start_times["channel"] == coord_start_times["channel"][ch]   ,    "y"] = y
            coord_start_times.loc[coord_start_times["channel"] == coord_start_times["channel"][ch]   ,    "z"] = z
            
            spread_start
            
            channel_start = spread_start[coord_start_times["channel"][ch] == channels2 ] * skipWindow
            if len(channel_start) > 0:
                channel_start = channel_start[0]
                if channel_start > seizure_length:
                    channel_start = np.nan
            else:
                channel_start = np.nan
            coord_start_times.loc[coord_start_times["channel"] == coord_start_times["channel"][ch]   ,    "start_time"] = channel_start
          
            
        t1_image = glob.glob(join(paths.BIDS_DERIVATIVES_ATLAS_LOCALIZATION, f"sub-{RID}", f"ses-implant01", "tmp", "orig_nu_std.nii.gz" ))[0]
        t1_image_brain = glob.glob(join(paths.BIDS_DERIVATIVES_ATLAS_LOCALIZATION, f"sub-{RID}", f"ses-implant01", "tmp", "brain_std.nii.gz" ))[0]
        img = nib.load(t1_image)
        img_brain = nib.load(t1_image_brain)
        #utils.show_slices(img, data_type = "img")
        img_data = img.get_fdata()
        brain_data = img_brain.get_fdata()
        affine = img.affine
        shape = img_data.shape
        img_data_total = copy.deepcopy(img_data)
        img_data_total[  np.where(img_data_total != 0)  ] = 0
        
        img_data_N = copy.deepcopy(img_data)
        img_data_N[  np.where(img_data_N != 0)  ] = 0
        
        
       
        for ch in range(len(coord_start_times)):
            print(f"\r{ch}/{len(coord_start_times)}    ", end = "\r")
            coord = coord_start_times.iloc[ch]
            radius = 40
            
            img_data_sphere = copy.deepcopy(img_data)
            img_data_sphere[  np.where(img_data_sphere != 0)  ] = 0
            
            coordinates = np.array(coord[["x", "y", "z"]]).astype(float)
            coordinates_voxels = utils.transform_coordinates_to_voxel(coordinates, affine)
            x,y,z = coordinates_voxels[0],coordinates_voxels[1],coordinates_voxels[2]
            sphere = utils.make_sphere_from_point(img_data_sphere, x,   y,  z, radius = radius)
            if not np.isnan(coord_start_times.start_time[ch]):
                img_data_N = img_data_N + sphere
                sphere[np.where(sphere >0)] = coord_start_times.start_time[ch]
                img_data_total = img_data_total + sphere
            
            #utils.show_slices(sphere, data_type = "data")
            
        utils.show_slices(img_data_N, data_type = "data")
        utils.show_slices(img_data_total, data_type = "data", cmap = "mako")
        
        
        utils.show_slices(brain_data, data_type = "data")
        
    
    
        img_data_avg = img_data_total/img_data_N
        img_data_avg[np.where(brain_data <= 0)] = np.nan
        
        utils.show_slices(img, data_type = "img")
        utils.show_slices(img_data_avg, data_type = "data", cmap = "mako")
        
        #img_data_avg[np.isnan(img_data_avg)] = seizure_length
        
        low, middle, high = 0.33,0.48,0.7
        slices_t1 = [   img_data[:, int((img_data.shape[1]*low)), : ] , img_data[:, int(img_data.shape[1]*middle), :] , img_data[:, int(img_data.shape[1]*high), :]   ]
        slices_heat = [   img_data_avg[:, int((img_data_avg.shape[1]*low)), : ] , img_data_avg[:, int(img_data_avg.shape[1]*middle), :] , img_data_avg[:, int(img_data_avg.shape[1]*high), :]   ]
        slices_brain = [   brain_data[:, int((brain_data.shape[1]*low)), : ] , brain_data[:, int(brain_data.shape[1]*middle), :] , brain_data[:, int(brain_data.shape[1]*high), :]   ]
        
        cmap1 = "gray"
        cmap2 = "Wistia_r"
        """
        fig, axes = utils.plot_make()
        #sns.heatmap(slices_t1[1], cmap=cmap1, ax = axes, square = True)
        axes.imshow(slices_t1[1].T, cmap=cmap1, origin="lower")
        pos = axes.imshow(slices_heat[1].T, cmap=cmap2, origin="lower")
        fig.colorbar(pos, ax=axes)
        """
        slice_image = slices_heat[1]
        
        mask = np.where(~np.isnan(slice_image))
        interp = interpolate.NearestNDInterpolator(np.transpose(mask), slice_image[mask])
        filled_data = interp(*np.indices(slice_image.shape))
        
        filled_data_copy_gaussian = scipy.ndimage.gaussian_filter(filled_data, sigma = 2)
        
        filled_data_copy = copy.deepcopy(filled_data_copy_gaussian)
        filled_data_copy[np.where(slices_brain[1] <= 0)] = np.nan

        plt.style.use('default')
        cmap1 = "gray"
        cmap2 = "Spectral"
        fig, axes = utils.plot_make()
        axes.imshow(slices_t1[1].T, cmap=cmap1, origin="lower")
        pos = axes.imshow(filled_data_copy.T, cmap=cmap2, origin="lower")
        fig.colorbar(pos, ax=axes)
    

        
        plt.savefig(join(paths.SEIZURE_SPREAD_FIGURES, "spread_by_coordinates", "spread_by_coordinates2.pdf"))
    def get_FunctionalConnectivity(self,
                                   sub,
                                   idKey,
                                   username,
                                   password,
                                   BIDS,
                                   dataset,
                                   session,
                                   functionalConnectivityPath=None,
                                   secondsBefore=30,
                                   secondsAfter=30,
                                   startKey="EEC",
                                   fsds=256,
                                   montage="bipolar",
                                   FCtype="pearson"):

        #checking if file is saved. If true, then just pull from that file
        if not functionalConnectivityPath == None:
            FCname = utils.baseSplitext(
                self.getEEGfileName(sub, "Ictal", idKey, BIDS, dataset,
                                    session, secondsBefore, secondsAfter))[:-4]
            fname = join(
                functionalConnectivityPath, FCname +
                f"functionalConnectivity_{FCtype}_{montage}_interictalPreictalIctalPostictal.pickle"
            )

        if utils.checkIfFileExists(fname, printBOOL=False):
            FC = utils.open_pickle(fname)
            return FC
        else:
            # get data
            AssociatedInterictalidKey = self.get_associatedInterictal(
                sub, idKey)
            seizure, fs, ictalStartIndex, ictalStopIndex = self.get_precitalIctalPostictal(
                sub,
                "Ictal",
                idKey,
                username,
                password,
                BIDS,
                dataset,
                session,
                secondsBefore=180,
                secondsAfter=180,
                load=True)

            interictal, fs = self.get_iEEGData(sub,
                                               "Interictal",
                                               AssociatedInterictalidKey,
                                               username,
                                               password,
                                               BIDS,
                                               dataset,
                                               session,
                                               startKey="Start",
                                               load=True)

            print("\nPreprocessing data: Filtering and Downsampling")
            ###filtering and downsampling
            if FCtype == "coherence":
                #Get Not prewhitened data (coherence measurement should not be used with pre-whitened data)
                _, _, _, seizureFilt, channels = echobase.preprocess(
                    seizure, fs, fs, montage=montage, prewhiten=False)
                _, _, _, interictalFilt, _ = echobase.preprocess(
                    interictal, fs, fs, montage=montage, prewhiten=False)
            else:
                #Get prewhitened data
                _, _, _, seizureFilt, channels = echobase.preprocess(
                    seizure, fs, fs, montage=montage, prewhiten=True)
                _, _, _, interictalFilt, _ = echobase.preprocess(
                    interictal, fs, fs, montage=montage, prewhiten=True)

            ictalStartIndexDS = int(ictalStartIndex * (fsds / fs))
            ictalStopIndexDS = int(ictalStopIndex * (fsds / fs))
            #downsample
            seizureFiltDS = self.downsample(seizureFilt, fs, fsds)
            interictalFiltDS = self.downsample(interictalFilt, fs, fsds)

            data = [
                interictalFiltDS, seizureFiltDS[:ictalStartIndexDS, :],
                seizureFiltDS[ictalStartIndexDS:ictalStopIndexDS + 1, :],
                seizureFiltDS[ictalStopIndexDS + 1:, :]
            ]
            data_FC = []
            if FCtype == "pearson":
                for x in range(4):
                    print(f"\n\n {x}/3")
                    data_FC.append(
                        list(echobase.pearson_wrapper(data[x], fsds)
                             [0]))  #ind at [0] because don't need the p-values
            if FCtype == "crossCorrelation":
                for x in range(4):
                    print(f"\n\n {x}/3")
                    data_FC.append(
                        list(echobase.crossCorrelation_wrapper(data[x], fsds)))
            if FCtype == "coherence":
                for x in range(4):
                    print(f"\n\n {x}/3")
                    data_FC.append(
                        list(echobase.coherence_wrapper(data[x], fsds)))

            FC = [channels, data_FC]
            utils.savePickle([channels, data_FC], fname)
            return FC
예제 #11
0
    #os.system(cmd) cannot run this docker command inside python interactive shell. Must run the printed command in terminal using qsiprep environment in mad in environment directory

#02 Tractography
for i in range(len(patientsDTI)):
    sub = patientsDTI[i]
    print(f"\n\n\n\n{sub}")
    if utils.getSubType(sub) == "control": ses = "control3T"
    else: ses = "preop3T"
    #input/output paths from QSI prep
    pathDWI = join(pathQSIPREP, "qsiprep", f"sub-{sub}", f"ses-{ses}", "dwi",
                   f"sub-{sub}_ses-{ses}_space-T1w_desc-preproc_dwi.nii.gz")
    pathTracts = join(pathTractography, f"sub-{sub}", "tracts")
    trkName = f"{join(pathTracts, utils.baseSplitextNiiGz(pathDWI)[2])}.trk.gz"
    utils.checkPathAndMake(pathTractography, pathTracts)

    if not utils.checkIfFileExists(trkName, returnOpposite=False):
        t0 = time.time()
        tractography.getTracts(BIDS, dsiStudioSingularityPatah, pathDWI,
                               pathTracts)
        utils.calculateTimeToComplete(t0, time.time(), len(patientsDTI), i)

#%% 06 Structural Connectivity for all atlases


#01 Atlas Registration
def atlasRegistrationBatch(patientsDTI, i):
    #for i in range(len(patientsDTI)):
    sub = patientsDTI[i]

    #input/output paths from QSI prep
    preop3T = join(pathQSIPREP, "qsiprep", f"sub-{sub}", "anat",
예제 #12
0
def mni_registration_to_T1(sub,
                           paths,
                           SESSION_RESEARCH3T="research3Tv[0-9][0-9]"):

    #getting directories and files

    #structural connectivity derivatives
    atlas_registration, structural_matrices, mni_images, mni_warp, ses, preop_T1, preop_T1_std, path_fib, path_trk = get_atlas_registration_and_tractography_paths(
        sub=sub, paths=paths, SESSION_RESEARCH3T=SESSION_RESEARCH3T)

    if not utils.checkIfFileExists(f"{mni_images}_fnirt.nii.gz",
                                   printBOOL=False):

        #QSI prep outputs
        preop_T1 = os.path.join(paths.BIDS_DERIVATIVES_QSIPREP, "qsiprep",
                                f"sub-{sub}", "anat",
                                f"sub-{sub}_desc-preproc_T1w.nii.gz")
        preop_T1_mask = os.path.join(paths.BIDS_DERIVATIVES_QSIPREP, "qsiprep",
                                     f"sub-{sub}", "anat",
                                     f"sub-{sub}_desc-brain_mask.nii.gz")
        preop_T1_bet = os.path.join(
            atlas_registration, f"sub-{sub}_desc-preproc_T1w_brain.nii.gz")

        #checking all images exist
        utils.checkPathError(preop_T1)
        utils.checkPathError(preop_T1_mask)
        utils.checkPathError(paths.MNI_TEMPLATE)
        utils.checkPathError(paths.MNI_TEMPLATE_BRAIN)

        #copying relevant images from QSI prep to registration folder
        cmd = f"cp  {preop_T1} {os.path.join(atlas_registration, os.path.basename(preop_T1))}"
        os.system(cmd)
        cmd = f"cp  {preop_T1_mask} {os.path.join(atlas_registration, os.path.basename(preop_T1_mask))}"
        os.system(cmd)
        cmd = f"fslmaths  {preop_T1} -mul {preop_T1_mask} {preop_T1_bet}"
        os.system(cmd)

        #intemediary files output names
        preop_T1 = os.path.join(atlas_registration, os.path.basename(preop_T1))
        preop_T1_mask = os.path.join(atlas_registration,
                                     os.path.basename(preop_T1_mask))

        preop_T1_std = f"{splitext(splitext(preop_T1)[0])[0]}_std.nii.gz"
        preop_T1_bet_std = f"{splitext(splitext(preop_T1_bet)[0])[0]}_std.nii.gz"

        #Begin Pipeline: Orient all images to standard RAS

        fillerString = "\n###########################" * 3
        print(
            f"\n\n{fillerString}\nPart 1 of 2\nReorientation of Images\nEstimated time: 10-30 seconds{fillerString}\nReorient all images to standard RAS\n"
        )
        cmd = f"fslreorient2std {preop_T1} {preop_T1_std}"
        print(cmd)
        os.system(cmd)
        cmd = f"fslreorient2std {preop_T1_bet} {preop_T1_bet_std}"
        print(cmd)
        os.system(cmd)

        #visualize
        utils.show_slices(f"{preop_T1_std}",
                          low=0.33,
                          middle=0.5,
                          high=0.66,
                          save=True,
                          saveFilename=join(atlas_registration, "pic_T1.png"))
        utils.show_slices(f"{preop_T1_bet_std}",
                          low=0.33,
                          middle=0.5,
                          high=0.66,
                          save=True,
                          saveFilename=join(atlas_registration,
                                            "pic_T1_brain.png"))

        #Registration of MNI to patient space (atlases are all in MNI space, so using this warp to apply to the atlases)
        print(
            f"\n\n{fillerString}\nPart 2 of 2\nMNI and atlas registration\nEstimated time: 1-2+ hours{fillerString}\nRegistration of MNI template to patient space\n"
        )

        #linear reg of MNI to preopT1 space
        if not utils.checkIfFileExists(f"{mni_images}_flirt.nii.gz",
                                       printBOOL=False):
            cmd = f"flirt -in {paths.MNI_TEMPLATE_BRAIN} -ref {preop_T1_bet_std} -dof 12 -out {mni_images}_flirt -omat {mni_images}_flirt.mat -v"
            print(cmd)
            os.system(cmd)
        #non linear reg of MNI to preopT1 space
        utils.show_slices(f"{mni_images}_flirt.nii.gz",
                          low=0.33,
                          middle=0.5,
                          high=0.66,
                          save=True,
                          saveFilename=join(atlas_registration,
                                            "pic_mni_to_T1_flirt.png"))
        print(
            "\n\nLinear registration of MNI template to image is done\n\nStarting Non-linear registration:\n\n\n"
        )
        if not utils.checkIfFileExists(f"{mni_images}_fnirt.nii.gz",
                                       printBOOL=False):
            cmd = f"fnirt --in={paths.MNI_TEMPLATE} --ref={preop_T1_std} --aff={mni_images}_flirt.mat --iout={mni_images}_fnirt -v --cout={mni_images}_coef --fout={mni_images}_warp"
            print(cmd)
            os.system(cmd)
        utils.show_slices(f"{mni_images}_fnirt.nii.gz",
                          low=0.33,
                          middle=0.5,
                          high=0.66,
                          save=True,
                          saveFilename=join(atlas_registration,
                                            "pic_mni_to_T1_fnirt.png"))
        print(f"\n\n{fillerString}\nDone{fillerString}\n\n\n\n")
    else:
        print(f"\n\n\n\nMNI registration already performed\n\n\n\n")