Esempio n. 1
0
def test_recording_loading():
    '''
    Test the loading and saving of files to various HTTP/S3/File routes.
    '''
    # NOTE: FOR THIS TEST TO SUCCEED, /auto/data/tmp/recordings/ must not have
    # blah.tar.gz or TAR010c-18-1.tar.gz in it.

    # Local filesystem
    # rec0 = Recording.load("/home/ivar/git/nems/signals/TAR010c-18-1.tar.gz")
    rec_path = join(RECORDING_DIR, "eno052d-a1.tgz")
    rec0 = Recording.load(rec_path)
    rec2 = Recording.load("file://%s" % rec_path)

    # HTTP
    # rec3 = Recording.load("http://hyrax.ohsu.edu:3000/recordings/eno052d-a1.tgz")
    # rec4 = Recording.load("http://hyrax.ohsu.edu:3000/baphy/294/eno052d-a1?stim=0&pupil=0")

    # S3
    # Direct access (would need AWS CLI lib? Maybe not best idea!)
    # TODO: Requires s3 credentials in environment. Probably best if on
    #       server only; put in nems_db?
    #rec5 = Recording.load('s3://mybucket/myfile.tar.gz')

    # Indirect access via http:
    rec6 = Recording.load("https://s3-us-west-2.amazonaws.com/nemspublic/"
                          "sample_data/eno052d-a1.tgz")

    # Save to a specific tar.gz file
    rec0.save('/tmp/tmp.tar.gz')
Esempio n. 2
0
def test_recording_from_arrays():
    # need a list of array-like data structures
    x = np.random.rand(3, 200)
    y = np.random.rand(1, 200)
    z = np.random.rand(5, 200)
    arrays = [x, y, z]
    # a name for the recording that will hold the signals
    rec_name = 'testing123'
    # the sampling rate for the signals, or a list of
    # individual sampling rates (if different)
    fs = [100, 100, 200]
    # a list of signal names (optional, but preferred)
    names = ['stim', 'resp', 'reference']
    # a list of keyword arguments for each signal,
    # such as channel names or epochs (also optional)
    kwargs = [{
        'chans': ['2kHz', '4kHz', '8kHz']
    }, {
        'chans': ['spike_rate']
    }, {
        'meta': {
            'experiment': 'oddball_2'
        },
        'chans': ['One', 'Two', 'Three', 'Four', 'Five']
    }]
    rec = Recording.load_from_arrays(arrays,
                                     rec_name,
                                     fs,
                                     sig_names=names,
                                     signal_kwargs=kwargs)
    # should also work with integer fs instead of list
    rec = Recording.load_from_arrays(arrays,
                                     rec_name,
                                     100,
                                     sig_names=names,
                                     signal_kwargs=kwargs)

    # All signal names should be present in recording signals dict
    contains = [(n in rec.signals.keys()) for n in names]
    assert not (False in contains)

    bad_names = ['stim']
    # should get an error now since len(names)
    # doesn't match up with len(arrays)
    with pytest.raises(ValueError):
        rec = Recording.load_from_arrays(arrays,
                                         rec_name,
                                         fs,
                                         sig_names=bad_names,
                                         signal_kwargs=kwargs)
Esempio n. 3
0
def simple_recording():
    stim = np.random.rand(18, 200)
    resp = np.random.rand(1, 200)
    return Recording.load_from_arrays([stim, resp],
                                      'simple_recording',
                                      100,
                                      sig_names=['stim', 'resp'])
Esempio n. 4
0
File: io.py Progetto: nadoss/nems_db
def remove_nan(rec):
    i = np.isnan(rec['fg'].as_continuous())
    new_signals = {}
    for name, signal in rec.items():
        new_data = signal.as_continuous()[:, i]
        new_signals[name] = signal._modified_copy(new_data)
    return Recording(new_signals)
Esempio n. 5
0
def split_recording(recording):
    '''
    split recording into independent recordings for CPP and CPN, does this to all composing signals
    :param recording: a nems.Recording object
    :return:
    '''

    sub_recordings = col.defaultdict(dict)
    metas = dict()
    for signame, signal in recording.signals.items():

        sub_signals = _split_signal(signal)

        for sig_type, sub_signal in sub_signals.items():
            sub_recordings[sig_type][signame] = sub_signal
            metas[sig_type] = sub_signal.meta

        pass

    sub_recordings = {sig_type: Recording(signals, meta=metas[sig_type]) for sig_type, signals in
                      sub_recordings.items()}

    return sub_recordings
# A demonstration of a minimalist model fitting system

from nems.recording import Recording

# ----------------------------------------------------------------------------
# DATA FETCHING

# GOAL: Get your data loaded into memory

# Method #1: Load the data from a local directory
rec = Recording.load('signals/gus027b13_p_PPS/')

# alternative ways to define the data object that could be saved as a 
# short(!) string in the dataspec
rec = my_create_recording_fun('signals/gus027b13_p_PPS/')
rec = Recording.load_standard_nems_format('signals/gus027b13_p_PPS/')


# Method #2: Load the data from baphy using the (incomplete, TODO) HTTP API:
# URL = "neuralprediction.org:3003/signals?batch=273&cellid=gus027b13-a1"
# rec = fetch_signals_over_http(URL)

# Method #3: Load the data from S3:
# stimfile="https://s3-us-west-2.amazonaws.com/nemspublic/sample_data/"+cellid+"_NAT_stim_ozgf_c18_fs100.mat"
# respfile="https://s3-us-west-2.amazonaws.com/nemspublic/sample_data/"+cellid+"_NAT_resp_fs100.mat"
# rec = fetch_signals_over_http(stimfile, respfile)

# Method #4: Load the data from a jerb (TODO)

# Method #5: Create a Recording object from a matrix, manually (TODO)
Esempio n. 7
0
def plot_collapsed_ref_tar(animal, site, cellids=None):
    site += "%"

    sql = "SELECT DISTINCT cellid, rawid, respfile FROM sCellFile WHERE cellid like %s AND runclassid=%s"
    d = nd.pd_query(sql, params=(
        site,
        42,
    ))

    mfile = []
    for f in np.unique(d.respfile):
        f_ = f.split('.')[0]
        mfile.append('/auto/data/daq/{0}/{1}/{2}'.format(
            animal, site[:-2], f_))

    if cellids is None:
        cellid = np.unique(d.cellid).tolist()
    else:
        cellid = cellids
    options = {
        "siteid": site[:-1],
        'cellid': cellids,
        "mfilename": mfile,
        'stim': False,
        'pupil': True,
        'rasterfs': 1000
    }

    uri = nb.baphy_load_recording_uri(**options)
    rec = Recording.load(uri)
    all_pupil = rec['pupil']._data
    ncols = len(mfile)
    if cellids is None:
        cellids = rec['resp'].chans

    for c in cellids:
        f, ax = plt.subplots(1, ncols, sharey=True, figsize=(12, 5))
        ref_base = 0
        for i, mf in enumerate(mfile):
            fn = mf.split('/')[-1]
            ep_mask = [ep for ep in rec.epochs.name if fn in ep]
            R = rec.copy()
            R['resp'] = R['resp'].rasterize(fs=20)
            R['resp'].fs = 20
            R = R.and_mask(ep_mask).apply_mask(reset_epochs=True)
            if '_a_' in fn:
                R = R.and_mask(['HIT_TRIAL']).apply_mask(reset_epochs=True)

            resp = R['resp'].extract_channels([c])

            tar_reps = resp.extract_epoch('TARGET').shape[0]
            tar_m = np.nanmean(resp.extract_epoch('TARGET'),
                               0).squeeze() * R['resp'].fs
            tar_sem = R['resp'].fs * np.nanstd(resp.extract_epoch('TARGET'),
                                               0).squeeze() / np.sqrt(tar_reps)

            ref_reps = resp.extract_epoch('REFERENCE').shape[0]
            ref_m = np.nanmean(resp.extract_epoch('REFERENCE'),
                               0).squeeze() * R['resp'].fs
            ref_sem = R['resp'].fs * np.nanstd(resp.extract_epoch('REFERENCE'),
                                               0).squeeze() / np.sqrt(ref_reps)

            # plot psth's
            time = np.linspace(0, len(tar_m) / R['resp'].fs, len(tar_m))
            ax[i].plot(time, tar_m, color='red', lw=2)
            ax[i].fill_between(time,
                               tar_m + tar_sem,
                               tar_m - tar_sem,
                               color='coral')
            time = np.linspace(0, len(ref_m) / R['resp'].fs, len(ref_m))
            ax[i].plot(time, ref_m, color='blue', lw=2)
            ax[i].fill_between(time,
                               ref_m + ref_sem,
                               ref_m - ref_sem,
                               color='lightblue')

            # set title
            ax[i].set_title(fn, fontsize=8)
            # set labels
            ax[i].set_xlabel('Time (s)')
            ax[i].set_ylabel('Spk / sec')

            # get raster plot baseline
            base = np.max(np.concatenate((tar_m + tar_sem, ref_m + ref_sem)))
            if base > ref_base:
                ref_base = base

        for i, mf in enumerate(mfile):
            # plot the rasters
            fn = mf.split('/')[-1]
            ep_mask = [ep for ep in rec.epochs.name if fn in ep]
            rast_rec = rec.and_mask(ep_mask).apply_mask(reset_epochs=True)
            if '_a_' in fn:
                rast_rec = rast_rec.and_mask(['HIT_TRIAL'
                                              ]).apply_mask(reset_epochs=True)
            rast = rast_rec['resp'].extract_channels([c])

            ref_times = np.where(rast.extract_epoch('REFERENCE').squeeze())
            base = ref_base
            ref_pupil = np.nanmean(
                rast_rec['pupil'].extract_epoch('REFERENCE'), -1)
            xoffset = rast.extract_epoch(
                'TARGET').shape[-1] / rec['resp'].fs + 0.01
            ax[i].plot(ref_pupil / np.max(all_pupil) + xoffset,
                       np.linspace(base, int(base * 2), len(ref_pupil)),
                       color='k')
            ax[i].axvline(xoffset + np.median(all_pupil / np.max(all_pupil)),
                          linestyle='--',
                          color='lightgray')
            #import pdb; pdb.set_trace()
            if ref_times[0].size != 0:
                max_rep = ref_pupil.shape[0] - 1
                ref_locs = ref_times[0] * (base / max_rep)
                ref_locs = ref_locs + base
                ax[i].plot(ref_times[1] / rec['resp'].fs,
                           ref_locs,
                           '|',
                           color='blue',
                           markersize=1)

            tar_times = np.where(rast.extract_epoch('TARGET').squeeze())
            tar_pupil = np.nanmean(rast_rec['pupil'].extract_epoch('TARGET'),
                                   -1)
            tar_base = np.max(ref_locs) + 1
            ax[i].plot(tar_pupil / np.max(all_pupil) + xoffset,
                       np.linspace(tar_base, tar_base + base, len(tar_pupil)),
                       color='k')
            if tar_times[0].size != 0:
                max_rep = tar_pupil.shape[0] - 1
                tar_locs = tar_times[0] * (base / max_rep)
                tar_locs = tar_locs + tar_base
                ax[i].plot(tar_times[1] / rec['resp'].fs,
                           tar_locs,
                           '|',
                           color='red',
                           markersize=1)

            # set ylim
            #ax[i].set_ylim((0, top))

            # set plot aspect
            #asp = np.diff(ax[i].get_xlim())[0] / np.diff(ax[i].get_ylim())[0]
            #ax[i].set_aspect(asp / 2)

        f.suptitle(c, fontsize=8)
        f.tight_layout()
Esempio n. 8
0
def fit_tf(modelspec,
           est: recording.Recording,
           use_modelspec_init: bool = True,
           optimizer: str = 'adam',
           max_iter: int = 10000,
           cost_function: str = 'squared_error',
           early_stopping_steps: int = 5,
           early_stopping_tolerance: float = 5e-4,
           early_stopping_val_split: float = 0,
           learning_rate: float = 1e-4,
           variable_learning_rate: bool = False,
           batch_size: typing.Union[None, int] = None,
           seed: int = 0,
           initializer: str = 'random_normal',
           filepath: typing.Union[str, Path] = None,
           freeze_layers: typing.Union[None, list] = None,
           IsReload: bool = False,
           epoch_name: str = "REFERENCE",
           use_tensorboard: bool = False,
           kernel_regularizer: str = None,
           **context) -> dict:
    """TODO

    :param est:
    :param modelspec:
    :param use_modelspec_init:
    :param optimizer:
    :param max_iter:
    :param cost_function:
    :param early_stopping_steps:
    :param early_stopping_tolerance:
    :param learning_rate:
    :param batch_size:
    :param seed:
    :param filepath:
    :param freeze_layers: Indexes of layers to freeze prior to training. Indexes are modelspec indexes, so are offset
      from model layer indexes.
    :param IsReload:
    :param epoch_name
    :param context:

    :return: dict {'modelspec': modelspec}
    """

    if IsReload:
        return {}

    tf.random.set_seed(seed)
    np.random.seed(seed)
    #os.environ['TF_DETERMINISTIC_OPS'] = '1'   # makes output deterministic, but reduces prediction accuracy

    log.info('Building tensorflow keras model from modelspec.')
    nems.utils.progress_fun()

    # figure out where to save model checkpoints
    job_id = os.environ.get('SLURM_JOBID', None)
    if job_id is not None:
        # if job is running on slurm, need to change model checkpoint dir
        # keep a record of the job id
        modelspec.meta['slurm_jobid'] = job_id

        log_dir_root = Path('/mnt/scratch')
        assert log_dir_root.exists()
        log_dir_base = log_dir_root / Path('SLURM_JOBID' + job_id)
        log_dir_sub = Path(str(modelspec.meta['batch'])) \
                / modelspec.meta.get('cellid', "NOCELL") \
                / modelspec.get_longname()
        filepath = log_dir_base / log_dir_sub
        tbroot = filepath / 'logs'
    elif filepath is None:
        filepath = modelspec.meta['modelpath']
        tbroot = Path(f'/auto/data/tmp/tensorboard/')
    else:
        tbroot = Path(f'/auto/data/tmp/tensorboard/')

    filepath = Path(filepath)
    if not filepath.exists():
        filepath.mkdir(exist_ok=True, parents=True)
    cellid = modelspec.meta.get('cellid', 'CELL')
    tbpath = tbroot / (str(modelspec.meta['batch']) + '_' + cellid + '_' +
                       modelspec.meta['modelname'])
    # TODO: should this code just be deleted then?
    if 0 & use_tensorboard:
        # disabled, this is dumb. it deletes the previous round of fitting (eg, tfinit)
        fileList = glob.glob(str(tbpath / '*' / '*'))
        for filePath in fileList:
            try:
                os.remove(filePath)
            except:
                print("Error while deleting file : ", filePath)

    checkpoint_filepath = filepath / 'weights.hdf5'
    tensorboard_filepath = tbpath
    gradient_filepath = filepath / 'gradients'

    # update seed based on fit index
    seed += modelspec.fit_index

    if (freeze_layers
            is not None) and len(freeze_layers) and (len(freeze_layers)
                                                     == freeze_layers[-1] + 1):
        truncate_model = True
        modelspec_trunc, est_trunc = \
            initializers.modelspec_remove_input_layers(modelspec, est, remove_count=len(freeze_layers))
        modelspec_original = modelspec
        est_original = est
        modelspec = modelspec_trunc
        est = est_trunc
        freeze_layers = None
        log.info(
            f"Special case of freezing: truncating model. fit_index={modelspec.fit_index} cell_index={modelspec.cell_index}"
        )
    else:
        truncate_model = False

    input_name = modelspec.meta.get('input_name', 'stim')
    output_name = modelspec.meta.get('output_name', 'resp')

    # also grab the fs
    fs = est[input_name].fs

    if (epoch_name is not None) and (epoch_name != ""):
        # extract out the raw data, and reshape to (batch, time, channel)
        stim_train = np.transpose(
            est[input_name].extract_epoch(epoch=epoch_name, mask=est['mask']),
            [0, 2, 1])
        resp_train = np.transpose(
            est[output_name].extract_epoch(epoch=epoch_name, mask=est['mask']),
            [0, 2, 1])
    else:
        # extract data as a single batch size (1, time, channel)
        stim_train = np.transpose(
            est.apply_mask()[input_name].as_continuous()[np.newaxis, ...],
            [0, 2, 1])
        resp_train = np.transpose(
            est.apply_mask()[output_name].as_continuous()[np.newaxis, ...],
            [0, 2, 1])

    log.info(
        f'Feature dimensions: {stim_train.shape}; Data dimensions: {resp_train.shape}.'
    )

    if True:
        log.info("adding a tiny bit of noise to resp_train")
        resp_train = resp_train + np.random.randn(*resp_train.shape) / 10000
    # get state if present, and setup training data
    if 'state' in est.signals:
        if (epoch_name is not None) and (epoch_name != ""):
            state_train = np.transpose(
                est['state'].extract_epoch(epoch=epoch_name, mask=est['mask']),
                [0, 2, 1])
        else:
            state_train = np.transpose(
                est.apply_mask()['state'].as_continuous()[np.newaxis, ...],
                [0, 2, 1])
        state_shape = state_train.shape
        log.info(f'State dimensions: {state_shape}')
        train_data = [stim_train, state_train]
    else:
        state_train, state_shape = None, None
        train_data = stim_train

    # get the layers and build the model
    cost_fn = loss_functions.get_loss_fn(cost_function)
    #model_layers = modelspec.modelspec2tf2(
    #    use_modelspec_init=use_modelspec_init, seed=seed, fs=fs,
    #    initializer=initializer, freeze_layers=freeze_layers,
    #    kernel_regularizer=kernel_regularizer)
    model_layers = modelbuilder.modelspec2tf(
        modelspec,
        use_modelspec_init=use_modelspec_init,
        seed=seed,
        fs=fs,
        initializer=initializer,
        freeze_layers=freeze_layers,
        kernel_regularizer=kernel_regularizer)

    if np.any([isinstance(layer, Conv2D_NEMS) for layer in model_layers]):
        # need a "channel" dimension for Conv2D (like rgb channels, not frequency). Only 1 channel for our data.
        stim_train = stim_train[..., np.newaxis]
        train_data = train_data[..., np.newaxis]

    # do some batch sizing logic
    batch_size = stim_train.shape[0] if batch_size == 0 else batch_size

    if variable_learning_rate:
        # TODO: allow other schedule options instead of hard-coding exp decay?
        # TODO: expose exp decay kwargs as kw options? not clear how to choose these parameters
        learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=learning_rate,
            decay_steps=10000,
            decay_rate=0.9)

    from nems.tf.loss_functions import pearson
    model = modelbuilder.ModelBuilder(
        name='Test-model',
        layers=model_layers,
        learning_rate=learning_rate,
        loss_fn=cost_fn,
        optimizer=optimizer,
        metrics=[pearson],
    ).build_model(input_shape=stim_train.shape,
                  state_shape=state_shape,
                  batch_size=batch_size)

    if freeze_layers is not None:
        for freeze_index in freeze_layers:
            log.info(
                f'TF layer #{freeze_index}: "{model.layers[freeze_index + 1].name}" is not trainable.'
            )

    # tracking early termination
    model.early_terminated = False

    # create the callbacks
    early_stopping = callbacks.DelayedStopper(
        monitor='val_loss',
        patience=30 * early_stopping_steps,
        min_delta=early_stopping_tolerance,
        verbose=1,
        restore_best_weights=True)
    regular_stopping = callbacks.DelayedStopper(
        monitor='loss',
        patience=30 * early_stopping_steps,
        min_delta=early_stopping_tolerance,
        verbose=1,
        restore_best_weights=True)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        filepath=str(checkpoint_filepath),
        save_best_only=False,
        save_weights_only=True,
        save_freq=100 * stim_train.shape[0],
        monitor='loss',
        verbose=0)
    sparse_logger = callbacks.SparseProgbarLogger(n_iters=50)
    nan_terminate = tf.keras.callbacks.TerminateOnNaN()
    nan_weight_terminate = callbacks.TerminateOnNaNWeights()
    tensorboard = tf.keras.callbacks.TensorBoard(
        log_dir=str(tensorboard_filepath),  # TODO: generic tensorboard dir?
        histogram_freq=0,  # record the distribution of the weights
        write_graph=False,
        update_freq='epoch',
        profile_batch=0)
    # gradient_logger = callbacks.GradientLogger(filepath=str(gradient_filepath),
    #                                            train_input=stim_train,
    #                                            model=model)

    # save an initial set of weights before freezing, in case of termination before any checkpoints
    #log.info('saving weights to : %s', str(checkpoint_filepath) )
    model.save_weights(str(checkpoint_filepath), overwrite=True)

    if version.parse(tf.__version__) >= version.parse("2.2.0"):
        callback0 = [sparse_logger]
        verbose = 0
    else:
        callback0 = []
        verbose = 2
    # enable the below to log tracked parameters to tensorboard
    if use_tensorboard:
        callback0.append(tensorboard)
        log.info(f'Enabling tensorboard, log: {str(tensorboard_filepath)}')
        # enable the below to record gradients to visualize in tensorboard; this is very slow,
        # and loading all this into tensorboard can use A LOT of memory
        # callback0.append(gradient_logger)

    if early_stopping_val_split > 0:
        callback0.append(early_stopping)
        log.info(
            f'Enabling early stopping, val split: {str(early_stopping_val_split)}'
        )
    else:
        callback0.append(regular_stopping)
        log.info(f'Stop tolerance: min_delta={early_stopping_tolerance}')

    log.info(f'Fitting model (batch_size={batch_size})...')
    history = model.fit(train_data,
                        resp_train,
                        validation_split=early_stopping_val_split,
                        verbose=verbose,
                        epochs=max_iter,
                        callbacks=callback0 + [
                            nan_terminate,
                            nan_weight_terminate,
                            checkpoint,
                        ],
                        batch_size=batch_size)

    # did we terminate on a nan loss or weights? Load checkpoint if so
    if np.all(np.isnan(model.predict(train_data))
              ) or model.early_terminated:  # TODO: should this be np.any()?
        log.warning(
            'Model terminated on nan loss or weights, restoring saved weights.'
        )
        try:
            # this can fail if it nans out before a single checkpoint gets saved, either because no saved weights
            # exist, or it tries to load a in different model from the init
            model.load_weights(str(checkpoint_filepath))
            log.warning('Reloaded previous saved weights after nan loss.')
        except (tf.errors.NotFoundError, ValueError):
            pass

    modelspec = tf2modelspec(model, modelspec)

    if truncate_model:
        log.info("Special case of freezing: restoring truncated model!!!")
        #modelspec_restored, rec_restored = modelspec_restore_input_layers(modelspec_trunc, rec_trunc, modelspec_original)
        modelspec_restored, est_restored = initializers.modelspec_restore_input_layers(
            modelspec, est, modelspec_original)
        est = est_original
        modelspec = modelspec_restored

    # debug: dump modelspec parameters
    #for i in range(len(modelspec)):
    #    log.info(modelspec.phi[i])

    contains_tf_only_layers = np.any(
        ['tf_only' in m['fn'] for m in modelspec.modules])
    if not contains_tf_only_layers:
        # compare the predictions from the model and modelspec
        error = compare_ms_tf(modelspec, model, est, train_data)
        if error > 1e-5:
            log.warning(
                f'Mean difference between NEMS and TF model prediction: {error}'
            )
        else:
            log.info(
                f'Mean difference between NEMS and TF model prediction: {error}'
            )
    else:
        # nothing to compare, ms evaluation is not implemented for this type of model
        pass

    # add in some relevant meta information
    modelspec.meta['n_parms'] = len(modelspec.phi_vector)
    try:
        n_epochs = len(history.history['loss'])
        if 'val_loss' in history.history.keys():
            #val_stop = np.argmin(history.history['val_loss'])
            #loss = history.history['loss'][val_stop]
            loss = np.nanmin(history.history['val_loss'])
        else:
            loss = np.nanmin(history.history['loss'])

    except KeyError:
        n_epochs = 0
        loss = 0
    if modelspec.fit_count == 1:
        modelspec.meta['n_epochs'] = n_epochs
        modelspec.meta['loss'] = loss
    else:
        if modelspec.fit_index == 0:
            modelspec.meta['n_epochs'] = np.zeros(modelspec.fit_count)
            modelspec.meta['loss'] = np.zeros(modelspec.fit_count)
        modelspec.meta['n_epochs'][modelspec.fit_index] = n_epochs
        modelspec.meta['loss'][modelspec.fit_index] = loss

    try:
        max_iter = modelspec.meta['extra_results']
        modelspec.meta['extra_results'] = max(max_iter, n_epochs)
    except KeyError:
        modelspec.meta['extra_results'] = n_epochs

    nems.utils.progress_fun()

    # clean up temp files
    if job_id is not None:
        log.info('removing temporary weights file(s)')
        shutil.rmtree(log_dir_base)

    return {'modelspec': modelspec}
Esempio n. 9
0
site = 'TAR010c'
nPCs = 4
batch = 307
fs = 20
rawid = pu.which_rawids(site)
ops = {
    'batch': batch,
    'pupil': 1,
    'rasterfs': fs,
    'siteid': site,
    'stim': 0,
    'rawid': rawid
}
uri = nb.baphy_load_recording_uri(**ops)
rec = Recording.load(uri)
rec['resp'] = rec['resp'].rasterize()
rec = rec.and_mask(['HIT_TRIAL', 'MISS_TRIAL', 'PASSIVE_EXPERIMENT'])

rec = rec.and_mask(['PreStimSilence', 'PostStimSilence'], invert=True)
rec = rec.apply_mask(reset_epochs=True)

# create four state masks
rec = preproc.create_ptd_masks(rec)

# create the appropriate dictionaries of responses
# don't *think* it's so important to balanced reps for
# this analysis since it's parametric (as opposed to the
# discrimination calc which depends greatly on being balanced)

# get list of unique targets present in both a/p. This is for
Esempio n. 10
0
def recording(signal1, signal2):
    signals = {signal1.name: signal1, signal2.name: signal2}
    return Recording(signals)
Esempio n. 11
0
def fit_ccnorm(modelspec,
               est: recording.Recording,
               metric=None,
               use_modelspec_init: bool = True,
               optimizer: str = 'adam',
               max_iter: int = 10000,
               early_stopping_steps: int = 5,
               tolerance: float = 5e-4,
               learning_rate: float = 1e-4,
               batch_size: typing.Union[None, int] = None,
               seed: int = 0,
               initializer: str = 'random_normal',
               freeze_layers: typing.Union[None, list] = None,
               epoch_name: str = "REFERENCE",
               shrink_cc: float = 0,
               noise_pcs: int = 0,
               shared_pcs: int = 0,
               also_fit_resp: bool = False,
               force_psth: bool = False,
               use_metric: typing.Union[None, str] = None,
               alpha: float = 0.1,
               beta: float = 1,
               exclude_idx=None,
               exclude_after=None,
               freeze_idx=None,
               freeze_after=None,
               **context):
    '''
    Required Arguments:
     est          A recording object
     modelspec     A modelspec object

    Optional Arguments:
     <copied from fit_tf for now

    Returns
     dictionary: {'modelspec': updated_modelspec}
    '''

    # Hard-coded
    cost_function = basic_cost
    fitter = scipy_minimize
    segmentor = nems.segmentors.use_all_data
    mapper = nems.fitters.mappers.simple_vector
    fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter}

    start_time = time.time()

    fit_index = modelspec.fit_index
    if (exclude_idx is not None) | (freeze_idx is not None) | \
            (exclude_after is not None) | (freeze_after is not None):
        modelspec0 = modelspec.copy()
        modelspec, include_set = modelspec_freeze_layers(
            modelspec,
            include_idx=None,
            exclude_idx=exclude_idx,
            exclude_after=exclude_after,
            freeze_idx=freeze_idx,
            freeze_after=freeze_after)
        modelspec0.set_fit(fit_index)
        modelspec.set_fit(fit_index)
    else:
        include_set = None

    # Computing PCs before masking out unwanted stimuli in order to
    # preserve match with epochs
    epoch_regex = "^STIM_"
    stims = (est.epochs['name'].value_counts() >= 8)
    stims = [
        stims.index[i] for i, s in enumerate(stims)
        if bool(re.search(epoch_regex, stims.index[i])) and s == True
    ]

    Rall_u = est.apply_mask()['psth'].as_continuous().T
    # can't simply extract evoked for refs because can be longer/shorted if it came after target
    # and / or if it was the last stim. So, masking prestim / postim doesn't work. Do it manually
    #d = est['resp'].extract_epochs(stims, mask=est['mask'])

    #R = [v.mean(axis=0) for (k, v) in d.items()]
    #R = [np.reshape(np.transpose(v,[1,0,2]),[v.shape[1],-1]) for (k, v) in d.items()]
    #Rall_u = np.hstack(R).T

    pca = PCA(n_components=2)
    pca.fit(Rall_u)
    pc_axes = pca.components_

    # apply mask to remove invalid portions of signals and allow fit to
    # only evaluate the model on the valid portion of the signals
    if 'mask_small' in est.signals.keys():
        log.info('reseting mask with mask_small+mask_large subset')
        est['mask'] = est['mask']._modified_copy(data=est['mask_small']._data +
                                                 est['mask_large']._data)

    if 'mask' in est.signals.keys():
        log.info("Data len pre-mask: %d", est['mask'].shape[1])
        est = est.apply_mask()
        log.info("Data len post-mask: %d", est['mask'].shape[1])

    # if we want to fit to first-order cc error.
    #uncomment this and make sure sdexp is generating a pred0 signal
    est = modelspec.evaluate(est, stop=2)
    if ('pred0' in est.signals.keys()) & (not force_psth):
        input_name = 'pred0'
        log.info('Found pred0 for fitting CC')
    else:
        input_name = 'psth'
        log.info('No pred0, using psth for fitting CC')

    conditions = [
        "_".join(k.split("_")[1:]) for k in est.signals.keys()
        if k.startswith("mask_")
    ]
    if (len(conditions) > 2) and any(
        [c.split("_")[-1] == 'lg' for c in conditions]):
        conditions.remove("small")
        conditions.remove("large")
    #conditions = conditions[0:2]
    #conditions = ['large','small']

    group_idx = [est['mask_' + c].as_continuous()[0, :] for c in conditions]
    cg_filtered = [(c, g) for c, g in zip(conditions, group_idx)
                   if g.sum() > 0]
    conditions, group_idx = zip(*cg_filtered)

    for c, g in zip(conditions, group_idx):
        log.info(f"cc data for {c} len {g.sum()}")

    resp = est['resp'].as_continuous()
    pred0 = est[input_name].as_continuous()
    #import pdb; pdb.set_trace()
    if shrink_cc > 0:
        log.info(f'cc approx: shrink_cc={shrink_cc}')
        group_cc = [
            cc_shrink(resp[:, idx] - pred0[:, idx], sigrat=shrink_cc)
            for idx in group_idx
        ]
    elif shared_pcs > 0:
        log.info(f'cc approx: shared_pcs={shared_pcs}')
        cc = np.cov(resp - pred0)
        u, s, vh = np.linalg.svd(cc)
        U = u[:, :shared_pcs] @ u[:, :shared_pcs].T

        group_cc = [
            cc_shared_space(resp[:, idx] - pred0[:, idx], U)
            for idx in group_idx
        ]
    elif noise_pcs > 0:
        log.info(f'cc approx: noise_pcs={noise_pcs}')
        group_cc = [
            cc_lowrank(resp[:, idx] - pred0[:, idx], n_pcs=noise_pcs)
            for idx in group_idx
        ]
    else:
        group_cc = [np.cov(resp[:, idx] - pred0[:, idx]) for idx in group_idx]
    group_cc_raw = [np.cov(resp[:, idx] - pred0[:, idx]) for idx in group_idx]

    # variance of projection onto PCs (PCs computed above before masking)
    pcproj0 = (resp - pred0).T.dot(pc_axes.T).T
    pcproj_std = pcproj0.std(axis=1)

    if (use_metric == 'cc_err_w'):

        def metric(d, verbose=False):
            return metrics.cc_err_w(d,
                                    pred_name='pred',
                                    pred0_name=input_name,
                                    group_idx=group_idx,
                                    group_cc=group_cc,
                                    alpha=alpha,
                                    pcproj_std=None,
                                    pc_axes=None,
                                    verbose=verbose)

        log.info(f"fit_ccnorm metric: cc_err_w (alpha={alpha})")

    elif (metric is None) and also_fit_resp:
        log.info(f"resp_cc_err: pred0_name: {input_name} beta: {beta}")
        metric = lambda d: metrics.resp_cc_err(d,
                                               pred_name='pred',
                                               pred0_name=input_name,
                                               group_idx=group_idx,
                                               group_cc=group_cc,
                                               beta=beta,
                                               pcproj_std=None,
                                               pc_axes=None)

    elif (use_metric == 'cc_err_md'):

        def metric(d, verbose=False):
            return metrics.cc_err_md(d,
                                     pred_name='pred',
                                     pred0_name=input_name,
                                     group_idx=group_idx,
                                     group_cc=group_cc,
                                     pcproj_std=None,
                                     pc_axes=None)

        log.info(f"fit_ccnorm metric: cc_err_md")

    elif (metric is None):
        #def cc_err(result, pred_name='pred_lv', resp_name='resp', pred0_name='pred',
        #   group_idx=None, group_cc=None, pcproj_std=None, pc_axes=None):
        # current implementation of cc_err
        metric = lambda d: metrics.cc_err(d,
                                          pred_name='pred',
                                          pred0_name=input_name,
                                          group_idx=group_idx,
                                          group_cc=group_cc,
                                          pcproj_std=None,
                                          pc_axes=None)
        log.info(f"fit_ccnorm metric: cc_err")

    # turn on "fit mode". currently this serves one purpose, for normalization
    # parameters to be re-fit for the output of each module that uses
    # normalization. does nothing if normalization is not being used.
    ms.fit_mode_on(modelspec, est)

    # Create the mapper functions that translates to and from modelspecs.
    # It has three functions that, when defined as mathematical functions, are:
    #    .pack(modelspec) -> fitspace_point
    #    .unpack(fitspace_point) -> modelspec
    #    .bounds(modelspec) -> fitspace_bounds
    packer, unpacker, pack_bounds = mapper(modelspec)

    # A function to evaluate the modelspec on the data
    evaluator = nems.modelspec.evaluate

    my_cost_function = cost_function
    my_cost_function.counter = 0

    # Freeze everything but sigma, since that's all the fitter should be
    # updating.
    cost_fn = partial(my_cost_function,
                      unpacker=unpacker,
                      modelspec=modelspec,
                      data=est,
                      segmentor=segmentor,
                      evaluator=evaluator,
                      metric=metric,
                      display_N=1000)

    # get initial sigma value representing some point in the fit space,
    # and corresponding bounds for each value
    sigma = packer(modelspec)
    bounds = pack_bounds(modelspec)

    # Results should be a list of modelspecs
    # (might only be one in list, but still should be packaged as a list)
    improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs)
    improved_modelspec = unpacker(improved_sigma)
    elapsed_time = (time.time() - start_time)

    start_err = cost_fn(sigma)
    final_err = cost_fn(improved_sigma)
    log.info("Delta error: %.06f - %.06f = %e", start_err, final_err,
             final_err - start_err)

    # TODO: Should this maybe be moved to a higher level
    # so it applies to ALL the fittters?
    ms.fit_mode_off(improved_modelspec)

    if include_set is not None:
        # pull out updated phi values from improved_modelspec, include_set only
        improved_modelspec = \
            modelspec_unfreeze_layers(improved_modelspec, modelspec0, include_set)
        improved_modelspec.set_fit(fit_index)

    log.info(
        f"Updating improved modelspec with fit_idx={improved_modelspec.fit_index}"
    )
    improved_modelspec.meta['fitter'] = 'ccnorm'
    improved_modelspec.meta['n_parms'] = len(improved_sigma)
    if modelspec.fit_count == 1:
        improved_modelspec.meta['fit_time'] = elapsed_time
        improved_modelspec.meta['loss'] = final_err
    else:
        if modelspec.fit_index == 0:
            improved_modelspec.meta['fit_time'] = np.zeros(
                improved_modelspec.fit_count)
            improved_modelspec.meta['loss'] = np.zeros(
                improved_modelspec.fit_count)
        improved_modelspec.meta['fit_time'][fit_index] = elapsed_time
        improved_modelspec.meta['loss'][fit_index] = final_err

    return {'modelspec': improved_modelspec}
Esempio n. 12
0
def fit_tf(modelspec,
           est: recording.Recording,
           use_modelspec_init: bool = True,
           optimizer: str = 'adam',
           max_iter: int = 10000,
           cost_function: str = 'squared_error',
           early_stopping_steps: int = 5,
           early_stopping_tolerance: float = 5e-4,
           learning_rate: float = 1e-4,
           batch_size: typing.Union[None, int] = None,
           seed: int = 0,
           initializer: str = 'random_normal',
           filepath: typing.Union[str, Path] = None,
           freeze_layers: typing.Union[None, list] = None,
           IsReload: bool = False,
           epoch_name: str = "REFERENCE",
           **context) -> dict:
    """TODO

    :param est:
    :param modelspec:
    :param use_modelspec_init:
    :param optimizer:
    :param max_iter:
    :param cost_function:
    :param early_stopping_steps:
    :param early_stopping_tolerance:
    :param learning_rate:
    :param batch_size:
    :param seed:
    :param filepath:
    :param freeze_layers: Indexes of layers to freeze prior to training. Indexes are modelspec indexes, so are offset
      from model layer indexes.
    :param IsReload:
    :param epoch_name
    :param context:

    :return: dict {'modelspec': modelspec}
    """

    if IsReload:
        return {}

    tf.random.set_seed(seed)
    np.random.seed(seed)
    #os.environ['TF_DETERMINISTIC_OPS'] = '1'   # makes output deterministic, but reduces prediction accuracy

    log.info('Building tensorflow keras model from modelspec.')
    nems.utils.progress_fun()

    # figure out where to save model checkpoints
    if filepath is None:
        filepath = modelspec.meta['modelpath']

    # if job is running on slurm, need to change model checkpoint dir
    job_id = os.environ.get('SLURM_JOBID', None)
    if job_id is not None:
        # keep a record of the job id
        modelspec.meta['slurm_jobid'] = job_id

        log_dir_root = Path('/mnt/scratch')
        assert log_dir_root.exists()
        log_dir_sub = Path('SLURM_JOBID' + job_id) / str(modelspec.meta['batch'])\
                      / modelspec.meta.get('cellid', "NOCELL")\
                      / modelspec.meta['modelname']
        filepath = log_dir_root / log_dir_sub

    filepath = Path(filepath)
    if not filepath.exists():
        filepath.mkdir(exist_ok=True, parents=True)

    checkpoint_filepath = filepath / 'weights.hdf5'
    tensorboard_filepath = filepath / 'logs'
    gradient_filepath = filepath / 'gradients'

    # update seed based on fit index
    seed += modelspec.fit_index

    # need to get duration of stims in order to reshape data
    #epoch_name = 'REFERENCE'  # TODO: this should not be hardcoded
    # moved to input parameter

    input_name = modelspec.meta.get('input_name', 'stim')
    output_name = modelspec.meta.get('output_name', 'resp')

    # also grab the fs
    fs = est[input_name].fs

    if (epoch_name is not None) and (epoch_name != ""):
        # extract out the raw data, and reshape to (batch, time, channel)
        stim_train = np.transpose(
            est[input_name].extract_epoch(epoch=epoch_name, mask=est['mask']),
            [0, 2, 1])
        resp_train = np.transpose(
            est[output_name].extract_epoch(epoch=epoch_name, mask=est['mask']),
            [0, 2, 1])
    else:
        # extract data as a single batch size (1, time, channel)
        stim_train = np.transpose(
            est.apply_mask()[input_name].as_continuous()[np.newaxis, ...],
            [0, 2, 1])
        resp_train = np.transpose(
            est.apply_mask()[output_name].as_continuous()[np.newaxis, ...],
            [0, 2, 1])

    log.info(
        f'Feature dimensions: {stim_train.shape}; Data dimensions: {resp_train.shape}.'
    )

    # get state if present, and setup training data
    if 'state' in est.signals:
        if (epoch_name is not None) and (epoch_name != ""):
            state_train = np.transpose(
                est['state'].extract_epoch(epoch=epoch_name, mask=est['mask']),
                [0, 2, 1])
        else:
            state_train = np.transpose(
                est.apply_mask()['state'].as_continuous()[np.newaxis, ...],
                [0, 2, 1])
        state_shape = state_train.shape
        log.info(f'State dimensions: {state_shape}')
        train_data = [stim_train, state_train]
    else:
        state_train, state_shape = None, None
        train_data = stim_train

    # correlation for monitoring
    # TODO: tf.utils?
    def pearson(y_true, y_pred):
        return tfp.stats.correlation(y_true,
                                     y_pred,
                                     event_axis=None,
                                     sample_axis=None)

    # get the layers and build the model
    cost_fn = loss_functions.get_loss_fn(cost_function)
    model_layers = modelspec.modelspec2tf2(
        use_modelspec_init=use_modelspec_init,
        seed=seed,
        fs=fs,
        initializer=initializer)
    if np.any([isinstance(layer, Conv2D_NEMS) for layer in model_layers]):
        # need a "channel" dimension for Conv2D (like rgb channels, not frequency). Only 1 channel for our data.
        stim_train = stim_train[..., np.newaxis]
        train_data = train_data[..., np.newaxis]

    # do some batch sizing logic
    batch_size = stim_train.shape[0] if batch_size == 0 else batch_size

    model = modelbuilder.ModelBuilder(
        name='Test-model',
        layers=model_layers,
        learning_rate=learning_rate,
        loss_fn=cost_fn,
        optimizer=optimizer,
        metrics=[pearson],
    ).build_model(input_shape=stim_train.shape,
                  state_shape=state_shape,
                  batch_size=batch_size)

    # tracking early termination
    model.early_terminated = False

    # create the callbacks
    early_stopping = callbacks.DelayedStopper(
        monitor='loss',
        patience=30 * early_stopping_steps,
        min_delta=early_stopping_tolerance,
        verbose=1,
        restore_best_weights=False)
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        filepath=str(checkpoint_filepath),
        save_best_only=False,
        save_weights_only=True,
        save_freq=100 * stim_train.shape[0],
        monitor='loss',
        verbose=0)
    sparse_logger = callbacks.SparseProgbarLogger(n_iters=10)
    nan_terminate = tf.keras.callbacks.TerminateOnNaN()
    nan_weight_terminate = callbacks.TerminateOnNaNWeights()
    tensorboard = tf.keras.callbacks.TensorBoard(
        log_dir=str(tensorboard_filepath),  # TODO: generic tensorboard dir?
        histogram_freq=0,  # record the distribution of the weights
        write_graph=False,
        update_freq='epoch',
        profile_batch=0)
    # gradient_logger = callbacks.GradientLogger(filepath=str(gradient_filepath),
    #                                            train_input=stim_train,
    #                                            model=model)

    # freeze layers
    if freeze_layers is not None:
        for freeze_index in freeze_layers:
            log.info(
                f'Freezing layer #{freeze_index}: "{model.layers[freeze_index + 1].name}".'
            )
            model.layers[freeze_index + 1].trainable = False

    # save an initial set of weights before freezing, in case of termination before any checkpoints
    #log.info('saving weights to : %s', str(checkpoint_filepath) )
    model.save_weights(str(checkpoint_filepath), overwrite=True)

    if version.parse(tf.__version__) >= version.parse("2.2.0"):
        callback0 = [sparse_logger]
        verbose = 0
    else:
        callback0 = []
        verbose = 2

    log.info(f'Fitting model (batch_size={batch_size})...')
    history = model.fit(
        train_data,
        resp_train,
        # validation_split=0.2,
        verbose=verbose,
        epochs=max_iter,
        batch_size=batch_size,
        callbacks=callback0 + [
            nan_terminate,
            nan_weight_terminate,
            early_stopping,
            checkpoint,
            # enable the below to log tracked parameters to tensorboard
            # tensorboard,
            # enable the below to record gradients to visualize in tensorboard; this is very slow,
            # and loading all this into tensorboard can use A LOT of memory
            # gradient_logger,
        ])

    # did we terminate on a nan loss or weights? Load checkpoint if so
    if np.all(np.isnan(model.predict(train_data))
              ) or model.early_terminated:  # TODO: should this be np.any()?
        log.warning(
            'Model terminated on nan loss or weights, restoring saved weights.'
        )
        try:
            # this can fail if it nans out before a single checkpoint gets saved, either because no saved weights
            # exist, or it tries to load a in different model from the init
            model.load_weights(str(checkpoint_filepath))
            log.warning('Reloaded previous saved weights after nan loss.')
        except (tf.errors.NotFoundError, ValueError):
            pass

    modelspec = tf2modelspec(model, modelspec)

    contains_tf_only_layers = np.any(
        ['tf_only' in m['fn'] for m in modelspec.modules])
    if not contains_tf_only_layers:
        # compare the predictions from the model and modelspec
        error = compare_ms_tf(modelspec, model, est, train_data)
        if error > 1e-5:
            log.warning(
                f'Mean difference between NEMS and TF model prediction: {error}'
            )
        else:
            log.info(
                f'Mean difference between NEMS and TF model prediction: {error}'
            )
    else:
        # nothing to compare, ms evaluation is not implemented for this type of model
        pass

    # add in some relevant meta information
    modelspec.meta['n_parms'] = len(modelspec.phi_vector)
    try:
        n_epochs = len(history.history['loss'])
    except KeyError:
        n_epochs = 0
    try:
        max_iter = modelspec.meta['extra_results']
        modelspec.meta['extra_results'] = max(max_iter, n_epochs)
    except KeyError:
        modelspec.meta['extra_results'] = n_epochs

    nems.utils.progress_fun()

    return {'modelspec': modelspec}
Esempio n. 13
0
relative_signals_dir = '../recordings'
relative_modelspecs_dir = '../modelspecs'

# Convert to absolute paths so they can be passed to functions in
# other directories
signals_dir = os.path.abspath(relative_signals_dir)
modelspecs_dir = os.path.abspath(relative_modelspecs_dir)

# ----------------------------------------------------------------------------
# DATA LOADING

# GOAL: Get your data loaded into memory as a Recording object
logging.info('Loading data...')

# Method #1: Load the data from a local directory
rec = Recording.load(os.path.join(signals_dir, 'eno052d-a1.tgz'))

# Method #2: Load the data from baphy using the (incomplete, TODO) HTTP API:
#URL = "http://potoroo:3004/baphy/271/bbl086b-11-1?rasterfs=200"
#rec = Recording.load_url(URL)

logging.info('Generating state signal...')

rec = preproc.make_state_signal(rec, ['pupil'], [''], 'state')

# ----------------------------------------------------------------------------
# INITIALIZE MODELSPEC

# GOAL: Define the model that you wish to test

logging.info('Initializing modelspec(s)...')
Esempio n. 14
0
relative_signals_dir = '../recordings'
relative_modelspecs_dir = '../modelspecs'

# Convert to absolute paths so they can be passed to functions in
# other directories
signals_dir = os.path.abspath(relative_signals_dir)
modelspecs_dir = os.path.abspath(relative_modelspecs_dir)

# ----------------------------------------------------------------------------
# DATA LOADING

# GOAL: Get your data loaded into memory as a Recording object
logging.info('Loading data...')

# Method #1: Load the data from a local directory
rec = Recording.load(os.path.join(signals_dir, 'TAR010c.NAT.fs50.tgz'))

cellid = 'TAR010c-16-2'
rec['resp'] = rec['resp'].extract_channels([cellid])

# Method #2: Load the data from baphy using the (incomplete, TODO) HTTP API:
#URL = "http://potoroo:3004/baphy/271/bbl086b-11-1?rasterfs=200"
#rec = Recording.load_url(URL)

logging.info('Generating state signal...')

rec = preproc.make_state_signal(rec, ['pupil'], [''], 'state')

# ----------------------------------------------------------------------------
# DATA WITHHOLDING
Esempio n. 15
0
sites = np.unique([c[:7] for c in cellids])

# loop over sites (to speed loading)
for site in sites:
    # load recording(s)
    rasterfs = 100
    ops = {
        'batch': batch,
        'siteid': site,
        'rasterfs': rasterfs,
        'pupil': 1,
        'stim': 0,
        'recache': False
    }
    uri = nb.baphy_load_recording_uri(**ops)
    rec100 = Recording.load(uri)
    rec100['resp'] = rec100['resp'].rasterize()
    rec100 = rec100.and_mask(
        ['HIT_TRIAL', 'CORRECT_REJECT_TRIAL', 'PASSIVE_EXPERIMENT'])

    rasterfs = 20
    ops = {
        'batch': batch,
        'siteid': site,
        'rasterfs': rasterfs,
        'pupil': 1,
        'stim': 0,
        'recache': False
    }
    uri = nb.baphy_load_recording_uri(**ops)
    rec20 = Recording.load(uri)
Esempio n. 16
0
def fit_pcnorm(modelspec,
               est: recording.Recording,
               metric=None,
               use_modelspec_init: bool = True,
               optimizer: str = 'adam',
               max_iter: int = 10000,
               early_stopping_steps: int = 5,
               tolerance: float = 5e-4,
               learning_rate: float = 1e-4,
               batch_size: typing.Union[None, int] = None,
               seed: int = 0,
               initializer: str = 'random_normal',
               freeze_layers: typing.Union[None, list] = None,
               epoch_name: str = "REFERENCE",
               n_pcs=2,
               **context):
    '''
    Required Arguments:
     est          A recording object
     modelspec     A modelspec object

    Optional Arguments:
     <copied from fit_tf for now

    Returns
     dictionary: {'modelspec': updated_modelspec}
    '''

    # Hard-coded
    cost_function = basic_cost
    fitter = scipy_minimize
    segmentor = nems.segmentors.use_all_data
    mapper = nems.fitters.mappers.simple_vector
    fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter}

    start_time = time.time()

    modelspec = copy.deepcopy(modelspec)

    # apply mask to remove invalid portions of signals and allow fit to
    # only evaluate the model on the valid portion of the signals
    if 'mask' in est.signals.keys():
        log.info("Data len pre-mask: %d", est['mask'].shape[1])
        est = est.apply_mask()
        log.info("Data len post-mask: %d", est['mask'].shape[1])

    conditions = [
        "_".join(k.split("_")[1:]) for k in est.signals.keys()
        if k.startswith("mask_")
    ]
    if (len(conditions) > 2) and any(
        [c.split("_")[-1] == 'lg' for c in conditions]):
        conditions.remove("small")
        conditions.remove("large")
    #conditions = conditions[0:2]
    #conditions = ['large','small']

    group_idx = [est['mask_' + c].as_continuous()[0, :] for c in conditions]
    cg_filtered = [(c, g) for c, g in zip(conditions, group_idx)
                   if g.sum() > 0]
    conditions, group_idx = zip(*cg_filtered)

    for c, g in zip(conditions, group_idx):
        log.info(f"Data subset for {c} len {g.sum()}")

    resp = est['resp'].as_continuous()
    pred0 = est['pred0'].as_continuous()
    residual = resp - pred0

    pca = PCA(n_components=n_pcs)
    pca.fit(residual.T)
    pc_axes = pca.components_

    pcproj = residual.T.dot(pc_axes.T).T

    group_pc = [pcproj[:, idx].std(axis=1) for idx in group_idx]
    resp_std = resp.std(axis=1)
    #import pdb; pdb.set_trace()

    if metric is None:
        metric = lambda d: pc_err(d,
                                  pred_name='pred',
                                  pred0_name='pred0',
                                  group_idx=group_idx,
                                  group_pc=group_pc,
                                  pc_axes=pc_axes,
                                  resp_std=resp_std)

    # turn on "fit mode". currently this serves one purpose, for normalization
    # parameters to be re-fit for the output of each module that uses
    # normalization. does nothing if normalization is not being used.
    ms.fit_mode_on(modelspec, est)

    # Create the mapper functions that translates to and from modelspecs.
    # It has three functions that, when defined as mathematical functions, are:
    #    .pack(modelspec) -> fitspace_point
    #    .unpack(fitspace_point) -> modelspec
    #    .bounds(modelspec) -> fitspace_bounds
    packer, unpacker, pack_bounds = mapper(modelspec)

    # A function to evaluate the modelspec on the data
    evaluator = nems.modelspec.evaluate

    my_cost_function = cost_function
    my_cost_function.counter = 0

    # Freeze everything but sigma, since that's all the fitter should be
    # updating.
    cost_fn = partial(my_cost_function,
                      unpacker=unpacker,
                      modelspec=modelspec,
                      data=est,
                      segmentor=segmentor,
                      evaluator=evaluator,
                      metric=metric,
                      display_N=1000)

    # get initial sigma value representing some point in the fit space,
    # and corresponding bounds for each value
    sigma = packer(modelspec)
    bounds = pack_bounds(modelspec)

    # Results should be a list of modelspecs
    # (might only be one in list, but still should be packaged as a list)
    improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs)
    improved_modelspec = unpacker(improved_sigma)
    elapsed_time = (time.time() - start_time)

    start_err = cost_fn(sigma)
    final_err = cost_fn(improved_sigma)
    log.info("Delta error: %.06f - %.06f = %e", start_err, final_err,
             final_err - start_err)

    # TODO: Should this maybe be moved to a higher level
    # so it applies to ALL the fittters?
    ms.fit_mode_off(improved_modelspec)
    ms.set_modelspec_metadata(improved_modelspec, 'fitter', 'ccnorm')
    ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time)
    ms.set_modelspec_metadata(improved_modelspec, 'n_parms',
                              len(improved_sigma))

    return {'modelspec': improved_modelspec.copy(), 'save_context': True}
Esempio n. 17
0
def from_nwb_pupil(nwb_file, nwb_format,fs=20,with_pupil=False,running_speed=False,as_dict=True):
#def from_nwb(cls, nwb_file, nwb_format,with_pupil=False,fs=20):
    """
    The NWB (Neurodata Without Borders) format is a unified data format developed by the Allen Brain Institute.
    Data is stored as an HDF5 file, with the format varying depending how the data was saved.
    
    References:
      - https://nwb.org
      - https://pynwb.readthedocs.io/en/latest/index.html
    :param nwb_file: path to the nwb file
    :param nwb_format: specifier for how the data is saved in the container
    :param int fs: will match for all signals
    :param bool with_pupil, running speed: whether to return pupil, speed signals in recording
    :param bool as_dict: return a dictionary of recording objects, each corresponding to a single unit/neuron
                         else a single recording object w/ each unit corresponding to a channel in pointprocess signal
    :return: a recording object
    """
    #log.info(f'Loading NWB file with format "{nwb_format}" from "{nwb_file}".')

    # add in supported nwb formats here
    assert nwb_format in ['neuropixel'], f'"{nwb_format}" not a supported NWB file format.'

    nwb_filepath = Path(nwb_file)
    if not nwb_filepath.exists():
        raise FileNotFoundError(f'"{nwb_file}" could not be found.')

    if nwb_format == 'neuropixel':
        """
        In neuropixel ecephys nwb files, data is stored in several attributes of the container: 
          - units: individual cell metadata, a dataframe
          - epochs: timing of the stimuli, series of arrays
          - lab_meta_data: metadata about the experiment, such as specimen details
          
        Spike times are saved as arrays in the 'spike_times' column of the units dataframe as xarrays. 
        The frequency defaults to match pupil - if no pupil data retrieved, set to chosen value (previous default 1250).
          
        Refs:
          - https://allensdk.readthedocs.io/en/latest/visual_coding_neuropixels.html
          - https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_quickstart.html
          - https://allensdk.readthedocs.io/en/latest/_static/examples/nb/ecephys_data_access.html
        """
        try:
            from pynwb import NWBHDF5IO
            from allensdk.brain_observatory.ecephys import nwb  # needed for ecephys format compat
        except ImportError:
            m = 'The "allensdk" library is required to work with neuropixel nwb formats, available on PyPI.'
            #log.error(m)
            raise ImportError(m)

        session_name = nwb_filepath.stem
        with NWBHDF5IO(str(nwb_filepath), 'r') as nwb_io:
            nwbfile = nwb_io.read()

            units = nwbfile.units
            epochs = nwbfile.epochs
           
            spike_times = dict(zip(units.id[:].astype(str), units['spike_times'][:]))

            # extract the metadata and convert to dict
            metadata = nwbfile.lab_meta_data['metadata'].to_dict()
            metadata['uri'] = str(nwb_filepath)  # add in uri
            #add invalid times data to meta as df if exist - includes times and probe id?
            if nwbfile.invalid_times is not None:
                invalid_times = nwbfile.invalid_times
                invalid_times =  np.array([invalid_times[col][:] for col in invalid_times.colnames])
                metadata['invalid_times'] = pd.DataFrame(invalid_times.transpose(),columns=['start_time', 'stop_time', 'tags'])
                
            # build the units metadata
            units_data = {
                col.name: col.data for col in units.columns
                if col.name not in ['spike_times', 'spike_times_index', 'spike_amplitudes',
                                    'spike_amplitudes_index', 'waveform_mean', 'waveform_mean_index']
            }

            # needs to be a dict
            units_meta = pd.DataFrame(units_data, index=units.id[:])
            #add electrode info to units meta
            electrodes=nwbfile.electrodes
            e_data = {col.name: col.data for col in electrodes.columns}
            e_meta = pd.DataFrame(e_data,index=electrodes.id[:])
            units_meta=pd.merge(units_meta,e_meta,left_on=units_meta.peak_channel_id,right_index=True, 
                                suffixes=('_unit','_channel')).drop(['key_0','group'],axis=1).to_dict('index')# needs to be a dict    

            # build the epoch dataframe
            epoch_data = {
                col.name: col.data for col in epochs.columns
                if col.name not in ['tags', 'timeseries', 'tags_index', 'timeseries_index']
            }

            epoch_df = pd.DataFrame(epoch_data, index=epochs.id[:]).rename({
                'start_time': 'start',
                'stop_time': 'end',
                'stimulus_name': 'name'
            }, axis='columns')

 
            #rename epochs to correspond to different nat scene/movie frames - 
            epoch_df.loc[epoch_df['frame'].notna(),'name'] = epoch_df.loc[epoch_df['frame'].notna(),'name'] + '_' + \
            epoch_df[epoch_df['frame'].notna()].iloc[:]['frame'].astype(int).astype(str)
            
            
            #drop extra columns
            metadata['epochs']=epoch_df #save extra stim info to meta
            epoch_df=epoch_df.drop([col for col in epoch_df.columns if col not in ['start','end','name']],axis=1)

#            #rename natural scene epochs to work w/demo
            df_copy = epoch_df[epoch_df.name.str.contains('natural_scene')].copy()
            df_copy.loc[:,'name']='REFERENCE'

            epoch_df=epoch_df.append(df_copy,ignore_index=True)
            #expand epoch bounds epochs will overlap to test evoked potential
#            to_adjust=epoch_df.loc[:,['start','end']].to_numpy()
#            epoch_df.loc[:,['start','end']] = nems.epoch.adjust_epoch_bounds(to_adjust,-0.1,0.1)
            
            
            # save the spike times as a point process signal frequency set to match other signals 
            pp = PointProcess(fs, spike_times, name='resp', recording=session_name, epochs=epoch_df,
                              chans=[str(c) for c in nwbfile.units.id[:]],meta=units_meta)
            #dict to pass to recording
            #signal_dict = {pp.name: pp}

          #  log.info('Successfully loaded nwb file.')
            from scipy.interpolate import interp1d
           #save pupil data as rasterized signal
            if with_pupil:
                try:
                    pupil = nwbfile.modules['eye_tracking'].data_interfaces['pupil_ellipse_fits']
                    t = pupil['timestamps'][:]
                    pupil = pupil['width'][:].reshape(1,-1) #only 1 dimension - or get 'height'
                    
                     #interpolate to set sampling rate
                    f = interp1d(t,pupil,bounds_error=False,fill_value=np.nan)

                    new_t = np.arange(0.0,(t.max()+1/fs),1/fs)#pupil data starting at timepoint 0.0 (nan filler)
                    pupil = f(new_t)
                    
                    pupil_signal = RasterizedSignal(fs=fs,data=pupil,recording=session_name,name='pupil',
                                                    epochs=epoch_df,chans=['pupil']) #for all data list(pupil_data.colnames[0:5])
                    
                #if no pupil data for session - still get spike data
                except KeyError:
                    print(session_name + ' has no pupil data.')

            
            if running_speed:
                running = nwbfile.modules['running'].data_interfaces['running_speed']
                t = running.timestamps[:][1]#data has start and end timestamps, here only end used
                running = running.data[:].reshape(1,-1)

                f = interp1d(t,running)
                #new_t = np.arange(np.min(t),np.max(t),1/fs)
                new_t = np.arange(epoch_df.start.min(),epoch_df.end.max(),1/fs)
                running = f(new_t)
                running=RasterizedSignal(fs=fs,data=running,name='running',recording=session_name,epochs=epoch_df)



            if as_dict:
                #each unit has seperate recording in dict
                rec_dict={}
                for c in pp.chans:
                    unit_signal=pp.extract_channels([c])
                    rec=Recording({'resp':unit_signal},meta=metadata)
                    if with_pupil:
                        rec.add_signal(pupil_signal)
                    if running_speed:
                        rec.add_signal(running)
                    rec_dict[c]=rec
                return rec_dict
            
            else:
                rec=Recording({'resp':pp},meta=metadata)
                if with_pupil:
                    rec.add_signal(pupil_signal)
                if running_speed:
                    rec.add_signal(running)
                return rec
Esempio n. 18
0
import random
import numpy as np
from functools import partial
import matplotlib.pyplot as plt
import nems.epoch as ep
import nems.modelspec as ms
import nems.plots.api as nplt
from nems.recording import Recording

# specify directories for loading data and fitted modelspec
signals_dir = '../signals'
#signals_dir = '/home/jacob/auto/data/batch271_fs100_ozgf18/'
modelspecs_dir = '../modelspecs'

# load the data
rec = Recording.load(os.path.join(signals_dir, 'TAR010c-18-1'))

# Add a new signal, respavg, to the recording, in 4 steps

# 1. Fold matrix over all stimuli, returning a dictionary where keys are stimuli
#    and each value in the dictionary is (reps X cell X bins)
epochs_to_extract = ep.epoch_names_matching(rec.epochs, '^STIM_')
folded_matrix = rec['resp'].extract_epochs(epochs_to_extract)

# 2. Average over all reps of each stim and save into dict called psth.
per_stim_psth = dict()
for k in folded_matrix.keys():
    per_stim_psth[k] = np.nanmean(folded_matrix[k], axis=0)

# 3. Invert the folding to unwrap the psth back out into a predicted spike_dict by
# simply replacing all epochs in the signal with their psth
Esempio n. 19
0
def generate_state_corrected_psth(batch=None, modelname=None, cellids=None, siteid=None, movement_mask=False,
                                        gain_only=False, dc_only=False, cache_path=None, recache=False):
    """
    Modifies the exisiting recording so that psth signal is the prediction specified
    by the modelname. Designed with stategain models in mind. CRH.

    If the model doesn't exist already in /auto/users/hellerc/results/, this
    will go ahead and fit the model and save it in /auto/users/hellerc/results.

    If the fit dir (from xforms) exists, simply reload the result and call this psth.
    """
    if siteid is None:
        raise ValueError("must specify siteid!")
    if cache_path is not None:
        fn = cache_path + siteid + '_{}.tgz'.format(modelname.split('.')[1])
        if gain_only:
            fn = fn.replace('.tgz', '_gonly.tgz')
        if 'mvm' in modelname:
            fn = fn.replace('.tgz', '_mvm.tgz')
        if (os.path.isfile(fn)) & (recache == False):
            rec = Recording.load(fn)
            return rec
        else:
            # do the rest of the code
            pass

    if batch is None or modelname is None:
        raise ValueError('Must specify batch and modelname!')
    results_table = nd.get_results_file(batch, modelnames=[modelname])
    preds = []
    ms = []
    for cell in cellids:
        log.info(cell)
        try:
            p = results_table[results_table['cellid']==cell]['modelpath'].values[0]
            if os.path.isdir(p):
                xfspec, ctx = xforms.load_analysis(p)
                preds.append(ctx['val'])
                ms.append(ctx['modelspec'])
            else:
                sys.exit('Fit for {0} does not exist'.format(cell))
        except:
            log.info("WARNING: fit doesn't exist for cell {0}".format(cell))

    # Need to add a check to make sure that the preds are the same length (if
    # multiple cellids). This could be violated if one cell for example existed
    # in a prepassive run but the other didn't and so they were fit differently
    file_epochs = []

    for pr in preds:
        file_epochs += [ep for ep in pr.epochs.name if ep.startswith('FILE')]

    unique_files = np.unique(file_epochs)
    shared_files = []
    for f in unique_files:
        if np.sum([1 for file in file_epochs if file == f]) == len(preds):
            shared_files.append(str(f))
        else:
            # this rawid didn't span all cells at the requested site
            pass

    # mask all file epochs for all preds with the shared file epochs
    # and adjust epochs
    if (int(batch) == 307) | (int(batch) == 294):
        for i, p in enumerate(preds):
            preds[i] = p.and_mask(shared_files)
            preds[i] = preds[i].apply_mask(reset_epochs=True)

    sigs = {}
    for i, p in enumerate(preds):
        if gain_only:
            # update phi
            mspec = ms[i]
            not_gain_keys = [k for k in mspec[0]['phi'].keys() if '_g' not in k]
            for k in not_gain_keys:
                mspec[0]['phi'][k] = np.append(mspec[0]['phi'][k][0, 0], np.zeros(mspec[0]['phi'][k].shape[-1]-1))[np.newaxis, :]
            pred = mspec.evaluate(p)['pred']
        elif dc_only:
            mspec = ms[i]
            not_dc_keys = [k for k in mspec[0]['phi'].keys() if '_d' not in k]
            for k in not_dc_keys:
                mspec[0]['phi'][k] = np.append(mspec[0]['phi'][k][0, 0], np.zeros(mspec[0]['phi'][k].shape[-1]-1))[np.newaxis, :]
            pred = mspec.evaluate(p)['pred']
        else:
            pred = p['pred'] 
        if i == 0:           
            new_psth_sp = p['psth_sp']
            new_psth = pred
            new_resp = p['resp'].rasterize()

        else:
            try:
                new_psth_sp = new_psth_sp.concatenate_channels([new_psth_sp, p['psth_sp']])
                new_psth = new_psth.concatenate_channels([new_psth, pred])
                new_resp = new_resp.concatenate_channels([new_resp, p['resp'].rasterize()])
            except ValueError:
                import pdb; pdb.set_trace()

    new_pup = preds[0]['pupil']
    sigs['pupil'] = new_pup

    if 'pupil_raw' in preds[0].signals.keys():
        sigs['pupil_raw'] = preds[0]['pupil_raw']

    if 'mask' in preds[0].signals:
        new_mask = preds[0]['mask']
        sigs['mask'] = new_mask
    else:
        mask_rec = preds[0].create_mask(True)
        new_mask = mask_rec['mask']
        sigs['mask'] = new_mask

    if 'rem' in preds[0].signals.keys():
        rem = preds[0]['rem']
        sigs['rem'] = rem

    if 'pupil_eyespeed' in preds[0].signals.keys():
        new_eyespeed = preds[0]['pupil_eyespeed']
        sigs['pupil_eyespeed'] = new_eyespeed

    new_psth_sp.name = 'psth_sp'
    new_psth.name = 'psth'
    new_resp.name = 'resp'
    sigs['psth_sp'] = new_psth_sp
    sigs['psth'] = new_psth
    sigs['resp'] = new_resp

    new_rec = Recording(sigs, meta=preds[0].meta)

    # make sure mask is cast to bool
    new_rec['mask'] = new_rec['mask']._modified_copy(new_rec['mask']._data.astype(bool))

    if cache_path is not None:
        log.info('caching {}'.format(fn))
        new_rec.save_targz(fn)

    return new_rec
Esempio n. 20
0
# get rid of any "cells" that never fired
idx = (resp==0).sum(axis=-1) == T
resp = resp[~idx, :]
psth = psth[~idx, :]
nCells = resp.shape[0]

# pack into nems recording
resp_sig = RasterizedSignal(fs=4, data=resp, name='resp', recording='simulation')
psth_sig = RasterizedSignal(fs=4, data=psth, name='psth', recording='simulation')
pupil_sig = RasterizedSignal(fs=4, data=pupil, name='pupil', recording='simulation')
bm = pupil > pupil.mean()
big_mask = RasterizedSignal(fs=4, data=bm, name='big_mask', recording='simulation')
sm = pupil < pupil.mean()
small_mask = RasterizedSignal(fs=4, data=sm, name='small_mask', recording='simulation')

rec = Recording({'resp': resp_sig, 'psth': psth_sig, 'pupil': pupil_sig})

# fit the GLM for different hyperparameters
x0 = np.zeros(3 * nCells)
nLV = 1
alpha1 = np.arange(0, 0.5, 0.05)
results = dict.fromkeys(alpha1)
for i, a2 in tqdm(enumerate(alpha1)):
    model_output = opt.minimize(glm.gain_only_objective, x0, (rec, big_mask, small_mask, a2, nLV), options={'gtol':1e-6, 'disp': True})
    weights = model_output.x

    # get model output
    w1 = weights.reshape((2+nLV), nCells)[0, :]
    w2 = weights.reshape((2+nLV), nCells)[1:-1, :]
    b = weights.reshape((2+nLV), nCells)[-1, :]
Esempio n. 21
0
            keyname = 'data_stream'
            #f_in = io.BytesIO(t.extractfile(member).read())

            # current non-optimal solution. extract hdf5 file to disk and then load
            t.extract(member, tpath)
            f = tpath + '/' + member.name

        elif basename.endswith('.json'):
            keyname = 'json_stream'
            f = io.StringIO(t.extractfile(member).read().decode('utf-8'))

        else:
            m = 'Unexpected file found in tar.gz: {} (size={})'.format(
                member.name, member.size)
            raise ValueError(m)
        # Ensure that we can doubly nest the streams dict
        if signame not in streams:
            streams[signame] = {}
        # Read out a stringIO object for each file now while it's open
        #f = io.StringIO(t.extractfile(member).read().decode('utf-8'))
        streams[signame][keyname] = f

# Now that the streams are organized, convert them into signals
# log.debug({k: streams[k].keys() for k in streams})
signals = [load_signal_from_streams(**sg) for sg in streams.values()]
signals_dict = {s.name: s for s in signals}

rec = Recording(signals=signals_dict)

shutil.rmtree(tpath)  # clean up