コード例 #1
0
ファイル: initializers.py プロジェクト: nadoss/nems_db
def init_dexp(rec, modelspec):
    """
    choose initial values for dexp applied after preceeding fir is
    initialized
    """
    # preserve input modelspec
    modelspec = copy.deepcopy(modelspec)

    target_i = find_module('double_exponential', modelspec)
    if target_i is None:
        log.warning("No dexp module was found, can't initialize.")
        return modelspec

    if target_i == len(modelspec):
        fit_portion = modelspec
    else:
        fit_portion = modelspec[:target_i]

    # generate prediction from module preceeding dexp
    ms.fit_mode_on(fit_portion)
    rec = ms.evaluate(rec, fit_portion)
    ms.fit_mode_off(fit_portion)

    pchans = rec['pred'].shape[0]
    amp = np.zeros([pchans, 1])
    base = np.zeros([pchans, 1])
    kappa = np.zeros([pchans, 1])
    shift = np.zeros([pchans, 1])

    for i in range(pchans):
        resp = rec['resp'].as_continuous()
        pred = rec['pred'].as_continuous()[i:(i + 1), :]

        keepidx = np.isfinite(resp) * np.isfinite(pred)
        resp = resp[keepidx]
        pred = pred[keepidx]

        # choose phi s.t. dexp starts as almost a straight line
        # phi=[max_out min_out slope mean_in]
        # meanr = np.nanmean(resp)
        stdr = np.nanstd(resp)

        # base = np.max(np.array([meanr - stdr * 4, 0]))
        base[i, 0] = np.min(resp)
        # base = meanr - stdr * 3

        # amp = np.max(resp) - np.min(resp)
        amp[i, 0] = stdr * 3

        shift[i, 0] = np.mean(pred)
        # shift = (np.max(pred) + np.min(pred)) / 2

        predrange = 2 / (np.max(pred) - np.min(pred) + 1)
        kappa[i, 0] = np.log(predrange)

    modelspec[target_i]['phi'] = {
        'amplitude': amp,
        'base': base,
        'kappa': kappa,
        'shift': shift
    }
    log.info("Init dexp: %s", modelspec[target_i]['phi'])

    return modelspec
コード例 #2
0
def fit_basic(data,
              modelspec,
              fitter=scipy_minimize,
              cost_function=None,
              segmentor=nems.segmentors.use_all_data,
              mapper=nems.fitters.mappers.simple_vector,
              metric=lambda data: metrics.nmse(data, 'pred', 'resp'),
              metaname='fit_basic',
              fit_kwargs={},
              require_phi=True):
    '''
    Required Arguments:
     data          A recording object
     modelspec     A modelspec object

    Optional Arguments:
     fitter        A function of (sigma, costfn) that tests various points,
                   in fitspace (i.e. sigmas) using the cost function costfn,
                   and hopefully returns a better sigma after some time.
     mapper        A class that has two methods, pack and unpack, which define
                   the mapping between modelspecs and a fitter's fitspace.
     segmentor     An function that selects a subset of the data during the
                   fitting process. This is NOT the same as est/val data splits
     metric        A function of a Recording that returns an error value
                   that is to be minimized.

    Returns
    A list containing a single modelspec, which has the best parameters found
    by this fitter.
    '''
    start_time = time.time()

    modelspec = copy.deepcopy(modelspec)

    if cost_function is None:
        # Use the cost function defined in this module by default
        cost_function = basic_cost

    if require_phi:
        # Ensure that phi exists for all modules;
        # choose prior mean if not found
        for i, m in enumerate(modelspec):
            if ('phi' not in m.keys()) and ('prior' in m.keys()):
                log.debug('Phi not found for module, using mean of prior: %s',
                          m)
                m = nems.priors.set_mean_phi([m])[0]  # Inits phi for 1 module
                modelspec[i] = m

    # apply mask to remove invalid portions of signals and allow fit to
    # only evaluate the model on the valid portion of the signals
    if 'mask' in data.signals.keys():
        log.info("Data len pre-mask: %d", data['mask'].shape[1])
        data = data.apply_mask()
        log.info("Data len post-mask: %d", data['mask'].shape[1])

    # turn on "fit mode". currently this serves one purpose, for normalization
    # parameters to be re-fit for the output of each module that uses
    # normalization. does nothing if normalization is not being used.
    ms.fit_mode_on(modelspec)

    # Create the mapper functions that translates to and from modelspecs.
    # It has three functions that, when defined as mathematical functions, are:
    #    .pack(modelspec) -> fitspace_point
    #    .unpack(fitspace_point) -> modelspec
    #    .bounds(modelspec) -> fitspace_bounds
    packer, unpacker, pack_bounds = mapper(modelspec)

    # A function to evaluate the modelspec on the data
    evaluator = nems.modelspec.evaluate

    my_cost_function = cost_function
    my_cost_function.counter = 0

    # Freeze everything but sigma, since that's all the fitter should be
    # updating.
    cost_fn = partial(my_cost_function,
                      unpacker=unpacker,
                      modelspec=modelspec,
                      data=data,
                      segmentor=segmentor,
                      evaluator=evaluator,
                      metric=metric)

    # get initial sigma value representing some point in the fit space,
    # and corresponding bounds for each value
    sigma = packer(modelspec)
    bounds = pack_bounds(modelspec)

    # Results should be a list of modelspecs
    # (might only be one in list, but still should be packaged as a list)
    improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs)
    improved_modelspec = unpacker(improved_sigma)

    elapsed_time = (time.time() - start_time)

    # TODO: Should this maybe be moved to a higher level
    # so it applies to ALL the fittters?
    ms.fit_mode_off(improved_modelspec)
    ms.set_modelspec_metadata(improved_modelspec, 'fitter', metaname)
    ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time)
    ms.set_modelspec_metadata(improved_modelspec, 'n_parms',
                              len(improved_sigma))

    results = [copy.deepcopy(improved_modelspec)]
    return results
コード例 #3
0
def fit_population_iteratively(
        est, modelspec,
        cost_function=basic_cost,
        fitter=coordinate_descent, evaluator=ms.evaluate,
        segmentor=nems.segmentors.use_all_data,
        mapper=nems.fitters.mappers.simple_vector,
        metric=lambda data: nems.metrics.api.nmse(data, 'pred', 'resp'),
        metaname='fit_basic', fit_kwargs={},
        module_sets=None, invert=False, tolerances=None, tol_iter=50,
        fit_iter=10, IsReload=False, **context
        ):
    '''
    Required Arguments:
     est          A recording object
     modelspec     A modelspec object

    Optional Arguments:
    TODO: need to deal with the fact that you can't pass functions in an xforms-frieldly fucntion
     fitter        (CURRENTLY NOT USED?)
                   A function of (sigma, costfn) that tests various points,
                   in fitspace (i.e. sigmas) using the cost function costfn,
                   and hopefully returns a better sigma after some time.
     mapper        (CURRENTLY NOT USED?)
                   A class that has two methods, pack and unpack, which define
                   the mapping between modelspecs and a fitter's fitspace.
     segmentor     (CURRENTLY NOT USED?)
                   An function that selects a subset of the data during the
                   fitting process. This is NOT the same as est/val data splits
     metric        A function of a Recording that returns an error value
                   that is to be minimized.

     module_sets   (CURRENTLY NOT USED?)
                   A nested list specifying which model indices should be fit.
                   Overall iteration will occurr len(module_sets) many times.
                   ex: [[0], [1, 3], [0, 1, 2, 3]]

     invert        (CURRENTLY NOT USED?)
                   Boolean. Causes module_sets to specify the model indices
                   that should *not* be fit.


    Returns
    A list containing a single modelspec, which has the best parameters found
    by this fitter.
    '''

    if IsReload:
        return {}

    modelspec = copy.deepcopy(modelspec)
    data = est.copy()

    fit_set_all, fit_set_slice = _figure_out_mod_split(modelspec)

    if tolerances is None:
        tolerances = [1e-4, 1e-5]

    # apply mask to remove invalid portions of signals and allow fit to
    # only evaluate the model on the valid portion of the signals
    # then delete the mask signal so that it's not reapplied on each fit
    if 'mask' in data.signals.keys():
        log.info("Data len pre-mask: %d", data['mask'].shape[1])
        data = data.apply_mask()
        log.info("Data len post-mask: %d", data['mask'].shape[1])
        del data.signals['mask']

    start_time = time.time()
    ms.fit_mode_on(modelspec, data)

    # modelspec = init_pop_pca(data, modelspec)
    # print(modelspec)

    # Ensure that phi exists for all modules; choose prior mean if not found
    # for i, m in enumerate(modelspec):
    #    if ('phi' not in m.keys()) and ('prior' in m.keys()):
    #        m = nems.priors.set_mean_phi([m])[0]  # Inits phi for 1 module
    #        log.debug('Phi not found for module, using mean of prior: {}'
    #                  .format(m))
    #        modelspec[i] = m

    error = np.inf

    slice_count = data['resp'].shape[0]
    step_size = 0.1
    if 'nonlinearity' in modelspec[-1]['fn']:
        skip_nl_first = True
        tolerances = [tolerances[0]] + tolerances
    else:
        skip_nl_first = False

    for toli, tol in enumerate(tolerances):

        log.info("Fitting subsets with tol: %.2E fit_iter %d tol_iter %d",
                 tol, fit_iter, tol_iter)
        cd_kwargs = fit_kwargs.copy()
        cd_kwargs.update({'tolerance': tol, 'max_iter': fit_iter,
                          'step_size': step_size})
        sp_kwargs = fit_kwargs.copy()
        sp_kwargs.update({'tolerance': tol, 'max_iter': fit_iter})

        if (toli == 0) and skip_nl_first:
            log.info('skipping nl on first tolerance loop')
            saved_modelspec = copy.deepcopy(modelspec)
            saved_fit_set_slice = fit_set_slice.copy()
            # import pdb;
            # pdb.set_trace()
            modelspec.pop_module()
            fit_set_slice = fit_set_slice[:-1]

        inner_i = 0
        error_reduction = np.inf
        # big_slice = 0
        # big_n = data['resp'].ntimes
        # big_step = int(big_n/10)
        # big_slice_size = int(big_n/2)
        while (error_reduction >= tol) and (inner_i < tol_iter):

            log.info("(%d) Tol %.2e: Loop %d/%d (max)",
                     toli, tol, inner_i, tol_iter)
            improved_modelspec = copy.deepcopy(modelspec)
            cc = 0
            slist = list(range(slice_count))
            # random.shuffle(slist)

            for i, m in enumerate(modelspec):
                if i in fit_set_all:
                    log.info(m['fn'] + ": fitting")
                else:
                    log.info(m['fn'] + ": frozen")

            # partially implemented: select temporal subset of data for fitting
            # on current loop.
            # data2 = data.copy()
            # big_slice += 1
            # sl = np.zeros(big_n, dtype=bool)
            # sl[:big_slice_size]=True
            # sl = np.roll(sl, big_step*big_slice)
            # log.info('Sampling temporal subset %d (size=%d/%d)', big_step, big_slice_size, big_n)
            # for s in data2.signals.values():
            #    e = s._modified_copy(s._data[:,sl])
            #    data2[e.name] = e

            # improved_modelspec = init.prefit_mod_subset(
            #        data, improved_modelspec, analysis.fit_basic,
            #        metric=metric,
            #        fit_set=fit_set_all,
            #        fit_kwargs=sp_kwargs)
            improved_modelspec = fit_population_channel_fast2(
                data, improved_modelspec, fit_set_all, fit_set_slice,
                analysis_function=analysis.fit_basic,
                metric=metric,
                fitter=scipy_minimize, fit_kwargs=sp_kwargs)

            for s in slist:
                log.info('Slice %d set %s' % (s, fit_set_slice))
                improved_modelspec = fit_population_slice(
                        data, improved_modelspec, slice=s,
                        fit_set=fit_set_slice,
                        analysis_function=analysis.fit_basic,
                        metric=metric,
                        fitter=scipy_minimize,
                        fit_kwargs=sp_kwargs)
                # fitter = coordinate_descent,
                # fit_kwargs = cd_kwargs)

                cc += 1
                # if (cc % 8 == 0) or (cc == slice_count):

            data = ms.evaluate(data, improved_modelspec)
            new_error = metric(data)
            error_reduction = error - new_error
            error = new_error
            log.info("tol=%.2E, iter=%d/%d: deltaE=%.6E",
                     tol, inner_i, tol_iter, error_reduction)
            inner_i += 1
            if error_reduction > 0:
                modelspec = improved_modelspec

        log.info("Done with tol %.2E (i=%d, max_error_reduction %.7f)",
                 tol, inner_i, error_reduction)

        if (toli == 0) and skip_nl_first:
            log.info('Restoring NL module after first tol loop')
            modelspec.append(saved_modelspec[-1])

            fit_set_slice = saved_fit_set_slice
            if 'double_exponential' in saved_modelspec[-1]['fn']:
                modelspec = init.init_dexp(data, modelspec)
            elif 'logistic_sigmoid' in saved_modelspec[-1]['fn']:
                modelspec = init.init_logsig(data, modelspec)
            elif 'relu' in saved_modelspec[-1]['fn']:
                # just keep initialized to zero
                pass
            else:
                raise ValueError("Output NL %s not supported",
                                 saved_modelspec[-1]['fn'])
            # just fit the NL
            improved_modelspec = copy.deepcopy(modelspec)

            kwa = cd_kwargs.copy()
            kwa['max_iter'] *= 2
            for s in range(slice_count):
                log.info('Slice %d set %s' % (s, [fit_set_slice[-1]]))
                improved_modelspec = fit_population_slice(
                        data, improved_modelspec, slice=s,
                        fit_set=fit_set_slice,
                        analysis_function=analysis.fit_basic,
                        metric=metric,
                        fitter=scipy_minimize,
                        fit_kwargs=sp_kwargs)
                # fitter = coordinate_descent,
                # fit_kwargs = cd_kwargs)
            data = ms.evaluate(data, modelspec)
            old_error = metric(data)
            data = ms.evaluate(data, improved_modelspec)
            new_error = metric(data)
            log.info('Init NL fit error change %.5f-%.5f = %.5f',
                     old_error, new_error, old_error-new_error)
            modelspec = improved_modelspec

        else:
            step_size *= 0.25

    elapsed_time = (time.time() - start_time)

    # TODO: Should this maybe be moved to a higher level
    # so it applies to ALL the fittters?
    ms.fit_mode_off(improved_modelspec)
    ms.set_modelspec_metadata(improved_modelspec, 'fitter', metaname)
    ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time)

    return {'modelspec': improved_modelspec.copy()}
コード例 #4
0
ファイル: fit_iteratively.py プロジェクト: nadoss/NEMS
def fit_iteratively(
    data,
    modelspec,
    cost_function=basic_cost,
    fitter=coordinate_descent,
    evaluator=ms.evaluate,
    segmentor=nems.segmentors.use_all_data,
    mapper=nems.fitters.mappers.simple_vector,
    metric=lambda data: nems.metrics.api.nmse(data, 'pred', 'resp'),
    metaname='fit_basic',
    fit_kwargs={},
    module_sets=None,
    invert=False,
    tolerances=None,
    tol_iter=50,
    fit_iter=10,
):
    '''
    Required Arguments:
     data          A recording object
     modelspec     A modelspec object

    Optional Arguments:
     fitter        A function of (sigma, costfn) that tests various points,
                   in fitspace (i.e. sigmas) using the cost function costfn,
                   and hopefully returns a better sigma after some time.
     mapper        A class that has two methods, pack and unpack, which define
                   the mapping between modelspecs and a fitter's fitspace.
     segmentor     An function that selects a subset of the data during the
                   fitting process. This is NOT the same as est/val data splits
     metric        A function of a Recording that returns an error value
                   that is to be minimized.

     module_sets   A nested list specifying which model indices should be fit.
                   Overall iteration will occurr len(module_sets) many times.
                   ex: [[0], [1, 3], [0, 1, 2, 3]]

     invert        Boolean. Causes module_sets to specify the model indices
                   that should *not* be fit.


    Returns
    A list containing a single modelspec, which has the best parameters found
    by this fitter.
    '''
    if module_sets is None:
        module_sets = []
        for i, m in enumerate(modelspec):
            if 'prior' in m.keys():
                if 'levelshift' in m['fn'] and 'fir' in modelspec[i - 1]['fn']:
                    # group levelshift with preceding fir filter by default
                    module_sets[-1].append(i)
                else:
                    # otherwise just fit each module separately
                    module_sets.append([i])
        log.info('Fit sets: %s', module_sets)

    if tolerances is None:
        tolerances = [1e-6]

    # apply mask to remove invalid portions of signals and allow fit to
    # only evaluate the model on the valid portion of the signals
    if 'mask' in data.signals.keys():
        log.info("Data len pre-mask: %d", data['mask'].shape[1])
        data = data.apply_mask()
        log.info("Data len post-mask: %d", data['mask'].shape[1])

    start_time = time.time()
    ms.fit_mode_on(modelspec)
    # Ensure that phi exists for all modules; choose prior mean if not found
    for i, m in enumerate(modelspec):
        if ('phi' not in m.keys()) and ('prior' in m.keys()):
            m = nems.priors.set_mean_phi([m])[0]  # Inits phi for 1 module
            log.debug(
                'Phi not found for module, using mean of prior: {}'.format(m))
            modelspec[i] = m

    error = np.inf
    for tol in tolerances:
        log.info("Fitting subsets with tol: %.2E fit_iter %d tol_iter %d", tol,
                 fit_iter, tol_iter)
        fit_kwargs.update({'tolerance': tol, 'max_iter': fit_iter})
        max_error_reduction = np.inf
        i = 0

        while (max_error_reduction >= tol) and (i < tol_iter):
            max_error_reduction = 0
            j = 0
            for subset in module_sets:
                improved_modelspec = _module_set_loop(subset, data, modelspec,
                                                      cost_function, fitter,
                                                      mapper, segmentor,
                                                      evaluator, metric,
                                                      fit_kwargs)
                new_error = cost_function.error
                error_reduction = error - new_error
                error = new_error
                j += 1
                if error_reduction > max_error_reduction:
                    max_error_reduction = error_reduction
            log.info("tol=%.2E, iter=%d/%d: max deltaE=%.6E", tol, i, tol_iter,
                     max_error_reduction)
            i += 1
        log.info("Done with tol %.2E (i=%d, max_error_reduction %.7f)", tol, i,
                 error_reduction)

    elapsed_time = (time.time() - start_time)

    # TODO: Should this maybe be moved to a higher level
    # so it applies to ALL the fittters?
    ms.fit_mode_off(improved_modelspec)
    ms.set_modelspec_metadata(improved_modelspec, 'fitter', metaname)
    ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time)
    results = [copy.deepcopy(improved_modelspec)]

    return results
コード例 #5
0
ファイル: fit_iteratively.py プロジェクト: nadoss/NEMS
def fit_module_sets(
        data,
        modelspec,
        cost_function=basic_cost,
        evaluator=ms.evaluate,
        segmentor=nems.segmentors.use_all_data,
        mapper=nems.fitters.mappers.simple_vector,
        metric=lambda data: nems.metrics.api.nmse(data, 'pred', 'resp'),
        fitter=coordinate_descent,
        fit_kwargs={},
        metaname='fit_module_sets',
        module_sets=None,
        invert=False,
        tolerance=1e-4,
        max_iter=1000):
    '''
    Required Arguments:
     data          A recording object
     modelspec     A modelspec object

    Optional Arguments:
     fitter        A function of (sigma, costfn) that tests various points,
                   in fitspace (i.e. sigmas) using the cost function costfn,
                   and hopefully returns a better sigma after some time.

     module_sets   A nested list specifying which model indices should be fit.
                   Overall iteration will occurr len(module_sets) many times.
                   ex: [[0], [1, 3], [0, 1, 2, 3]]

     invert        Boolean. Causes module_sets to specify the model indices
                   that should *not* be fit.


    Returns
    A list containing a single modelspec, which has the best parameters found
    by this fitter.
    '''
    if module_sets is None:
        module_sets = [[i] for i in range(len(modelspec))]
    fit_kwargs.update({'tolerance': tolerance, 'max_iter': max_iter})

    # Ensure that phi exists for all modules; choose prior mean if not found
    for i, m in enumerate(modelspec):
        if ('phi' not in m.keys()) and ('prior' in m.keys()):
            m = nems.priors.set_mean_phi([m])[0]  # Inits phi for 1 module
            log.debug(
                'Phi not found for module, using mean of prior: {}'.format(m))
            modelspec[i] = m

    if invert:
        module_sets = _invert_subsets(modelspec, module_sets)

    ms.fit_mode_on(modelspec)
    start_time = time.time()

    log.info("Fitting all subsets with tolerance: %.2E", tolerance)
    for subset in module_sets:
        improved_modelspec = _module_set_loop(subset, data, modelspec,
                                              cost_function, fitter, mapper,
                                              segmentor, evaluator, metric,
                                              fit_kwargs)

    elapsed_time = (time.time() - start_time)

    # TODO: Should this maybe be moved to a higher level
    # so it applies to ALL the fittters?
    ms.fit_mode_off(improved_modelspec)
    ms.set_modelspec_metadata(improved_modelspec, 'fitter', metaname)
    ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time)
    results = [copy.deepcopy(improved_modelspec)]

    return results
コード例 #6
0
ファイル: fit_ccnorm.py プロジェクト: LBHB/NEMS
def fit_pcnorm(modelspec,
               est: recording.Recording,
               metric=None,
               use_modelspec_init: bool = True,
               optimizer: str = 'adam',
               max_iter: int = 10000,
               early_stopping_steps: int = 5,
               tolerance: float = 5e-4,
               learning_rate: float = 1e-4,
               batch_size: typing.Union[None, int] = None,
               seed: int = 0,
               initializer: str = 'random_normal',
               freeze_layers: typing.Union[None, list] = None,
               epoch_name: str = "REFERENCE",
               n_pcs=2,
               **context):
    '''
    Required Arguments:
     est          A recording object
     modelspec     A modelspec object

    Optional Arguments:
     <copied from fit_tf for now

    Returns
     dictionary: {'modelspec': updated_modelspec}
    '''

    # Hard-coded
    cost_function = basic_cost
    fitter = scipy_minimize
    segmentor = nems.segmentors.use_all_data
    mapper = nems.fitters.mappers.simple_vector
    fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter}

    start_time = time.time()

    modelspec = copy.deepcopy(modelspec)

    # apply mask to remove invalid portions of signals and allow fit to
    # only evaluate the model on the valid portion of the signals
    if 'mask' in est.signals.keys():
        log.info("Data len pre-mask: %d", est['mask'].shape[1])
        est = est.apply_mask()
        log.info("Data len post-mask: %d", est['mask'].shape[1])

    conditions = [
        "_".join(k.split("_")[1:]) for k in est.signals.keys()
        if k.startswith("mask_")
    ]
    if (len(conditions) > 2) and any(
        [c.split("_")[-1] == 'lg' for c in conditions]):
        conditions.remove("small")
        conditions.remove("large")
    #conditions = conditions[0:2]
    #conditions = ['large','small']

    group_idx = [est['mask_' + c].as_continuous()[0, :] for c in conditions]
    cg_filtered = [(c, g) for c, g in zip(conditions, group_idx)
                   if g.sum() > 0]
    conditions, group_idx = zip(*cg_filtered)

    for c, g in zip(conditions, group_idx):
        log.info(f"Data subset for {c} len {g.sum()}")

    resp = est['resp'].as_continuous()
    pred0 = est['pred0'].as_continuous()
    residual = resp - pred0

    pca = PCA(n_components=n_pcs)
    pca.fit(residual.T)
    pc_axes = pca.components_

    pcproj = residual.T.dot(pc_axes.T).T

    group_pc = [pcproj[:, idx].std(axis=1) for idx in group_idx]
    resp_std = resp.std(axis=1)
    #import pdb; pdb.set_trace()

    if metric is None:
        metric = lambda d: pc_err(d,
                                  pred_name='pred',
                                  pred0_name='pred0',
                                  group_idx=group_idx,
                                  group_pc=group_pc,
                                  pc_axes=pc_axes,
                                  resp_std=resp_std)

    # turn on "fit mode". currently this serves one purpose, for normalization
    # parameters to be re-fit for the output of each module that uses
    # normalization. does nothing if normalization is not being used.
    ms.fit_mode_on(modelspec, est)

    # Create the mapper functions that translates to and from modelspecs.
    # It has three functions that, when defined as mathematical functions, are:
    #    .pack(modelspec) -> fitspace_point
    #    .unpack(fitspace_point) -> modelspec
    #    .bounds(modelspec) -> fitspace_bounds
    packer, unpacker, pack_bounds = mapper(modelspec)

    # A function to evaluate the modelspec on the data
    evaluator = nems.modelspec.evaluate

    my_cost_function = cost_function
    my_cost_function.counter = 0

    # Freeze everything but sigma, since that's all the fitter should be
    # updating.
    cost_fn = partial(my_cost_function,
                      unpacker=unpacker,
                      modelspec=modelspec,
                      data=est,
                      segmentor=segmentor,
                      evaluator=evaluator,
                      metric=metric,
                      display_N=1000)

    # get initial sigma value representing some point in the fit space,
    # and corresponding bounds for each value
    sigma = packer(modelspec)
    bounds = pack_bounds(modelspec)

    # Results should be a list of modelspecs
    # (might only be one in list, but still should be packaged as a list)
    improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs)
    improved_modelspec = unpacker(improved_sigma)
    elapsed_time = (time.time() - start_time)

    start_err = cost_fn(sigma)
    final_err = cost_fn(improved_sigma)
    log.info("Delta error: %.06f - %.06f = %e", start_err, final_err,
             final_err - start_err)

    # TODO: Should this maybe be moved to a higher level
    # so it applies to ALL the fittters?
    ms.fit_mode_off(improved_modelspec)
    ms.set_modelspec_metadata(improved_modelspec, 'fitter', 'ccnorm')
    ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time)
    ms.set_modelspec_metadata(improved_modelspec, 'n_parms',
                              len(improved_sigma))

    return {'modelspec': improved_modelspec.copy(), 'save_context': True}
コード例 #7
0
ファイル: fit_ccnorm.py プロジェクト: LBHB/NEMS
def fit_ccnorm(modelspec,
               est: recording.Recording,
               metric=None,
               use_modelspec_init: bool = True,
               optimizer: str = 'adam',
               max_iter: int = 10000,
               early_stopping_steps: int = 5,
               tolerance: float = 5e-4,
               learning_rate: float = 1e-4,
               batch_size: typing.Union[None, int] = None,
               seed: int = 0,
               initializer: str = 'random_normal',
               freeze_layers: typing.Union[None, list] = None,
               epoch_name: str = "REFERENCE",
               shrink_cc: float = 0,
               noise_pcs: int = 0,
               shared_pcs: int = 0,
               also_fit_resp: bool = False,
               force_psth: bool = False,
               use_metric: typing.Union[None, str] = None,
               alpha: float = 0.1,
               beta: float = 1,
               exclude_idx=None,
               exclude_after=None,
               freeze_idx=None,
               freeze_after=None,
               **context):
    '''
    Required Arguments:
     est          A recording object
     modelspec     A modelspec object

    Optional Arguments:
     <copied from fit_tf for now

    Returns
     dictionary: {'modelspec': updated_modelspec}
    '''

    # Hard-coded
    cost_function = basic_cost
    fitter = scipy_minimize
    segmentor = nems.segmentors.use_all_data
    mapper = nems.fitters.mappers.simple_vector
    fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter}

    start_time = time.time()

    fit_index = modelspec.fit_index
    if (exclude_idx is not None) | (freeze_idx is not None) | \
            (exclude_after is not None) | (freeze_after is not None):
        modelspec0 = modelspec.copy()
        modelspec, include_set = modelspec_freeze_layers(
            modelspec,
            include_idx=None,
            exclude_idx=exclude_idx,
            exclude_after=exclude_after,
            freeze_idx=freeze_idx,
            freeze_after=freeze_after)
        modelspec0.set_fit(fit_index)
        modelspec.set_fit(fit_index)
    else:
        include_set = None

    # Computing PCs before masking out unwanted stimuli in order to
    # preserve match with epochs
    epoch_regex = "^STIM_"
    stims = (est.epochs['name'].value_counts() >= 8)
    stims = [
        stims.index[i] for i, s in enumerate(stims)
        if bool(re.search(epoch_regex, stims.index[i])) and s == True
    ]

    Rall_u = est.apply_mask()['psth'].as_continuous().T
    # can't simply extract evoked for refs because can be longer/shorted if it came after target
    # and / or if it was the last stim. So, masking prestim / postim doesn't work. Do it manually
    #d = est['resp'].extract_epochs(stims, mask=est['mask'])

    #R = [v.mean(axis=0) for (k, v) in d.items()]
    #R = [np.reshape(np.transpose(v,[1,0,2]),[v.shape[1],-1]) for (k, v) in d.items()]
    #Rall_u = np.hstack(R).T

    pca = PCA(n_components=2)
    pca.fit(Rall_u)
    pc_axes = pca.components_

    # apply mask to remove invalid portions of signals and allow fit to
    # only evaluate the model on the valid portion of the signals
    if 'mask_small' in est.signals.keys():
        log.info('reseting mask with mask_small+mask_large subset')
        est['mask'] = est['mask']._modified_copy(data=est['mask_small']._data +
                                                 est['mask_large']._data)

    if 'mask' in est.signals.keys():
        log.info("Data len pre-mask: %d", est['mask'].shape[1])
        est = est.apply_mask()
        log.info("Data len post-mask: %d", est['mask'].shape[1])

    # if we want to fit to first-order cc error.
    #uncomment this and make sure sdexp is generating a pred0 signal
    est = modelspec.evaluate(est, stop=2)
    if ('pred0' in est.signals.keys()) & (not force_psth):
        input_name = 'pred0'
        log.info('Found pred0 for fitting CC')
    else:
        input_name = 'psth'
        log.info('No pred0, using psth for fitting CC')

    conditions = [
        "_".join(k.split("_")[1:]) for k in est.signals.keys()
        if k.startswith("mask_")
    ]
    if (len(conditions) > 2) and any(
        [c.split("_")[-1] == 'lg' for c in conditions]):
        conditions.remove("small")
        conditions.remove("large")
    #conditions = conditions[0:2]
    #conditions = ['large','small']

    group_idx = [est['mask_' + c].as_continuous()[0, :] for c in conditions]
    cg_filtered = [(c, g) for c, g in zip(conditions, group_idx)
                   if g.sum() > 0]
    conditions, group_idx = zip(*cg_filtered)

    for c, g in zip(conditions, group_idx):
        log.info(f"cc data for {c} len {g.sum()}")

    resp = est['resp'].as_continuous()
    pred0 = est[input_name].as_continuous()
    #import pdb; pdb.set_trace()
    if shrink_cc > 0:
        log.info(f'cc approx: shrink_cc={shrink_cc}')
        group_cc = [
            cc_shrink(resp[:, idx] - pred0[:, idx], sigrat=shrink_cc)
            for idx in group_idx
        ]
    elif shared_pcs > 0:
        log.info(f'cc approx: shared_pcs={shared_pcs}')
        cc = np.cov(resp - pred0)
        u, s, vh = np.linalg.svd(cc)
        U = u[:, :shared_pcs] @ u[:, :shared_pcs].T

        group_cc = [
            cc_shared_space(resp[:, idx] - pred0[:, idx], U)
            for idx in group_idx
        ]
    elif noise_pcs > 0:
        log.info(f'cc approx: noise_pcs={noise_pcs}')
        group_cc = [
            cc_lowrank(resp[:, idx] - pred0[:, idx], n_pcs=noise_pcs)
            for idx in group_idx
        ]
    else:
        group_cc = [np.cov(resp[:, idx] - pred0[:, idx]) for idx in group_idx]
    group_cc_raw = [np.cov(resp[:, idx] - pred0[:, idx]) for idx in group_idx]

    # variance of projection onto PCs (PCs computed above before masking)
    pcproj0 = (resp - pred0).T.dot(pc_axes.T).T
    pcproj_std = pcproj0.std(axis=1)

    if (use_metric == 'cc_err_w'):

        def metric(d, verbose=False):
            return metrics.cc_err_w(d,
                                    pred_name='pred',
                                    pred0_name=input_name,
                                    group_idx=group_idx,
                                    group_cc=group_cc,
                                    alpha=alpha,
                                    pcproj_std=None,
                                    pc_axes=None,
                                    verbose=verbose)

        log.info(f"fit_ccnorm metric: cc_err_w (alpha={alpha})")

    elif (metric is None) and also_fit_resp:
        log.info(f"resp_cc_err: pred0_name: {input_name} beta: {beta}")
        metric = lambda d: metrics.resp_cc_err(d,
                                               pred_name='pred',
                                               pred0_name=input_name,
                                               group_idx=group_idx,
                                               group_cc=group_cc,
                                               beta=beta,
                                               pcproj_std=None,
                                               pc_axes=None)

    elif (use_metric == 'cc_err_md'):

        def metric(d, verbose=False):
            return metrics.cc_err_md(d,
                                     pred_name='pred',
                                     pred0_name=input_name,
                                     group_idx=group_idx,
                                     group_cc=group_cc,
                                     pcproj_std=None,
                                     pc_axes=None)

        log.info(f"fit_ccnorm metric: cc_err_md")

    elif (metric is None):
        #def cc_err(result, pred_name='pred_lv', resp_name='resp', pred0_name='pred',
        #   group_idx=None, group_cc=None, pcproj_std=None, pc_axes=None):
        # current implementation of cc_err
        metric = lambda d: metrics.cc_err(d,
                                          pred_name='pred',
                                          pred0_name=input_name,
                                          group_idx=group_idx,
                                          group_cc=group_cc,
                                          pcproj_std=None,
                                          pc_axes=None)
        log.info(f"fit_ccnorm metric: cc_err")

    # turn on "fit mode". currently this serves one purpose, for normalization
    # parameters to be re-fit for the output of each module that uses
    # normalization. does nothing if normalization is not being used.
    ms.fit_mode_on(modelspec, est)

    # Create the mapper functions that translates to and from modelspecs.
    # It has three functions that, when defined as mathematical functions, are:
    #    .pack(modelspec) -> fitspace_point
    #    .unpack(fitspace_point) -> modelspec
    #    .bounds(modelspec) -> fitspace_bounds
    packer, unpacker, pack_bounds = mapper(modelspec)

    # A function to evaluate the modelspec on the data
    evaluator = nems.modelspec.evaluate

    my_cost_function = cost_function
    my_cost_function.counter = 0

    # Freeze everything but sigma, since that's all the fitter should be
    # updating.
    cost_fn = partial(my_cost_function,
                      unpacker=unpacker,
                      modelspec=modelspec,
                      data=est,
                      segmentor=segmentor,
                      evaluator=evaluator,
                      metric=metric,
                      display_N=1000)

    # get initial sigma value representing some point in the fit space,
    # and corresponding bounds for each value
    sigma = packer(modelspec)
    bounds = pack_bounds(modelspec)

    # Results should be a list of modelspecs
    # (might only be one in list, but still should be packaged as a list)
    improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs)
    improved_modelspec = unpacker(improved_sigma)
    elapsed_time = (time.time() - start_time)

    start_err = cost_fn(sigma)
    final_err = cost_fn(improved_sigma)
    log.info("Delta error: %.06f - %.06f = %e", start_err, final_err,
             final_err - start_err)

    # TODO: Should this maybe be moved to a higher level
    # so it applies to ALL the fittters?
    ms.fit_mode_off(improved_modelspec)

    if include_set is not None:
        # pull out updated phi values from improved_modelspec, include_set only
        improved_modelspec = \
            modelspec_unfreeze_layers(improved_modelspec, modelspec0, include_set)
        improved_modelspec.set_fit(fit_index)

    log.info(
        f"Updating improved modelspec with fit_idx={improved_modelspec.fit_index}"
    )
    improved_modelspec.meta['fitter'] = 'ccnorm'
    improved_modelspec.meta['n_parms'] = len(improved_sigma)
    if modelspec.fit_count == 1:
        improved_modelspec.meta['fit_time'] = elapsed_time
        improved_modelspec.meta['loss'] = final_err
    else:
        if modelspec.fit_index == 0:
            improved_modelspec.meta['fit_time'] = np.zeros(
                improved_modelspec.fit_count)
            improved_modelspec.meta['loss'] = np.zeros(
                improved_modelspec.fit_count)
        improved_modelspec.meta['fit_time'][fit_index] = elapsed_time
        improved_modelspec.meta['loss'][fit_index] = final_err

    return {'modelspec': improved_modelspec}
コード例 #8
0
ファイル: contrast_helpers.py プロジェクト: nadoss/nems_db
def _init_double_exponential(rec, modelspec, target_i):

    if target_i == len(modelspec):
        fit_portion = modelspec.modules
    else:
        fit_portion = modelspec.modules[:target_i]

    # generate prediction from modules preceeding dsig

    # HACK
    for i, m in enumerate(fit_portion):
        if not m.get('phi', None):
            old = m.get('prior', {})
            m = priors.set_mean_phi([m])[0]
            m['prior'] = old
            fit_portion[i] = m

    ms.fit_mode_on(fit_portion)
    rec = ms.evaluate(rec, fit_portion)
    ms.fit_mode_off(fit_portion)

    in_signal = modelspec[target_i]['fn_kwargs']['i']
    pchans = rec[in_signal].shape[0]
    amp = np.zeros([pchans, 1])
    base = np.zeros([pchans, 1])
    kappa = np.zeros([pchans, 1])
    shift = np.zeros([pchans, 1])

    for i in range(pchans):
        resp = rec['resp'].as_continuous()
        pred = rec[in_signal].as_continuous()[i:(i + 1), :]
        if resp.shape[0] == pchans:
            resp = resp[i:(i + 1), :]

        keepidx = np.isfinite(resp) * np.isfinite(pred)
        resp = resp[keepidx]
        pred = pred[keepidx]

        # choose phi s.t. dexp starts as almost a straight line
        # phi=[max_out min_out slope mean_in]
        # meanr = np.nanmean(resp)
        stdr = np.nanstd(resp)

        # base = np.max(np.array([meanr - stdr * 4, 0]))
        base[i, 0] = np.min(resp)
        # base = meanr - stdr * 3

        # amp = np.max(resp) - np.min(resp)
        amp[i, 0] = stdr * 3

        shift[i, 0] = np.mean(pred)
        # shift = (np.max(pred) + np.min(pred)) / 2

        predrange = 2 / (np.max(pred) - np.min(pred) + 1)
        kappa[i, 0] = np.log(predrange)

    amp = ('Normal', {'mean': amp, 'sd': 1.0})
    base = ('Normal', {'mean': base, 'sd': 1.0})
    kappa = ('Normal', {'mean': kappa, 'sd': 1.0})
    shift = ('Normal', {'mean': shift, 'sd': 1.0})

    modelspec[target_i]['prior'].update({
        'base': base,
        'amplitude': amp,
        'shift': shift,
        'kappa': kappa,
    })

    return modelspec
コード例 #9
0
ファイル: contrast_helpers.py プロジェクト: nadoss/nems_db
def _init_logistic_sigmoid(rec, modelspec, dsig_idx):

    if dsig_idx == len(modelspec):
        fit_portion = modelspec.modules
    else:
        fit_portion = modelspec.modules[:dsig_idx]

    # generate prediction from module preceeding dexp

    # HACK to get phi for ctwc, ctfir, ctlvl which have not been prefit yet
    for i, m in enumerate(fit_portion):
        if not m.get('phi', None):
            if [k in m['id'] for k in ['ctwc', 'ctfir', 'ctlvl']]:
                old = m.get('prior', {})
                m = priors.set_mean_phi([m])[0]
                m['prior'] = old
                fit_portion[i] = m
            else:
                log.warning(
                    "unexpected module missing phi during init step\n:"
                    "%s, #%d", m['id'], i)

    ms.fit_mode_on(fit_portion)
    rec = ms.evaluate(rec, fit_portion)
    ms.fit_mode_off(fit_portion)

    pred = rec['pred'].as_continuous()
    resp = rec['resp'].as_continuous()

    mean_pred = np.nanmean(pred)
    min_pred = np.nanmean(pred) - np.nanstd(pred) * 3
    max_pred = np.nanmean(pred) + np.nanstd(pred) * 3
    if min_pred < 0:
        min_pred = 0
        mean_pred = (min_pred + max_pred) / 2

    pred_range = max_pred - min_pred
    min_resp = max(np.nanmean(resp) - np.nanstd(resp) * 3, 0)  # must be >= 0
    max_resp = np.nanmean(resp) + np.nanstd(resp) * 3
    resp_range = max_resp - min_resp

    # Rather than setting a hard value for initial phi,
    # set the prior distributions and let the fitter/analysis
    # decide how to use it.
    base0 = min_resp + 0.05 * (resp_range)
    amplitude0 = resp_range
    shift0 = mean_pred
    kappa0 = pred_range
    log.info("Initial   base,amplitude,shift,kappa=({}, {}, {}, {})".format(
        base0, amplitude0, shift0, kappa0))

    base = ('Exponential', {'beta': base0})
    amplitude = ('Exponential', {'beta': amplitude0})
    shift = ('Normal', {'mean': shift0, 'sd': pred_range**2})
    kappa = ('Exponential', {'beta': kappa0})

    modelspec[dsig_idx]['prior'].update({
        'base': base,
        'amplitude': amplitude,
        'shift': shift,
        'kappa': kappa,
        'base_mod': base,
        'amplitude_mod': amplitude,
        'shift_mod': shift,
        'kappa_mod': kappa
    })

    for kw in modelspec[dsig_idx]['fn_kwargs']:
        if kw in ['base_mod', 'amplitude_mod', 'shift_mod', 'kappa_mod']:
            modelspec[dsig_idx]['prior'].pop(kw)

    modelspec[dsig_idx]['bounds'] = {
        'base': (1e-15, None),
        'amplitude': (1e-15, None),
        'shift': (None, None),
        'kappa': (1e-15, None),
    }

    return modelspec
コード例 #10
0
def init_logsig(rec, modelspec):
    '''
    Initialization of priors for logistic_sigmoid,
    based on process described in methods of Rabinowitz et al. 2014.
    '''
    # preserve input modelspec
    modelspec = copy.deepcopy(modelspec)

    target_i = find_module('double_exponential', modelspec)
    if target_i is None:
        log.warning("No dexp module was found, can't initialize.")
        return modelspec

    if target_i == len(modelspec):
        fit_portion = modelspec
    else:
        fit_portion = modelspec[:target_i]

    # generate prediction from module preceeding dexp
    ms.fit_mode_on(fit_portion)
    rec = ms.evaluate(rec, fit_portion)
    ms.fit_mode_off(fit_portion)

    pred = rec['pred'].as_continuous()
    resp = rec['resp'].as_continuous()

    mean_pred = np.nanmean(pred)
    min_pred = np.nanmean(pred) - np.nanstd(pred) * 3
    max_pred = np.nanmean(pred) + np.nanstd(pred) * 3
    pred_range = max_pred - min_pred
    min_resp = max(np.nanmean(resp) - np.nanstd(resp) * 3, 0)  # must be >= 0

    max_resp = np.nanmean(resp) + np.nanstd(resp) * 3
    resp_range = max_resp - min_resp

    # Rather than setting a hard value for initial phi,
    # set the prior distributions and let the fitter/analysis
    # decide how to use it.
    base0 = min_resp + 0.05 * (resp_range)
    amplitude0 = resp_range
    shift0 = mean_pred
    kappa0 = pred_range
    log.info("Initial   base,amplitude,shift,kappa=({}, {}, {}, {})".format(
        base0, amplitude0, shift0, kappa0))

    base = ('Exponential', {'beta': base0})
    amplitude = ('Exponential', {'beta': amplitude0})
    shift = ('Normal', {'mean': shift0, 'sd': pred_range})
    kappa = ('Exponential', {'beta': kappa0})

    modelspec[target_i]['prior'] = {
        'base': base,
        'amplitude': amplitude,
        'shift': shift,
        'kappa': kappa
    }

    modelspec[target_i]['bounds'] = {
        'base': (1e-15, None),
        'amplitude': (1e-15, None),
        'shift': (None, None),
        'kappa': (1e-15, None)
    }

    return modelspec