def load_small_stim_pkl(stim_pkl, runtype="prod"): """ load_small_stim_pkl(stim_pkl) Loads a smaller stimulus dictionary from the stimulus pickle file in which "posbyframe" for visual flow stimuli is not included. If it does not exist, small stimulus dictionary is created and saved as a pickle with "_small" appended to name. Reduces the pickle size about 10 fold. Required args: - stim_pkl (Path): full path name for the full stimulus pickle file Optional args: - runtype (str): runtype ("prod" or "pilot") """ stim_pkl = Path(stim_pkl) stim_pkl_no_ext = Path(stim_pkl.parent, stim_pkl.stem) small_stim_pkl_name = Path(f"{stim_pkl_no_ext}_small.pkl") if small_stim_pkl_name.is_file(): return file_util.loadfile(small_stim_pkl_name) else: logger.info("Creating smaller stimulus pickle.", extra={"spacing": TAB}) stim_dict = file_util.loadfile(stim_pkl) if runtype == "pilot": stim_par_key = "stimParams" elif runtype == "prod": stim_par_key = "stim_params" else: gen_util.accepted_values_error( "runtype", runtype, ["prod", "pilot"]) for i in range(len(stim_dict["stimuli"])): stim_keys = stim_dict["stimuli"][i][stim_par_key].keys() stim_par = stim_dict["stimuli"][i][stim_par_key] if runtype == "pilot" and "posByFrame" in stim_keys: _ = stim_par.pop("posByFrame") elif runtype == "prod" and "square_params" in stim_keys: _ = stim_par["session_params"].pop("posbyframe") file_util.saveinfo(stim_dict, small_stim_pkl_name) return stim_dict
def get_roi_locations(roi_extract_dict): """ get_roi_locations(roi_extract_dict) Returns ROI locations, extracted from ROI extraction dictionary. Required args: - roi_extract_dict (dict): ROI extraction dictionary Returns - roi_locations (pd DataFrame): ROI locations dataframe """ if not isinstance(roi_extract_dict, dict): roi_extract_dict = file_util.loadfile(roi_extract_dict) # get data out of json and into dataframe rois = roi_extract_dict["rois"] roi_locations_list = [] for i in range(len(rois)): roi = rois[i] mask = roi["mask"] roi_locations_list.append( [roi["id"], roi["x"], roi["y"], roi["width"], roi["height"], roi["valid"], mask]) roi_locations = pd.DataFrame( data=roi_locations_list, columns=["id", "x", "y", "width", "height", "valid", "mask"]) return roi_locations
def load_info_from_mouse_df(sessid, mouse_df="mouse_df.csv"): """ load_info_from_mouse_df(sessid) Returns dictionary containing information from the mouse dataframe. Required args: - sessid (int): session ID Optional args: - mouse_df (Path): path name of dataframe containing information on each session. Dataframe should have the following columns: sessid, mouse_n, depth, plane, line, sess_n, pass_fail, all_files, any_files, notes default: "mouse_df.csv" Returns: - df_dict (dict): dictionary with following keys: - all_files (bool) : if True, all files have been acquired for the session - any_files (bool) : if True, some files have been acquired for the session - dandi_id (str) : Dandi session ID - date (str) : session date (i.e., yyyymmdd) - depth (int) : recording depth - plane (str) : recording plane ("soma" or "dend") - line (str) : mouse line (e.g., "L5-Rbp4") - mouse_n (int) : mouse number (e.g., 1) - mouseid (int) : mouse ID (6 digits) - notes (str) : notes from the dataframe on the session - pass_fail (str) : whether session passed "P" or failed "F" quality control - runtype (str) : "prod" (production) or "pilot" data - sess_n (int) : overall session number (e.g., 1) - stim_seed (int) : random seed used to generated stimulus """ if isinstance(mouse_df, (str, Path)): mouse_df = file_util.loadfile(mouse_df) df_line = gen_util.get_df_vals(mouse_df, "sessid", sessid, single=True) df_dict = { "mouse_n" : int(df_line["mouse_n"].tolist()[0]), "dandi_id" : df_line["dandi_session_id"].tolist()[0], "date" : int(df_line["date"].tolist()[0]), "depth" : df_line["depth"].tolist()[0], "plane" : df_line["plane"].tolist()[0], "line" : df_line["line"].tolist()[0], "mouseid" : int(df_line["mouseid"].tolist()[0]), "runtype" : df_line["runtype"].tolist()[0], "sess_n" : int(df_line["sess_n"].tolist()[0]), "stim_seed" : int(df_line["stim_seed"].tolist()[0]), "pass_fail" : df_line["pass_fail"].tolist()[0], "all_files" : bool(int(df_line["all_files"].tolist()[0])), "any_files" : bool(int(df_line["any_files"].tolist()[0])), "notes" : df_line["notes"].tolist()[0], } return df_dict
def plot_from_dict(direc, plt_bkend=None, fontdir=None): """ plot_from_dict(direc) Plots data from dictionaries containing analysis parameters and results, or path to results. Required args: - direc (Path): path to directory in which dictionaries to plot data from are located Optional_args: - plt_bkend (str): mpl backend to use for plotting (e.g., "agg") default: None - fontdir (Path) : directory in which additional fonts are stored default: None """ logger.info(f"Plotting from hyperparameters in: {direc}", extra={"spacing": "\n"}) direc = Path(direc) plot_util.manage_mpl(plt_bkend, fontdir=fontdir) hyperpars = file_util.loadfile("hyperparameters.json", fulldir=direc) if "logregpar" in hyperpars.keys(): plot_traces_scores(hyperpars, savedir=direc) plot_util.cond_close_figs()
def load_basic_stimulus_table(stim_dict, stim_sync_h5, time_sync_h5, align_pkl, sessid, runtype="prod"): """ load_basic_stimulus_table(stim_dict, stim_sync_h5, time_sync_h5, align_pkl, sessid) Creates the alignment dataframe (stim_df) and saves it as a pickle in the session directory, if it does not already exist. Returns dataframe. Arguments: - stim_dict (dict) : experiment stim dictionary, loaded from pickle - stim_sync_h5 (Path): full path name of the experiment sync hdf5 file - time_sync_h5 (Path): full path name of the time synchronization hdf5 file - align_pkl (Path) : full path name of the output pickle file to create - sessid (int) : session ID, needed the check whether this session needs to be treated differently (e.g., for alignment bugs) Optional arguments: - runtype (str): runtype ("prod" or "pilot") default: "prod"): Returns: df (pandas): basic stimulus table. stim_align (1D array): stimulus to 2p alignment array """ align_pkl = Path(align_pkl) sessdir = align_pkl.parent # create stim_df if doesn't exist if not align_pkl.is_file(): logger.info( f"Stimulus alignment pickle not found in {sessdir}, and " "will be created.", extra={"spacing": TAB}) sess_sync_util.get_stim_frames( stim_dict, stim_sync_h5, time_sync_h5, align_pkl, sessid, runtype, ) align_dict = file_util.loadfile(align_pkl) df = align_dict["stim_df"] stim_align = align_dict["stim_align"].astype(int) return df, stim_align
def plot_from_dict(dict_path, plt_bkend=None, fontdir=None, parallel=False, datetime=True, overwrite=False): """ plot_from_dict(dict_path) Plots data from dictionaries containing analysis parameters and results. Required args: - dict_path (Path): path to dictionary to plot data from Optional_args: - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg") default: None - fontdir (Path) : path to directory where additional fonts are stored default: None - parallel (bool) : if True, some of the analysis is parallelized across CPU cores default: False - datetime (bool) : figpar["save"] datatime parameter (whether to place figures in a datetime folder) default: True - overwrite (bool): figpar["save"] overwrite parameter (whether to overwrite figures) default: False """ logger.info(f"Plotting from dictionary: {dict_path}", extra={"spacing": "\n"}) figpar = sess_plot_util.init_figpar( plt_bkend=plt_bkend, fontdir=fontdir, datetime=datetime, overwrite=overwrite ) plot_util.manage_mpl(cmap=False, **figpar["mng"]) dict_path = Path(dict_path) info = file_util.loadfile(dict_path) savedir = dict_path.parent analysis = info["extrapar"]["analysis"] # 0. Plots the correlation between pupil and roi/run changes for each # session if analysis == "c": # difference correlation plot_pup_diff_corr(figpar=figpar, savedir=savedir, **info) # difference correlation per ROI between stimuli elif analysis == "r": plot_pup_roi_stim_corr(figpar=figpar, savedir=savedir, **info) else: warnings.warn(f"No plotting function for analysis {analysis}", category=UserWarning, stacklevel=1) plot_util.cond_close_figs()
def get_motion_border(roi_extract_dict): """ get_motion_border(roi_extract_dict) Returns motion border for motion corrected stack. Required args: - roi_extract_dict (dict): ROI extraction dictionary Returns: - motion border (list): motion border values for [x0, x1, y1, y0] (right, left, down, up shifts) """ if not isinstance(roi_extract_dict, dict): roi_extract_dict = file_util.loadfile(roi_extract_dict) coords = ["x0", "x1", "y0", "y1"] motion_border = [roi_extract_dict["motion_border"][coord] for coord in coords] return motion_border
def get_sessid_from_mouse_df(mouse_n=1, sess_n=1, runtype="prod", mouse_df="mouse_df.csv"): """ get_sessid_from_mouse_df(sessid) Returns session ID, based on the mouse number, session number, and runtype, based on the mouse dataframe. Optional args: - mouse_n (int) : mouse number default: 1 - sess_n (int) : session number default: 1 - runtype (str) : type of data default: 1 - mouse_df (Path): path name of dataframe containing information on each session. Dataframe should have the following columns: mouse_n, sess_n, runtype default: "mouse_df.csv" Returns: - sessid (int): session ID """ if isinstance(mouse_df, (str, Path)): mouse_df = file_util.loadfile(mouse_df) df_line = gen_util.get_df_vals( mouse_df, ["mouse_n", "sess_n", "runtype"], [int(mouse_n), int(sess_n), runtype], single=True ) sessid = int(df_line["sessid"].tolist()[0]) return sessid
def check_if_data_exists(figpar, filetype="json", overwrite_plot_only=False, raise_no_data=True): """ check_if_data_exists(figpar) Returns whether to rerun analysis, depending on whether data file already exists and fipar["save"]["overwrite"] is True or False. Required args: - figpar (dict): dictionary containing figure parameters ["fig_panel_analysis"] (FigPanelAnalysis): figure/panel analysis object ["dirs"]["figdir"] (Path): figure directory ["save"]["overwrite"] (bool): whether to overwrite data and figure files Optional args: - filetype (str): type of data file expected default: "json" - overwrite_plot_only (bool): if True, data is replotted only. default: False - raise_no_data (bool): if True, an error is raised if overwrite_plot_only is True, but no analysis data is found. default: True Returns: - run_analysis (bool): if True, analysis should be run - data_path (Path): path to data (whether it exists, or not) """ fig_panel_analysis = figpar["fig_panel_analysis"] savedir, savename = get_save_path(fig_panel_analysis, main_direc=figpar["dirs"]["figdir"]) datadir = get_datafile_save_path(savedir) data_path = file_util.add_ext(datadir.joinpath(savename), filetype)[0] run_analysis = True if data_path.is_file(): warn_str = f"Analysis data already exists under {data_path}." if figpar["save"]["overwrite"] and not overwrite_plot_only: warn_str = f"{warn_str}\nFile will be overwritten." logger.warning(warn_str, extra={"spacing": "\n"}) else: warn_str = (f"{warn_str}\nReplotting from existing file.\n" "To overwrite file, run script with the '--overwrite' " "argument, and without --plot_only.") logger.warning(warn_str, extra={"spacing": "\n"}) info = file_util.loadfile(data_path) fig_panel_analysis.plot_fct(figpar=figpar, **info) run_analysis = False elif overwrite_plot_only and raise_no_data: raise RuntimeError( "overwrite_plot_only is True, but no analysis data was found " f"under {data_path}") return run_analysis, data_path
def plot_from_dict(dict_path, plt_bkend=None, fontdir=None, plot_tc=True, parallel=False, datetime=True, overwrite=False): """ plot_from_dict(dict_path) Plots data from dictionaries containing analysis parameters and results. Required args: - dict_path (Path): path to dictionary to plot data from Optional_args: - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg") default: None - fontdir (Path) : path to directory where additional fonts are stored default: None - plot_tc (bool) : if True, tuning curves are plotted for each ROI (dummy argument) default: True - parallel (bool) : if True, some of the analysis is parallelized across CPU cores default: False - datetime (bool) : figpar["save"] datatime parameter (whether to place figures in a datetime folder) default: True - overwrite (bool): figpar["save"] overwrite parameter (whether to overwrite figures) default: False """ logger.info(f"Plotting from dictionary: {dict_path}", extra={"spacing": "\n"}) figpar = sess_plot_util.init_figpar(plt_bkend=plt_bkend, fontdir=fontdir, datetime=datetime, overwrite=overwrite) plot_util.manage_mpl(cmap=False, **figpar["mng"]) plt.rcParams["figure.titlesize"] = "xx-large" plt.rcParams["axes.titlesize"] = "xx-large" dict_path = Path(dict_path) info = file_util.loadfile(dict_path) savedir = dict_path.parent analysis = info["extrapar"]["analysis"] # 1. Plot average traces by quantile x unexpected for each session if analysis == "t": # traces gen_plots.plot_traces_by_qu_unexp_sess(figpar=figpar, savedir=savedir, modif=True, **info) # 2. Plot average traces by quantile, locked to unexpected for each session elif analysis == "l": # unexpected locked traces gen_plots.plot_traces_by_qu_lock_sess(figpar=figpar, savedir=savedir, modif=True, **info) else: warnings.warn(f"No modified plotting option for analysis {analysis}", category=UserWarning, stacklevel=1) plot_util.cond_close_figs()
def get_run_velocity(stim_sync_h5, stim_pkl="", stim_dict=None, filter_ks=5): """ get_run_velocity(stim_sync_h5) Adapted from allensdk.brain_observatory.running_processing.__main__.main(). Loads and calculates the linear running velocity from the raw running data. Required args: - stim_sync_h5 (Path): full path name of the stimulus sync h5 file Optional args: - stim_pkl (Path) : full path name of the experiment stim pickle file default: "" - stim_dict (dict): stimulus dictionary, with keys "fps" and "items", from which running velocity is extracted. If not None, overrides pkl_file_name. default: None - filter_ks (int) : kernel size to use in median filtering the linear running velocity (0 to skip filtering). default: 5 Returns: - running_velocity (array): array of length equal to the number of stimulus frames, each element corresponds to the linear running velocity for that stimulus frame """ if stim_pkl == "" and stim_dict is None: raise ValueError("Must provide either the pickle file name or the " "stimulus dictionary.") if stim_dict is None: # check that the pickle file exists file_util.checkfile(stim_pkl) # read the input pickle file and call it "pkl" stim_dict = file_util.loadfile(stim_pkl) stim_fr_timestamps = get_stim_fr_timestamps(stim_sync_h5) # occasionally an extra set of frame times are acquired after the rest of # the signals. We detect and remove these stim_fr_timestamps = sync_utilities.trim_discontiguous_times( stim_fr_timestamps) num_raw_timestamps = len(stim_fr_timestamps) raw_running_deg = running_main.running_from_stim_file( stim_dict, "dx", num_raw_timestamps) if num_raw_timestamps != len(raw_running_deg): raise ValueError( f"found {num_raw_timestamps} rising edges on the vsync line, " f"but only {len(raw_running_deg)} rotation samples") use_median_duration = False use_filter_ks = filter_ks # for running alignement test analyses if TEST_RUNNING_BLIPS: logger.warning("Pre-processing running data using median duration " "and no filter, for testing purposes.") use_median_duration = True use_filter_ks = 0 running_velocity = calculate_running_velocity( stim_fr_timestamps=stim_fr_timestamps, raw_running_deg=raw_running_deg, wheel_radius=WHEEL_RADIUS, subject_position=SUBJECT_POSITION, use_median_duration=use_median_duration, filter_ks=use_filter_ks, ) return running_velocity
def plot_from_dict(dict_path, plt_bkend=None, fontdir=None, parallel=False, datetime=True, overwrite=False): """ plot_from_dict(dict_path) Plots data from dictionaries containing analysis parameters and results. Required args: - dict_path (Path): path to dictionary to plot data from Optional_args: - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg") default: None - fontdir (Path) : path to directory where additional fonts are stored default: None - parallel (bool) : if True, some of the plotting is parallelized across CPU cores default: False - datetime (bool) : figpar["save"] datatime parameter (whether to place figures in a datetime folder) default: True - overwrite (bool): figpar["save"] overwrite parameter (whether to overwrite figures) default: False """ logger.info(f"Plotting from dictionary: {dict_path}", extra={"spacing": "\n"}) figpar = sess_plot_util.init_figpar( plt_bkend=plt_bkend, fontdir=fontdir, datetime=datetime, overwrite=overwrite ) plot_util.manage_mpl(cmap=False, **figpar["mng"]) dict_path = Path(dict_path) info = file_util.loadfile(dict_path) savedir = dict_path.parent analysis = info["extrapar"]["analysis"] # 0. Plots the full traces for each session if analysis == "f": # full traces plot_full_traces(figpar=figpar, savedir=savedir, **info) # 1. Plot average traces by quantile x unexpected for each session elif analysis == "t": # traces plot_traces_by_qu_unexp_sess(figpar=figpar, savedir=savedir, **info) # 2. Plot average traces by quantile, locked to unexpected for each session elif analysis == "l": # unexpected locked traces plot_traces_by_qu_lock_sess(figpar=figpar, savedir=savedir, **info) # 3. Plot magnitude of change in dF/F area from first to last quantile of # unexpected vs no unexpected sequences, for each session elif analysis == "m": # mag plot_mag_change(figpar=figpar, savedir=savedir, **info) # 4. Plot autocorrelations elif analysis == "a": # autocorr plot_autocorr(figpar=figpar, savedir=savedir, **info) else: warnings.warn(f"No plotting function for analysis {analysis}", category=UserWarning, stacklevel=1) plot_util.cond_close_figs()
def sess_per_mouse(mouse_df, mouse_n="any", sess_n=1, runtype="prod", plane="any", line="any", pass_fail="P", incl="yes", all_files=1, any_files=1, min_rois=1, closest=False, omit_sess=[], omit_mice=[]): """ sess_per_mouse(mouse_df) Returns list of session IDs (up to 1 per mouse) that fit the specified criteria. Required args: - mouse_df (Path): path name of dataframe containing information on each session Optional args: - mouse_n (int or str) : mouse number(s) of interest default: "any" - sess_n (int or str) : session number(s) of interest (1, 2, 3, ... or "first", "last") default: 1 - runtype (str or list) : runtype value(s) of interest ("pilot", "prod") default: "prod" - plane (str or list) : plane value(s) of interest ("soma", "dend", "any") default: "any" - line (str or list) : line value(s) of interest ("L5", "L23", "any") default: "any" - pass_fail (str or list): pass/fail values of interest ("P", "F", "any") default: "P" - incl (str) : which sessions to include ("yes", "no", "any") default: "yes" - all_files (int or list): all_files values of interest (0, 1) default: 1 - any_files (int or list): any_files values of interest (0, 1) default: 1 - min_rois (int) : min number of ROIs default: 1 - closest (bool) : if False, only exact session number is retained, otherwise the closest default: False - omit_sess (list) : sessions to omit default: [] - omit_mice (list) : mice to omit default: [] Returns: - sessids (list): sessions to analyse (1 per mouse) """ if isinstance(mouse_df, (str, Path)): mouse_df = file_util.loadfile(mouse_df) orig_sess_n = int(sess_n) if closest or str(sess_n) in ["first", "last", "-1"]: sess_n = gen_util.get_df_label_vals(mouse_df, "sess_n", "any") if runtype == "any": raise ValueError("Must specify runtype (cannot be any), as there is " "overlap in mouse numbers.") # get list of mice that fit the criteria mouse_ns = get_sess_vals(mouse_df, "mouse_n", mouse_n, sess_n, runtype, plane, line, pass_fail, incl, all_files, any_files, min_rois, omit_sess, omit_mice, unique=True, sort=True) # get session ID each mouse based on criteria sessids = [] for i in sorted(mouse_ns): sess_ns = get_sess_vals(mouse_df, "sess_n", i, sess_n, runtype, plane, line, pass_fail, incl, all_files, any_files, min_rois, omit_sess, omit_mice, sort=True) # skip mouse if no sessions meet criteria if len(sess_ns) == 0: continue # if only exact sess n is accepted (not closest) elif str(orig_sess_n) == "first" or not closest: n = sess_ns[0] elif str(orig_sess_n) in ["last", "-1"]: n = sess_ns[-1] # find closest sess number among possible sessions else: n = sess_ns[np.argmin( np.absolute([x - orig_sess_n for x in sess_ns]))] sessid = get_sess_vals(mouse_df, "sessid", i, n, runtype, plane, line, pass_fail, incl, all_files, any_files, min_rois, omit_sess, omit_mice)[0] sessids.append(sessid) if len(sessids) == 0: raise RuntimeError("No sessions meet the criteria.") return sessids
def get_stim_frames(pkl_file_name, stim_sync_h5, time_sync_h5, df_pkl_name, sessid, runtype="prod"): """ get_stim_frames(pkl_file_name, stim_sync_h5, time_sync_h5, df_pkl_name, sessid) Pulls out the stimulus frame information from the stimulus pickle file, as well as synchronization information from the stimulus sync file, and stores synchronized stimulus frame information in the output pickle file along with the stimulus alignment array. Required args: - pkl_file_name (Path): full path name of the experiment stim pickle file - stim_sync_h5 (Path): full path name of the experiment sync hdf5 file - time_sync_h5 (Path) : full path to the time synchronization hdf5 file - df_pkl_name (Path) : full path name of the output pickle file to create - sessid (int) : session ID, needed the check whether this session needs to be treated differently (e.g., for alignment bugs) Optional argument: - runtype (str) : the type of run, either "pilot" or "prod" default: "prod" """ # read the pickle file and call it "pkl" if isinstance(pkl_file_name, dict): pkl = pkl_file_name else: # check that the pickle file exists file_util.checkfile(pkl_file_name) pkl = file_util.loadfile(pkl_file_name, filetype="pickle") if runtype == "pilot": num_stimtypes = 2 # visual flow (bricks) and Gabors elif runtype == "prod": num_stimtypes = 3 # 2 visual flow (bricks) and 1 set of Gabors if len(pkl["stimuli"]) != num_stimtypes: raise ValueError(f"{num_stimtypes} stimuli types expected, but " f"{len(pkl['stimuli'])} found.") # get dataset object, sample frequency and vsyncs stim_vsync_fall_adj, valid_twop_vsync_fall = get_vsync_falls(stim_sync_h5) # calculate the alignment logger.info("Calculating stimulus alignment.") stimulus_alignment = Dataset2p.calculate_stimulus_alignment( stim_vsync_fall_adj, valid_twop_vsync_fall) # get the second stimulus alignment from sess_util.sess_load_util import load_beh_sync_h5_data second_stimulus_alignment = load_beh_sync_h5_data(time_sync_h5)[2] if len(second_stimulus_alignment) == len(stimulus_alignment) + 1: second_stimulus_alignment = second_stimulus_alignment[:-1] if int(sessid) in ADJUST_SECOND_ALIGNMENT: diff = second_stimulus_alignment - stimulus_alignment adjustment = scist.mode(diff)[0][0] # most frequent difference stimulus_alignment += adjustment # compare alignments compare_alignments(stimulus_alignment, second_stimulus_alignment) offset = int(pkl["pre_blank_sec"] * pkl["fps"]) logger.info("Creating the stim_df:") # get number of segments expected and actually recorded for each stimulus segs = [] segs_exp = [] frames_per_seg = [] stim_types = [] stim_type_names = [] for i in range(num_stimtypes): # records the max num of segs in the frame list for each stimulus segs.extend([np.max(pkl["stimuli"][i]["frame_list"]) + 1]) # calculates the expected number of segs based on fps, # display duration (s) and seg length fps = pkl["stimuli"][i]["fps"] if runtype == "pilot": name = pkl["stimuli"][i]["stimParams"]["elemParams"]["name"] elif runtype == "prod": name = pkl["stimuli"][i]["stim_params"]["elemParams"]["name"] stim_type_names.extend([name]) stim_types.extend([name[0]]) if name == "bricks": frames_per_seg.extend([fps]) segs_exp.extend([ int(60. * np.sum(np.diff(pkl["stimuli"][i]["display_sequence"])) / frames_per_seg[i]) ]) elif name == "gabors": frames_per_seg.extend([fps / 1000. * 300]) # to exclude grey seg segs_exp.extend([ int(60. * np.sum(np.diff(pkl["stimuli"][i]["display_sequence"])) / frames_per_seg[i] * 4. / 5) ]) else: raise ValueError(f"{name} stimulus type not recognized.") # check whether the actual number of frames is within a small range of # expected about two frames per sequence? n_seq = pkl["stimuli"][0]["display_sequence"].shape[0] * 2 if np.abs(segs[i] - segs_exp[i]) > n_seq: raise ValueError( f"Expected {segs_exp[i]} frames for stimulus {i}, " f"but found {segs[i]}.") total_stimsegs = np.sum(segs) stim_df = pd.DataFrame(index=list(range(np.sum(total_stimsegs))), columns=[ "stimType", "stimPar1", "stimPar2", "surp", "stimSeg", "GABORFRAME", "start_frame", "end_frame", "num_frames" ]) zz = 0 # For gray-screen pre_blank stim_df.loc[zz, "stimType"] = -1 stim_df.loc[zz, "stimPar1"] = -1 stim_df.loc[zz, "stimPar2"] = -1 stim_df.loc[zz, "surp"] = -1 stim_df.loc[zz, "stimSeg"] = -1 stim_df.loc[zz, "GABORFRAME"] = -1 stim_df.loc[zz, "start_frame"] = stimulus_alignment[0] # 2p start frame stim_df.loc[zz, "end_frame"] = stimulus_alignment[offset] # 2p end frame stim_df.loc[zz, "num_frames"] = \ (stimulus_alignment[offset] - stimulus_alignment[0]) zz += 1 for stype_n in range(num_stimtypes): logger.info(f"Stimtype: {stim_type_names[stype_n]}", extra={"spacing": TAB}) movie_segs = pkl["stimuli"][stype_n]["frame_list"] for segment in range(segs[stype_n]): seg_inds = np.where(movie_segs == segment)[0] tup = (segment, int(stimulus_alignment[seg_inds[0] + offset]), \ int(stimulus_alignment[seg_inds[-1] + 1 + offset])) stim_df.loc[zz, "stimType"] = stim_types[stype_n][0] stim_df.loc[zz, "stimSeg"] = segment stim_df.loc[zz, "start_frame"] = tup[1] stim_df.loc[zz, "end_frame"] = tup[2] stim_df.loc[zz, "num_frames"] = tup[2] - tup[1] get_seg_params(stim_types, stype_n, stim_df, zz, pkl, segment, runtype) zz += 1 # check whether any 2P frames are in associated to 2 stimuli overlap = np.any((np.sort(stim_df["start_frame"])[1:] - np.sort(stim_df["end_frame"])[:-1]) < 0) if overlap: raise ValueError("Some 2P frames associated with two stimulus " "segments.") # create a dictionary for pickling stim_dict = {"stim_df": stim_df, "stim_align": stimulus_alignment} # store in the pickle file try: file_util.saveinfo(stim_dict, df_pkl_name, overwrite=True) except: raise OSError(f"Could not save stimulus pickle file {df_pkl_name}")
def load_stimulus_table(stim_dict, stim_sync_h5, time_sync_h5, align_pkl, sessid, runtype="prod"): """ load_stimulus_table(stim_dict, stim_sync_h5, time_sync_h5, align_pkl, sessid) Retrieves and expands stimulus dataframe. Arguments: stim_dict (dict) : experiment stim dictionary, loaded from pickle (or full path to load it from) stim_sync_h5 (Path): full path name of the experiment sync hdf5 file time_sync_h5 (Path): full path name of the time synchronization hdf5 file align_pkl (Path) : full path name of the output pickle file to create sessid (int) : session ID, needed the check whether this session needs to be treated differently (e.g., for alignment bugs) Optional args: runtype (str): runtype ("prod" or "pilot") default: "prod"): Returns: df (pandas): stimulus table. stim_align (1D array): stimulus to 2p alignment array """ # PRE-LOAD EVERYTHING TO AVOID RE-LOADING # read the pickle file and call it "pkl" if not isinstance(stim_dict, dict): stim_dict = file_util.loadfile(stim_dict, filetype="pickle") # load dataframe as is (or trigger creation, if it doesn't exist) df, stim_align = load_basic_stimulus_table(stim_dict, stim_sync_h5, time_sync_h5, align_pkl, sessid, runtype=runtype) # load stimulus timestamps stimulus_timestamps = sess_sync_util.get_stim_fr_timestamps( stim_sync_h5, time_sync_h5=time_sync_h5, stim_align=stim_align) # CREATE DATAFRAME # load dataframe with updated column names df = update_basic_stimulus_table(df) # add stimulus frame numbers df = add_stimulus_frame_num(df, stim_dict, stim_align, runtype=runtype) # add stimulus names df = add_stimulus_locations(df, stim_dict, runtype=runtype) # adjust stimulus segment numbers for each stimulus type df = modify_segment_num(df) # add gabor grayscreens (G) df = add_gabor_grayscreen_rows(df, stim_align) # add intervening grayscreen stimuli df = add_grayscreen_stimulus_rows(df, stim_align) # update time columns df = add_time(df, stim_dict, stimulus_timestamps) # add stimulus names df = add_stimulus_template_names_and_frames(df, runtype=runtype) # reorder rows df = df.sort_values("start_frame_twop").reset_index(drop=True) # verify that no stray -1s are left check_for_values = [ "stimulus_template_name", "stimulus_type", "start_frame_stim_template", "start_frame_stim", "stop_frame_stim", "num_frames_stim", "start_frame_twop", "stop_frame_twop", "num_frames_twop", "start_time_sec", "stop_time_sec", "duration_sec", ] for column in check_for_values: if len(df.loc[df[column] == -1]) != 0: raise NotImplementedError( f"Incorrect implemmentation. -1 found in {column} column of " "stimulus table.") # replace -1s with ""/np.nan/[] str_dtype_columns = ["gabor_frame", "main_flow_direction"] list_dtype_columns = [ "gabor_orientations", "gabor_sizes", "gabor_locations_x", "gabor_locations_y", "square_locations_x", "square_locations_y", ] df[str_dtype_columns] = df[str_dtype_columns].replace([-1, "-1"], "") for column in df.columns: if column in list_dtype_columns: df[column] = df[column].apply( lambda d: d if isinstance(d, np.ndarray) else np.asarray([])) else: df[column] = df[column].replace([-1, "-1"], np.nan) # drop stimulus_segment df = df.drop(columns="orig_stimulus_segment") return df, stim_align
def get_sess_vals(mouse_df, returnlab, mouse_n="any", sess_n="any", runtype="any", plane="any", line="any", pass_fail="P", incl="all", all_files=1, any_files=1, min_rois=1, omit_sess=[], omit_mice=[], unique=True, sort=False): """ get_sess_vals(mouse_df, returnlab) Returns list of values under the specified label that fit the specified criteria. Required args: - mouse_df (Path) : path name of dataframe containing information on each session - returnlab (str or list): label(s) from which to return values Optional args: - mouse_n (int, str or list) : mouse number(s) of interest default: "any" - sess_n (int, str or list) : session number(s) of interest default: "any" - runtype (str or list) : runtype value(s) of interest ("pilot", "prod") default: "any" - plane (str or list) : plane value(s) of interest ("soma", "dend", "any") default: "any" - line (str or list) : line value(s) of interest ("L5", "L23", "any") default: "any" - pass_fail (str or list) : pass/fail values of interest ("P", "F", "any") default: "P" - incl (str) : which sessions to include ("yes", "no", "any") default: "yes" - all_files (str, int or list): all_files values of interest (0, 1) default: 1 - any_files (str, int or list): any_files values of interest (0, 1) default: 1 - min_rois (int) : min number of ROIs default: 1 - omit_sess (list) : sessions to omit default: [] - omit_mice (list) : mice to omit default: [] - unique (bool) : whether to return a list of values without duplicates (only done if only one returnlab is provided) default: False - sort (bool) : whether to sort output values (only done if only one returnlab is provided) default: False Returns: - sess_vals (list): values from dataframe that correspond to criteria """ if isinstance(mouse_df, (str, Path)): mouse_df = file_util.loadfile(mouse_df) # get depth values corresponding to the plane depth = depth_vals(plane, line) sessid = "any" params = [ mouse_n, sessid, sess_n, runtype, depth, pass_fail, incl, all_files, any_files ] param_names = [ "mouse_n", "sessid", "sess_n", "runtype", "depth", "pass_fail", "incl", "all_files", "any_files" ] # for each label, collect values in a list for i in range(len(params)): params[i] = gen_util.get_df_label_vals(mouse_df, param_names[i], params[i]) [ mouse_n, sessid, sess_n, runtype, depth, pass_fail, incl, all_files, any_files ] = params # remove omitted sessions from the session id list sessid = gen_util.remove_if(sessid, omit_sess) # collect all mouse IDs and remove omitted mice mouse_n = gen_util.remove_if(mouse_n, omit_mice) sess_vals = get_mouse_df_vals(mouse_df, returnlab, mouse_n, sessid, sess_n, runtype, depth, pass_fail, incl, all_files, any_files, min_rois, unique, sort) return sess_vals
def get_roi_masks(mask_file=None, roi_extract_json=None, objectlist_txt=None, mask_threshold=MASK_THRESHOLD, min_n_pix=MIN_N_PIX, make_bool=True): """ get_roi_masks() Returns ROI masks, loaded either from an h5 or json file, and optionally converted to boolean. NOTE: If masks are loaded from roi_extract_json, they are already boolean. Optional args: - mask_file (Path) : ROI mask h5. If None, roi_extract_json and objectlist_txt are used. default: None - roi_extract_json (Path): ROI extraction json (only needed is mask_file is None) - objectlist_txt (Path) : ROI object list txt (only needed if mask_file is None) default: None - mask_threshold (float) : minimum value in non-boolean mask to retain a pixel in an ROI mask default: MASK_THRESHOLD - min_n_pix (int) : minimum number of pixels in an ROI below which, ROI is set to be empty default: MIN_N_PIX - make_bool (bool) : if True, ROIs are converted to boolean before being returned default: True Returns: - roi_masks (3D array): ROI masks, structured as ROI x height x width - roi_ids (list) : ID for each ROI """ if (mask_file is None and (roi_extract_json is None or objectlist_txt is None)): raise ValueError("Must provide 'mask_file' or both " "'roi_extract_json' and 'objectlist_txt'.") if mask_file is None: roi_extract_dict = file_util.loadfile(roi_extract_json) h = roi_extract_dict["image"]["height"] w = roi_extract_dict["image"]["width"] roi_metrics = get_roi_metrics(roi_extract_dict, objectlist_txt) roi_ids = np.sort(roi_metrics.cell_specimen_id.values) nrois = len(roi_ids) # source data is boolean roi_masks = np.full([nrois, h, w], False).astype(bool) for i, roi_id in enumerate(roi_ids): m = roi_metrics[roi_metrics.id == roi_id].iloc[0] mask = np.asarray(m["mask"]) binary_mask = np.zeros((h, w), dtype=np.uint8) binary_mask[ int(m.y): int(m.y) + int(m.height), int(m.x): int(m.x) + int(m.width)] = mask roi_masks[i] = binary_mask else: with h5py.File(mask_file, "r") as f: roi_masks = f["data"][()] # not binary roi_ids = list(range(len(roi_masks))) roi_masks = process_roi_masks( roi_masks, mask_threshold=mask_threshold, min_n_pix=min_n_pix, make_bool=make_bool ) return roi_masks, roi_ids
def load_gen_stim_properties(stim_dict, stimtype="gabors", runtype="prod"): """ load_gen_stim_properties(stim_dict) Returns dictionary with general stimulus properties loaded from the stimulus dictionary. Arguments: - stim_dict (dict) : experiment stim dictionary, loaded from pickle Optional arguments: - stimtype (str): stimulus type default: "gabors" - runtype (str) : runtype ("prod" or "pilot") default: "prod"): Returns: - gen_stim_props (dict): dictionary with stimulus properties. ["deg_per_pix"]: deg / pixel conversion used to generate stimuli ["exp_len_s"] : duration of an expected seq (sec) [min, max] ["seg_len_s"] : duration of an expected seq (sec) [min, max] ["unexp_len_s"]: duration of an unexpected seq (sec) [min, max] ["win_size"] : window size [wid, hei] (in pixels) if stimtype == "gabors": ["n_segs_per_seq"]: number of segments in a sequence (including G) ["phase"] : phase (0-1) ["sf"] : spatial frequency (cyc/pix) ["size_ran"] : size range (in pixels) if stimtype == "visflow": ["speed"] : visual flow speed (pix/sec) """ if not isinstance(stim_dict, dict): stim_dict = file_util.loadfile(stim_dict, filetype="pickle") # run checks runtypes = ["prod", "pilot"] if runtype not in ["prod", "pilot"]: gen_util.accepted_values_error("runtype", runtype, runtypes) stimtypes = ["gabors", "visflow"] if stimtype not in stimtypes: gen_util.accepted_values_error("stimtype", stimtype, stimtypes) # find a stimulus of the correct type in dictionary stim_params_key = "stimParams" if runtype == "pilot" else "stim_params" stimtype_key = "gabor_params" if stimtype == "gabors" else "square_params" stim_ns = [ stim_n for stim_n, all_stim_info in enumerate(stim_dict["stimuli"]) if stimtype_key in all_stim_info[stim_params_key].keys() ] if len(stim_ns) == 0: raise RuntimeError( f"No {stimtype} stimulus found in stimulus dictionary.") else: stim_n = stim_ns[ 0] # same general stimulus properties expected for all # collect information all_stim_info = stim_dict["stimuli"][stim_n] if runtype == "prod": sess_par = all_stim_info[stim_params_key]["session_params"] else: sess_par = all_stim_info[stim_params_key]["subj_params"] gen_stim_props = { "win_size": sess_par["windowpar"][0], "deg_per_pix": sess_par["windowpar"][1], } stimtype_info = all_stim_info[stim_params_key][stimtype_key] gen_stim_props["exp_len_s"] = stimtype_info["reg_len"] gen_stim_props["unexp_len_s"] = stimtype_info["surp_len"] deg_per_pix = gen_stim_props["deg_per_pix"] if stimtype == "gabors": gen_stim_props["seg_len_s"] = stimtype_info["im_len"] gen_stim_props["n_segs_per_seq"] = stimtype_info["n_im"] + 1 # for G gen_stim_props["phase"] = stimtype_info["phase"] gen_stim_props["sf"] = stimtype_info["sf"] gen_stim_props["size_ran"] = \ [np.around(x / deg_per_pix) for x in stimtype_info["size_ran"]] # Gabor size conversion based on psychopy definition # full-width half-max -> 6 std size_conv = 1.0 / (2 * np.sqrt(2 * np.log(2))) * stimtype_info["sd"] gen_stim_props["size_ran"] = \ [int(np.around(x * size_conv)) for x in gen_stim_props["size_ran"]] else: gen_stim_props["seg_len_s"] = stimtype_info["seg_len"] gen_stim_props["speed"] = stimtype_info["speed"] / deg_per_pix return gen_stim_props
def load_stim_df_info(stim_pkl, stim_sync_h5, time_sync_h5, align_pkl, sessid, runtype="prod"): """ load_stim_df_info(stim_pkl, stim_sync_h5, time_sync_h5, align_pkl, sessid) Creates the alignment dataframe (stim_df) and saves it as a pickle in the session directory, if it does not already exist. Returns dataframe, alignment arrays, and frame rate. Required args: - stim_pkl (Path) : full path name of the experiment stim pickle file - stim_sync_h5 (Path): full path name of the experiment sync hdf5 file - time_sync_h5 (Path): full path name of the time synchronization hdf5 file - align_pkl (Path) : full path name of the output pickle file to create - sessid (int) : session ID, needed the check whether this session needs to be treated differently (e.g., for alignment bugs) Optional args: - runtype (str): runtype ("prod" or "pilot") default: "prod" Returns: - stim_df (pd DataFrame): stimlus alignment dataframe with columns: "stimtype", "unexp", "stim_seg", "gabfr", "gab_ori", "gabk", "visflow_dir", "visflow_size", "start_twop_fr", "end_twop_fr", "num_twop_fr" - stimtype_order (list) : stimulus type order - stim2twopfr (1D array): 2p frame numbers for each stimulus frame, as well as the flanking blank screen frames - twop_fps (num) : mean 2p frames per second - twop_fr_stim (int) : number of 2p frames recorded while stim was playing """ align_pkl = Path(align_pkl) sessdir = align_pkl.parent # create stim_df if doesn't exist if not align_pkl.is_file(): logger.info(f"Stimulus alignment pickle not found in {sessdir}, and " "will be created.", extra={"spacing": TAB}) sess_sync_util.get_stim_frames( stim_pkl, stim_sync_h5, time_sync_h5, align_pkl, sessid, runtype, ) align = file_util.loadfile(align_pkl) stim_df = align["stim_df"] stim_df = stim_df.rename( columns={"GABORFRAME": "gabfr", "surp": "unexp", # rename surprise to unexpected "stimType": "stimtype", "stimSeg": "stim_seg", "start_frame": "start_twop_fr", "end_frame": "end_twop_fr", "num_frames": "num_twop_fr"}) # rename bricks -> visflow stim_df["stimtype"] = stim_df["stimtype"].replace({"b": "v"}) stim_df = modify_visflow_segs(stim_df, runtype) stim_df = stim_df.sort_values("start_twop_fr").reset_index(drop=True) # note: STIMULI ARE NOT ORDERED IN THE PICKLE stimtype_map = { "g": "gabors", "v": "visflow" } stimtype_order = stim_df["stimtype"].map(stimtype_map).unique() stimtype_order = list( filter(lambda s: s in stimtype_map.values(), stimtype_order)) # split stimPar1 and stimPar2 into all stimulus parameters stim_df["gab_ori"] = stim_df["stimPar1"] stim_df["gabk"] = stim_df["stimPar2"] stim_df["visflow_size"] = stim_df["stimPar1"] stim_df["visflow_dir"] = stim_df["stimPar2"] stim_df = stim_df.drop(columns=["stimPar1", "stimPar2"]) for col in stim_df.columns: if "gab" in col: stim_df.loc[stim_df["stimtype"] != "g", col] = -1 if "visflow" in col: stim_df.loc[stim_df["stimtype"] != "v", col] = -1 # expand on direction info for direc in ["right", "left"]: stim_df.loc[(stim_df["visflow_dir"] == direc), "visflow_dir"] = \ sess_gen_util.get_visflow_screen_mouse_direc(direc) stim2twopfr = align["stim_align"].astype("int") twop_fps = sess_sync_util.get_frame_rate(stim_sync_h5)[0] twop_fr_stim = int(max(align["stim_align"])) return stim_df, stimtype_order, stim2twopfr, twop_fps, twop_fr_stim
def plot_summ(output, savename, stimtype="gabors", comp="unexp", ctrl=False, visflow_dir="both", fluor="dff", scale=True, CI=0.95, plt_bkend=None, fontdir=None, modif=False): """ plot_summ(output) Plots summary data for a specific comparison, for each datatype in a separate figure and saves figures. Required args: - output (str) : general directory in which summary dataframe is saved (runtype and q1v4 values are inferred from the directory name) - savename (str): name of the dataframe containing summary data to plot Optional args: - stimtype (str) : stimulus type default: "gabors" - comp (str) : type of comparison default: "unexp" - ctrl (bool) : if True, control comparisons are analysed default: False - visflow_dir (str): visual flow direction default: "both" - fluor (str) : fluorescence trace type default: "dff" - scale (bool) : whether ROIs are scaled default: True - CI (num) : CI for shuffled data default: 0.95 - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg") default: None - fontdir (str) : directory in which additional fonts are located default: None - modif (bool) : if True, plots are made in a modified (simplified way) default: False """ plot_util.manage_mpl(plt_bkend, fontdir=fontdir) summ_scores_file = Path(output, savename) if summ_scores_file.is_file(): summ_scores = file_util.loadfile(summ_scores_file) else: warnings.warn(f"{summ_scores_file} not found.", category=RuntimeWarning, stacklevel=1) return if len(summ_scores) == 0: warnings.warn(f"No data in {summ_scores_file}.", category=RuntimeWarning, stacklevel=1) return # drop NaN lines summ_scores = summ_scores.loc[~summ_scores["epoch_n_mean"].isna()] data_types = ["epoch_n", "test_acc", "test_acc_bal"] data_titles = ["Epoch nbrs", "Test accuracy", "Test accuracy (balanced)"] stats = ["mean", "sem", "sem"] shuff_stats = ["median"] + math_util.get_percentiles(CI)[1] q1v4, evu = False, False if "_q1v4" in str(output): q1v4 = True elif "_evu" in str(output): evu = True runtype = "prod" if "pilot" in str(output): runtype = "pilot" if stimtype == "gabors": visflow_dir = "none" stim_str = "gab" stim_str_pr = "gabors" else: visflow_dir = sess_gen_util.get_params(stimtype, visflow_dir)[0] if isinstance(visflow_dir, list) and len(visflow_dir) == 2: visflow_dir = "both" stim_str = sess_str_util.dir_par_str(visflow_dir, str_type="file") stim_str_pr = sess_str_util.dir_par_str(visflow_dir, str_type="print") scale_str = sess_str_util.scale_par_str(scale, "file") scale_str_pr = sess_str_util.scale_par_str(scale, "file").replace("_", " ") ctrl_str = sess_str_util.ctrl_par_str(ctrl) ctrl_str_pr = sess_str_util.ctrl_par_str(ctrl, str_type="print") modif_str = "_modif" if modif else "" save_dir = Path(output, f"figures_{fluor}") save_dir.mkdir(exist_ok=True) cols = ["scale", "fluor", "visflow_dir", "runtype"] cri = [scale, fluor, visflow_dir, runtype] plot_lines = gen_util.get_df_vals(summ_scores, cols, cri) cri_str = ", ".join([f"{col}: {crit}" for col, crit in zip(cols, cri)]) if len(plot_lines) == 0: # no data warnings.warn(f"No data found for {cri_str}", category=RuntimeWarning, stacklevel=1) return else: # shuffle or non shuffle missing skip = False for shuff in [False, True]: if shuff not in plot_lines["shuffle"].tolist(): warnings.warn(f"No shuffle={shuff} data found for {cri_str}", category=RuntimeWarning, stacklevel=1) skip = True if skip: return for data, data_title in zip(data_types, data_titles): if not modif: title = (f"{stim_str_pr.capitalize()} {comp}{ctrl_str_pr} - " f"{data_title} for log regr on\n" + u"{} {} ".format(scale_str_pr, fluor) + f"data ({runtype})") else: title = (f"{stim_str_pr.capitalize()} {comp}{ctrl_str_pr}\n" f"{data_title}") if "_" in title: title = title.replace("_", " ") savename = (f"{data}_{stim_str}_{comp}{ctrl_str}{scale_str}" f"{modif_str}.svg") full_savename = Path(save_dir, savename) plot_data_summ(plot_lines, data, stats, shuff_stats, title, full_savename, CI, q1v4, evu, comp, modif) plot_util.cond_close_figs()
def plot_traces_scores(hyperpars, tr_stats=None, full_scores=None, plot_wei=True, savedir=None): """ plot_traces_scores(hyperpars) Plots training traces and scores for a logistic regression analysis run. Required args: - hyperpars (dict): ["analyspar"] (dict): dictionary with keys of analyspar named tuple ["extrapar"] (dict) : dictionary with extra parameters ["classes"] (list) : class names ["dirname"] (str) : directory in which data and plots are saved ["loss_name"] (str): name of loss ["shuffle"] (bool) : if True, data was shuffled ["logregpar"] (dict): dictionary with keys of logregpar named tuple ["sesspar"] (dict) : dictionary with keys of sesspar named tuple ["stimpar"] (dict) : dictionary with keys of stimpar named tuple Optional args: - tr_stats (dict) : dictionary of trace stats data ["n_rois"] (int) : number of ROIs ["train_ns"] (list) : number of segments per class ["train_class_stats"] (3D array) : training statistics, structured as class x stats (me, err) x frames ["xran"] (array-like) : x values for frames optionally, if an additional named set (e.g., "test_Q4") is passed: ["set_ns"] (list) : number of segments per class ["set_class_stats"] (3D array): trace statistics, structured as class x stats (me, err) x frames - full_scores (pd DataFrame): dataframe in which scores are recorded, for each epoch - plot_wei (bool or int) : if True, weights are plotted in a subplot. Or if int, index of model to plot. default: True - savedir (Path) : directory in which to save figure (used instead of extrapar["dirname"], if passed) default: None """ analyspar = hyperpars["analyspar"] sesspar = hyperpars["sesspar"] stimpar = hyperpars["stimpar"] logregpar = hyperpars["logregpar"] extrapar = hyperpars["extrapar"] if savedir is None: savedir = extrapar["dirname"] if tr_stats is None: tr_stats_path = Path(savedir, "tr_stats.json") if tr_stats_path.is_file(): tr_stats = file_util.loadfile(tr_stats_path) else: warnings.warn("No trace statistics found.", category=RuntimeWarning, stacklevel=1) if full_scores is None: full_scores_path = Path(savedir, "scores_df.csv") if full_scores_path.is_file(): full_scores = file_util.loadfile(full_scores_path) if plot_wei and logregpar["alg"] == "sklearn": saved = full_scores.loc[ full_scores["saved"] == 1]["run_n"].tolist() if len(saved) > 0: plot_wei = saved[0] else: warnings.warn("No scores dataframe found.", category=RuntimeWarning, stacklevel=1) if tr_stats is not None: plot_class_traces(analyspar, sesspar, stimpar, logregpar, tr_stats, extrapar["classes"], extrapar["shuffle"], plot_wei=plot_wei, modeldir=savedir, savedir=savedir) if full_scores is not None: plot_scores( analyspar, sesspar, stimpar, logregpar, extrapar, full_scores, savedir=savedir)