Esempio n. 1
0
def main(args):
    """
    main(args)

    Runs analyses with parser arguments.

    Required args:
        - args (dict): parser argument dictionary
    """

    # set logger to the specified level
    logger_util.set_level(level=args.log_level)

    # args.device = gen_util.get_device(args.cuda)
    args.fontdir = DEFAULT_FONTDIR if DEFAULT_FONTDIR.is_dir() else None

    if args.dict_path != "":
        plot_dicts.plot_from_dicts(Path(args.dict_path),
                                   source="glm",
                                   plt_bkend=args.plt_bkend,
                                   fontdir=args.fontdir,
                                   parallel=args.parallel,
                                   datetime=not (args.no_datetime),
                                   overwrite=args.overwrite)
    else:
        if args.datadir is None:
            args.datadir = DEFAULT_DATADIR
        else:
            args.datadir = Path(args.datadir)
        mouse_df = DEFAULT_MOUSE_DF_PATH

        args = reformat_args(args)

        # get numbers of sessions to analyse
        if args.sess_n == "all":
            all_sess_ns = sess_gen_util.get_sess_vals(mouse_df,
                                                      "sess_n",
                                                      runtype=args.runtype,
                                                      plane=args.plane,
                                                      line=args.line,
                                                      min_rois=args.min_rois,
                                                      pass_fail=args.pass_fail,
                                                      incl=args.incl,
                                                      omit_sess=args.omit_sess,
                                                      omit_mice=args.omit_mice)
        else:
            all_sess_ns = gen_util.list_if_not(args.sess_n)

        for sess_n in all_sess_ns:
            analys_pars = prep_analyses(sess_n, args, mouse_df)

            analyses_parallel = bool(args.parallel * (not args.debug))
            run_analyses(*analys_pars,
                         analyses=args.analyses,
                         parallel=analyses_parallel)
def main(args):
    """
    main(args)

    Runs analyses with parser arguments.

    Required args:
        - args (dict): parser argument dictionary
    """

    # set logger to the specified level
    logger_util.set_level(level=args.log_level)

    args.device = "cpu"

    if args.datadir is None: 
        args.datadir = DEFAULT_DATADIR
    else:
        args.datadir = Path(args.datadir)
    args.mouse_df = DEFAULT_MOUSE_DF_PATH
    args.runtype = "prod"
    args.plane = "soma"
    args.stimtype = "gabors"


    args.omit_sess, args.omit_mice = sess_gen_util.all_omit(
        args.stimtype, args.runtype
        )

    
    all_sessids = sess_gen_util.get_sess_vals(
        args.mouse_df, "sessid", runtype=args.runtype, sess_n=[1, 2, 3], 
        plane=args.plane, min_rois=1, pass_fail="P", omit_sess=args.omit_sess, 
        omit_mice=args.omit_mice)


    # bsizes =[1, 15, 30] #3
    # outchs = [18, 9, 3]
    # hiddims = [100, 35, 5]
    # numlays = [3, 2, 1]
    # lr_exs = [4, 3, 5]
    # convs = [True, False]
    # args.n_epochs = 0

    gen_util.parallel_wrap(
        run_sess_lstm, all_sessids, args_list=[args], parallel=args.parallel)
def download_dandiset_assets(dandiset_id="000037",
                             version="draft",
                             output=".",
                             incl_stim_templates=False,
                             incl_full_stacks=False,
                             sess_ns="all",
                             mouse_ns="all",
                             excluded_sess=True,
                             mouse_df=DEFAULT_MOUSE_DF_PATH,
                             log_level="info"):

    logger_util.format_all(level=log_level)

    dandiset_id = f"{int(dandiset_id):06}"  # ensure correct ID formatting

    asset_sessids = "all"
    if sess_ns not in ["all", "any"] or mouse_ns not in ["all", "any"]:
        if dandiset_id != "000037":
            raise NotImplementedError(
                "Selecting assets based on session and mouse numbers is only "
                "implemented for dandiset 000037.")
        sess_ns = reformat_n(sess_ns)
        mouse_ns = reformat_n(mouse_ns)
        pass_fail = "all" if excluded_sess else "P"
        asset_sessids = sess_gen_util.get_sess_vals(mouse_df,
                                                    "dandi_session_id",
                                                    mouse_n=mouse_ns,
                                                    sess_n=sess_ns,
                                                    runtype="prod",
                                                    pass_fail=pass_fail,
                                                    incl="all",
                                                    sort=True)

    logger.info("Identifying the URLs of dandi assets to download...")
    dandiset_urls = get_dandiset_asset_urls(
        dandiset_id,
        version=version,
        asset_sessids=asset_sessids,
        incl_stim_templates=incl_stim_templates,
        incl_full_stacks=incl_full_stacks)

    logger.info(f"Downloading {len(dandiset_urls)} assets from "
                f"dandiset {dandiset_id}...")

    for dandiset_url in dandiset_urls:
        dandi_download.download(dandiset_url, output, existing="refresh")
Esempio n. 4
0
def main(args):
    """
    main(args)

    Runs analyses with parser arguments.

    Required args:
        - args (dict): parser argument dictionary
    """

    # set logger to the specified level
    logger_util.set_level(level=args.log_level)

    args.fontdir = DEFAULT_FONTDIR if DEFAULT_FONTDIR.is_dir() else None

    if args.dict_path is not None:
        source = "modif" if args.modif else "run"
        plot_dicts.plot_from_dicts(Path(args.dict_path),
                                   source=source,
                                   plt_bkend=args.plt_bkend,
                                   fontdir=args.fontdir,
                                   parallel=args.parallel,
                                   datetime=not (args.no_datetime),
                                   overwrite=args.overwrite)
    else:
        args = reformat_args(args)
        if args.datadir is None:
            args.datadir = DEFAULT_DATADIR
        else:
            args.datadir = Path(args.datadir)
        mouse_df = DEFAULT_MOUSE_DF_PATH

        # get numbers of sessions to analyse
        if args.sess_n == "all":
            all_sess_ns = sess_gen_util.get_sess_vals(mouse_df,
                                                      "sess_n",
                                                      runtype=args.runtype,
                                                      plane=args.plane,
                                                      line=args.line,
                                                      min_rois=args.min_rois,
                                                      pass_fail=args.pass_fail,
                                                      incl=args.incl,
                                                      omit_sess=args.omit_sess,
                                                      omit_mice=args.omit_mice,
                                                      sort=True)
        else:
            all_sess_ns = gen_util.list_if_not(args.sess_n)

        # get analysis parameters for each session number
        all_analys_pars = gen_util.parallel_wrap(prep_analyses,
                                                 all_sess_ns,
                                                 args_list=[args, mouse_df],
                                                 parallel=args.parallel)

        # split parallel from sequential analyses
        bool(args.parallel * (not args.debug))
        if args.parallel:
            run_seq = ""  # should be run parallel within analysis
            all_analyses = gen_util.remove_lett(args.analyses, run_seq)
            sess_parallels = [True, False]
            analyses_parallels = [False, True]
        else:
            all_analyses = [args.analyses]
            sess_parallels, analyses_parallels = [False], [False]

        for analyses, sess_parallel, analyses_parallel in zip(
                all_analyses, sess_parallels, analyses_parallels):
            if len(analyses) == 0:
                continue
            args_dict = {
                "analyses": analyses,
                "seed": args.seed,
                "parallel": analyses_parallel,
            }

            # run analyses for each parameter set
            gen_util.parallel_wrap(run_analyses,
                                   all_analys_pars,
                                   args_dict=args_dict,
                                   parallel=sess_parallel,
                                   mult_loop=True)
def run_regr(args):
    """
    run_regr(args)

    Does runs of a logistic regressions on the specified comparison and range
    of sessions.
    
    Required args:
        - args (Argument parser): parser with analysis parameters as attributes:
            alg (str)             : algorithm to use ("sklearn" or "pytorch")
            bal (bool)            : if True, classes are balanced
            batchsize (int)       : nbr of samples dataloader will load per 
                                    batch (for "pytorch" alg)
            visflow_dir (str)     : visual flow direction to analyse
            visflow_per (float)   : number of seconds to include before visual 
                                    flow segments
            visflow_size (int or list): visual flow square sizes to include
            comp (str)            : type of comparison
            datadir (str)         : data directory
            dend (str)            : type of dendrites to use ("allen" or "dend")
            device (str)          : device name (i.e., "cuda" or "cpu")
            ep_freq (int)         : frequency at which to log loss to 
                                    console
            error (str)           : error to take, i.e., "std" (for std 
                                    or quantiles) or "sem" (for SEM or MAD)
            fluor (str)           : fluorescence trace type
            fontdir (str)         : directory in which additional fonts are 
                                    located
            gabfr (int)           : gabor frame of reference if comparison 
                                    is "unexp"
            gabk (int or list)    : gabor kappas to include
            gab_ori (list or str) : gabor orientations to include
            incl (str or list)    : sessions to include ("yes", "no", "all")
            lr (num)              : model learning rate (for "pytorch" alg)
            mouse_n (int)         : mouse number
            n_epochs (int)        : number of epochs
            n_reg (int)           : number of regular runs
            n_shuff (int)         : number of shuffled runs
            scale (bool)          : if True, each ROI is scaled
            output (str)          : general directory in which to save 
                                    output
            parallel (bool)       : if True, runs are done in parallel
            plt_bkend (str)       : pyplot backend to use
            q1v4 (bool)           : if True, analysis is trained on first and 
                                    tested on last quartiles
            exp_v_unexp (bool)    : if True, analysis is trained on 
                                    expected and tested on unexpected sequences
            runtype (str)         : type of run ("prod" or "pilot")
            seed (int)            : seed to seed random processes with
            sess_n (int)          : session number
            stats (str)           : stats to take, i.e., "mean" or "median"
            stimtype (str)        : stim to analyse ("gabors" or "visflow")
            train_p (list)        : proportion of dataset to allocate to 
                                    training
            uniqueid (str or int) : unique ID for analysis
            wd (float)            : weight decay value (for "pytorch" arg)
    """

    args = copy.deepcopy(args)

    if args.datadir is None:
        args.datadir = DEFAULT_DATADIR
    else:
        args.datadir = Path(args.datadir)

    if args.uniqueid == "datetime":
        args.uniqueid = gen_util.create_time_str()
    elif args.uniqueid in ["None", "none"]:
        args.uniqueid = None

    reseed = False
    if args.seed in [None, "None"]:
        reseed = True

    # deal with parameters
    extrapar = {"uniqueid": args.uniqueid, "seed": args.seed}

    techpar = {
        "reseed": reseed,
        "device": args.device,
        "alg": args.alg,
        "parallel": args.parallel,
        "plt_bkend": args.plt_bkend,
        "fontdir": args.fontdir,
        "output": args.output,
        "ep_freq": args.ep_freq,
        "n_reg": args.n_reg,
        "n_shuff": args.n_shuff,
    }

    mouse_df = DEFAULT_MOUSE_DF_PATH

    stimpar = logreg.get_stimpar(args.comp,
                                 args.stimtype,
                                 args.visflow_dir,
                                 args.visflow_size,
                                 args.gabfr,
                                 args.gabk,
                                 gab_ori=args.gab_ori,
                                 visflow_pre=args.visflow_pre)

    analyspar = sess_ntuple_util.init_analyspar(args.fluor,
                                                stats=args.stats,
                                                error=args.error,
                                                scale=not (args.no_scale),
                                                dend=args.dend)

    if args.q1v4:
        quantpar = sess_ntuple_util.init_quantpar(4, [0, -1])
    else:
        quantpar = sess_ntuple_util.init_quantpar(1, 0)

    logregpar = sess_ntuple_util.init_logregpar(args.comp, not (args.not_ctrl),
                                                args.q1v4, args.exp_v_unexp,
                                                args.n_epochs, args.batchsize,
                                                args.lr, args.train_p, args.wd,
                                                args.bal, args.alg)

    omit_sess, omit_mice = sess_gen_util.all_omit(stimpar.stimtype,
                                                  args.runtype,
                                                  stimpar.visflow_dir,
                                                  stimpar.visflow_size,
                                                  stimpar.gabk)

    sessids = sorted(
        sess_gen_util.get_sess_vals(mouse_df,
                                    "sessid",
                                    args.mouse_n,
                                    args.sess_n,
                                    args.runtype,
                                    incl=args.incl,
                                    omit_sess=omit_sess,
                                    omit_mice=omit_mice))

    if len(sessids) == 0:
        logger.warning(
            f"No sessions found (mouse: {args.mouse_n}, sess: {args.sess_n}, "
            f"runtype: {args.runtype})")

    for sessid in sessids:
        sess = sess_gen_util.init_sessions(sessid,
                                           args.datadir,
                                           mouse_df,
                                           args.runtype,
                                           full_table=False,
                                           fluor=analyspar.fluor,
                                           dend=analyspar.dend,
                                           temp_log="warning")[0]
        logreg.run_regr(sess, analyspar, stimpar, logregpar, quantpar,
                        extrapar, techpar)

        plot_util.cond_close_figs()
Esempio n. 6
0
def init_sessions(analyspar,
                  sesspar,
                  mouse_df,
                  datadir,
                  sessions=None,
                  roi=True,
                  run=False,
                  pupil=False,
                  parallel=False):
    """
    init_sessions(sesspar, mouse_df, datadir)

    Initializes sessions.

    Required args:
        - analyspar (AnalysPar): 
            named tuple containing session parameters
        - sesspar (SessPar): 
            named tuple containing session parameters
        - mouse_df (pandas df): 
            path name of dataframe containing information on each session
        - datadir (Path): 
            path to data directory
    
    Optional args:
        - sessions (list): 
            preloaded sessions
            default: None
        - roi (bool): 
            if True, ROI data is loaded
            default: True
        - run (bool): 
            if True, running data is loaded
            default: False
        - pupil (bool): 
            if True, pupil data is loaded
            default: False

    Returns:
        - sessions (list): 
            Session objects 
    """

    sesspar_dict = sesspar._asdict()
    sesspar_dict.pop("closest")

    # identify sessions needed
    sessids = sorted(
        sess_gen_util.get_sess_vals(mouse_df, "sessid", **sesspar_dict))

    if len(sessids) == 0:
        raise ValueError("No sessions meet the criteria.")

    # check for preloaded sessions, and only load new ones
    if sessions is not None:
        loaded_sessids = [session.sessid for session in sessions]
        ext_str = " additional"
    else:
        sessions = []
        loaded_sessids = []
        ext_str = ""

    # identify new sessions to load
    load_sessids = list(
        filter(lambda sessid: sessid not in loaded_sessids, sessids))

    # remove sessions that are not needed
    if len(sessions):
        sessions = [
            session for session in sessions if session.sessid in sessids
        ]

        # check that previously loaded sessions have roi/run/pupil data loaded
        args_list = [roi, run, pupil, analyspar.fluor, analyspar.dend]
        with logger_util.TempChangeLogLevel(level="warning"):
            sessions = gen_util.parallel_wrap(sess_gen_util.check_session,
                                              sessions,
                                              args_list=args_list,
                                              parallel=parallel)

    # load new sessions
    if len(load_sessids):
        logger.info(f"Loading {len(load_sessids)}{ext_str} session(s)...",
                    extra={"spacing": "\n"})

        args_dict = {
            "datadir": datadir,
            "mouse_df": mouse_df,
            "runtype": sesspar.runtype,
            "full_table": False,
            "fluor": analyspar.fluor,
            "dend": analyspar.dend,
            "roi": roi,
            "run": run,
            "pupil": pupil,
            "temp_log": "critical"  # suppress almost all logs 
        }

        new_sessions = gen_util.parallel_wrap(sess_gen_util.init_sessions,
                                              load_sessids,
                                              args_dict=args_dict,
                                              parallel=parallel,
                                              use_tqdm=True)

        # flatten list of new sessions, and add to full sessions list
        new_sessions = [sess for singles in new_sessions for sess in singles]
        sessions = sessions + new_sessions

    # combine session lists, and sort
    sorter = [sessids.index(session.sessid) for session in sessions]
    sessions = [sessions[i] for i in sorter]

    # update ROI tracking parameters
    for sess in sessions:
        sess.set_only_tracked_rois(analyspar.tracked)

    return sessions
def prep_analyses(sess_n, args, mouse_df, parallel=False):
    """
    prep_analyses(sess_n, args, mouse_df)

    Prepares named tuples and sessions for which to run analyses, based on the 
    arguments passed.

    Required args:
        - sess_n (int)          : session number to run analyses on, or 
                                  combination of session numbers to compare, 
                                  e.g. "1v2"
        - args (Argument parser): parser containing all parameters
        - mouse_df (pandas df)  : path name of dataframe containing information 
                                  on each session

    Optional args:
        - parallel (bool): if True, sessions are initialized in parallel 
                           across CPU cores 
                           default: False

    Returns:
        - sessions (list)      : list of sessions, or nested list per mouse 
                                 if sess_n is a combination
        - analysis_dict (dict): dictionary of analysis parameters 
                                (see init_param_cont())
    """

    args = copy.deepcopy(args)

    args.sess_n = sess_n

    analysis_dict = init_param_cont(args)
    analyspar, sesspar, stimpar = [
        analysis_dict[key] for key in ["analyspar", "sesspar", "stimpar"]
    ]

    roi = (args.datatype == "roi")
    run = (args.datatype == "run")

    sesspar_dict = sesspar._asdict()
    _ = sesspar_dict.pop("closest")

    [all_mouse_ns,
     all_sess_ns] = sess_gen_util.get_sess_vals(mouse_df,
                                                ["mouse_n", "sess_n"],
                                                omit_sess=args.omit_sess,
                                                omit_mice=args.omit_mice,
                                                unique=False,
                                                **sesspar_dict)

    if args.sess_n in ["any", "all"]:
        all_sess_ns = [n + 1 for n in range(max(all_sess_ns))]
    else:
        all_sess_ns = args.sess_n

    # get session IDs and create Sessions
    all_mouse_ns = sorted(set(all_mouse_ns))

    logger.info(f"Loading sessions for {len(all_mouse_ns)} mice...",
                extra={"spacing": "\n"})
    args_list = [
        all_sess_ns, sesspar, mouse_df, args.datadir, args.omit_sess,
        analyspar.fluor, analyspar.dend, roi, run
    ]
    sessions = gen_util.parallel_wrap(init_mouse_sess,
                                      all_mouse_ns,
                                      args_list=args_list,
                                      parallel=parallel,
                                      use_tqdm=True)

    check_all = set([sess for m_sess in sessions for sess in m_sess])
    if len(sessions) == 0 or check_all == {None}:
        raise RuntimeError("No sessions meet the criteria.")

    runtype_str = ""
    if sesspar.runtype != "prod":
        runtype_str = f" ({sesspar.runtype} data)"

    stim_str = stimpar.stimtype
    if stimpar.stimtype == "gabors":
        stim_str = "gabor"
    elif stimpar.stimtype == "visflow":
        stim_str = "visual flow"

    logger.info(
        f"Analysis of {sesspar.plane} responses to {stim_str} "
        f"stimuli{runtype_str}.\nSession {sesspar.sess_n}",
        extra={"spacing": "\n"})

    return sessions, analysis_dict
def init_mouse_sess(mouse_n,
                    all_sess_ns,
                    sesspar,
                    mouse_df,
                    datadir,
                    omit_sess=[],
                    fluor="dff",
                    dend="extr",
                    roi=True,
                    run=False,
                    pupil=False):
    """
    init_mouse_sess(mouse_n, all_sess_ns, sesspar, mouse_df, datadir)

    Initializes the sessions for the specified mouse.

    Required args:
        - mouse_n (int)       : mouse number
        - all_sess_ns (list)  : list of all sessions to include
        - sesspar (SessPar)   : named tuple containing session parameters
        - mouse_df (pandas df): path name of dataframe containing information 
                                  on each session
        - datadir (str)       : path to data directory
    
    Optional args:
        - omit_sess (list): list of sessions to omit
        - dend (str)      : type of dendrites to use ("allen" or "dend")
        - fluor (str)     : if "raw", raw ROI traces are used. If 
                            "dff", dF/F ROI traces are used.
        - dend (str)      : type of dendrites to use ("allen" or "dend")
        - roi (bool)      : if True, ROI data is loaded
                            default: True
        - run (bool)      : if True, running data is loaded
                            default: False
        - pupil (bool)    : if True, pupil data is loaded
                            default: False

    Returns:
        - mouse_sesses (list): list of Session objects for the specified mouse, 
                               with None in the position of missing sessions 
    """

    sesspar_dict = sesspar._asdict()
    sesspar_dict.pop("closest")

    mouse_sesses = []
    for sess_n in all_sess_ns:
        sesspar_dict["sess_n"] = sess_n
        sesspar_dict["mouse_n"] = mouse_n
        sessid = sess_gen_util.get_sess_vals(mouse_df,
                                             "sessid",
                                             omit_sess=omit_sess,
                                             **sesspar_dict)
        if len(sessid) == 0:
            sess = [None]
        elif len(sessid) > 1:
            raise RuntimeError(
                "Expected no more than 1 session per mouse/session number.")
        else:
            sess = sess_gen_util.init_sessions(sessid[0],
                                               datadir,
                                               mouse_df,
                                               sesspar.runtype,
                                               full_table=False,
                                               fluor=fluor,
                                               dend=dend,
                                               omit=roi,
                                               roi=roi,
                                               run=run,
                                               pupil=pupil,
                                               temp_log="warning")
            if len(sess) == 0:
                sess = [None]
        mouse_sesses.append(sess[0])

    return mouse_sesses
Esempio n. 9
0
def prep_analyses(sess_n, args, mouse_df):
    """
    prep_analyses(sess_n, args, mouse_df)

    Prepares named tuples and sessions for which to run analyses, based on the 
    arguments passed.

    Required args:
        - sess_n (int)          : session number to run analyses on
        - args (Argument parser): parser containing all parameters
        - mouse_df (pandas df)  : path name of dataframe containing information 
                                  on each session

    Returns:
        - sessions (list)     : list of sessions, or nested list per mouse 
                                if sess_n is a combination to compare
        - analysis_dict (dict): dictionary of analysis parameters 
                                (see init_param_cont())
    """

    args = copy.deepcopy(args)

    args.sess_n = sess_n

    analysis_dict = init_param_cont(args)
    analyspar, sesspar, stimpar = [
        analysis_dict[key] for key in ["analyspar", "sesspar", "stimpar"]
    ]

    # get session IDs and create Sessions
    sesspar_dict = sesspar._asdict()
    _ = sesspar_dict.pop("closest")
    sessids = sess_gen_util.get_sess_vals(mouse_df,
                                          "sessid",
                                          omit_sess=args.omit_sess,
                                          omit_mice=args.omit_mice,
                                          **sesspar_dict)

    logger.info(f"Loading {len(sessids)} session(s)...",
                extra={"spacing": "\n"})

    args_dict = {
        "datadir": args.datadir,
        "mouse_df": mouse_df,
        "runtype": sesspar.runtype,
        "full_table": False,
        "fluor": analyspar.fluor,
        "dend": analyspar.dend,
        "run": True,
        "pupil": (sesspar.runtype != "pilot"),
        "temp_log": "warning",
    }

    sessions = gen_util.parallel_wrap(sess_gen_util.init_sessions,
                                      sessids,
                                      args_dict=args_dict,
                                      parallel=args.parallel,
                                      use_tqdm=True)

    # flatten list of sessions
    sessions = [sess for singles in sessions for sess in singles]

    # sort by mouse number
    sort_order = np.argsort([sess.mouse_n for sess in sessions])
    sessions = [sessions[s] for s in sort_order]

    runtype_str = ""
    if sesspar.runtype != "prod":
        runtype_str = f" ({sesspar.runtype} data)"

    stim_str = stimpar.stimtype
    if stimpar.stimtype == "gabors":
        stim_str = "gabor"
    elif stimpar.stimtype == "visflow":
        stim_str = "visual flow"

    logger.info(
        f"Analysis of {sesspar.plane} responses to {stim_str} "
        f"stimuli{runtype_str}.\nSession {sesspar.sess_n}",
        extra={"spacing": "\n"})

    return sessions, analysis_dict