Exemplo n.º 1
0
def initialize_stim_with_sta(population, data, x0, Ns=None):
    """ Initialize the stimulus response parameters with the STA
        TODO: Move this to the bkgd model once we have decided upon the
        correct function signature
    """
    if Ns is None:
        Ns = np.arange(population.N)

    if isinstance(Ns,int):
        Ns = [Ns]

    temporal = isinstance(population.glm.bkgd_model, BasisStimulus)
    spatiotemporal = isinstance(population.glm.bkgd_model, SpatiotemporalStimulus)
    
    if not (temporal or spatiotemporal):
        return

    # Compute the STA
    print "Initializing with the STA"
    # TODO Fix these super hacky calls
    if temporal:
        s = sta(data['stim'],
                data,
                population.glm.bkgd_model.ibasis.get_value().shape[0],
                Ns=Ns)
    elif spatiotemporal:
        s = sta(data['stim'],
                data,
                population.glm.bkgd_model.ibasis_t.get_value().shape[0],
                Ns=Ns)
        
    else:
       # We're only initializing the basis function stim models now
       return

    # Compute the initial weights for each neuron
    for i,n in enumerate(Ns):
        sn = np.squeeze(s[i,:,:])
        if sn.ndim == 1:
            sn = np.reshape(sn, [sn.size, 1])

        if spatiotemporal:
           # Factorize the STA into a spatiotemporal filter using SVD
           # CAUTION! Numpy svd returns V transpose whereas Matlab svd returns V!
           U,Sig,V = np.linalg.svd(sn)
           f_t = U[:,0] * np.sqrt(Sig[0])
           f_x = V[0,:] * np.sqrt(Sig[0])

           # Project this onto the spatial and temporal bases
           w_t = project_onto_basis(f_t, population.glm.bkgd_model.ibasis_t.get_value())
           w_x = project_onto_basis(f_x, population.glm.bkgd_model.ibasis_x.get_value())

           # Flatten into 1D vectors
           w_t = np.ravel(w_t)
           w_x = np.ravel(w_x)
           
           x0['glms'][n]['bkgd']['w_x'] = w_x
           x0['glms'][n]['bkgd']['w_t'] = w_t
        elif temporal:
            # Only using a temporal filter
            D_stim = sn.shape[1]
            B = population.glm.bkgd_model.ibasis.get_value().shape[1]
            
            # Project this onto the spatial and temporal bases
            w_t = np.zeros((B*D_stim,1))
            for d in np.arange(D_stim):
                w_t[d*B:(d+1)*B] = project_onto_basis(sn[:,d], 
                                                      population.glm.bkgd_model.ibasis.get_value())
            # Flatten into a 1D vector 
            w_t = np.ravel(w_t)
            x0['glms'][n]['bkgd']['w_stim'] = w_t    
Exemplo n.º 2
0
def convert_stimulus_filters_to_sharedtc(from_popn, from_model, from_vars, to_popn, to_model, to_vars):
    """
    Convert a set of stimulus filters to a shared set of tuning curves
    """
    # Get the spatial component of the stimulus filter for each neuron
    N = from_popn.N
    R = to_model['latent']['sharedtuningcurves']['R']
    from_state = from_popn.eval_state(from_vars)

    locs = np.zeros((N,2))
    local_stim_xs = []
    local_stim_ts = []
    for n in range(N):
        s_glm = from_state['glms'][n]


        # to_state = to_popn.eval_state(to_vars)
        assert 'stim_response_x' in s_glm['bkgd']

        # Get the stimulus responses
        stim_x = s_glm['bkgd']['stim_response_x']
        stim_t = s_glm['bkgd']['stim_response_t']
        loc_max = np.argmax(np.abs(stim_x))

        if stim_x.ndim == 2:
            locsi, locsj = np.unravel_index(loc_max, stim_x.shape)
            locs[n,0], locs[n,1] = locsi.ravel(), locsj.ravel()

            # # TODO: Test whether we have an issue with unraveling the data.
            # locs[n,0], locs[n,1] = locsj.ravel(), locsi.ravel()

        # Get the stimulus response in the vicinity of the mode
        # Create a meshgrid of the correct shape, centered around the max
        max_rb = to_model['latent']['latent_location']['location_prior']['max0']
        max_ub = to_model['latent']['latent_location']['location_prior']['max1']
        gsz = to_model['latent']['sharedtuningcurves']['spatial_shape']
        gwidth = (np.array(gsz) - 1)//2
        lb = max(0, locs[n,0]-gwidth[0])
        rb = min(locs[n,0]-gwidth[0]+gsz[0], max_rb)
        db = max(0, locs[n,1]-gwidth[1])
        ub = min(locs[n,1]-gwidth[1]+gsz[1], max_ub)
        grid = np.ix_(np.arange(lb, rb).astype(np.int),
                      np.arange(db, ub).astype(np.int))

        # grid = grid.astype(np.int)
        # Add this local filter to the list
        local_stim_xs.append(stim_x[grid])
        local_stim_ts.append(stim_t)

    # Cluster the local stimulus filters
    from sklearn.cluster import KMeans
    flattened_filters_x = np.array(map(lambda f: f.ravel(), local_stim_xs))
    flattened_filters_t = np.array(map(lambda f: f.ravel(), local_stim_ts))
    km = KMeans(n_clusters=R)
    km.fit(flattened_filters_x)
    Y = km.labels_
    print 'Filter cluster labels from kmeans: ',  Y

    # Initialize type based on stimulus filter
    to_vars['latent']['sharedtuningcurve_provider']['Y'] = Y

    # Initialize shared tuning curves (project onto the bases)
    from pyglm.utils.basis import project_onto_basis
    for r in range(R):
        mean_filter_xr = flattened_filters_x[Y==r].mean(axis=0)
        mean_filter_tr = flattened_filters_t[Y==r].mean(axis=0)

        # TODO: Make sure the filters are being normalized properly!

        # Project the mean filters onto the basis
        to_vars['latent']['sharedtuningcurve_provider']['w_x'][:,r] = \
            project_onto_basis(mean_filter_xr,
                               to_popn.glm.bkgd_model.spatial_basis).ravel()

        # Temporal part of the filter
        temporal_basis = to_popn.glm.bkgd_model.temporal_basis
        t_temporal_basis = np.arange(temporal_basis.shape[0])
        t_mean_filter_tr = np.linspace(0, temporal_basis.shape[0]-1, mean_filter_tr.shape[0])
        interp_mean_filter_tr = np.interp(t_temporal_basis, t_mean_filter_tr, mean_filter_tr)
        to_vars['latent']['sharedtuningcurve_provider']['w_t'][:,r] = \
            project_onto_basis(interp_mean_filter_tr, temporal_basis).ravel()

    # Initialize locations based on stimuls filters
    to_vars['latent']['location_provider']['L'] = locs.ravel().astype(np.int)