def fit_shifter(
        name='20200304_kilowf',
        path='/home/jake/Data/Datasets/MitchellV1FreeViewing/stim_movies/',
        tdownsample=2,
        numlags=12,
        lengthscale=1,
        stimlist=["Gabor", "Dots", "BackImage", "Grating", "FixRsvpStim"]):
    """
    fit_shifter

    # Main routine. Fits shifter network simultaneously with convolutional divisive normalization layer
    # Inputs:
        name='20200304_kilowf'
        path='/home/jake/Data/Datasets/MitchellV1FreeViewing/stim_movies/'
        tdownsample=2
        numlags=12
        lengthscale=1
    """
    num_lags = numlags
    sessid = name  # '20200304_kilowf'

    save_dir = '../../checkpoints/v1calibration_ls{}'.format(lengthscale)
    figDir = "/home/jake/Data/Repos/V1FreeViewingCode/Figures/2021_pytorchmodeling"

    outfile = Path(save_dir) / sessid / 'best_shifter.p'
    if outfile.exists():
        print("fit_shifter: shifter model already fit for this session")
        return

    #%% LOAD ALL DATASETS

    # build tent basis for saccades
    n = 40
    num_basis = 15
    B = np.maximum(
        1 - np.abs(
            np.expand_dims(np.asarray(np.arange(0, n)), axis=1) -
            np.arange(0, n, n / num_basis)) / n * num_basis, 0)
    t_downsample = tdownsample

    gd = dd.PixelDataset(sessid,
                         stims=stimlist,
                         stimset="Train",
                         num_lags=num_lags,
                         downsample_t=t_downsample,
                         downsample_s=1,
                         valid_eye_rad=5.2,
                         dirname=path,
                         include_frametime={
                             'num_basis': 40,
                             'full_experiment': False
                         },
                         include_saccades=[{
                             'name': 'sacon',
                             'basis': B,
                             'offset': -20
                         }, {
                             'name': 'sacoff',
                             'basis': B,
                             'offset': 0
                         }],
                         include_eyepos=True,
                         preload=True)

    sample = gd[:]

    #%% compute STAS
    """
    Compute STAS using einstein summation through pytorch
    """
    # # use gabor here if it exists, else, use dots
    # if 'Gabor' in gd.stims:
    #     stimuse = [i for i,s in zip(range(len(gd.stims)), gd.stims) if 'Gabor' == s][0]
    # elif 'Dots' in gd.stims:
    #     stimuse = [i for i,s in zip(range(len(gd.stims)), gd.stims) if 'Dots' == s][0]
    # use first stim in list (automatically goes in Gabor, Dots order)
    index = np.where(gd.stim_indices == 0)[0]
    sample = gd[index]  # load sample

    stas = torch.einsum('nlwh,nc->lwhc', sample['stim'],
                        sample['robs'] - sample['robs'].mean(dim=0))
    sta = stas.detach().cpu().numpy()

    # plot STAs / get RF centers
    """
    Plot space/time STAs 
    """
    NC = sta.shape[3]
    mu = np.zeros((NC, 2))
    sx = np.ceil(np.sqrt(NC * 2))
    sy = np.round(np.sqrt(NC * 2))

    mod2 = sy % 2
    sy += mod2
    sx -= mod2

    tdiff = np.zeros((num_lags, NC))
    blag = np.zeros(NC)

    plt.figure(figsize=(sx * 2, sy * 2))
    for cc in range(NC):
        w = sta[:, :, :, cc]

        wt = np.std(w, axis=0)
        wt /= np.max(np.abs(wt))  # normalize for numerical stability
        # softmax
        wt = wt**10
        wt /= np.sum(wt)
        sz = wt.shape
        xx, yy = np.meshgrid(np.linspace(-1, 1, sz[1]),
                             np.linspace(1, -1, sz[0]))

        mu[cc, 0] = np.minimum(np.maximum(np.sum(xx * wt), -.5),
                               .5)  # center of mass after softmax
        mu[cc, 1] = np.minimum(np.maximum(np.sum(yy * wt), -.5),
                               .5)  # center of mass after softmax

        w = (w - np.mean(w)) / np.std(w)

        bestlag = np.argmax(np.std(w.reshape((gd.num_lags, -1)), axis=1))
        blag[cc] = bestlag
        plt.subplot(sx, sy, cc * 2 + 1)
        v = np.max(np.abs(w))
        plt.imshow(w[bestlag, :, :],
                   aspect='auto',
                   interpolation=None,
                   vmin=-v,
                   vmax=v,
                   cmap="coolwarm",
                   extent=(-1, 1, -1, 1))
        plt.plot(mu[cc, 0], mu[cc, 1], '.b')
        plt.title(cc)
        plt.subplot(sx, sy, cc * 2 + 2)
        i, j = np.where(w[bestlag, :, :] == np.max(w[bestlag, :, :]))
        t1 = w[:, i[0], j[0]]
        plt.plot(t1, '-ob')
        i, j = np.where(w[bestlag, :, :] == np.min(w[bestlag, :, :]))
        t2 = w[:, i[0], j[0]]
        plt.plot(t2, '-or')
        yd = plt.ylim()
        tdiff[:, cc] = t1 - t2

    plt.savefig(figDir + "/rawstas_" + gd.id + ".pdf", bbox_inches='tight')

    #%% Refit DivNorm Model on all datasets
    """
    Fit single layer DivNorm model with modulation
    """
    for version in range(1):  # range of version numbers

        #% Model: convolutional model
        input_channels = gd.num_lags
        hidden_channels = 16
        input_kern = 19
        hidden_kern = 5
        core = cores.Stacked2dDivNorm(
            input_channels,
            hidden_channels,
            input_kern,
            hidden_kern,
            layers=1,
            gamma_hidden=1e-6,  # group sparsity
            gamma_input=1,
            gamma_center=0,
            skip=0,
            final_nonlinearity=True,
            bias=False,
            pad_input=True,
            hidden_padding=hidden_kern // 2,
            group_norm=True,
            num_groups=4,
            weight_norm=True,
            hidden_dilation=1,
            input_regularizer="RegMats",
            input_reg_types=["d2x", "center", "d2t"],
            input_reg_amt=[.000005, .01, 0.00001],
            hidden_reg_types=["d2x", "center"],
            hidden_reg_amt=[.000005, .01],
            stack=None,
            use_avg_reg=True)

        # initialize input layer to be centered
        regw = regularizers.gaussian2d(input_kern, sigma=input_kern // 4)
        core.features[0].conv.weight.data = torch.einsum(
            'ijkm,km->ijkm', core.features[0].conv.weight.data,
            torch.tensor(regw))

        # Readout
        in_shape = [core.outchannels, gd.NY, gd.NX]
        bias = True
        readout = readouts.Point2DGaussian(in_shape,
                                           gd.NC,
                                           bias,
                                           init_mu_range=0.1,
                                           init_sigma=1,
                                           batch_sample=True,
                                           gamma_l1=0,
                                           gamma_l2=0.00001,
                                           align_corners=True,
                                           gauss_type='uncorrelated',
                                           constrain_positive=False,
                                           shifter={
                                               'hidden_features': 20,
                                               'hidden_layers': 1,
                                               'final_tanh': False,
                                               'activation': "softplus",
                                               'lengthscale': lengthscale
                                           })

        modifiers = {
            'stimlist': ['frametime', 'sacoff'],
            'gain': [sample['frametime'].shape[1], sample['sacoff'].shape[1]],
            'offset':
            [sample['frametime'].shape[1], sample['sacoff'].shape[1]],
            'stage': "readout",
            'outdims': gd.NC
        }

        # combine core and readout into model
        model = encoders.EncoderMod(
            core,
            readout,
            modifiers=modifiers,
            gamma_mod=0,
            weight_decay=.001,
            optimizer='AdamW',
            learning_rate=
            .01,  # high initial learning rate because we decay on plateau
            betas=[.9, .999],
            amsgrad=False)

        # initialize readout based on spike rate and STA centers
        model.readout.bias.data = sample['robs'].mean(
            dim=0)  # initialize readout bias helps
        model.readout._mu.data[0, :, 0, :] = torch.tensor(
            mu.astype('float32'))  # initiaalize mus

        #% Train
        trainer, train_dl, valid_dl = ut.get_trainer(gd,
                                                     version=version,
                                                     save_dir=save_dir,
                                                     name=gd.id,
                                                     auto_lr=False,
                                                     batchsize=1000,
                                                     num_workers=64,
                                                     earlystopping=True)

        trainpath = Path(save_dir) / gd.id / "version_{}".format(version)
        if not trainpath.exists():
            trainer.fit(model, train_dl, valid_dl)

    # Load best version
    pth = Path(save_dir) / gd.id

    valloss = np.array([])
    ind = np.array([])
    for v in pth.rglob('version*'):

        try:
            df = pd.read_csv(str(v / "metrics.csv"))
            ind = np.append(ind, int(v.name.split('_')[1]))
            valloss = np.append(valloss, np.nanmin(df.val_loss.to_numpy()))
        except:
            "skip"

    sortind = np.argsort(ind)
    ind = ind[sortind]
    valloss = valloss[sortind]

    vernum = ind[np.argmin(valloss)]

    # load all shifters and save them
    shifters = nn.ModuleList()
    for vernum in ind:
        # load linear model version
        ver = "version_" + str(int(vernum))
        chkpath = pth / ver / 'checkpoints'
        best_epoch = ut.find_best_epoch(chkpath)
        model2 = encoders.EncoderMod.load_from_checkpoint(
            str(chkpath / best_epoch))
        shifters.append(model2.readout.shifter.cpu())

    # load best model version
    ver = "version_" + str(int(vernum))
    chkpath = pth / ver / 'checkpoints'
    best_epoch = ut.find_best_epoch(chkpath)
    model2 = encoders.EncoderMod.load_from_checkpoint(str(chkpath /
                                                          best_epoch))
    shifter = model2.readout.shifter
    shifter.cpu()

    print("saving file")
    outdict = {
        'cids': gd.cids,
        'shifter': shifter,
        'shifters': shifters,
        'vernum': ind,
        'vallos': valloss,
        'numlags': num_lags,
        'tdownsample': tdownsample,
        'lengthscale': lengthscale
    }
    pickle.dump(outdict, open(str(outfile), "wb"))
sample = gd_test[:]
l2 = get_null_adjusted_ll(model2, sample)

plt.figure()
plt.plot(l2, '-o')
plt.axhline(0, color='k')

#%% Reload dataset in restricted position
import V1FreeViewingCode.models.datasets as dd
import importlib
importlib.reload(dd)

gd_restricted = dd.PixelDataset(sessid, stims=["Gabor"],
    stimset="Train", num_lags=num_lags,
    downsample_t=t_downsample,
    downsample_s=1,
    valid_eye_rad=2,
    valid_eye_ctr=(0.0,0.0),
    include_eyepos=True,
    preload=True)
#%% retrain 

# don't train first layer
model2.core.features.layer0.conv.weight.requires_grad = False
model2.core.features.layer0.norm.weight.requires_grad = False
model2.core.features.layer0.norm.bias.requires_grad = False

# train pooling layer
model2.core.features.layer1.conv.weight.requires_grad = True

# do train readout
model2.readout.features.requires_grad = True
Beispiel #3
0
    valloss = tmp['vallos']
    num_lags = tmp['numlags']
    t_downsample = tmp['tdownsample']
else:
    print("foveal_stas: need to run fit_shifter for %s" %sessid)


#%% LOAD DATASETS with shifter

import V1FreeViewingCode.models.datasets as dd
import importlib
importlib.reload(dd)
gd = dd.PixelDataset(sessid, stims=["Dots", "Gabor"], #, "BackImage", "Grating", "FixRsvpStim"],
    stimset="Train", num_lags=num_lags,
    downsample_t=t_downsample,
    downsample_s=1,
    valid_eye_rad=5.2,
    shifter=shifter,
    include_eyepos=True,
    preload=True)


#%% get sample
sample = gd[:]
im = sample['stim'].detach().clone()

#%% new STAs on shifted stimulus
crop = [20, 50, 0, 30]
stas = torch.einsum('nlwh,nc->lwhc', im[:,:,crop[2]:crop[3], crop[0]:crop[1]], sample['robs']-sample['robs'].mean(dim=0))
sta = stas.detach().cpu().numpy()

#%%
def make_figures(
        name='20200304_kilowf',
        path='/home/jake/Data/Datasets/MitchellV1FreeViewing/stim_movies/',
        tdownsample=2,
        numlags=12,
        lengthscale=1,
        nspace=100):
    """
    make_figures
    Make STA figures and dump analyses for Figure 5 of Yates et al., 2021
    Inputs:
        gd      <Dataset>   loaded pixel dataset
        outdict <dict>      meta data from shifter fit (saved by fit_shifter)
        figDir  <str>       path to directory to save figures
        nspace  <int>       number of spatial positions in shifter plots (default = 100)
    
    Output:
        None
    Saves a file mat file with all analyses and dumps a pdf for each cell
    """
    print("Running Make Figures")

    import scipy.io as sio  # for saving matlab files

    figDir = "/home/jake/Data/Repos/V1FreeViewingCode/Figures/2021_pytorchmodeling"

    save_dir = '../../checkpoints/v1calibration_ls{}'.format(lengthscale)

    outfile = Path(save_dir) / name / 'best_shifter.p'
    if outfile.exists():
        print("fit_shifter was already run. Loading [%s]" % name)
        tmp = pickle.load(open(str(outfile), "rb"))
        cids = tmp['cids']
        shifter = tmp['shifter']
        shifters = tmp['shifters']
        vernum = tmp['vernum']
        valloss = tmp['vallos']
        num_lags = tmp['numlags']
        t_downsample = tmp['tdownsample']
    else:
        print("v1_tracker_calibration: need to run fit_shifter for %s" % name)
        return

    n = 40
    num_basis = 15
    B = np.maximum(
        1 - np.abs(
            np.expand_dims(np.asarray(np.arange(0, n)), axis=1) -
            np.arange(0, n, n / num_basis)) / n * num_basis, 0)
    gd = dd.PixelDataset(
        name,
        stims=["Dots", "Gabor"],  #, "BackImage", "Grating", "FixRsvpStim"],
        stimset="Train",
        num_lags=num_lags,
        downsample_t=t_downsample,
        downsample_s=1,
        valid_eye_rad=5.2,
        include_frametime={
            'num_basis': 40,
            'full_experiment': False
        },
        include_saccades=[{
            'name': 'sacon',
            'basis': B,
            'offset': -20
        }, {
            'name': 'sacoff',
            'basis': B,
            'offset': 0
        }],
        include_eyepos=True,
        preload=True)

    sample = gd[:]

    # --- Plot all shifters

    # build inputs for shifter plotting
    xax = np.linspace(-gd.valid_eye_rad, gd.valid_eye_rad, nspace)
    xx, yy = np.meshgrid(xax, xax)
    xgrid = torch.tensor(xx.astype('float32').reshape((-1, 1)))
    ygrid = torch.tensor(yy.astype('float32').reshape((-1, 1)))

    inputs = torch.cat((xgrid, ygrid), dim=1)

    nshifters = len(shifters)

    print("Start Plot")
    fig = plt.figure(figsize=(7, nshifters * 2))
    shiftX = []
    shiftY = []

    for i in range(nshifters):
        y = shifters[i](inputs)

        y2 = y.detach().cpu().numpy()
        y2 /= gd.valid_eye_rad / 60  # convert to arcmin
        vmin = np.min(y2)
        vmax = np.max(y2)

        plt.subplot(nshifters, 2, i * 2 + 1)
        im = plt.contourf(xax,
                          xax,
                          y2[:, 0].reshape((nspace, nspace)),
                          vmin=vmin,
                          vmax=vmax)

        if i == 0:
            plt.title("Horizontal")

        plt.colorbar(im)

        plt.subplot(nshifters, 2, i * 2 + 2)
        im = plt.contourf(
            xax, xax, y2[:, 1].reshape((nspace, nspace)), vmin=vmin, vmax=vmax
        )  #, extent=(-gd.valid_eye_rad,gd.valid_eye_rad,-gd.valid_eye_rad,gd.valid_eye_rad), interpolation=None, vmin=vmin, vmax=vmax)
        if i == 0:
            plt.title("Vertical")

        plt.colorbar(im)

        shiftX.append(y2[:, 0].reshape((nspace, nspace)))
        shiftY.append(y2[:, 1].reshape((nspace, nspace)))

    print("Save fig")
    plt.savefig(figDir + "/shifters_" + gd.id + ".pdf", bbox_inches='tight')
    print("Success")
    # calculate mean and standard deviation across shifters
    Xarray = np.asarray(shiftX)
    Yarray = np.asarray(shiftY)
    mux = Xarray.mean(axis=0)
    sdx = Xarray.std(axis=0)

    muy = Yarray.mean(axis=0)
    sdy = Yarray.std(axis=0)

    print("Select Stim")
    # select stimulus
    if 'Dots' in gd.stims:
        stimuse = [
            i for i, s in zip(range(len(gd.stims)), gd.stims) if 'Dots' == s
        ][0]
    elif 'Gabor' in gd.stims:
        stimuse = [
            i for i, s in zip(range(len(gd.stims)), gd.stims) if 'Gabor' == s
        ][0]

    index = np.where(gd.stim_indices == stimuse)[0]
    sample = gd[index]  # load sample

    # use best model
    bestver = np.argmin(valloss)
    shift = shifters[bestver](sample['eyepos']).detach()
    y = shifters[bestver](inputs)
    y2 = y.detach().cpu().numpy()

    print("Shift Stim")
    # shift stimulus
    im = sample['stim']  #.detach().clone() # original
    im2 = shift_stim(im, shift, gd)  # shifted

    print("Get STAs")
    # compute new STAs on shifted stimulus
    stas = torch.einsum('nlwh,nc->lwhc', im,
                        sample['robs'] - sample['robs'].mean(dim=0))
    sta = stas.detach().cpu().numpy()
    stas = torch.einsum('nlwh,nc->lwhc', im2,
                        sample['robs'] - sample['robs'].mean(dim=0))
    sta2 = stas.detach().cpu().numpy()

    print("Save Mat File")
    # Save output for matlab
    fname = figDir + "/rfs_" + gd.id + ".mat"
    mdict = {
        'cids': cids,
        'xspace': xx,
        'yspace': yy,
        'shiftx': y2[:, 0].reshape((100, 100)),
        'shifty': y2[:, 1].reshape((100, 100)),
        'stas_pre': sta,
        'stas_post': sta2,
        'valloss': valloss,
        'mushiftx': mux,
        'mushifty': muy,
        'sdshiftx': sdx,
        'sdshifty': sdy
    }

    sio.savemat(fname, mdict)

    # compute dimensions
    extent = np.round(gd.rect / gd.ppd * 60)
    extent = np.asarray([extent[i] for i in [0, 2, 3, 1]])

    print("Plot STAs")
    # Plot STAs Before and After
    NC = sta.shape[3]

    # Loop over cells
    for cc in range(NC):
        plt.figure(figsize=(8, 2))

        w = sta[:, :, :, cc]
        w2 = sta2[:, :, :, cc]

        bestlag = np.argmax(np.std(w2.reshape((gd.num_lags, -1)), axis=1))

        w = (w - np.mean(w[bestlag, :, :])) / np.std(w)  # before
        w2 = (w2 - np.mean(w2[bestlag, :, :])) / np.std(w2)  # after

        plt.subplot(1, 4, 3)  # After Space
        v = np.max(np.abs(w2))
        plt.imshow(w2[bestlag, :, :],
                   aspect='auto',
                   interpolation=None,
                   vmin=-v,
                   vmax=v,
                   cmap="coolwarm",
                   extent=extent)
        plt.title("After")
        plt.xlabel("arcmin")
        plt.ylabel("arcmin")

        plt.subplot(1, 4, 4)  # After Time
        i, j = np.where(w2[bestlag, :, :] == np.max(w2[bestlag, :, :]))
        plt.plot(w2[:, i[0], j[0]], '-ob')
        i, j = np.where(w2[bestlag, :, :] == np.min(w2[bestlag, :, :]))
        plt.plot(w2[:, i[0], j[0]], '-or')
        yd = plt.ylim()
        plt.xlabel("Lag (frame=8ms)")

        plt.subplot(1, 4, 1)  # Before Space
        plt.imshow(w[bestlag, :, :],
                   aspect='auto',
                   interpolation=None,
                   vmin=-v,
                   vmax=v,
                   cmap="coolwarm",
                   extent=extent)
        plt.title("Before")
        plt.xlabel("arcmin")
        plt.ylabel("arcmin")

        plt.subplot(1, 4, 2)  # Before Time
        i, j = np.where(w[bestlag, :, :] == np.max(w[bestlag, :, :]))
        plt.plot(w[:, i[0], j[0]], '-ob')
        i, j = np.where(w[bestlag, :, :] == np.min(w[bestlag, :, :]))
        plt.plot(w[:, i[0], j[0]], '-or')
        plt.axvline(bestlag, color='k', linestyle='--')
        plt.ylim(yd)

        plt.xlabel("Lag (frame=8ms)")

        plt.savefig(figDir + "/sta_shift" + gd.id + "_" + str(cc) + ".pdf",
                    bbox_inches='tight')
        plt.close('all')

    print("Save Mat File")
    # save all STAs as one figure
    plot_sta_fig(sta, gd)  # Before
    plt.savefig(figDir + "/rawstas_" + gd.id + ".pdf", bbox_inches='tight')

    plot_sta_fig(sta2, gd)  # After
    plt.savefig(figDir + "/shiftstas_" + gd.id + ".pdf", bbox_inches='tight')
    plt.close('all')
Beispiel #5
0
        np.expand_dims(np.asarray(np.arange(0, n)), axis=1) -
        np.arange(0, n, n / num_basis)) / n * num_basis, 0)
t_downsample = tdownsample

gd = dd.PixelDataset(sessid,
                     stims=stimlist,
                     stimset="Train",
                     num_lags=num_lags,
                     downsample_t=t_downsample,
                     downsample_s=1,
                     valid_eye_rad=5.2,
                     dirname=path,
                     include_frametime={
                         'num_basis': 40,
                         'full_experiment': False
                     },
                     include_saccades=[{
                         'name': 'sacon',
                         'basis': B,
                         'offset': -20
                     }, {
                         'name': 'sacoff',
                         'basis': B,
                         'offset': 0
                     }],
                     include_eyepos=True,
                     preload=True)

sample = gd[:]

#%% compute STAS
Beispiel #6
0
importlib.reload(dd)

cropidx = cropidxs[sessid]

n = 40
num_basis = 15
B = np.maximum(1 - np.abs(np.expand_dims(np.asarray(np.arange(0,n)), axis=1) - np.arange(0,n,n/num_basis))/n*num_basis, 0)


# Load Training / Evaluation
gd = dd.PixelDataset(sessid, stims=["Dots", "Gabor", "BackImage", "Grating", "FixRsvpStim"],
    stimset="Train", num_lags=num_lags,
    downsample_t=t_downsample,
    downsample_s=1,
    valid_eye_rad=5.2,
    shifter=shifters[np.argmin(valloss)],
    cropidx=cropidx,
    include_frametime={'num_basis': 40, 'full_experiment': False},
    include_saccades=[{'name':'sacon', 'basis':B, 'offset':-20}, {'name':'sacoff', 'basis':B, 'offset':0}],
    include_eyepos=True,
    optics={'type': 'gausspsf', 'sigma': (0.7, 0.7, 0.0)},
    preload=True)

#%% Load test set
# print("Loading Test set")
# gd_test = dd.PixelDataset(sessid, stims=["Dots", "Gabor", "BackImage"], #, "Grating", "FixRsvpStim"],
#     stimset="Test", num_lags=num_lags,
#     downsample_t=t_downsample,
#     downsample_s=1,
#     valid_eye_rad=5.2,
#     cropidx=cropidx,
#     shifter=shifter,
Beispiel #7
0
#%% LOAD ALL DATASETS

# build tent basis for saccades
n = 40
num_basis = 15
B = np.maximum(1 - np.abs(np.expand_dims(np.asarray(np.arange(0,n)), axis=1) - np.arange(0,n,n/num_basis))/n*num_basis, 0)

import V1FreeViewingCode.models.datasets as dd
import importlib
importlib.reload(dd)
gd = dd.PixelDataset(sessid, stims=["Dots", "Gabor"], #, "BackImage", "Grating", "FixRsvpStim"],
    stimset="Train", num_lags=num_lags,
    downsample_t=t_downsample,
    downsample_s=1,
    valid_eye_rad=5.2,
    include_frametime={'num_basis': 40, 'full_experiment': False},
    include_saccades=[{'name':'sacon', 'basis':B, 'offset':-20}, {'name':'sacoff', 'basis':B, 'offset':0}],
    include_eyepos=True,
    preload=True)

sample = gd[:]

# from V1FreeViewingCode.Analysis.manuscript_freeviewingmethods.v1_tracker_calibration import make_figures



# %% Plot all shifters

# build inputs for shifter plotting
nspace = 100
l2 = get_null_adjusted_ll(model2, sample)

plt.figure()
plt.plot(l2, '-o')
plt.axhline(0, color='k')

#%% Reload dataset in restricted position
import V1FreeViewingCode.models.datasets as dd
import importlib
importlib.reload(dd)

gd_shift = dd.PixelDataset(sessid, stims=["Gabor"],
    stimset="Train", num_lags=num_lags,
    downsample_t=t_downsample,
    downsample_s=1,
    valid_eye_rad=5.2,
    valid_eye_ctr=(0.0,0.0),
    include_eyepos=True,
    cropidx=[(15,50),(20,50)],
    shifter=model2.readout.shifter,
    preload=True)
# %% reload sample and compute STAs
sample = gd_shift[:] # load sample 

stas = torch.einsum('nlwh,nc->lwhc', sample['stim'], (sample['robs']-sample['robs'].mean(dim=0))/sample['robs'].sum(dim=0))
sta = stas.detach().cpu().numpy()

#%% plot STAs / get RF centers
"""
Plot space/time STAs 
"""
NC = sta.shape[3]
Beispiel #9
0
import V1FreeViewingCode.models.datasets as dd
import importlib

importlib.reload(dd)
t_downsample = 1
num_lags = 20

gd_shift = dd.PixelDataset(
    sessid,
    stims=["Gabor", "Grating", "BackImage"],
    stimset="Train",
    num_lags=num_lags,
    downsample_t=t_downsample,
    downsample_s=1,
    valid_eye_rad=5,
    valid_eye_ctr=(0.0, 0.0),
    include_eyepos=True,
    cropidx=win,
    cids=cids,
    shifter=shifter,  #model2.readout.shifter,
    preload=True,
    temporal=True)

#%%  test set
gab_shift_test = dd.PixelDataset(sessid,
                                 stims=["Gabor"],
                                 stimset="Test",
                                 num_lags=num_lags,
                                 downsample_t=t_downsample,
                                 downsample_s=1,