def plot_from_dict(direc, plt_bkend=None, fontdir=None):
    """
    plot_from_dict(direc)

    Plots data from dictionaries containing analysis parameters and results, or 
    path to results.

    Required args:
        - direc (Path): path to directory in which dictionaries to plot data 
                        from are located
    
    Optional_args:
        - plt_bkend (str): mpl backend to use for plotting (e.g., "agg")
                           default: None
        - fontdir (Path) : directory in which additional fonts are stored
                           default: None
    """

    logger.info(f"Plotting from hyperparameters in: {direc}", 
        extra={"spacing": "\n"})

    direc = Path(direc)

    plot_util.manage_mpl(plt_bkend, fontdir=fontdir)
    hyperpars = file_util.loadfile("hyperparameters.json", fulldir=direc)

    if "logregpar" in hyperpars.keys():
        plot_traces_scores(hyperpars, savedir=direc)

    plot_util.cond_close_figs()
def plot_from_dict(dict_path, plt_bkend=None, fontdir=None, parallel=False, 
                   datetime=True, overwrite=False):
    """
    plot_from_dict(dict_path)

    Plots data from dictionaries containing analysis parameters and results.

    Required args:
        - dict_path (Path): path to dictionary to plot data from
    
    Optional_args:
        - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg")
                            default: None
        - fontdir (Path)  : path to directory where additional fonts are stored
                            default: None
        - parallel (bool) : if True, some of the analysis is parallelized 
                            across CPU cores
                            default: False
        - datetime (bool) : figpar["save"] datatime parameter (whether to 
                            place figures in a datetime folder)
                            default: True
        - overwrite (bool): figpar["save"] overwrite parameter (whether to 
                            overwrite figures)
                            default: False
    """

    logger.info(f"Plotting from dictionary: {dict_path}", 
        extra={"spacing": "\n"})
    
    figpar = sess_plot_util.init_figpar(
        plt_bkend=plt_bkend, fontdir=fontdir, datetime=datetime, 
        overwrite=overwrite
        )
    plot_util.manage_mpl(cmap=False, **figpar["mng"])

    dict_path = Path(dict_path)

    info = file_util.loadfile(dict_path)
    savedir = dict_path.parent

    analysis = info["extrapar"]["analysis"]

    # 0. Plots the correlation between pupil and roi/run changes for each 
    # session
    if analysis == "c": # difference correlation
        plot_pup_diff_corr(figpar=figpar, savedir=savedir, **info)
    
    # difference correlation per ROI between stimuli
    elif analysis == "r": 
        plot_pup_roi_stim_corr(figpar=figpar, savedir=savedir, **info)

    else:
        warnings.warn(f"No plotting function for analysis {analysis}", 
            category=UserWarning, stacklevel=1)

    plot_util.cond_close_figs()
def run_analyses(sessions,
                 analysis_dict,
                 analyses,
                 datatype="roi",
                 seed=None,
                 parallel=False):
    """
    run_analyses(sessions, analysis_dict, analyses)

    Runs requested analyses on sessions using the parameters passed.

    Required args:
        - sessions (list)     : list of sessions, possibly nested
        - analysis_dict (dict): analysis parameter dictionary 
                                (see init_param_cont())
        - analyses (str)      : analyses to run
    
    Optional args:
        - datatype (str) : datatype ("run", "roi")
                           default: "roi"
        - seed (int)     : seed to use
                           default: None
        - parallel (bool): if True, some analyses are parallelized 
                           across CPU cores 
                           default: False
    """

    if len(sessions) == 0:
        logger.warning("No sessions meet these criteria.")
        return

    # changes backend and defaults
    plot_util.manage_mpl(cmap=False, **analysis_dict["figpar"]["mng"])
    sess_plot_util.update_plt_linpla()

    fct_dict = get_analysis_fcts()

    args_dict = copy.deepcopy(analysis_dict)
    for key, item in zip(["seed", "parallel", "datatype"],
                         [seed, parallel, datatype]):
        args_dict[key] = item

    # run through analyses
    for analysis in analyses:
        if analysis not in fct_dict.keys():
            raise ValueError(f"{analysis} analysis not found.")
        fct, datatype_req = fct_dict[analysis]
        if datatype not in datatype_req:
            continue
        args_dict_use = gen_util.keep_dict_keys(
            args_dict,
            inspect.getfullargspec(fct).args)
        fct(sessions=sessions, analysis=analysis, **args_dict_use)

        plot_util.cond_close_figs()
示例#4
0
def run_analyses(sessions, analysis_dict, analyses, seed=None,
                 parallel=False, plot_tc=True):
    """
    run_analyses(sessions, analysis_dict, analyses)

    Runs requested analyses on sessions using the parameters passed.

    Required args:
        - sessions (list)     : list of sessions, possibly nested
        - analysis_dict (dict): analysis parameter dictionary 
                                (see init_param_cont())
        - analyses (str)      : analyses to run

    Optional args:
        - seed (int)     : seed to use
                           default: None
        - parallel (bool): if True, some analyses are parallelized 
                           across CPU cores 
                           default: False
        - plot_tc (bool) : if True, tuning curves are plotted for each ROI  
                           default: True
    """

    if len(sessions) == 0:
        logger.warning("No sessions meet these criteria.")
        return

    comp = True if isinstance(sessions[0], list) else False

    # changes backend and defaults
    plot_util.manage_mpl(cmap=False, **analysis_dict["figpar"]["mng"])

    fct_dict = get_analysis_fcts()

    args_dict = copy.deepcopy(analysis_dict)
    for key, item in zip(["seed", "parallel", "plot_tc"], 
        [seed, parallel, plot_tc]):
        args_dict[key] = item

    # run through analyses
    for analysis in analyses:
        if analysis not in fct_dict.keys():
            raise ValueError(f"{analysis} analysis not found.")
        fct, comp_req = fct_dict[analysis]
        if comp_req != comp:
            continue
        args_dict_use = gen_util.keep_dict_keys(
            args_dict, inspect.getfullargspec(fct).args)
        fct(sessions=sessions, analysis=analysis, **args_dict_use)

        plot_util.cond_close_figs()
示例#5
0
def plot_from_dict(dict_path,
                   plt_bkend=None,
                   fontdir=None,
                   plot_tc=True,
                   parallel=False,
                   datetime=True,
                   overwrite=False):
    """
    plot_from_dict(dict_path)

    Plots data from dictionaries containing analysis parameters and results.

    Required args:
        - dict_path (Path): path to dictionary to plot data from
    
    Optional_args:
        - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg")
                            default: None
        - fontdir (Path)  : path to directory where additional fonts are stored
                            default: None
        - plot_tc (bool)  : if True, tuning curves are plotted for each ROI 
                            (dummy argument)
                            default: True
        - parallel (bool) : if True, some of the analysis is parallelized 
                            across CPU cores
                            default: False
        - datetime (bool) : figpar["save"] datatime parameter (whether to 
                            place figures in a datetime folder)
                            default: True
        - overwrite (bool): figpar["save"] overwrite parameter (whether to 
                            overwrite figures)
                            default: False
    """

    logger.info(f"Plotting from dictionary: {dict_path}",
                extra={"spacing": "\n"})

    figpar = sess_plot_util.init_figpar(plt_bkend=plt_bkend,
                                        fontdir=fontdir,
                                        datetime=datetime,
                                        overwrite=overwrite)
    plot_util.manage_mpl(cmap=False, **figpar["mng"])

    plt.rcParams["figure.titlesize"] = "xx-large"
    plt.rcParams["axes.titlesize"] = "xx-large"

    dict_path = Path(dict_path)

    info = file_util.loadfile(dict_path)
    savedir = dict_path.parent

    analysis = info["extrapar"]["analysis"]

    # 1. Plot average traces by quantile x unexpected for each session
    if analysis == "t":  # traces
        gen_plots.plot_traces_by_qu_unexp_sess(figpar=figpar,
                                               savedir=savedir,
                                               modif=True,
                                               **info)

    # 2. Plot average traces by quantile, locked to unexpected for each session
    elif analysis == "l":  # unexpected locked traces
        gen_plots.plot_traces_by_qu_lock_sess(figpar=figpar,
                                              savedir=savedir,
                                              modif=True,
                                              **info)

    else:
        warnings.warn(f"No modified plotting option for analysis {analysis}",
                      category=UserWarning,
                      stacklevel=1)

    plot_util.cond_close_figs()
def plot_from_dicts(direc,
                    source="roi",
                    plt_bkend=None,
                    fontdir=None,
                    plot_tc=True,
                    parallel=False,
                    datetime=True,
                    overwrite=False,
                    pattern="",
                    depth=0):
    """
    plot_from_dicts(direc)

    Plots data from dictionaries containing analysis parameters and results, or 
    path to results.

    Required args:
        - direc (Path): path to directory in which dictionaries to plot data 
                        from are located or path to a single json file
    
    Optional_args:
        - source (str)    : plotting source ("roi", "run", "gen", "pup", 
                            "modif", "logreg", "glm")
        - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg")
                            default: None
        - fontdir (Path)  : directory in which additional fonts are stored
                            default: None
        - plot_tc (bool)  : if True, tuning curves are plotted for each ROI 
                            default: True
        - parallel (bool) : if True, some of the analysis is parallelized 
                            across CPU cores
                            default: False
        - datetime (bool) : figpar["save"] datatime parameter (whether to 
                            place figures in a datetime folder)
                            default: True
        - overwrite (bool): figpar["save"] overwrite parameter (whether to 
                            overwrite figures)
                            default: False
        - pattern (str)   : pattern based on which to include json files in 
                            direc if direc is a directory
                            default: ""
        - depth (int)     : maximum depth at which to check for json files if 
                            direc is a directory
                            default: 0
    """

    file_util.checkexists(direc)

    direc = Path(direc)

    if direc.is_dir():
        if source == "logreg":
            targ_file = "hyperparameters.json"
        else:
            targ_file = "*.json"

        dict_paths = []
        for d in range(depth + 1):
            dict_paths.extend(
                glob.glob(str(Path(direc, *(["*"] * d), targ_file))))

        dict_paths = list(filter(lambda x: pattern in str(x), dict_paths))

        if source == "logreg":
            dict_paths = [Path(dp).parent for dp in dict_paths]

        if len(dict_paths) == 0:
            raise OSError(f"No jsons found in {direc} at "
                          f"depth {depth} with pattern '{pattern}'.")

    elif ".json" not in str(direc):
        raise ValueError("If providing a file, must be a json file.")
    else:
        if (source == "logreg"
                and not str(direc).endswith("hyperparameters.json")):
            raise ValueError("For logreg source, must provide path to "
                             "a hyperparameters json file.")

        dict_paths = [direc]

    dict_paths = sorted(dict_paths)
    if len(dict_paths) > 1:
        logger.info(f"Plotting from {len(dict_paths)} dictionaries.")

    fontdir = Path(fontdir) if fontdir is not None else fontdir
    args_dict = {
        "plt_bkend": plt_bkend,
        "fontdir": fontdir,
        "plot_tc": plot_tc,
        "datetime": datetime,
        "overwrite": overwrite,
    }

    pass_parallel = True
    sources = [
        "roi", "run", "gen", "modif", "pup", "logreg", "glm", "acr_sess"
    ]
    if source == "roi":
        fct = roi_plots.plot_from_dict
    elif source in ["run", "gen"]:
        fct = gen_plots.plot_from_dict
    elif source in ["pup", "pupil"]:
        fct = pup_plots.plot_from_dict
    elif source == "modif":
        fct = mod_plots.plot_from_dict
    elif source == "logreg":
        pass_parallel = False
        fct = logreg_plots.plot_from_dict
    elif source == "glm":
        pass_parallel = False
        fct = glm_plots.plot_from_dict
    elif source == "acr_sess":
        fct = acr_sess_plots.plot_from_dict
    else:
        gen_util.accepted_values_error("source", source, sources)

    args_dict = gen_util.keep_dict_keys(args_dict,
                                        inspect.getfullargspec(fct).args)
    gen_util.parallel_wrap(fct,
                           dict_paths,
                           args_dict=args_dict,
                           parallel=parallel,
                           pass_parallel=pass_parallel)

    plot_util.cond_close_figs()
示例#7
0
def plot_from_dict(dict_path, plt_bkend=None, fontdir=None, parallel=False, 
                   datetime=True, overwrite=False):
    """
    plot_from_dict(dict_path)

    Plots data from dictionaries containing analysis parameters and results.

    Required args:
        - dict_path (Path): path to dictionary to plot data from
    
    Optional_args:
        - plt_bkend (str) : mpl backend to use for plotting (e.g., "agg")
                            default: None
        - fontdir (Path)  : path to directory where additional fonts are stored
                            default: None
        - parallel (bool) : if True, some of the plotting is parallelized 
                            across CPU cores
                            default: False
        - datetime (bool) : figpar["save"] datatime parameter (whether to 
                            place figures in a datetime folder)
                            default: True
        - overwrite (bool): figpar["save"] overwrite parameter (whether to 
                            overwrite figures)
                            default: False
    """

    logger.info(f"Plotting from dictionary: {dict_path}", 
        extra={"spacing": "\n"})
    
    figpar = sess_plot_util.init_figpar(
        plt_bkend=plt_bkend, fontdir=fontdir, datetime=datetime, 
        overwrite=overwrite
        )
    plot_util.manage_mpl(cmap=False, **figpar["mng"])

    dict_path = Path(dict_path)

    info = file_util.loadfile(dict_path)
    savedir = dict_path.parent

    analysis = info["extrapar"]["analysis"]

    # 0. Plots the full traces for each session
    if analysis == "f": # full traces
        plot_full_traces(figpar=figpar, savedir=savedir, **info)

    # 1. Plot average traces by quantile x unexpected for each session 
    elif analysis == "t": # traces
        plot_traces_by_qu_unexp_sess(figpar=figpar, savedir=savedir, **info)

    # 2. Plot average traces by quantile, locked to unexpected for each session 
    elif analysis == "l": # unexpected locked traces
        plot_traces_by_qu_lock_sess(figpar=figpar, savedir=savedir, **info)

    # 3. Plot magnitude of change in dF/F area from first to last quantile of 
    # unexpected vs no unexpected sequences, for each session
    elif analysis == "m": # mag
        plot_mag_change(figpar=figpar, savedir=savedir, **info)

    # 4. Plot autocorrelations
    elif analysis == "a": # autocorr
        plot_autocorr(figpar=figpar, savedir=savedir, **info)

    else:
        warnings.warn(f"No plotting function for analysis {analysis}", 
            category=UserWarning, stacklevel=1)

    plot_util.cond_close_figs()
def run_regr(args):
    """
    run_regr(args)

    Does runs of a logistic regressions on the specified comparison and range
    of sessions.
    
    Required args:
        - args (Argument parser): parser with analysis parameters as attributes:
            alg (str)             : algorithm to use ("sklearn" or "pytorch")
            bal (bool)            : if True, classes are balanced
            batchsize (int)       : nbr of samples dataloader will load per 
                                    batch (for "pytorch" alg)
            visflow_dir (str)     : visual flow direction to analyse
            visflow_per (float)   : number of seconds to include before visual 
                                    flow segments
            visflow_size (int or list): visual flow square sizes to include
            comp (str)            : type of comparison
            datadir (str)         : data directory
            dend (str)            : type of dendrites to use ("allen" or "dend")
            device (str)          : device name (i.e., "cuda" or "cpu")
            ep_freq (int)         : frequency at which to log loss to 
                                    console
            error (str)           : error to take, i.e., "std" (for std 
                                    or quantiles) or "sem" (for SEM or MAD)
            fluor (str)           : fluorescence trace type
            fontdir (str)         : directory in which additional fonts are 
                                    located
            gabfr (int)           : gabor frame of reference if comparison 
                                    is "unexp"
            gabk (int or list)    : gabor kappas to include
            gab_ori (list or str) : gabor orientations to include
            incl (str or list)    : sessions to include ("yes", "no", "all")
            lr (num)              : model learning rate (for "pytorch" alg)
            mouse_n (int)         : mouse number
            n_epochs (int)        : number of epochs
            n_reg (int)           : number of regular runs
            n_shuff (int)         : number of shuffled runs
            scale (bool)          : if True, each ROI is scaled
            output (str)          : general directory in which to save 
                                    output
            parallel (bool)       : if True, runs are done in parallel
            plt_bkend (str)       : pyplot backend to use
            q1v4 (bool)           : if True, analysis is trained on first and 
                                    tested on last quartiles
            exp_v_unexp (bool)    : if True, analysis is trained on 
                                    expected and tested on unexpected sequences
            runtype (str)         : type of run ("prod" or "pilot")
            seed (int)            : seed to seed random processes with
            sess_n (int)          : session number
            stats (str)           : stats to take, i.e., "mean" or "median"
            stimtype (str)        : stim to analyse ("gabors" or "visflow")
            train_p (list)        : proportion of dataset to allocate to 
                                    training
            uniqueid (str or int) : unique ID for analysis
            wd (float)            : weight decay value (for "pytorch" arg)
    """

    args = copy.deepcopy(args)

    if args.datadir is None:
        args.datadir = DEFAULT_DATADIR
    else:
        args.datadir = Path(args.datadir)

    if args.uniqueid == "datetime":
        args.uniqueid = gen_util.create_time_str()
    elif args.uniqueid in ["None", "none"]:
        args.uniqueid = None

    reseed = False
    if args.seed in [None, "None"]:
        reseed = True

    # deal with parameters
    extrapar = {"uniqueid": args.uniqueid, "seed": args.seed}

    techpar = {
        "reseed": reseed,
        "device": args.device,
        "alg": args.alg,
        "parallel": args.parallel,
        "plt_bkend": args.plt_bkend,
        "fontdir": args.fontdir,
        "output": args.output,
        "ep_freq": args.ep_freq,
        "n_reg": args.n_reg,
        "n_shuff": args.n_shuff,
    }

    mouse_df = DEFAULT_MOUSE_DF_PATH

    stimpar = logreg.get_stimpar(args.comp,
                                 args.stimtype,
                                 args.visflow_dir,
                                 args.visflow_size,
                                 args.gabfr,
                                 args.gabk,
                                 gab_ori=args.gab_ori,
                                 visflow_pre=args.visflow_pre)

    analyspar = sess_ntuple_util.init_analyspar(args.fluor,
                                                stats=args.stats,
                                                error=args.error,
                                                scale=not (args.no_scale),
                                                dend=args.dend)

    if args.q1v4:
        quantpar = sess_ntuple_util.init_quantpar(4, [0, -1])
    else:
        quantpar = sess_ntuple_util.init_quantpar(1, 0)

    logregpar = sess_ntuple_util.init_logregpar(args.comp, not (args.not_ctrl),
                                                args.q1v4, args.exp_v_unexp,
                                                args.n_epochs, args.batchsize,
                                                args.lr, args.train_p, args.wd,
                                                args.bal, args.alg)

    omit_sess, omit_mice = sess_gen_util.all_omit(stimpar.stimtype,
                                                  args.runtype,
                                                  stimpar.visflow_dir,
                                                  stimpar.visflow_size,
                                                  stimpar.gabk)

    sessids = sorted(
        sess_gen_util.get_sess_vals(mouse_df,
                                    "sessid",
                                    args.mouse_n,
                                    args.sess_n,
                                    args.runtype,
                                    incl=args.incl,
                                    omit_sess=omit_sess,
                                    omit_mice=omit_mice))

    if len(sessids) == 0:
        logger.warning(
            f"No sessions found (mouse: {args.mouse_n}, sess: {args.sess_n}, "
            f"runtype: {args.runtype})")

    for sessid in sessids:
        sess = sess_gen_util.init_sessions(sessid,
                                           args.datadir,
                                           mouse_df,
                                           args.runtype,
                                           full_table=False,
                                           fluor=analyspar.fluor,
                                           dend=analyspar.dend,
                                           temp_log="warning")[0]
        logreg.run_regr(sess, analyspar, stimpar, logregpar, quantpar,
                        extrapar, techpar)

        plot_util.cond_close_figs()
def plot_summ(output, savename, stimtype="gabors", comp="unexp", ctrl=False, 
              visflow_dir="both", fluor="dff", scale=True, CI=0.95, 
              plt_bkend=None, fontdir=None, modif=False):
    """
    plot_summ(output)

    Plots summary data for a specific comparison, for each datatype in a 
    separate figure and saves figures. 

    Required args:
        - output (str)  : general directory in which summary dataframe 
                          is saved (runtype and q1v4 values are inferred from 
                          the directory name)
        - savename (str): name of the dataframe containing summary data to plot
            
    Optional args:
        - stimtype (str)   : stimulus type
                             default: "gabors"
        - comp (str)       : type of comparison
                             default: "unexp"
        - ctrl (bool)      : if True, control comparisons are analysed
                             default: False                           
        - visflow_dir (str): visual flow direction
                             default: "both"
        - fluor (str)      : fluorescence trace type
                             default: "dff"
        - scale (bool)     : whether ROIs are scaled
                             default: True
        - CI (num)         : CI for shuffled data
                             default: 0.95
        - plt_bkend (str)  : mpl backend to use for plotting (e.g., "agg")
                             default: None
        - fontdir (str)    : directory in which additional fonts are located
                             default: None
        - modif (bool)     : if True, plots are made in a modified (simplified 
                             way)
                             default: False

    """
    
    plot_util.manage_mpl(plt_bkend, fontdir=fontdir)

    summ_scores_file = Path(output, savename)
    
    if summ_scores_file.is_file():
        summ_scores = file_util.loadfile(summ_scores_file)
    else:
        warnings.warn(f"{summ_scores_file} not found.", 
            category=RuntimeWarning, stacklevel=1)
        return

    if len(summ_scores) == 0:
        warnings.warn(f"No data in {summ_scores_file}.", 
            category=RuntimeWarning, stacklevel=1)
        return

    # drop NaN lines
    summ_scores = summ_scores.loc[~summ_scores["epoch_n_mean"].isna()]

    data_types  = ["epoch_n", "test_acc", "test_acc_bal"]
    data_titles = ["Epoch nbrs", "Test accuracy", "Test accuracy (balanced)"]

    stats = ["mean", "sem", "sem"]
    shuff_stats = ["median"] + math_util.get_percentiles(CI)[1]

    q1v4, evu = False, False
    if "_q1v4" in str(output):
        q1v4 = True
    elif "_evu" in str(output):
        evu = True
    
    runtype = "prod"
    if "pilot" in str(output):
        runtype = "pilot"

    if stimtype == "gabors":
        visflow_dir = "none"
        stim_str = "gab"
        stim_str_pr = "gabors"

    else:
        visflow_dir = sess_gen_util.get_params(stimtype, visflow_dir)[0]
        if isinstance(visflow_dir, list) and len(visflow_dir) == 2:
            visflow_dir = "both"
        stim_str = sess_str_util.dir_par_str(visflow_dir, str_type="file")
        stim_str_pr = sess_str_util.dir_par_str(visflow_dir, str_type="print")

    scale_str = sess_str_util.scale_par_str(scale, "file")
    scale_str_pr = sess_str_util.scale_par_str(scale, "file").replace("_", " ")
    ctrl_str = sess_str_util.ctrl_par_str(ctrl)
    ctrl_str_pr = sess_str_util.ctrl_par_str(ctrl, str_type="print")
    modif_str = "_modif" if modif else ""

    save_dir = Path(output, f"figures_{fluor}")
    save_dir.mkdir(exist_ok=True)
    
    cols = ["scale", "fluor", "visflow_dir", "runtype"]
    cri  = [scale, fluor, visflow_dir, runtype]
    plot_lines = gen_util.get_df_vals(summ_scores, cols, cri)
    cri_str = ", ".join([f"{col}: {crit}" for col, crit in zip(cols, cri)])
    if len(plot_lines) == 0: # no data
        warnings.warn(f"No data found for {cri_str}", category=RuntimeWarning, 
            stacklevel=1)
        return
    else: # shuffle or non shuffle missing
        skip = False
        for shuff in [False, True]:
            if shuff not in plot_lines["shuffle"].tolist():
                warnings.warn(f"No shuffle={shuff} data found for {cri_str}", 
                    category=RuntimeWarning, stacklevel=1)
                skip = True
        if skip:
            return

    for data, data_title in zip(data_types, data_titles):
        if not modif:
            title = (f"{stim_str_pr.capitalize()} {comp}{ctrl_str_pr} - "
                f"{data_title} for log regr on\n" + 
                u"{} {} ".format(scale_str_pr, fluor) + 
                f"data ({runtype})")
        else:
            title = (f"{stim_str_pr.capitalize()} {comp}{ctrl_str_pr}\n"
                f"{data_title}")

        if "_" in title:
            title = title.replace("_", " ")

        savename = (f"{data}_{stim_str}_{comp}{ctrl_str}{scale_str}"
            f"{modif_str}.svg")
        full_savename = Path(save_dir, savename)
        plot_data_summ(plot_lines, data, stats, shuff_stats, title, 
            full_savename, CI, q1v4, evu, comp, modif)

    plot_util.cond_close_figs()
示例#10
0
def main(args):
    """
    main(args)

    Runs analyses with parser arguments.

    Required args:
        - args (dict): 
            parser argument dictionary
    """

    # set logger to the specified level
    logger_util.set_level(level=args.log_level)

    if args.datadir is None:
        args.datadir = DEFAULT_DATADIR
    else:
        args.datadir = Path(args.datadir)
    args.mouse_df_path = DEFAULT_MOUSE_DF_PATH

    # Directory with additional fonts
    args.fontdir = DEFAULT_FONTDIR if DEFAULT_FONTDIR.exists() else None

    # warn if parallel is not used
    if args.overwrite and not (args.plot_only):
        if not args.parallel:
            warnings.warn(
                "Unless memory demands are too high for the machine being "
                "used, it is strongly recommended that paper analyses be run "
                "with the '--parallel' argument (enables computations to be "
                "distributed across available CPU cores). Otherwise, analyses "
                "may be very slow.",
                category=UserWarning,
                stacklevel=1)
            time.sleep(paper_organization.WARNING_SLEEP)

    # run through figure(s) and panel(s)
    if args.figure == "all":
        figures = paper_organization.get_all_figures()
    else:
        figures = [args.figure]

    sessions = None
    panel = args.panel
    for args.figure in figures:
        if panel == "all":
            panels = paper_organization.get_all_panels(args.figure)
        else:
            panels = [panel]

        for p, args.panel in enumerate(panels):
            new_fig = (p == 0)
            try:
                with gen_util.TimeIt():
                    sessions = run_single_panel(args,
                                                sessions=sessions,
                                                new_fig=new_fig)
            except Exception as err:
                sep = DOUBLE_SEP if new_fig else SEP
                if "Cannot plot figure panel" in str(err):
                    lead = f"{sep}Fig. {args.figure}{args.panel.upper()}"
                    logger.info(f"{lead}. {err}")
                else:
                    raise err

            plot_util.cond_close_figs()