示例#1
0
def test_recording_loading():
    '''
    Test the loading and saving of files to various HTTP/S3/File routes.
    '''
    # NOTE: FOR THIS TEST TO SUCCEED, /auto/data/tmp/recordings/ must not have
    # blah.tar.gz or TAR010c-18-1.tar.gz in it.

    # Local filesystem
    # rec0 = Recording.load("/home/ivar/git/nems/signals/TAR010c-18-1.tar.gz")
    rec_path = join(RECORDING_DIR, "eno052d-a1.tgz")
    rec0 = Recording.load(rec_path)
    rec2 = Recording.load("file://%s" % rec_path)

    # HTTP
    # rec3 = Recording.load("http://hyrax.ohsu.edu:3000/recordings/eno052d-a1.tgz")
    # rec4 = Recording.load("http://hyrax.ohsu.edu:3000/baphy/294/eno052d-a1?stim=0&pupil=0")

    # S3
    # Direct access (would need AWS CLI lib? Maybe not best idea!)
    # TODO: Requires s3 credentials in environment. Probably best if on
    #       server only; put in nems_db?
    #rec5 = Recording.load('s3://mybucket/myfile.tar.gz')

    # Indirect access via http:
    rec6 = Recording.load("https://s3-us-west-2.amazonaws.com/nemspublic/"
                          "sample_data/eno052d-a1.tgz")

    # Save to a specific tar.gz file
    rec0.save('/tmp/tmp.tar.gz')
示例#2
0
def test_recording_from_arrays():
    # need a list of array-like data structures
    x = np.random.rand(3, 200)
    y = np.random.rand(1, 200)
    z = np.random.rand(5, 200)
    arrays = [x, y, z]
    # a name for the recording that will hold the signals
    rec_name = 'testing123'
    # the sampling rate for the signals, or a list of
    # individual sampling rates (if different)
    fs = [100, 100, 200]
    # a list of signal names (optional, but preferred)
    names = ['stim', 'resp', 'reference']
    # a list of keyword arguments for each signal,
    # such as channel names or epochs (also optional)
    kwargs = [{
        'chans': ['2kHz', '4kHz', '8kHz']
    }, {
        'chans': ['spike_rate']
    }, {
        'meta': {
            'experiment': 'oddball_2'
        },
        'chans': ['One', 'Two', 'Three', 'Four', 'Five']
    }]
    rec = Recording.load_from_arrays(arrays,
                                     rec_name,
                                     fs,
                                     sig_names=names,
                                     signal_kwargs=kwargs)
    # should also work with integer fs instead of list
    rec = Recording.load_from_arrays(arrays,
                                     rec_name,
                                     100,
                                     sig_names=names,
                                     signal_kwargs=kwargs)

    # All signal names should be present in recording signals dict
    contains = [(n in rec.signals.keys()) for n in names]
    assert not (False in contains)

    bad_names = ['stim']
    # should get an error now since len(names)
    # doesn't match up with len(arrays)
    with pytest.raises(ValueError):
        rec = Recording.load_from_arrays(arrays,
                                         rec_name,
                                         fs,
                                         sig_names=bad_names,
                                         signal_kwargs=kwargs)
示例#3
0
def simple_recording():
    stim = np.random.rand(18, 200)
    resp = np.random.rand(1, 200)
    return Recording.load_from_arrays([stim, resp],
                                      'simple_recording',
                                      100,
                                      sig_names=['stim', 'resp'])
示例#4
0
文件: io.py 项目: nadoss/nems_db
def remove_nan(rec):
    i = np.isnan(rec['fg'].as_continuous())
    new_signals = {}
    for name, signal in rec.items():
        new_data = signal.as_continuous()[:, i]
        new_signals[name] = signal._modified_copy(new_data)
    return Recording(new_signals)
示例#5
0
def split_recording(recording):
    '''
    split recording into independent recordings for CPP and CPN, does this to all composing signals
    :param recording: a nems.Recording object
    :return:
    '''

    sub_recordings = col.defaultdict(dict)
    metas = dict()
    for signame, signal in recording.signals.items():

        sub_signals = _split_signal(signal)

        for sig_type, sub_signal in sub_signals.items():
            sub_recordings[sig_type][signame] = sub_signal
            metas[sig_type] = sub_signal.meta

        pass

    sub_recordings = {sig_type: Recording(signals, meta=metas[sig_type]) for sig_type, signals in
                      sub_recordings.items()}

    return sub_recordings
# A demonstration of a minimalist model fitting system

from nems.recording import Recording

# ----------------------------------------------------------------------------
# DATA FETCHING

# GOAL: Get your data loaded into memory

# Method #1: Load the data from a local directory
rec = Recording.load('signals/gus027b13_p_PPS/')

# alternative ways to define the data object that could be saved as a 
# short(!) string in the dataspec
rec = my_create_recording_fun('signals/gus027b13_p_PPS/')
rec = Recording.load_standard_nems_format('signals/gus027b13_p_PPS/')


# Method #2: Load the data from baphy using the (incomplete, TODO) HTTP API:
# URL = "neuralprediction.org:3003/signals?batch=273&cellid=gus027b13-a1"
# rec = fetch_signals_over_http(URL)

# Method #3: Load the data from S3:
# stimfile="https://s3-us-west-2.amazonaws.com/nemspublic/sample_data/"+cellid+"_NAT_stim_ozgf_c18_fs100.mat"
# respfile="https://s3-us-west-2.amazonaws.com/nemspublic/sample_data/"+cellid+"_NAT_resp_fs100.mat"
# rec = fetch_signals_over_http(stimfile, respfile)

# Method #4: Load the data from a jerb (TODO)

# Method #5: Create a Recording object from a matrix, manually (TODO)
示例#7
0
文件: ptd_plot.py 项目: LBHB/nems_db
def plot_collapsed_ref_tar(animal, site, cellids=None):
    site += "%"

    sql = "SELECT DISTINCT cellid, rawid, respfile FROM sCellFile WHERE cellid like %s AND runclassid=%s"
    d = nd.pd_query(sql, params=(
        site,
        42,
    ))

    mfile = []
    for f in np.unique(d.respfile):
        f_ = f.split('.')[0]
        mfile.append('/auto/data/daq/{0}/{1}/{2}'.format(
            animal, site[:-2], f_))

    if cellids is None:
        cellid = np.unique(d.cellid).tolist()
    else:
        cellid = cellids
    options = {
        "siteid": site[:-1],
        'cellid': cellids,
        "mfilename": mfile,
        'stim': False,
        'pupil': True,
        'rasterfs': 1000
    }

    uri = nb.baphy_load_recording_uri(**options)
    rec = Recording.load(uri)
    all_pupil = rec['pupil']._data
    ncols = len(mfile)
    if cellids is None:
        cellids = rec['resp'].chans

    for c in cellids:
        f, ax = plt.subplots(1, ncols, sharey=True, figsize=(12, 5))
        ref_base = 0
        for i, mf in enumerate(mfile):
            fn = mf.split('/')[-1]
            ep_mask = [ep for ep in rec.epochs.name if fn in ep]
            R = rec.copy()
            R['resp'] = R['resp'].rasterize(fs=20)
            R['resp'].fs = 20
            R = R.and_mask(ep_mask).apply_mask(reset_epochs=True)
            if '_a_' in fn:
                R = R.and_mask(['HIT_TRIAL']).apply_mask(reset_epochs=True)

            resp = R['resp'].extract_channels([c])

            tar_reps = resp.extract_epoch('TARGET').shape[0]
            tar_m = np.nanmean(resp.extract_epoch('TARGET'),
                               0).squeeze() * R['resp'].fs
            tar_sem = R['resp'].fs * np.nanstd(resp.extract_epoch('TARGET'),
                                               0).squeeze() / np.sqrt(tar_reps)

            ref_reps = resp.extract_epoch('REFERENCE').shape[0]
            ref_m = np.nanmean(resp.extract_epoch('REFERENCE'),
                               0).squeeze() * R['resp'].fs
            ref_sem = R['resp'].fs * np.nanstd(resp.extract_epoch('REFERENCE'),
                                               0).squeeze() / np.sqrt(ref_reps)

            # plot psth's
            time = np.linspace(0, len(tar_m) / R['resp'].fs, len(tar_m))
            ax[i].plot(time, tar_m, color='red', lw=2)
            ax[i].fill_between(time,
                               tar_m + tar_sem,
                               tar_m - tar_sem,
                               color='coral')
            time = np.linspace(0, len(ref_m) / R['resp'].fs, len(ref_m))
            ax[i].plot(time, ref_m, color='blue', lw=2)
            ax[i].fill_between(time,
                               ref_m + ref_sem,
                               ref_m - ref_sem,
                               color='lightblue')

            # set title
            ax[i].set_title(fn, fontsize=8)
            # set labels
            ax[i].set_xlabel('Time (s)')
            ax[i].set_ylabel('Spk / sec')

            # get raster plot baseline
            base = np.max(np.concatenate((tar_m + tar_sem, ref_m + ref_sem)))
            if base > ref_base:
                ref_base = base

        for i, mf in enumerate(mfile):
            # plot the rasters
            fn = mf.split('/')[-1]
            ep_mask = [ep for ep in rec.epochs.name if fn in ep]
            rast_rec = rec.and_mask(ep_mask).apply_mask(reset_epochs=True)
            if '_a_' in fn:
                rast_rec = rast_rec.and_mask(['HIT_TRIAL'
                                              ]).apply_mask(reset_epochs=True)
            rast = rast_rec['resp'].extract_channels([c])

            ref_times = np.where(rast.extract_epoch('REFERENCE').squeeze())
            base = ref_base
            ref_pupil = np.nanmean(
                rast_rec['pupil'].extract_epoch('REFERENCE'), -1)
            xoffset = rast.extract_epoch(
                'TARGET').shape[-1] / rec['resp'].fs + 0.01
            ax[i].plot(ref_pupil / np.max(all_pupil) + xoffset,
                       np.linspace(base, int(base * 2), len(ref_pupil)),
                       color='k')
            ax[i].axvline(xoffset + np.median(all_pupil / np.max(all_pupil)),
                          linestyle='--',
                          color='lightgray')
            #import pdb; pdb.set_trace()
            if ref_times[0].size != 0:
                max_rep = ref_pupil.shape[0] - 1
                ref_locs = ref_times[0] * (base / max_rep)
                ref_locs = ref_locs + base
                ax[i].plot(ref_times[1] / rec['resp'].fs,
                           ref_locs,
                           '|',
                           color='blue',
                           markersize=1)

            tar_times = np.where(rast.extract_epoch('TARGET').squeeze())
            tar_pupil = np.nanmean(rast_rec['pupil'].extract_epoch('TARGET'),
                                   -1)
            tar_base = np.max(ref_locs) + 1
            ax[i].plot(tar_pupil / np.max(all_pupil) + xoffset,
                       np.linspace(tar_base, tar_base + base, len(tar_pupil)),
                       color='k')
            if tar_times[0].size != 0:
                max_rep = tar_pupil.shape[0] - 1
                tar_locs = tar_times[0] * (base / max_rep)
                tar_locs = tar_locs + tar_base
                ax[i].plot(tar_times[1] / rec['resp'].fs,
                           tar_locs,
                           '|',
                           color='red',
                           markersize=1)

            # set ylim
            #ax[i].set_ylim((0, top))

            # set plot aspect
            #asp = np.diff(ax[i].get_xlim())[0] / np.diff(ax[i].get_ylim())[0]
            #ax[i].set_aspect(asp / 2)

        f.suptitle(c, fontsize=8)
        f.tight_layout()
示例#8
0
def fit_tf(modelspec,
           est: recording.Recording,
           use_modelspec_init: bool = True,
           optimizer: str = 'adam',
           max_iter: int = 10000,
           cost_function: str = 'squared_error',
           early_stopping_steps: int = 5,
           early_stopping_tolerance: float = 5e-4,
           early_stopping_val_split: float = 0,
           learning_rate: float = 1e-4,
           variable_learning_rate: bool = False,
           batch_size: typing.Union[None, int] = None,
           seed: int = 0,
           initializer: str = 'random_normal',
           filepath: typing.Union[str, Path] = None,
           freeze_layers: typing.Union[None, list] = None,
           IsReload: bool = False,
           epoch_name: str = "REFERENCE",
           use_tensorboard: bool = False,
           kernel_regularizer: str = None,
           **context) -> dict:
    """TODO

    :param est:
    :param modelspec:
    :param use_modelspec_init:
    :param optimizer:
    :param max_iter:
    :param cost_function:
    :param early_stopping_steps:
    :param early_stopping_tolerance:
    :param learning_rate:
    :param batch_size:
    :param seed:
    :param filepath:
    :param freeze_layers: Indexes of layers to freeze prior to training. Indexes are modelspec indexes, so are offset
      from model layer indexes.
    :param IsReload:
    :param epoch_name
    :param context:

    :return: dict {'modelspec': modelspec}
    """

    if IsReload:
        return {}

    tf.random.set_seed(seed)
    np.random.seed(seed)
    #os.environ['TF_DETERMINISTIC_OPS'] = '1'   # makes output deterministic, but reduces prediction accuracy

    log.info('Building tensorflow keras model from modelspec.')
    nems.utils.progress_fun()

    # figure out where to save model checkpoints
    job_id = os.environ.get('SLURM_JOBID', None)
    if job_id is not None:
        # if job is running on slurm, need to change model checkpoint dir
        # keep a record of the job id
        modelspec.meta['slurm_jobid'] = job_id

        log_dir_root = Path('/mnt/scratch')
        assert log_dir_root.exists()
        log_dir_base = log_dir_root / Path('SLURM_JOBID' + job_id)
        log_dir_sub = Path(str(modelspec.meta['batch'])) \
                / modelspec.meta.get('cellid', "NOCELL") \
                / modelspec.get_longname()
        filepath = log_dir_base / log_dir_sub
        tbroot = filepath / 'logs'
    elif filepath is None:
        filepath = modelspec.meta['modelpath']
        tbroot = Path(f'/auto/data/tmp/tensorboard/')
    else:
        tbroot = Path(f'/auto/data/tmp/tensorboard/')

    filepath = Path(filepath)
    if not filepath.exists():
        filepath.mkdir(exist_ok=True, parents=True)
    cellid = modelspec.meta.get('cellid', 'CELL')
    tbpath = tbroot / (str(modelspec.meta['batch']) + '_' + cellid + '_' +
                       modelspec.meta['modelname'])
    # TODO: should this code just be deleted then?
    if 0 & use_tensorboard:
        # disabled, this is dumb. it deletes the previous round of fitting (eg, tfinit)
        fileList = glob.glob(str(tbpath / '*' / '*'))
        for filePath in fileList:
            try:
                os.remove(filePath)
            except:
                print("Error while deleting file : ", filePath)

    checkpoint_filepath = filepath / 'weights.hdf5'
    tensorboard_filepath = tbpath
    gradient_filepath = filepath / 'gradients'

    # update seed based on fit index
    seed += modelspec.fit_index

    if (freeze_layers
            is not None) and len(freeze_layers) and (len(freeze_layers)
                                                     == freeze_layers[-1] + 1):
        truncate_model = True
        modelspec_trunc, est_trunc = \
            initializers.modelspec_remove_input_layers(modelspec, est, remove_count=len(freeze_layers))
        modelspec_original = modelspec
        est_original = est
        modelspec = modelspec_trunc
        est = est_trunc
        freeze_layers = None
        log.info(
            f"Special case of freezing: truncating model. fit_index={modelspec.fit_index} cell_index={modelspec.cell_index}"
        )
    else:
        truncate_model = False

    input_name = modelspec.meta.get('input_name', 'stim')
    output_name = modelspec.meta.get('output_name', 'resp')

    # also grab the fs
    fs = est[input_name].fs

    if (epoch_name is not None) and (epoch_name != ""):
        # extract out the raw data, and reshape to (batch, time, channel)
        stim_train = np.transpose(
            est[input_name].extract_epoch(epoch=epoch_name, mask=est['mask']),
            [0, 2, 1])
        resp_train = np.transpose(
            est[output_name].extract_epoch(epoch=epoch_name, mask=est['mask']),
            [0, 2, 1])
    else:
        # extract data as a single batch size (1, time, channel)
        stim_train = np.transpose(
            est.apply_mask()[input_name].as_continuous()[np.newaxis, ...],
            [0, 2, 1])
        resp_train = np.transpose(
            est.apply_mask()[output_name].as_continuous()[np.newaxis, ...],
            [0, 2, 1])

    log.info(
        f'Feature dimensions: {stim_train.shape}; Data dimensions: {resp_train.shape}.'
    )

    if True:
        log.info("adding a tiny bit of noise to resp_train")
        resp_train = resp_train + np.random.randn(*resp_train.shape) / 10000
    # get state if present, and setup training data
    if 'state' in est.signals:
        if (epoch_name is not None) and (epoch_name != ""):
            state_train = np.transpose(
                est['state'].extract_epoch(epoch=epoch_name, mask=est['mask']),
                [0, 2, 1])
        else:
            state_train = np.transpose(
                est.apply_mask()['state'].as_continuous()[np.newaxis, ...],
                [0, 2, 1])
        state_shape = state_train.shape
        log.info(f'State dimensions: {state_shape}')
        train_data = [stim_train, state_train]
    else:
        state_train, state_shape = None, None
        train_data = stim_train

    # get the layers and build the model
    cost_fn = loss_functions.get_loss_fn(cost_function)
    #model_layers = modelspec.modelspec2tf2(
    #    use_modelspec_init=use_modelspec_init, seed=seed, fs=fs,
    #    initializer=initializer, freeze_layers=freeze_layers,
    #    kernel_regularizer=kernel_regularizer)
    model_layers = modelbuilder.modelspec2tf(
        modelspec,
        use_modelspec_init=use_modelspec_init,
        seed=seed,
        fs=fs,
        initializer=initializer,
        freeze_layers=freeze_layers,
        kernel_regularizer=kernel_regularizer)

    if np.any([isinstance(layer, Conv2D_NEMS) for layer in model_layers]):
        # need a "channel" dimension for Conv2D (like rgb channels, not frequency). Only 1 channel for our data.
        stim_train = stim_train[..., np.newaxis]
        train_data = train_data[..., np.newaxis]

    # do some batch sizing logic
    batch_size = stim_train.shape[0] if batch_size == 0 else batch_size

    if variable_learning_rate:
        # TODO: allow other schedule options instead of hard-coding exp decay?
        # TODO: expose exp decay kwargs as kw options? not clear how to choose these parameters
        learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=learning_rate,
            decay_steps=10000,
            decay_rate=0.9)

    from nems.tf.loss_functions import pearson
    model = modelbuilder.ModelBuilder(
        name='Test-model',
        layers=model_layers,
        learning_rate=learning_rate,
        loss_fn=cost_fn,
        optimizer=optimizer,
        metrics=[pearson],
    ).build_model(input_shape=stim_train.shape,
                  state_shape=state_shape,
                  batch_size=batch_size)

    if freeze_layers is not None:
        for freeze_index in freeze_layers:
            log.info(
                f'TF layer #{freeze_index}: "{model.layers[freeze_index + 1].name}" is not trainable.'
            )

    # tracking early termination
    model.early_terminated = False

    # create the callbacks
    early_stopping = callbacks.DelayedStopper(
        monitor='val_loss',
        patience=30 * early_stopping_steps,
        min_delta=early_stopping_tolerance,
        verbose=1,
        restore_best_weights=True)
    regular_stopping = callbacks.DelayedStopper(
        monitor='loss',
        patience=30 * early_stopping_steps,
        min_delta=early_stopping_tolerance,
        verbose=1,
        restore_best_weights=True)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        filepath=str(checkpoint_filepath),
        save_best_only=False,
        save_weights_only=True,
        save_freq=100 * stim_train.shape[0],
        monitor='loss',
        verbose=0)
    sparse_logger = callbacks.SparseProgbarLogger(n_iters=50)
    nan_terminate = tf.keras.callbacks.TerminateOnNaN()
    nan_weight_terminate = callbacks.TerminateOnNaNWeights()
    tensorboard = tf.keras.callbacks.TensorBoard(
        log_dir=str(tensorboard_filepath),  # TODO: generic tensorboard dir?
        histogram_freq=0,  # record the distribution of the weights
        write_graph=False,
        update_freq='epoch',
        profile_batch=0)
    # gradient_logger = callbacks.GradientLogger(filepath=str(gradient_filepath),
    #                                            train_input=stim_train,
    #                                            model=model)

    # save an initial set of weights before freezing, in case of termination before any checkpoints
    #log.info('saving weights to : %s', str(checkpoint_filepath) )
    model.save_weights(str(checkpoint_filepath), overwrite=True)

    if version.parse(tf.__version__) >= version.parse("2.2.0"):
        callback0 = [sparse_logger]
        verbose = 0
    else:
        callback0 = []
        verbose = 2
    # enable the below to log tracked parameters to tensorboard
    if use_tensorboard:
        callback0.append(tensorboard)
        log.info(f'Enabling tensorboard, log: {str(tensorboard_filepath)}')
        # enable the below to record gradients to visualize in tensorboard; this is very slow,
        # and loading all this into tensorboard can use A LOT of memory
        # callback0.append(gradient_logger)

    if early_stopping_val_split > 0:
        callback0.append(early_stopping)
        log.info(
            f'Enabling early stopping, val split: {str(early_stopping_val_split)}'
        )
    else:
        callback0.append(regular_stopping)
        log.info(f'Stop tolerance: min_delta={early_stopping_tolerance}')

    log.info(f'Fitting model (batch_size={batch_size})...')
    history = model.fit(train_data,
                        resp_train,
                        validation_split=early_stopping_val_split,
                        verbose=verbose,
                        epochs=max_iter,
                        callbacks=callback0 + [
                            nan_terminate,
                            nan_weight_terminate,
                            checkpoint,
                        ],
                        batch_size=batch_size)

    # did we terminate on a nan loss or weights? Load checkpoint if so
    if np.all(np.isnan(model.predict(train_data))
              ) or model.early_terminated:  # TODO: should this be np.any()?
        log.warning(
            'Model terminated on nan loss or weights, restoring saved weights.'
        )
        try:
            # this can fail if it nans out before a single checkpoint gets saved, either because no saved weights
            # exist, or it tries to load a in different model from the init
            model.load_weights(str(checkpoint_filepath))
            log.warning('Reloaded previous saved weights after nan loss.')
        except (tf.errors.NotFoundError, ValueError):
            pass

    modelspec = tf2modelspec(model, modelspec)

    if truncate_model:
        log.info("Special case of freezing: restoring truncated model!!!")
        #modelspec_restored, rec_restored = modelspec_restore_input_layers(modelspec_trunc, rec_trunc, modelspec_original)
        modelspec_restored, est_restored = initializers.modelspec_restore_input_layers(
            modelspec, est, modelspec_original)
        est = est_original
        modelspec = modelspec_restored

    # debug: dump modelspec parameters
    #for i in range(len(modelspec)):
    #    log.info(modelspec.phi[i])

    contains_tf_only_layers = np.any(
        ['tf_only' in m['fn'] for m in modelspec.modules])
    if not contains_tf_only_layers:
        # compare the predictions from the model and modelspec
        error = compare_ms_tf(modelspec, model, est, train_data)
        if error > 1e-5:
            log.warning(
                f'Mean difference between NEMS and TF model prediction: {error}'
            )
        else:
            log.info(
                f'Mean difference between NEMS and TF model prediction: {error}'
            )
    else:
        # nothing to compare, ms evaluation is not implemented for this type of model
        pass

    # add in some relevant meta information
    modelspec.meta['n_parms'] = len(modelspec.phi_vector)
    try:
        n_epochs = len(history.history['loss'])
        if 'val_loss' in history.history.keys():
            #val_stop = np.argmin(history.history['val_loss'])
            #loss = history.history['loss'][val_stop]
            loss = np.nanmin(history.history['val_loss'])
        else:
            loss = np.nanmin(history.history['loss'])

    except KeyError:
        n_epochs = 0
        loss = 0
    if modelspec.fit_count == 1:
        modelspec.meta['n_epochs'] = n_epochs
        modelspec.meta['loss'] = loss
    else:
        if modelspec.fit_index == 0:
            modelspec.meta['n_epochs'] = np.zeros(modelspec.fit_count)
            modelspec.meta['loss'] = np.zeros(modelspec.fit_count)
        modelspec.meta['n_epochs'][modelspec.fit_index] = n_epochs
        modelspec.meta['loss'][modelspec.fit_index] = loss

    try:
        max_iter = modelspec.meta['extra_results']
        modelspec.meta['extra_results'] = max(max_iter, n_epochs)
    except KeyError:
        modelspec.meta['extra_results'] = n_epochs

    nems.utils.progress_fun()

    # clean up temp files
    if job_id is not None:
        log.info('removing temporary weights file(s)')
        shutil.rmtree(log_dir_base)

    return {'modelspec': modelspec}
示例#9
0
site = 'TAR010c'
nPCs = 4
batch = 307
fs = 20
rawid = pu.which_rawids(site)
ops = {
    'batch': batch,
    'pupil': 1,
    'rasterfs': fs,
    'siteid': site,
    'stim': 0,
    'rawid': rawid
}
uri = nb.baphy_load_recording_uri(**ops)
rec = Recording.load(uri)
rec['resp'] = rec['resp'].rasterize()
rec = rec.and_mask(['HIT_TRIAL', 'MISS_TRIAL', 'PASSIVE_EXPERIMENT'])

rec = rec.and_mask(['PreStimSilence', 'PostStimSilence'], invert=True)
rec = rec.apply_mask(reset_epochs=True)

# create four state masks
rec = preproc.create_ptd_masks(rec)

# create the appropriate dictionaries of responses
# don't *think* it's so important to balanced reps for
# this analysis since it's parametric (as opposed to the
# discrimination calc which depends greatly on being balanced)

# get list of unique targets present in both a/p. This is for
示例#10
0
def recording(signal1, signal2):
    signals = {signal1.name: signal1, signal2.name: signal2}
    return Recording(signals)
示例#11
0
文件: fit_ccnorm.py 项目: LBHB/NEMS
def fit_ccnorm(modelspec,
               est: recording.Recording,
               metric=None,
               use_modelspec_init: bool = True,
               optimizer: str = 'adam',
               max_iter: int = 10000,
               early_stopping_steps: int = 5,
               tolerance: float = 5e-4,
               learning_rate: float = 1e-4,
               batch_size: typing.Union[None, int] = None,
               seed: int = 0,
               initializer: str = 'random_normal',
               freeze_layers: typing.Union[None, list] = None,
               epoch_name: str = "REFERENCE",
               shrink_cc: float = 0,
               noise_pcs: int = 0,
               shared_pcs: int = 0,
               also_fit_resp: bool = False,
               force_psth: bool = False,
               use_metric: typing.Union[None, str] = None,
               alpha: float = 0.1,
               beta: float = 1,
               exclude_idx=None,
               exclude_after=None,
               freeze_idx=None,
               freeze_after=None,
               **context):
    '''
    Required Arguments:
     est          A recording object
     modelspec     A modelspec object

    Optional Arguments:
     <copied from fit_tf for now

    Returns
     dictionary: {'modelspec': updated_modelspec}
    '''

    # Hard-coded
    cost_function = basic_cost
    fitter = scipy_minimize
    segmentor = nems.segmentors.use_all_data
    mapper = nems.fitters.mappers.simple_vector
    fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter}

    start_time = time.time()

    fit_index = modelspec.fit_index
    if (exclude_idx is not None) | (freeze_idx is not None) | \
            (exclude_after is not None) | (freeze_after is not None):
        modelspec0 = modelspec.copy()
        modelspec, include_set = modelspec_freeze_layers(
            modelspec,
            include_idx=None,
            exclude_idx=exclude_idx,
            exclude_after=exclude_after,
            freeze_idx=freeze_idx,
            freeze_after=freeze_after)
        modelspec0.set_fit(fit_index)
        modelspec.set_fit(fit_index)
    else:
        include_set = None

    # Computing PCs before masking out unwanted stimuli in order to
    # preserve match with epochs
    epoch_regex = "^STIM_"
    stims = (est.epochs['name'].value_counts() >= 8)
    stims = [
        stims.index[i] for i, s in enumerate(stims)
        if bool(re.search(epoch_regex, stims.index[i])) and s == True
    ]

    Rall_u = est.apply_mask()['psth'].as_continuous().T
    # can't simply extract evoked for refs because can be longer/shorted if it came after target
    # and / or if it was the last stim. So, masking prestim / postim doesn't work. Do it manually
    #d = est['resp'].extract_epochs(stims, mask=est['mask'])

    #R = [v.mean(axis=0) for (k, v) in d.items()]
    #R = [np.reshape(np.transpose(v,[1,0,2]),[v.shape[1],-1]) for (k, v) in d.items()]
    #Rall_u = np.hstack(R).T

    pca = PCA(n_components=2)
    pca.fit(Rall_u)
    pc_axes = pca.components_

    # apply mask to remove invalid portions of signals and allow fit to
    # only evaluate the model on the valid portion of the signals
    if 'mask_small' in est.signals.keys():
        log.info('reseting mask with mask_small+mask_large subset')
        est['mask'] = est['mask']._modified_copy(data=est['mask_small']._data +
                                                 est['mask_large']._data)

    if 'mask' in est.signals.keys():
        log.info("Data len pre-mask: %d", est['mask'].shape[1])
        est = est.apply_mask()
        log.info("Data len post-mask: %d", est['mask'].shape[1])

    # if we want to fit to first-order cc error.
    #uncomment this and make sure sdexp is generating a pred0 signal
    est = modelspec.evaluate(est, stop=2)
    if ('pred0' in est.signals.keys()) & (not force_psth):
        input_name = 'pred0'
        log.info('Found pred0 for fitting CC')
    else:
        input_name = 'psth'
        log.info('No pred0, using psth for fitting CC')

    conditions = [
        "_".join(k.split("_")[1:]) for k in est.signals.keys()
        if k.startswith("mask_")
    ]
    if (len(conditions) > 2) and any(
        [c.split("_")[-1] == 'lg' for c in conditions]):
        conditions.remove("small")
        conditions.remove("large")
    #conditions = conditions[0:2]
    #conditions = ['large','small']

    group_idx = [est['mask_' + c].as_continuous()[0, :] for c in conditions]
    cg_filtered = [(c, g) for c, g in zip(conditions, group_idx)
                   if g.sum() > 0]
    conditions, group_idx = zip(*cg_filtered)

    for c, g in zip(conditions, group_idx):
        log.info(f"cc data for {c} len {g.sum()}")

    resp = est['resp'].as_continuous()
    pred0 = est[input_name].as_continuous()
    #import pdb; pdb.set_trace()
    if shrink_cc > 0:
        log.info(f'cc approx: shrink_cc={shrink_cc}')
        group_cc = [
            cc_shrink(resp[:, idx] - pred0[:, idx], sigrat=shrink_cc)
            for idx in group_idx
        ]
    elif shared_pcs > 0:
        log.info(f'cc approx: shared_pcs={shared_pcs}')
        cc = np.cov(resp - pred0)
        u, s, vh = np.linalg.svd(cc)
        U = u[:, :shared_pcs] @ u[:, :shared_pcs].T

        group_cc = [
            cc_shared_space(resp[:, idx] - pred0[:, idx], U)
            for idx in group_idx
        ]
    elif noise_pcs > 0:
        log.info(f'cc approx: noise_pcs={noise_pcs}')
        group_cc = [
            cc_lowrank(resp[:, idx] - pred0[:, idx], n_pcs=noise_pcs)
            for idx in group_idx
        ]
    else:
        group_cc = [np.cov(resp[:, idx] - pred0[:, idx]) for idx in group_idx]
    group_cc_raw = [np.cov(resp[:, idx] - pred0[:, idx]) for idx in group_idx]

    # variance of projection onto PCs (PCs computed above before masking)
    pcproj0 = (resp - pred0).T.dot(pc_axes.T).T
    pcproj_std = pcproj0.std(axis=1)

    if (use_metric == 'cc_err_w'):

        def metric(d, verbose=False):
            return metrics.cc_err_w(d,
                                    pred_name='pred',
                                    pred0_name=input_name,
                                    group_idx=group_idx,
                                    group_cc=group_cc,
                                    alpha=alpha,
                                    pcproj_std=None,
                                    pc_axes=None,
                                    verbose=verbose)

        log.info(f"fit_ccnorm metric: cc_err_w (alpha={alpha})")

    elif (metric is None) and also_fit_resp:
        log.info(f"resp_cc_err: pred0_name: {input_name} beta: {beta}")
        metric = lambda d: metrics.resp_cc_err(d,
                                               pred_name='pred',
                                               pred0_name=input_name,
                                               group_idx=group_idx,
                                               group_cc=group_cc,
                                               beta=beta,
                                               pcproj_std=None,
                                               pc_axes=None)

    elif (use_metric == 'cc_err_md'):

        def metric(d, verbose=False):
            return metrics.cc_err_md(d,
                                     pred_name='pred',
                                     pred0_name=input_name,
                                     group_idx=group_idx,
                                     group_cc=group_cc,
                                     pcproj_std=None,
                                     pc_axes=None)

        log.info(f"fit_ccnorm metric: cc_err_md")

    elif (metric is None):
        #def cc_err(result, pred_name='pred_lv', resp_name='resp', pred0_name='pred',
        #   group_idx=None, group_cc=None, pcproj_std=None, pc_axes=None):
        # current implementation of cc_err
        metric = lambda d: metrics.cc_err(d,
                                          pred_name='pred',
                                          pred0_name=input_name,
                                          group_idx=group_idx,
                                          group_cc=group_cc,
                                          pcproj_std=None,
                                          pc_axes=None)
        log.info(f"fit_ccnorm metric: cc_err")

    # turn on "fit mode". currently this serves one purpose, for normalization
    # parameters to be re-fit for the output of each module that uses
    # normalization. does nothing if normalization is not being used.
    ms.fit_mode_on(modelspec, est)

    # Create the mapper functions that translates to and from modelspecs.
    # It has three functions that, when defined as mathematical functions, are:
    #    .pack(modelspec) -> fitspace_point
    #    .unpack(fitspace_point) -> modelspec
    #    .bounds(modelspec) -> fitspace_bounds
    packer, unpacker, pack_bounds = mapper(modelspec)

    # A function to evaluate the modelspec on the data
    evaluator = nems.modelspec.evaluate

    my_cost_function = cost_function
    my_cost_function.counter = 0

    # Freeze everything but sigma, since that's all the fitter should be
    # updating.
    cost_fn = partial(my_cost_function,
                      unpacker=unpacker,
                      modelspec=modelspec,
                      data=est,
                      segmentor=segmentor,
                      evaluator=evaluator,
                      metric=metric,
                      display_N=1000)

    # get initial sigma value representing some point in the fit space,
    # and corresponding bounds for each value
    sigma = packer(modelspec)
    bounds = pack_bounds(modelspec)

    # Results should be a list of modelspecs
    # (might only be one in list, but still should be packaged as a list)
    improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs)
    improved_modelspec = unpacker(improved_sigma)
    elapsed_time = (time.time() - start_time)

    start_err = cost_fn(sigma)
    final_err = cost_fn(improved_sigma)
    log.info("Delta error: %.06f - %.06f = %e", start_err, final_err,
             final_err - start_err)

    # TODO: Should this maybe be moved to a higher level
    # so it applies to ALL the fittters?
    ms.fit_mode_off(improved_modelspec)

    if include_set is not None:
        # pull out updated phi values from improved_modelspec, include_set only
        improved_modelspec = \
            modelspec_unfreeze_layers(improved_modelspec, modelspec0, include_set)
        improved_modelspec.set_fit(fit_index)

    log.info(
        f"Updating improved modelspec with fit_idx={improved_modelspec.fit_index}"
    )
    improved_modelspec.meta['fitter'] = 'ccnorm'
    improved_modelspec.meta['n_parms'] = len(improved_sigma)
    if modelspec.fit_count == 1:
        improved_modelspec.meta['fit_time'] = elapsed_time
        improved_modelspec.meta['loss'] = final_err
    else:
        if modelspec.fit_index == 0:
            improved_modelspec.meta['fit_time'] = np.zeros(
                improved_modelspec.fit_count)
            improved_modelspec.meta['loss'] = np.zeros(
                improved_modelspec.fit_count)
        improved_modelspec.meta['fit_time'][fit_index] = elapsed_time
        improved_modelspec.meta['loss'][fit_index] = final_err

    return {'modelspec': improved_modelspec}
示例#12
0
def fit_tf(modelspec,
           est: recording.Recording,
           use_modelspec_init: bool = True,
           optimizer: str = 'adam',
           max_iter: int = 10000,
           cost_function: str = 'squared_error',
           early_stopping_steps: int = 5,
           early_stopping_tolerance: float = 5e-4,
           learning_rate: float = 1e-4,
           batch_size: typing.Union[None, int] = None,
           seed: int = 0,
           initializer: str = 'random_normal',
           filepath: typing.Union[str, Path] = None,
           freeze_layers: typing.Union[None, list] = None,
           IsReload: bool = False,
           epoch_name: str = "REFERENCE",
           **context) -> dict:
    """TODO

    :param est:
    :param modelspec:
    :param use_modelspec_init:
    :param optimizer:
    :param max_iter:
    :param cost_function:
    :param early_stopping_steps:
    :param early_stopping_tolerance:
    :param learning_rate:
    :param batch_size:
    :param seed:
    :param filepath:
    :param freeze_layers: Indexes of layers to freeze prior to training. Indexes are modelspec indexes, so are offset
      from model layer indexes.
    :param IsReload:
    :param epoch_name
    :param context:

    :return: dict {'modelspec': modelspec}
    """

    if IsReload:
        return {}

    tf.random.set_seed(seed)
    np.random.seed(seed)
    #os.environ['TF_DETERMINISTIC_OPS'] = '1'   # makes output deterministic, but reduces prediction accuracy

    log.info('Building tensorflow keras model from modelspec.')
    nems.utils.progress_fun()

    # figure out where to save model checkpoints
    if filepath is None:
        filepath = modelspec.meta['modelpath']

    # if job is running on slurm, need to change model checkpoint dir
    job_id = os.environ.get('SLURM_JOBID', None)
    if job_id is not None:
        # keep a record of the job id
        modelspec.meta['slurm_jobid'] = job_id

        log_dir_root = Path('/mnt/scratch')
        assert log_dir_root.exists()
        log_dir_sub = Path('SLURM_JOBID' + job_id) / str(modelspec.meta['batch'])\
                      / modelspec.meta.get('cellid', "NOCELL")\
                      / modelspec.meta['modelname']
        filepath = log_dir_root / log_dir_sub

    filepath = Path(filepath)
    if not filepath.exists():
        filepath.mkdir(exist_ok=True, parents=True)

    checkpoint_filepath = filepath / 'weights.hdf5'
    tensorboard_filepath = filepath / 'logs'
    gradient_filepath = filepath / 'gradients'

    # update seed based on fit index
    seed += modelspec.fit_index

    # need to get duration of stims in order to reshape data
    #epoch_name = 'REFERENCE'  # TODO: this should not be hardcoded
    # moved to input parameter

    input_name = modelspec.meta.get('input_name', 'stim')
    output_name = modelspec.meta.get('output_name', 'resp')

    # also grab the fs
    fs = est[input_name].fs

    if (epoch_name is not None) and (epoch_name != ""):
        # extract out the raw data, and reshape to (batch, time, channel)
        stim_train = np.transpose(
            est[input_name].extract_epoch(epoch=epoch_name, mask=est['mask']),
            [0, 2, 1])
        resp_train = np.transpose(
            est[output_name].extract_epoch(epoch=epoch_name, mask=est['mask']),
            [0, 2, 1])
    else:
        # extract data as a single batch size (1, time, channel)
        stim_train = np.transpose(
            est.apply_mask()[input_name].as_continuous()[np.newaxis, ...],
            [0, 2, 1])
        resp_train = np.transpose(
            est.apply_mask()[output_name].as_continuous()[np.newaxis, ...],
            [0, 2, 1])

    log.info(
        f'Feature dimensions: {stim_train.shape}; Data dimensions: {resp_train.shape}.'
    )

    # get state if present, and setup training data
    if 'state' in est.signals:
        if (epoch_name is not None) and (epoch_name != ""):
            state_train = np.transpose(
                est['state'].extract_epoch(epoch=epoch_name, mask=est['mask']),
                [0, 2, 1])
        else:
            state_train = np.transpose(
                est.apply_mask()['state'].as_continuous()[np.newaxis, ...],
                [0, 2, 1])
        state_shape = state_train.shape
        log.info(f'State dimensions: {state_shape}')
        train_data = [stim_train, state_train]
    else:
        state_train, state_shape = None, None
        train_data = stim_train

    # correlation for monitoring
    # TODO: tf.utils?
    def pearson(y_true, y_pred):
        return tfp.stats.correlation(y_true,
                                     y_pred,
                                     event_axis=None,
                                     sample_axis=None)

    # get the layers and build the model
    cost_fn = loss_functions.get_loss_fn(cost_function)
    model_layers = modelspec.modelspec2tf2(
        use_modelspec_init=use_modelspec_init,
        seed=seed,
        fs=fs,
        initializer=initializer)
    if np.any([isinstance(layer, Conv2D_NEMS) for layer in model_layers]):
        # need a "channel" dimension for Conv2D (like rgb channels, not frequency). Only 1 channel for our data.
        stim_train = stim_train[..., np.newaxis]
        train_data = train_data[..., np.newaxis]

    # do some batch sizing logic
    batch_size = stim_train.shape[0] if batch_size == 0 else batch_size

    model = modelbuilder.ModelBuilder(
        name='Test-model',
        layers=model_layers,
        learning_rate=learning_rate,
        loss_fn=cost_fn,
        optimizer=optimizer,
        metrics=[pearson],
    ).build_model(input_shape=stim_train.shape,
                  state_shape=state_shape,
                  batch_size=batch_size)

    # tracking early termination
    model.early_terminated = False

    # create the callbacks
    early_stopping = callbacks.DelayedStopper(
        monitor='loss',
        patience=30 * early_stopping_steps,
        min_delta=early_stopping_tolerance,
        verbose=1,
        restore_best_weights=False)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        filepath=str(checkpoint_filepath),
        save_best_only=False,
        save_weights_only=True,
        save_freq=100 * stim_train.shape[0],
        monitor='loss',
        verbose=0)
    sparse_logger = callbacks.SparseProgbarLogger(n_iters=10)
    nan_terminate = tf.keras.callbacks.TerminateOnNaN()
    nan_weight_terminate = callbacks.TerminateOnNaNWeights()
    tensorboard = tf.keras.callbacks.TensorBoard(
        log_dir=str(tensorboard_filepath),  # TODO: generic tensorboard dir?
        histogram_freq=0,  # record the distribution of the weights
        write_graph=False,
        update_freq='epoch',
        profile_batch=0)
    # gradient_logger = callbacks.GradientLogger(filepath=str(gradient_filepath),
    #                                            train_input=stim_train,
    #                                            model=model)

    # freeze layers
    if freeze_layers is not None:
        for freeze_index in freeze_layers:
            log.info(
                f'Freezing layer #{freeze_index}: "{model.layers[freeze_index + 1].name}".'
            )
            model.layers[freeze_index + 1].trainable = False

    # save an initial set of weights before freezing, in case of termination before any checkpoints
    #log.info('saving weights to : %s', str(checkpoint_filepath) )
    model.save_weights(str(checkpoint_filepath), overwrite=True)

    if version.parse(tf.__version__) >= version.parse("2.2.0"):
        callback0 = [sparse_logger]
        verbose = 0
    else:
        callback0 = []
        verbose = 2

    log.info(f'Fitting model (batch_size={batch_size})...')
    history = model.fit(
        train_data,
        resp_train,
        # validation_split=0.2,
        verbose=verbose,
        epochs=max_iter,
        batch_size=batch_size,
        callbacks=callback0 + [
            nan_terminate,
            nan_weight_terminate,
            early_stopping,
            checkpoint,
            # enable the below to log tracked parameters to tensorboard
            # tensorboard,
            # enable the below to record gradients to visualize in tensorboard; this is very slow,
            # and loading all this into tensorboard can use A LOT of memory
            # gradient_logger,
        ])

    # did we terminate on a nan loss or weights? Load checkpoint if so
    if np.all(np.isnan(model.predict(train_data))
              ) or model.early_terminated:  # TODO: should this be np.any()?
        log.warning(
            'Model terminated on nan loss or weights, restoring saved weights.'
        )
        try:
            # this can fail if it nans out before a single checkpoint gets saved, either because no saved weights
            # exist, or it tries to load a in different model from the init
            model.load_weights(str(checkpoint_filepath))
            log.warning('Reloaded previous saved weights after nan loss.')
        except (tf.errors.NotFoundError, ValueError):
            pass

    modelspec = tf2modelspec(model, modelspec)

    contains_tf_only_layers = np.any(
        ['tf_only' in m['fn'] for m in modelspec.modules])
    if not contains_tf_only_layers:
        # compare the predictions from the model and modelspec
        error = compare_ms_tf(modelspec, model, est, train_data)
        if error > 1e-5:
            log.warning(
                f'Mean difference between NEMS and TF model prediction: {error}'
            )
        else:
            log.info(
                f'Mean difference between NEMS and TF model prediction: {error}'
            )
    else:
        # nothing to compare, ms evaluation is not implemented for this type of model
        pass

    # add in some relevant meta information
    modelspec.meta['n_parms'] = len(modelspec.phi_vector)
    try:
        n_epochs = len(history.history['loss'])
    except KeyError:
        n_epochs = 0
    try:
        max_iter = modelspec.meta['extra_results']
        modelspec.meta['extra_results'] = max(max_iter, n_epochs)
    except KeyError:
        modelspec.meta['extra_results'] = n_epochs

    nems.utils.progress_fun()

    return {'modelspec': modelspec}
示例#13
0
relative_signals_dir = '../recordings'
relative_modelspecs_dir = '../modelspecs'

# Convert to absolute paths so they can be passed to functions in
# other directories
signals_dir = os.path.abspath(relative_signals_dir)
modelspecs_dir = os.path.abspath(relative_modelspecs_dir)

# ----------------------------------------------------------------------------
# DATA LOADING

# GOAL: Get your data loaded into memory as a Recording object
logging.info('Loading data...')

# Method #1: Load the data from a local directory
rec = Recording.load(os.path.join(signals_dir, 'eno052d-a1.tgz'))

# Method #2: Load the data from baphy using the (incomplete, TODO) HTTP API:
#URL = "http://potoroo:3004/baphy/271/bbl086b-11-1?rasterfs=200"
#rec = Recording.load_url(URL)

logging.info('Generating state signal...')

rec = preproc.make_state_signal(rec, ['pupil'], [''], 'state')

# ----------------------------------------------------------------------------
# INITIALIZE MODELSPEC

# GOAL: Define the model that you wish to test

logging.info('Initializing modelspec(s)...')
示例#14
0
relative_signals_dir = '../recordings'
relative_modelspecs_dir = '../modelspecs'

# Convert to absolute paths so they can be passed to functions in
# other directories
signals_dir = os.path.abspath(relative_signals_dir)
modelspecs_dir = os.path.abspath(relative_modelspecs_dir)

# ----------------------------------------------------------------------------
# DATA LOADING

# GOAL: Get your data loaded into memory as a Recording object
logging.info('Loading data...')

# Method #1: Load the data from a local directory
rec = Recording.load(os.path.join(signals_dir, 'TAR010c.NAT.fs50.tgz'))

cellid = 'TAR010c-16-2'
rec['resp'] = rec['resp'].extract_channels([cellid])

# Method #2: Load the data from baphy using the (incomplete, TODO) HTTP API:
#URL = "http://potoroo:3004/baphy/271/bbl086b-11-1?rasterfs=200"
#rec = Recording.load_url(URL)

logging.info('Generating state signal...')

rec = preproc.make_state_signal(rec, ['pupil'], [''], 'state')

# ----------------------------------------------------------------------------
# DATA WITHHOLDING
示例#15
0
sites = np.unique([c[:7] for c in cellids])

# loop over sites (to speed loading)
for site in sites:
    # load recording(s)
    rasterfs = 100
    ops = {
        'batch': batch,
        'siteid': site,
        'rasterfs': rasterfs,
        'pupil': 1,
        'stim': 0,
        'recache': False
    }
    uri = nb.baphy_load_recording_uri(**ops)
    rec100 = Recording.load(uri)
    rec100['resp'] = rec100['resp'].rasterize()
    rec100 = rec100.and_mask(
        ['HIT_TRIAL', 'CORRECT_REJECT_TRIAL', 'PASSIVE_EXPERIMENT'])

    rasterfs = 20
    ops = {
        'batch': batch,
        'siteid': site,
        'rasterfs': rasterfs,
        'pupil': 1,
        'stim': 0,
        'recache': False
    }
    uri = nb.baphy_load_recording_uri(**ops)
    rec20 = Recording.load(uri)
示例#16
0
文件: fit_ccnorm.py 项目: LBHB/NEMS
def fit_pcnorm(modelspec,
               est: recording.Recording,
               metric=None,
               use_modelspec_init: bool = True,
               optimizer: str = 'adam',
               max_iter: int = 10000,
               early_stopping_steps: int = 5,
               tolerance: float = 5e-4,
               learning_rate: float = 1e-4,
               batch_size: typing.Union[None, int] = None,
               seed: int = 0,
               initializer: str = 'random_normal',
               freeze_layers: typing.Union[None, list] = None,
               epoch_name: str = "REFERENCE",
               n_pcs=2,
               **context):
    '''
    Required Arguments:
     est          A recording object
     modelspec     A modelspec object

    Optional Arguments:
     <copied from fit_tf for now

    Returns
     dictionary: {'modelspec': updated_modelspec}
    '''

    # Hard-coded
    cost_function = basic_cost
    fitter = scipy_minimize
    segmentor = nems.segmentors.use_all_data
    mapper = nems.fitters.mappers.simple_vector
    fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter}

    start_time = time.time()

    modelspec = copy.deepcopy(modelspec)

    # apply mask to remove invalid portions of signals and allow fit to
    # only evaluate the model on the valid portion of the signals
    if 'mask' in est.signals.keys():
        log.info("Data len pre-mask: %d", est['mask'].shape[1])
        est = est.apply_mask()
        log.info("Data len post-mask: %d", est['mask'].shape[1])

    conditions = [
        "_".join(k.split("_")[1:]) for k in est.signals.keys()
        if k.startswith("mask_")
    ]
    if (len(conditions) > 2) and any(
        [c.split("_")[-1] == 'lg' for c in conditions]):
        conditions.remove("small")
        conditions.remove("large")
    #conditions = conditions[0:2]
    #conditions = ['large','small']

    group_idx = [est['mask_' + c].as_continuous()[0, :] for c in conditions]
    cg_filtered = [(c, g) for c, g in zip(conditions, group_idx)
                   if g.sum() > 0]
    conditions, group_idx = zip(*cg_filtered)

    for c, g in zip(conditions, group_idx):
        log.info(f"Data subset for {c} len {g.sum()}")

    resp = est['resp'].as_continuous()
    pred0 = est['pred0'].as_continuous()
    residual = resp - pred0

    pca = PCA(n_components=n_pcs)
    pca.fit(residual.T)
    pc_axes = pca.components_

    pcproj = residual.T.dot(pc_axes.T).T

    group_pc = [pcproj[:, idx].std(axis=1) for idx in group_idx]
    resp_std = resp.std(axis=1)
    #import pdb; pdb.set_trace()

    if metric is None:
        metric = lambda d: pc_err(d,
                                  pred_name='pred',
                                  pred0_name='pred0',
                                  group_idx=group_idx,
                                  group_pc=group_pc,
                                  pc_axes=pc_axes,
                                  resp_std=resp_std)

    # turn on "fit mode". currently this serves one purpose, for normalization
    # parameters to be re-fit for the output of each module that uses
    # normalization. does nothing if normalization is not being used.
    ms.fit_mode_on(modelspec, est)

    # Create the mapper functions that translates to and from modelspecs.
    # It has three functions that, when defined as mathematical functions, are:
    #    .pack(modelspec) -> fitspace_point
    #    .unpack(fitspace_point) -> modelspec
    #    .bounds(modelspec) -> fitspace_bounds
    packer, unpacker, pack_bounds = mapper(modelspec)

    # A function to evaluate the modelspec on the data
    evaluator = nems.modelspec.evaluate

    my_cost_function = cost_function
    my_cost_function.counter = 0

    # Freeze everything but sigma, since that's all the fitter should be
    # updating.
    cost_fn = partial(my_cost_function,
                      unpacker=unpacker,
                      modelspec=modelspec,
                      data=est,
                      segmentor=segmentor,
                      evaluator=evaluator,
                      metric=metric,
                      display_N=1000)

    # get initial sigma value representing some point in the fit space,
    # and corresponding bounds for each value
    sigma = packer(modelspec)
    bounds = pack_bounds(modelspec)

    # Results should be a list of modelspecs
    # (might only be one in list, but still should be packaged as a list)
    improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs)
    improved_modelspec = unpacker(improved_sigma)
    elapsed_time = (time.time() - start_time)

    start_err = cost_fn(sigma)
    final_err = cost_fn(improved_sigma)
    log.info("Delta error: %.06f - %.06f = %e", start_err, final_err,
             final_err - start_err)

    # TODO: Should this maybe be moved to a higher level
    # so it applies to ALL the fittters?
    ms.fit_mode_off(improved_modelspec)
    ms.set_modelspec_metadata(improved_modelspec, 'fitter', 'ccnorm')
    ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time)
    ms.set_modelspec_metadata(improved_modelspec, 'n_parms',
                              len(improved_sigma))

    return {'modelspec': improved_modelspec.copy(), 'save_context': True}
示例#17
0
def from_nwb_pupil(nwb_file, nwb_format,fs=20,with_pupil=False,running_speed=False,as_dict=True):
#def from_nwb(cls, nwb_file, nwb_format,with_pupil=False,fs=20):
    """
    The NWB (Neurodata Without Borders) format is a unified data format developed by the Allen Brain Institute.
    Data is stored as an HDF5 file, with the format varying depending how the data was saved.
    
    References:
      - https://nwb.org
      - https://pynwb.readthedocs.io/en/latest/index.html
    :param nwb_file: path to the nwb file
    :param nwb_format: specifier for how the data is saved in the container
    :param int fs: will match for all signals
    :param bool with_pupil, running speed: whether to return pupil, speed signals in recording
    :param bool as_dict: return a dictionary of recording objects, each corresponding to a single unit/neuron
                         else a single recording object w/ each unit corresponding to a channel in pointprocess signal
    :return: a recording object
    """
    #log.info(f'Loading NWB file with format "{nwb_format}" from "{nwb_file}".')

    # add in supported nwb formats here
    assert nwb_format in ['neuropixel'], f'"{nwb_format}" not a supported NWB file format.'

    nwb_filepath = Path(nwb_file)
    if not nwb_filepath.exists():
        raise FileNotFoundError(f'"{nwb_file}" could not be found.')

    if nwb_format == 'neuropixel':
        """
        In neuropixel ecephys nwb files, data is stored in several attributes of the container: 
          - units: individual cell metadata, a dataframe
          - epochs: timing of the stimuli, series of arrays
          - lab_meta_data: metadata about the experiment, such as specimen details
          
        Spike times are saved as arrays in the 'spike_times' column of the units dataframe as xarrays. 
        The frequency defaults to match pupil - if no pupil data retrieved, set to chosen value (previous default 1250).
          
        Refs:
          - https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html
          - https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_quickstart.html
          - https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_data_access.html
        """
        try:
            from pynwb import NWBHDF5IO
            from allensdk.brain_observatory.ecephys import nwb  # needed for ecephys format compat
        except ImportError:
            m = 'The "allensdk" library is required to work with neuropixel nwb formats, available on PyPI.'
            #log.error(m)
            raise ImportError(m)

        session_name = nwb_filepath.stem
        with NWBHDF5IO(str(nwb_filepath), 'r') as nwb_io:
            nwbfile = nwb_io.read()

            units = nwbfile.units
            epochs = nwbfile.epochs
           
            spike_times = dict(zip(units.id[:].astype(str), units['spike_times'][:]))

            # extract the metadata and convert to dict
            metadata = nwbfile.lab_meta_data['metadata'].to_dict()
            metadata['uri'] = str(nwb_filepath)  # add in uri
            #add invalid times data to meta as df if exist - includes times and probe id?
            if nwbfile.invalid_times is not None:
                invalid_times = nwbfile.invalid_times
                invalid_times =  np.array([invalid_times[col][:] for col in invalid_times.colnames])
                metadata['invalid_times'] = pd.DataFrame(invalid_times.transpose(),columns=['start_time', 'stop_time', 'tags'])
                
            # build the units metadata
            units_data = {
                col.name: col.data for col in units.columns
                if col.name not in ['spike_times', 'spike_times_index', 'spike_amplitudes',
                                    'spike_amplitudes_index', 'waveform_mean', 'waveform_mean_index']
            }

            # needs to be a dict
            units_meta = pd.DataFrame(units_data, index=units.id[:])
            #add electrode info to units meta
            electrodes=nwbfile.electrodes
            e_data = {col.name: col.data for col in electrodes.columns}
            e_meta = pd.DataFrame(e_data,index=electrodes.id[:])
            units_meta=pd.merge(units_meta,e_meta,left_on=units_meta.peak_channel_id,right_index=True, 
                                suffixes=('_unit','_channel')).drop(['key_0','group'],axis=1).to_dict('index')# needs to be a dict    

            # build the epoch dataframe
            epoch_data = {
                col.name: col.data for col in epochs.columns
                if col.name not in ['tags', 'timeseries', 'tags_index', 'timeseries_index']
            }

            epoch_df = pd.DataFrame(epoch_data, index=epochs.id[:]).rename({
                'start_time': 'start',
                'stop_time': 'end',
                'stimulus_name': 'name'
            }, axis='columns')

 
            #rename epochs to correspond to different nat scene/movie frames - 
            epoch_df.loc[epoch_df['frame'].notna(),'name'] = epoch_df.loc[epoch_df['frame'].notna(),'name'] + '_' + \
            epoch_df[epoch_df['frame'].notna()].iloc[:]['frame'].astype(int).astype(str)
            
            
            #drop extra columns
            metadata['epochs']=epoch_df #save extra stim info to meta
            epoch_df=epoch_df.drop([col for col in epoch_df.columns if col not in ['start','end','name']],axis=1)

#            #rename natural scene epochs to work w/demo
            df_copy = epoch_df[epoch_df.name.str.contains('natural_scene')].copy()
            df_copy.loc[:,'name']='REFERENCE'

            epoch_df=epoch_df.append(df_copy,ignore_index=True)
            #expand epoch bounds epochs will overlap to test evoked potential
#            to_adjust=epoch_df.loc[:,['start','end']].to_numpy()
#            epoch_df.loc[:,['start','end']] = nems.epoch.adjust_epoch_bounds(to_adjust,-0.1,0.1)
            
            
            # save the spike times as a point process signal frequency set to match other signals 
            pp = PointProcess(fs, spike_times, name='resp', recording=session_name, epochs=epoch_df,
                              chans=[str(c) for c in nwbfile.units.id[:]],meta=units_meta)
            #dict to pass to recording
            #signal_dict = {pp.name: pp}

          #  log.info('Successfully loaded nwb file.')
            from scipy.interpolate import interp1d
           #save pupil data as rasterized signal
            if with_pupil:
                try:
                    pupil = nwbfile.modules['eye_tracking'].data_interfaces['pupil_ellipse_fits']
                    t = pupil['timestamps'][:]
                    pupil = pupil['width'][:].reshape(1,-1) #only 1 dimension - or get 'height'
                    
                     #interpolate to set sampling rate
                    f = interp1d(t,pupil,bounds_error=False,fill_value=np.nan)

                    new_t = np.arange(0.0,(t.max()+1/fs),1/fs)#pupil data starting at timepoint 0.0 (nan filler)
                    pupil = f(new_t)
                    
                    pupil_signal = RasterizedSignal(fs=fs,data=pupil,recording=session_name,name='pupil',
                                                    epochs=epoch_df,chans=['pupil']) #for all data list(pupil_data.colnames[0:5])
                    
                #if no pupil data for session - still get spike data
                except KeyError:
                    print(session_name + ' has no pupil data.')

            
            if running_speed:
                running = nwbfile.modules['running'].data_interfaces['running_speed']
                t = running.timestamps[:][1]#data has start and end timestamps, here only end used
                running = running.data[:].reshape(1,-1)

                f = interp1d(t,running)
                #new_t = np.arange(np.min(t),np.max(t),1/fs)
                new_t = np.arange(epoch_df.start.min(),epoch_df.end.max(),1/fs)
                running = f(new_t)
                running=RasterizedSignal(fs=fs,data=running,name='running',recording=session_name,epochs=epoch_df)



            if as_dict:
                #each unit has seperate recording in dict
                rec_dict={}
                for c in pp.chans:
                    unit_signal=pp.extract_channels([c])
                    rec=Recording({'resp':unit_signal},meta=metadata)
                    if with_pupil:
                        rec.add_signal(pupil_signal)
                    if running_speed:
                        rec.add_signal(running)
                    rec_dict[c]=rec
                return rec_dict
            
            else:
                rec=Recording({'resp':pp},meta=metadata)
                if with_pupil:
                    rec.add_signal(pupil_signal)
                if running_speed:
                    rec.add_signal(running)
                return rec
示例#18
0
文件: test_plots.py 项目: nadoss/NEMS
import random
import numpy as np
from functools import partial
import matplotlib.pyplot as plt
import nems.epoch as ep
import nems.modelspec as ms
import nems.plots.api as nplt
from nems.recording import Recording

# specify directories for loading data and fitted modelspec
signals_dir = '../signals'
#signals_dir = '/home/jacob/auto/data/batch271_fs100_ozgf18/'
modelspecs_dir = '../modelspecs'

# load the data
rec = Recording.load(os.path.join(signals_dir, 'TAR010c-18-1'))

# Add a new signal, respavg, to the recording, in 4 steps

# 1. Fold matrix over all stimuli, returning a dictionary where keys are stimuli
#    and each value in the dictionary is (reps X cell X bins)
epochs_to_extract = ep.epoch_names_matching(rec.epochs, '^STIM_')
folded_matrix = rec['resp'].extract_epochs(epochs_to_extract)

# 2. Average over all reps of each stim and save into dict called psth.
per_stim_psth = dict()
for k in folded_matrix.keys():
    per_stim_psth[k] = np.nanmean(folded_matrix[k], axis=0)

# 3. Invert the folding to unwrap the psth back out into a predicted spike_dict by
# simply replacing all epochs in the signal with their psth
示例#19
0
def generate_state_corrected_psth(batch=None, modelname=None, cellids=None, siteid=None, movement_mask=False,
                                        gain_only=False, dc_only=False, cache_path=None, recache=False):
    """
    Modifies the exisiting recording so that psth signal is the prediction specified
    by the modelname. Designed with stategain models in mind. CRH.

    If the model doesn't exist already in /auto/users/hellerc/results/, this
    will go ahead and fit the model and save it in /auto/users/hellerc/results.

    If the fit dir (from xforms) exists, simply reload the result and call this psth.
    """
    if siteid is None:
        raise ValueError("must specify siteid!")
    if cache_path is not None:
        fn = cache_path + siteid + '_{}.tgz'.format(modelname.split('.')[1])
        if gain_only:
            fn = fn.replace('.tgz', '_gonly.tgz')
        if 'mvm' in modelname:
            fn = fn.replace('.tgz', '_mvm.tgz')
        if (os.path.isfile(fn)) & (recache == False):
            rec = Recording.load(fn)
            return rec
        else:
            # do the rest of the code
            pass

    if batch is None or modelname is None:
        raise ValueError('Must specify batch and modelname!')
    results_table = nd.get_results_file(batch, modelnames=[modelname])
    preds = []
    ms = []
    for cell in cellids:
        log.info(cell)
        try:
            p = results_table[results_table['cellid']==cell]['modelpath'].values[0]
            if os.path.isdir(p):
                xfspec, ctx = xforms.load_analysis(p)
                preds.append(ctx['val'])
                ms.append(ctx['modelspec'])
            else:
                sys.exit('Fit for {0} does not exist'.format(cell))
        except:
            log.info("WARNING: fit doesn't exist for cell {0}".format(cell))

    # Need to add a check to make sure that the preds are the same length (if
    # multiple cellids). This could be violated if one cell for example existed
    # in a prepassive run but the other didn't and so they were fit differently
    file_epochs = []

    for pr in preds:
        file_epochs += [ep for ep in pr.epochs.name if ep.startswith('FILE')]

    unique_files = np.unique(file_epochs)
    shared_files = []
    for f in unique_files:
        if np.sum([1 for file in file_epochs if file == f]) == len(preds):
            shared_files.append(str(f))
        else:
            # this rawid didn't span all cells at the requested site
            pass

    # mask all file epochs for all preds with the shared file epochs
    # and adjust epochs
    if (int(batch) == 307) | (int(batch) == 294):
        for i, p in enumerate(preds):
            preds[i] = p.and_mask(shared_files)
            preds[i] = preds[i].apply_mask(reset_epochs=True)

    sigs = {}
    for i, p in enumerate(preds):
        if gain_only:
            # update phi
            mspec = ms[i]
            not_gain_keys = [k for k in mspec[0]['phi'].keys() if '_g' not in k]
            for k in not_gain_keys:
                mspec[0]['phi'][k] = np.append(mspec[0]['phi'][k][0, 0], np.zeros(mspec[0]['phi'][k].shape[-1]-1))[np.newaxis, :]
            pred = mspec.evaluate(p)['pred']
        elif dc_only:
            mspec = ms[i]
            not_dc_keys = [k for k in mspec[0]['phi'].keys() if '_d' not in k]
            for k in not_dc_keys:
                mspec[0]['phi'][k] = np.append(mspec[0]['phi'][k][0, 0], np.zeros(mspec[0]['phi'][k].shape[-1]-1))[np.newaxis, :]
            pred = mspec.evaluate(p)['pred']
        else:
            pred = p['pred'] 
        if i == 0:           
            new_psth_sp = p['psth_sp']
            new_psth = pred
            new_resp = p['resp'].rasterize()

        else:
            try:
                new_psth_sp = new_psth_sp.concatenate_channels([new_psth_sp, p['psth_sp']])
                new_psth = new_psth.concatenate_channels([new_psth, pred])
                new_resp = new_resp.concatenate_channels([new_resp, p['resp'].rasterize()])
            except ValueError:
                import pdb; pdb.set_trace()

    new_pup = preds[0]['pupil']
    sigs['pupil'] = new_pup

    if 'pupil_raw' in preds[0].signals.keys():
        sigs['pupil_raw'] = preds[0]['pupil_raw']

    if 'mask' in preds[0].signals:
        new_mask = preds[0]['mask']
        sigs['mask'] = new_mask
    else:
        mask_rec = preds[0].create_mask(True)
        new_mask = mask_rec['mask']
        sigs['mask'] = new_mask

    if 'rem' in preds[0].signals.keys():
        rem = preds[0]['rem']
        sigs['rem'] = rem

    if 'pupil_eyespeed' in preds[0].signals.keys():
        new_eyespeed = preds[0]['pupil_eyespeed']
        sigs['pupil_eyespeed'] = new_eyespeed

    new_psth_sp.name = 'psth_sp'
    new_psth.name = 'psth'
    new_resp.name = 'resp'
    sigs['psth_sp'] = new_psth_sp
    sigs['psth'] = new_psth
    sigs['resp'] = new_resp

    new_rec = Recording(sigs, meta=preds[0].meta)

    # make sure mask is cast to bool
    new_rec['mask'] = new_rec['mask']._modified_copy(new_rec['mask']._data.astype(bool))

    if cache_path is not None:
        log.info('caching {}'.format(fn))
        new_rec.save_targz(fn)

    return new_rec
示例#20
0
# get rid of any "cells" that never fired
idx = (resp==0).sum(axis=-1) == T
resp = resp[~idx, :]
psth = psth[~idx, :]
nCells = resp.shape[0]

# pack into nems recording
resp_sig = RasterizedSignal(fs=4, data=resp, name='resp', recording='simulation')
psth_sig = RasterizedSignal(fs=4, data=psth, name='psth', recording='simulation')
pupil_sig = RasterizedSignal(fs=4, data=pupil, name='pupil', recording='simulation')
bm = pupil > pupil.mean()
big_mask = RasterizedSignal(fs=4, data=bm, name='big_mask', recording='simulation')
sm = pupil < pupil.mean()
small_mask = RasterizedSignal(fs=4, data=sm, name='small_mask', recording='simulation')

rec = Recording({'resp': resp_sig, 'psth': psth_sig, 'pupil': pupil_sig})

# fit the GLM for different hyperparameters
x0 = np.zeros(3 * nCells)
nLV = 1
alpha1 = np.arange(0, 0.5, 0.05)
results = dict.fromkeys(alpha1)
for i, a2 in tqdm(enumerate(alpha1)):
    model_output = opt.minimize(glm.gain_only_objective, x0, (rec, big_mask, small_mask, a2, nLV), options={'gtol':1e-6, 'disp': True})
    weights = model_output.x

    # get model output
    w1 = weights.reshape((2+nLV), nCells)[0, :]
    w2 = weights.reshape((2+nLV), nCells)[1:-1, :]
    b = weights.reshape((2+nLV), nCells)[-1, :]
示例#21
0
            keyname = 'data_stream'
            #f_in = io.BytesIO(t.extractfile(member).read())

            # current non-optimal solution. extract hdf5 file to disk and then load
            t.extract(member, tpath)
            f = tpath + '/' + member.name

        elif basename.endswith('.json'):
            keyname = 'json_stream'
            f = io.StringIO(t.extractfile(member).read().decode('utf-8'))

        else:
            m = 'Unexpected file found in tar.gz: {} (size={})'.format(
                member.name, member.size)
            raise ValueError(m)
        # Ensure that we can doubly nest the streams dict
        if signame not in streams:
            streams[signame] = {}
        # Read out a stringIO object for each file now while it's open
        #f = io.StringIO(t.extractfile(member).read().decode('utf-8'))
        streams[signame][keyname] = f

# Now that the streams are organized, convert them into signals
# log.debug({k: streams[k].keys() for k in streams})
signals = [load_signal_from_streams(**sg) for sg in streams.values()]
signals_dict = {s.name: s for s in signals}

rec = Recording(signals=signals_dict)

shutil.rmtree(tpath)  # clean up