Пример #1
0
def get_stas(Stim, Robs, dims, valid=None, num_lags=10, plot=True):
    NX = dims[0]
    NY = dims[1]
    NT, NC = Robs.shape

    if valid is None:
        valid = np.arange(0, NT, 1)

    # create time-embedded stimulus
    Xstim, rinds = create_time_embedding_valid(Stim, [num_lags, NX, NY], valid)
    Rvalid = deepcopy(Robs[rinds, :])

    NTv = Rvalid.shape[0]
    print('%d valid samples of %d possible' % (NTv, NT))

    stas = Xstim.T @ (Rvalid - np.average(Rvalid, axis=0))
    stas = np.reshape(stas, [NX * NY, num_lags, NC]) / NTv

    if plot:
        plt.figure(figsize=(10, 15))
        sx, sy = U.get_subplot_dims(NC)

    mu = np.zeros(NC)
    for cc in range(NC):
        if plot:
            plt.subplot(sx, sy, cc + 1)
            plt.plot(np.abs(stas[:, :, cc]).T, color=[.5, .5, .5])
        tlevel = np.median(
            np.abs(stas[:, :, cc] - np.average(stas[:, :, cc]))) * 6
        mu[cc] = np.average(np.abs(stas[:, :, cc]) > tlevel)

        if plot:
            plt.axhline(tlevel, color='k')
            plt.title(cc)
    return stas
Пример #2
0
valid_eye_rad = 5.2  # degrees -- use this when looking at eye-calibration (see below)
ppd = 37.50476617061

eyeX = (eyeAtFrame[:,0]-640)/ppd
eyeY = (eyeAtFrame[:,1]-380)/ppd

eyeCentered = np.hypot(eyeX, eyeY) < valid_eye_rad
# eyeCentered = np.logical_and(eyeX < 0, eyeCentered)
valid_inds = np.intersect1d(valdata, np.where(eyeCentered)[0])

# %% quick check STAS
stas = ne.get_stas(Stim, Robs, [NX,NY], valid=valid_inds, num_lags=10, plot=False)

plt.figure(figsize=(10,10))
NC = Robs.shape[1]
sx,sy = U.get_subplot_dims(NC)
sumdensity = 0
for cc in range(NC):
    plt.subplot(sx,sy,cc+1)
    bestlag = np.argmax(np.max(abs(stas[:,:,cc]),axis=0))
    plt.imshow(np.reshape(stas[:,bestlag,cc], (NY,NX)))
    sumdensity += np.abs(stas[:,bestlag,cc])
    plt.title(cc)
    plt.axis("off")
    
plt.figure()
plt.imshow(np.reshape(sumdensity, (NY, NX)))

#%%
# Cxinds = ne.crop_indx(NX, range(9,24), range(9,24))
Cxinds = ne.crop_indx(NX, range(5,25), range(5,25))
Пример #3
0
    r = get_psth(Robs[:, cc], gratingDir, win)
    rhat0 = get_psth(yhat0[:, cc], gratingDir, win)
    rhat1 = get_psth(yhat1[:, cc], gratingDir, win)
    TC[:, cc] = np.sum(r[0], axis=0)
    TC1[:, cc] = np.sum(rhat0[0], axis=0)
    TC2[:, cc] = np.sum(rhat1[0], axis=0)

    plt.subplot(sx, sy, cc + 1)
    plt.plot(directions, TC[:, cc], 'k-o')
    plt.plot(directions, TC1[:, cc], 'r')
    plt.plot(directions, TC2[:, cc], 'b')
    plt.xticks(np.arange(0, 360, 90))
    sns.despine(offset=0, trim=True)

plt.figure()
r2Base = U.r_squared(TC, TC1)
r2Dir = U.r_squared(TC, TC2)
plt.plot(r2Base, r2Dir, 'o')
plt.xlim((0, 1))
plt.ylim((0, 1))
plt.plot(plt.xlim(), plt.xlim(), 'k')
plt.title("R-squared Tuning Curve")

r2Base = U.r_squared(Robs[Xi, :], yhat0[Xi, :])
r2Dir = U.r_squared(Robs[Xi, :], yhat1[Xi, :])
plt.figure()
plt.plot(r2Base, r2Dir, 'o')
plt.plot(plt.xlim(), plt.xlim(), 'k')
plt.xlabel("Var Explained (Base)")
plt.ylabel("Var Explained (Dir)")
Пример #4
0
def get_stim_model(Stim,
                   Robs,
                   dims,
                   valid=None,
                   num_lags=10,
                   plot=True,
                   XTreg=0.05,
                   L1reg=5e-3,
                   MIreg=0.1,
                   MSCreg=10.0,
                   Greg=0.1,
                   Mreg=1e-4,
                   num_subs=36,
                   num_hid=24,
                   num_tkern=None,
                   Cindx=None,
                   base_mod=None,
                   cids=None,
                   autoencoder=False):

    NX = dims[0]
    NY = dims[1]

    NT, NC = Robs.shape

    if valid is None:
        valid = np.arange(0, NT, 1)

    # create time-embedded stimulus
    Xstim, rinds = create_time_embedding_valid(Stim, [num_lags, NX, NY], valid)
    Rvalid = deepcopy(Robs[rinds, :])

    NTv = Rvalid.shape[0]
    print('%d valid samples of %d possible' % (NTv, NT))

    stas = Xstim.T @ (Rvalid - np.average(Rvalid, axis=0))
    stas = np.reshape(stas, [NX * NY, num_lags, NC]) / NTv

    if plot:
        plt.figure(figsize=(10, 15))
        sx, sy = U.get_subplot_dims(NC)

    mu = np.zeros(NC)
    for cc in range(NC):
        if plot:
            plt.subplot(sx, sy, cc + 1)
            plt.plot(np.abs(stas[:, :, cc]).T, color=[.5, .5, .5])
        tlevel = np.median(
            np.abs(stas[:, :, cc] - np.average(stas[:, :, cc]))) * 4
        mu[cc] = np.average(np.abs(stas[:, :, cc]) > tlevel)

        if plot:
            plt.axhline(tlevel, color='k')
            plt.title(cc)

    # threshold good STAS
    thresh = 0.01
    if plot:
        plt.figure()
        plt.plot(mu, '-o')
        plt.axhline(thresh, color='k')
        plt.show()

    if cids is None:
        cids = np.where(mu > thresh)[0]  # units to analyze
        print("found %d good STAs" % len(cids))

    if plot:
        plt.figure(figsize=(10, 15))
        for cc in cids:
            plt.subplot(sx, sy, cc + 1)
            bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0))
            plt.imshow(np.reshape(stas[:, bestlag, cc], (NY, NX)))
            plt.title(cc)

    # index into "good" units
    Rvalid = Rvalid[:, cids]
    NC = Rvalid.shape[1]
    stas = stas[:, :, cids]

    if Cindx is None:
        print("Getting Crop Index")
        # Crop stimulus to center around RFs
        sumdensity = np.zeros([NX * NY])
        for cc in range(NC):
            bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0))
            sumdensity += stas[:, bestlag, cc]**2

        if plot:
            plt.figure()
            plt.imshow(np.reshape(sumdensity, [NY, NX]))
            plt.title("Sum Density STA")

        # get Crop indices (TODO: debug)
        sumdensity = (sumdensity - np.min(sumdensity)) / (np.max(sumdensity) -
                                                          np.min(sumdensity))
        I = np.reshape(sumdensity, [NY, NX]) > .3
        xinds = np.where(np.sum(I, axis=0) > 0)[0]
        yinds = np.where(np.sum(I, axis=1) > 0)[0]

        NX2 = np.maximum(len(xinds), len(yinds))
        x0 = np.min(xinds)
        y0 = np.min(yinds)

        xinds = range(x0, x0 + NX2)
        yinds = range(y0, y0 + NX2)

        Cindx = crop_indx(NX, xinds, yinds)

        if plot:
            plt.figure()
            plt.imshow(np.reshape(sumdensity[Cindx], [NX2, NX2]))
            plt.title('Cropped')
            plt.show()

    NX2 = np.sqrt(len(Cindx)).astype(int)

    # make new cropped stimulus
    Xstim, rinds = create_time_embedding_valid(Stim[:, Cindx],
                                               [num_lags, NX2, NX2], valid)

    # index into Robs
    Rvalid = deepcopy(Robs[rinds, :])
    Rvalid = Rvalid[:, cids]
    Rvalid = NDNutils.shift_mat_zpad(Rvalid, -1, dim=0)  # get rid of first lag

    NC = Rvalid.shape[1]  # new number of units
    NT = Rvalid.shape[0]
    print('%d valid samples of %d possible' % (NT, Stim.shape[0]))
    print('%d good units' % NC)

    # double-check STAS work with cropped stimulus
    stas = Xstim.T @ Rvalid
    stas = np.reshape(stas, [NX2 * NX2, num_lags, NC]) / NT

    if plot:
        plt.figure(figsize=(10, 15))
        for cc in range(NC):
            plt.subplot(sx, sy, cc + 1)
            bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0))
            plt.imshow(np.reshape(stas[:, bestlag, cc], (NX2, NX2)))
            plt.title(cc)
        plt.show()

    Ui, Xi = NDNutils.generate_xv_folds(NT)

    # fit SCAFFOLD MODEL
    try:
        if len(XTreg) == 2:
            d2t = XTreg[0]
            d2x = XTreg[1]
        else:
            d2t = XTreg[0]
            d2x = deepcopy(d2t)
    except TypeError:
        d2t = deepcopy(XTreg)
        d2x = deepcopy(XTreg)

    # optimizer parameters
    adam_params = U.def_adam_params()

    if not base_mod is None:
        side2b = base_mod.copy_model()
        side2b.set_regularization('d2t', d2t, layer_target=0)
        side2b.set_regularization('d2x', d2x, layer_target=0)
        side2b.set_regularization('glocal', Greg, layer_target=0)
        side2b.set_regularization('l1', L1reg, layer_target=0)
        side2b.set_regularization('max', MIreg, ffnet_target=0, layer_target=1)
        side2b.set_regularization('max',
                                  MSCreg,
                                  ffnet_target=1,
                                  layer_target=0)

        if len(side2b.networks) == 4:  # includes autoencoder network
            input_data = [Xstim, Rvalid]
        else:
            input_data = Xstim

    else:
        # Best regularization arrived at
        Greg0 = 1e-1
        Mreg0 = 1e-6
        L1reg0 = 1e-5

        if not num_tkern is None:
            ndn_par = NDNutils.ffnetwork_params(
                input_dims=[1, NX2, NX2, num_lags],
                layer_sizes=[num_tkern, num_subs, num_hid],
                layer_types=['conv', 'normal', 'normal'],
                ei_layers=[None, num_subs // 2, num_hid // 2],
                conv_filter_widths=[1],
                normalization=[1, 1, 1],
                act_funcs=['lin', 'relu', 'relu'],
                verbose=True,
                reg_list={
                    'd2t': [1e-3],
                    'd2x': [None, XTreg],
                    'l1': [L1reg0, L1reg0],
                    'glocal': [Greg0, Greg0]
                })
        else:
            ndn_par = NDNutils.ffnetwork_params(
                input_dims=[1, NX2, NX2, num_lags],
                layer_sizes=[num_subs, num_hid],
                layer_types=['normal', 'normal'],
                ei_layers=[num_subs // 2, num_hid // 2],
                normalization=[1, 1],
                act_funcs=['relu', 'relu'],
                verbose=True,
                reg_list={
                    'd2t': [d2t],
                    'd2x': [d2x],
                    'l1': [L1reg0, L1reg0],
                    'glocal': [Greg0]
                })

        side_par = NDNutils.ffnetwork_params(network_type='side',
                                             xstim_n=None,
                                             ffnet_n=0,
                                             layer_sizes=[NC],
                                             layer_types=['normal'],
                                             normalization=[-1],
                                             act_funcs=['softplus'],
                                             verbose=True,
                                             reg_list={'max': [Mreg0]})

        side_par[
            'pos_constraints'] = True  # ensures Exc and Inh mean something

        if autoencoder:  # capturea additional variability using autoencoder
            auto_par = NDNutils.ffnetwork_params(
                input_dims=[1, NC, 1],
                xstim_n=[1],
                layer_sizes=[2, 1, NC],
                time_expand=[0, 15, 0],
                layer_types=['normal', 'temporal', 'normal'],
                conv_filter_widths=[None, 1, None],
                act_funcs=['relu', 'lin', 'lin'],
                normalization=[1, 1, 0],
                reg_list={'d2t': [None, 1e-1, None]})

            add_par = NDNutils.ffnetwork_params(xstim_n=None,
                                                ffnet_n=[1, 2],
                                                layer_sizes=[NC],
                                                layer_types=['add'],
                                                act_funcs=['softplus'])

            side2 = NDN.NDN([ndn_par, side_par, auto_par, add_par],
                            ffnet_out=1,
                            noise_dist='poisson')

            # set output regularization on the latent
            side2.batch_size = adam_params['batch_size']
            side2.initialize_output_reg(network_target=2,
                                        layer_target=1,
                                        reg_vals={'d2t': 1e-1})

            input_data = [Xstim, Rvalid]

        else:
            side2 = NDN.NDN([ndn_par, side_par],
                            ffnet_out=1,
                            noise_dist='poisson')

            input_data = Xstim

        _ = side2.train(input_data=input_data,
                        output_data=Rvalid,
                        train_indxs=Ui,
                        test_indxs=Xi,
                        silent=False,
                        learning_alg='adam',
                        opt_params=adam_params)

        side2.set_regularization('glocal', Greg, layer_target=0)
        side2.set_regularization('l1', L1reg, layer_target=0)
        side2.set_regularization('max', MIreg, ffnet_target=0, layer_target=1)
        side2.set_regularization('max', MSCreg, ffnet_target=1, layer_target=0)

        side2b = side2.copy_model()

    _ = side2b.train(input_data=input_data,
                     output_data=Rvalid,
                     train_indxs=Ui,
                     test_indxs=Xi,
                     silent=False,
                     learning_alg='adam',
                     opt_params=adam_params)

    LLs2n = side2b.eval_models(input_data=input_data,
                               output_data=Rvalid,
                               data_indxs=Xi,
                               nulladjusted=True)
    print(np.mean(LLs2n))
    if plot:
        plt.hist(LLs2n)
        plt.xlabel('Nats/Spike')
        plt.show()

    return side2b, Xstim, Rvalid, rinds, cids, Cindx
Пример #5
0
def prep_stim_model(Stim,
                    Robs,
                    dims,
                    valid=None,
                    num_lags=10,
                    plot=True,
                    Cindx=None,
                    cids=None):

    NX = dims[0]
    NY = dims[1]

    NT, NC = Robs.shape

    if valid is None:
        valid = np.arange(0, NT, 1)

    # create time-embedded stimulus
    Xstim, rinds = create_time_embedding_valid(Stim, [num_lags, NX, NY], valid)
    Rvalid = deepcopy(Robs[rinds, :])

    NTv = Rvalid.shape[0]
    print('%d valid samples of %d possible' % (NTv, NT))

    stas = Xstim.T @ (Rvalid - np.average(Rvalid, axis=0))
    stas = np.reshape(stas, [NX * NY, num_lags, NC]) / NTv

    if plot:
        plt.figure(figsize=(10, 15))
        sx, sy = U.get_subplot_dims(NC)

    mu = np.zeros(NC)
    for cc in range(NC):
        if plot:
            plt.subplot(sx, sy, cc + 1)
            plt.plot(np.abs(stas[:, :, cc]).T, color=[.5, .5, .5])
        tlevel = np.median(
            np.abs(stas[:, :, cc] - np.average(stas[:, :, cc]))) * 4
        mu[cc] = np.average(np.abs(stas[:, :, cc]) > tlevel)

        if plot:
            plt.axhline(tlevel, color='k')
            plt.title(cc)

    # threshold good STAS
    thresh = 0.01
    if plot:
        plt.figure()
        plt.plot(mu, '-o')
        plt.axhline(thresh, color='k')
        plt.show()

    if cids is None:
        cids = np.where(mu > thresh)[0]  # units to analyze
        print("found %d good STAs" % len(cids))

    if plot:
        plt.figure(figsize=(10, 15))
        for cc in cids:
            plt.subplot(sx, sy, cc + 1)
            bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0))
            plt.imshow(np.reshape(stas[:, bestlag, cc], (NY, NX)))
            plt.title(cc)

    # index into "good" units
    Rvalid = Rvalid[:, cids]
    NC = Rvalid.shape[1]
    stas = stas[:, :, cids]

    if Cindx is None:
        print("Getting Crop Index")
        # Crop stimulus to center around RFs
        sumdensity = np.zeros([NX * NY])
        for cc in range(NC):
            bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0))
            sumdensity += stas[:, bestlag, cc]**2

        if plot:
            plt.figure()
            plt.imshow(np.reshape(sumdensity, [NY, NX]))
            plt.title("Sum Density STA")

        # get Crop indices (TODO: debug)
        sumdensity = (sumdensity - np.min(sumdensity)) / (np.max(sumdensity) -
                                                          np.min(sumdensity))
        I = np.reshape(sumdensity, [NY, NX]) > .3
        xinds = np.where(np.sum(I, axis=0) > 0)[0]
        yinds = np.where(np.sum(I, axis=1) > 0)[0]

        NX2 = np.maximum(len(xinds), len(yinds))
        x0 = np.min(xinds)
        y0 = np.min(yinds)

        xinds = range(x0, x0 + NX2)
        yinds = range(y0, y0 + NX2)

        Cindx = crop_indx(NX, xinds, yinds)

        if plot:
            plt.figure()
            plt.imshow(np.reshape(sumdensity[Cindx], [NX2, NX2]))
            plt.title('Cropped')
            plt.show()

    NX2 = np.sqrt(len(Cindx)).astype(int)

    # make new cropped stimulus
    Xstim, rinds = create_time_embedding_valid(Stim[:, Cindx],
                                               [num_lags, NX2, NX2], valid)

    # index into Robs
    Rvalid = deepcopy(Robs[rinds, :])
    Rvalid = Rvalid[:, cids]
    Rvalid = NDNutils.shift_mat_zpad(Rvalid, -1, dim=0)  # get rid of first lag

    NC = Rvalid.shape[1]  # new number of units
    NT = Rvalid.shape[0]
    print('%d valid samples of %d possible' % (NT, Stim.shape[0]))
    print('%d good units' % NC)

    # double-check STAS work with cropped stimulus
    stas = Xstim.T @ Rvalid
    stas = np.reshape(stas, [NX2 * NX2, num_lags, NC]) / NT

    if plot:
        plt.figure(figsize=(10, 15))
        for cc in range(NC):
            plt.subplot(sx, sy, cc + 1)
            bestlag = np.argmax(np.max(abs(stas[:, :, cc]), axis=0))
            plt.imshow(np.reshape(stas[:, bestlag, cc], (NX2, NX2)))
            plt.title(cc)
        plt.show()

    dims = (num_lags, NX2, NX2)
    return Xstim, Rvalid, dims, Cindx, cids
Пример #6
0
#%%

ypred = model(xt)
cc += 1
if cc == gd.NC:
    cc = 0
a = ypred[:,cc].detach().cpu().numpy()
r0 = np.reshape(a, (gd.opts['num_repeats'],-1))
r = np.reshape(yt[:,cc], (gd.opts['num_repeats'],-1))

r = np.average(r, axis=0)
r0 = np.average(r0, axis=0)
plt.plot(r)
plt.plot(r0)
plt.title("cell %d" %cc)
U.r_squared(np.reshape(r, (-1,1)), np.reshape(r0, (-1,1)))


# %%
w = model.l1.weight.detach().cpu().numpy()

nfilt = w.shape[0]
sx,sy = U.get_subplot_dims(nfilt)
plt.figure(figsize=(10,10))
for cc in range(nfilt):
    plt.subplot(sx,sy,cc+1)
    wtmp = np.reshape(w[cc,:], (gd.num_lags, gd.NX*gd.NY))
    # plt.imshow(np.reshape(w[cc,:], (gd.NX*gd.NY, gd.num_lags)), aspect='auto')
    plt.imshow(wtmp, aspect='auto')
    # plt.plot(wtmp)
Пример #7
0
# get median absolute deviation per unit
med = np.median(stas, axis=0)
devs = np.abs(stas - med)
mad = np.median(devs, axis=0) # median absolute deviation

excursions = np.mean(devs > 4*mad, axis=0)
thresh = 0.01
cids = np.where(excursions>thresh)[0]
plt.plot(excursions, '-o')
plt.plot(cids, excursions[cids], 'o')


#%% plot stas
NC = gd.NC
sx,sy = U.get_subplot_dims(NC)
sx*=2
plt.figure(figsize=(10,20))
for cc in range(NC):
    plt.subplot(sx,sy,cc*2+1)
    wtmp = stas[:,cc].reshape((gd.num_lags, -1))
    tpower = np.std(wtmp, axis=1)
    bestlag = np.argmax(tpower)
    wspace = wtmp[bestlag,:].reshape( (gd.NY, gd.NX))
    plt.imshow(wspace, aspect='auto')
    plt.axis("off")
    if cc in cids:
        plt.title(cc)
    plt.subplot(sx,sy,cc*2+2)
    if cc in cids:
        plt.plot(wtmp[:,np.argmax(wspace)], '-ob')
Пример #8
0
Xstim = flat(sample['stim'].permute((0, 2, 3, 1))).detach().cpu().numpy()
Robs = sample['robs'].detach().cpu().numpy()

# test set
sample = gab_shift_test[:]  # load sample

Xstim_test = flat(sample['stim'].permute((0, 2, 3, 1))).detach().cpu().numpy()
Robs_test = sample['robs'].detach().cpu().numpy()

NT = Xstim.shape[0]
Ui, Xi = NDNutils.generate_xv_folds(NT)

num_lags = gd_shift.num_lags
NC = Robs.shape[1]

adam_params = U.def_adam_params()
#%% Make model

Greg0 = 1e-1
Greg = 1e-1
Creg0 = 1
Creg = 1e-2
Mreg0 = 1e-3
Mreg = 1e-1
L1reg0 = 1e-5
Xreg = 1e-2

num_tkern = 2
num_subs = 8

# ndn_par = NDNutils.ffnetwork_params(