def reformat_args(args):
    """
    reformat_args(args)

    Returns reformatted args for analyses, specifically 
        - Sets stimulus parameters to "none" if they are irrelevant to the 
          stimtype
        - Changes stimulus parameters from "both" to actual values
        - Sets seed, though doesn't seed
        - Modifies analyses (if "all" or "all_" in parameter)

    Adds the following args:
        - dend (str)     : type of dendrites to use ("allen", "extr")
        - omit_sess (str): sess to omit
        - omit_mice (str): mice to omit

    Required args:
        - args (Argument parser): parser with the following attributes: 
            runtype (str)        : runtype ("pilot" or "prod")
            stimtype (str)       : stimulus to analyse (visflow or gabors)
    
    Returns:
        - args (Argument parser): input parser, with the following attributes 
                                  added:
                                      visflow_dir, visflow_size, gabfr, gabk, 
                                      gab_oriomit_sess, omit_mice, dend, 
                                      analyses, seed
    """

    args = copy.deepcopy(args)

    [args.visflow_dir, args.visflow_size, args.gabfr, args.gabk,
     args.gab_ori] = sess_gen_util.get_params(args.stimtype, "both", 128,
                                              "any", 16, "any")

    if args.plane == "soma":
        args.dend = "allen"

    args.omit_sess, args.omit_mice = sess_gen_util.all_omit(
        args.stimtype, args.runtype, args.visflow_dir, args.visflow_size,
        args.gabk)

    # choose a seed if none is provided (i.e., args.seed=-1), but seed later
    args.seed = rand_util.seed_all(args.seed,
                                   "cpu",
                                   log_seed=False,
                                   seed_now=False)

    # collect analysis letters
    all_analyses = "".join(get_analysis_fcts().keys())
    if "all" in args.analyses:
        if "_" in args.analyses:
            excl = args.analyses.split("_")[1]
            args.analyses, _ = gen_util.remove_lett(all_analyses, excl)
        else:
            args.analyses = all_analyses
    elif "_" in args.analyses:
        raise ValueError("Use '_' in args.analyses only with 'all'.")

    return args
    def _set_seed(self):
        """
        self._set_seed()

        Updates attributes related to random process seeding.

        Updates the following attributes:
            - paper_seed (bool): whether the paper seed will be used
            - seed (bool): 
                specific seed that will be used None, if self.randomness is 
                False
        """

        if not hasattr(self, "n_perms"):
            raise RuntimeError("Must run self._set_power() first.")

        if not self.randomness:
            self.seed = None
            self.paper_seed = True
        else:
            if self.seed == "paper":
                self.seed = PAPER_SEED
                self.paper_seed = True
            else:
                self.seed = int(self.seed)
                self.paper_seed = False

            if self.seed == -1:
                self.seed = PAPER_SEED  # select any seed but the paper seed
                while self.seed == PAPER_SEED:
                    self.seed = rand_util.seed_all(-1,
                                                   "cpu",
                                                   log_seed=False,
                                                   seed_now=False)
            if self.seed != PAPER_SEED:
                self.warnings.append(seed_warning(self.seed))
Beispiel #3
0
def run_mag_change(sessions,
                   analysis,
                   seed,
                   analyspar,
                   sesspar,
                   stimpar,
                   permpar,
                   quantpar,
                   figpar,
                   datatype="roi"):
    """
    run_mag_change(sessions, analysis, seed, analyspar, sesspar, stimpar, 
                   permpar, quantpar, figpar)

    Calculates and plots the magnitude of change in activity of ROIs between 
    the first and last quantile for expected vs unexpected sequences.
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "m")
        - seed (int)           : seed value to use. (-1 treated as None) 
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - permpar (PermPar)    : named tuple containing permutation parameters
        - quantpar (QuantPar)  : named tuple containing quantile analysis 
                                 parameters
        - figpar (dict)        : dictionary containing figure parameters   

    Optional args:
        - datatype (str): type of data (e.g., "roi", "run") 
    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Calculating and plotting the magnitude changes in {datastr} "
        f"activity across quantiles \n({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    if permpar.multcomp:
        permpar = sess_ntuple_util.get_modif_ntuple(permpar, "multcomp",
                                                    len(sessions))

    # get full data: session x unexp x quants of interest x [ROI x seq]
    integ_info = quant_analys.trace_stats_by_qu_sess(sessions,
                                                     analyspar,
                                                     stimpar,
                                                     quantpar.n_quants,
                                                     quantpar.qu_idx,
                                                     by_exp=True,
                                                     integ=True,
                                                     ret_arr=True,
                                                     datatype=datatype)
    all_counts = integ_info[-2]
    qu_data = integ_info[-1]

    # extract session info
    mouse_ns = [sess.mouse_n for sess in sessions]
    lines = [sess.line for sess in sessions]

    if analyspar.rem_bad:
        nanpol = None
    else:
        nanpol = "omit"

    seed = rand_util.seed_all(seed, "cpu", log_seed=False)

    mags = quant_analys.qu_mags(qu_data,
                                permpar,
                                mouse_ns,
                                lines,
                                analyspar.stats,
                                analyspar.error,
                                nanpol=nanpol,
                                op_qu="diff",
                                op_unexp="diff")

    # convert mags items to list
    mags = copy.deepcopy(mags)
    mags["all_counts"] = all_counts
    for key in ["mag_st", "L2", "mag_rel_th", "L2_rel_th"]:
        mags[key] = mags[key].tolist()

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            incl_roi=(datatype == "roi"),
                                            rem_bad=analyspar.rem_bad)

    extrapar = {"analysis": analysis, "datatype": datatype, "seed": seed}

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "extrapar": extrapar,
        "permpar": permpar._asdict(),
        "quantpar": quantpar._asdict(),
        "mags": mags,
        "sess_info": sess_info
    }

    fulldir, savename = gen_plots.plot_mag_change(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")
Beispiel #4
0
def run_traces_by_qu_lock_sess(sessions,
                               analysis,
                               seed,
                               analyspar,
                               sesspar,
                               stimpar,
                               quantpar,
                               figpar,
                               datatype="roi"):
    """
    run_traces_by_qu_lock_sess(sessions, analysis, analyspar, sesspar, 
                               stimpar, quantpar, figpar)

    Retrieves trace statistics by session x quantile at the transition of
    expected to unexpected sequences (or v.v.) and plots traces across ROIs by 
    quantile with each session in a separate subplot.
    
    Also runs analysis for one quantile (full data) with different unexpected 
    lengths grouped separated 
    
    Saves results and parameters relevant to analysis in a dictionary.

    Required args:
        - sessions (list)      : list of Session objects
        - analysis (str)       : analysis type (e.g., "l")
        - seed (int)           : seed value to use. (-1 treated as None)
        - analyspar (AnalysPar): named tuple containing analysis parameters
        - sesspar (SessPar)    : named tuple containing session parameters
        - stimpar (StimPar)    : named tuple containing stimulus parameters
        - quantpar (QuantPar)  : named tuple containing quantile analysis 
                                 parameters
        - figpar (dict)        : dictionary containing figure parameters
    
    Optional args:
        - datatype (str): type of data (e.g., "roi", "run")

    """

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            datatype, "print")

    datastr = sess_str_util.datatype_par_str(datatype)

    logger.info(
        f"Analysing and plotting unexpected vs expected {datastr} "
        f"traces locked to unexpected onset by quantile ({quantpar.n_quants}) "
        f"\n({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    seed = rand_util.seed_all(seed, "cpu", log_seed=False)

    # modify quantpar to retain all quantiles
    quantpar_one = sess_ntuple_util.init_quantpar(1, 0)
    n_quants = quantpar.n_quants
    quantpar_mult = sess_ntuple_util.init_quantpar(n_quants, "all")

    if stimpar.stimtype == "visflow":
        pre_post = [2.0, 6.0]
    elif stimpar.stimtype == "gabors":
        pre_post = [2.0, 8.0]
    else:
        gen_util.accepted_values_error("stimpar.stimtype", stimpar.stimtype,
                                       ["visflow", "gabors"])
    logger.warning("Setting pre to {}s and post to {}s.".format(*pre_post))

    stimpar = sess_ntuple_util.get_modif_ntuple(stimpar, ["pre", "post"],
                                                pre_post)

    figpar = copy.deepcopy(figpar)
    if figpar["save"]["use_dt"] is None:
        figpar["save"]["use_dt"] = gen_util.create_time_str()

    for baseline in [None, stimpar.pre]:
        basestr_pr = sess_str_util.base_par_str(baseline, "print")
        for quantpar in [quantpar_one, quantpar_mult]:
            locks = ["unexp", "exp"]
            if quantpar.n_quants == 1:
                locks.append("unexp_split")
            # get the stats (all) separating by session and quantiles
            for lock in locks:
                logger.info(
                    f"{quantpar.n_quants} quant, {lock} lock{basestr_pr}",
                    extra={"spacing": "\n"})
                if lock == "unexp_split":
                    trace_info = quant_analys.trace_stats_by_exp_len_sess(
                        sessions,
                        analyspar,
                        stimpar,
                        quantpar.n_quants,
                        quantpar.qu_idx,
                        byroi=False,
                        nan_empty=True,
                        baseline=baseline,
                        datatype=datatype)
                else:
                    trace_info = quant_analys.trace_stats_by_qu_sess(
                        sessions,
                        analyspar,
                        stimpar,
                        quantpar.n_quants,
                        quantpar.qu_idx,
                        byroi=False,
                        lock=lock,
                        nan_empty=True,
                        baseline=baseline,
                        datatype=datatype)

                # for comparison, locking to middle of expected sample (1 quant)
                exp_samp = quant_analys.trace_stats_by_qu_sess(
                    sessions,
                    analyspar,
                    stimpar,
                    quantpar_one.n_quants,
                    quantpar_one.qu_idx,
                    byroi=False,
                    lock="exp_samp",
                    nan_empty=True,
                    baseline=baseline,
                    datatype=datatype)

                extrapar = {
                    "analysis": analysis,
                    "datatype": datatype,
                    "seed": seed,
                }

                xrans = [xran.tolist() for xran in trace_info[0]]
                all_stats = [sessst.tolist() for sessst in trace_info[1]]
                exp_stats = [expst.tolist() for expst in exp_samp[1]]
                trace_stats = {
                    "xrans": xrans,
                    "all_stats": all_stats,
                    "all_counts": trace_info[2],
                    "lock": lock,
                    "baseline": baseline,
                    "exp_stats": exp_stats,
                    "exp_counts": exp_samp[2]
                }

                if lock == "unexp_split":
                    trace_stats["unexp_lens"] = trace_info[3]

                sess_info = sess_gen_util.get_sess_info(
                    sessions,
                    analyspar.fluor,
                    incl_roi=(datatype == "roi"),
                    rem_bad=analyspar.rem_bad)

                info = {
                    "analyspar": analyspar._asdict(),
                    "sesspar": sesspar._asdict(),
                    "stimpar": stimpar._asdict(),
                    "quantpar": quantpar._asdict(),
                    "extrapar": extrapar,
                    "sess_info": sess_info,
                    "trace_stats": trace_stats
                }

                fulldir, savename = gen_plots.plot_traces_by_qu_lock_sess(
                    figpar=figpar, **info)
                file_util.saveinfo(info, savename, fulldir, "json")
Beispiel #5
0
def run_sess_logreg(sess,
                    analyspar,
                    stimpar,
                    logregpar,
                    n_splits=100,
                    n_shuff_splits=300,
                    seed=None,
                    parallel=False):
    """
    run_sess_logreg(sess, analyspar, stimpar, logregpar)

    Runs logistic regressions on a session (real data and shuffled), and 
    returns statistics dataframes.

    Required args:
        - sess (Session): 
            Session object
        - analyspar (AnalysPar): 
            named tuple containing analysis parameters
        - stimpar (StimPar): 
            named tuple containing stimulus parameters
        - logregpar (LogRegPar): 
            named tuple containing logistic regression parameters

    Optional args:
        - n_splits (int):
            number of data splits to run logistic regressions on
            default: 100
        - n_shuff_splits (int):
            number of shuffled data splits to run logistic regressions on
            default: 300
        - seed (int): 
            seed value to use. (-1 treated as None)
            default: None
        - parallel (bool): 
            if True, some of the analysis is run in parallel across CPU cores 
            default: False

    Returns:
        - data_stats_df (pd.DataFrame):
            dataframe with only one data row containing data stats for each 
            score and data subset.
        - shuffle_df (pd.DataFrame):
            dataframe where each row contains data for different data 
            shuffles, and each column contains data for each score and data 
            subset.
    """

    seed = rand_util.seed_all(seed, log_seed=False, seed_now=False)

    # retrieve data
    input_data, target_data, ctrl_ns = get_decoding_data(sess,
                                                         analyspar,
                                                         stimpar,
                                                         comp=logregpar.comp,
                                                         ctrl=logregpar.ctrl)

    scores_df = misc_analys.get_check_sess_df([sess], None, analyspar)
    common_columns = scores_df.columns.tolist()
    logreg_columns = ["comp", "ctrl", "bal", "shuffle"]

    # do checks
    if logregpar.q1v4 or logregpar.exp_v_unexp:
        raise NotImplementedError("q1v4 and exp_v_unexp are not implemented.")
    if n_splits <= 0 or n_shuff_splits <= 0:
        raise ValueError("n_splits and n_shuff_splits must be greater than 0.")

    set_types = ["train", "test"]
    score_types = ["neg_log_loss", "accuracy", "balanced_accuracy"]
    set_score_types = list(itertools.product(set_types, score_types))

    extrapar = dict()
    for shuffle in [False, True]:
        n_runs = n_shuff_splits if shuffle else n_splits
        extrapar["shuffle"] = shuffle

        temp_dfs = []
        for b, n in enumerate(range(0, n_runs, MAX_SIMULT_RUNS)):
            extrapar["n_runs"] = int(np.min([MAX_SIMULT_RUNS, n_runs - n]))

            with logger_util.TempChangeLogLevel(level="warning"):
                mod_cvs, _, _ = logreg_util.run_logreg_cv_sk(
                    input_data,
                    target_data,
                    logregpar._asdict(),
                    extrapar,
                    analyspar.scale,
                    ctrl_ns,
                    randst=seed + b,
                    parallel=parallel,
                    save_models=False,
                    catch_set_prob=False)

            temp_df = pd.DataFrame()
            for set_type, score_type in set_score_types:
                key = f"{set_type}_{score_type}"
                temp_df[key] = mod_cvs[key]
            temp_dfs.append(temp_df)

        # compile batch scores, and get session stats for non shuffled data
        temp_df = pd.concat(temp_dfs, ignore_index=True)
        if not shuffle:
            temp_df = get_df_stats(temp_df, analyspar)

        # add columns to df
        score_columns = temp_df.columns.tolist()
        for col in common_columns:
            temp_df[col] = scores_df.loc[0, col]
        for col in logreg_columns:
            if col != "shuffle":
                temp_df[col] = logregpar._asdict()[col]
            else:
                temp_df[col] = shuffle

        # re-sort columns
        temp_df = temp_df.reindex(common_columns + logreg_columns +
                                  score_columns,
                                  axis=1)

        if shuffle:
            shuffle_df = temp_df
        else:
            data_stats_df = temp_df

    return data_stats_df, shuffle_df
Beispiel #6
0
def reformat_args(args):
    """
    reformat_args(args)

    Returns reformatted args for analyses, specifically 
        - Sets stimulus parameters to "none" if they are irrelevant to the 
          stimtype
        - Changes stimulus parameters from "both" to actual values
        - Sets seed, though doesn't seed
        - Modifies analyses (if "all" or "all_" in parameter)

    Adds the following args:
        - omit_sess (str): sess to omit
        - omit_mice (str): mice to omit

    Required args:
        - args (Argument parser): parser with the following attributes: 
            visflow_dir (str)    : visual flow direction values to include
                                   (e.g., "right", "left" or "both")
            visflow_size (int or str): visual flow size values to include
                                   (e.g., 128, 256, "both")
            gabfr (int)          : gabor frame value to start sequences at
                                   (e.g., 0, 1, 2, 3)
            gabk (int or str)    : gabor kappa values to include 
                                   (e.g., 4, 16 or "both")
            gab_ori (int or str) : gabor orientation values to include
                                   (e.g., 0, 45, 90, 135, 180, 225 or "all")
            runtype (str)        : runtype ("pilot" or "prod")
            stimtype (str)       : stimulus to analyse (visflow or gabors)
    
    Returns:
        - args (Argument parser): input parser, with the following attributes
                                  modified: 
                                      visflow_dir, visflow_size, gabfr, gabk, 
                                      gab_ori, grps, analyses, seed
                                  and the following attributes added:
                                      omit_sess, omit_mice
    """
    args = copy.deepcopy(args)

    [args.visflow_dir, args.visflow_size, args.gabfr, args.gabk,
     args.gab_ori] = sess_gen_util.get_params(args.stimtype, args.visflow_dir,
                                              args.visflow_size, args.gabfr,
                                              args.gabk, args.gab_ori)

    args.omit_sess, args.omit_mice = sess_gen_util.all_omit(
        args.stimtype, args.runtype, args.visflow_dir, args.visflow_size,
        args.gabk)

    # choose a seed if none is provided (i.e., args.seed=-1), but seed later
    args.seed = rand_util.seed_all(args.seed,
                                   "cpu",
                                   log_seed=False,
                                   seed_now=False)

    # collect analysis letters
    all_analyses = "".join(get_analysis_fcts().keys())
    if "all" in args.analyses:
        if "_" in args.analyses:
            excl = args.analyses.split("_")[1]
            args.analyses, _ = gen_util.remove_lett(all_analyses, excl)
        else:
            args.analyses = all_analyses
    elif "_" in args.analyses:
        raise ValueError("Use '_' in args.analyses only with 'all'.")

    return args
Beispiel #7
0
def run_glms(sessions,
             analysis,
             seed,
             analyspar,
             sesspar,
             stimpar,
             glmpar,
             figpar,
             parallel=False):
    """
    run_glms(sessions, analysis, seed, analyspar, sesspar, stimpar, glmpar, 
             figpar)
    """

    seed = rand_util.seed_all(seed, "cpu", log_seed=False)

    sessstr_pr = sess_str_util.sess_par_str(sesspar.sess_n, stimpar.stimtype,
                                            sesspar.plane, stimpar.visflow_dir,
                                            stimpar.visflow_size, stimpar.gabk,
                                            "print")
    dendstr_pr = sess_str_util.dend_par_str(analyspar.dend, sesspar.plane,
                                            "roi", "print")

    logger.info(
        "Analysing and plotting explained variance in ROI activity "
        f"({sessstr_pr}{dendstr_pr}).",
        extra={"spacing": "\n"})

    if glmpar.each_roi:  # must do each session separately
        glm_type = "per_ROI_per_sess"
        sess_batches = sessions
        logger.info("Per ROI, each session separately.")
    else:
        glm_type = "across_sess"
        sess_batches = [sessions]
        logger.info(f"Across ROIs, {len(sessions)} sessions together.")

    # optionally runs in parallel, or propagates parallel to next level
    parallel_here = (parallel and not (glmpar.each_roi)
                     and (len(sess_batches) != 1))
    parallel_after = True if (parallel and not (parallel_here)) else False

    args_list = [analyspar, sesspar, stimpar, glmpar]
    args_dict = {
        "parallel": parallel_after,  # proactively set next parallel 
        "seed": seed,
    }
    all_expl_var = gen_util.parallel_wrap(run_glm,
                                          sess_batches,
                                          args_list,
                                          args_dict,
                                          parallel=parallel_here)

    if glmpar.each_roi:
        sessions = sess_batches
    else:
        sessions = sess_batches[0]

    sess_info = sess_gen_util.get_sess_info(sessions,
                                            analyspar.fluor,
                                            rem_bad=analyspar.rem_bad)

    extrapar = {
        "analysis": analysis,
        "seed": seed,
        "glm_type": glm_type,
    }

    info = {
        "analyspar": analyspar._asdict(),
        "sesspar": sesspar._asdict(),
        "stimpar": stimpar._asdict(),
        "glmpar": glmpar._asdict(),
        "extrapar": extrapar,
        "all_expl_var": all_expl_var,
        "sess_info": sess_info
    }

    fulldir, savename = glm_plots.plot_glm_expl_var(figpar=figpar, **info)

    file_util.saveinfo(info, savename, fulldir, "json")

    return
def run_sess_lstm(sessid, args):

    if args.parallel and args.plt_bkend is not None:
        plt.switch_backend(args.plt_bkend) # needs to be repeated within joblib

    args.seed = rand_util.seed_all(args.seed, args.device, seed_torch=True)

    train_p = 0.8
    lr = 1. * 10**(-args.lr_ex)
    if args.conv:
        conv_str = "_conv"
        outch_str = f"_{args.out_ch}outch"
    else:
        conv_str = ""
        outch_str = ""

    # Input output parameters
    n_stim_s  = 0.6
    n_roi_s = 0.3

    # Stim/traces for training
    train_gabfr = 0
    train_post = 0.9 # up to C
    roi_train_pre = 0 # from A
    stim_train_pre   = 0.3 # from preceeding grayscreen

    # Stim/traces for testing (separated for unexp vs exp)
    test_gabfr = 3
    test_post  = 0.6 # up to grayscreen
    roi_test_pre = 0 # from D/U
    stim_test_pre   = 0.3 # from preceeding C

    sess = sess_gen_util.init_sessions(
        sessid, args.datadir, args.mouse_df, args.runtype, full_table=False, 
        fluor="dff", dend="extr", run=True, temp_log="warning")[0]

    analysdir = sess_gen_util.get_analysdir(
        sess.mouse_n, sess.sess_n, sess.plane, stimtype=args.stimtype, 
        comp=None)
    dirname = Path(args.output, analysdir)
    file_util.createdir(dirname, log_dir=False)

    # Must not scale ROIs or running BEFOREHAND. Must do after to use only 
    # network available data.

    # seq x frame x gabor x par
    logger.info("Preparing stimulus parameter dataframe", 
        extra={"spacing": "\n"})
    train_stim_wins, run_stats = sess_data_util.get_stim_data(
        sess, args.stimtype, n_stim_s, train_gabfr, stim_train_pre, 
        train_post, gabk=16, run=True)

    logger.info("Adding ROI data")
    xran, train_roi_wins, roi_stats = sess_data_util.get_roi_data(
        sess, args.stimtype, n_roi_s, train_gabfr, roi_train_pre, train_post, 
        gabk=16)

    logger.warning("Preparing windowed datasets (too slow - to be improved)")
    raise NotImplementedError("Not implemented properly - some error leads "
        "to excessive memory requests.")
    test_stim_wins = []
    test_roi_wins  = []
    for unexp in [0, 1]:
        stim_wins = sess_data_util.get_stim_data(
            sess, args.stimtype, n_stim_s, test_gabfr, stim_test_pre, 
            test_post, unexp, gabk=16, run_mean=run_stats[0], 
            run_std=run_stats[1])
        test_stim_wins.append(stim_wins)
        
        roi_wins = sess_data_util.get_roi_data(sess, args.stimtype, n_roi_s,  
                           test_gabfr, roi_test_pre, test_post, unexp, gabk=16, 
                           roi_means=roi_stats[0], roi_stds=roi_stats[1])[1]
        test_roi_wins.append(roi_wins)

    n_pars = train_stim_wins.shape[-1] # n parameters (121)
    n_rois = train_roi_wins.shape[-1] # n ROIs

    hyperstr = (f"{args.hidden_dim}hd_{args.num_layers}hl_{args.lr_ex}lrex_"
                f"{args.batchsize}bs{outch_str}{conv_str}")

    dls = data_util.create_dls(train_stim_wins, train_roi_wins, train_p=train_p, 
                            test_p=0, batchsize=args.batchsize, thresh_cl=0, 
                            train_shuff=True)[0]
    train_dl, val_dl, _ = dls

    test_dls = []
    
    for s in [0, 1]:
        dl = data_util.init_dl(test_stim_wins[s], test_roi_wins[s], 
                            batchsize=args.batchsize)
        test_dls.append(dl)

    logger.info("Running LSTM")
    if args.conv:
        lstm = ConvPredROILSTM(args.hidden_dim, n_rois, out_ch=args.out_ch, 
                            num_layers=args.num_layers, dropout=args.dropout)
    else:
        lstm = PredLSTM(n_pars, args.hidden_dim, n_rois, 
                        num_layers=args.num_layers, dropout=args.dropout)

    lstm = lstm.to(args.device)
    lstm.loss_fn = torch.nn.MSELoss(size_average=False)
    lstm.opt = torch.optim.Adam(lstm.parameters(), lr=lr)

    loss_df = pd.DataFrame(
        np.nan, index=range(args.n_epochs), columns=["train", "val"])
    min_val = np.inf
    for ep in range(args.n_epochs):
        logger.info(f"====> Epoch {ep}", extra={"spacing": "\n"})
        if ep == 0:
            train_loss = run_dl(lstm, train_dl, args.device, train=False)    
        else:
            train_loss = run_dl(lstm, train_dl, args.device, train=True)
        val_loss = run_dl(lstm, val_dl, args.device, train=False)
        loss_df["train"].loc[ep] = train_loss/train_dl.dataset.n_samples
        loss_df["val"].loc[ep] = val_loss/val_dl.dataset.n_samples
        logger.info(f"Training loss  : {loss_df['train'].loc[ep]}")
        logger.info(f"Validation loss: {loss_df['val'].loc[ep]}")

        # record model if training is lower than val, and val reaches a new low
        if ep == 0 or val_loss < min_val:
            prev_model = glob.glob(str(Path(dirname, f"{hyperstr}_ep*.pth")))
            prev_df = glob.glob(str(Path(dirname, f"{hyperstr}.csv")))
            min_val = val_loss
            saved_ep = ep
                
            if len(prev_model) == 1 and len(prev_df) == 1:
                Path(prev_model[0]).unlink()
                Path(prev_df[0]).unlink()

            savename = f"{hyperstr}_ep{ep}"
            savefile = Path(dirname, savename)
        
            torch.save({"net": lstm.state_dict(), "opt": lstm.opt.state_dict()},
                        f"{savefile}.pth")
        
            file_util.saveinfo(loss_df, hyperstr, dirname, "csv")

    plot_util.linclab_plt_defaults(font=["Arial", "Liberation Sans"], 
                                   fontdir=DEFAULT_FONTDIR)
    fig, ax = plt.subplots(1)
    for dataset in ["train", "val"]:
        plot_util.plot_traces(ax, range(args.n_epochs), np.asarray(loss_df[dataset]), 
                  label=dataset, title=f"Average loss (MSE) ({n_rois} ROIs)", 
                  xticks="auto")
    fig.savefig(Path(dirname, f"{hyperstr}_loss"))

    savemod = Path(dirname, f"{hyperstr}_ep{saved_ep}.pth")
    checkpoint = torch.load(savemod)
    lstm.load_state_dict(checkpoint["net"]) 

    n_samples = 20
    val_idx = np.random.choice(range(val_dl.dataset.n_samples), n_samples)
    val_samples = val_dl.dataset[val_idx]
    xrans = data_util.get_win_xrans(xran, val_samples[1].shape[1], val_idx.tolist())

    fig, ax = plot_util.init_fig(n_samples, ncols=4, sharex=True, subplot_hei=2, 
                                subplot_wid=5)


    lstm.eval()
    with torch.no_grad():
        batch_len, seq_len, n_items = val_samples[1].shape
        pred_tr = lstm(val_samples[0].transpose(1, 0).to(args.device))
        pred_tr = pred_tr.view([seq_len, batch_len, n_items]).transpose(1, 0)

    for lab, data in zip(["target", "pred"], [val_samples[1], pred_tr]):
        data = data.numpy()
        for n in range(n_samples):
            roi_n = np.random.choice(range(data.shape[-1]))
            sub_ax = plot_util.get_subax(ax, n)
            plot_util.plot_traces(sub_ax, xrans[n], data[n, :, roi_n], 
                label=lab, xticks="auto")
            plot_util.set_ticks(sub_ax, "x", xran[0], xran[-1], n=7)

    sess_plot_util.plot_labels(ax, train_gabfr, plot_vals="exp", pre=roi_train_pre, 
                            post=train_post)

    fig.suptitle(f"Target vs predicted validation traces ({n_rois} ROIs)")
    fig.savefig(Path(dirname, f"{hyperstr}_traces"))
def reformat_args(args):
    """
    reformat_args(args)

    Returns reformatted args for analyses, specifically 
        - Sets stimulus parameters to "none" if they are irrelevant to the 
          stimtype
        - Changes stimulus parameters from "both" to actual values
        - Modifies the session number parameter
        - Sets seed, though doesn't seed
        - Modifies analyses (if "all" or "all_" in parameter)
        - Sets latency parameters based on lat_method

    Adds the following args:
        - dend (str)     : type of dendrites to use ("allen" or "extr")
        - omit_sess (str): sess to omit
        - omit_mice (str): mice to omit

    Required args:
        - args (Argument parser): parser with the following attributes: 
            visflow_dir (str)        : visual flow direction values to include
                                   (e.g., "right", "left" or "both")
            visflow_size (int or str): visual flow size values to include
                                   (e.g., 128, 256, "both")
            gabfr (int)          : gabor frame value to start sequences at
                                   (e.g., 0, 1, 2, 3)
            gabk (int or str)    : gabor kappa values to include 
                                   (e.g., 4, 16 or "both")
            gab_ori (int or str) : gabor orientation values to include
                                   (e.g., 0, 45, 90, 135, 180, 225 or "all")
            mouse_ns (str)       : mouse numbers or range 
                                   (e.g., 1, "1,3", "1-3", "all")
            runtype (str)        : runtype ("pilot" or "prod")
            sess_n (str)         : session number range (e.g., "1-1", "all")
            stimtype (str)       : stimulus to analyse (visflow or gabors)
    
    Returns:
        - args (Argument parser): input parser, with the following attributes
                                  modified: 
                                      visflow_dir, visflow_size, gabfr, gabk, 
                                      gab_ori, sess_n, mouse_ns, analyses, seed, 
                                      lat_p_val_thr, lat_rel_std
                                  and the following attributes added:
                                      omit_sess, omit_mice, dend
    """

    args = copy.deepcopy(args)

    if args.plane == "soma":
        args.dend = "allen"

    [args.visflow_dir, args.visflow_size, args.gabfr, args.gabk,
     args.gab_ori] = sess_gen_util.get_params(args.stimtype, args.visflow_dir,
                                              args.visflow_size, args.gabfr,
                                              args.gabk, args.gab_ori)

    if args.datatype == "run":
        args.fluor = "n/a"
    if args.plane == "soma":
        args.dend = "allen"

    args.omit_sess, args.omit_mice = sess_gen_util.all_omit(
        args.stimtype, args.runtype, args.visflow_dir, args.visflow_size,
        args.gabk)

    if "-" in str(args.sess_n):
        vals = str(args.sess_n).split("-")
        if len(vals) != 2:
            raise ValueError(
                "If args.sess_n is a range, must have format 1-3.")
        st = int(vals[0])
        end = int(vals[1]) + 1
        args.sess_n = list(range(st, end))

    if args.lat_method == "ratio":
        args.lat_p_val_thr = None
    elif args.lat_method == "ttest":
        args.lat_rel_std = None

    # choose a seed if none is provided (i.e., args.seed=-1), but seed later
    args.seed = rand_util.seed_all(args.seed,
                                   "cpu",
                                   log_seed=False,
                                   seed_now=False)

    # collect mouse numbers from args.mouse_ns
    if "," in args.mouse_ns:
        args.mouse_ns = [int(n) for n in args.mouse_ns.split(",")]
    elif "-" in args.mouse_ns:
        vals = str(args.mouse_ns).split("-")
        if len(vals) != 2:
            raise ValueError(
                "If args.mouse_ns is a range, must have format 1-3.")
        st = int(vals[0])
        end = int(vals[1]) + 1
        args.mouse_ns = list(range(st, end))
    elif args.mouse_ns not in ["all", "any"]:
        args.mouse_ns = int(args.mouse_ns)

    # collect analysis letters
    all_analyses = "".join(get_analysis_fcts().keys())
    if "all" in args.analyses:
        if "_" in args.analyses:
            excl = args.analyses.split("_")[1]
            args.analyses, _ = gen_util.remove_lett(all_analyses, excl)
        else:
            args.analyses = all_analyses
    elif "_" in args.analyses:
        raise ValueError("Use '_' in args.analyses only with 'all'.")

    return args
Beispiel #10
0
def plot_violin_data(sub_ax, xs, all_data, palette=None, dashes=None, 
                     seed=None):
    """
    plot_violin_data(sub_ax, xs, all_data)

    Plots violin data for each data group.

    Required args:
        - sub_ax (plt subplot):
            subplot
        - xs (list):
            x value for each data group
        - all_data (list):
            data for each data group


    Optional args:
        - palette (list)
            colors for each data group
            default: None
        - dashes (list): 
            dash patterns for each data group
            default: None
        - seed (int): 
            seed value to use. (-1 treated as None)
            default: None
    """

    # seed for scatterplot
    rand_util.seed_all(seed, log_seed=False)

    # checks
    if len(xs) != len(all_data):
        raise ValueError("xs must have the same length as all_data.")
    
    if palette is not None and len(xs) != len(palette):
        raise ValueError(
            "palette, if provided, must have the same length as xs."
            )

    if dashes is not None and len(xs) != len(dashes):
        raise ValueError(
            "dashes, if provided, must have the same length as xs."
            )

    xs_arr = np.concatenate([
        np.full_like(data, x) for x, data in zip(xs, all_data)
    ])
    data_arr = np.concatenate(all_data)

    # add violins
    bplot = seaborn.violinplot(
        x=xs_arr, y=data_arr, inner=None, linewidth=3.5, color="white", 
        ax=sub_ax
    )

    # edit contours
    for c, collec in enumerate(bplot.collections):
        collec.set_edgecolor(plot_helper_fcts.NEARBLACK)

        if dashes is not None and dashes[c] is not None:
            collec.set_linestyle(plot_helper_fcts.VDASH)

    # add data dots
    seaborn.stripplot(
        x=xs_arr, y=data_arr, size=9, jitter=0.2, alpha=0.3, palette=palette, 
        ax=sub_ax
        )