def plot_from_dict(direc, plt_bkend=None, fontdir=None):
    """
    plot_from_dict(direc)

    Plots data from dictionaries containing analysis parameters and results, or 
    path to results.

    Required args:
        - direc (Path): path to directory in which dictionaries to plot data 
                        from are located
    
    Optional_args:
        - plt_bkend (str): mpl backend to use for plotting (e.g., "agg")
                           default: None
        - fontdir (Path) : directory in which additional fonts are stored
                           default: None
    """

    logger.info(f"Plotting from hyperparameters in: {direc}", 
        extra={"spacing": "\n"})

    direc = Path(direc)

    plot_util.manage_mpl(plt_bkend, fontdir=fontdir)
    hyperpars = file_util.loadfile("hyperparameters.json", fulldir=direc)

    if "logregpar" in hyperpars.keys():
        plot_traces_scores(hyperpars, savedir=direc)

    plot_util.cond_close_figs()
def plot_from_dict(dict_path, plt_bkend=None, fontdir=None, parallel=False, 
                   datetime=True, overwrite=False):
    """
    plot_from_dict(dict_path)

    Plots data from dictionaries containing analysis parameters and results.

    Required args:
        - dict_path (Path): path to dictionary to plot data from
    
    Optional_args:
        - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg")
                            default: None
        - fontdir (Path)  : path to directory where additional fonts are stored
                            default: None
        - parallel (bool) : if True, some of the analysis is parallelized 
                            across CPU cores
                            default: False
        - datetime (bool) : figpar["save"] datatime parameter (whether to 
                            place figures in a datetime folder)
                            default: True
        - overwrite (bool): figpar["save"] overwrite parameter (whether to 
                            overwrite figures)
                            default: False
    """

    logger.info(f"Plotting from dictionary: {dict_path}", 
        extra={"spacing": "\n"})
    
    figpar = sess_plot_util.init_figpar(
        plt_bkend=plt_bkend, fontdir=fontdir, datetime=datetime, 
        overwrite=overwrite
        )
    plot_util.manage_mpl(cmap=False, **figpar["mng"])

    dict_path = Path(dict_path)

    info = file_util.loadfile(dict_path)
    savedir = dict_path.parent

    analysis = info["extrapar"]["analysis"]

    # 0. Plots the correlation between pupil and roi/run changes for each 
    # session
    if analysis == "c": # difference correlation
        plot_pup_diff_corr(figpar=figpar, savedir=savedir, **info)
    
    # difference correlation per ROI between stimuli
    elif analysis == "r": 
        plot_pup_roi_stim_corr(figpar=figpar, savedir=savedir, **info)

    else:
        warnings.warn(f"No plotting function for analysis {analysis}", 
            category=UserWarning, stacklevel=1)

    plot_util.cond_close_figs()
def run_analyses(sessions,
                 analysis_dict,
                 analyses,
                 datatype="roi",
                 seed=None,
                 parallel=False):
    """
    run_analyses(sessions, analysis_dict, analyses)

    Runs requested analyses on sessions using the parameters passed.

    Required args:
        - sessions (list)     : list of sessions, possibly nested
        - analysis_dict (dict): analysis parameter dictionary 
                                (see init_param_cont())
        - analyses (str)      : analyses to run
    
    Optional args:
        - datatype (str) : datatype ("run", "roi")
                           default: "roi"
        - seed (int)     : seed to use
                           default: None
        - parallel (bool): if True, some analyses are parallelized 
                           across CPU cores 
                           default: False
    """

    if len(sessions) == 0:
        logger.warning("No sessions meet these criteria.")
        return

    # changes backend and defaults
    plot_util.manage_mpl(cmap=False, **analysis_dict["figpar"]["mng"])
    sess_plot_util.update_plt_linpla()

    fct_dict = get_analysis_fcts()

    args_dict = copy.deepcopy(analysis_dict)
    for key, item in zip(["seed", "parallel", "datatype"],
                         [seed, parallel, datatype]):
        args_dict[key] = item

    # run through analyses
    for analysis in analyses:
        if analysis not in fct_dict.keys():
            raise ValueError(f"{analysis} analysis not found.")
        fct, datatype_req = fct_dict[analysis]
        if datatype not in datatype_req:
            continue
        args_dict_use = gen_util.keep_dict_keys(
            args_dict,
            inspect.getfullargspec(fct).args)
        fct(sessions=sessions, analysis=analysis, **args_dict_use)

        plot_util.cond_close_figs()
예제 #4
0
def run_analyses(sessions, analysis_dict, analyses, seed=None,
                 parallel=False, plot_tc=True):
    """
    run_analyses(sessions, analysis_dict, analyses)

    Runs requested analyses on sessions using the parameters passed.

    Required args:
        - sessions (list)     : list of sessions, possibly nested
        - analysis_dict (dict): analysis parameter dictionary 
                                (see init_param_cont())
        - analyses (str)      : analyses to run

    Optional args:
        - seed (int)     : seed to use
                           default: None
        - parallel (bool): if True, some analyses are parallelized 
                           across CPU cores 
                           default: False
        - plot_tc (bool) : if True, tuning curves are plotted for each ROI  
                           default: True
    """

    if len(sessions) == 0:
        logger.warning("No sessions meet these criteria.")
        return

    comp = True if isinstance(sessions[0], list) else False

    # changes backend and defaults
    plot_util.manage_mpl(cmap=False, **analysis_dict["figpar"]["mng"])

    fct_dict = get_analysis_fcts()

    args_dict = copy.deepcopy(analysis_dict)
    for key, item in zip(["seed", "parallel", "plot_tc"], 
        [seed, parallel, plot_tc]):
        args_dict[key] = item

    # run through analyses
    for analysis in analyses:
        if analysis not in fct_dict.keys():
            raise ValueError(f"{analysis} analysis not found.")
        fct, comp_req = fct_dict[analysis]
        if comp_req != comp:
            continue
        args_dict_use = gen_util.keep_dict_keys(
            args_dict, inspect.getfullargspec(fct).args)
        fct(sessions=sessions, analysis=analysis, **args_dict_use)

        plot_util.cond_close_figs()
예제 #5
0
def plot_from_dict(dict_path,
                   plt_bkend=None,
                   fontdir=None,
                   plot_tc=True,
                   parallel=False,
                   datetime=True,
                   overwrite=False):
    """
    plot_from_dict(dict_path)

    Plots data from dictionaries containing analysis parameters and results.

    Required args:
        - dict_path (Path): path to dictionary to plot data from
    
    Optional_args:
        - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg")
                            default: None
        - fontdir (Path)  : path to directory where additional fonts are stored
                            default: None
        - plot_tc (bool)  : if True, tuning curves are plotted for each ROI 
                            (dummy argument)
                            default: True
        - parallel (bool) : if True, some of the analysis is parallelized 
                            across CPU cores
                            default: False
        - datetime (bool) : figpar["save"] datatime parameter (whether to 
                            place figures in a datetime folder)
                            default: True
        - overwrite (bool): figpar["save"] overwrite parameter (whether to 
                            overwrite figures)
                            default: False
    """

    logger.info(f"Plotting from dictionary: {dict_path}",
                extra={"spacing": "\n"})

    figpar = sess_plot_util.init_figpar(plt_bkend=plt_bkend,
                                        fontdir=fontdir,
                                        datetime=datetime,
                                        overwrite=overwrite)
    plot_util.manage_mpl(cmap=False, **figpar["mng"])

    plt.rcParams["figure.titlesize"] = "xx-large"
    plt.rcParams["axes.titlesize"] = "xx-large"

    dict_path = Path(dict_path)

    info = file_util.loadfile(dict_path)
    savedir = dict_path.parent

    analysis = info["extrapar"]["analysis"]

    # 1. Plot average traces by quantile x unexpected for each session
    if analysis == "t":  # traces
        gen_plots.plot_traces_by_qu_unexp_sess(figpar=figpar,
                                               savedir=savedir,
                                               modif=True,
                                               **info)

    # 2. Plot average traces by quantile, locked to unexpected for each session
    elif analysis == "l":  # unexpected locked traces
        gen_plots.plot_traces_by_qu_lock_sess(figpar=figpar,
                                              savedir=savedir,
                                              modif=True,
                                              **info)

    else:
        warnings.warn(f"No modified plotting option for analysis {analysis}",
                      category=UserWarning,
                      stacklevel=1)

    plot_util.cond_close_figs()
예제 #6
0
def plot_from_dict(dict_path, plt_bkend=None, fontdir=None, parallel=False, 
                   datetime=True, overwrite=False):
    """
    plot_from_dict(dict_path)

    Plots data from dictionaries containing analysis parameters and results.

    Required args:
        - dict_path (Path): path to dictionary to plot data from
    
    Optional_args:
        - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg")
                            default: None
        - fontdir (Path)  : path to directory where additional fonts are stored
                            default: None
        - parallel (bool) : if True, some of the plotting is parallelized 
                            across CPU cores
                            default: False
        - datetime (bool) : figpar["save"] datatime parameter (whether to 
                            place figures in a datetime folder)
                            default: True
        - overwrite (bool): figpar["save"] overwrite parameter (whether to 
                            overwrite figures)
                            default: False
    """

    logger.info(f"Plotting from dictionary: {dict_path}", 
        extra={"spacing": "\n"})
    
    figpar = sess_plot_util.init_figpar(
        plt_bkend=plt_bkend, fontdir=fontdir, datetime=datetime, 
        overwrite=overwrite
        )
    plot_util.manage_mpl(cmap=False, **figpar["mng"])

    dict_path = Path(dict_path)

    info = file_util.loadfile(dict_path)
    savedir = dict_path.parent

    analysis = info["extrapar"]["analysis"]

    # 0. Plots the full traces for each session
    if analysis == "f": # full traces
        plot_full_traces(figpar=figpar, savedir=savedir, **info)

    # 1. Plot average traces by quantile x unexpected for each session 
    elif analysis == "t": # traces
        plot_traces_by_qu_unexp_sess(figpar=figpar, savedir=savedir, **info)

    # 2. Plot average traces by quantile, locked to unexpected for each session 
    elif analysis == "l": # unexpected locked traces
        plot_traces_by_qu_lock_sess(figpar=figpar, savedir=savedir, **info)

    # 3. Plot magnitude of change in dF/F area from first to last quantile of 
    # unexpected vs no unexpected sequences, for each session
    elif analysis == "m": # mag
        plot_mag_change(figpar=figpar, savedir=savedir, **info)

    # 4. Plot autocorrelations
    elif analysis == "a": # autocorr
        plot_autocorr(figpar=figpar, savedir=savedir, **info)

    else:
        warnings.warn(f"No plotting function for analysis {analysis}", 
            category=UserWarning, stacklevel=1)

    plot_util.cond_close_figs()
def plot_summ(output, savename, stimtype="gabors", comp="unexp", ctrl=False, 
              visflow_dir="both", fluor="dff", scale=True, CI=0.95, 
              plt_bkend=None, fontdir=None, modif=False):
    """
    plot_summ(output)

    Plots summary data for a specific comparison, for each datatype in a 
    separate figure and saves figures. 

    Required args:
        - output (str)  : general directory in which summary dataframe 
                          is saved (runtype and q1v4 values are inferred from 
                          the directory name)
        - savename (str): name of the dataframe containing summary data to plot
            
    Optional args:
        - stimtype (str)   : stimulus type
                             default: "gabors"
        - comp (str)       : type of comparison
                             default: "unexp"
        - ctrl (bool)      : if True, control comparisons are analysed
                             default: False                           
        - visflow_dir (str): visual flow direction
                             default: "both"
        - fluor (str)      : fluorescence trace type
                             default: "dff"
        - scale (bool)     : whether ROIs are scaled
                             default: True
        - CI (num)         : CI for shuffled data
                             default: 0.95
        - plt_bkend (str)  : mpl backend to use for plotting (e.g., "agg")
                             default: None
        - fontdir (str)    : directory in which additional fonts are located
                             default: None
        - modif (bool)     : if True, plots are made in a modified (simplified 
                             way)
                             default: False

    """
    
    plot_util.manage_mpl(plt_bkend, fontdir=fontdir)

    summ_scores_file = Path(output, savename)
    
    if summ_scores_file.is_file():
        summ_scores = file_util.loadfile(summ_scores_file)
    else:
        warnings.warn(f"{summ_scores_file} not found.", 
            category=RuntimeWarning, stacklevel=1)
        return

    if len(summ_scores) == 0:
        warnings.warn(f"No data in {summ_scores_file}.", 
            category=RuntimeWarning, stacklevel=1)
        return

    # drop NaN lines
    summ_scores = summ_scores.loc[~summ_scores["epoch_n_mean"].isna()]

    data_types  = ["epoch_n", "test_acc", "test_acc_bal"]
    data_titles = ["Epoch nbrs", "Test accuracy", "Test accuracy (balanced)"]

    stats = ["mean", "sem", "sem"]
    shuff_stats = ["median"] + math_util.get_percentiles(CI)[1]

    q1v4, evu = False, False
    if "_q1v4" in str(output):
        q1v4 = True
    elif "_evu" in str(output):
        evu = True
    
    runtype = "prod"
    if "pilot" in str(output):
        runtype = "pilot"

    if stimtype == "gabors":
        visflow_dir = "none"
        stim_str = "gab"
        stim_str_pr = "gabors"

    else:
        visflow_dir = sess_gen_util.get_params(stimtype, visflow_dir)[0]
        if isinstance(visflow_dir, list) and len(visflow_dir) == 2:
            visflow_dir = "both"
        stim_str = sess_str_util.dir_par_str(visflow_dir, str_type="file")
        stim_str_pr = sess_str_util.dir_par_str(visflow_dir, str_type="print")

    scale_str = sess_str_util.scale_par_str(scale, "file")
    scale_str_pr = sess_str_util.scale_par_str(scale, "file").replace("_", " ")
    ctrl_str = sess_str_util.ctrl_par_str(ctrl)
    ctrl_str_pr = sess_str_util.ctrl_par_str(ctrl, str_type="print")
    modif_str = "_modif" if modif else ""

    save_dir = Path(output, f"figures_{fluor}")
    save_dir.mkdir(exist_ok=True)
    
    cols = ["scale", "fluor", "visflow_dir", "runtype"]
    cri  = [scale, fluor, visflow_dir, runtype]
    plot_lines = gen_util.get_df_vals(summ_scores, cols, cri)
    cri_str = ", ".join([f"{col}: {crit}" for col, crit in zip(cols, cri)])
    if len(plot_lines) == 0: # no data
        warnings.warn(f"No data found for {cri_str}", category=RuntimeWarning, 
            stacklevel=1)
        return
    else: # shuffle or non shuffle missing
        skip = False
        for shuff in [False, True]:
            if shuff not in plot_lines["shuffle"].tolist():
                warnings.warn(f"No shuffle={shuff} data found for {cri_str}", 
                    category=RuntimeWarning, stacklevel=1)
                skip = True
        if skip:
            return

    for data, data_title in zip(data_types, data_titles):
        if not modif:
            title = (f"{stim_str_pr.capitalize()} {comp}{ctrl_str_pr} - "
                f"{data_title} for log regr on\n" + 
                u"{} {} ".format(scale_str_pr, fluor) + 
                f"data ({runtype})")
        else:
            title = (f"{stim_str_pr.capitalize()} {comp}{ctrl_str_pr}\n"
                f"{data_title}")

        if "_" in title:
            title = title.replace("_", " ")

        savename = (f"{data}_{stim_str}_{comp}{ctrl_str}{scale_str}"
            f"{modif_str}.svg")
        full_savename = Path(save_dir, savename)
        plot_data_summ(plot_lines, data, stats, shuff_stats, title, 
            full_savename, CI, q1v4, evu, comp, modif)

    plot_util.cond_close_figs()
예제 #8
0
def run_single_panel(args, sessions=None, new_fig=False):
    """
    run_single_panel(args, sessions=None, new_fig=False)

    Runs analyses and plots a single panel.

    Required args:
        - args (dict): 
            parser argument dictionary

    Optional args:
        - sessions (list):
            preloaded Session objects
            default: None
        - new_fig (bool):
            if True, a new figure is being plotted 
            (additional separator is logged)
            default: False

    Returns:
        - sessions (list):
            loaded Session objects
    """

    analysis_dict = init_analysis(args)
    fig_panel_analysis = analysis_dict["figpar"]["fig_panel_analysis"]

    # changes backend and defaults
    plot_util.manage_mpl(cmap=False, **analysis_dict["figpar"]["mng"])
    sess_plot_util.update_plt_linpla()

    action = "Running analysis and producing plot"
    if args.plot_only:
        action = "Producing plot"

    sep = DOUBLE_SEP if new_fig else SEP
    logger.info(
        f"{sep}Fig. {fig_panel_analysis.figure}{fig_panel_analysis.panel}. "
        f"{action}: {fig_panel_analysis.description}",
        extra={"spacing": "\n"})

    # Log any relevant warnings to the console
    fig_panel_analysis.log_warnings(plot_only=args.plot_only)

    # Check if analysis needs to be rerun, and if not, replots only.
    run_analysis, data_path = \
        helper_fcts.check_if_data_exists(
            analysis_dict["figpar"], overwrite_plot_only=args.plot_only,
            raise_no_data=False
            )

    if not run_analysis:
        return
    elif args.plot_only:
        logger.warning(
            f"Skipping plot, as no analysis data was found under {data_path}.",
            extra={"spacing": "\n"})
        return

    sessions = init_sessions(
        analyspar=analysis_dict["analyspar"],
        sesspar=analysis_dict["sesspar"],
        mouse_df=args.mouse_df_path,
        datadir=args.datadir,
        sessions=sessions,
        roi=fig_panel_analysis.specific_params["roi"],
        run=fig_panel_analysis.specific_params["run"],
        pupil=fig_panel_analysis.specific_params["pupil"],
        parallel=args.parallel,
    )

    analysis_dict["seed"] = fig_panel_analysis.seed
    analysis_dict["parallel"] = bool(args.parallel * (not args.debug))

    analysis_fct = fig_panel_analysis.analysis_fct
    analysis_dict_use = gen_util.keep_dict_keys(
        analysis_dict,
        inspect.getfullargspec(analysis_fct).args)

    analysis_fct(sessions=sessions, **analysis_dict_use)

    return sessions