Beispiel #1
0
def generate_prediction(est, val, modelspecs):
    list_val = False
    if type(val) is list:
        # ie, if jackknifing
        list_val = True
    else:
        # Evaluate estimation and validation data

        # Since ms.evaluate only does a shallow copy of rec, successive
        # evaluations of one rec on many modelspecs just results in a list of
        # different pointers to the same recording. So need to force copies of
        # est/val before evaluating.
        if len(modelspecs) == 1:
            # no copies needed for 1 modelspec
            est = [est]
            val = [val]
        else:
            est = [est.copy() for i, _ in enumerate(modelspecs)]
            val = [val.copy() for i, _ in enumerate(modelspecs)]

    new_est = [ms.evaluate(d, m) for m, d in zip(modelspecs, est)]
    new_val = [ms.evaluate(d, m) for m, d in zip(modelspecs, val)]
    if list_val:
        new_val = [recording.jackknife_inverse_merge(new_val)]

    return new_est, new_val
Beispiel #2
0
def before_and_after_scatter(rec,
                             modelspec,
                             idx,
                             sig_name='pred',
                             compare='resp',
                             smoothing_bins=False,
                             mod_name='Unknown',
                             xlabel1=None,
                             xlabel2=None,
                             ylabel1=None,
                             ylabel2=None):

    # HACK: shouldn't hardcode 'stim', might be named something else
    #       or not present at all. Need to figure out a better solution
    #       for special case of idx = 0
    if idx == 0:
        # Can't have anything before index 0, so use input stimulus
        before = rec
        before_sig = rec['stim']
        before.name = '**stim'
    else:
        before = ms.evaluate(rec, modelspec, start=None, stop=idx)
        before_sig = before[sig_name]

    # compute correlation for pre-module before it's over-written
    corr1 = nm.corrcoef(before, pred_name=sig_name, resp_name=compare)

    # now evaluate next module step
    after = ms.evaluate(before.copy(), modelspec, start=idx, stop=idx + 1)
    after_sig = after[sig_name]
    corr2 = nm.corrcoef(after, pred_name=sig_name, resp_name=compare)

    compare_to = rec[compare]
    title1 = '{} vs {} before {}'.format(sig_name, compare, mod_name)
    title2 = '{} vs {} after {}'.format(sig_name, compare, mod_name)
    # TODO: These are coming out the same, but that seems unlikely
    text1 = "r = {0:.5f}".format(corr1)
    text2 = "r = {0:.5f}".format(corr2)

    fn1 = partial(plot_scatter,
                  before_sig,
                  compare_to,
                  title=title1,
                  smoothing_bins=smoothing_bins,
                  xlabel=xlabel1,
                  ylabel=ylabel1,
                  text=text1)
    fn2 = partial(plot_scatter,
                  after_sig,
                  compare_to,
                  title=title2,
                  smoothing_bins=smoothing_bins,
                  xlabel=xlabel2,
                  ylabel=ylabel2,
                  text=text2)

    return fn1, fn2
Beispiel #3
0
def generate_prediction_sets(est, val, modelspecs):
    if type(val) is list:
        # ie, if jackknifing
        new_est = [ms.evaluate(d, m) for m, d in zip(modelspecs, est)]
        new_val = [ms.evaluate(d, m) for m, d in zip(modelspecs, val)]
    else:
        print('val and est must be lists')

    return new_est, new_val
Beispiel #4
0
def before_and_after(rec,
                     modelspec,
                     sig_name,
                     ax=None,
                     title=None,
                     idx=0,
                     channels=0,
                     xlabel='Time',
                     ylabel='Value',
                     **options):
    '''
    Plots a timeseries of specified signal just before and just after
    the transformation performed at some step in the modelspec.

    Arguments:
    ----------
    rec : recording object
        The dataset to use. See nems/recording.py.

    modelspec : list of dicts
        The transformations to perform. See nems/modelspec.py.

    sig_name : str
        Specifies the signal in 'rec' to be examined.

    idx : int
        An index into the modelspec. rec[sig_name] will be plotted
        as it exists after step idx-1 and after step idx.

    Returns:
    --------
    None
    '''
    # HACK: shouldn't hardcode 'stim', might be named something else
    #       or not present at all. Need to figure out a better solution
    #       for special case of idx = 0
    if idx == 0:
        input_name = modelspec[0]['fn_kwargs']['i']
        before = rec[input_name].copy()
        before.name += ' before**'
    else:
        before = ms.evaluate(rec.copy(), modelspec, start=None,
                             stop=idx)[sig_name]
        before.name += ' before'

    after = ms.evaluate(rec, modelspec, start=idx,
                        stop=idx + 1)[sig_name].copy()
    after.name += ' after'
    timeseries_from_signals([before, after],
                            channels=channels,
                            xlabel=xlabel,
                            ylabel=ylabel,
                            ax=ax,
                            title=title,
                            **options)
Beispiel #5
0
def generate_prediction(est, val, modelspecs):

    if type(val) is list:
        # ie, if jackknifing
        new_est = [ms.evaluate(d, m) for m, d in zip(modelspecs, est)]
        new_val = [ms.evaluate(d, m) for m, d in zip(modelspecs, val)]
        new_val = [recording.jackknife_inverse_merge(new_val)]
    else:
        # Evaluate estimation and validation data
        new_est = [ms.evaluate(est, m) for m in modelspecs]
        new_val = [ms.evaluate(val, m) for m in modelspecs]

    return new_est, new_val
Beispiel #6
0
def _reduced_param_pred(mspec, rec, idx, param):
    gc_ms_no_param = mspec.copy()
    gc_ms_no_param[idx]['phi']['%s_mod' % param] = \
            gc_ms_no_param[idx]['phi']['%s' % param].copy()
    pred_no_param = ms.evaluate(rec, gc_ms_no_param)['pred'].as_continuous().T

    return pred_no_param
Beispiel #7
0
def nl_scatter(rec,
               modelspec,
               idx,
               sig_name='pred',
               compare='resp',
               smoothing_bins=False,
               cursor_time=None,
               xlabel1=None,
               ylabel1=None,
               **options):

    # HACK: shouldn't hardcode 'stim', might be named something else
    #       or not present at all. Need to figure out a better solution
    #       for special case of idx = 0

    if 'mask' in rec.signals.keys():
        before = rec.apply_mask()
    else:
        before = rec.copy()
    if idx == 0:
        # Can't have anything before index 0, so use input stimulus
        sig_name = 'stim'
        before_sig = before['stim']
        before.name = '**stim'
    else:
        before = ms.evaluate(before, modelspec, start=None, stop=idx)
        before_sig = before[sig_name]

    # compute correlation for pre-module before it's over-written
    if before[sig_name].shape[0] == 1:
        corr1 = nm.corrcoef(before, pred_name=sig_name, resp_name=compare)
    else:
        corr1 = 0
        log.warning('corr coef expects single-dim predictions')

    compare_to = before[compare]

    module = modelspec[idx]
    mod_name = module['fn'].replace('nems.modules.',
                                    '').replace('.', ' ').replace('_',
                                                                  ' ').title()

    title1 = mod_name
    text1 = "r = {0:.5f}".format(np.mean(corr1))

    ax = plot_scatter(before_sig,
                      compare_to,
                      title=title1,
                      smoothing_bins=smoothing_bins,
                      xlabel=xlabel1,
                      ylabel=ylabel1,
                      text=text1,
                      module=module,
                      **options)

    if cursor_time is not None:
        tbin = int(cursor_time * rec[sig_name].fs)
        x = before_sig.as_continuous()[0, tbin]
        ylim = ax.get_ylim()
        ax.plot([x, x], ylim, 'r-')
Beispiel #8
0
def model_per_time(ctx):
    """
    state_colors : N x 2 list
       color spec for high/low lines in each of the N states
    """

    rec = ctx['val'][0].apply_mask()
    modelspec = ctx['modelspecs'][0]
    epoch = "REFERENCE"
    rec = ms.evaluate(rec, modelspec)

    plt.figure()
    ax = plt.subplot(2, 1, 1)
    state_vars_timeseries(rec, modelspec, ax=ax)

    ax = plt.subplot(2, 1, 2)
    state_vars_psth_all(rec,
                        epoch,
                        psth_name='resp',
                        psth_name2='pred',
                        state_sig='state_raw',
                        colors=None,
                        channel=None,
                        decimate_by=1,
                        ax=ax,
                        files_only=True,
                        modelspec=modelspec)
Beispiel #9
0
def stp_sigmoid_pred_matched(cellid, batch, modelname, LN, include_phi=True):
    xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)
    ln_spec, ln_ctx = xhelp.load_model_xform(cellid, batch, LN)
    modelspec = ctx['modelspec']
    modelspec.recording = ctx['val']
    val = ctx['val'].apply_mask()
    ln_modelspec = ln_ctx['modelspec']
    ln_modelspec.recording = ln_ctx['val']
    ln_val = ln_ctx['val'].apply_mask()

    pred_after_NL = val['pred'].as_continuous().flatten()  # with stp
    val_before_NL = ms.evaluate(ln_val, ln_modelspec, stop=-1)
    pred_before_NL = val_before_NL['pred'].as_continuous().flatten()  # no stp

    stp_idx = find_module('stp', modelspec)
    val_before_stp = ms.evaluate(val, modelspec, stop=stp_idx)
    val_after_stp = ms.evaluate(val, modelspec, stop=stp_idx + 1)
    pred_before_stp = val_before_stp['pred'].as_continuous().mean(
        axis=0).flatten()
    pred_after_stp = val_after_stp['pred'].as_continuous().mean(
        axis=0).flatten()
    stp_effect = (pred_after_stp - pred_before_stp) / (pred_after_stp +
                                                       pred_before_stp)

    fig = plt.figure()
    plasma = plt.get_cmap('plasma')
    plt.scatter(pred_before_NL,
                pred_after_NL,
                c=stp_effect,
                s=2,
                alpha=0.75,
                cmap=plasma)
    plt.title(cellid)
    plt.xlabel('pred in (no stp)')
    plt.ylabel('pred out (with stp)')

    if include_phi:
        stp_phi = modelspec.phi[stp_idx]
        phi_string = '\n'.join(
            ['%s:  %.4E' % (k, v) for k, v in stp_phi.items()])
        fig.text(0.775, 0.9, phi_string, va='top', ha='left')
        plt.subplots_adjust(right=0.775, left=0.075)
    plt.colorbar()

    return fig
Beispiel #10
0
def evaluate_context(ctx,
                     rec_key='val',
                     rec_idx=0,
                     mspec_idx=0,
                     start=None,
                     stop=None):
    rec = ctx[rec_key][0]
    mspec = ctx['modelspecs'][mspec_idx]
    return ms.evaluate(rec, mspec, start=start, stop=stop)
Beispiel #11
0
def reconsitute_rec(batch, cellid_list, modelname):
    '''
    Takes a group of single cell recordings (from cells of a population recording) including their model predictions,
    and builds a recording withe signals containing the responses and predictions of all the cells in the population
    This is to make the recordigs compatible with downstream dispersion analisis or any analysis working with signals
    of neuronal populations
    :param batch: int batch number
    :param cellid_list: [str, str ...] list of cell IDs
    :param modelname: str. modelaname
    :return: NEMS Recording object
    '''

    result_paths = _get_result_paths(batch, cellid_list, modelname)

    cell_resp_dict = dict()
    cell_pred_dict = col.defaultdict()

    for ff, filepath in enumerate(result_paths):
        # use modelsepcs to predict the response of resp
        xfspec, ctx = xforms.load_analysis(filepath=filepath,
                                           eval_model=False,
                                           only=slice(0, 2, 1))
        modelspecs = ctx['modelspecs'][0]
        cellid = modelspecs[0]['meta']['cellid']
        real_modelname = modelspecs[0]['meta']['modelname']
        rec = ctx['rec'].copy()
        rec = ms.evaluate(
            rec, modelspecs)  # recording containing signal for resp and pred

        # holds and organizes the raw data, keeping track of the cell for later concatenations.
        cell_resp_dict.update(
            rec['resp']._data
        )  # in PointProcess signals _data is already a dict, thus the use of update
        cell_pred_dict[cellid] = rec[
            'pred']._data  # in Rasterized signals _data is a matrix, thus the requirement to asign key.

    # create a new population recording. pull stim from last single cell, create signal from meta form last resp signal and
    # stacked data for all cells. modify signal metadata to be consistent with new data and cells contained
    pop_resp = rec['resp']._modified_copy(data=cell_resp_dict,
                                          chans=list(cell_resp_dict.keys()),
                                          nchans=len(
                                              list(cell_resp_dict.keys())))

    stack_data = np.concatenate(list(cell_pred_dict.values()), axis=0)
    pop_pred = rec['pred']._modified_copy(data=stack_data,
                                          chans=list(cell_pred_dict.keys()),
                                          nchans=len(
                                              list(cell_pred_dict.keys())))

    reconstituted_recording = rec.copy()

    reconstituted_recording['resp'] = pop_resp
    reconstituted_recording['pred'] = pop_pred
    del reconstituted_recording.signals['state']
    del reconstituted_recording.signals['state_raw']

    return reconstituted_recording
Beispiel #12
0
def dynamic_sigmoid_range(cellid,
                          batch,
                          modelname,
                          plot=True,
                          show_min_max=True):

    xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)
    modelspec = ctx['modelspec']
    modelspec.recording = ctx['val']
    val = ctx['val'].apply_mask()
    ctpred = val['ctpred'].as_continuous().flatten()
    val_before_dsig = ms.evaluate(val, modelspec, stop=-1)
    pred_before_dsig = val_before_dsig['pred'].as_continuous().flatten()
    lows = {k: v for k, v in modelspec[-1]['phi'].items() if '_mod' not in k}
    highs = {k[:-4]: v for k, v in modelspec[-1]['phi'].items() if '_mod' in k}
    for k in lows:
        if k not in highs:
            highs[k] = lows[k]
    # re-sort keys to make sure they're in the same order
    lows = {k: lows[k] for k in sorted(lows)}
    highs = {k: highs[k] for k in sorted(highs)}

    ctmax_val = np.max(ctpred)
    ctmin_val = np.min(ctpred)
    ctmax_idx = np.argmax(ctpred)
    ctmin_idx = np.argmin(ctpred)

    thetas = list(lows.values())
    theta_mods = list(highs.values())
    for t, t_m, k in zip(thetas, theta_mods, list(lows.keys())):
        lows[k] = t + (t_m - t) * ctmin_val
        highs[k] = t + (t_m - t) * ctmax_val

    low_out = _double_exponential(pred_before_dsig, **lows).flatten()
    high_out = _double_exponential(pred_before_dsig, **highs).flatten()

    if plot:
        fig = plt.figure()
        plt.scatter(pred_before_dsig, low_out, color='blue', s=0.7, alpha=0.6)
        plt.scatter(pred_before_dsig, high_out, color='red', s=0.7, alpha=0.6)

        if show_min_max:
            max_pred = pred_before_dsig[ctmax_idx]
            min_pred = pred_before_dsig[ctmin_idx]
            plt.scatter(min_pred,
                        low_out[ctmin_idx],
                        facecolors=None,
                        edgecolors='blue',
                        s=60)
            plt.scatter(max_pred,
                        high_out[ctmax_idx],
                        facecolors=None,
                        edgecolors='red',
                        s=60)

    return pred_before_dsig, low_out, high_out
Beispiel #13
0
def dynamic_sigmoid_differences(batch,
                                modelname,
                                hist_bins=60,
                                test_limit=None,
                                save_path=None,
                                load_path=None,
                                use_quartiles=False,
                                avg_bin_count=20):

    if load_path is None:
        cellids = nd.get_batch_cells(batch, as_list=True)
        ratios = []
        for cellid in cellids[:test_limit]:
            xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)
            val = ctx['val'].apply_mask()
            ctpred = val['ctpred'].as_continuous().flatten()
            pred_after = val['pred'].as_continuous().flatten()
            val_before = ms.evaluate(val, ctx['modelspec'], stop=-1)
            pred_before = val_before['pred'].as_continuous().flatten()
            median_ct = np.nanmedian(ctpred)
            if use_quartiles:
                low = np.percentile(ctpred, 25)
                high = np.percentile(ctpred, 75)
                low_mask = (ctpred >= low) & (ctpred < median_ct)
                high_mask = ctpred >= high
            else:
                low_mask = ctpred < median_ct
                high_mask = ctpred >= median_ct

            # TODO: do some kind of binning here since the two vectors
            # don't actually overlap in x axis
            mean_before, bin_masks = _binned_xvar(pred_before, avg_bin_count)
            low = _binned_yavg(pred_after, low_mask, bin_masks)
            high = _binned_yavg(pred_after, high_mask, bin_masks)

            ratio = np.nanmean((low - high) / (np.abs(low) + np.abs(high)))
            ratios.append(ratio)

        ratios = np.array(ratios)
        if save_path is not None:
            np.save(save_path, ratios)
    else:
        ratios = np.load(load_path)

    plt.figure()
    plt.hist(ratios,
             bins=hist_bins,
             color=[wsu_gray_light],
             edgecolor='black',
             linewidth=1)
    #plt.rc('text', usetex=True)
    #plt.xlabel(r'\texit{\frac{low-high}{\left|high\right|+\left|low\right|}}')
    plt.xlabel('(low - high)/(|low| + |high|)')
    plt.ylabel('cell count')
    plt.title("difference of low-contrast output and high-contrast output\n"
              "positive means low-contrast has higher firing rate on average")
Beispiel #14
0
def before_and_after_signal(rec, modelspec, idx, sig_name='pred'):
    # HACK: shouldn't hardcode 'stim', might be named something else
    #       or not present at all. Need to figure out a better solution
    #       for special case of idx = 0
    if idx == 0:
        # Can't have anything before index 0, so use input stimulus
        before = rec
        before_sig = copy.deepcopy(rec['stim'])
    else:
        before = ms.evaluate(rec, modelspec, start=None, stop=idx)
        before_sig = copy.deepcopy(before[sig_name])

    before_sig.name = 'before'

    after = ms.evaluate(before.copy(), modelspec, start=idx, stop=idx+1)
    after_sig = copy.deepcopy(after[sig_name])
    after_sig.name = 'after'

    return before_sig, after_sig
Beispiel #15
0
def generate_prediction(est, val, modelspecs):
    list_val = False
    if type(val) is list:
        # ie, if jackknifing
        list_val = True
    else:
        # Evaluate estimation and validation data

        # Since ms.evaluate only does a shallow copy of rec, successive
        # evaluations of one rec on many modelspecs just results in a list of
        # different pointers to the same recording. So need to force copies of
        # est/val before evaluating.
        if len(modelspecs) == 1:
            # no copies needed for 1 modelspec
            est = [est]
            val = [val]
        else:
            est = [est.copy() for i, _ in enumerate(modelspecs)]
            val = [val.copy() for i, _ in enumerate(modelspecs)]

    new_est = []
    new_val = []
    for m, e, v in zip(modelspecs, est, val):
        # nan-out periods outside of mask
        e = ms.evaluate(e, m)
        v = ms.evaluate(v, m)
        if 'mask' in v.signals.keys():
            m = v['mask'].as_continuous()
            x = v['pred'].as_continuous().copy()
            x[..., m[0,:] == 0] = np.nan
            v['pred'] = v['pred']._modified_copy(x)
        new_est.append(e)
        new_val.append(v)

    #new_est = [ms.evaluate(d, m) for m, d in zip(modelspecs, est)]
    #new_val = [ms.evaluate(d, m) for m, d in zip(modelspecs, val)]

    if list_val:
        new_val = [recording.jackknife_inverse_merge(new_val)]

    return new_est, new_val
Beispiel #16
0
def dynamic_sigmoid_distribution(cellid,
                                 batch,
                                 modelname,
                                 sample_every=10,
                                 alpha=0.1):

    xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)
    modelspec = ctx['modelspec']
    val = ctx['val'].apply_mask()
    modelspec.recording = val
    val_before_dsig = ms.evaluate(val, modelspec, stop=-1)
    pred_before_dsig = val_before_dsig['pred'].as_continuous().T
    ctpred = val_before_dsig['ctpred'].as_continuous().T

    lows = {k: v for k, v in modelspec[-1]['phi'].items() if '_mod' not in k}
    highs = {k[:-4]: v for k, v in modelspec[-1]['phi'].items() if '_mod' in k}
    for k in lows:
        if k not in highs:
            highs[k] = lows[k]
    # re-sort keys to make sure they're in the same order
    lows = {k: lows[k] for k in sorted(lows)}
    highs = {k: highs[k] for k in sorted(highs)}
    thetas = list(lows.values())
    theta_mods = list(highs.values())

    fig = plt.figure()
    for i in range(int(len(pred_before_dsig) / sample_every)):
        try:
            ts = {}
            for t, t_m, k in zip(thetas, theta_mods, list(lows.keys())):
                ts[k] = t + (t_m - t) * ctpred[i * sample_every]
            y = _double_exponential(pred_before_dsig, **ts)
            plt.scatter(pred_before_dsig,
                        y,
                        color='black',
                        alpha=alpha,
                        s=0.01)
        except IndexError:
            # Will happen on last attempt if array wasn't evenly divisible
            # by sample_every
            pass
    t_max = {}
    t_min = {}
    for t, t_m, k in zip(thetas, theta_mods, list(lows.keys())):
        t_max[k] = t + (t_m - t) * np.nanmax(ctpred)
        t_min[k] = t + (t_m - t) * np.nanmin(ctpred)
    max_out = _double_exponential(pred_before_dsig, **t_max)
    min_out = _double_exponential(pred_before_dsig, **t_min)
    plt.scatter(pred_before_dsig, max_out, color='red', s=0.1)
    plt.scatter(pred_before_dsig, min_out, color='blue', s=0.1)
Beispiel #17
0
def mod_output(rec,
               modelspec,
               sig_name='pred',
               ax=None,
               title=None,
               idx=0,
               channels=0,
               xlabel='Time',
               ylabel='Value',
               **options):
    '''
    Plots a time series of specified signal output by step in the modelspec.

    Arguments:
    ----------
    rec : recording object
        The dataset to use. See nems/recording.py.

    modelspec : list of dicts
        The transformations to perform. See nems/modelspec.py.

    sig_name : str or list of strings
        Specifies the signal in 'rec' to be examined.

    idx : int
        An index into the modelspec. rec[sig_name] will be plotted
        as it exists after step idx-1 and after step idx.

    Returns:
    --------
    ax : axis containing plot
    '''
    if type(sig_name) is str:
        sig_name = [sig_name]

    trec = ms.evaluate(rec, modelspec, stop=idx + 1)
    if 'mask' in trec.signals.keys():
        trec = trec.apply_mask()

    sigs = [trec[s] for s in sig_name]
    ax = timeseries_from_signals(sigs,
                                 channels=channels,
                                 xlabel=xlabel,
                                 ylabel=ylabel,
                                 ax=ax,
                                 title=title,
                                 **options)
    return ax
Beispiel #18
0
def init_dexp(rec, modelspec):
    """
    choose initial values for dexp applied after preceeding fir is
    initialized
    """
    target_i = None
    target_module = 'double_exponential'
    for i, m in enumerate(modelspec):
        if target_module in m['fn']:
            target_i = i
            break

    if not target_i:
        log.info(
            "target_module: {} not found in modelspec.".format(target_module))
        return modelspec
    else:
        log.info("target_module: {0} found at modelspec[{1}].".format(
            target_module, target_i - 1))

    if target_i == len(modelspec):
        fit_portion = modelspec
    else:
        fit_portion = modelspec[:target_i]

    # generate prediction from module preceeding dexp
    rec = ms.evaluate(rec, fit_portion)
    resp = rec['resp'].as_continuous()
    pred = rec['pred'].as_continuous()
    keepidx = np.isfinite(resp) * np.isfinite(pred)
    resp = resp[keepidx]
    pred = pred[keepidx]

    # choose phi s.t. dexp starts as almost a straight line
    # phi=[max_out min_out slope mean_in]
    meanr = np.nanmean(resp)
    stdr = np.nanstd(resp)
    modelspec[target_i]['phi'] = {}
    modelspec[target_i]['phi']['amplitude'] = stdr * 8
    modelspec[target_i]['phi']['base'] = meanr - stdr * 4
    modelspec[target_i]['phi']['kappa'] = np.log(np.std(pred) / 10)
    modelspec[target_i]['phi']['shift'] = np.mean(pred)
    log.info(modelspec[target_i])

    return modelspec
Beispiel #19
0
def dynamic_sigmoid_pred_matched(cellid, batch, modelname, include_phi=True):
    xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)
    modelspec = ctx['modelspec']
    modelspec.recording = ctx['val']
    val = ctx['val'].apply_mask()
    ctpred = val['ctpred'].as_continuous().flatten()
    # HACK
    # this really shouldn't happen.. but for some reason some  of the
    # batch 263 cells are getting nans, so temporary fix.
    ctpred[np.isnan(ctpred)] = 0
    pred_after_dsig = val['pred'].as_continuous().flatten()
    val_before_dsig = ms.evaluate(val, modelspec, stop=-1)
    pred_before_dsig = val_before_dsig['pred'].as_continuous().flatten()

    fig = plt.figure(figsize=(12, 7))
    plasma = plt.get_cmap('plasma')
    plt.scatter(pred_before_dsig,
                pred_after_dsig,
                c=ctpred,
                s=2,
                alpha=0.75,
                cmap=plasma)
    plt.title(cellid)
    plt.xlabel('pred in')
    plt.ylabel('pred out')

    if include_phi:
        dsig_phi = modelspec.phi[-1]
        phi_string = '\n'.join(
            ['%s:  %.4E' % (k, v) for k, v in dsig_phi.items()])
        thetas = list(dsig_phi.keys())[0:-1:2]
        mods = list(dsig_phi.keys())[1::2]
        weights = {
            k: (dsig_phi[mods[i]] - dsig_phi[thetas[i]])
            for i, k in enumerate(thetas)
        }
        weights_string = 'weights:\n' + '\n'.join(
            ['%s:  %.4E' % (k, v) for k, v in weights.items()])
        fig.text(0.775, 0.9, phi_string, va='top', ha='left')
        fig.text(0.775, 0.1, weights_string, va='bottom', ha='left')
        plt.subplots_adjust(right=0.775, left=0.075)

    plt.colorbar()
    return fig
Beispiel #20
0
def contrast_kernel_output(rec,
                           modelspec,
                           ax=None,
                           title=None,
                           idx=0,
                           channels=0,
                           xlabel='Time',
                           ylabel='Value',
                           **options):

    output = ms.evaluate(rec, modelspec, stop=idx + 1)['ctpred']
    timeseries_from_signals([output],
                            channels=channels,
                            xlabel=xlabel,
                            ylabel=ylabel,
                            ax=ax,
                            title=title)

    return ax
Beispiel #21
0
def single_state_mod_index(rec,
                           modelspec,
                           epoch='REFERENCE',
                           psth_name='pred',
                           state_sig='state',
                           state_chan='pupil'):

    if type(state_chan) is list:
        if len(state_chan) == 0:
            state_chan = rec[state_sig].chans

        mod_list = [
            single_state_mod_index(rec,
                                   modelspec,
                                   epoch=epoch,
                                   psth_name=psth_name,
                                   state_sig=state_sig,
                                   state_chan=s) for s in state_chan
        ]
        return mod_list

    sidx = find_module('state', modelspec)
    if sidx is None:
        raise ValueError("no state signal found")

    modelspec = copy.deepcopy(modelspec)

    state_chan_idx = rec[state_sig].chans.index(state_chan)
    k = np.ones(rec[state_sig].shape[0], dtype=bool)
    k[0] = False
    k[state_chan_idx] = False
    modelspec[sidx]['phi']['d'][:, k] = 0
    modelspec[sidx]['phi']['g'][:, k] = 0

    newrec = ms.evaluate(rec, modelspec)

    return np.array(
        state_mod_index(newrec,
                        epoch=epoch,
                        psth_name=psth_name,
                        state_sig=state_sig,
                        state_chan=state_chan))
Beispiel #22
0
# 4. Now add the signal to the recording
rec.add_signal(respavg)

# Now split into est and val data sets
est, val = rec.split_using_epoch_occurrence_counts(epoch_regex='^STIM_')
# est, val = rec.split_at_time(0.8)

# Load some modelspecs and create their predictions
modelspecs = ms.load_modelspecs(modelspecs_dir, 'TAR010c-18-1')#, regex=('^TAR010c-18-1\.{\d+}\.json'))
# Testing summary statistics:
means, stds = ms.summary_stats(modelspecs)
print("means: {}".format(means))
print("stds: {}".format(stds))

pred = [ms.evaluate(val, m)['pred'] for m in modelspecs]

# Shorthands for unchanging signals
stim = val['stim']
resp = val['resp']
respavg = val['respavg']


def plot_layout(plot_fn_struct):
    '''
    Accepts a list of lists of functions of 1 argument (ax).
    Basically a fancy subplot that lets you lay out functions without
    worrying about details. See example below
    '''
    # Count how many plot functions we want
    nrows = len(plot_fn_struct)
Beispiel #23
0
def _get_pred_resp(rec, modelspec, pred_name, resp_name):
    """Evaluates rec using a fitted modelspec, then returns pred and resp
    signals."""
    new_rec = ms.evaluate(rec, modelspec)
    return [new_rec[pred_name], new_rec[resp_name]]
Beispiel #24
0
def init_logsig(rec, modelspec):
    '''
    Initialization of priors for logistic_sigmoid,
    based on process described in methods of Rabinowitz et al. 2014.
    '''
    # preserve input modelspec
    modelspec = copy.deepcopy(modelspec)

    target_i = find_module('logistic_sigmoid', modelspec)
    if target_i is None:
        log.warning("No logsig module was found, can't initialize.")
        return modelspec

    if target_i == len(modelspec):
        fit_portion = modelspec
    else:
        fit_portion = modelspec[:target_i]

    # generate prediction from module preceeding dexp
    ms.fit_mode_on(fit_portion)
    rec = ms.evaluate(rec, fit_portion)
    ms.fit_mode_off(fit_portion)

    pred = rec['pred'].as_continuous()
    resp = rec['resp'].as_continuous()

    mean_pred = np.nanmean(pred)
    min_pred = np.nanmean(pred) - np.nanstd(pred) * 3
    max_pred = np.nanmean(pred) + np.nanstd(pred) * 3
    if min_pred < 0:
        min_pred = 0
        mean_pred = (min_pred + max_pred) / 2

    pred_range = max_pred - min_pred
    min_resp = max(np.nanmean(resp) - np.nanstd(resp) * 3, 0)  # must be >= 0
    max_resp = np.nanmean(resp) + np.nanstd(resp) * 3
    resp_range = max_resp - min_resp

    # Rather than setting a hard value for initial phi,
    # set the prior distributions and let the fitter/analysis
    # decide how to use it.
    base0 = min_resp + 0.05 * (resp_range)
    amplitude0 = resp_range
    shift0 = mean_pred
    kappa0 = pred_range
    log.info("Initial   base,amplitude,shift,kappa=({}, {}, {}, {})".format(
        base0, amplitude0, shift0, kappa0))

    base = ('Exponential', {'beta': base0})
    amplitude = ('Exponential', {'beta': amplitude0})
    shift = ('Normal', {'mean': shift0, 'sd': pred_range})
    kappa = ('Exponential', {'beta': kappa0})

    modelspec[target_i]['prior'].update({
        'base': base,
        'amplitude': amplitude,
        'shift': shift,
        'kappa': kappa
    })

    modelspec[target_i]['bounds'] = {
        'base': (1e-15, None),
        'amplitude': (1e-15, None),
        'shift': (None, None),
        'kappa': (1e-15, None)
    }

    return modelspec
Beispiel #25
0
def init_dexp(rec, modelspec):
    """
    choose initial values for dexp applied after preceeding fir is
    initialized
    """
    # preserve input modelspec
    modelspec = copy.deepcopy(modelspec)

    target_i = find_module('double_exponential', modelspec)
    if target_i is None:
        log.warning("No dexp module was found, can't initialize.")
        return modelspec

    if target_i == len(modelspec):
        fit_portion = modelspec
    else:
        fit_portion = modelspec[:target_i]

    # ensures all previous modules have their phi initialized
    # choose prior mean if not found
    for i, m in enumerate(fit_portion):
        if ('phi' not in m.keys()) and ('prior' in m.keys()):
            log.debug('Phi not found for module, using mean of prior: %s', m)
            m = priors.set_mean_phi([m])[0]  # Inits phi for 1 module
            fit_portion[i] = m

    # generate prediction from module preceeding dexp
    ms.fit_mode_on(fit_portion)
    rec = ms.evaluate(rec, fit_portion)
    ms.fit_mode_off(fit_portion)

    in_signal = modelspec[target_i]['fn_kwargs']['i']
    pchans = rec[in_signal].shape[0]
    amp = np.zeros([pchans, 1])
    base = np.zeros([pchans, 1])
    kappa = np.zeros([pchans, 1])
    shift = np.zeros([pchans, 1])

    for i in range(pchans):
        resp = rec['resp'].as_continuous()
        pred = rec[in_signal].as_continuous()[i:(i + 1), :]
        if resp.shape[0] == pchans:
            resp = resp[i:(i + 1), :]

        keepidx = np.isfinite(resp) * np.isfinite(pred)
        resp = resp[keepidx]
        pred = pred[keepidx]

        # choose phi s.t. dexp starts as almost a straight line
        # phi=[max_out min_out slope mean_in]
        # meanr = np.nanmean(resp)
        stdr = np.nanstd(resp)

        # base = np.max(np.array([meanr - stdr * 4, 0]))
        base[i, 0] = np.min(resp)
        # base = meanr - stdr * 3

        # amp = np.max(resp) - np.min(resp)
        amp[i, 0] = stdr * 3

        shift[i, 0] = np.mean(pred)
        # shift = (np.max(pred) + np.min(pred)) / 2

        predrange = 2 / (np.max(pred) - np.min(pred) + 1)
        kappa[i, 0] = np.log(predrange)

    modelspec[target_i]['phi'] = {
        'amplitude': amp,
        'base': base,
        'kappa': kappa,
        'shift': shift
    }
    log.info("Init dexp: %s", modelspec[target_i]['phi'])

    return modelspec
Beispiel #26
0
def fit_population_channel_fast2(rec, modelspec,
                                 fit_set_all, fit_set_slice,
                                 analysis_function=analysis.fit_basic,
                                 metric=metrics.nmse,
                                 fitter=scipy_minimize, fit_kwargs={}):

    # guess at number of subspace dimensions
    dim_count = modelspec[fit_set_slice[0]]['phi']['coefficients'].shape[1]
    wi = [i for i in fit_set_slice if 'weight_channels' in modelspec[i]['fn']]
    wi = wi[0]
    li = [i for i in fit_set_slice if 'levelshift' in modelspec[i]['fn']]
    li = li[0]

    for d in range(dim_count):
        # fit each dim separately
        log.info('Updating dim %d/%d', d+1, dim_count)

        # create modelspec with single population subspace filter
        tmodelspec = _extract_pop_channel(modelspec, d, fit_set_all)

        # temp append full-population layer as non-free parameters
        tmodelspec2 = copy.deepcopy(tmodelspec)
        for i in fit_set_slice:
            m = copy.deepcopy(modelspec[i])
            for k, v in m['phi'].items():
                # just applies to wc module?
                if v.shape[1] >= dim_count:
                    m['phi'][k] = v[:, [d]]
                else:
                    m['phi'][k] = v
            tmodelspec2.append(m)

        # compute residual from prediction by the rest of the pop model
        trec = rec.copy()
        trec = ms.evaluate(trec, modelspec)
        r = trec['resp'].as_continuous()
        p = trec['pred'].as_continuous().copy()
        respstd = np.nanstd(r)  # std of actual response

        trec = ms.evaluate(trec, tmodelspec2)
        p2 = trec['pred'].as_continuous()

        trec = ms.evaluate(trec, tmodelspec)

        # residual we're trying to predict with tmodelspec
        r = r - p + p2

        # calculate streamlined nMSE function for single pop channel model
        # by inverting neuron-specific gains and level shifts
        a = modelspec[wi]['phi']['coefficients'][:, [d]]
        b = modelspec[li]['phi']['level']

        r -= b  # subtract level shift from residual
        A1 = np.sum(a ** 2)
        A2 = np.sum(2 * a * r, axis=0, keepdims=True)
        A3 = np.sum(r**2, axis=0, keepdims=True)

        def my_nmse(result):
            '''
            hacked from nems.metrics.mse.nmse. optimized nMSE for situation when a single
            population channel is predicting responses with fixed per-neuron gains and levelshifts
            A1, A2, A3, respstd defined outside of function
            result :  recording object updated by fitter, prediction response of single pop channel
            '''
            X1 = result['pred'].as_continuous()

            squared_errors = A1 * (X1**2) - A2 * X1 + A3
            mean_sq_err = np.sum(squared_errors) / (r.shape[0]*r.shape[1])

            mse = np.sqrt(mean_sq_err)
            return mse / respstd

        # import pdb
        # pdb.set_trace()

        tmodelspec = analysis_function(trec, tmodelspec, fitter=fitter,
                                       metric=my_nmse, fit_kwargs=fit_kwargs)

        modelspec = _update_pop_channel(tmodelspec, modelspec, d, fit_set_all)

    return modelspec
Beispiel #27
0
def fit_population_channel_fast(rec, modelspec,
                                fit_set_all, fit_set_slice,
                                analysis_function=analysis.fit_basic,
                                metric=metrics.nmse,
                                fitter=scipy_minimize, fit_kwargs={}):

    # guess at number of subspace dimensions
    dim_count = modelspec[fit_set_slice[0]]['phi']['coefficients'].shape[1]

    for d in range(dim_count):
        # fit each dim separately
        log.info('Updating dim %d/%d', d+1, dim_count)

        # create modelspec with single population subspace filter
        tmodelspec = ms.ModelSpec()
        for i in fit_set_all:
            m = copy.deepcopy(modelspec[i])
            for k, v in m['phi'].items():
                x = v.shape[0]
                if x >= dim_count:
                    x1 = int(x/dim_count) * d
                    x2 = int(x/dim_count) * (d+1)

                    m['phi'][k] = v[x1:x2]
                    if 'bank_count' in m['fn_kwargs'].keys():
                        m['fn_kwargs']['bank_count'] = 1
                else:
                    # single model-wide parameter, only fit for d==0
                    if d == 0:
                        m['phi'][k] = v
                    else:
                        m['fn_kwargs'][k] = v  # keep fixed for d>0
                        del m['phi']
                        del m['prior']
            tmodelspec.append(m)

        # append full-population layer as non-free parameters
        for i in fit_set_slice:
            m = copy.deepcopy(modelspec[i])
            for k, v in m['phi'].items():
                # just applies to wc module?
                if v.shape[1] >= dim_count:
                    m['fn_kwargs'][k] = v[:, [d]]
                else:
                    m['fn_kwargs'][k] = v
                # print(k)
                # print(m['fn_kwargs'][k])
            del m['phi']
            del m['prior']
            tmodelspec.append(m)
            # print(tmodelspec[-1]['fn_kwargs'])

        # compute residual from prediction by the rest of the pop model
        trec = rec.copy()
        trec = ms.evaluate(trec, modelspec)
        r = trec['resp'].as_continuous()
        p = trec['pred'].as_continuous()

        trec = ms.evaluate(trec, tmodelspec)
        p2 = trec['pred'].as_continuous()
        trec['resp'] = trec['resp']._modified_copy(data=r-p+p2)

        # import pdb;
        # pdb.set_trace()
        tmodelspec = analysis_function(trec, tmodelspec, fitter=fitter,
                                       metric=metric, fit_kwargs=fit_kwargs)

        for i in fit_set_all:
            for k, v in tmodelspec[i]['phi'].items():
                x = modelspec[i]['phi'][k].shape[0]
                if x >= dim_count:
                    x1 = int(x/dim_count) * d
                    x2 = int(x/dim_count) * (d+1)

                    modelspec[i]['phi'][k][x1:x2] = v
                else:
                    modelspec[i]['phi'][k] = v

        # for i in fit_set_slice:
        #    for k, v in tmodelspec[i]['phi'].items():
        #        if modelspec[i]['phi'][k].shape[0] >= dim_count:
        #            modelspec[i]['phi'][k][d,:] = v

        # print([modelspec.phi[f] for f in fit_set_all])

    return modelspec
Beispiel #28
0
def fit_population_slice(rec, modelspec, slice=0, fit_set=None,
                         analysis_function=analysis.fit_basic,
                         metric=metrics.nmse,
                         fitter=scipy_minimize, fit_kwargs={}):

    """
    fits a slice of a population model. modified from prefit_mod_subset

    slice: int
        response channel to fit
    fit_set: list
        list of mod names to fit

    """

    # preserve input modelspec
    modelspec = copy.deepcopy(modelspec)

    slice_count = rec['resp'].shape[0]
    if slice > slice_count:
        raise ValueError("Slice %d > slice_count %d", slice, slice_count)

    if fit_set is None:
        raise ValueError("fit_set list of module indices must be specified")

    if type(fit_set[0]) is int:
        fit_idx = fit_set
    else:
        fit_idx = []
        for i, m in enumerate(modelspec):
            for fn in fit_set:
                if fn in m['fn']:
                    fit_idx.append(i)

    # identify any excluded modules and take them out of temp modelspec
    # that will be fit here
    tmodelspec = ms.ModelSpec()
    sliceinfo = []
    for i, m in enumerate(modelspec):
        m = copy.deepcopy(m)
        # need to have phi in place
        if not m.get('phi'):
            log.info('Initializing phi for module %d (%s)', i, m['fn'])
            m = priors.set_mean_phi([m])[0]  # Inits phi

        if i in fit_idx:
            s = {}
            for key, value in m['phi'].items():
                log.debug('Slicing %d (%s) key %s chan %d for fit',
                          i, m['fn'], key, slice)

                # keep only sliced channel(s)
                if 'bank_count' in m['fn_kwargs'].keys():
                    bank_count = m['fn_kwargs']['bank_count']
                    filters_per_bank = int(value.shape[0] / bank_count)
                    slices = np.arange(slice*filters_per_bank,
                                       (slice+1)*filters_per_bank)
                    m['phi'][key] = value[slices, ...]
                    s[key] = slices
                    m['fn_kwargs']['bank_count'] = 1
                elif value.shape[0] == slice_count:
                    m['phi'][key] = value[[slice], ...]
                    s[key] = [slice]
                else:
                    raise ValueError("Not sure how to slice %s %s",
                                     m['fn'], key)

            # record info about how sliced this module parameter
            sliceinfo.append(s)

        tmodelspec.append(m)

    if len(fit_idx) == 0:
        log.info('No modules matching fit_set for slice fit')
        return modelspec

    exclude_idx = np.setdiff1d(np.arange(0, len(modelspec)),
                               np.array(fit_idx)).tolist()
    for i in exclude_idx:
        m = tmodelspec[i]
        log.debug('Freezing phi for module %d (%s)', i, m['fn'])

        m['fn_kwargs'].update(m['phi'])
        m['phi'] = {}
        tmodelspec[i] = m

    # generate temp recording with only resposnes of interest
    temp_rec = rec.copy()
    slice_chans = [temp_rec['resp'].chans[slice]]
    temp_rec['resp'] = temp_rec['resp'].extract_channels(slice_chans)

    # remove initial modules
    first_idx = fit_idx[0]
    if first_idx > 0:
        # print('firstidx {}'.format(first_idx))
        temp_rec = ms.evaluate(temp_rec, tmodelspec, stop=first_idx)
        # temp_rec['stim'] = temp_rec['pred'].copy()
        # tmodelspec = tmodelspec.copy(lb=first_idx)
        # tmodelspec[0]['fn_kwargs']['i'] = 'stim'
        tmodelspec = tmodelspec.copy()
        tmodelspec.fast_eval_on(rec=temp_rec, subset=fit_idx)
        first_idx = 0
        # print(tmodelspec)
    # print(temp_rec.signals.keys())

    # IS this mask necessary? Does it work?
    # if 'mask' in temp_rec.signals.keys():
    #    print("Data len pre-mask: %d" % (temp_rec['mask'].shape[1]))
    #    temp_rec = temp_rec.apply_mask()
    #    print("Data len post-mask: %d" % (temp_rec['mask'].shape[1]))

    # fit the subset of modules
    temp_rec = ms.evaluate(temp_rec, tmodelspec)
    error_before = metric(temp_rec)
    tmodelspec = analysis_function(temp_rec, tmodelspec, fitter=fitter,
                                   metric=metric, fit_kwargs=fit_kwargs)
    tmodelspec.fast_eval_off()
    temp_rec = ms.evaluate(temp_rec, tmodelspec)
    error_after = metric(temp_rec)
    dError = error_before - error_after
    if dError < 0:
        log.info("dError (%.6f - %.6f) = %.6f worse. not updating modelspec"
                 % (error_before, error_after, dError))
    else:
        log.info("dError (%.6f - %.6f) = %.6f better. updating modelspec"
                 % (error_before, error_after, dError))

        # reassemble the full modelspec with updated phi values from tmodelspec
        for i, mod_idx in enumerate(fit_idx):
            m = copy.deepcopy(modelspec[mod_idx])
            # need to have phi in place
            if not m.get('phi'):
                log.info('Intializing phi for module %d (%s)', mod_idx, m['fn'])
                m = priors.set_mean_phi([m])[0]  # Inits phi
            for key, value in tmodelspec[mod_idx - first_idx]['phi'].items():
                # print(key)
                # print(m['phi'][key].shape)
                # print(sliceinfo[i][key])
                # print(value)
                m['phi'][key][sliceinfo[i][key], :] = value

            modelspec[mod_idx] = m

    return modelspec
Beispiel #29
0
def fit_population_iteratively(
        est, modelspec,
        cost_function=basic_cost,
        fitter=coordinate_descent, evaluator=ms.evaluate,
        segmentor=nems.segmentors.use_all_data,
        mapper=nems.fitters.mappers.simple_vector,
        metric=lambda data: nems.metrics.api.nmse(data, 'pred', 'resp'),
        metaname='fit_basic', fit_kwargs={},
        module_sets=None, invert=False, tolerances=None, tol_iter=50,
        fit_iter=10, IsReload=False, **context
        ):
    '''
    Required Arguments:
     est          A recording object
     modelspec     A modelspec object

    Optional Arguments:
    TODO: need to deal with the fact that you can't pass functions in an xforms-frieldly fucntion
     fitter        (CURRENTLY NOT USED?)
                   A function of (sigma, costfn) that tests various points,
                   in fitspace (i.e. sigmas) using the cost function costfn,
                   and hopefully returns a better sigma after some time.
     mapper        (CURRENTLY NOT USED?)
                   A class that has two methods, pack and unpack, which define
                   the mapping between modelspecs and a fitter's fitspace.
     segmentor     (CURRENTLY NOT USED?)
                   An function that selects a subset of the data during the
                   fitting process. This is NOT the same as est/val data splits
     metric        A function of a Recording that returns an error value
                   that is to be minimized.

     module_sets   (CURRENTLY NOT USED?)
                   A nested list specifying which model indices should be fit.
                   Overall iteration will occurr len(module_sets) many times.
                   ex: [[0], [1, 3], [0, 1, 2, 3]]

     invert        (CURRENTLY NOT USED?)
                   Boolean. Causes module_sets to specify the model indices
                   that should *not* be fit.


    Returns
    A list containing a single modelspec, which has the best parameters found
    by this fitter.
    '''

    if IsReload:
        return {}

    modelspec = copy.deepcopy(modelspec)
    data = est.copy()

    fit_set_all, fit_set_slice = _figure_out_mod_split(modelspec)

    if tolerances is None:
        tolerances = [1e-4, 1e-5]

    # apply mask to remove invalid portions of signals and allow fit to
    # only evaluate the model on the valid portion of the signals
    # then delete the mask signal so that it's not reapplied on each fit
    if 'mask' in data.signals.keys():
        log.info("Data len pre-mask: %d", data['mask'].shape[1])
        data = data.apply_mask()
        log.info("Data len post-mask: %d", data['mask'].shape[1])
        del data.signals['mask']

    start_time = time.time()
    ms.fit_mode_on(modelspec, data)

    # modelspec = init_pop_pca(data, modelspec)
    # print(modelspec)

    # Ensure that phi exists for all modules; choose prior mean if not found
    # for i, m in enumerate(modelspec):
    #    if ('phi' not in m.keys()) and ('prior' in m.keys()):
    #        m = nems.priors.set_mean_phi([m])[0]  # Inits phi for 1 module
    #        log.debug('Phi not found for module, using mean of prior: {}'
    #                  .format(m))
    #        modelspec[i] = m

    error = np.inf

    slice_count = data['resp'].shape[0]
    step_size = 0.1
    if 'nonlinearity' in modelspec[-1]['fn']:
        skip_nl_first = True
        tolerances = [tolerances[0]] + tolerances
    else:
        skip_nl_first = False

    for toli, tol in enumerate(tolerances):

        log.info("Fitting subsets with tol: %.2E fit_iter %d tol_iter %d",
                 tol, fit_iter, tol_iter)
        cd_kwargs = fit_kwargs.copy()
        cd_kwargs.update({'tolerance': tol, 'max_iter': fit_iter,
                          'step_size': step_size})
        sp_kwargs = fit_kwargs.copy()
        sp_kwargs.update({'tolerance': tol, 'max_iter': fit_iter})

        if (toli == 0) and skip_nl_first:
            log.info('skipping nl on first tolerance loop')
            saved_modelspec = copy.deepcopy(modelspec)
            saved_fit_set_slice = fit_set_slice.copy()
            # import pdb;
            # pdb.set_trace()
            modelspec.pop_module()
            fit_set_slice = fit_set_slice[:-1]

        inner_i = 0
        error_reduction = np.inf
        # big_slice = 0
        # big_n = data['resp'].ntimes
        # big_step = int(big_n/10)
        # big_slice_size = int(big_n/2)
        while (error_reduction >= tol) and (inner_i < tol_iter):

            log.info("(%d) Tol %.2e: Loop %d/%d (max)",
                     toli, tol, inner_i, tol_iter)
            improved_modelspec = copy.deepcopy(modelspec)
            cc = 0
            slist = list(range(slice_count))
            # random.shuffle(slist)

            for i, m in enumerate(modelspec):
                if i in fit_set_all:
                    log.info(m['fn'] + ": fitting")
                else:
                    log.info(m['fn'] + ": frozen")

            # partially implemented: select temporal subset of data for fitting
            # on current loop.
            # data2 = data.copy()
            # big_slice += 1
            # sl = np.zeros(big_n, dtype=bool)
            # sl[:big_slice_size]=True
            # sl = np.roll(sl, big_step*big_slice)
            # log.info('Sampling temporal subset %d (size=%d/%d)', big_step, big_slice_size, big_n)
            # for s in data2.signals.values():
            #    e = s._modified_copy(s._data[:,sl])
            #    data2[e.name] = e

            # improved_modelspec = init.prefit_mod_subset(
            #        data, improved_modelspec, analysis.fit_basic,
            #        metric=metric,
            #        fit_set=fit_set_all,
            #        fit_kwargs=sp_kwargs)
            improved_modelspec = fit_population_channel_fast2(
                data, improved_modelspec, fit_set_all, fit_set_slice,
                analysis_function=analysis.fit_basic,
                metric=metric,
                fitter=scipy_minimize, fit_kwargs=sp_kwargs)

            for s in slist:
                log.info('Slice %d set %s' % (s, fit_set_slice))
                improved_modelspec = fit_population_slice(
                        data, improved_modelspec, slice=s,
                        fit_set=fit_set_slice,
                        analysis_function=analysis.fit_basic,
                        metric=metric,
                        fitter=scipy_minimize,
                        fit_kwargs=sp_kwargs)
                # fitter = coordinate_descent,
                # fit_kwargs = cd_kwargs)

                cc += 1
                # if (cc % 8 == 0) or (cc == slice_count):

            data = ms.evaluate(data, improved_modelspec)
            new_error = metric(data)
            error_reduction = error - new_error
            error = new_error
            log.info("tol=%.2E, iter=%d/%d: deltaE=%.6E",
                     tol, inner_i, tol_iter, error_reduction)
            inner_i += 1
            if error_reduction > 0:
                modelspec = improved_modelspec

        log.info("Done with tol %.2E (i=%d, max_error_reduction %.7f)",
                 tol, inner_i, error_reduction)

        if (toli == 0) and skip_nl_first:
            log.info('Restoring NL module after first tol loop')
            modelspec.append(saved_modelspec[-1])

            fit_set_slice = saved_fit_set_slice
            if 'double_exponential' in saved_modelspec[-1]['fn']:
                modelspec = init.init_dexp(data, modelspec)
            elif 'logistic_sigmoid' in saved_modelspec[-1]['fn']:
                modelspec = init.init_logsig(data, modelspec)
            elif 'relu' in saved_modelspec[-1]['fn']:
                # just keep initialized to zero
                pass
            else:
                raise ValueError("Output NL %s not supported",
                                 saved_modelspec[-1]['fn'])
            # just fit the NL
            improved_modelspec = copy.deepcopy(modelspec)

            kwa = cd_kwargs.copy()
            kwa['max_iter'] *= 2
            for s in range(slice_count):
                log.info('Slice %d set %s' % (s, [fit_set_slice[-1]]))
                improved_modelspec = fit_population_slice(
                        data, improved_modelspec, slice=s,
                        fit_set=fit_set_slice,
                        analysis_function=analysis.fit_basic,
                        metric=metric,
                        fitter=scipy_minimize,
                        fit_kwargs=sp_kwargs)
                # fitter = coordinate_descent,
                # fit_kwargs = cd_kwargs)
            data = ms.evaluate(data, modelspec)
            old_error = metric(data)
            data = ms.evaluate(data, improved_modelspec)
            new_error = metric(data)
            log.info('Init NL fit error change %.5f-%.5f = %.5f',
                     old_error, new_error, old_error-new_error)
            modelspec = improved_modelspec

        else:
            step_size *= 0.25

    elapsed_time = (time.time() - start_time)

    # TODO: Should this maybe be moved to a higher level
    # so it applies to ALL the fittters?
    ms.fit_mode_off(improved_modelspec)
    ms.set_modelspec_metadata(improved_modelspec, 'fitter', metaname)
    ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time)

    return {'modelspec': improved_modelspec.copy()}
Beispiel #30
0
def init_dexp(rec, modelspec):
    """
    choose initial values for dexp applied after preceeding fir is
    initialized
    """
    # preserve input modelspec
    modelspec = copy.deepcopy(modelspec)

    target_i = find_module('double_exponential', modelspec)
    if target_i is None:
        log.warning("No dexp module was found, can't initialize.")
        return modelspec

    if target_i == len(modelspec):
        fit_portion = modelspec
    else:
        fit_portion = modelspec[:target_i]

    # generate prediction from module preceeding dexp
    ms.fit_mode_on(fit_portion)
    rec = ms.evaluate(rec, fit_portion)
    ms.fit_mode_off(fit_portion)

    pchans = rec['pred'].shape[0]
    amp = np.zeros([pchans, 1])
    base = np.zeros([pchans, 1])
    kappa = np.zeros([pchans, 1])
    shift = np.zeros([pchans, 1])

    for i in range(pchans):
        resp = rec['resp'].as_continuous()
        pred = rec['pred'].as_continuous()[i:(i + 1), :]

        keepidx = np.isfinite(resp) * np.isfinite(pred)
        resp = resp[keepidx]
        pred = pred[keepidx]

        # choose phi s.t. dexp starts as almost a straight line
        # phi=[max_out min_out slope mean_in]
        # meanr = np.nanmean(resp)
        stdr = np.nanstd(resp)

        # base = np.max(np.array([meanr - stdr * 4, 0]))
        base[i, 0] = np.min(resp)
        # base = meanr - stdr * 3

        # amp = np.max(resp) - np.min(resp)
        amp[i, 0] = stdr * 3

        shift[i, 0] = np.mean(pred)
        # shift = (np.max(pred) + np.min(pred)) / 2

        predrange = 2 / (np.max(pred) - np.min(pred) + 1)
        kappa[i, 0] = np.log(predrange)

    modelspec[target_i]['phi'] = {
        'amplitude': amp,
        'base': base,
        'kappa': kappa,
        'shift': shift
    }
    log.info("Init dexp: %s", modelspec[target_i]['phi'])

    return modelspec