Example #1
0
def run_example():

    X, Y, K_true = create_toy_data(N=2500)

    nt, nf = K_true.shape
    D = create_derivative_matrix(nt, nf, order='C')
    prior = SmoothnessPrior(D=D)
    model = CbRF(optimize=True, metric='AUC', prior=prior,
                 verbose=True, n_griditer=3, n_jobs=1)
    model.fit(X, Y)
    k = model.get_weights()
    K = np.reshape(k, K_true.shape)

    fig, axarr = plt.subplots(nrows=1, ncols=2)

    ax = axarr[0]
    ax.set_title('True')
    vmax = np.max(np.abs(K_true))
    ax.imshow(K_true, interpolation='nearest', vmin=-vmax, vmax=vmax)

    ax = axarr[1]
    ax.set_title('CbRF')
    vmax = np.max(np.abs(K))
    ax.imshow(K, interpolation='nearest', vmin=-vmax, vmax=vmax)

    plt.show()
Example #2
0
def run_example():

    X, Y, K_true = create_toy_data(N=2500)

    nt, nf = K_true.shape
    D = create_derivative_matrix(nt, nf, order='C')
    prior = SmoothnessPrior(D=D)
    model = CbRF(optimize=True,
                 metric='AUC',
                 prior=prior,
                 verbose=True,
                 n_griditer=3,
                 n_jobs=1)
    model.fit(X, Y)
    k = model.get_weights()
    K = np.reshape(k, K_true.shape)

    fig, axarr = plt.subplots(nrows=1, ncols=2)

    ax = axarr[0]
    ax.set_title('True')
    vmax = np.max(np.abs(K_true))
    ax.imshow(K_true, interpolation='nearest', vmin=-vmax, vmax=vmax)

    ax = axarr[1]
    ax.set_title('CbRF')
    vmax = np.max(np.abs(K))
    ax.imshow(K, interpolation='nearest', vmin=-vmax, vmax=vmax)

    plt.show()
Example #3
0
def run_example():

    data_dir = join(split(__file__)[0], 'data')
    data_file = 'STRFs_ChirpsBlocks_IC_2012-01-31_50dB_chan01_unit01.h5'

    block = load_data(data_dir, data_file)
    X, Y, rfsize, axes, fs = convert_data(block)

    # Add some models
    n_jobs = 3  # number of processes used for hyperparameter optimization
    n_griditer = 3
    models = []

    model = STA()
    models.append(model)

    model = GaussianGLM(n_griditer=n_griditer,
                        n_jobs=n_jobs,
                        optimizer='ridge')
    models.append(model)

    model = CbRF(n_griditer=n_griditer, n_jobs=n_jobs)
    models.append(model)

    model = PoissonGLM(n_griditer=n_griditer, n_jobs=n_jobs)
    models.append(model)

    # Fit models
    for model in models:
        print 50 * '-'
        print "Fitting model:", model.name
        print 50 * '-'
        model.fit(X, Y)

    # Plot STRFs
    fig, axarr = plt.subplots(nrows=1,
                              ncols=len(models),
                              sharex=True,
                              sharey=True)
    for i, model in enumerate(models):

        ax = axarr[i]
        ax.set_title(model.name)

        W = np.reshape(model.coef_, rfsize)
        rf = STRF(W,
                  fs,
                  time=axes[0].values.base,
                  frequency=axes[1].values.base)
        rf.show(ax=ax, show_now=False, colorbar=False)

        if i > 0:
            ax.set_ylabel('')

    fig.set_size_inches(8, 2.5)
    fig.tight_layout()
    plt.show()
Example #4
0
def run_example():

    X, Y, K_true = create_toy_data(N=25000, nonlin_order=3)

    # Estimate STA
    model = STA()
    print "Estimating STA"
    model.fit(X, Y)
    k_sta = model.get_weights()
    K_sta = np.reshape(k_sta, K_true.shape)

    # Fit CbRF parameters
    prior = GaussianPrior()
    model = CbRF(optimize=True,
                 metric='AUC',
                 prior=prior,
                 verbose=True,
                 n_griditer=3,
                 n_jobs=-1,
                 param_grid={'alpha': 2**np.linspace(-20, 0, 7)})
    print "This may take a couple of minutes. Time to grab a tea or coffee ..."
    model.fit(X, Y)
    k_cbrf = model.get_weights()
    K_cbrf = np.reshape(k_cbrf, K_true.shape)

    fig, axarr = plt.subplots(nrows=1, ncols=3)

    ax = axarr[0]
    ax.set_title('True')
    vmax = np.max(np.abs(K_true))
    ax.imshow(K_true, interpolation='nearest', vmin=-vmax, vmax=vmax)

    ax = axarr[1]
    ax.set_title('STA')
    vmax = np.max(np.abs(K_sta))
    ax.imshow(K_sta, interpolation='nearest', vmin=-vmax, vmax=vmax)

    ax = axarr[2]
    ax.set_title('CbRF')
    vmax = np.max(np.abs(K_cbrf))
    ax.imshow(K_cbrf, interpolation='nearest', vmin=-vmax, vmax=vmax)

    plt.show()
Example #5
0
def run_example():

    t_spikefilt = 0.05

    data_dir = join(split(__file__)[0], 'data')
    data_file = 'STRFs_ChirpsBlocks_IC_2012-01-31_50dB_chan01_unit01.h5'
#    data_file = 'STRFs_ChirpsBlocks_IC_2012-04-26_Pos_02_30dB_chan01_unit01.h5'

    block = load_data(data_dir, data_file)
    X, Y, rfsize, axes, fs, n_spikefilt = convert_data(block, t_spikefilt)

    # Add some models
    n_jobs = 3  # number of workers used for hyperparameter optimization
    n_griditer = 3
    models = []

#    model = STA(n_postfilt=n_postfilt)
#    models.append(model)

#    model = GaussianGLM(n_griditer=n_griditer, n_jobs=n_jobs,
#                        optimizer='ridge')
#    models.append(model)

    model = CbRF(n_griditer=n_griditer, n_jobs=n_jobs, n_spikefilt=n_spikefilt)
    models.append(model)
#
#    model = PoissonGLM(n_griditer=n_griditer, n_jobs=n_jobs)
#    models.append(model)

    # Fit models
    for model in models:
        print 50 * '-'
        print "Fitting model:", model.name
        print 50 * '-'
        model.fit(X, Y)

    # Plot STRFs
    fig, axarr = plt.subplots(nrows=len(models), ncols=1+int(n_spikefilt > 0))
    axarr = np.atleast_2d(axarr)
    for i, model in enumerate(models):

        ax = axarr[i, 0]
        ax.set_title(model.name)

        W = np.reshape(model.get_weights(), rfsize)
        rf = STRF(W, fs, time=axes[0].values.base,
                  frequency=axes[1].values.base)
        rf.show(ax=ax, show_now=False, colorbar=False)

        if n_spikefilt > 0:
            # Post-spike filter
            ax = axarr[i, 1]
            xx = np.linspace(-t_spikefilt, 0, n_spikefilt) * 1000
            h = model.get_spikefilt()
            ax.plot(xx, h)
            ax.set_xlabel('Time (ms)')
            ax.set_ylabel('Gain (a.u.)')

    fig.set_size_inches(8, 4)
    fig.tight_layout()
    plt.show()