Ejemplo n.º 1
0
def main():
    from nems.main import fit_single_model
    import argparse
    import pylab as pl

    parser = argparse.ArgumentParser(description='Fit single model')
    parser.add_argument('batch', type=int, help='Batch to use')
    parser.add_argument('cell', type=str, help='Cell from batch to fit')
    parser.add_argument('model', type=str, help='Model to fit')
    args = parser.parse_args()
    fit_single_model(args.cell, args.batch, args.model, autoplot=True)
    pl.show()
Ejemplo n.º 2
0
@author: svd

This is meant to illustrate the range of uses either in place or that 
would be nice to support
"""

import nems.main as nems
import nems.utilities as nu

# CASE 1
# vanilla NAT model on "nice" linear cell
cellid = "chn029d-a1"
batch = 271  #A1
modelname = "fb18ch100_wcg01_fir15_fit01"

stack = nems.fit_single_model(cellid, batch, modelname, saveInDB=False)

# CASE 2
# pupil-dependent gain - stretch out all trials (rather than averaging across
# stimulus repetitions). that way you can adjust gain according to pupil at
# each point in time
cellid = "BOL006b-48-1"
batch = 293  # tone pip (PPS) + pupil
modelname = "parm50_wc01_fir10_pupgain_fit01_nested5"

stack = nems.fit_single_model(cellid, batch, modelname, saveInDB=False)

# CASE 3
# pupil-dependent AND task-dependent gain. as in CASE 3, stretched out in
# time. gain on output of STRF is now a function of both pupil and task
# condition typically would want to add "_nested10" or something on the end.
Ejemplo n.º 3
0
    if queueid:
        print("Starting QUEUEID={}".format(queueid))
        nd.update_job_start(queueid)

    if len(sys.argv) < 4:
        print('syntax: nems_fit_single cellid batch modelname')
        exit(-1)

    cellid = sys.argv[1]
    batch = sys.argv[2]
    modelname = sys.argv[3]

    print("Running fit_single_model({0},{1},{2})".format(
        cellid, batch, modelname))
    stack = nems.fit_single_model(cellid, batch, modelname, autoplot=False)

    print("Done with fit.")

    # Edit: added code to save preview image. -Jacob 7/6/2017
    preview_file = stack.quick_plot_save(mode="png")
    print("Preview saved to: {0}".format(preview_file))

    if db_exists:
        if queueid:
            pass
        else:
            queueid = None
        r_id = nd.save_results(stack, preview_file, queueid=queueid)
        print("Fit results saved to NarfResults, id={0}".format(r_id))
Ejemplo n.º 4
0
def get_stacks(cell_ids='all',
               method='load',
               from_file='default',
               modelnames=[
                   'env100e_fir20_fit01_ssa', "env100e_stp1pc_fir20_fit01_ssa"
               ]):

    # defines path to load from or save to the stacks
    if from_file == 'default':
        filename = '/home/mateo/nems/SSA_batch_296/171109_refreshed_full_batch_stacks'
    elif from_file != 'default' and isinstance(from_file, str):
        filename = '/home/mateo/nems/SSA_batch_296/{}'.format(from_file)
    else:
        raise ValueError(
            'invalid from_file value, chose either "default" or a path')

    #defines models to be used, batch, default 296, and list of cell ids,
    modelnames = modelnames
    batch = 296

    if cell_ids == 'all':
        d = ndb.get_batch_cells(batch=batch)
        input_cells = d['cellid'].tolist()
    elif isinstance(cell_ids, list):
        input_cells = cell_ids
    else:
        raise ValueError('cell_ids has to be "all" or ar list of cell ids')

    all_stacks = dict()
    problem_cells = dict()

    if method == 'load':

        for mn in modelnames:

            w_stacks = list()
            w_p_cells = list()

            for cellid in input_cells:
                try:
                    print(
                        '############\n reloading {} \n########## \n '.format(
                            cellid))

                    loaded_stack = nu.io.load_single_model(cellid, batch, mn)

                    del_idx = nu.utils.find_modules(
                        loaded_stack, mod_name='metrics.ssa_index')[0]

                    loaded_stack.remove(del_idx)
                    del loaded_stack.data[del_idx]
                    loaded_stack.insert(nmet.ssa_index,
                                        idx=del_idx,
                                        z_score='bootstrap',
                                        significant_bins='window')

                    w_stacks.append(loaded_stack)
                except:
                    try:
                        fitted_stack = nm.fit_single_model(cellid,
                                                           batch,
                                                           mn,
                                                           autoplot=False)
                        w_stacks.append(fitted_stack)
                    except:
                        print('reloading of {} failed, skipping to next cell'.
                              format(cellid))
                        w_p_cells.append(cellid)

            all_stacks[mn] = w_stacks
            problem_cells[mn] = w_p_cells

        jl.dump(all_stacks, filename)

    elif method == 'fit':

        for mn in modelnames:

            w_stacks = list()
            w_p_cells = list()

            for cellid in input_cells:
                try:
                    print('############\n locally fitting {} \n########## \n '.
                          format(cellid))
                    fitted_stack = fit_single_cell(cellid, batch, mn)
                    w_stacks.append(fitted_stack)
                except:
                    print('fitting of {} failed, skipping to next cell'.format(
                        cellid))
                    w_p_cells.append(cellid)

            all_stacks[mn] = w_stacks
            problem_cells[mn] = w_p_cells

        jl.dump(all_stacks, filename)

    elif method == 'joblib':

        all_stacks = jl.load(filename)

    else:
        raise ValueError(
            'method {} not suported, options are: "load", "joblib", "fit"'.
            format(method))

    return all_stacks