Пример #1
0
Файл: utils.py Проект: LBHB/NEMS
def get_default_savepath(modelspec):
    if get_setting('USE_NEMS_BAPHY_API'):
        results_dir = 'http://'+get_setting('NEMS_BAPHY_API_HOST')+":"+ \
                      str(get_setting('NEMS_BAPHY_API_PORT')) + '/results'
    else:
        results_dir = get_setting('NEMS_RESULTS_DIR')

    batch = modelspec.meta.get('batch', 0)
    exptid = modelspec.meta.get('exptid', 'DATA')
    siteid = modelspec.meta.get('siteid', exptid)
    cellid = modelspec.meta.get('cellid', siteid)
    cellids = modelspec.meta.get('cellids', siteid)

    if (siteid == 'DATA') and (type(cellids) is list) and len(cellids) > 1:
        if cellid == 'none':
            siteid = 'none'  # special siteid that uses all sites in a single recording
        else:
            siteid = cellids[0].split("-")[0]
        destination = os.path.join(results_dir, str(batch), siteid,
                                   modelspec.get_longname())
    else:
        destination = os.path.join(results_dir, str(batch), cellid,
                                   modelspec.get_longname())
    log.info('model save destination: %s', destination)
    return destination
Пример #2
0
def weight_channels_heatmap(modelspec,
                            idx=None,
                            ax=None,
                            clim=None,
                            title=None,
                            chan_names=None,
                            wc_idx=0,
                            **options):
    """
    :param modelspec: modelspec object
    :param idx: index into modelspec
    :param ax:
    :param clim:
    :param title:
    :param chan_names: labels for x axis
    :param wc_idx:
    :param options:
    :return:
    """
    if idx is not None:
        # module has been specified
        coefficients = _get_wc_coefficients(modelspec[idx:], idx=0)
    else:
        # weird old way: get the idx-th set of coefficients
        coefficients = _get_wc_coefficients(modelspec, idx=wc_idx)

    # normalize per channel:
    #coefficients /= np.std(coefficients, axis=0, keepdims=True)

    # make bigger dimension horizontal
    if coefficients.shape[0] > coefficients.shape[1]:
        ax = plot_heatmap(coefficients.T,
                          xlabel='Channel Out',
                          ylabel='Channel In',
                          ax=ax,
                          clim=clim,
                          title=title,
                          cmap=get_setting('WEIGHTS_CMAP'))
    else:
        ax = plot_heatmap(coefficients,
                          xlabel='Channel In',
                          ylabel='Channel Out',
                          ax=ax,
                          clim=clim,
                          title=title,
                          cmap=get_setting('WEIGHTS_CMAP'))

    if chan_names is None:
        chan_names = []
    elif type(chan_names) is int:
        chan_names = [chan_names]

    for i, c in enumerate(chan_names):
        plt.text(i, 0, c)

    return ax
Пример #3
0
def load_model_xform(cellid,
                     batch=271,
                     modelname="ozgf100ch18_wcg18x2_fir15x2_lvl1_dexp1_fit01",
                     eval_model=True,
                     only=None):
    '''
    Load a model that was previously fit via fit_model_xforms

    Parameters
    ----------
    cellid : str
        cellid in celldb database
    batch : int
        batch number in celldb database
    modelname : str
        modelname in celldb database
    eval_model : boolean
        If true, the entire xfspec will be re-evaluated after loading.
    only : int
        Index of single xfspec step to evaluate if eval_model is False.
        For example, only=0 will typically just load the recording.

    Returns
    -------
    xfspec, ctx : nested list, dictionary

    '''

    kws = escaped_split(modelname, '_')
    old = False
    if (len(kws) > 3) or ((len(kws) == 3) and kws[1].startswith('stategain')
                          and not kws[1].startswith('stategain.')):
        # Check if modelname uses old format.
        log.info("Using old modelname format ... ")
        old = True

    d = nd.get_results_file(batch, [modelname], [cellid])
    filepath = d['modelpath'][0]
    # TODO add BAPHY_API support . Not implemented on nems_baphy yet?
    #if get_setting('USE_NEMS_BAPHY_API'):
    #    prefix = '/auto/data/nems_db' # get_setting('NEMS_RESULTS_DIR')
    #    uri = filepath.replace(prefix,
    #                           'http://' + get_setting('NEMS_BAPHY_API_HOST') + ":" + str(get_setting('NEMS_BAPHY_API_PORT')))
    #else:
    #    uri = filepath.replace('/auto/data/nems_db/results', get_setting('NEMS_RESULTS_DIR'))

    # hack: hard-coded assumption that server will use this data root
    uri = filepath.replace('/auto/data/nems_db/results',
                           get_setting('NEMS_RESULTS_DIR'))
    if old:
        raise NotImplementedError("need to use oxf library.")
        xfspec, ctx = oxf.load_analysis(uri, eval_model=eval_model)
    else:
        xfspec, ctx = xforms.load_analysis(uri,
                                           eval_model=eval_model,
                                           only=only)
    return xfspec, ctx
Пример #4
0
    def load_saved_settings(self):
        """Reads in saved settings.

        The section header for this GUI is "db_browser", with each tab saving it's config
        in variables prefixed with their tab name. Different configs are named by appending
        a colon then the save name (ex: "db_browser:save1").
        """
        self.config_file = Path(get_setting('SAVED_SETTINGS_PATH')) / 'gui.ini'
        self.config_group = 'db_browser'
        self.config = ConfigParser(delimiters=('='))
        self.config.read(self.config_file)
Пример #5
0
Файл: utils.py Проект: LBHB/NEMS
def adjust_uri_prefix(uri, use_nems_defaults=True):
    """
    if get_setting('USE_NEMS_BAPHY_API') is True: translate file system URI to http --or--
    if get_setting('USE_NEMS_BAPHY_API') is False: translate http URI to file system
    
    Warning! May be too hacky, and unclear where this should be evaluated! Currently run in uri.load_uri, which may be too low-level,
    as there may be situations where you want to hard-code a URI that doesn't match the expectations of the current configuration.
    """
    use_API = get_setting('USE_NEMS_BAPHY_API')
    prefix_is_http = uri.startswith("http")
    rec_prefix = get_setting('NEMS_RECORDINGS_DIR')
    res_prefix = get_setting('NEMS_RESULTS_DIR')
    rec_match = uri.find("/recordings")
    res_match = uri.find("/results")

    if use_API and (not prefix_is_http):
        api_prefix = 'http://' + get_setting(
            'NEMS_BAPHY_API_HOST') + ":" + str(
                get_setting('NEMS_BAPHY_API_PORT'))
        if uri.startswith(rec_prefix):
            new_uri = uri.replace(rec_prefix, api_prefix + "/recordings")
            log.info(f"Adjusting URI from {uri} to {new_uri}")
        elif uri.startswith(res_prefix):
            new_uri = uri.replace(res_prefix, api_prefix + "/results")
            log.info(f"Adjusting URI from {uri} to {new_uri}")
        else:
            new_uri = uri
    elif (not use_API) and prefix_is_http:
        if rec_match:
            new_uri = rec_prefix + uri[(rec_match + 11):]
            log.info(f"Adjusting URI from {uri} to {new_uri}")
        elif res_match:
            new_uri = res_prefix + uri[(res_match + 8):]
            log.info(f"Adjusting URI from {uri} to {new_uri}")
        else:
            new_uri = uri
    else:
        new_uri = uri

    return new_uri
Пример #6
0
def nonparametric_strf(modelspec,
                       idx,
                       ax=None,
                       clim=None,
                       title=None,
                       **kwargs):
    coefficients = modelspec[idx]['phi']['coefficients']
    plot_heatmap(coefficients,
                 xlabel='Time Bin',
                 ylabel='Channel In',
                 ax=ax,
                 clim=clim,
                 cmap=get_setting('FILTER_CMAP'),
                 title=title)
Пример #7
0
def _get_modelspecs(cellids, batch, modelname, multi='mean'):
    filepaths = load_batch_modelpaths(batch,
                                      modelname,
                                      cellids,
                                      eval_model=False)
    speclists = []
    for path in filepaths:
        mspaths = []
        path = path.replace('http://hyrax.ohsu.edu:3003/',
                            '/auto/data/nems_db/')
        if get_setting('NEMS_RESULTS_DIR').startswith("/Volumes"):
            path = path.replace('/auto/', '/Volumes/')
        for file in os.listdir(path):
            if file.startswith("modelspec"):
                mspaths.append(os.path.join(path, file))
        speclists.append([load_resource(p) for p in mspaths])

    modelspecs = []
    for m in speclists:
        if len(m) > 1:
            if multi == 'first':
                this_mspec = m[0]
            elif multi == 'all':
                this_mspec = m
            elif multi == 'mean':
                stats = ms.summary_stats(m)
                temp_spec = copy.deepcopy(m[0])
                phis = [m['phi'] for m in temp_spec]
                for p in phis:
                    for k in p:
                        for s in stats:
                            if s.endswith('--' + k):
                                p[k] = stats[s]['mean']
                for m, p in zip(temp_spec, phis):
                    m['phi'] = p
                this_mspec = temp_spec
            else:
                log.warning(
                    "Couldn't interpret <multi> parameter. Got: %s,\n"
                    "Expected one of: 'mean, first, random, all'.\n"
                    "Using first modelspec instead.", multi)
                this_mspec = m[0]
        else:
            this_mspec = m[0]

        modelspecs.append(ms.ModelSpec([this_mspec]))

    return modelspecs
Пример #8
0
def fir_heatmap(modelspec,
                ax=None,
                clim=None,
                title=None,
                chans=None,
                fir_idx=0,
                **options):
    coefficients = _get_fir_coefficients(modelspec, idx=fir_idx)
    plot_heatmap(coefficients,
                 xlabel='Time Bin',
                 ylabel='Channel In',
                 ax=ax,
                 clim=clim,
                 cmap=get_setting('FILTER_CMAP'),
                 title=title)
    if chans is not None:
        for i, c in enumerate(chans):
            plt.text(-0.4, i, c, verticalalignment='center')
Пример #9
0
def strf_local_lin(rec, modelspec, cursor_time=20, channels=0, **options):
    rec = rec.copy()

    tbin = int(cursor_time * rec['resp'].fs)

    chan_count = rec['stim'].shape[0]
    firmod = find_module('fir', modelspec)
    tbin_count = modelspec.phi[firmod]['coefficients'].shape[1] + 2

    use_dstrf = True
    if use_dstrf:
        index = int(cursor_time * rec['resp'].fs)
        strf = modelspec.get_dstrf(rec,
                                   index=index,
                                   width=20,
                                   out_channel=channels)
    else:
        resp_chan = channels
        d = rec['stim']._data.copy()
        strf = np.zeros((chan_count, tbin_count))
        _p1 = rec['pred']._data[resp_chan, tbin]
        eps = np.nanstd(d) / 100
        eps = 0.01
        #print('eps: {}'.format(eps))
        for c in range(chan_count):
            #eps = np.std(d[c, :])/100
            for t in range(tbin_count):

                _d = d.copy()
                _d[c, tbin - t] *= 1 + eps
                rec['stim'] = rec['stim']._modified_copy(data=_d)
                rec = modelspec.evaluate(rec)
                _p2 = rec['pred']._data[resp_chan, tbin]
                strf[c, t] = (_p2 - _p1) / eps
    print('strf min: {} max: {}'.format(np.min(strf), np.max(strf)))
    options['clim'] = np.array([-np.max(np.abs(strf)), np.max(np.abs(strf))])
    plot_heatmap(strf, cmap=get_setting('FILTER_CMAP'), **options)
Пример #10
0
import numpy as np

from nems.registry import KeywordRegistry
from nems.plugins import default_keywords
from nems.utils import find_module
from nems.analysis.api import fit_basic
from nems.fitters.api import scipy_minimize
import nems.priors as priors
import nems.modelspec as ms
import nems.metrics.api as metrics
from nems import get_setting

log = logging.getLogger(__name__)
default_kws = KeywordRegistry()
default_kws.register_module(default_keywords)
default_kws.register_plugins(get_setting('KEYWORD_PLUGINS'))


def from_keywords(keyword_string, registry=None, rec=None, meta={}):
    '''
    Returns a modelspec created by splitting keyword_string on underscores
    and replacing each keyword with what is found in the nems.keywords.defaults
    registry. You may provide your own keyword registry using the
    registry={...} argument.
    '''
    if registry is None:
        registry = default_kws
    keywords = keyword_string.split('-')

    # Lookup the modelspec fragments in the registry
    modelspec = []
Пример #11
0
def do_decoding_analysis(lv_model=True, **ctx):
    """
    Meant to replace CRH's personal code for doing decoding analysis.
    Basically the only this that needed to change is to allow the decoding.load_site 
    function to take a ctx (an already loaded recording). So then load_site just 
    reshapes the data accordingly and returns same stuff as always
    """
    # pull site / batch out of context
    site = ctx["meta"]["cellid"]
    batch = ctx["meta"]["batch"]
    if type(batch) is str:
        batch = int(batch)
    modelname = "decoding"

    # ================================ SET RNG STATE ===================================
    np.random.seed(123)

    # ============================== SAVE PARAMETERS ===================================
    # define save directory
    path = nems.get_setting("NEMS_RESULTS_DIR")

    # ========================== Define analysis options ========================
    # these are harcoded right now, but should be editable in the future
    njacks = 10
    zscore = True
    pup_ops = None
    shuffle_trials = False
    regress_pupil = False
    use_xforms = False

    # old simulation stuff (brute force approach)
    sim1 = False
    sim2 = False
    sim12 = False
    sim_tdr_space = False  # perform the simulation in the TDR space. piggy-back on jackknifes for this. e.g. for each jackknife, do a new simulation
                        # this is kludgy. may want to do more iterations. If so, gonna need to rethink the modelname / est val creation etc

    pca_ops = None
    do_pls = False
    do_PCA = False
    pca_lv = False
    nc_lv = False
    nc_lv_z = False
    fix_tdr2 = False
    ddr2_method = 'fa'  # use FactorAnalysis to ID the noise axis, by default. Can also do 'pca'
    gain_only = False
    dc_only = False
    est_equal_val = False  # for low rep sites where can't perform cross-validation
    n_noise_axes = 0    # whether or not to go beyond TDR = 2 dimensions. e.g. if this is 2, Will compute a 4D TDR space (2 more noise dimensions)
    loocv = False    # leave-one-out cross validation
    exclude_low_fr = False
    threshold = None
    movement_mask = False
    use_old_cpn = False
    all_pup = False
    fa_model = False # simulate data from FA
    fa_sim = False

    if do_pls:
        log.info("Also running PLS dimensionality reduction for N components. Will be slower")
        raise DeprecationWarning("Updates have been made since this was last used. Make sure behavior is as expected")
    elif do_PCA:
        log.info("Also running trial averaged PCA dimensionality reduction for N components.")
        raise DeprecationWarning("Updates have been made since this was last used. Make sure behavior is as expected")
    else:
        log.info("Only performing TDR dimensionality reduction. No PLS or PCA")

    # ================ load LV information for this site =======================
    # can't compare these axes if we've reduced dimensionality
    if pca_ops is None:
        if pca_lv:
            fn = '/auto/users/hellerc/results/nat_pupil_ms/LV/pca_regression_lvs.pickle'
            # load results from pickle file
            with open(fn, 'rb') as handle:
                lv_results = pickle.load(handle)
            beta1 = lv_results[site+str(batch)]['beta1']
            beta2 = lv_results[site+str(batch)]['beta2']
        elif nc_lv:
            log.info("loading LVs from NC method using raw responses")
            fn = '/auto/users/hellerc/results/nat_pupil_ms/LV/nc_based_lvs.pickle'
            # load results from pickle file
            with open(fn, 'rb') as handle:
                lv_results = pickle.load(handle)
            beta1 = lv_results[site+str(batch)]['beta1']
            beta2 = lv_results[site+str(batch)]['beta2']
        elif nc_lv_z:
            log.info("loading LVs from NC method using z-scored responses")
            fn = '/auto/users/hellerc/results/nat_pupil_ms/LV/nc_zscore_lvs.pickle'
            # load results from pickle file
            with open(fn, 'rb') as handle:
                lv_results = pickle.load(handle)
            beta1 = lv_results[site+str(batch)]['beta1']
            beta2 = lv_results[site+str(batch)]['beta2']
        else:
            beta1=None
            beta2=None
    else:
        beta1 = None
        beta2 = None

    # ================================= load recording ==================================
    X, sp_bins, X_pup, pup_mask, epochs = decoding.load_site(site=site, batch=batch, 
                                        ctx=ctx,
                                        pca_ops=pca_ops,
                                        pup_ops=pup_ops,
                                        regress_pupil=regress_pupil,
                                        gain_only=gain_only,
                                        dc_only=dc_only,
                                        use_xforms=use_xforms,
                                        return_epoch_list=True,
                                        exclude_low_fr=exclude_low_fr,
                                        threshold=threshold,
                                        mask_movement=movement_mask,
                                        use_old_cpn=use_old_cpn)
    ncells = X.shape[0]
    nreps_raw = X.shape[1]
    nstim = X.shape[2]
    nbins = X.shape[3]
    sp_bins = sp_bins.reshape(1, sp_bins.shape[1], nstim * nbins)
    nstim = nstim * nbins

    # =========================== generate a list of stim pairs ==========================
    all_combos = list(combinations(range(nstim), 2))
    spont_bins = np.argwhere(sp_bins[0, 0, :])
    spont_combos = [c for c in all_combos if (c[0] in spont_bins) & (c[1] in spont_bins)]
    ev_ev_combos = [c for c in all_combos if (c[0] not in spont_bins) & (c[1] not in spont_bins)]
    spont_ev_combos = [c for c in all_combos if (c not in ev_ev_combos) & (c not in spont_combos)]

    # get list of epoch combos as a tuple (in the same fashion as above)
    epochs_bins = np.concatenate([[e+'_'+str(k) for k in range(nbins)] for e in epochs])
    epochs_str_combos = list(combinations(epochs_bins, 2))

    # =================================== simulate =======================================
    # update X to simulated data if specified. Else X = X_raw.
    # point of this is so that decoding axes / TDR space doesn't change for simulation (or for xforms predicted data)
    # should make results easier to interpret. CRH 06.04.2020
    X_raw = X.copy()
    pup_mask_raw = pup_mask.copy()
    meta = ctx['rec'].meta
    if (sim1 | sim2 | sim12) & (not sim_tdr_space):
        X, pup_mask = decoding.simulate_response(X, pup_mask, sim_first_order=sim1,
                                                            sim_second_order=sim2,
                                                            sim_all=sim12,
                                                            ntrials=5000)
    elif lv_model:
        # get lv model predictions from context (ctx["val"]["pred"])
        # then evaluate decoding with predictions
        X, _, _, pup_mask, _ = decoding.load_site(site=site, batch=batch, 
                                        ctx=ctx,
                                        use_pred=True,
                                        pca_ops=pca_ops,
                                        pup_ops=pup_ops,
                                        regress_pupil=regress_pupil,
                                        gain_only=gain_only,
                                        dc_only=dc_only,
                                        use_xforms=use_xforms,
                                        return_epoch_list=True,
                                        exclude_low_fr=exclude_low_fr,
                                        threshold=threshold,
                                        mask_movement=movement_mask,
                                        use_old_cpn=use_old_cpn)

    elif sim_tdr_space:
        log.info("Performing simulations within TDR space. Unique simulation per each jackknife")

    elif fa_model:
        log.info(f"Simulate data based on factor analysis model with sim={fa_sim} for this dataset")
        big_psth = np.stack([X.reshape(ncells, nreps_raw, nstim)[:, pup_mask.reshape(1, nreps_raw, nstim)[0, :, s], s].mean(axis=1) for s in range(nstim)]).T.reshape(ncells, X.shape[2], nbins)[:, np.newaxis, :, :]
        small_psth = np.stack([X.reshape(ncells, nreps_raw, nstim)[:, pup_mask.reshape(1, nreps_raw, nstim)[0, :, s]==False, s].mean(axis=1) for s in range(nstim)]).T.reshape(ncells, X.shape[2], nbins)[:, np.newaxis, :, :]
        X, pup_mask = decoding.load_FA_model(site, batch, big_psth, small_psth, sim=fa_sim, nreps=5000)
    
    # another option for surrogates based on pop metrics
    # how to do this? Generate data elsewhere, then load? 
    # that way can compute pop metrics / pairwise metrics on each
    # set and then just load them here for decoding

    else:
        pass

    nreps = X.shape[1]

    # =============================== reshape data ===================================
    # reshape mask to match data
    pup_mask = pup_mask.reshape(1, nreps, nstim)
    pup_mask_raw = pup_mask_raw.reshape(1, nreps_raw, nstim)
    # reshape X (and X_raw)
    X = X.reshape(ncells, nreps, nstim)
    X_raw = X_raw.reshape(ncells, nreps_raw, nstim)
    # reshape X_pup
    X_pup = X_pup.reshape(1, nreps_raw, nstim)

    # ===================== decide if / how we should shuffle data ===================
    if shuffle_trials:
        # shuffle trials per neuron within state, so don't break any overall gain changes between large / small pupil
        if np.any(X!=X_raw):
            raise ValueError("Right now, we're just set up to do this for raw data. Not sure how you'd do it for sim data")
        else:
            log.info("Shuffing trials within pupil condition to break state-dependent changes in correlation")
            X = decoding.shuffle_trials(X, pup_mask)
            X_raw = X.copy()

    # ============================== get pupil variance ==================================
    # figure out pupil variance per stimulus (this always happens on raw data... X_pup and pup_mask_raw)
    pupil_range = nat_preproc.get_pupil_range(X_pup, pup_mask_raw)

    # =========================== generate list of est/val sets ==========================
    # also generate list of est / val for the raw data. Because of random number seed, it's 
    # critical that this happens first, and doesn't happend twice (if simulation is False)
    log.info("Generate list of {0} est / val sets".format(njacks))

    # generate raw est/val sets
    np.random.seed(123)
    est_raw, val_raw, p_est_raw, p_val_raw = nat_preproc.get_est_val_sets(X_raw, pup_mask=pup_mask_raw, njacks=njacks, est_equal_val=est_equal_val)
    nreps_train_raw = est_raw[0].shape[1]
    nreps_test_raw = val_raw[0].shape[1]

    # check if data was simulated. If so, then generate the est / val sets for this data
    xraw_equals_x = False
    if (X.shape == X_raw.shape):
        if np.all(X_raw == X):
            xraw_equals_x = True
            est = est_raw.copy()
            val = val_raw.copy()
            p_est = p_est_raw.copy()
            p_val = p_val_raw.copy()
        else:
            est, val, p_est, p_val = nat_preproc.get_est_val_sets(X, pup_mask=pup_mask, njacks=njacks)
    else:
        est, val, p_est, p_val = nat_preproc.get_est_val_sets(X, pup_mask=pup_mask, njacks=njacks)

    nreps_train = est[0].shape[1]
    nreps_test = val[0].shape[1]

    # determine number of dim reduction components (bounded by ndim in dataset) 
    # force to less than 10, for speed purposes.
    components = np.min([ncells, nreps_train, 10])

    # ============================ preprocess est / val sets =============================
    if zscore:
        log.info("z-score est / val sets")
        est, val, fullX = nat_preproc.scale_est_val(est, val, full=X)
        est_raw, val_raw, fullX_raw = nat_preproc.scale_est_val(est_raw, val_raw, full=X_raw)
    else:
        # just center data
        log.info("center est / val sets")
        est, val, fullX = nat_preproc.scale_est_val(est, val, full=X, sd=False)
        est_raw, val_raw, fullX_raw = nat_preproc.scale_est_val(est_raw, val_raw, full=X_raw, sd=False)

    # =========================== if fix tdr 2 =======================================
    # calculate first noise PC for each val set, use this to define TDR2, rather
    # than stimulus specific first noise PC (this method seems too noisy). Always
    # use raw data for this.
    if fix_tdr2:
        if ddr2_method=='nclv':
            log.info("Loading cached delta noise correlation axis as ddr2 noise axis")
            lvdict = pickle.load(open('/auto/users/hellerc/results/nat_pupil_ms/LV/nc_zscore_lvs.pickle', 'rb'))
            nc_ax = lvdict[site+str(batch)]['beta2'].T
            tdr2_axes = [nc_ax] * len(val)
        else:
            log.info("Finding first noise dimension for each est set using raw data")
            tdr2_axes = nat_preproc.get_first_pc_per_est(est_raw, method=ddr2_method)
    else:
        tdr2_axes = [None] * len(val)
    # set up data frames to save results (wait to preallocate space on first
    # iteration, because then we'll have the columns)
    temp_pca_results = pd.DataFrame()
    temp_pls_results = pd.DataFrame()
    temp_tdr_results = pd.DataFrame()
    pls_index = range(len(all_combos) * njacks * (components-2))
    pca_index = range(len(all_combos) * njacks)
    tdr_index = range(len(all_combos) * njacks)
    pca_idx = 0
    pls_idx = 0
    tdr_idx = 0
    # ============================== Loop over stim pairs ================================
    for stim_pair_idx, (ecombo, combo) in enumerate(zip(epochs_str_combos, all_combos)):
        # print every 500th pair. Don't want to overwhelm log
        if (stim_pair_idx % 500) == 0:
            log.info("Analyzing stimulus pair {0} / {1}".format(stim_pair_idx, len(all_combos)))
        if combo in spont_combos:
            category = 'spont_spont'
        elif combo in spont_ev_combos:
            category = 'spont_evoked'
        elif combo in ev_ev_combos:
            category = 'evoked_evoked'

        for ev_set in range(njacks):
            X_train = est[ev_set][:, :, [combo[0], combo[1]]] 
            X_test = val[ev_set][:, :, [combo[0], combo[1]]]
            _Xfull = fullX[ev_set][:, :, [combo[0], combo[1]]]
            if all_pup:
                # use all data for big / small pupil
                ptrain_mask = pup_mask.copy()[:, :, [combo[0], combo[1]]]
                ptest_mask = pup_mask.copy()[:, :, [combo[0], combo[1]]]

            else:
                ptrain_mask = p_est[ev_set][:, :, [combo[0], combo[1]]]
                ptest_mask = p_val[ev_set][:, :, [combo[0], combo[1]]]
            tdr2_axis = tdr2_axes[ev_set]

            # if ptrain_mask is all False, just set it to None so
            # that we don't attempt to run the bp/sp analysis. This is a kludgy 
            # way to deal with cases where we selected a pupil window for which there 
            # was not valid data (or enough valid data) in this experiment
            if (np.sum(ptrain_mask[:, :, 0]) < 5) | (np.sum(ptrain_mask[:, :, 1]) < 5):
                log.info(f"For combo: {combo}, ev_set: {ev_set} not enough data. stim1: {ptrain_mask[:, :, 0].sum()}, stim2: {ptrain_mask[:, :, 1].sum()}")
                ptrain_mask = None

            xtrain = nat_preproc.flatten_X(X_train[:, :, :, np.newaxis])
            xtest = nat_preproc.flatten_X(X_test[:, :, :, np.newaxis])
            xfull = nat_preproc.flatten_X(_Xfull[:, :, :, np.newaxis])

            # define raw data
            if xraw_equals_x:
                raw_data = None
            else:
                # define data to be used for tdr decomposition
                X_train_raw = est_raw[ev_set][:, :, [combo[0], combo[1]]] 
                xtrain_raw = nat_preproc.flatten_X(X_train_raw[:, :, :, np.newaxis])
                raw_data = (xtrain_raw, nreps_train_raw)

            # ============================== TDR ANALYSIS ==============================
            # custom dim reduction onto plane defined by dU and first PC of noise covariance (+ additional noise axes)
            if sim_tdr_space:
                # simulate data *after* after projecting into TDR space.
                try:
                    if not loocv:
                        _tdr_results = decoding.do_tdr_dprime_analysis(xtrain,
                                                        xtest,
                                                        nreps_train,
                                                        nreps_test,
                                                        tdr_data=raw_data,
                                                        n_additional_axes=n_noise_axes,
                                                        sim1=sim1,
                                                        sim2=sim2,
                                                        sim12=sim12,
                                                        beta1=beta1,
                                                        beta2=beta2,
                                                        tdr2_axis=tdr2_axis,
                                                        fullpup=all_pup,
                                                        fullX=xfull,
                                                        ptrain_mask=ptrain_mask,
                                                        ptest_mask=ptest_mask)
                    else:
                        raise NotImplementedError("WIP -- loocv for simulations")
                except:
                    log.info("Can't perform analysis for stimulus combo: {0}".format(combo))
                    _tdr_results = {}
            else:
                if not loocv:
                    _tdr_results = decoding.do_tdr_dprime_analysis(xtrain,
                                                                xtest,
                                                                nreps_train,
                                                                nreps_test,
                                                                tdr_data=raw_data,
                                                                n_additional_axes=n_noise_axes,
                                                                beta1=beta1,
                                                                beta2=beta2,
                                                                tdr2_axis=tdr2_axis,
                                                                fullpup=all_pup,
                                                                fullX=xfull,                                                            
                                                                ptrain_mask=ptrain_mask,
                                                                ptest_mask=ptest_mask)
                else:
                    # use leave-one-out cross validation
                    _tdr_results = decoding.do_tdr_dprime_analysis_loocv(xtrain,
                                                                        nreps_train,
                                                                        tdr_data=raw_data,
                                                                        n_additional_axes=n_noise_axes,
                                                                        beta1=beta1,
                                                                        beta2=beta2,
                                                                        tdr2_axis=tdr2_axis,
                                                                        pmask=ptrain_mask)
                
            _tdr_results.update({
                'n_components': 2+n_noise_axes,
                'jack_idx': ev_set,
                'combo': combo,
                'e1': ecombo[0],
                'e2': ecombo[1],
                'category': category,
                'site': site
            })
            # preallocate space for subsequent iterations
            if tdr_idx == 0:
                if 'bp_dp' not in _tdr_results.keys():
                    bp_cols = ['bp_dp', 'bp_evals', 'bp_dU_mag', 'bp_dU_dot_evec',
                                        'bp_cos_dU_wopt', 'bp_dU_dot_evec_sq', 'bp_evec_snr', 
                                        'bp_cos_dU_evec']
                    for bp_col in bp_cols:
                        if bp_col in ['bp_dU_dot_evec', 'bp_dU_dot_evec_sq', 'bp_evec_snr', 'bp_cos_dU_evec']:
                            _tdr_results[bp_col] = np.nan * np.ones((1, 2+n_noise_axes))
                        elif bp_col == 'bp_evals':
                            _tdr_results[bp_col] = np.nan * np.ones(2+n_noise_axes)
                        elif bp_col == 'bp_cos_dU_wopt':
                            _tdr_results[bp_col] = np.nan * np.ones((1, 1))
                        else:
                            _tdr_results[bp_col] = np.nan # add place holder in case the first didn't have data for this
                if 'sp_dp' not in _tdr_results.keys():
                    sp_cols = ['sp_dp', 'sp_evals', 'sp_dU_mag', 'sp_dU_dot_evec',
                                        'sp_cos_dU_wopt', 'sp_dU_dot_evec_sq', 'sp_evec_snr', 
                                        'sp_cos_dU_evec']
                    for sp_col in sp_cols:
                        if sp_col in ['sp_dU_dot_evec', 'sp_dU_dot_evec_sq', 'sp_evec_snr', 'sp_cos_dU_evec']:
                            _tdr_results[sp_col] = np.nan * np.ones((1, 2+n_noise_axes))
                        elif sp_col == 'sp_evals':
                            _tdr_results[sp_col] = np.nan * np.ones(2+n_noise_axes)
                        elif sp_col == 'sp_cos_dU_wopt':
                            _tdr_results[sp_col] = np.nan * np.ones((1, 1))
                        else:
                            _tdr_results[sp_col] = np.nan # add place holder in case the first didn't have data for this

                temp_tdr_results = temp_tdr_results.append([_tdr_results])
                #t = {k: [v] for k,v in _tdr_results.items()}
                #temp_tdr_results = pd.concat([temp_tdr_results, pd.DataFrame(t)], ignore_index=True)
                tdr_results = pd.DataFrame(index=tdr_index, columns=temp_tdr_results.columns)
                tdr_results.loc[tdr_idx] = temp_tdr_results.iloc[0].values
                temp_tdr_results = pd.DataFrame()

            else:
                temp_tdr_results = temp_tdr_results.append([_tdr_results])
                #t = {k: [v] for k,v in _tdr_results.items()}
                #temp_tdr_results = pd.concat([temp_tdr_results, pd.DataFrame(t)], ignore_index=True)
                tdr_results.loc[tdr_idx, temp_tdr_results.keys()] = temp_tdr_results.iloc[0].values
                temp_tdr_results = pd.DataFrame()
            tdr_idx += 1

            # ============================== PCA ANALYSIS ===============================
            if do_PCA:
                _pca_results = decoding.do_pca_dprime_analysis(xtrain, 
                                                            xtest, 
                                                            nreps_train,
                                                            nreps_test,
                                                            ptrain_mask=ptrain_mask,
                                                            ptest_mask=ptest_mask)
                _pca_results.update({
                    'n_components': 2,
                    'jack_idx': ev_set,
                    'combo': combo,
                    'e1': ecombo[0],
                    'e2': ecombo[1],
                    'category': category,
                    'site': site
                })
                # preallocate space for subsequent iterations
                if pca_idx == 0:
                    temp_pca_results = temp_pca_results.append([_pca_results])
                    pca_results = pd.DataFrame(index=pca_index, columns=temp_pca_results.columns)
                    pca_results.loc[pca_idx] = temp_pca_results.iloc[0].values
                    temp_pca_results = pd.DataFrame()

                else:
                    temp_pca_results = temp_pca_results.append([_pca_results])
                    pca_results.loc[pca_idx] = temp_pca_results.iloc[0].values
                    temp_pca_results = pd.DataFrame()
                pca_idx += 1

            if do_pls:
                # ============================== PLS ANALYSIS ===============================
                for n_components in range(2, components):

                    _pls_results = decoding.do_pls_dprime_analysis(xtrain, 
                                                                xtest, 
                                                                nreps_train,
                                                                nreps_test,
                                                                ptrain_mask=ptrain_mask,
                                                                ptest_mask=ptest_mask,
                                                                n_components=n_components)
                    _pls_results.update({
                        'n_components': n_components,
                        'jack_idx': ev_set,
                        'combo': combo,
                        'e1': ecombo[0],
                        'e2': ecombo[1],
                        'category': category,
                        'site': site
                    })
                
                    # preallocate space for subsequent iterations
                    if pls_idx == 0:
                        temp_pls_results = temp_pls_results.append([_pls_results])
                        pls_results = pd.DataFrame(index=pls_index, columns=temp_pls_results.columns)
                        pls_results.loc[pls_idx] = temp_pls_results.iloc[0].values
                        temp_pls_results = pd.DataFrame()

                    else:
                        temp_pls_results = temp_pls_results.append([_pls_results])
                        pls_results.loc[pls_idx] = temp_pls_results.iloc[0].values
                        temp_pls_results = pd.DataFrame()

                    pls_idx += 1

    
    # convert columns to str
    tdr_results.loc[:, 'combo'] = ['{0}_{1}'.format(c[0], c[1]) for c in tdr_results.combo.values]
    if do_PCA:
        pca_results.loc[:, 'combo'] = ['{0}_{1}'.format(c[0], c[1]) for c in pca_results.combo.values]
    if do_pls:
        pls_results.loc[:, 'combo'] = ['{0}_{1}'.format(c[0], c[1]) for c in pls_results.combo.values]

    # get mean pupil range for each combo
    log.info('Computing mean pupil range for each pair of stimuli')
    combo_to_tup = lambda x: (int(x.split('_')[0]), int(x.split('_')[1])) 
    combos = pd.Series(tdr_results['combo'].values).apply(combo_to_tup)
    #import pdb; pdb.set_trace()
    pr = pupil_range
    get_mean = lambda x: (pr[pr.stim==x[0]]['range'] + pr[pr.stim==x[1]]['range']) / 2
    pr_range = combos.apply(get_mean)
    tdr_results['mean_pupil_range'] = pr_range.values

    # convert to correct dtypes
    tdr_results = decoding.cast_dtypes(tdr_results)
    if do_PCA:
        pca_results = decoding.cast_dtypes(pca_results)
    if do_pls:
        pls_results = decoding.cast_dtypes(pls_results)

    # collapse over results to save disk space by packing into "DecodingResults object"
    log.info("Compressing results into DecodingResults object... ")
    tdr_results = decoding.DecodingResults(tdr_results, pupil_range=pupil_range)
    if do_PCA:
        pca_results = decoding.DecodingResults(pca_results, pupil_range=pupil_range)
    if do_pls:
        pls_results = decoding.DecodingResults(pls_results, pupil_range=pupil_range)

    if meta is not None:
        if 'mask_bins' in meta.keys():
            tdr_results.meta['mask_bins'] = meta['mask_bins']

    # save results
    modelname = modelname.replace('*', '_')

    results_path = ctx['modelspec'][0]['meta']['modelpath']
    log.info(f"Saving results to {results_path}")

    tdr_results.save_pickle(os.path.join(results_path, modelname+'_TDR.pickle'))

    if do_PCA:
        pca_results.save_pickle(os.path.join(results_path, modelname+'_PCA.pickle'))

    if do_pls:
        pls_results.save_pickle(os.path.join(results_path, modelname+'_PLS.pickle'))

    return 0
Пример #12
0
# TODO: tests for utility functions in nems/plots/*, like those in nems/plots/assemble.py

import os
import numpy as np
import matplotlib.pyplot as plt

import nems.recording as recording
import nems.plots.api as nplt
import nems.epoch as ep
import nems

signals_dir = nems.get_setting('NEMS_RECORDINGS_DIR')

#uri = signals_dir + "/por074b-c2.tgz"
#uri = signals_dir + "/BRT026c-02-1.tgz"
#cellid = "BRT026c-02-1"
recording_file = "TAR010c.NAT.fs100.ch18.tgz"
uri = os.path.join(signals_dir, recording_file)

cellid = "TAR010c-18-2"


def test_plots():
    recording.get_demo_recordings(name=recording_file)
    rec = recording.load_recording(uri)

    resp = rec['resp'].rasterize()
    stim = rec['stim'].rasterize()

    epoch_regex = "^STIM_"
Пример #13
0
        cellid, batch, modelname))
    #savefile = nw.fit_model_xforms_baphy(cellid, batch, modelname, saveInDB=True)
    savefile = xhelp.fit_model_xform(cellid, batch, modelname, saveInDB=True)

    log.info("Done with fit.")

    # Mark completed in the queue. Note that this should happen last thing!
    # Otherwise the job might still crash after being marked as complete.
    if db_exists & bool(queueid):
        nd.update_job_complete(queueid)

        if 'SLURM_JOB_ID' in os.environ:
            # need to copy the job log over to the queue log dir
            log_file_dir = Path.home() / 'job_history'
            log_file = list(
                log_file_dir.glob(
                    f'*jobid{os.environ["SLURM_JOB_ID"]}_log.out'))
            if len(log_file) == 1:
                log_file = log_file[0]
                log.info(f'Found log file: "{str(log_file)}"')
                log.info('Copying log file to queue log repo.')

                with open(log_file, 'r') as f:
                    log_data = f.read()

                dst_prefix = r'http://' + get_setting(
                    'NEMS_BAPHY_API_HOST') + ":" + str(
                        get_setting('NEMS_BAPHY_API_PORT'))
                dst_loc = dst_prefix + '/queuelog/' + str(queueid)
                save_resource(str(dst_loc), data=log_data)
Пример #14
0
def generate_xforms_spec(recording_uri,
                         modelname,
                         meta={},
                         xforms_kwargs={},
                         kw_kwargs={},
                         autoPred=True,
                         autoStats=True,
                         autoPlot=True):
    """
    TODO: Update this doc

    OUTDATED
    Fits a single NEMS model
    eg, 'ozgf100ch18_wc18x1_lvl1_fir15x1_dexp1_fit01'
    generates modelspec with 'wc18x1_lvl1_fir1x15_dexp1'

    based on fit_model function in nems/scripts/fit_model.py

    example xfspec:
     xfspec = [
        ['nems.xforms.load_recordings', {'recording_uri_list': recordings}],
        ['nems.xforms.add_average_sig', {'signal_to_average': 'resp',
                                         'new_signalname': 'resp',
                                         'epoch_regex': '^STIM_'}],
        ['nems.xforms.split_by_occurrence_counts', {'epoch_regex': '^STIM_'}],
        ['nems.xforms.init_from_keywords', {'keywordstring': modelspecname}],
        ['nems.xforms.set_random_phi',  {}],
        ['nems.xforms.fit_basic',       {}],
        # ['nems.xforms.add_summary_statistics',    {}],
        ['nems.xforms.plot_summary',    {}],
        # ['nems.xforms.save_recordings', {'recordings': ['est', 'val']}],
        ['nems.xforms.fill_in_default_metadata',    {}],
     ]
    """

    log.info('Initializing modelspec(s) for recording/model {0}/{1}...'.format(
        recording_uri, modelname))

    # parse modelname and assemble xfspecs for loader and fitter

    # TODO: naming scheme change: pre_modules, modules, post_modules?
    #       or something along those lines... since they aren't really
    #       just loaders and fitters
    load_keywords, model_keywords, fit_keywords = escaped_split(modelname, '_')

    xforms_lib = KeywordRegistry(recording_uri=recording_uri, **xforms_kwargs)
    xforms_lib.register_modules(
        [default_loaders, default_fitters, default_initializers])
    xforms_lib.register_plugins(get_setting('XFORMS_PLUGINS'))

    keyword_lib = KeywordRegistry(**kw_kwargs)
    keyword_lib.register_module(default_keywords)
    keyword_lib.register_plugins(get_setting('KEYWORD_PLUGINS'))

    # Generate the xfspec, which defines the sequence of events
    # to run through (like a packaged-up script)
    xfspec = []

    # 1) Load the data
    xfspec.extend(_parse_kw_string(load_keywords, xforms_lib))

    # 2) generate a modelspec
    xfspec.append([
        'nems.xforms.init_from_keywords', {
            'keywordstring': model_keywords,
            'meta': meta,
            'registry': keyword_lib
        }
    ])

    # 3) fit the data
    xfspec.extend(_parse_kw_string(fit_keywords, xforms_lib))

    # TODO: need to make this smarter about how to handle the ordering
    #       of pred/stats when only stats is overridden.
    #       For now just have to manually include pred if you want to
    #       do your own stats or plot xform (like using stats.pm)

    # 4) generate a prediction (optional)
    if autoPred:
        if not _xform_exists(xfspec, 'nems.xforms.predict'):
            xfspec.append(['nems.xforms.predict', {}])

    # 5) add some performance statistics (optional)
    if autoStats:
        if not _xform_exists(xfspec, 'nems.xforms.add_summary_statistics'):
            xfspec.append(['nems.xforms.add_summary_statistics', {}])

    # 6) generate plots (optional)
    if autoPlot:
        if not _xform_exists(xfspec, 'nems.xforms.plot_summary'):
            log.info('Adding summary plot to xfspec...')
            xfspec.append(['nems.xforms.plot_summary', {}])

    return xfspec
Пример #15
0
        trial_count = len(all_trials) + 1
        if trial_count <= 9:
            all_trials[f'Trial_00{trial_count}'] = single_trial
            all_epochs[f'Trial_00{trial_count}'] = single_epoch
        if 9 < trial_count < 100:
            all_trials[f'Trial_0{trial_count}'] = single_trial
            all_epochs[f'Trial_0{trial_count}'] = single_epoch
        if trial_count >= 100:
            all_trials[f'Trial_{trial_count}'] = single_trial
            all_epochs[f'Trial_{trial_count}'] = single_epoch

    return all_trials, all_epochs


BAPHY_ROOT = "/Users/grego/baphy"
RECORDING_PATH = get_setting('NEMS_RECORDINGS_DIR')
recording_file = os.path.join(RECORDING_PATH, "classifier.tgz")
class_labels = ['Ferret']
cats = len(class_labels)

dir1 = os.path.join(BAPHY_ROOT, "SoundObjects", "@FerretVocal", "Sounds_set4",
                    "*.wav")
# dir2 = os.path.join(BAPHY_ROOT, "SoundObjects", "@Speech", "sounds", "*sa1.wav")
set1 = glob.glob(dir1)
# set2 = glob.glob(dir2)
sound_set1 = set1
# sound_set2 = set2

sound_classes = (np.zeros(len(set1)) + 1)
sound_files = set1
# sound_classes = np.concatenate((np.zeros(len(set1)) + 1,
Пример #16
0
import nems.db as nd
import nems.plots.api as nplt
import nems.xform_helper as xhelp
import nems.epoch as ep
from nems.utils import find_common
import pandas as pd
import scipy.ndimage.filters as sf
import nems.gui.recording_browser as browser
import nems.gui.editors as editor
import nems.gui.model_comparison as comparison
from nems.gui.canvas import NemsCanvas, EpochCanvas, PrettyWidget

from configparser import ConfigParser
import nems

configfile = os.path.join(nems.get_setting('SAVED_SETTINGS_PATH') + '/gui.ini')
nems_root = os.path.abspath(nems.get_setting('SAVED_SETTINGS_PATH') + '/../../')

# TEMP ERROR CATCHER
# Back up the reference to the exceptionhook
sys._excepthook = sys.excepthook

def my_exception_hook(exctype, value, traceback):
    # Print the error and traceback
    print(exctype, value, traceback)
    # Call the normal Exception hook after
    sys._excepthook(exctype, value, traceback)
    sys.exit(1)

# Set the exception hook to our wrapping function
sys.excepthook = my_exception_hook
Пример #17
0
def fit_xforms_model(batch, cellid, modelname, save_analysis=False):

    # parse modelname into loaders, modelspecs, and fit keys
    load_keywords, model_keywords, fit_keywords = modelname.split("_")

    # construct the meta data dict
    meta = {
        'batch': batch,
        'cellid': cellid,
        'modelname': modelname,
        'loader': load_keywords,
        'fitkey': fit_keywords,
        'modelspecname': model_keywords,
        'username': '******',
        'labgroup': 'lbhb',
        'public': 1,
        'githash': os.environ.get('CODEHASH', ''),
        'recording': load_keywords
    }

    xforms_kwargs = {}
    xforms_init_context = {'cellid': cellid, 'batch': int(batch)}
    recording_uri = None
    kw_kwargs = {}

    xforms_lib = KeywordRegistry(**xforms_kwargs)

    xforms_lib.register_modules(
        [default_loaders, default_fitters, default_initializers])
    xforms_lib.register_plugins(get_setting('XFORMS_PLUGINS'))

    keyword_lib = KeywordRegistry()
    keyword_lib.register_module(default_keywords)
    keyword_lib.register_plugins(get_setting('KEYWORD_PLUGINS'))

    # Generate the xfspec, which defines the sequence of events
    # to run through (like a packaged-up script)
    xfspec = []

    # 0) set up initial context
    if xforms_init_context is None:
        xforms_init_context = {}
    if kw_kwargs is not None:
        xforms_init_context['kw_kwargs'] = kw_kwargs
    xforms_init_context['keywordstring'] = model_keywords
    xforms_init_context['meta'] = meta
    xfspec.append(['nems.xforms.init_context', xforms_init_context])

    # 1) Load the data
    xfspec.extend(xhelp._parse_kw_string(load_keywords, xforms_lib))

    # 2) generate a modelspec
    xfspec.append(
        ['nems.xforms.init_from_keywords', {
            'registry': keyword_lib
        }])

    # 3) fit the data
    xfspec.extend(xhelp._parse_kw_string(fit_keywords, xforms_lib))

    # Generate a prediction
    xfspec.append(['nems.xforms.predict', {}])

    # 4) add some performance statistics
    xfspec.append(['nems.xforms.add_summary_statistics', {}])

    # 5) plot
    #xfspec.append(['nems_lbhb.lv_helpers.add_summary_statistics', {}])

    # Create a log stream set to the debug level; add it as a root log handler
    log_stream = io.StringIO()
    ch = logging.StreamHandler(log_stream)
    ch.setLevel(logging.DEBUG)
    fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    formatter = logging.Formatter(fmt)
    ch.setFormatter(formatter)
    rootlogger = logging.getLogger()
    rootlogger.addHandler(ch)

    ctx = {}
    for xfa in xfspec:
        ctx = xforms.evaluate_step(xfa, ctx)

    # Close the log, remove the handler, and add the 'log' string to context
    log.info('Done (re-)evaluating xforms.')
    ch.close()
    rootlogger.removeFilter(ch)

    log_xf = log_stream.getvalue()

    modelspec = ctx['modelspec']
    if save_analysis:
        # save results
        if get_setting('USE_NEMS_BAPHY_API'):
            prefix = 'http://' + get_setting(
                'NEMS_BAPHY_API_HOST') + ":" + str(
                    get_setting('NEMS_BAPHY_API_PORT')) + '/results/'
        else:
            prefix = get_setting('NEMS_RESULTS_DIR')

        if type(cellid) is list:
            cell_name = cellid[0].split("-")[0]
        else:
            cell_name = cellid

        destination = os.path.join(prefix, str(batch), cell_name,
                                   modelspec.get_longname())

        modelspec.meta['modelpath'] = destination
        modelspec.meta.update(meta)

        log.info('Saving modelspec(s) to {0} ...'.format(destination))

        xforms.save_analysis(destination,
                             recording=ctx['rec'],
                             modelspec=modelspec,
                             xfspec=xfspec,
                             figures=[],
                             log=log_xf)

        # save performance and some other metadata in database Results table
        nd.update_results_table(modelspec)

    return xfspec, ctx
Пример #18
0
Файл: utils.py Проект: LBHB/NEMS
                            all_there = False
                            break
                if all_there:
                    filtered_collection.append(c)
            else:
                if (s in c) or (c in s):
                    filtered_collection.append(c)
                    break
                else:
                    pass

    return filtered_collection


default_configfile = os.path.join(
    nems.get_setting('SAVED_SETTINGS_PATH') + '/gui.ini')
nems_root = os.path.abspath(
    nems.get_setting('SAVED_SETTINGS_PATH') + '/../../')


def load_settings(config_group="db_browser_last", configfile=None):

    if configfile is None:
        configfile = default_configfile

    config = ConfigParser(delimiters=('='))

    try:
        config.read(configfile)
        settings = dict(config.items(config_group))
        return settings
Пример #19
0
def pop_file(stimfmt='ozgf',
             batch=None,
             cellid=None,
             rasterfs=50,
             chancount=18,
             siteid=None,
             loadkey=None,
             **options):

    siteid = siteid.split("-")[0]
    subsetstr = []
    sitelist = []
    if siteid == 'ALLCELLS':
        if (batch in [322]):
            subsetstr = ["NAT4v2", "NAT3", "NAT1"]
        elif (batch in [323]):
            subsetstr = ["NAT4"]
        elif (batch in [333]):
            #runclass="OLP"
            #sql="SELECT sRunData.cellid,gData.svalue,gData.rawid FROM sRunData INNER JOIN" +\
            #        " sCellFile ON sRunData.cellid=sCellFile.cellid " +\
            #        " INNER JOIN gData ON" + \
            #        " sCellFile.rawid=gData.rawid AND gData.name='Ref_Combos'" +\
            #        " AND gData.svalue='Manual'" +\
            #        " INNER JOIN gRunClass on gRunClass.id=sCellFile.runclassid" +\
            #        f" WHERE sRunData.batch={batch} and gRunClass.name='{runclass}'"
            #d = nd.pd_query(sql)

            #d['siteid'] = d['cellid'].apply(nd.get_siteid)
            #sitelist = d['siteid'].unique()
            modelname_filter = 'ozgf.fs100.ch18-ld-norm.l1-sev_wc.18x4.g-fir.4x25-lvl.1-dexp.1_tfinit.n.lr1e3.et3.rb10.es20-newtf.n.lr1e4'
            sitelist, _ = nd.get_batch_sites(batch,
                                             modelname_filter=modelname_filter)

            #sitelist=sitelist[:4]
            #log.info('limiting sitelist to 4 entries!!!!!!!!!!!!!!!!!')
        else:
            raise ValueError(f'ALLCELLS not supported for batch {batch}')
    elif ((batch == 272) and (siteid == 'none')) or (siteid in [
            'bbl086b', 'TAR009d', 'TAR010c', 'TAR017b'
    ]):
        subsetstr = ["NAT1"]
    elif siteid in [
            'none', 'NAT3', 'AMT003c', 'AMT005c', 'AMT018a', 'AMT020a',
            'AMT023d', 'bbl099g', 'bbl104h', 'BRT026c', 'BRT032e', 'BRT033b',
            'BRT034f', 'BRT037b', 'BRT038b', 'BRT039c', 'AMT031a', 'AMT032a'
    ]:
        # Should use NAT3 as siteid going forward for better readability,
        # but left other options here for backwards compatibility.
        subsetstr = ["NAT3"]
    elif (batch in [322, 323, 333]) or (siteid == 'NAT4'):
        subsetstr = ["NAT4v2"]
    else:
        raise ValueError('site not known for popfile')
    use_API = get_setting('USE_NEMS_BAPHY_API')

    uri_root = '/auto/data/nems_db/recordings/'

    recording_uri_list = []
    #max_sites = 2;
    max_sites = 12
    log.info(f"TRUNCATING MULTI-FILE DATA AT {max_sites} RECORDINGS")
    for s in sitelist[:max_sites]:
        recording_uri = generate_recording_uri(batch=batch,
                                               cellid=s,
                                               stimfmt=stimfmt,
                                               rasterfs=rasterfs,
                                               chancount=chancount,
                                               **options)
        log.info(f'loading {recording_uri}')
        #if use_API:
        #    host = 'http://'+get_setting('NEMS_BAPHY_API_HOST')+":"+str(get_setting('NEMS_BAPHY_API_PORT'))
        #    recording_uri = host + '/recordings/' + str(batch) + '/' + recname + '.tgz'
        #else:
        #    recording_uri = '{}{}/{}.tgz'.format(uri_root, batch, recname)
        recording_uri_list.append(recording_uri)
    for s in subsetstr:
        recname = f"{s}_{stimfmt}.fs{rasterfs}.ch{chancount}"
        log.info(f'loading {recname}')
        #data_file = '{}{}/{}.tgz'.format(uri_root, batch, recname)

        if use_API:
            host = 'http://' + get_setting('NEMS_BAPHY_API_HOST') + ":" + str(
                get_setting('NEMS_BAPHY_API_PORT'))
            recording_uri = host + '/recordings/' + str(
                batch) + '/' + recname + '.tgz'
        else:
            recording_uri = '{}{}/{}.tgz'.format(uri_root, batch, recname)
        recording_uri_list.append(recording_uri)
    if len(subsetstr) == 1:
        return recording_uri
    else:
        return recording_uri_list
Пример #20
0
def generate_xforms_spec(recording_uri=None,
                         modelname=None,
                         meta={},
                         xforms_kwargs={},
                         kw_kwargs={},
                         autoPred=True,
                         autoStats=True,
                         autoPlot=True):
    """
    Generate an xforms spec based on a modelname, which can then be evaluated
    in order to process and fit a model.

    Parameter
    ---------
    recording_uri : str
        Location to load recording from, e.g. a filepath or URL.
    modelname : str
        NEMS-formatted modelname, e.g. 'ld-sev_wc.18x2-fir.2x15-dexp.1_basic'
        The modelname will be parsed into a series of xforms functions using
        xforms and keyword registries.
    meta : dict
        Additional keyword arguments for nems.initializers.init_from_keywords
    xforms_kwargs : dict
        Additional keyword arguments for the xforms registry
    kw_kwargs : dict
        Additional keyword arguments for the keyword registry
    autoPred : boolean
        If true, will automatically append nems.xforms.predict to the xfspec
        if it is not already present.
    autoStats : boolean
        If true, will automatically append nems.xforms.add_summary_statistics
        to the xfspec if it is not already present.
    autoPlot : boolean
        If true, will automatically append nems.xforms.plot_summary to the
        xfspec if it is not already present.

    Returns
    -------
    xfspec : list of 2- or 4- tuples

    """

    log.info('Initializing modelspec(s) for recording/model {0}/{1}...'.format(
        recording_uri, modelname))

    # parse modelname and assemble xfspecs for loader and fitter

    # TODO: naming scheme change: pre_modules, modules, post_modules?
    #       or something along those lines... since they aren't really
    #       just loaders and fitters
    load_keywords, model_keywords, fit_keywords = escaped_split(modelname, '_')
    if recording_uri is not None:
        xforms_lib = KeywordRegistry(recording_uri=recording_uri,
                                     **xforms_kwargs)
    else:
        xforms_lib = KeywordRegistry(**xforms_kwargs)

    xforms_lib.register_modules(
        [default_loaders, default_fitters, default_initializers])
    xforms_lib.register_plugins(get_setting('XFORMS_PLUGINS'))

    keyword_lib = KeywordRegistry(**kw_kwargs)
    keyword_lib.register_module(default_keywords)
    keyword_lib.register_plugins(get_setting('KEYWORD_PLUGINS'))

    # Generate the xfspec, which defines the sequence of events
    # to run through (like a packaged-up script)
    xfspec = []

    # 1) Load the data
    xfspec.extend(_parse_kw_string(load_keywords, xforms_lib))

    # 2) generate a modelspec
    xfspec.append([
        'nems.xforms.init_from_keywords', {
            'keywordstring': model_keywords,
            'meta': meta,
            'registry': keyword_lib
        }
    ])

    # 3) fit the data
    xfspec.extend(_parse_kw_string(fit_keywords, xforms_lib))

    # TODO: need to make this smarter about how to handle the ordering
    #       of pred/stats when only stats is overridden.
    #       For now just have to manually include pred if you want to
    #       do your own stats or plot xform (like using stats.pm)

    # 4) generate a prediction (optional)
    if autoPred:
        if not _xform_exists(xfspec, 'nems.xforms.predict'):
            xfspec.append(['nems.xforms.predict', {}])

    # 5) add some performance statistics (optional)
    if autoStats:
        if not _xform_exists(xfspec, 'nems.xforms.add_summary_statistics'):
            xfspec.append(['nems.xforms.add_summary_statistics', {}])

    # 6) generate plots (optional)
    if autoPlot:
        if not _xform_exists(xfspec, 'nems.xforms.plot_summary'):
            # log.info('Adding summary plot to xfspec...')
            xfspec.append(['nems.xforms.plot_summary', {}])

    return xfspec
Пример #21
0
                              xforms_init_context=xforms_init_context,
                              autoPlot=autoPlot)
log.info(xfspec)

# actually do the loading, preprocessing, fit
ctx, log_xf = xforms.evaluate(xfspec)

# save some extra metadata
modelspec = ctx['modelspec']

# this code may not be necessary any more.
#destination = '{0}/{1}/{2}/{3}'.format(
#    get_setting('NEMS_RESULTS_DIR'), batch, cellid, modelspec.get_longname())
if type(cellid) is list:
    destination = os.path.join(
        get_setting('NEMS_RESULTS_DIR'), str(batch),
        cellid[0][:7], modelspec.get_longname())
else:
    destination = os.path.join(
        get_setting('NEMS_RESULTS_DIR'), str(batch),
        cellid, modelspec.get_longname())
modelspec.meta['modelpath'] = destination
modelspec.meta['figurefile'] = os.path.join(destination, 'figure.0000.png')
modelspec.meta.update(meta)

# save results
log.info('Saving modelspec(s) to {0} ...'.format(destination))
if 'figures' in ctx.keys():
    figs = ctx['figures']
else:
    figs = []
Пример #22
0
def fit_model_xform(cellid,
                    batch,
                    modelname,
                    autoPlot=True,
                    saveInDB=False,
                    returnModel=False,
                    recording_uri=None,
                    initial_context=None):
    """
    Fit a single NEMS model using data stored in database. First generates an xforms
    script based on modelname parameter and then evaluates it.
    :param cellid: cellid and batch specific dataset in database
    :param batch:
    :param modelname: string specifying model architecture, preprocessing
    and fit method
    :param autoPlot: generate summary plot when complete
    :param saveInDB: save results to Results table
    :param returnModel: boolean (default False). If False, return savepath
       if True return xfspec, ctx tuple
    :param recording_uri
    :return: savepath = path to saved results or (xfspec, ctx) tuple
    """
    startime = time.time()
    log.info('Initializing modelspec(s) for cell/batch %s/%d...', cellid,
             int(batch))

    # Segment modelname for meta information
    kws = escaped_split(modelname, '_')

    modelspecname = escaped_join(kws[1:-1], '-')
    loadkey = kws[0]
    fitkey = kws[-1]

    meta = {
        'batch': batch,
        'cellid': cellid,
        'modelname': modelname,
        'loader': loadkey,
        'fitkey': fitkey,
        'modelspecname': modelspecname,
        'username': '******',
        'labgroup': 'lbhb',
        'public': 1,
        'githash': os.environ.get('CODEHASH', ''),
        'recording': loadkey
    }
    if type(cellid) is list:
        meta['siteid'] = cellid[0][:7]

    # registry_args = {'cellid': cellid, 'batch': int(batch)}
    registry_args = {}
    xforms_init_context = {'cellid': cellid, 'batch': int(batch)}
    if initial_context is not None:
        xforms_init_context.update(initial_context)

    log.info("TODO: simplify generate_xforms_spec parameters")
    xfspec = generate_xforms_spec(recording_uri=recording_uri,
                                  modelname=modelname,
                                  meta=meta,
                                  xforms_kwargs=registry_args,
                                  xforms_init_context=xforms_init_context,
                                  autoPlot=autoPlot)
    log.debug(xfspec)

    # actually do the loading, preprocessing, fit
    if initial_context is None:
        initial_context = {}
    ctx, log_xf = xforms.evaluate(xfspec)  #, context=initial_context)

    # save some extra metadata
    modelspec = ctx['modelspec']

    if type(cellid) is list:
        cell_name = cellid[0].split("-")[0]
    else:
        cell_name = cellid

    if 'modelpath' not in modelspec.meta:
        prefix = get_setting('NEMS_RESULTS_DIR')
        destination = os.path.join(prefix, str(batch), cell_name,
                                   modelspec.get_longname())

        log.info(f'Setting modelpath to "{destination}"')
        modelspec.meta['modelpath'] = destination
        modelspec.meta['figurefile'] = os.path.join(destination,
                                                    'figure.0000.png')
    else:
        destination = modelspec.meta['modelpath']

    # figure out URI for location to save results (either file or http, depending on USE_NEMS_BAPHY_API)
    if get_setting('USE_NEMS_BAPHY_API'):
        prefix = 'http://' + get_setting('NEMS_BAPHY_API_HOST') + ":" + str(get_setting('NEMS_BAPHY_API_PORT')) + \
                 '/results'
        save_loc = str(
            batch) + '/' + cell_name + '/' + modelspec.get_longname()
        save_destination = prefix + '/' + save_loc
        # set the modelspec meta save locations to be the filesystem and not baphy
        modelspec.meta['modelpath'] = get_setting(
            'NEMS_RESULTS_DIR') + '/' + save_loc
        modelspec.meta['figurefile'] = modelspec.meta[
            'modelpath'] + '/' + 'figure.0000.png'
    else:
        save_destination = destination

    modelspec.meta['runtime'] = int(time.time() - startime)
    modelspec.meta.update(meta)

    if returnModel:
        # return fit, skip save!
        return xfspec, ctx

    # save results
    log.info('Saving modelspec(s) to {0} ...'.format(save_destination))
    if 'figures' in ctx.keys():
        figs = ctx['figures']
    else:
        figs = []
    save_data = xforms.save_analysis(save_destination,
                                     recording=ctx.get('rec'),
                                     modelspec=modelspec,
                                     xfspec=xfspec,
                                     figures=figs,
                                     log=log_xf,
                                     update_meta=False)

    # save in database as well
    if saveInDB:
        nd.update_results_table(modelspec)

    return save_data['savepath']
Пример #23
0
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os

import nems_lbhb.strf.strf as strf
from nems_lbhb.strf.torc_subfunctions import strfplot
import nems_lbhb.baphy as nb
from nems.recording import Recording
import nems.db as nd
from nems import get_setting

fs = 1000

# data frame cache
path = get_setting("NEMS_RESULTS_DIR")
df_295_filename = os.path.join(path, str(295), 'd_tuning.csv')
df_307_filename = os.path.join(path, str(307), 'd_tuning.csv')
df_309_filename = os.path.join(path, str(309), 'd_tuning.csv')
df_313_filename = os.path.join(path, str(313), 'd_tuning.csv')

# pdf figure cache
# pdf_path = '/home/charlie/Desktop/lbhb/code/nems_db/nems_lbhb/pupil_behavior_scripts/strf_tuning/'

# ================================= batch 307 ==================================
cells_307 = nd.get_batch_cells(307).cellid
df_307 = pd.DataFrame(index=cells_307,
                      columns=['BF', 'SNR', 'STRF', 'StimParms'])
for cellid in cells_307:
    print('analyzing cell: {0}, batch {1}'.format(cellid, 307))
Пример #24
0
import nems.modelspec as ms
import nems.plots.api as nplt
import nems.analysis.api
import nems.utils
import nems.uri
import nems.recording as recording
import nems.xforms as xforms
import nems.tf.cnn as cnn
import nems.tf.cnnlink as cnnlink

log = logging.getLogger(__name__)

# ----------------------------------------------------------------------------
# CONFIGURATION
# figure out data and results paths:
results_dir = nems.get_setting('NEMS_RESULTS_DIR')
signals_dir = nems.get_setting('NEMS_RECORDINGS_DIR')

# ----------------------------------------------------------------------------
# DATA LOADING & PRE-PROCESSING
"""
recording.get_demo_recordings(name="TAR010c-18-1.pkl")

datafile = os.path.join(signals_dir, "TAR010c-18-1.pkl")
load_command = 'nems.demo.loaders.demo_loader'
exptid = "TAR010c"
batch = 271
cellid = "TAR010c-18-1"
"""

datafile = os.path.join(
Пример #25
0
def kamiak_to_database(cellids,
                       batch,
                       modelnames,
                       source_path,
                       executable_path=None,
                       script_path=None):

    user = '******'
    linux_user = '******'
    allowqueuemaster = 1
    waitid = 0
    parmstring = ''
    rundataid = 0
    priority = 1
    reserve_gb = 0
    codeHash = 'kamiak'

    if executable_path in [None, 'None', 'NONE', '']:
        executable_path = get_setting('DEFAULT_EXEC_PATH')
    if script_path in [None, 'None', 'NONE', '']:
        script_path = get_setting('DEFAULT_SCRIPT_PATH')

    combined = [(c, b, m)
                for c, b, m in itertools.product(cellids, [batch], modelnames)]
    notes = ['%s/%s/%s' % (c, b, m) for c, b, m in combined]
    commandPrompts = [
        "%s %s %s %s %s" % (executable_path, script_path, c, b, m)
        for c, b, m in combined
    ]

    engine = nd.Engine()
    for (c, b, m), note, commandPrompt in zip(combined, notes, commandPrompts):
        path = os.path.join(source_path, batch, c, m)
        if not os.path.exists(path):
            log.warning("missing fit for: \n%s\n%s\n%s\n"
                        "using path: %s\n", batch, c, m, path)
            continue
        else:
            xfspec, ctx = xforms.load_analysis(path, eval_model=False)
            preview = ctx['modelspec'].meta.get('figurefile', None)
            if 'log' not in ctx:
                ctx['log'] = 'missing log'
            figures_to_load = ctx['figures_to_load']
            figures = [xforms.load_resource(f) for f in figures_to_load]
            ctx['figures'] = figures
            xforms.save_analysis(None, None, ctx['modelspec'], xfspec,
                                 ctx['figures'], ctx['log'])
            nd.update_results_table(ctx['modelspec'], preview=preview)

        conn = engine.connect()
        sql = 'SELECT * FROM tQueue WHERE note="' + note + '"'
        r = conn.execute(sql)
        if r.rowcount > 0:
            # existing job, figure out what to do with it
            x = r.fetchone()
            queueid = x['id']
            complete = x['complete']

            if complete == 1:
                # Do nothing - the queue already shows a complete job
                pass

            elif complete == 2:
                # Change dead to complete
                sql = "UPDATE tQueue SET complete=1, killnow=0 WHERE id={}".format(
                    queueid)
                r = conn.execute(sql)

            else:
                # complete in [-1, 0] -- already running or queued
                # Do nothing
                pass

        else:
            # New job
            sql = "INSERT INTO tQueue (rundataid,progname,priority," +\
                   "reserve_gb,parmstring,allowqueuemaster,user," +\
                   "linux_user,note,waitid,codehash,queuedate,complete) VALUES"+\
                   " ({},'{}',{}," +\
                   "{},'{}',{},'{}'," +\
                   "'{}','{}',{},'{}',NOW(),1)"

            sql = sql.format(rundataid, commandPrompt, priority, reserve_gb,
                             parmstring, allowqueuemaster, user, linux_user,
                             note, waitid, codeHash)
            r = conn.execute(sql)

        conn.close()
Пример #26
0
meta = {'batch': batch, 'cellid': cellid, 'modelname': modelname,
        'loader': load_keywords, 'fitkey': fit_keywords, 'modelspecname': model_keywords,
        'username': '******', 'labgroup': 'lbhb', 'public': 1,
        'githash': os.environ.get('CODEHASH', ''),
        'recording': load_keywords}

xforms_kwargs = {}
xforms_init_context = {'cellid': cellid, 'batch': int(batch)}
recording_uri = None
kw_kwargs ={}

xforms_lib = KeywordRegistry(**xforms_kwargs)

xforms_lib.register_modules([default_loaders, default_fitters,
                                default_initializers])
xforms_lib.register_plugins(get_setting('XFORMS_PLUGINS'))

keyword_lib = KeywordRegistry()
keyword_lib.register_module(default_keywords)
keyword_lib.register_plugins(get_setting('KEYWORD_PLUGINS'))

# Generate the xfspec, which defines the sequence of events
# to run through (like a packaged-up script)
xfspec = []

# 0) set up initial context
if xforms_init_context is None:
    xforms_init_context = {}
if kw_kwargs is not None:
    xforms_init_context['kw_kwargs'] = kw_kwargs
xforms_init_context['keywordstring'] = model_keywords
Пример #27
0
        afl + pxf

    (so we have 6 total models, two batches)

as of making this file (04 / 14 / 2020). The cleanest task results have been 
observed using stategain per file. However, using perfile seems to underestimate
the influence of pupil bc the passive file epochs can take care of a lot of
this.
"""

import helpers as helper
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from nems import get_setting
path = get_setting('NEMS_RESULTS_DIR')

fig_path = '/auto/users/hellerc/results/pup_beh_ms/'

batch = 309
r0_threshold = 0
octave_cutoff = 0.5
use_sig_from = None  #'d_pup_afl_pxf_stategain.csv'  # model fit to determine significance
first_passive = False  # models fit on data with first passive only + actives
group_files = False

sg_fns = [
    'd_pup_fil_stategain.csv', 'd_pup_afl_stategain.csv',
    'd_pup_afl_pxf_stategain.csv'
]
sd_fns = [
Пример #28
0
import nems.uri
import nems.recording as recording
from nems.signal import RasterizedSignal
from nems.fitters.api import scipy_minimize
import nems.db as nd

log = logging.getLogger(__name__)


# ----------------------------------------------------------------------------
# LOAD AND FORMAT RECORDING DATA

# data file and results locations
# defined in nems/nems/configs/settings.py, which will override
# defaults in nems/nems/configs/defaults.py
results_dir = nems.get_setting('NEMS_RESULTS_DIR')
recordings_dir = nems.get_setting('NEMS_RECORDINGS_DIR')

save_results = True
browse_results = True

# 2p data from Polley Lab at EPL
respfile = os.path.join(recordings_dir, 'data_nems_2p/neurons.csv')
stimfile = os.path.join(recordings_dir, 'data_nems_2p/stim_spectrogram.csv')
exptid = "POL001"
cellid = "POL001-080"
batch = 1  # define the group of data this belong to (eg, 1: A1, 2: AAF, etc)
load_command='nems.demo.loaders.load_polley_data'

# MODEL SPEC
# modelspecname = 'dlog_wcg18x1_stp1_fir1x15_lvl1_dexp1'
Пример #29
0
meta = {'batch': batch, 'cellid': cellid, 'modelname': modelname,
        'loader': loadkey, 'fitkey': fitkey, 'modelspecname': modelspecname,
        'username': '******', 'labgroup': 'lbhb', 'public': 1,
        'githash': os.environ.get('CODEHASH', ''),
        'recording': loadkey}

load_keywords, model_keywords, fit_keywords = modelname.split("_")

# xforms_kwargs = {'cellid': cellid, 'batch': int(batch)}
xforms_kwargs = {}
xforms_init_context = {'cellid': cellid, 'batch': int(batch),
                       'meta': meta, 'keywordstring': model_keywords}
xforms_lib = KeywordRegistry(**xforms_kwargs)
xforms_lib.register_modules([default_loaders, default_fitters,
                             default_initializers])
xforms_lib.register_plugins(get_setting('XFORMS_PLUGINS'))

keyword_lib = KeywordRegistry()
keyword_lib.register_module(default_keywords)
keyword_lib.register_plugins(get_setting('KEYWORD_PLUGINS'))

# Generate the xfspec, which defines the sequence of events
# to run through (like a packaged-up script)
xfspec = []

# 0) set up initial context
xfspec.append(['nems.xforms.init_context', xforms_init_context])

# 1) Load the data
xfspec.extend(xhelp._parse_kw_string(load_keywords, xforms_lib))
Пример #30
0
def init_pop_pca(est, modelspec, flip_pcs=False, IsReload=False, **context):
    """ fit up through the fir module of a population model using the pca
    signal"""

    if IsReload:
        return {}

    # preserve input modelspec. necessary?
    modelspec = copy.deepcopy(modelspec)

    ifir = find_module('filter_bank', modelspec)
    iwc = find_module('weight_channels', modelspec)

    chan_count = modelspec[ifir]['fn_kwargs']['bank_count']
    chan_per_bank = int(modelspec[iwc]['prior']['mean'][1]['mean'].shape[0] /
                        chan_count)
    rec = est.copy()
    tmodelspec = copy.deepcopy(modelspec)

    kw = [m['id'] for m in modelspec[:iwc]]

    wc = modelspec[iwc]['id'].split(".")
    wcs = wc[1].split("x")
    wcs[1] = str(chan_per_bank)
    wc[1] = "x".join(wcs)
    wc = ".".join(wc)

    fir = modelspec[ifir]['id'].split(".")
    fircore = fir[1].split("x")
    fir[1] = "x".join(fircore[:-1])
    fir = ".".join(fir)

    kw.append(wc)
    kw.append(fir)
    kw.append("lvl.1")
    keywordstring = "-".join(kw)
    keyword_lib = KeywordRegistry()
    keyword_lib.register_module(default_keywords)
    keyword_lib.register_plugins(get_setting('KEYWORD_PLUGINS'))
    if flip_pcs:
        pc_fit_count = int(np.ceil(chan_count / 2))
    else:
        pc_fit_count = chan_count
    for pc_idx in range(pc_fit_count):
        r = rec['pca'].extract_channels([rec['pca'].chans[pc_idx]])
        m = np.nanmean(r.as_continuous())
        d = np.nanstd(r.as_continuous())
        rec['resp'] = r._modified_copy((r._data - m) / d)
        tmodelspec = init.from_keywords(keyword_string=keywordstring,
                                        meta={},
                                        registry=keyword_lib,
                                        rec=rec)
        tolerance = 1e-4
        tmodelspec = init.prefit_LN(rec,
                                    tmodelspec,
                                    tolerance=tolerance,
                                    max_iter=700)

        # save results back into main modelspec
        itfir = find_module('fir', tmodelspec)
        itwc = find_module('weight_channels', tmodelspec)

        if pc_idx == 0:
            for tm, m in zip(tmodelspec[:(iwc + 1)], modelspec[:(iwc + 1)]):
                m['phi'] = tm['phi'].copy()
            modelspec[ifir]['phi'] = tmodelspec[itfir]['phi'].copy()
        else:
            for k, v in tmodelspec[iwc]['phi'].items():
                modelspec[iwc]['phi'][k] = np.concatenate(
                    (modelspec[iwc]['phi'][k], v))
            for k, v in tmodelspec[itfir]['phi'].items():
                #if k=='coefficients':
                #    v/=100 # kludge
                modelspec[ifir]['phi'][k] = np.concatenate(
                    (modelspec[ifir]['phi'][k], v))

        if flip_pcs and (pc_idx * 2 < chan_count):
            # add negative flipped version of fit
            for k, v in tmodelspec[iwc]['phi'].items():
                modelspec[iwc]['phi'][k] = np.concatenate(
                    (modelspec[iwc]['phi'][k], v))
            for k, v in tmodelspec[itfir]['phi'].items():
                #if k=='coefficients':
                #    v/=100 # kludge
                modelspec[ifir]['phi'][k] = np.concatenate(
                    (-modelspec[ifir]['phi'][k], v))

    respcount = est['resp'].shape[0]
    fit_set_all, fit_set_slice = _figure_out_mod_split(modelspec)
    cd_kwargs = {}
    cd_kwargs.update({
        'tolerance': tolerance,
        'max_iter': 100,
        'step_size': 0.1
    })

    for s in range(respcount):
        log.info('Pre-fit slice %d', s)
        modelspec = fit_population_slice(est,
                                         modelspec,
                                         slice=s,
                                         fit_set=fit_set_slice,
                                         analysis_function=analysis.fit_basic,
                                         metric=metrics.nmse,
                                         fitter=coordinate_descent,
                                         fit_kwargs=cd_kwargs)

    return {'modelspec': modelspec}