Ejemplo n.º 1
0
stimulus_data = data['full_stimulus']
stimulus_data = smooth_stimulus(stimulus_times, stimulus_data)

Ts = data['signal_times']
Xs = data['signal_marks']
ys = spike_stimulus(Ts, stimulus_times, stimulus_data)

# Remove bad spikes
Ts, _, (Xs, ys) = remove_unlabeled_spikes(Ts, data['signal_cellids'], Xs, ys)

# Separate signal features
# Note, this function will scale the data by default before adding the constant
Xs = separate_signal_features(Xs)

# Drop to a single signal
T, (X, y) = multi_to_single_signal(Ts, Xs, ys)

# Create a mask for the training subset when the stimulus is moving quickly (running)
y_train_mask = stimulus_gradient_mask(T, y, min_g=5, max_g=1000)

# Calculate bin edges independent of signal
ybin_edges, ybin_counts = bin_edges_from_data(stimulus_data, STIMULUS_BINS)

# Construct the KDE
estimator = BivariateKernelDensity(n_neighbors=30,
                                   bandwidth_X=0.13,
                                   bandwidth_y=12,
                                   ybins=ybin_edges,
                                   tree_backend='auto' if GPU else 'ball',
                                   n_jobs=8)
Ejemplo n.º 2
0
# Fit the model
pipeline.fit(X_train, y_train)

if DISPLAY_TUNING_CURVES:
    fig, axes = n_subplot_grid(min(X.shape[1], 10))
    for i, ax in enumerate(axes):
        axes[i].imshow(pipeline.tuning_curves[i].reshape(
            STIMULUS_BINS, STIMULUS_BINS))
        axes[i].set_title('Example TC {}'.format(i))
    fig.show()

# Predict probabilities
y_pred = pipeline.predict_proba(X_test)

# Already single signal but this will sort the arrays quickly
T_test, (y_pred, y_test) = multi_to_single_signal([T_test], [y_pred], [y_test])

# Normalize to a probability distribution
y_pred /= np.nansum(y_pred, axis=1)[:, np.newaxis]

ybin_grid = pipeline.ybin_grid
y_predicted = ybin_grid[np.argmax(y_pred, axis=1)]

if DISPLAY_PLOTS:
    fig, axes = n_subplot_grid(y_predicted.shape[1],
                               max_horizontal=1,
                               figsize=(10, 8))
    for dim, ax in enumerate(axes):
        ax.plot(T_test, y_test[:, dim])
        ax.plot(T_test, y_predicted[:, dim])
        # ax.plot(T_test, y_predicted_filt[:, dim])