Beispiel #1
0
def fit_population_channel(rec,
                           modelspec,
                           fit_set_all,
                           fit_set_slice,
                           analysis_function=analysis.fit_basic,
                           metric=metrics.nmse,
                           fitter=scipy_minimize,
                           fit_kwargs={}):

    # guess at number of subspace dimensions
    dim_count = modelspec[fit_set_slice[0]]['phi']['coefficients'].shape[0]

    # invert cell-specific modules
    trec = _invert_slice(rec, modelspec, fit_set_slice)

    tmodelspec = ms.ModelSpec()

    for i in fit_set_all:
        m = modelspec[i].copy()
        tmodelspec.append(m)

    tmodelspec = analysis_function(trec,
                                   tmodelspec,
                                   fitter=fitter,
                                   metric=metric,
                                   fit_kwargs=fit_kwargs)

    for i in fit_set_all:
        for k, v in tmodelspec[i]['phi'].items():
            modelspec[i]['phi'][k] = v

    return modelspec
Beispiel #2
0
def _prefit_dsig_only(est,
                      modelspec,
                      analysis_function,
                      fitter,
                      metric=None,
                      fit_kwargs={}):
    '''
    Perform a rough fit that only allows dynamic_sigmoid parameters to vary.
    '''

    dsig_idx = find_module('dynamic_sigmoid', modelspec)

    # freeze all non-static dynamic sigmoid parameters
    dynamic_phi = {
        'amplitude_mod': False,
        'base_mod': False,
        'kappa_mod': False,
        'shift_mod': False
    }
    for p in dynamic_phi:
        v = modelspec[dsig_idx]['prior'].pop(p, False)
        if v:
            modelspec[dsig_idx]['fn_kwargs'][p] = np.nan
            dynamic_phi[p] = v

    # Remove ctwc, ctfir, and ctlvl if they exist
    temp = []
    for i, m in enumerate(modelspec.modules):
        if 'ct' in m['id']:
            pass
        else:
            temp.append(m)
    temp = ms.ModelSpec(raw=[temp])
    temp = prefit_mod_subset(est,
                             temp,
                             analysis_function,
                             fit_set=['dynamic_sigmoid'],
                             fitter=fitter,
                             metric=metric,
                             fit_kwargs=fit_kwargs)

    # Put ctwc, ctfir, and ctlvl back in where applicable
    j = 0
    for i, m in enumerate(modelspec.modules):
        if 'ct' in m['id']:
            pass
        else:
            modelspec[i] = temp[j]
            j += 1

    # reset dynamic sigmoid parameters if they were frozen
    for p, v in dynamic_phi.items():
        if v:
            prior = priors._tuples_to_distributions({p: v})[p]
            modelspec[dsig_idx]['fn_kwargs'].pop(p, None)
            modelspec[dsig_idx]['prior'][p] = v
            modelspec[dsig_idx]['phi'][p] = prior.mean()

    return modelspec
Beispiel #3
0
def fit_population_channel_fast_old(rec, modelspec,
                                    fit_set_all, fit_set_slice,
                                    analysis_function=analysis.fit_basic,
                                    metric=metrics.nmse,
                                    fitter=scipy_minimize, fit_kwargs={}):

    # guess at number of subspace dimensions
    dim_count = modelspec[fit_set_slice[0]]['phi']['coefficients'].shape[1]

    for d in range(dim_count):
        # fit each dim separately
        log.info('Updating dim %d/%d', d+1, dim_count)

        # invert cell-specific modules
        trec = _invert_slice(rec, modelspec, fit_set_slice,
                             population_channel=d)

        tmodelspec = ms.ModelSpec()
        for i in fit_set_all:
            m = copy.deepcopy(modelspec[i])
            for k, v in m['phi'].items():
                x = v.shape[0]
                if x >= dim_count:
                    x1 = int(x / dim_count) * d
                    x2 = int(x / dim_count) * (d + 1)

                    m['phi'][k] = v[x1:x2]
                    if 'bank_count' in m['fn_kwargs'].keys():
                        m['fn_kwargs']['bank_count'] = 1
                else:
                    # single model-wide parameter, only fit for d == 0
                    if d == 0:
                        m['phi'][k] = v
                    else:
                        m['fn_kwargs'][k] = v  # keep fixed for d>0

            tmodelspec.append(m)

        tmodelspec = analysis_function(trec, tmodelspec, fitter=fitter,
                                       metric=metric, fit_kwargs=fit_kwargs)

        for i in fit_set_all:
            for k, v in tmodelspec[i]['phi'].items():
                x = modelspec[i]['phi'][k].shape[0]
                if x >= dim_count:
                    x1 = int(x/dim_count) * d
                    x2 = int(x/dim_count) * (d+1)

                    modelspec[i]['phi'][k][x1:x2] = v
                else:
                    modelspec[i]['phi'][k] = v
        # print([modelspec.phi[f] for f in fit_set_all])

    return modelspec
Beispiel #4
0
def _get_modelspecs(cellids, batch, modelname, multi='mean'):
    filepaths = load_batch_modelpaths(batch,
                                      modelname,
                                      cellids,
                                      eval_model=False)
    speclists = []
    for path in filepaths:
        mspaths = []
        path = path.replace('http://hyrax.ohsu.edu:3003/',
                            '/auto/data/nems_db/')
        if get_setting('NEMS_RESULTS_DIR').startswith("/Volumes"):
            path = path.replace('/auto/', '/Volumes/')
        for file in os.listdir(path):
            if file.startswith("modelspec"):
                mspaths.append(os.path.join(path, file))
        speclists.append([load_resource(p) for p in mspaths])

    modelspecs = []
    for m in speclists:
        if len(m) > 1:
            if multi == 'first':
                this_mspec = m[0]
            elif multi == 'all':
                this_mspec = m
            elif multi == 'mean':
                stats = ms.summary_stats(m)
                temp_spec = copy.deepcopy(m[0])
                phis = [m['phi'] for m in temp_spec]
                for p in phis:
                    for k in p:
                        for s in stats:
                            if s.endswith('--' + k):
                                p[k] = stats[s]['mean']
                for m, p in zip(temp_spec, phis):
                    m['phi'] = p
                this_mspec = temp_spec
            else:
                log.warning(
                    "Couldn't interpret <multi> parameter. Got: %s,\n"
                    "Expected one of: 'mean, first, random, all'.\n"
                    "Using first modelspec instead.", multi)
                this_mspec = m[0]
        else:
            this_mspec = m[0]

        modelspecs.append(ms.ModelSpec([this_mspec]))

    return modelspecs
Beispiel #5
0
def _extract_pop_channel(modelspec, d, fit_set_all, freeze_idx=[]):
    """
    extract mini model from modelspec, just for channel d
    over the modules indexed by fit_set_all
    :param modelspec:
    :param fit_set_all:
    :param d:
    :param freeze_idx: list of module indices to freeze (move phi to fn_kwargs)
    :return: tmodelspec - subspace model
    """
    # create modelspec with single population subspace filter
    dim_count = modelspec[fit_set_all[-1]+1]['phi']['coefficients'].shape[1]
    tmodelspec = ms.ModelSpec()
    for i in fit_set_all:
        m = copy.deepcopy(modelspec[i])
        print(m['fn'])
        for k, v in m['phi'].items():
            x = v.shape[0]
            if x >= dim_count:
                x1 = int(x/dim_count) * d
                x2 = int(x/dim_count) * (d+1)

                m['phi'][k] = v[x1:x2]
                if 'bounds' in m.keys():
                    m['bounds'][k] = (m['bounds'][k][0][x1:x2],
                                      m['bounds'][k][1][x1:x2])
                if 'bank_count' in m['fn_kwargs'].keys():
                    m['fn_kwargs']['bank_count'] = 1

            else:
                # single model-wide parameter, only fit for d==0
                if d == 0:
                    m['phi'][k] = v
                else:
                    m['fn_kwargs'][k] = v  # keep fixed for d>0
                    del m['phi']
                    del m['prior']
        tmodelspec.append(m)

    return tmodelspec
Beispiel #6
0
def fit_population_channel(rec, modelspec,
                           fit_set_all, fit_set_slice,
                           analysis_function=analysis.fit_basic,
                           metric=metrics.nmse,
                           fitter=scipy_minimize, fit_kwargs={}):
    """
    DEPRECATED?
    Fit all the population channels, but only trying to predict the
    responses inverted through the weight_channels of layer 2
    :param rec:
    :param modelspec:
    :param fit_set_all:
    :param fit_set_slice:
    :param analysis_function:
    :param metric:
    :param fitter:
    :param fit_kwargs:
    :return:
    """
    # guess at number of subspace dimensions
    # dim_count = modelspec[fit_set_slice[0]]['phi']['coefficients'].shape[1]

    # invert cell-specific modules
    trec = _invert_slice(rec, modelspec, fit_set_slice)

    tmodelspec = ms.ModelSpec()

    for i in fit_set_all:
        m = modelspec[i].copy()
        tmodelspec.append(m)

    tmodelspec = analysis_function(trec, tmodelspec, fitter=fitter,
                                   metric=metric, fit_kwargs=fit_kwargs)

    for i in fit_set_all:
        for k, v in tmodelspec[i]['phi'].items():
            modelspec[i]['phi'][k] = v

    return modelspec
Beispiel #7
0
def init_relsat(rec, modelspec):
    modelspec = copy.deepcopy(modelspec)

    target_i = find_module('saturated_rectifier', modelspec)
    if target_i is None:
        log.warning("No relsat module was found, can't initialize.")
        return modelspec

    if target_i == len(modelspec):
        fit_portion = modelspec.modules
    else:
        fit_portion = modelspec.modules[:target_i]

    # generate prediction from module preceeding dexp
    #rec = ms.evaluate(rec, ms.ModelSpec(fit_portion)).apply_mask()
    rec = ms.ModelSpec(fit_portion).evaluate(rec).apply_mask()

    pred = rec['pred'].as_continuous().flatten()
    resp = rec['resp'].as_continuous().flatten()
    stdr = np.nanstd(resp)

    base = np.min(resp)
    amplitude = min(np.mean(resp) + stdr * 3, np.max(resp))
    shift = np.mean(pred) - 1.5 * np.nanstd(pred)
    kappa = 1

    base_prior = ('Exponential', {'beta': base})
    amplitude_prior = ('Exponential', {'beta': amplitude})
    shift_prior = ('Normal', {'mean': shift, 'sd': shift})
    kappa_prior = ('Exponential', {'beta': kappa})

    modelspec['prior'] = {
        'base': base_prior,
        'amplitude': amplitude_prior,
        'shift': shift_prior,
        'kappa': kappa_prior
    }

    return modelspec
Beispiel #8
0
def fit_population_channel_fast(rec, modelspec,
                                fit_set_all, fit_set_slice,
                                analysis_function=analysis.fit_basic,
                                metric=metrics.nmse,
                                fitter=scipy_minimize, fit_kwargs={}):

    # guess at number of subspace dimensions
    dim_count = modelspec[fit_set_slice[0]]['phi']['coefficients'].shape[1]

    for d in range(dim_count):
        # fit each dim separately
        log.info('Updating dim %d/%d', d+1, dim_count)

        # create modelspec with single population subspace filter
        tmodelspec = ms.ModelSpec()
        for i in fit_set_all:
            m = copy.deepcopy(modelspec[i])
            for k, v in m['phi'].items():
                x = v.shape[0]
                if x >= dim_count:
                    x1 = int(x/dim_count) * d
                    x2 = int(x/dim_count) * (d+1)

                    m['phi'][k] = v[x1:x2]
                    if 'bank_count' in m['fn_kwargs'].keys():
                        m['fn_kwargs']['bank_count'] = 1
                else:
                    # single model-wide parameter, only fit for d==0
                    if d == 0:
                        m['phi'][k] = v
                    else:
                        m['fn_kwargs'][k] = v  # keep fixed for d>0
                        del m['phi']
                        del m['prior']
            tmodelspec.append(m)

        # append full-population layer as non-free parameters
        for i in fit_set_slice:
            m = copy.deepcopy(modelspec[i])
            for k, v in m['phi'].items():
                # just applies to wc module?
                if v.shape[1] >= dim_count:
                    m['fn_kwargs'][k] = v[:, [d]]
                else:
                    m['fn_kwargs'][k] = v
                # print(k)
                # print(m['fn_kwargs'][k])
            del m['phi']
            del m['prior']
            tmodelspec.append(m)
            # print(tmodelspec[-1]['fn_kwargs'])

        # compute residual from prediction by the rest of the pop model
        trec = rec.copy()
        trec = ms.evaluate(trec, modelspec)
        r = trec['resp'].as_continuous()
        p = trec['pred'].as_continuous()

        trec = ms.evaluate(trec, tmodelspec)
        p2 = trec['pred'].as_continuous()
        trec['resp'] = trec['resp']._modified_copy(data=r-p+p2)

        # import pdb;
        # pdb.set_trace()
        tmodelspec = analysis_function(trec, tmodelspec, fitter=fitter,
                                       metric=metric, fit_kwargs=fit_kwargs)

        for i in fit_set_all:
            for k, v in tmodelspec[i]['phi'].items():
                x = modelspec[i]['phi'][k].shape[0]
                if x >= dim_count:
                    x1 = int(x/dim_count) * d
                    x2 = int(x/dim_count) * (d+1)

                    modelspec[i]['phi'][k][x1:x2] = v
                else:
                    modelspec[i]['phi'][k] = v

        # for i in fit_set_slice:
        #    for k, v in tmodelspec[i]['phi'].items():
        #        if modelspec[i]['phi'][k].shape[0] >= dim_count:
        #            modelspec[i]['phi'][k][d,:] = v

        # print([modelspec.phi[f] for f in fit_set_all])

    return modelspec
Beispiel #9
0
def init_logsig(rec, modelspec):
    '''
    Initialization of priors for logistic_sigmoid,
    based on process described in methods of Rabinowitz et al. 2014.
    '''
    # preserve input modelspec
    modelspec = copy.deepcopy(modelspec)

    target_i = find_module('logistic_sigmoid', modelspec)
    if target_i is None:
        log.warning("No logsig module was found, can't initialize.")
        return modelspec

    if target_i == len(modelspec):
        fit_portion = modelspec.modules
    else:
        fit_portion = modelspec.modules[:target_i]

    # generate prediction from module preceeding dexp
    #rec = ms.evaluate(rec, ms.ModelSpec(fit_portion))
    rec = ms.ModelSpec(fit_portion).evaluate(rec)

    pred = rec['pred'].as_continuous()
    resp = rec['resp'].as_continuous()

    mean_pred = np.nanmean(pred)
    min_pred = np.nanmean(pred) - np.nanstd(pred) * 3
    max_pred = np.nanmean(pred) + np.nanstd(pred) * 3
    if min_pred < 0:
        min_pred = 0
        mean_pred = (min_pred + max_pred) / 2

    pred_range = max_pred - min_pred
    min_resp = max(np.nanmean(resp) - np.nanstd(resp) * 3, 0)  # must be >= 0
    max_resp = np.nanmean(resp) + np.nanstd(resp) * 3
    resp_range = max_resp - min_resp

    # Rather than setting a hard value for initial phi,
    # set the prior distributions and let the fitter/analysis
    # decide how to use it.
    base0 = min_resp + 0.05 * (resp_range)
    amplitude0 = resp_range
    shift0 = mean_pred
    kappa0 = pred_range
    log.info("Initial   base,amplitude,shift,kappa=({}, {}, {}, {})".format(
        base0, amplitude0, shift0, kappa0))

    base = ('Exponential', {'beta': base0})
    amplitude = ('Exponential', {'beta': amplitude0})
    shift = ('Normal', {'mean': shift0, 'sd': pred_range})
    kappa = ('Exponential', {'beta': kappa0})

    if 'phi' in modelspec[target_i]:
        modelspec[target_i]['phi'].update({
            'base': base0,
            'amplitude': amplitude0,
            'shift': shift0,
            'kappa': kappa0
        })

    modelspec[target_i]['prior'].update({
        'base': base,
        'amplitude': amplitude,
        'shift': shift,
        'kappa': kappa
    })

    modelspec[target_i]['bounds'] = {
        'base': (1e-15, None),
        'amplitude': (1e-15, None),
        'shift': (None, None),
        'kappa': (1e-15, None)
    }

    return modelspec
Beispiel #10
0
def init_dexp(rec, modelspec, nl_mode=2, override_target_i=None):
    """
    choose initial values for dexp applied after preceeding fir is
    initialized
    nl_mode must be in {1,2} (default is 2),
            pre 11/29/18 models were fit with v1
            1: amp = np.nanstd(resp) * 3
               kappa = np.log(2 / (np.max(pred) - np.min(pred) + 1))
            2:
               amp = resp[pred>np.percentile(pred,90)].mean()
               kappa = np.log(2 / (np.std(pred)*3))

    override_target_i should be an integer index into the modelspec.
    This replaces the normal behavior of the function which would look up
    the index of the 'double_exponential' module. Use this if you want
    to use dexp's initialization procedure for a similar nonlinearity module.

    """
    # preserve input modelspec
    modelspec = copy.deepcopy(modelspec)

    if override_target_i is None:
        target_i = find_module('double_exponential', modelspec)
        if target_i is None:
            log.warning("No dexp module was found, can't initialize.")
            return modelspec
    else:
        target_i = override_target_i

    if target_i == len(modelspec):
        fit_portion = modelspec.modules
    else:
        fit_portion = modelspec.modules[:target_i]

    # ensures all previous modules have their phi initialized
    # choose prior mean if not found
    for i, m in enumerate(fit_portion):
        if ('phi' not in m.keys()) and ('prior' in m.keys()):
            log.debug('Phi not found for module, using mean of prior: %s', m)
            m = priors.set_mean_phi([m])[0]  # Inits phi for 1 module
            fit_portion[i] = m

    # generate prediction from module preceeding dexp
    #rec = ms.evaluate(rec, ms.ModelSpec(fit_portion))
    rec = ms.ModelSpec(fit_portion).evaluate(rec)

    in_signal = modelspec[target_i]['fn_kwargs']['i']
    pchans = rec[in_signal].shape[0]
    amp = np.zeros([pchans, 1])
    base = np.zeros([pchans, 1])
    kappa = np.zeros([pchans, 1])
    shift = np.zeros([pchans, 1])
    out_signal = modelspec.meta.get('output_name', 'resp')
    for i in range(pchans):
        resp = rec[out_signal].as_continuous()
        pred = rec[in_signal].as_continuous()[i:(i + 1), :]
        if resp.shape[0] == pchans:
            resp = resp[i:(i + 1), :]

        keepidx = np.isfinite(resp) * np.isfinite(pred)
        resp = resp[keepidx]
        pred = pred[keepidx]

        # choose phi s.t. dexp starts as almost a straight line
        # phi=[max_out min_out slope mean_in]
        # meanr = np.nanmean(resp)
        stdr = np.nanstd(resp)

        # base = np.max(np.array([meanr - stdr * 4, 0]))
        base[i, 0] = np.min(resp)

        # amp = np.max(resp) - np.min(resp)
        if nl_mode == 1:
            amp[i, 0] = stdr * 3
            predrange = 2 / (np.max(pred) - np.min(pred) + 1)
            shift[i, 0] = np.mean(pred)
            kappa[i, 0] = np.log(predrange)
        elif (nl_mode == 2) & (np.std(pred) == 0):
            log.warning(
                'Init dexp: channel %d prediction has zero std, reverting to nl_mode==1',
                i)
            amp[i, 0] = stdr * 3
            predrange = 2 / (np.max(pred) - np.min(pred) + 1)
            shift[i, 0] = np.mean(pred)
            kappa[i, 0] = np.log(predrange)
        elif nl_mode == 2:
            mask = np.zeros_like(pred, dtype=bool)
            pct = 91
            while (sum(mask) < .01 * pred.shape[0]) and (pct > 1):
                pct -= 1
                mask = pred > np.percentile(pred, pct)
            if np.sum(mask) == 0:
                mask = np.ones_like(pred, dtype=bool)

            if pct != 90:
                log.warning(
                    'Init dexp: Default for init mode 2 is to find mean '
                    'of responses for times where pred>pctile(pred,90). '
                    '\nNo times were found so this was lowered to '
                    'pred>pctile(pred,%d).', pct)
            amp[i, 0] = resp[mask].mean()
            predrange = 2 / (np.std(pred) * 3)
            if not np.isfinite(predrange):
                predrange = 1
            shift[i, 0] = np.mean(pred)
            kappa[i, 0] = np.log(predrange)
        elif nl_mode == 3:
            base[i, 0] = np.min(resp) - stdr
            amp[i, 0] = stdr * 4
            predrange = 1 / (np.std(pred) * 3)

            shift[i, 0] = np.mean(pred)
            kappa[i, 0] = np.log(predrange)
        elif nl_mode == 4:
            base[i, 0] = np.mean(resp) - stdr * 1
            amp[i, 0] = stdr * 4
            predrange = 2 / (np.std(pred) * 3)
            if not np.isfinite(predrange):
                predrange = 1
            kappa[i, 0] = 0
            shift[i, 0] = 0
        else:
            raise ValueError('nl mode = {} not valid'.format(nl_mode))

    modelspec[target_i]['phi'] = {
        'amplitude': amp,
        'base': base,
        'kappa': kappa,
        'shift': shift
    }
    log.info("Init dexp: %s", modelspec[target_i]['phi'])

    return modelspec
Beispiel #11
0
def from_keywords(keyword_string,
                  registry=None,
                  rec=None,
                  meta={},
                  init_phi_to_mean_prior=True,
                  input_name='stim',
                  output_name='resp'):
    '''
    Returns a modelspec created by splitting keyword_string on underscores
    and replacing each keyword with what is found in the nems.keywords.defaults
    registry. You may provide your own keyword registry using the
    registry={...} argument.
    '''

    if registry is None:
        from nems.xforms import keyword_lib
        registry = keyword_lib
    keywords = keyword_string.split('-')

    # Lookup the modelspec fragments in the registry
    modelspec = ms.ModelSpec()
    for kw in keywords:
        if (kw.startswith("fir.Nx") or kw.startswith("wc.Nx")) and \
                (rec is not None):
            N = rec[input_name].nchans
            kw_old = kw
            kw = kw.replace(".N", ".{}".format(N))
            log.info("kw: dynamically subbing %s with %s", kw_old, kw)
        elif kw.startswith("stategain.N") and (rec is not None):
            N = rec['state'].nchans
            kw_old = kw
            kw = kw.replace("stategain.N", "stategain.{}".format(N))
            log.info("kw: dynamically subbing %s with %s", kw_old, kw)
        elif (kw.endswith(".N")) and (rec is not None):
            N = rec[input_name].nchans
            kw_old = kw
            kw = kw.replace(".N", ".{}".format(N))
            log.info("kw: dynamically subbing %s with %s", kw_old, kw)
        elif (kw.endswith(".cN")) and (rec is not None):
            N = rec[input_name].nchans
            kw_old = kw
            kw = kw.replace(".cN", ".c{}".format(N))
            log.info("kw: dynamically subbing %s with %s", kw_old, kw)

        elif (kw.endswith("xN")) and (rec is not None):
            N = rec[input_name].nchans
            kw_old = kw
            kw = kw.replace("xN", "x{}".format(N))
            log.info("kw: dynamically subbing %s with %s", kw_old, kw)

        elif ("xN" in kw) and (rec is not None):
            N = rec[input_name].nchans
            kw_old = kw
            kw = kw.replace("xN", "x{}".format(N))
            log.info("kw: dynamically subbing %s with %s", kw_old, kw)

        if (".S" in kw or ".Sx" in kw) and (rec is not None):
            S = rec['state'].nchans
            kw_old = kw
            kw = kw.replace(".S", ".{}".format(S))
            log.info("kw: dynamically subbing %s with %s", kw_old, kw)

        if ("xS" in kw) and (rec is not None):
            S = rec['state'].nchans
            kw_old = kw
            kw = kw.replace("xS", "x{}".format(S))
            log.info("kw: dynamically subbing %s with %s", kw_old, kw)

        if (".R" in kw) and (rec is not None):
            R = rec[output_name].nchans
            kw_old = kw
            kw = kw.replace(".R", ".{}".format(R))
            log.info("kw: dynamically subbing %s with %s", kw_old, kw)

        elif ("xR" in kw) and (rec is not None):
            R = rec[output_name].nchans
            kw_old = kw
            kw = kw.replace("xR", "x{}".format(R))
            log.info("kw: dynamically subbing %s with %s", kw_old, kw)

        else:
            log.info('kw: %s', kw)

        if registry.kw_head(kw) not in registry:
            raise ValueError("unknown keyword: {}".format(kw))

        templates = copy.deepcopy(registry[kw])
        if not isinstance(templates, list):
            templates = [templates]
        for d in templates:
            d['id'] = kw
            if init_phi_to_mean_prior:
                d = priors.set_mean_phi([d])[0]  # Inits phi for 1 module
            modelspec.append(d)

    # first module that takes input='pred' should take ctx['input_name']
    # instead. can't hard-code in keywords, since we don't know which
    # keyword will be first. and can't assume that it will be module[0]
    # because those might be state manipulations
    first_input_found = False
    i = 0
    while (not first_input_found) and (i < len(modelspec)):
        if ('i' in modelspec[i]['fn_kwargs'].keys()) and (
                modelspec[i]['fn_kwargs']['i'] == 'pred'):
            log.info("Setting modelspec[%d] input to %s", i, input_name)
            modelspec[i]['fn_kwargs']['i'] = input_name
            """ OLD
            if input_name != 'stim':
                modelspec[i]['fn_kwargs']['i'] = input_name
            elif 'state' in modelspec[i]['fn']:
                modelspec[i]['fn_kwargs']['i'] = 'psth'
            else:
                modelspec[i]['fn_kwargs']['i'] = input_name
            """

            # 'i' key found
            first_input_found = True
        i += 1

    # insert metadata, if provided
    if rec is not None:
        if 'cellids' in meta.keys():
            # cellids list already exists. keep it.
            pass
        elif 'cellid' in meta.keys():
            meta['cellids'] = [meta['cellid']]
        elif ((rec['resp'].shape[0] > 1)
              and (type(rec.meta['cellid']) is list)):
            # guess cellids list from rec.meta
            meta['cellids'] = rec.meta['cellid']

    meta['input_name'] = input_name
    meta['output_name'] = output_name

    # for modelspec object, we know that meta must exist, so just update
    modelspec.meta.update(meta)

    if modelspec.meta.get('modelpath') is None:
        destination = get_default_savepath(modelspec)
        modelspec.meta['modelpath'] = destination
        modelspec.meta['figurefile'] = os.path.join(destination,
                                                    'figure.0000.png')

    return modelspec
Beispiel #12
0
def pick_best_phi(modelspec=None,
                  est=None,
                  val=None,
                  est_list=None,
                  val_list=None,
                  criterion='mse_fit',
                  metric_fn='nems.metrics.mse.nmse',
                  jackknifed_fit=False,
                  keep_n=1,
                  IsReload=False,
                  **context):
    """
    For models with multiple fits (eg, based on multiple initial conditions),
     find the best prediction for the recording provided (presumably est data,
     though possibly something held out)

    For jackknifed fits, pick the best fit for each jackknife set. so a F x J modelspec
     is reduced in size to 1 x J. Models are tested with est data for that jackknife.
     This has only been tested with est recording, which is likely should be used.

    :param modelspec: should have fit_count>0
    :param est: view_count should match fit_count, ie,
                after generate_prediction is called
    :param context: extra context stuff for xforms compatibility.
    :return: modelspec with fit_count==1
    """
    if IsReload:
        return {}
    if modelspec.fit_count <= 1:
        return {}

    if est_list is None:
        est_list = [est]
        val_list = [val]

    #for cellidx,est,val in zip(range(len(est_list)),est_list,val_list):
    #    modelspec.set_cell(cellidx)
    #    est, val = nems.analysis.api.generate_prediction(est, val, modelspec, jackknifed_fit=jackknifed_fit)
    #    modelspec.recording = val
    #    est_list[cellidx] = est
    #    val_list[cellidx] = val

    #import pdb; pdb.set_trace()

    jack_count = modelspec.jack_count
    fit_count = modelspec.fit_count
    best_idx = np.zeros(jack_count, dtype=int)
    modelspec.set_cell(0)

    # for each jackknife set, figure out best fit. save it to new_raw
    new_raw = np.zeros((modelspec.cell_count, keep_n, jack_count), dtype='O')
    for j in range(jack_count):
        view_range = [i * jack_count + j for i in range(fit_count)]
        x = None
        n = 0

        # support for multi-cell, len(est_list)>1 fits
        #import pdb; pdb.set_trace()

        for cell_idx, this_est in enumerate(est_list):
            #this_modelspec = modelspec.copy(jack_index=j)
            #this_modelspec.cell_index = cell_idx
            modelspec.cell_index = cell_idx
            modelspec.jack_index = j

            if 'loss' in modelspec.meta.keys():
                # quick ranking: use a vector of loss values, one per fitidx stored in meta
                if x is None:
                    x = modelspec.meta['loss']
                else:
                    x = x + modelspec.meta['loss']
                n = n + 1
                if j > 1:
                    log.info(
                        'Not supported yet! jackknife + multifit using tf loss to select'
                    )
                    import pdb
                    pdb.set_trace()

            elif (metric_fn == 'nems.metrics.mse.nmse') & (criterion
                                                           == 'mse_fit'):
                # for backwards compatibility, run the below code to compute metric specified
                # by criterion.
                # unclear if this works
                x = np.zeros(modelspec.fit_count)
                for fitidx in range(modelspec.fit_count):
                    traw = modelspec.raw[cell_idx, fitidx, j]
                    tmodelspec = ms.ModelSpec(traw)
                    this_est = tmodelspec.evaluate(this_est)
                    this_val = tmodelspec.evaluate(this_val)
                    tmodelspec = standard_correlation(est=this_est,
                                                      val=this_val,
                                                      modelspec=tmodelspec)
                    # average performance across output channels (if more than one output)
                    x[fitidx] += tmodelspec.meta[criterion].sum(axis=0)
                n += new_modelspec.meta[criterion].shape[0]

            else:
                # unclear if this works
                fn = nems.utils.lookup_fn_at(metric_fn)
                tx = []
                for fitidx in range(modelspec.fit_count):
                    traw = modelspec.raw[cell_idx, fitidx, j]
                    tmodelspec = ms.ModelSpec(traw)
                    this_est = tmodelspec.evaluate(this_est)
                    tx.append(fn(this_est, **context))
                n = n + tx[0].shape[0]
                tx = np.concatenate(tx, axis=1).mean(axis=0)
                if x is None:
                    x = tx
                else:
                    x = x + tx

        x = x / n

        if np.any(x == 0):
            raise RuntimeError('Error: At least one elements of modelspec.meta[\'loss\'] is 0. '\
                    'This will cause this fn to pick that phi, but it\'s probably not. Bug in the code?')
        tx = x.copy()
        for n in range(keep_n):
            best_idx[j] = int(np.nanargmin(tx))
            new_raw[:, n, j] = modelspec.raw[:, best_idx[j], j]

            log.info(
                'jack %d: %d/%d best phi (fit_idx=%d) has fit_metric=%.5f', j,
                n + 1, keep_n, best_idx[j], tx[best_idx[j]])
            tx[best_idx[j]] = np.nanmax(tx)

    for cell_index in range(new_raw.shape[0]):
        new_raw[cell_index, 0,
                0][0]['meta'] = modelspec.raw[cell_index, 0,
                                              0][0]['meta'].copy()
    new_modelspec = ms.ModelSpec(new_raw)
    new_modelspec.set_cell(0)
    new_modelspec.set_jack(0)
    new_modelspec.meta['rand_' + criterion] = x

    return {'modelspec': new_modelspec, 'best_random_idx': best_idx}
Beispiel #13
0
def predict_and_summarize_for_all_modelspec(
        modelspec=None,
        est=None,
        val=None,
        est_list=None,
        val_list=None,
        criterion=['r_test', 'se_test', 'mse_test', 'll_test'],
        metric_fn='nems.metrics.mse.nmse',
        jackknifed_fit=False,
        keep_n=1,
        IsReload=False,
        **context):
    """
     For models with multiple fits (eg, based on multiple initial conditions),
      find the best prediction for the recording provided (presumably est data,
      though possibly something held out)

     For jackknifed fits, pick the best fit for each jackknife set. so a F x J modelspec
      is reduced in size to 1 x J. Models are tested with est data for that jackknife.
      This has only been tested with est recording, which is likely should be used.

     :param modelspec: should have fit_count>0
     :param est: view_count should match fit_count, ie,
                 after generate_prediction is called
     :param context: extra context stuff for xforms compatibility.
     :return: modelspec with fit_count==1
     """
    if IsReload:
        return {}
    if modelspec.fit_count <= 1:
        return {}

    if est_list is None:
        est_list = [est]
        val_list = [val]

    if modelspec.jack_count > 1:
        log.info(
            'Not supported yet! jackknife + multifit using tf loss to select')
        import pdb
        pdb.set_trace()

    if modelspec.cell_count > 1:
        log.info(
            'Not tested for mega models. I think it will work but confirm. LAS'
        )
        import pdb
        pdb.set_trace()
    fit_count = modelspec.fit_count
    modelspec.set_cell(0)  #cell is a misnomer here, means site
    Ncells = val_list[0]['resp'].shape[0]
    means = [np.zeros(modelspec.fit_count) for i in range(len(criterion))]
    if modelspec.cell_count == 1:
        indiv_cells = [
            np.zeros((Ncells, modelspec.fit_count))
            for i in range(len(criterion))
        ]
    n = 0
    for cell_idx, (this_est, this_val) in enumerate(zip(est_list, val_list)):
        #loop through each "cell_idx" (site index for megamodels)
        for fitidx in range(modelspec.fit_count):
            #loop through each fitidx
            traw = modelspec.raw[cell_idx, fitidx, 0]
            tmodelspec = ms.ModelSpec(traw)
            this_est = tmodelspec.evaluate(this_est)
            this_val = tmodelspec.evaluate(this_val)
            tmodelspec = standard_correlation(est=this_est,
                                              val=this_val,
                                              modelspec=tmodelspec)
            # average performance across output channels (if more than one output)
            for i, criterion_ in enumerate(criterion):
                means[i][fitidx] += tmodelspec.meta[criterion_].sum(axis=0)
                if modelspec.cell_count == 1:
                    indiv_cells[i][:, fitidx] = tmodelspec.meta[criterion_][:,
                                                                            0]
        n += tmodelspec.meta[criterion[0]].shape[0]

    for i, criterion_ in enumerate(criterion):
        means[i] = means[i] / n
        modelspec.meta['rand_' + criterion_] = means[i]
        if modelspec.cell_count == 1:
            modelspec.meta['rand_' + criterion_ + '_all'] = indiv_cells[i]

    return {'modelspec': modelspec}
Beispiel #14
0
def fit_tf_init(modelspec,
                est: recording.Recording,
                nl_init: str = 'tf',
                IsReload: bool = False,
                **kwargs) -> dict:
    """Inits a model using tf.

    Makes a new model up to the last relu or first levelshift, in the process setting levelshift to the mean of the
    resp. Excludes in the new model stp, rdt_gain, state_dc_gain, state_gain. Fits this. Then runs init_static_nl ,
    which looks at the last 2 layers of the original model, and if any of dexp, relu, log_sig, sat_rect are in those
    last two, only fits the first it encounters (freezes all other layers).
    """
    if IsReload:
        return {}

    def first_substring_index(strings, substring):
        try:
            return next(i for i, string in enumerate(strings)
                        if substring in string)
        except StopIteration:
            return None

    # find the first 'lvl' or last 'relu'
    ms_modules = [ms['fn'] for ms in modelspec]
    #up_to_idx = first_substring_index(ms_modules, 'levelshift')
    relu_idx = first_substring_index(reversed(ms_modules), 'relu')
    lvl_idx = first_substring_index(reversed(ms_modules), 'levelshift')
    _idxs = [
        i for i in [relu_idx, lvl_idx, len(modelspec) - 1] if i is not None
    ]
    up_to_idx = len(modelspec) - 1 - np.min(_idxs)
    #last_idx = np.min([relu_idx, lvl_idx])
    #up_to_idx = len(modelspec) - 1 - up_to_idx
    #if up_to_idx is None:
    #    up_to_idx = first_substring_index(reversed(ms_modules), 'levelshift')
    #    # because reversed, need to mirror the idx
    #    if up_to_idx is not None:
    #        up_to_idx = len(modelspec) - 1 - up_to_idx
    #    else:
    #        up_to_idx = len(modelspec) - 1
    log.info('up_to_idx=%d (%s)', up_to_idx, modelspec[up_to_idx]['fn'])

    # do the +1 here to avoid adding to None
    up_to_idx += 1
    # exclude the following from the init
    exclude = ['rdt_gain', 'state_dc_gain', 'state_gain']
    freeze = ['stp']
    # more complex version of first_substring_index: checks for not membership in init_static_nl_layers
    init_idxes = [
        idx for idx, ms in enumerate(ms_modules[:up_to_idx])
        if not any(sub in ms for sub in exclude)
    ]

    freeze_idxes = []
    # make a temp modelspec
    temp_ms = mslib.ModelSpec()
    log.info('Creating temporary model for init with:')
    for idx in init_idxes:
        # TODO: handle 'merge_channels'
        ms = copy.deepcopy(modelspec[idx])
        log.info(f'{ms["fn"]}')

        # fix levelshift if present (will always be the last module)
        if idx == init_idxes[-1] and 'levelshift' in ms['fn']:
            output_name = modelspec.meta.get('output_name', 'resp')
            try:
                mean_resp = np.nanmean(est[output_name].as_continuous(),
                                       axis=1,
                                       keepdims=True)
            except NotImplementedError:
                # as_continuous only available for RasterizedSignal
                mean_resp = np.nanmean(
                    est[output_name].rasterize().as_continuous(),
                    axis=1,
                    keepdims=True)
            if len(ms['phi']['level'][:]) == len(mean_resp):
                log.info(
                    f'Fixing "{ms["fn"]}" to: {mean_resp.flatten()[0]:.3f}')
                ms['phi']['level'][:] = mean_resp

        temp_ms.append(ms)
        if any(fr in ms['fn'] for fr in freeze):
            freeze_idxes.append(len(temp_ms) - 1)

    log.info(
        'Running first init fit: model up to first lvl/relu without stp/gain.')
    log.debug('freeze_idxes: %s', freeze_idxes)
    filepath = Path(modelspec.meta['modelpath']) / 'init_part1'
    temp_ms = fit_tf(temp_ms,
                     est,
                     freeze_layers=freeze_idxes,
                     filepath=filepath,
                     **kwargs)['modelspec']

    # put back into original modelspec
    for ms_idx, temp_ms_module in zip(init_idxes, temp_ms):
        modelspec[ms_idx] = temp_ms_module

    if nl_init == 'skip':
        return {'modelspec': modelspec}

    elif nl_init == 'scipy':
        # pre-fit static NL if it exists
        _d = init_static_nl(est=est, modelspec=modelspec)
        modelspec = _d['modelspec']
        # TODO : Initialize relu in some intelligent way?

        log.info('finished fit_tf_init, fit_idx=%d/%d',
                 modelspec.fit_index + 1, modelspec.fit_count)
        return {'modelspec': modelspec}
    else:
        # init the static nl
        init_static_nl_mapping = {
            'double_exponential': initializers.init_dexp,
            'relu': None,
            'logistic_sigmoid': initializers.init_logsig,
            'saturated_rectifier': initializers.init_relsat,
        }
        # first find the first occurrence of a static nl in last two layers
        # if present, remove it from the idxes of the modules to freeze, init the nl and fit, and return the modelspec
        for idx, ms in enumerate(modelspec[-2:], len(modelspec) - 2):
            for init_static_layer, init_fn in init_static_nl_mapping.items():
                if init_static_layer in ms['fn']:
                    log.info(
                        f'Initializing static nl "{ms["fn"]}" at layer #{idx}')
                    # relu has a custom init
                    if init_static_layer == 'relu':
                        ms['phi']['offset'][:] = -0.1
                    else:
                        modelspec = init_fn(est, modelspec, nl_mode=4)

                    static_nl_idx_not = list(
                        set(range(len(modelspec))) - set([idx]))
                    log.info(
                        'Running second init fit: all frozen but static nl.')

                    # don't overwrite the phis in the modelspec
                    kwargs['use_modelspec_init'] = True
                    filepath = Path(modelspec.meta['modelpath']) / 'init_part2'
                    return fit_tf(modelspec,
                                  est,
                                  freeze_layers=static_nl_idx_not,
                                  filepath=filepath,
                                  **kwargs)

    # no static nl to init
    return {'modelspec': modelspec}
Beispiel #15
0
def fit_tf_init(modelspec,
                est: recording.Recording,
                est_list: typing.Union[None, list] = None,
                nl_init: str = 'tf',
                IsReload: bool = False,
                isolate_NL: bool = False,
                skip_init: bool = False,
                up_to_idx=None,
                **kwargs) -> dict:
    """Inits a model using tf.

    Makes a new model up to the last relu or first levelshift, in the process setting levelshift to the mean of the
    resp. Excludes in the new model stp, rdt_gain, state_dc_gain, state_gain. Fits this. Then runs init_static_nl ,
    which looks at the last 2 layers of the original model, and if any of dexp, relu, log_sig, sat_rect are in those
    last two, only fits the first it encounters (freezes all other layers).
    """

    if IsReload or skip_init:
        return {}

    def first_substring_index(strings, substring):
        try:
            return next(i for i, string in enumerate(strings)
                        if substring in string)
        except StopIteration:
            return None

    modelspec.cell_index = 0
    if est_list is None:
        est_list = [est]

    # find the first 'lvl' or last 'relu'
    ms_modules = [ms['fn'] for ms in modelspec]
    if up_to_idx is None:
        #up_to_idx = first_substring_index(ms_modules, 'levelshift')
        relu_idx = first_substring_index(reversed(ms_modules),
                                         'relu')  #fit to last relu
        lvl_idx = first_substring_index(reversed(ms_modules),
                                        'levelshift')  #fit to last levelshift
        _idxs = [
            i
            for i in [relu_idx, lvl_idx, len(modelspec) - 1] if i is not None
        ]
        up_to_idx = len(modelspec) - 1 - np.min(_idxs)
        #last_idx = np.min([relu_idx, lvl_idx])
        #up_to_idx = len(modelspec) - 1 - up_to_idx

    log.info('up_to_idx=%d (%s)', up_to_idx, modelspec[up_to_idx]['fn'])

    # do the +1 here to avoid adding to None
    up_to_idx += 1
    # exclude the following from the init
    exclude = ['rdt_gain']  # , 'state_dc_gain', 'state_gain', 'sdexp']
    freeze = ['stp']

    # more complex version of first_substring_index: checks for not membership in init_static_nl_layers
    init_idxes = [
        idx for idx, ms in enumerate(ms_modules[:up_to_idx])
        if not any(sub in ms for sub in exclude)
    ]
    freeze_idxes = []

    # make a temp modelspec
    log.info('Creating temporary model for init')
    temp_ms = mslib.ModelSpec(cell_count=modelspec.cell_count)
    temp_ms.shared_count = modelspec.shared_count
    for cell_idx in range(modelspec.cell_count):
        modelspec.cell_index = cell_idx
        temp_ms.cell_index = cell_idx
        est = est_list[cell_idx]

        output_name = modelspec.meta.get('output_name', 'resp')
        try:
            mean_resp = np.nanmean(est[output_name].as_continuous(),
                                   axis=1,
                                   keepdims=True)
        except NotImplementedError:
            # as_continuous only available for RasterizedSignal
            mean_resp = np.nanmean(
                est[output_name].rasterize().as_continuous(),
                axis=1,
                keepdims=True)
        mean_added = False

        for idx in init_idxes:
            # TODO: handle 'merge_channels'
            ms = copy.deepcopy(modelspec[idx])
            log.info(f'{ms["fn"]}')

            # fix dc to mean_resp if present
            if (not mean_added) and (idx == init_idxes[-1]) and \
                    ('levelshift' in ms['fn']) and (len(ms['phi']['level']) == len(mean_resp)):
                log.info(
                    f'Fixing "{ms["fn"]}" to: {mean_resp.flatten()[0]:.3f}')
                ms['phi']['level'][:] = mean_resp
                mean_added = True
            elif (not mean_added) and ('state_dc' in ms['fn']) and \
                    (ms['phi']['d'].shape[0] == len(mean_resp)):
                log.info(f'Fixing "{ms["fn"]}[d][:,0]" to mean_resp')
                ms['phi']['d'][:, [0]] = mean_resp
                mean_added = True

            temp_ms.append(ms)
            if any(fr in ms['fn'] for fr in freeze):
                freeze_idxes.append(len(temp_ms) - 1)

    # important to reset cell_index to zero???
    est = est_list[0]
    temp_ms.cell_index = 0

    log.info(
        'Running first init fit: model up to last lvl/relu without stp/gain.')
    log.debug('freeze_idxes: %s', freeze_idxes)
    filepath = Path(modelspec.meta['modelpath']) / 'init_part1'

    if 'freeze_layers' in kwargs.keys():
        force_freeze = kwargs.pop(
            'freeze_layers')  # can't pass freeze_layers twice,
    else:
        force_freeze = None
    if force_freeze is not None:
        freeze_idxes = list(
            set(force_freeze +
                freeze_idxes))  # but also need to take union with freeze_idxes
    if modelspec.cell_count > 1:
        temp_ms = fit_tf_iterate(temp_ms,
                                 est,
                                 est_list=est_list,
                                 freeze_layers=freeze_idxes,
                                 filepath=filepath,
                                 **kwargs)['modelspec']
    else:
        temp_ms = fit_tf(temp_ms,
                         est,
                         est_list=est_list,
                         freeze_layers=freeze_idxes,
                         filepath=filepath,
                         **kwargs)['modelspec']

    for cell_idx in range(temp_ms.cell_count):
        log.info(
            f"***********************************************************************************"
        )
        log.info(
            f"****   fit_tf_init, fit_index={modelspec.fit_index} cell_index={cell_idx} fitting output NL    ****"
        )
        modelspec.cell_index = cell_idx
        temp_ms.cell_index = cell_idx
        est = est_list[cell_idx]

        # put back into original modelspec
        meta_save = modelspec.meta.copy()
        for ms_idx, temp_ms_module in zip(init_idxes, temp_ms):
            modelspec[ms_idx] = temp_ms_module

        if modelspec.fit_count > 1:
            if modelspec.fit_index == 0:
                meta_save['n_epochs'] = np.zeros(modelspec.fit_count)
                meta_save['loss'] = np.zeros(modelspec.fit_count)
            meta_save['n_epochs'][
                modelspec.fit_index] = temp_ms.meta['n_epochs']
            meta_save['loss'][modelspec.fit_index] = temp_ms.meta['loss']

        modelspec.meta.update(meta_save)

        if nl_init == 'skip':
            pass
            # to support multiple cell_indexes, don't return here,
            # wait til done looping instead
            #return {'modelspec': modelspec}

        elif nl_init == 'scipy':
            # pre-fit static NL if it exists
            _d = init_static_nl(est=est, modelspec=modelspec)
            modelspec = _d['modelspec']
            # TODO : Initialize relu in some intelligent way?

            log.info('finished fit_tf_init, fit_idx=%d/%d',
                     modelspec.fit_index + 1, modelspec.fit_count)
            #return {'modelspec': modelspec}

        else:
            # init the static nl
            init_static_nl_mapping = {
                'double_exponential': initializers.init_dexp,
                'relu': None,
                'logistic_sigmoid': initializers.init_logsig,
                'saturated_rectifier': initializers.init_relsat,
            }
            # first find the first occurrence of a static nl in last two layers
            # if present, remove it from the idxes of the modules to freeze, init the nl and fit, and return the modelspec
            stage2_pending = True
            for idx, ms in enumerate(modelspec[-2:], len(modelspec) - 2):
                for init_static_layer, init_fn in init_static_nl_mapping.items(
                ):
                    if stage2_pending and (init_static_layer in ms['fn']):
                        log.info(
                            f'Initializing static nl "{ms["fn"]}" at layer #{idx}'
                        )
                        # relu has a custom init
                        if init_static_layer == 'relu':
                            ms['phi']['offset'][:] = -0.1
                        else:
                            modelspec = init_fn(est, modelspec, nl_mode=4)

                        if force_freeze is not None:
                            log.info(
                                f'Running second init fit: force_freeze: {force_freeze}.'
                            )
                            static_nl_idx_not = force_freeze
                        elif isolate_NL:
                            log.info(
                                'Running second init fit: all frozen but static nl.'
                            )
                            static_nl_idx_not = list(
                                set(range(len(modelspec))) - set([idx]))
                        elif modelspec.cell_count > 1:
                            log.info(
                                f'Titan model: freezing first {modelspec.shared_count} layers.'
                            )
                            static_nl_idx_not = list(
                                range(modelspec.shared_count))
                        else:
                            log.info(
                                'Running second init fit: not frozen but coarser tolerance.'
                            )
                            static_nl_idx_not = []

                        # don't overwrite the phis in the modelspec
                        kwargs['use_modelspec_init'] = True
                        filepath = Path(
                            modelspec.meta['modelpath']) / 'init_part2'
                        modelspec = fit_tf(modelspec,
                                           est,
                                           freeze_layers=static_nl_idx_not,
                                           filepath=filepath,
                                           **kwargs)['modelspec']
                        stage2_pending = False

    modelspec.cell_index = 0
    return {'modelspec': modelspec}
Beispiel #16
0
def _prefit_to_target(rec,
                      modelspec,
                      analysis_function,
                      target_module,
                      extra_exclude=[],
                      fitter=scipy_minimize,
                      metric=None,
                      fit_kwargs={}):
    """Removes all modules from the modelspec that come after the
    first occurrence of the target module, then performs a
    rough fit on the shortened modelspec, then adds the latter
    modules back on and returns the full modelspec.
    """

    # preserve input modelspec
    modelspec = copy.deepcopy(modelspec)

    # figure out last modelspec module to fit
    target_i = None
    if type(target_module) is not list:
        target_module = [target_module]
    for i, m in enumerate(modelspec.modules):
        tlist = [True for t in target_module if t in m['fn']]

        if len(tlist):
            target_i = i + 1
            # don't break. use last occurrence of target module

    if not target_i:
        log.info(
            "target_module: {} not found in modelspec.".format(target_module))
        return modelspec
    else:
        log.info("target_module: {0} found at modelspec[{1}].".format(
            target_module, target_i - 1))

    # identify any excluded modules and take them out of temp modelspec
    # that will be fit here
    exclude_idx = []
    tmodelspec = ms.ModelSpec()
    for i in range(len(modelspec)):
        m = copy.deepcopy(modelspec[i])

        for fn in extra_exclude:
            if (fn in m['fn']):
                if (m.get('phi') is None):
                    m = priors.set_mean_phi([m])[0]  # Inits phi
                    log.info('Mod %d (%s) fixing phi to prior mean', i, fn)
                else:
                    log.info('Mod %d (%s) fixing phi', i, fn)

                m['fn_kwargs'].update(m['phi'])
                del m['phi']
                del m['prior']
                exclude_idx.append(i)

        if ('relu' in m['fn']):
            log.info('found relu')

        elif ('levelshift' in m['fn']):
            #m = priors.set_mean_phi([m])[0]
            output_name = modelspec.meta.get('output_name', 'resp')
            try:
                mean_resp = np.nanmean(rec[output_name].as_continuous(),
                                       axis=1,
                                       keepdims=True)
            except NotImplementedError:
                # as_continuous only available for RasterizedSignal
                mean_resp = np.nanmean(
                    rec[output_name].rasterize().as_continuous(),
                    axis=1,
                    keepdims=True)
            log.info('Mod %d (%s) initializing level to %s mean %.3f', i,
                     m['fn'], output_name, mean_resp[0])
            log.info('resp has %d channels', len(mean_resp))
            m['phi']['level'][:] = mean_resp

        if (i < target_i) or ('merge_channels' in m['fn']):
            tmodelspec.append(m)

    # fit the subset of modules
    if metric is None:
        tmodelspec = analysis_function(rec,
                                       tmodelspec,
                                       fitter=fitter,
                                       fit_kwargs=fit_kwargs)
    else:
        tmodelspec = analysis_function(rec,
                                       tmodelspec,
                                       fitter=fitter,
                                       metric=metric,
                                       fit_kwargs=fit_kwargs)
    if type(tmodelspec) is list:
        # backward compatibility
        tmodelspec = tmodelspec[0]

    # reassemble the full modelspec with updated phi values from tmodelspec
    #print(modelspec[0])
    #print(modelspec.phi[2])
    for i in np.setdiff1d(np.arange(target_i), np.array(exclude_idx)).tolist():
        modelspec[int(i)] = tmodelspec[int(i)]

    return modelspec
Beispiel #17
0
def fit_population_slice(rec, modelspec, slice=0, fit_set=None,
                         analysis_function=analysis.fit_basic,
                         metric=metrics.nmse,
                         fitter=scipy_minimize, fit_kwargs={}):

    """
    fits a slice of a population model. modified from prefit_mod_subset

    slice: int
        response channel to fit
    fit_set: list
        list of mod names to fit

    """

    # preserve input modelspec
    modelspec = copy.deepcopy(modelspec)

    slice_count = rec['resp'].shape[0]
    if slice > slice_count:
        raise ValueError("Slice %d > slice_count %d", slice, slice_count)

    if fit_set is None:
        raise ValueError("fit_set list of module indices must be specified")

    if type(fit_set[0]) is int:
        fit_idx = fit_set
    else:
        fit_idx = []
        for i, m in enumerate(modelspec):
            for fn in fit_set:
                if fn in m['fn']:
                    fit_idx.append(i)

    # identify any excluded modules and take them out of temp modelspec
    # that will be fit here
    tmodelspec = ms.ModelSpec()
    sliceinfo = []
    for i, m in enumerate(modelspec):
        m = copy.deepcopy(m)
        # need to have phi in place
        if not m.get('phi'):
            log.info('Initializing phi for module %d (%s)', i, m['fn'])
            m = priors.set_mean_phi([m])[0]  # Inits phi

        if i in fit_idx:
            s = {}
            for key, value in m['phi'].items():
                log.debug('Slicing %d (%s) key %s chan %d for fit',
                          i, m['fn'], key, slice)

                # keep only sliced channel(s)
                if 'bank_count' in m['fn_kwargs'].keys():
                    bank_count = m['fn_kwargs']['bank_count']
                    filters_per_bank = int(value.shape[0] / bank_count)
                    slices = np.arange(slice*filters_per_bank,
                                       (slice+1)*filters_per_bank)
                    m['phi'][key] = value[slices, ...]
                    s[key] = slices
                    m['fn_kwargs']['bank_count'] = 1
                elif value.shape[0] == slice_count:
                    m['phi'][key] = value[[slice], ...]
                    s[key] = [slice]
                else:
                    raise ValueError("Not sure how to slice %s %s",
                                     m['fn'], key)

            # record info about how sliced this module parameter
            sliceinfo.append(s)

        tmodelspec.append(m)

    if len(fit_idx) == 0:
        log.info('No modules matching fit_set for slice fit')
        return modelspec

    exclude_idx = np.setdiff1d(np.arange(0, len(modelspec)),
                               np.array(fit_idx)).tolist()
    for i in exclude_idx:
        m = tmodelspec[i]
        log.debug('Freezing phi for module %d (%s)', i, m['fn'])

        m['fn_kwargs'].update(m['phi'])
        m['phi'] = {}
        tmodelspec[i] = m

    # generate temp recording with only resposnes of interest
    temp_rec = rec.copy()
    slice_chans = [temp_rec['resp'].chans[slice]]
    temp_rec['resp'] = temp_rec['resp'].extract_channels(slice_chans)

    # remove initial modules
    first_idx = fit_idx[0]
    if first_idx > 0:
        # print('firstidx {}'.format(first_idx))
        temp_rec = ms.evaluate(temp_rec, tmodelspec, stop=first_idx)
        # temp_rec['stim'] = temp_rec['pred'].copy()
        # tmodelspec = tmodelspec.copy(lb=first_idx)
        # tmodelspec[0]['fn_kwargs']['i'] = 'stim'
        tmodelspec = tmodelspec.copy()
        tmodelspec.fast_eval_on(rec=temp_rec, subset=fit_idx)
        first_idx = 0
        # print(tmodelspec)
    # print(temp_rec.signals.keys())

    # IS this mask necessary? Does it work?
    # if 'mask' in temp_rec.signals.keys():
    #    print("Data len pre-mask: %d" % (temp_rec['mask'].shape[1]))
    #    temp_rec = temp_rec.apply_mask()
    #    print("Data len post-mask: %d" % (temp_rec['mask'].shape[1]))

    # fit the subset of modules
    temp_rec = ms.evaluate(temp_rec, tmodelspec)
    error_before = metric(temp_rec)
    tmodelspec = analysis_function(temp_rec, tmodelspec, fitter=fitter,
                                   metric=metric, fit_kwargs=fit_kwargs)
    tmodelspec.fast_eval_off()
    temp_rec = ms.evaluate(temp_rec, tmodelspec)
    error_after = metric(temp_rec)
    dError = error_before - error_after
    if dError < 0:
        log.info("dError (%.6f - %.6f) = %.6f worse. not updating modelspec"
                 % (error_before, error_after, dError))
    else:
        log.info("dError (%.6f - %.6f) = %.6f better. updating modelspec"
                 % (error_before, error_after, dError))

        # reassemble the full modelspec with updated phi values from tmodelspec
        for i, mod_idx in enumerate(fit_idx):
            m = copy.deepcopy(modelspec[mod_idx])
            # need to have phi in place
            if not m.get('phi'):
                log.info('Intializing phi for module %d (%s)', mod_idx, m['fn'])
                m = priors.set_mean_phi([m])[0]  # Inits phi
            for key, value in tmodelspec[mod_idx - first_idx]['phi'].items():
                # print(key)
                # print(m['phi'][key].shape)
                # print(sliceinfo[i][key])
                # print(value)
                m['phi'][key][sliceinfo[i][key], :] = value

            modelspec[mod_idx] = m

    return modelspec
Beispiel #18
0
def pick_best_phi(modelspec=None,
                  est=None,
                  val=None,
                  criterion='mse_fit',
                  metric_fn='nems.metrics.mse.nmse',
                  jackknifed_fit=False,
                  keep_n=1,
                  IsReload=False,
                  **context):
    """
    For models with multiple fits (eg, based on multiple initial conditions),
     find the best prediction for the recording provided (presumably est data,
     though possibly something held out)

    For jackknifed fits, pick the best fit for each jackknife set. so a F x J modelspec
     is reduced in size to 1 x J. Models are tested with est data for that jackknife.
     This has only been tested with est recording, which is likely should be used.

    :param modelspec: should have fit_count>0
    :param est: view_count should match fit_count, ie,
                after generate_prediction is called
    :param context: extra context stuff for xforms compatibility.
    :return: modelspec with fit_count==1
    """
    if IsReload:
        return {}

    # generate prediction for each jack and fit
    new_est, new_val = generate_prediction(est,
                                           val,
                                           modelspec,
                                           jackknifed_fit=jackknifed_fit)

    jack_count = modelspec.jack_count
    fit_count = modelspec.fit_count
    best_idx = np.zeros(jack_count, dtype=int)
    new_raw = np.zeros((1, keep_n, jack_count), dtype='O')
    #import pdb; pdb.set_trace()

    # for each jackknife set, figure out best fit
    for j in range(jack_count):
        view_range = [i * jack_count + j for i in range(fit_count)]
        this_est = new_est.view_subset(view_range)
        this_modelspec = modelspec.copy(jack_index=j)

        if (metric_fn == 'nems.metrics.mse.nmse') & (criterion == 'mse_fit'):
            # for backwards compatibility, run the below code to compute metric specified
            # by criterion.
            new_modelspec = standard_correlation(est=this_est,
                                                 val=new_val,
                                                 modelspec=this_modelspec)
            # average performance across output channels (if more than one output)
            x = np.mean(new_modelspec.meta[criterion], axis=0)

        else:
            fn = nems.utils.lookup_fn_at(metric_fn)
            x = []
            for e in this_est.views():
                x.append(fn(e, **context))
        tx = x.copy()
        for n in range(keep_n):
            best_idx[j] = int(np.argmin(tx))
            new_raw[0, n, j] = modelspec.raw[0, best_idx[j], j]

            log.info(
                'jack %d: %d/%d best phi (fit_idx=%d) has fit_metric=%.5f', j,
                n + 1, keep_n, best_idx[j], tx[best_idx[j]])
            tx[best_idx[j]] = tx.max()

    new_raw[0, 0, 0][0]['meta'] = modelspec.meta.copy()
    new_modelspec = ms.ModelSpec(new_raw)
    new_modelspec.meta['rand_' + criterion] = x

    return {'modelspec': new_modelspec, 'best_random_idx': best_idx}