Beispiel #1
0
def plot_autocorr(analyspar, sesspar, stimpar, extrapar, autocorrpar, 
                  sess_info, autocorr_data, figpar=None, savedir=None):
    """
    plot_autocorr(analyspar, sesspar, stimpar, extrapar, autocorrpar, 
                  sess_info, autocorr_data)

    From dictionaries, plots autocorrelation during stimulus blocks.

    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
                                parameters
            ["analysis"] (str): analysis type (e.g., "a")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - autocorrpar (dict)  : dictionary with keys of AutocorrPar namedtuple
        - sess_info (dict)    : dictionary containing information from each
                                session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session

        - autocorr_data (dict): dictionary containing data to plot:
            ["xrans"] (list): list of lag values in seconds for each session
            ["stats"] (list): list of 3D arrays (or nested lists) of
                              autocorrelation statistics, structured as:
                                     sessions stats (me, err) 
                                     x ROI or 1x and 10x lag 
                                     x lag
    
    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         dictionaries
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None
    
    Returns:
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """


    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"], 
        stimpar["gabk"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")

    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"],stimpar["visflow_size"], stimpar["gabk"]) 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
     
    datatype = extrapar["datatype"]
    if datatype == "roi":
        fluorstr_pr = sess_str_util.fluor_par_str(
            analyspar["fluor"], str_type="print")
        if autocorrpar["byitem"]:
            title_str = u"{}\nautocorrelation".format(fluorstr_pr)
        else:
            title_str = "\nautocorr. acr. ROIs" .format(fluorstr_pr)

    elif datatype == "run":
        datastr = sess_str_util.datatype_par_str(datatype)
        title_str = u"\n{} autocorrelation".format(datastr)

    if stimpar["stimtype"] == "gabors":
        seq_bars = [-1.5, 1.5] # light lines
    else:
        seq_bars = [-1.0, 1.0] # light lines

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) 

    n_sess = len(mouse_ns)

    xrans = autocorr_data["xrans"]
    stats = [np.asarray(stat) for stat in autocorr_data["stats"]]

    lag_s = autocorrpar["lag_s"]
    xticks = np.linspace(-lag_s, lag_s, lag_s*2+1)
    yticks = np.linspace(0, 1, 6)

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    byitemstr = ""
    if autocorrpar["byitem"]:
        byitemstr = "_byroi"

    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    for i in range(n_sess):
        sub_ax = plot_util.get_subax(ax, i)
        title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}, "
            u"{} ".format(statstr_pr) + f"{title_str} (sess "
            f"{sess_ns[i]}, {lines[i]} {planes[i]}{dendstr_pr}{nroi_strs[i]})")
        # transpose to ROI/lag x stats x series
        sess_stats = stats[i].transpose(1, 0, 2) 
        for s, sub_stats in enumerate(sess_stats):
            lab = None
            if not autocorrpar["byitem"]:
                lab = ["actual lag", "10x lag"][s]

            plot_util.plot_traces(
                sub_ax, xrans[i], sub_stats[0], sub_stats[1:], xticks=xticks, 
                yticks=yticks, alpha=0.2, label=lab)

        plot_util.add_bars(sub_ax, hbars=seq_bars)
        sub_ax.set_ylim([0, 1])
        sub_ax.set_title(title, y=1.02)
        if sub_ax.is_last_row():
            sub_ax.set_xlabel("Lag (s)")

    plot_util.turn_off_extra(ax, n_sess)

    if savedir is None:
        savedir = Path(
            figpar["dirs"][datatype], 
            figpar["dirs"]["autocorr"])

    savename = (f"{datatype}_autocorr{byitemstr}_{sessstr}{dendstr}")

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename
Beispiel #2
0
def plot_traces_by_qu_lock_sess(analyspar, sesspar, stimpar, extrapar, 
                                quantpar, sess_info, trace_stats, 
                                figpar=None, savedir=None, modif=False):
    """
    plot_traces_by_qu_lock_sess(analyspar, sesspar, stimpar, extrapar, 
                                quantpar, sess_info, trace_stats)

    From dictionaries, plots traces by quantile, locked to transitions from 
    unexpected to expected or v.v. with each session in a separate subplot.
    
    Returns figure name and save directory path.
    
    Required args:
        - analyspar (dict)  : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)    : dictionary with keys of SessPar namedtuple
        - stimpar (dict)    : dictionary with keys of StimPar namedtuple
        - extrapar (dict)   : dictionary containing additional analysis 
                              parameters
            ["analysis"] (str): analysis type (e.g., "l")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - quantpar (dict)   : dictionary with keys of QuantPar namedtuple
        - sess_info (dict)  : dictionary containing information from each
                              session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            if datatype == 
            ["nrois"] (list)      : number of ROIs in session

        - trace_stats (dict): dictionary containing trace stats information
            ["xrans"] (list)           : time values for the 2p frames for each 
                                         session
            ["all_stats"] (list)       : list of 4D arrays or lists of trace 
                                         data statistics across ROIs for each 
                                         session, structured as:
                                            (unexp_len x) quantiles x
                                            stats (me, err) x frames
            ["all_counts"] (array-like): number of sequences, structured as:
                                                sess x (unexp_len x) quantiles
            ["lock"] (str)             : value to which segments are locked:
                                         "unexp", "exp" or "unexp_split"
            ["baseline"] (num)         : number of seconds used for baseline
            ["exp_stats"] (list)       : list of 3D arrays or lists of trace 
                                         data statistics across ROIs for
                                         expected sampled sequences, 
                                         structured as:
                                            quantiles (1) x stats (me, err) 
                                            x frames
            ["exp_counts"] (array-like): number of sequences corresponding to
                                         exp_stats, structured as:
                                            sess x quantiles (1)
            
            if data is by unexp_len:
            ["unexp_lens"] (list)       : number of consecutive segments for
                                         each unexp_len, structured by session
                
    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         dictionaries
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None   
        - modif (bool) : if True, modified (slimmed-down) plots are created
                         instead
                         default: False
    
    Returns:
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """
    analyspar["dend"] = None
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"],
        stimpar["gabk"], "print")
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")
        
    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"])
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
     
    basestr = sess_str_util.base_par_str(trace_stats["baseline"])
    basestr_pr = sess_str_util.base_par_str(trace_stats["baseline"], "print")

    datatype = extrapar["datatype"]
    dimstr = sess_str_util.datatype_dim_str(datatype)

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) 

    n_sess = len(mouse_ns)

    xrans      = [np.asarray(xran) for xran in trace_stats["xrans"]]
    all_stats  = [np.asarray(sessst) for sessst in trace_stats["all_stats"]]
    exp_stats  = [np.asarray(expst) for expst in trace_stats["exp_stats"]]
    all_counts = trace_stats["all_counts"]
    exp_counts = trace_stats["exp_counts"]

    lock  = trace_stats["lock"]
    col_idx = 0
    if "unexp" in lock:
        lock = "unexp"
        col_idx = 1
    
    # plot unexp_lens default values
    if stimpar["stimtype"] == "gabors":
        DEFAULT_UNEXP_LEN = [3.0, 4.5, 6.0]
        if stimpar["gabfr"] not in ["any", "all"]:
            offset = sess_str_util.gabfr_nbrs(stimpar["gabfr"])
    else:
        DEFAULT_UNEXP_LEN = [2.0, 3.0, 4.0]
    
    offset = 0
    unexp_lab, len_ext = "", ""
    unexp_lens = [[None]] * n_sess
    unexp_len_default = True
    if "unexp_lens" in trace_stats.keys():
        unexp_len_default = False
        unexp_lens = trace_stats["unexp_lens"]
        len_ext = "_bylen"
        
        if stimpar["stimtype"] == "gabors":
            unexp_lens = [
                [sl * 1.5/5 - 0.3 * offset for sl in sls] for sls in unexp_lens
                ]

    inv = 1 if lock == "unexp" else -1
    # RANGE TO PLOT
    if modif:
        st_val = -2.0
        end_val  = 6.0
        n_ticks = int((end_val - st_val) // 2 + 1)
    else:
        n_ticks = 21

    if figpar is None:
        figpar = sess_plot_util.init_figpar()
    figpar = copy.deepcopy(figpar)
    if modif:
        figpar["init"]["subplot_wid"] = 6.5
    else:
        figpar["init"]["subplot_wid"] *= 2

    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    exp_min, exp_max = np.inf, -np.inf
    for i, (stats, counts) in enumerate(zip(all_stats, all_counts)):
        sub_ax = plot_util.get_subax(ax, i)

        # plot expected data
        if exp_stats[i].shape[0] != 1:
            raise ValueError("Expected only one quantile for exp_stats.")

        n_lines = quantpar["n_quants"] * len(unexp_lens[i])
        cols = sess_plot_util.get_quant_cols(n_lines)[0][col_idx]
        if len(cols) < n_lines:
            cols = [None] * n_lines

        if modif:
            line = "2/3" if "23" in lines[i] else "5"
            plane = "somata" if "soma" in planes[i] else "dendrites"
            title = f"M{mouse_ns[i]} - layer {line} {plane}{dendstr_pr}"
            lab = "exp" if i == 0 else None
            y_ax = None if i == 0 else ""

            st, end = 0, len(xrans[i])
            st_vals = list(filter(
                lambda j: xrans[i][j] <= st_val, range(len(xrans[i]))
                ))
            end_vals = list(filter(
                lambda j: xrans[i][j] >= end_val, range(len(xrans[i]))
                ))
            if len(st_vals) != 0:
                st = st_vals[-1]
            if len(end_vals) != 0:
                end = end_vals[0] + 1
            time_slice = slice(st, end)

        else:
            title = (f"Mouse {mouse_ns[i]} - {stimstr_pr}, "
                u"{} ".format(statstr_pr) + f"{lock} locked across {dimstr}"
                f"{basestr_pr}\n(sess {sess_ns[i]}, {lines[i]} {planes[i]}"
                f"{dendstr_pr}{nroi_strs[i]})")
            lab = f"exp (no lock) ({exp_counts[i][0]})"
            y_ax = None
            st = 0
            end = len(xrans[i])
            time_slice = slice(None) # use all

        # add length markers
        use_unexp_lens = unexp_lens[i]
        if unexp_len_default:
            use_unexp_lens = DEFAULT_UNEXP_LEN
        leng_col = sess_plot_util.get_quant_cols(1)[0][col_idx][0]
        for leng in use_unexp_lens:
            if leng is None:
                continue
            edge = leng * inv
            if edge < 0:
                edge = np.max([xrans[i][st], edge])
            elif edge > 0:
                edge = np.min([xrans[i][end - 1], edge])
            plot_util.add_vshade(
                sub_ax, 0, edge, color=leng_col, alpha=0.1)

        sess_plot_util.add_axislabels(
            sub_ax, fluor=analyspar["fluor"], datatype=datatype, y_ax=y_ax
            )
        plot_util.add_bars(sub_ax, hbars=0)
        alpha = np.min([0.4, 0.8 / n_lines])

        if stimpar["stimtype"] == "gabors":
            sess_plot_util.plot_gabfr_pattern(
                sub_ax, xrans[i], offset=offset, bars_omit=[0] + unexp_lens[i]
                )

        plot_util.plot_traces(
            sub_ax, xrans[i][time_slice], exp_stats[i][0][0, time_slice], 
            exp_stats[i][0][1:, time_slice], n_xticks=n_ticks,
            alpha=alpha, label=lab, alpha_line=0.8, color="darkgray", 
            xticks="auto")

        # get expected data range to adjust y lims
        exp_min = np.min([exp_min, np.nanmin(exp_stats[i][0][0])])
        exp_max = np.max([exp_max, np.nanmax(exp_stats[i][0][0])])

        n = 0 # count lines plotted
        for s, unexp_len in enumerate(unexp_lens[i]):
            if unexp_len is not None:
                counts, stats = all_counts[i][s], all_stats[i][s]       
                # remove offset   
                unexp_lab = f"unexp len {unexp_len + 0.3 * offset}"
            else:
                unexp_lab = "unexp" if modif else f"{lock} lock"
            for q, qu_idx in enumerate(quantpar["qu_idx"]):
                qu_lab = ""
                if quantpar["n_quants"] > 1:
                    qu_lab = "{} ".format(sess_str_util.quantile_str(
                        qu_idx, quantpar["n_quants"], str_type="print"
                        ))
                lab = f"{qu_lab}{unexp_lab}"
                if modif:
                    lab = lab if i == 0 else None
                else:
                    lab = f"{lab} ({counts[q]})"
                if n == 2 and cols[n] is None:
                    sub_ax.plot([], []) # to advance the color cycle (past gray)
                plot_util.plot_traces(sub_ax, xrans[i][time_slice], 
                    stats[q][0, time_slice], stats[q][1:, time_slice], title, 
                    alpha=alpha, label=lab, n_xticks=n_ticks, alpha_line=0.8, 
                    color=cols[n], xticks="auto")
                n += 1
            if unexp_len is not None:
                plot_util.add_bars(
                    sub_ax, hbars=unexp_len, color=sub_ax.lines[-1].get_color(), 
                    alpha=1)
    
    plot_util.turn_off_extra(ax, n_sess)

    if savedir is None:
        savedir = Path(
            figpar["dirs"][datatype], 
            figpar["dirs"]["unexp_qu"], 
            f"{lock}_lock", basestr.replace("_", ""))

    if not modif:
        if stimpar["stimtype"] == "visflow":
            plot_util.rel_confine_ylims(sub_ax, [exp_min, exp_max], 5)

    qu_str = f"_{quantpar['n_quants']}q"
    if quantpar["n_quants"] == 1:
        qu_str = ""
 
    savename = (f"{datatype}_av_{lock}_lock{len_ext}{basestr}_{sessstr}"
        f"{dendstr}{qu_str}")
    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename
Beispiel #3
0
def plot_traces_by_qu_unexp_sess(analyspar, sesspar, stimpar, extrapar, 
                                quantpar, sess_info, trace_stats, figpar=None, 
                                savedir=None, modif=False):
    """
    plot_traces_by_qu_unexp_sess(analyspar, sesspar, stimpar, extrapar, 
                                quantpar, sess_info, trace_stats)

    From dictionaries, plots traces by quantile/unexpected with each session in a 
    separate subplot.
    
    Returns figure name and save directory path.
    
    Required args:
        - analyspar (dict)  : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)    : dictionary with keys of SessPar namedtuple
        - stimpar (dict)    : dictionary with keys of StimPar namedtuple
        - extrapar (dict)   : dictionary containing additional analysis 
                              parameters
            ["analysis"] (str): analysis type (e.g., "t")
            ["datatype"] (str): datatype (e.g., "run", "roi")
        - quantpar (dict)   : dictionary with keys of QuantPar namedtuple
        - sess_info (dict)  : dictionary containing information from each
                              session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            if extrapar["datatype"] == "roi":
            ["nrois"] (list)      : number of ROIs in session

        - trace_stats (dict): dictionary containing trace stats information
            ["xrans"] (list)           : time values for the frames, for each 
                                         session
            ["all_stats"] (list)       : list of 4D arrays or lists of trace 
                                         data statistics across ROIs for each
                                         session, structured as:
                                            sess x unexp x quantiles x
                                            stats (me, err) x frames
            ["all_counts"] (array-like): number of sequences, structured as:
                                                sess x unexp x quantiles
                
    Optional args:
        - figpar (dict): dictionary containing the following figure parameter 
                         dictionaries
                         default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (str): path of directory in which to save plots.
                         default: None    
        - modif (bool) : if True, modified (slimmed-down) plots are created
                         instead
                         default: False
    
    Returns:
        - fulldir (str) : final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """
 
    stimstr_pr = sess_str_util.stim_par_str(
        stimpar["stimtype"], stimpar["visflow_dir"], stimpar["visflow_size"],
        stimpar["gabk"], "print")
    statstr_pr = sess_str_util.stat_par_str(
        analyspar["stats"], analyspar["error"], "print")
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")
        
    sessstr = sess_str_util.sess_par_str(
        sesspar["sess_n"], stimpar["stimtype"], sesspar["plane"], 
        stimpar["visflow_dir"], stimpar["visflow_size"], stimpar["gabk"])
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])
     
    datatype = extrapar["datatype"]
    dimstr = sess_str_util.datatype_dim_str(datatype)

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, empty=(datatype!="roi")) 

    n_sess = len(mouse_ns)

    xrans      = [np.asarray(xran) for xran in trace_stats["xrans"]]
    all_stats  = [np.asarray(sessst) for sessst in trace_stats["all_stats"]]
    all_counts = trace_stats["all_counts"]

    cols, lab_cols = sess_plot_util.get_quant_cols(quantpar["n_quants"])
    alpha = np.min([0.4, 0.8 / quantpar["n_quants"]])

    unexps = ["exp", "unexp"]
    n = 6
    if figpar is None:
        figpar = sess_plot_util.init_figpar()
    
    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    for i in range(n_sess):
        sub_ax = plot_util.get_subax(ax, i)
        for s, [col, leg_ext] in enumerate(zip(cols, unexps)):
            for q, qu_idx in enumerate(quantpar["qu_idx"]):
                qu_lab = ""
                if quantpar["n_quants"] > 1:
                    qu_lab = "{} ".format(sess_str_util.quantile_str(
                        qu_idx, quantpar["n_quants"], str_type="print"
                        ))
                if modif:
                    line = "2/3" if "23" in lines[i] else "5"
                    plane = "somata" if "soma" in planes[i] else "dendrites"
                    title = f"M{mouse_ns[i]} - layer {line} {plane}{dendstr_pr}"
                    leg = f"{qu_lab}{leg_ext}" if i == 0 else None
                    y_ax = None if i == 0 else ""

                else:
                    title=(f"Mouse {mouse_ns[i]} - {stimstr_pr}, " 
                        u"{}\n".format(statstr_pr) + f"across {dimstr} (sess "
                        f"{sess_ns[i]}, {lines[i]} {planes[i]}{dendstr_pr}"
                        f"{nroi_strs[i]})")
                    leg = f"{qu_lab}{leg_ext} ({all_counts[i][s][q]})"
                    y_ax = None

                plot_util.plot_traces(
                    sub_ax, xrans[i], all_stats[i][s, q, 0], 
                    all_stats[i][s, q, 1:], title, color=col[q], alpha=alpha, 
                    label=leg, n_xticks=n, xticks="auto")
                sess_plot_util.add_axislabels(
                    sub_ax, fluor=analyspar["fluor"], datatype=datatype, 
                    y_ax=y_ax)
    
    plot_util.turn_off_extra(ax, n_sess)

    if stimpar["stimtype"] == "gabors": 
        sess_plot_util.plot_labels(
            ax, stimpar["gabfr"], "both", pre=stimpar["pre"], 
            post=stimpar["post"], cols=lab_cols, 
            sharey=figpar["init"]["sharey"])
    
    if savedir is None:
        savedir = Path(
            figpar["dirs"][datatype], 
            figpar["dirs"]["unexp_qu"])

    qu_str = f"_{quantpar['n_quants']}q"
    if quantpar["n_quants"] == 1:
        qu_str = ""

    savename = f"{datatype}_av_{sessstr}{dendstr}{qu_str}"
    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename
def plot_pup_roi_stim_corr(analyspar, sesspar, stimpar, extrapar, 
                           sess_info, corr_data, figpar=None, savedir=None):
    """
    plot_pup_roi_stim_corr(analyspar, sesspar, stimpar, extrapar, 
                           sess_info, corr_data)

    From dictionaries, plots correlation between unexpected-locked changes in 
    pupil diameter and each ROI, for gabors versus visual flow responses for 
    each session.
    
    Required args:
        - analyspar (dict)    : dictionary with keys of AnalysPar namedtuple
        - sesspar (dict)      : dictionary with keys of SessPar namedtuple 
        - stimpar (dict)      : dictionary with keys of StimPar namedtuple
        - extrapar (dict)     : dictionary containing additional analysis 
                                parameters
            ["analysis"] (str): analysis type (e.g., "r")
            ["datatype"] (str): datatype (e.g., "roi")
        - sess_info (dict)    : dictionary containing information from each
                                session 
            ["mouse_ns"] (list)   : mouse numbers
            ["sess_ns"] (list)    : session numbers  
            ["lines"] (list)      : mouse lines
            ["planes"] (list)     : imaging planes
            ["nrois"] (list)      : number of ROIs in session
        - corr_data (dict)    : dictionary containing data to plot:
            ["stim_order"] (list): ordered list of stimtypes
            ["roi_corrs"] (list) : nested list of correlations between pupil 
                                   and ROI responses changes locked to 
                                   unexpected, structured as 
                                       session x stimtype x ROI
            ["corrs"] (list)     : list of correlation between stimtype
                                   correlations for each session
    
    Optional args:
        - figpar (dict) : dictionary containing the following figure parameter 
                          dictionaries
                          default: None
            ["init"] (dict): dictionary with figure initialization parameters
            ["save"] (dict): dictionary with figure saving parameters
            ["dirs"] (dict): dictionary with additional figure parameters
        - savedir (Path): path of directory in which to save plots.
                          default: None
    
    Returns:
        - fulldir (Path): final path of the directory in which the figure is 
                          saved (may differ from input savedir, if datetime 
                          subfolder is added.)
        - savename (str): name under which the figure is saved
    """

    stimstr_prs = []
    for stimtype in corr_data["stim_order"]:
        stimstr_pr = sess_str_util.stim_par_str(
            stimtype, stimpar["visflow_dir"], stimpar["visflow_size"], 
            stimpar["gabk"], "print")
        stimstr_pr = stimstr_pr[:-1] if stimstr_pr[-1] == "s" else stimstr_pr
        stimstr_prs.append(stimstr_pr)
        
    dendstr_pr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"], "print")

    sessstr = f"sess{sesspar['sess_n']}_{sesspar['plane']}" 
    dendstr = sess_str_util.dend_par_str(
        analyspar["dend"], sesspar["plane"], extrapar["datatype"])

    label_str = sess_str_util.fluor_par_str(
        analyspar["fluor"], str_type="print")
    lab_app = (f" ({analyspar['stats']} over "
        f"{stimpar['pre']}/{stimpar['post']} sec)")

    logger.info("Plotting pupil-ROI difference correlations for "
        "{} vs {}.".format(*corr_data["stim_order"]))

    # extract some info from sess_info
    keys = ["mouse_ns", "sess_ns", "lines", "planes"]
    [mouse_ns, sess_ns, lines, planes] = [sess_info[key] for key in keys]

    n_sess = len(mouse_ns)
    nroi_strs = sess_str_util.get_nroi_strs(sess_info, style="comma")

    if figpar is None:
        figpar = sess_plot_util.init_figpar()

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()
    figpar["init"]["sharex"] = True
    figpar["init"]["sharey"] = True
    
    fig, ax = plot_util.init_fig(n_sess, **figpar["init"])
    suptitle = (u"Relationship between pupil diam. and {} changes, locked to "
        "unexpected events\n{} for each ROI ({} vs {})".format(
            label_str, lab_app, *corr_data["stim_order"]))
    
    for i, sess_roi_corrs in enumerate(corr_data["roi_corrs"]):
        sub_ax = plot_util.get_subax(ax, i)
        title = (f"Mouse {mouse_ns[i]} (sess {sess_ns[i]}, {lines[i]} "
            f"{planes[i]}{dendstr_pr}{nroi_strs[i]})")
        
        # top plot: correlations
        corr = f"Corr = {corr_data['corrs'][i]:.2f}"
        sub_ax.plot(
            sess_roi_corrs[0], sess_roi_corrs[1], marker=".", linestyle="None", 
            label=corr)
        sub_ax.set_title(title, y=1.01)
        if sub_ax.is_last_row():
            sub_ax.set_xlabel(f"{stimstr_prs[0].capitalize()} correlations")
        if sub_ax.is_first_col():
            sub_ax.set_ylabel(f"{stimstr_prs[1].capitalize()} correlations")
        sub_ax.legend()

    plot_util.turn_off_extra(ax, n_sess)

    fig.suptitle(suptitle, fontsize="xx-large", y=1)

    if savedir is None:
        savedir = Path(
            figpar["dirs"]["roi"],
            figpar["dirs"]["pupil"])

    savename = f"roi_diff_corrbyroi_{sessstr}{dendstr}"

    fulldir = plot_util.savefig(fig, savename, savedir, **figpar["save"])

    return fulldir, savename                           
def plot_data_summ(plot_lines, data, stats, shuff_stats, title, savename, 
                   CI=0.95, q1v4=False, evu=False, comp="unexp", modif=False, 
                   no_save=False):
    """
    plot_data_summ(plot_lines, data, stats, shuff_stats, title, savename)

    Plots summary data for a specific comparison, for each line and plane and 
    saves figure.

    Required args:
        - plot_lines (pd DataFrame): DataFrame containing scores summary
                                     for specific comparison and criteria
        - data (str)               : label of type of data to plot,
                                     e.g., "epoch_n" or "test_acc_bal" 
        - stats (list)             : list of stats to use for non shuffled 
                                     data, e.g., ["mean", "sem", "sem"]
        - shuff_stats (list)       : list of stats to use for shuffled 
                                     data, e.g., ["median", "p2p5", "p97p5"]
        - title (str)              : general plot titles (must contain "data")
        - savename (str)           : plot save path
        
    Optional args:
        - CI (num)      : CI for shuffled data (e.g., 0.95)
                          default: 0.95
        - q1v4 (bool)   : if True, analysis is separated across first and 
                          last quartiles
                          default: False
        - evu (bool)    : if True, the first dataset will include expected 
                          sequences and the second will include unexpected 
                          sequences
                          default: False
        - comp (str)    : type of comparison
                          default: "unexp"
        - modif (bool)  : if True, plots are made in a modified (simplified 
                          way)
                          default: False
        - no_save (bool): if True, figure is not saved
                          default: False

    Returns:
        - fig (plt Figure)   : figure
    """
    
    celltypes = [[x, y] for y in ["dend", "soma"] for x in ["L23", "L5"]]

    max_sess = max(plot_lines["sess_n"].tolist())

    fig, ax = init_res_fig(len(celltypes), max_sess, modif)
    
    if modif:
        fig.suptitle(title, y=1.0, weight="bold")

    n_vals = 5 # (mean/med, sem/2.5p, sem/97.5p, n_rois, n_runs)
    
    if data == "test_acc_bal":
        found = False
        for key in plot_lines.keys():
            if data in key:
                found = True
        if not found:
            warnings.warn("test_acc_bal was not recorded", 
                category=RuntimeWarning, stacklevel=1)
            return
    
    split_oris = sess_str_util.get_split_oris(comp)

    data_types = gen_util.list_if_not(data)

    if (q1v4 or evu or split_oris) and "test_acc" in data:
        ext_test = sess_str_util.ext_test_str(q1v4, evu, comp)
        n_vals = 8 # (extra mean/med, sem/2.5p, sem/97.5p for Q4) 
        if data == "test_acc": 
            data_types = ["test_acc", f"{ext_test}_acc"]
        elif data == "test_acc_bal":   
            data_types = ["test_acc_bal", f"{ext_test}_acc_bal"]
        else:
            gen_util.accepted_values_error("data", data, 
                ["test_acc", "test_acc_bal"])

    for i, [line, plane] in enumerate(celltypes):
        sub_ax = plot_util.get_subax(ax, i)
        if not modif:
            sub_ax.set_xlim(-0.5, max_sess - 0.5)
        # get the right rows in dataframe
        cols       = ["plane"]
        cri        = [plane]
        curr_lines = gen_util.get_df_vals(plot_lines.loc[
            plot_lines["line"].str.contains(line)], cols, cri)
        cri_str = ", ".join([f"{col}: {crit}" for col, crit in zip(cols, cri)])
        if len(curr_lines) == 0: # no data
            warnings.warn(f"No data found for {line} {plane}, {cri_str}", 
                category=RuntimeWarning, stacklevel=1)
            continue
        else: # shuffle or non shuffle missing
            skip = False
            for shuff in [False, True]:
                if shuff not in curr_lines["shuffle"].tolist():
                    warnings.warn(f"No shuffle={shuff} data found for {line} "
                        f"{plane}, {cri_str}", category=RuntimeWarning, 
                        stacklevel=1)
                    skip = True
            if skip:
                continue

        sess_ns = gen_util.get_df_vals(curr_lines, label="sess_n", dtype=int)
        mouse_ns = gen_util.get_df_vals(curr_lines, label="mouse_n", 
            dtype=int)
        # mouse x sess x n_vals 
        if -1 not in mouse_ns:
            raise RuntimeError("Shuffle data across mice is missing.")
        mouse_ns = gen_util.remove_if(mouse_ns, -1)
        data_arr = np.empty((len(mouse_ns), int(max_sess), n_vals)) * np.nan
        shuff_arr = np.empty((1, int(max_sess), n_vals - 1)) * np.nan

        for sess_n in sess_ns:
            sess_mice = gen_util.get_df_vals(
                curr_lines, "sess_n", sess_n, "mouse_n", dtype=int)
            for m, mouse_n in enumerate(mouse_ns + [-1]):
                if mouse_n not in sess_mice:
                    continue
                if mouse_n == -1:
                    stat_types = shuff_stats
                    arr = shuff_arr
                    m = 0
                else:
                    stat_types = stats
                    arr = data_arr
                curr_line = gen_util.get_df_vals(curr_lines, 
                    ["sess_n", "mouse_n", "shuffle"], 
                    [sess_n, mouse_n, mouse_n==-1])
                if len(curr_line) > 1:
                    raise RuntimeError("Several lines correspond to criteria.")
                elif len(curr_line) == 0:
                    continue
                for st, stat in enumerate(stat_types):
                    for d, dat in enumerate(data_types):
                        i = d * 3 + st
                        arr[m, int(sess_n-1), i] = curr_line[f"{dat}_{stat}"]
                if mouse_n != -1:
                    arr[m, int(sess_n-1), -2] = curr_line["n_rois"]
                arr[m, int(sess_n-1), -1] = curr_line["runs_total"] - \
                    curr_line["runs_nan"]

        summ_subplot(sub_ax, data_arr, shuff_arr, title, mouse_ns, sess_ns, 
            line, plane, stats[0], stats[1], CI, q1v4, evu, split_oris, modif)

    if modif:
        n_sess_keep = 3
        ylab = ax[0, 0].get_ylabel()
        ax[0, 0].set_ylabel("")
        sess_plot_util.format_linpla_subaxes(ax, ylab=ylab, 
            xticks=np.arange(1, n_sess_keep + 1))

        yticks = ax[0, 0].get_yticks()
        # always set ticks (even again) before setting labels
        ax[1, 0].set_yticks(yticks)
        ax[1, 0].set_yticklabels([int(v) for v in yticks], 
            fontdict={"weight": "bold"})

    if not no_save:
        fig.savefig(savename)

    return fig
def run_sess_lstm(sessid, args):

    if args.parallel and args.plt_bkend is not None:
        plt.switch_backend(args.plt_bkend) # needs to be repeated within joblib

    args.seed = rand_util.seed_all(args.seed, args.device, seed_torch=True)

    train_p = 0.8
    lr = 1. * 10**(-args.lr_ex)
    if args.conv:
        conv_str = "_conv"
        outch_str = f"_{args.out_ch}outch"
    else:
        conv_str = ""
        outch_str = ""

    # Input output parameters
    n_stim_s  = 0.6
    n_roi_s = 0.3

    # Stim/traces for training
    train_gabfr = 0
    train_post = 0.9 # up to C
    roi_train_pre = 0 # from A
    stim_train_pre   = 0.3 # from preceeding grayscreen

    # Stim/traces for testing (separated for unexp vs exp)
    test_gabfr = 3
    test_post  = 0.6 # up to grayscreen
    roi_test_pre = 0 # from D/U
    stim_test_pre   = 0.3 # from preceeding C

    sess = sess_gen_util.init_sessions(
        sessid, args.datadir, args.mouse_df, args.runtype, full_table=False, 
        fluor="dff", dend="extr", run=True, temp_log="warning")[0]

    analysdir = sess_gen_util.get_analysdir(
        sess.mouse_n, sess.sess_n, sess.plane, stimtype=args.stimtype, 
        comp=None)
    dirname = Path(args.output, analysdir)
    file_util.createdir(dirname, log_dir=False)

    # Must not scale ROIs or running BEFOREHAND. Must do after to use only 
    # network available data.

    # seq x frame x gabor x par
    logger.info("Preparing stimulus parameter dataframe", 
        extra={"spacing": "\n"})
    train_stim_wins, run_stats = sess_data_util.get_stim_data(
        sess, args.stimtype, n_stim_s, train_gabfr, stim_train_pre, 
        train_post, gabk=16, run=True)

    logger.info("Adding ROI data")
    xran, train_roi_wins, roi_stats = sess_data_util.get_roi_data(
        sess, args.stimtype, n_roi_s, train_gabfr, roi_train_pre, train_post, 
        gabk=16)

    logger.warning("Preparing windowed datasets (too slow - to be improved)")
    raise NotImplementedError("Not implemented properly - some error leads "
        "to excessive memory requests.")
    test_stim_wins = []
    test_roi_wins  = []
    for unexp in [0, 1]:
        stim_wins = sess_data_util.get_stim_data(
            sess, args.stimtype, n_stim_s, test_gabfr, stim_test_pre, 
            test_post, unexp, gabk=16, run_mean=run_stats[0], 
            run_std=run_stats[1])
        test_stim_wins.append(stim_wins)
        
        roi_wins = sess_data_util.get_roi_data(sess, args.stimtype, n_roi_s,  
                           test_gabfr, roi_test_pre, test_post, unexp, gabk=16, 
                           roi_means=roi_stats[0], roi_stds=roi_stats[1])[1]
        test_roi_wins.append(roi_wins)

    n_pars = train_stim_wins.shape[-1] # n parameters (121)
    n_rois = train_roi_wins.shape[-1] # n ROIs

    hyperstr = (f"{args.hidden_dim}hd_{args.num_layers}hl_{args.lr_ex}lrex_"
                f"{args.batchsize}bs{outch_str}{conv_str}")

    dls = data_util.create_dls(train_stim_wins, train_roi_wins, train_p=train_p, 
                            test_p=0, batchsize=args.batchsize, thresh_cl=0, 
                            train_shuff=True)[0]
    train_dl, val_dl, _ = dls

    test_dls = []
    
    for s in [0, 1]:
        dl = data_util.init_dl(test_stim_wins[s], test_roi_wins[s], 
                            batchsize=args.batchsize)
        test_dls.append(dl)

    logger.info("Running LSTM")
    if args.conv:
        lstm = ConvPredROILSTM(args.hidden_dim, n_rois, out_ch=args.out_ch, 
                            num_layers=args.num_layers, dropout=args.dropout)
    else:
        lstm = PredLSTM(n_pars, args.hidden_dim, n_rois, 
                        num_layers=args.num_layers, dropout=args.dropout)

    lstm = lstm.to(args.device)
    lstm.loss_fn = torch.nn.MSELoss(size_average=False)
    lstm.opt = torch.optim.Adam(lstm.parameters(), lr=lr)

    loss_df = pd.DataFrame(
        np.nan, index=range(args.n_epochs), columns=["train", "val"])
    min_val = np.inf
    for ep in range(args.n_epochs):
        logger.info(f"====> Epoch {ep}", extra={"spacing": "\n"})
        if ep == 0:
            train_loss = run_dl(lstm, train_dl, args.device, train=False)    
        else:
            train_loss = run_dl(lstm, train_dl, args.device, train=True)
        val_loss = run_dl(lstm, val_dl, args.device, train=False)
        loss_df["train"].loc[ep] = train_loss/train_dl.dataset.n_samples
        loss_df["val"].loc[ep] = val_loss/val_dl.dataset.n_samples
        logger.info(f"Training loss  : {loss_df['train'].loc[ep]}")
        logger.info(f"Validation loss: {loss_df['val'].loc[ep]}")

        # record model if training is lower than val, and val reaches a new low
        if ep == 0 or val_loss < min_val:
            prev_model = glob.glob(str(Path(dirname, f"{hyperstr}_ep*.pth")))
            prev_df = glob.glob(str(Path(dirname, f"{hyperstr}.csv")))
            min_val = val_loss
            saved_ep = ep
                
            if len(prev_model) == 1 and len(prev_df) == 1:
                Path(prev_model[0]).unlink()
                Path(prev_df[0]).unlink()

            savename = f"{hyperstr}_ep{ep}"
            savefile = Path(dirname, savename)
        
            torch.save({"net": lstm.state_dict(), "opt": lstm.opt.state_dict()},
                        f"{savefile}.pth")
        
            file_util.saveinfo(loss_df, hyperstr, dirname, "csv")

    plot_util.linclab_plt_defaults(font=["Arial", "Liberation Sans"], 
                                   fontdir=DEFAULT_FONTDIR)
    fig, ax = plt.subplots(1)
    for dataset in ["train", "val"]:
        plot_util.plot_traces(ax, range(args.n_epochs), np.asarray(loss_df[dataset]), 
                  label=dataset, title=f"Average loss (MSE) ({n_rois} ROIs)", 
                  xticks="auto")
    fig.savefig(Path(dirname, f"{hyperstr}_loss"))

    savemod = Path(dirname, f"{hyperstr}_ep{saved_ep}.pth")
    checkpoint = torch.load(savemod)
    lstm.load_state_dict(checkpoint["net"]) 

    n_samples = 20
    val_idx = np.random.choice(range(val_dl.dataset.n_samples), n_samples)
    val_samples = val_dl.dataset[val_idx]
    xrans = data_util.get_win_xrans(xran, val_samples[1].shape[1], val_idx.tolist())

    fig, ax = plot_util.init_fig(n_samples, ncols=4, sharex=True, subplot_hei=2, 
                                subplot_wid=5)


    lstm.eval()
    with torch.no_grad():
        batch_len, seq_len, n_items = val_samples[1].shape
        pred_tr = lstm(val_samples[0].transpose(1, 0).to(args.device))
        pred_tr = pred_tr.view([seq_len, batch_len, n_items]).transpose(1, 0)

    for lab, data in zip(["target", "pred"], [val_samples[1], pred_tr]):
        data = data.numpy()
        for n in range(n_samples):
            roi_n = np.random.choice(range(data.shape[-1]))
            sub_ax = plot_util.get_subax(ax, n)
            plot_util.plot_traces(sub_ax, xrans[n], data[n, :, roi_n], 
                label=lab, xticks="auto")
            plot_util.set_ticks(sub_ax, "x", xran[0], xran[-1], n=7)

    sess_plot_util.plot_labels(ax, train_gabfr, plot_vals="exp", pre=roi_train_pre, 
                            post=train_post)

    fig.suptitle(f"Target vs predicted validation traces ({n_rois} ROIs)")
    fig.savefig(Path(dirname, f"{hyperstr}_traces"))