コード例 #1
0
adam_params = NDN.NDN.optimizer_defaults(opt_params={'use_gpu': True},
                                         learning_alg='adam')

lbfgs_params = NDN.NDN.optimizer_defaults(opt_params={
    'use_gpu': True,
    'display': True
},
                                          learning_alg='lbfgs')
lbfgs_params['maxiter'] = 1000

# setup training indices
NT = Xstim.shape[0]
valdata = np.arange(0, NT, 1)

NC = 1
Ui, Xi = NDNutils.generate_xv_folds(NT, num_blocks=2)

# GLM

# NDN parameters for processing the stimulus
par = NDNutils.ffnetwork_params(input_dims=[1, NX, NY, num_lags],
                                layer_sizes=[NC],
                                layer_types=['normal'],
                                normalization=[0],
                                act_funcs=['softplus'],
                                verbose=True,
                                reg_list={
                                    'd2x': [0.01],
                                    'glocal': [0.1]
                                })
コード例 #2
0
ファイル: neureye.py プロジェクト: jcbyts/V1FreeViewingCode
def get_stim_model(Stim,
                   Robs,
                   dims,
                   valid=None,
                   num_lags=10,
                   plot=True,
                   XTreg=0.05,
                   L1reg=5e-3,
                   MIreg=0.1,
                   MSCreg=10.0,
                   Greg=0.1,
                   Mreg=1e-4,
                   num_subs=36,
                   num_hid=24,
                   num_tkern=None,
                   Cindx=None,
                   base_mod=None,
                   cids=None,
                   autoencoder=False):

    NX = dims[0]
    NY = dims[1]

    NT, NC = Robs.shape

    if valid is None:
        valid = np.arange(0, NT, 1)

    # create time-embedded stimulus
    Xstim, rinds = create_time_embedding_valid(Stim, [num_lags, NX, NY], valid)
    Rvalid = deepcopy(Robs[rinds, :])

    NTv = Rvalid.shape[0]
    print('%d valid samples of %d possible' % (NTv, NT))

    stas = Xstim.T @ (Rvalid - np.average(Rvalid, axis=0))
    stas = np.reshape(stas, [NX * NY, num_lags, NC]) / NTv

    if plot:
        plt.figure(figsize=(10, 15))
        sx, sy = U.get_subplot_dims(NC)

    mu = np.zeros(NC)
    for cc in range(NC):
        if plot:
            plt.subplot(sx, sy, cc + 1)
            plt.plot(np.abs(stas[:, :, cc]).T, color=[.5, .5, .5])
        tlevel = np.median(
            np.abs(stas[:, :, cc] - np.average(stas[:, :, cc]))) * 4
        mu[cc] = np.average(np.abs(stas[:, :, cc]) > tlevel)

        if plot:
            plt.axhline(tlevel, color='k')
            plt.title(cc)

    # threshold good STAS
    thresh = 0.01
    if plot:
        plt.figure()
        plt.plot(mu, '-o')
        plt.axhline(thresh, color='k')
        plt.show()

    if cids is None:
        cids = np.where(mu > thresh)[0]  # units to analyze
        print("found %d good STAs" % len(cids))

    if plot:
        plt.figure(figsize=(10, 15))
        for cc in cids:
            plt.subplot(sx, sy, cc + 1)
            bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0))
            plt.imshow(np.reshape(stas[:, bestlag, cc], (NY, NX)))
            plt.title(cc)

    # index into "good" units
    Rvalid = Rvalid[:, cids]
    NC = Rvalid.shape[1]
    stas = stas[:, :, cids]

    if Cindx is None:
        print("Getting Crop Index")
        # Crop stimulus to center around RFs
        sumdensity = np.zeros([NX * NY])
        for cc in range(NC):
            bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0))
            sumdensity += stas[:, bestlag, cc]**2

        if plot:
            plt.figure()
            plt.imshow(np.reshape(sumdensity, [NY, NX]))
            plt.title("Sum Density STA")

        # get Crop indices (TODO: debug)
        sumdensity = (sumdensity - np.min(sumdensity)) / (np.max(sumdensity) -
                                                          np.min(sumdensity))
        I = np.reshape(sumdensity, [NY, NX]) > .3
        xinds = np.where(np.sum(I, axis=0) > 0)[0]
        yinds = np.where(np.sum(I, axis=1) > 0)[0]

        NX2 = np.maximum(len(xinds), len(yinds))
        x0 = np.min(xinds)
        y0 = np.min(yinds)

        xinds = range(x0, x0 + NX2)
        yinds = range(y0, y0 + NX2)

        Cindx = crop_indx(NX, xinds, yinds)

        if plot:
            plt.figure()
            plt.imshow(np.reshape(sumdensity[Cindx], [NX2, NX2]))
            plt.title('Cropped')
            plt.show()

    NX2 = np.sqrt(len(Cindx)).astype(int)

    # make new cropped stimulus
    Xstim, rinds = create_time_embedding_valid(Stim[:, Cindx],
                                               [num_lags, NX2, NX2], valid)

    # index into Robs
    Rvalid = deepcopy(Robs[rinds, :])
    Rvalid = Rvalid[:, cids]
    Rvalid = NDNutils.shift_mat_zpad(Rvalid, -1, dim=0)  # get rid of first lag

    NC = Rvalid.shape[1]  # new number of units
    NT = Rvalid.shape[0]
    print('%d valid samples of %d possible' % (NT, Stim.shape[0]))
    print('%d good units' % NC)

    # double-check STAS work with cropped stimulus
    stas = Xstim.T @ Rvalid
    stas = np.reshape(stas, [NX2 * NX2, num_lags, NC]) / NT

    if plot:
        plt.figure(figsize=(10, 15))
        for cc in range(NC):
            plt.subplot(sx, sy, cc + 1)
            bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0))
            plt.imshow(np.reshape(stas[:, bestlag, cc], (NX2, NX2)))
            plt.title(cc)
        plt.show()

    Ui, Xi = NDNutils.generate_xv_folds(NT)

    # fit SCAFFOLD MODEL
    try:
        if len(XTreg) == 2:
            d2t = XTreg[0]
            d2x = XTreg[1]
        else:
            d2t = XTreg[0]
            d2x = deepcopy(d2t)
    except TypeError:
        d2t = deepcopy(XTreg)
        d2x = deepcopy(XTreg)

    # optimizer parameters
    adam_params = U.def_adam_params()

    if not base_mod is None:
        side2b = base_mod.copy_model()
        side2b.set_regularization('d2t', d2t, layer_target=0)
        side2b.set_regularization('d2x', d2x, layer_target=0)
        side2b.set_regularization('glocal', Greg, layer_target=0)
        side2b.set_regularization('l1', L1reg, layer_target=0)
        side2b.set_regularization('max', MIreg, ffnet_target=0, layer_target=1)
        side2b.set_regularization('max',
                                  MSCreg,
                                  ffnet_target=1,
                                  layer_target=0)

        if len(side2b.networks) == 4:  # includes autoencoder network
            input_data = [Xstim, Rvalid]
        else:
            input_data = Xstim

    else:
        # Best regularization arrived at
        Greg0 = 1e-1
        Mreg0 = 1e-6
        L1reg0 = 1e-5

        if not num_tkern is None:
            ndn_par = NDNutils.ffnetwork_params(
                input_dims=[1, NX2, NX2, num_lags],
                layer_sizes=[num_tkern, num_subs, num_hid],
                layer_types=['conv', 'normal', 'normal'],
                ei_layers=[None, num_subs // 2, num_hid // 2],
                conv_filter_widths=[1],
                normalization=[1, 1, 1],
                act_funcs=['lin', 'relu', 'relu'],
                verbose=True,
                reg_list={
                    'd2t': [1e-3],
                    'd2x': [None, XTreg],
                    'l1': [L1reg0, L1reg0],
                    'glocal': [Greg0, Greg0]
                })
        else:
            ndn_par = NDNutils.ffnetwork_params(
                input_dims=[1, NX2, NX2, num_lags],
                layer_sizes=[num_subs, num_hid],
                layer_types=['normal', 'normal'],
                ei_layers=[num_subs // 2, num_hid // 2],
                normalization=[1, 1],
                act_funcs=['relu', 'relu'],
                verbose=True,
                reg_list={
                    'd2t': [d2t],
                    'd2x': [d2x],
                    'l1': [L1reg0, L1reg0],
                    'glocal': [Greg0]
                })

        side_par = NDNutils.ffnetwork_params(network_type='side',
                                             xstim_n=None,
                                             ffnet_n=0,
                                             layer_sizes=[NC],
                                             layer_types=['normal'],
                                             normalization=[-1],
                                             act_funcs=['softplus'],
                                             verbose=True,
                                             reg_list={'max': [Mreg0]})

        side_par[
            'pos_constraints'] = True  # ensures Exc and Inh mean something

        if autoencoder:  # capturea additional variability using autoencoder
            auto_par = NDNutils.ffnetwork_params(
                input_dims=[1, NC, 1],
                xstim_n=[1],
                layer_sizes=[2, 1, NC],
                time_expand=[0, 15, 0],
                layer_types=['normal', 'temporal', 'normal'],
                conv_filter_widths=[None, 1, None],
                act_funcs=['relu', 'lin', 'lin'],
                normalization=[1, 1, 0],
                reg_list={'d2t': [None, 1e-1, None]})

            add_par = NDNutils.ffnetwork_params(xstim_n=None,
                                                ffnet_n=[1, 2],
                                                layer_sizes=[NC],
                                                layer_types=['add'],
                                                act_funcs=['softplus'])

            side2 = NDN.NDN([ndn_par, side_par, auto_par, add_par],
                            ffnet_out=1,
                            noise_dist='poisson')

            # set output regularization on the latent
            side2.batch_size = adam_params['batch_size']
            side2.initialize_output_reg(network_target=2,
                                        layer_target=1,
                                        reg_vals={'d2t': 1e-1})

            input_data = [Xstim, Rvalid]

        else:
            side2 = NDN.NDN([ndn_par, side_par],
                            ffnet_out=1,
                            noise_dist='poisson')

            input_data = Xstim

        _ = side2.train(input_data=input_data,
                        output_data=Rvalid,
                        train_indxs=Ui,
                        test_indxs=Xi,
                        silent=False,
                        learning_alg='adam',
                        opt_params=adam_params)

        side2.set_regularization('glocal', Greg, layer_target=0)
        side2.set_regularization('l1', L1reg, layer_target=0)
        side2.set_regularization('max', MIreg, ffnet_target=0, layer_target=1)
        side2.set_regularization('max', MSCreg, ffnet_target=1, layer_target=0)

        side2b = side2.copy_model()

    _ = side2b.train(input_data=input_data,
                     output_data=Rvalid,
                     train_indxs=Ui,
                     test_indxs=Xi,
                     silent=False,
                     learning_alg='adam',
                     opt_params=adam_params)

    LLs2n = side2b.eval_models(input_data=input_data,
                               output_data=Rvalid,
                               data_indxs=Xi,
                               nulladjusted=True)
    print(np.mean(LLs2n))
    if plot:
        plt.hist(LLs2n)
        plt.xlabel('Nats/Spike')
        plt.show()

    return side2b, Xstim, Rvalid, rinds, cids, Cindx
コード例 #3
0
    plt.subplot(sx,sy,cc+1)
    bestlag = np.argmax(np.max(abs(stas[:,:,cc]),axis=0))
    sta = stas[:,bestlag,cc]
    plt.imshow(np.reshape(sta[Cxinds], [NX2,NX2]))

# %% prepare stimulus for model
Xstim, Rvalid, dims, CXinds, cids = ne.prep_stim_model(Stim, Robs, [NX,NY],
    valid=valid_inds,
    num_lags=10,
    plot=True,
    Cindx=Cxinds,
    )

#%% fit GQM
NT = Xstim.shape[0]
Ui, Xi = NDNutils.generate_xv_folds(NT)

# optimizer parameters
adam_params = U.def_adam_params()

# d2ts = 1e-4*10**np.arange(0, 5)

d2xs = 1e-2*10**np.arange(0, 5)
gqms = []
LLxs = []
for step in range(len(d2xs)):

    d2t = .05
    d2x = d2xs[step]
    loc = 1e-5
コード例 #4
0
stas = Xstim.T@(Rvalid - np.mean(Rvalid))
stas = np.reshape(stas, [NX2*NX2, num_lags, NC])/NT
    
plt.figure(figsize=(10,15))
for cc in range(NC):
    plt.subplot(sx,sy,cc+1)
    bestlag = np.argmax(np.max(abs(stas[:,:,cc]),axis=0))
    # plt.plot(stas[:,:,cc].T)
    plt.imshow(np.reshape(stas[:,bestlag,cc], (NX2,NX2)))
    plt.title(cc)
plt.show()

#%%
Rvalid = deepcopy(Robs[:,cids])
# train, test indices    
Ui, Xi = NDNutils.generate_xv_folds(NT)
    
# fit SCAFFOLD MODEL
    
# Best regularization arrived at
Greg0 = 1e-1
Mreg0 = 1e-6
L1reg0 = 1e-5

XTreg=0.05
L1reg=5e-3
MIreg=0.1
MSCreg=10.0
Greg=0.1
Mreg=1e-4