コード例 #1
0
    def train_arnet(self, start_params):
        arnet_autograd_func = grad(arnet_cost_function)
        arnet_bounds = [(0, 1)] * len(start_params) + [(0, 1)] * self.num_src

        iter_cnt = 0
        pred_train_output_mat = np.empty((0, len(self.true_train_output)), np.float)
        pred_test_output_mat = np.empty((0, self.num_output), np.float)
        link_weights_list = np.empty((0, self.num_src), np.float)
        network_ratio_list = []
        while iter_cnt < self.num_ensemble:
            arnet_init_values = np.array(start_params + [np.random.random()] * self.num_src)
            arnet_optimizer = optimize.minimize(arnet_cost_function, arnet_init_values, jac=arnet_autograd_func,
                                                method='L-BFGS-B',
                                                args=(self.train_input, self.true_train_output),
                                                bounds=arnet_bounds,
                                                options={'maxiter': 100, 'disp': False})
            arnet_fitted_params = arnet_optimizer.x
            # arnet_ar_coef = arnet_fitted_params[: self.num_input]
            arnet_link_weights = arnet_fitted_params[self.num_input:]

            arnet_pred_train, arnet_latent_train = arnet_predict(arnet_fitted_params, self.train_input, mode='train')
            arnet_pred_test, arnet_latent_test = arnet_predict(arnet_fitted_params, self.test_input, mode='test')

            pred_train_output_mat = np.vstack((pred_train_output_mat, arnet_pred_train))
            pred_test_output_mat = np.vstack((pred_test_output_mat, arnet_pred_test))
            link_weights_list = np.vstack((link_weights_list, arnet_link_weights))
            network_ratio_list.append(1 - sum(arnet_latent_train) / sum(arnet_pred_train))
            iter_cnt += 1

        self.pred_train_output = np.nanmean(pred_train_output_mat, axis=0)
        self.pred_test_output = np.nanmean(pred_test_output_mat, axis=0)
        self.link_weights = np.nanmean(link_weights_list, axis=0)
        self.network_ratio = np.mean(network_ratio_list)
コード例 #2
0
    def standard_normalizer(self, x):
        # compute the mean and standard deviation of the input
        x_means = np.nanmean(x, axis=1)[:, np.newaxis]
        x_stds = np.nanstd(x, axis=1)[:, np.newaxis]

        # check to make sure thta x_stds > small threshold, for those not
        # divide by 1 instead of original standard deviation
        ind = np.argwhere(x_stds < 10**(-2))
        if len(ind) > 0:
            ind = [v[0] for v in ind]
            adjust = np.zeros((x_stds.shape))
            adjust[ind] = 1.0
            x_stds += adjust

        # fill in any nan values with means
        ind = np.argwhere(np.isnan(x) == True)
        for i in ind:
            x[i[0], i[1]] = x_means[i[0]]

        # create standard normalizer function
        normalizer = lambda data: (data - x_means) / x_stds

        # create inverse standard normalizer
        inverse_normalizer = lambda data: data * x_stds + x_means

        # return normalizer
        return normalizer, inverse_normalizer
コード例 #3
0
 def ori_avg(Rs,these_ori_dirs):
     if fit_sc:
         rs_sc = np.nanmean(Rs[:nsc].reshape((nrun,nsize,ncontrast,ndir))[:,:,:,these_ori_dirs],-1)
         rs_sc[:,1:,1:] = ssi.convolve(rs_sc,kernel,'valid')
         rs_sc = rs_sc.reshape((nrun*nsize*ncontrast))
         if fit_fg:
             rs_fg = np.nanmean(Rs[nsc:].reshape((nrun,nstim_fg,ndir))[:,:,these_ori_dirs],-1)
             rs_fg = rs_fg.reshape((nrun*nstim_fg))
         else:
             rs_fg = np.zeros((0,))
     elif fit_fg:
         rs_sc = np.zeros((0,))
         rs_fg = np.nanmean(Rs.reshape((nrun,nstim_fg,ndir))[:,:,these_ori_dirs],-1)
         rs_fg = rs_fg.reshape((nrun*nstim_fg))
     Rso = np.concatenate((rs_sc,rs_fg))
     return Rso
コード例 #4
0
ファイル: functions.py プロジェクト: kunalghosh/viabel
def compute_R_hat(chains, warmup=500):
    #first axis is relaisations, second is iters
    # N_realisations X N_iters X Ndims
    jitter = 1e-8
    chains = chains[:, warmup:, :]
    n_iters = chains.shape[1]
    n_chains = chains.shape[0]
    K = chains.shape[2]
    if n_iters % 2 == 1:
        n_iters = int(n_iters - 1)
        chains = chains[:, :n_iters - 1, :]

    n_iters = n_iters // 2
    psi = np.reshape(chains, (n_chains * 2, n_iters, K))
    n_chains2 = n_chains * 2
    psi_dot_j = np.mean(psi, axis=1)
    psi_dot_dot = np.mean(psi_dot_j, axis=0)
    s_j_2 = np.sum(
        (psi - np.expand_dims(psi_dot_j, axis=1))**2, axis=1) / (n_iters - 1)
    B = n_iters * np.sum(
        (psi_dot_j - psi_dot_dot)**2, axis=0) / (n_chains2 - 1)
    W = np.nanmean(s_j_2, axis=0)
    W = W + jitter
    var_hat = (n_iters - 1) / n_iters + (B / (n_iters * W))
    R_hat = np.sqrt(var_hat)
    return var_hat, R_hat
コード例 #5
0
ファイル: utilities.py プロジェクト: abhiagwl/vistan
def advi_callback(params, t, g, results, delta_results, model, eval_function,
                  hparams):
    results.append(eval_function(params))

    if (t + 1) % hparams['advi_callback_iteration'] == 0:

        if len(results) > hparams['advi_callback_iteration']:
            previous_elbo = results[-(hparams['advi_callback_iteration'] + 1)]
        else:
            previous_elbo = 0.0

        current_elbo = results[-1]
        delta_results.append(relative_difference(previous_elbo, current_elbo))
        delta_elbo_mean = np.nanmean(delta_results)
        delta_elbo_median = np.nanmedian(delta_results)

        if ((delta_elbo_median <= hparams['advi_convergence_threshold']) |
            (delta_elbo_mean <= hparams['advi_convergence_threshold'])):
            tqdm.write(f"Converged early according to ADVI "
                       f"metrics for Median/Mean")
            tqdm.write(f"Iteration {t+1}")
            tqdm.write(f"Rel. tolerance Δ threshold: "
                       f"{hparams['advi_convergence_threshold']}")
            tqdm.write(f"Rel. tolerance Δ mean: {delta_elbo_mean:.5f}")
            tqdm.write(f"Rel. tolerance Δ median: {delta_elbo_median:.5f}")
            return "exit"
    return None
コード例 #6
0
ファイル: preprocessing.py プロジェクト: zhoupc/ssm
def standardize(data, mask):

    data[~mask] = np.nan
    m = np.nanmean(data, axis=0)
    s = np.nanstd(data, axis=0)
    s[~np.any(mask, axis=0)] = 1
    y = (data - m) / s
    assert np.all(np.isfinite(y))
    return y
コード例 #7
0
def optimize_kappa(data,x0=None,cost_fn=svm_cost,args=()):
    # given data, optimize (kappa, amplitude) inputs to svm_fn(), and return kappa
    if x0 is None:
        x0 = np.array((1,np.nanmean(data)))
    this_cost_fn = lambda x: cost_fn(x,data,*args)
    this_grad = grad(this_cost_fn)
    result = sop.minimize(this_cost_fn,x0,bounds=[(0,np.inf),(0,np.inf)],method='L-BFGS-B',jac=this_grad)
    kappa,amplitude = result.x
    return kappa,amplitude
コード例 #8
0
    def train_lstm(self):
        num_epochs = 100
        # sanity check: check the shape of train input, train output, and test input
        # print('shape of train input: {0}, train output: {1}, test input: {2}'.format(self.train_input.shape, self.train_output.shape, self.test_input.shape))

        callbacks = [EarlyStopping(monitor='val_loss', patience=10)]
        iter_cnt = 0
        pred_train_output_mat = np.empty((0, self.len_train_output), np.float)
        pred_test_output_mat = np.empty((0, self.num_output), np.float)
        while iter_cnt < self.num_ensemble:
            self.history = self.model.fit(self.train_input, self.train_output, validation_split=0.15, shuffle=False,
                                          batch_size=1, epochs=num_epochs, callbacks=callbacks, verbose=0)

            # get the predicted train output
            pred_train_output = self.model.predict(self.train_input)
            pred_train_output_mat = np.zeros(shape=(self.num_output, self.len_train_output), dtype=np.float)
            pred_train_output_mat.fill(np.nan)
            for i in range(self.num_sequence):
                seq_pred_train_output = post_process_results(pred_train_output[i],
                                                             denom=self.train_denom_list[i],
                                                             ts_seasonality_in=self.ts_seasonality_in,
                                                             shift=i,
                                                             freq=self.freq).ravel()
                pred_train_output_mat[i % self.num_output, i: i + self.num_output] = seq_pred_train_output
            iter_pred_train_output = np.nanmean(pred_train_output_mat, axis=0)
            iter_train_smape, _ = smape(self.true_train_output, iter_pred_train_output)
            if iter_train_smape < 150:
                # get the predicted test output
                iter_pred_test_output = self.model.predict(self.test_input).ravel()
                iter_pred_test_output = post_process_results(iter_pred_test_output,
                                                             denom=self.test_denom,
                                                             ts_seasonality_in=self.ts_seasonality_in,
                                                             shift=self.len_train_output,
                                                             freq=self.freq).ravel()

                pred_train_output_mat = np.vstack((pred_train_output_mat, iter_pred_train_output))
                pred_test_output_mat = np.vstack((pred_test_output_mat, iter_pred_test_output))
                iter_cnt += 1

        self.pred_train_output = np.nanmean(pred_train_output_mat, axis=0)
        self.pred_test_output = np.nanmean(pred_test_output_mat, axis=0)
コード例 #9
0
ファイル: fitstar.py プロジェクト: jradavenport/flareninja
def sigma_clip(t, y, yerr, mask=None):

    if mask is not None:
        m = np.copy(mask)
    else:
        m = np.ones(len(t), dtype=bool)
    while True:
        mu = np.nanmean(y[m])
        sig = np.nanstd(y[m])
        m0 = y - mu < 3 * sig
        if np.all(m0 == m):
            break
        m = m0

    #t, y, yerr = t[m], y[m], yerr[m]
    return m
コード例 #10
0
def compute_R_hat(chains, warmup=0.5):
    """
    Compute the split-R-hat for multiple chains,
    all the chains are split into two and the R-hat is computed over them
    before removing the 'warmup' iterates if desired.

    Parameters
    ----------
    chains : multi-dimensional array, shape=(n_chains, n_iters, n_var_params)

    Returns
    -------
    var_hat : var-hat computed in BDA

    R_hat: the split R-hat for multiple chains
    """

    jitter = 1e-8
    n_iters = chains.shape[1]
    n_chains = chains.shape[0]
    if warmup < 1:
        warmup = int(warmup * n_iters)

    if warmup > n_iters - 2:
        raise ValueError('Warmup should be less than number of iterates ..')

    if (n_iters - warmup) % 2:
        warmup = int(warmup + 1)

    chains = chains[:, warmup:, :]

    K = chains.shape[2]
    n_iters = chains.shape[1]
    n_iters = int(n_iters // 2)
    psi = np.reshape(chains, (n_chains * 2, n_iters, K))
    n_chains2 = n_chains * 2
    psi_dot_j = np.mean(psi, axis=1)
    psi_dot_dot = np.mean(psi_dot_j, axis=0)
    s_j_2 = np.sum(
        (psi - np.expand_dims(psi_dot_j, axis=1))**2, axis=1) / (n_iters - 1)
    B = n_iters * np.sum(
        (psi_dot_j - psi_dot_dot)**2, axis=0) / (n_chains2 - 1)
    W = np.nanmean(s_j_2, axis=0)
    W = W + jitter
    var_hat = (n_iters - 1) / n_iters + (B / (n_iters * W))
    R_hat = np.sqrt(var_hat)
    return var_hat, R_hat
コード例 #11
0
ファイル: synth.py プロジェクト: aasensio/DNHazel
    def compute_rotated_map(self, rotation):
        """
        Compute stellar maps projected on the plane of the sky for a given rotation of the star
        Args:
            rotation (float) : rotation around the star in degrees given as [longitude, latitude] in degrees
        
        Returns:
            pixel_unique (int) : vector with the "active" healpix pixels
            pixel_map (int) : map showing the healpix pixel projected on the plane of the sky
            mu_pixel (float): map of the astrocentric angle for each pixel on the plane of the sky (zero for pixels not in the star)
            T_pixel (float): map of temperatures for each pixel on the plane of the sky
        """
        mu_pixel = np.zeros_like(self.mu_angle)
        T_pixel = np.zeros_like(self.mu_angle)

# Get the projection of the healpix pixel indices on the plane of the sky
        pixel_map = self.projector.projmap(self.indices, self.f_vec2pix, rot=rotation)[:,0:int(self.npix/2)]

# Get the unique elements in the vector
        pixel_unique = np.unique(pixel_map)
        
# Now loop over all unique pixels, filling up the array of the projected map with the mu and temeperature values
        for j in range(len(pixel_unique)):
            ind = np.where(pixel_map == pixel_unique[j])            

            if (np.all(np.isfinite(self.mu_angle[ind[0],ind[1]]))):
                if (self.mu_angle[ind[0],ind[1]].size == 0):
                    value = 0.0
                else:                    
                    value = np.nanmean(self.mu_angle[ind[0],ind[1]])
                    mu_pixel[ind[0],ind[1]] = value

                    T_pixel[ind[0],ind[1]] = self.temperature_map[int(pixel_unique[j])]
            else:
                mu_pixel[ind[0],ind[1]] = 0.0
                T_pixel[ind[0],ind[1]] = 0.0

        return pixel_unique, pixel_map, mu_pixel, T_pixel
コード例 #12
0
def compute_R_hat(chains, warmup=0, jitter=1e-8):
    """
    Computing R hat values using split R hat approach

    Parameters
    ----------
    chains : `numpy_ndarray(n_iters, dimensions)`
        Sample of parameter estimates that has one chain
    warmup : `int`, optional
        Number of iterations needed for warmup. The default is 0.
    jitter : `float`, optional
        Smoothing term that avoids division by zero. The default is 1e-8.

    Returns
    -------
    R_hat : `numpy_ndarray(dimenstions,)`
        Computed R hat values for each parameter

    """
    n_chains = 1
    chains = chains[warmup:,:]
    n_iters, d = chains.shape
    if n_iters%2 ==1:
        n_iters = int(n_iters-1)
        chains = chains[:n_iters,:]
    n_iters = n_iters // 2
    n_chains2 = n_chains *2
    psi = np.reshape(chains,(n_chains2,n_iters,d))
    psi_dot_j = np.mean(psi,axis=1)
    psi_dot_dot = np.mean(psi_dot_j,axis=0)
    s_j_2 = np.sum((psi - np.expand_dims(psi_dot_j, axis=1)) ** 2, axis=1) / (n_iters - 1)
    B = n_iters * np.sum((psi_dot_j - psi_dot_dot) ** 2, axis=0) / (n_chains2 - 1)
    W = np.nanmean(s_j_2, axis=0)
    W = W + jitter
    var_hat = (n_iters - 1) / n_iters + (B / (n_iters*W))
    R_hat = np.sqrt(var_hat)
    return R_hat
コード例 #13
0
ファイル: sim_utils.py プロジェクト: dmossing/analysis
def average_time(arr, nbefore=8, nafter=8):
    # average across time points and directions
    ndim = len(arr.shape)
    slicer = [slice(None) for idim in range(ndim)]
    slicer[-1] = slice(nbefore, -nafter)
    return np.nanmean(arr[slicer], -1)
コード例 #14
0
ファイル: sim_utils.py プロジェクト: dmossing/analysis
def compute_tuning(dsfile,
                   datafield='decon',
                   running=True,
                   expttype='size_contrast_0',
                   running_pct_cutoff=default_running_pct_cutoff,
                   fill_nans_under_cutoff=False,
                   run_speed_cutoff=default_run_speed_cutoff):
    # take in an HDF5 data struct, and convert to an n-dimensional matrix
    # describing the tuning curve of each neuron. For size-contrast stimuli,
    # the dimensions of this matrix are ROI index x size x contrast x direction x time.
    # This outputs a list of such matrices, where each element is one imaging session
    with h5py.File(dsfile, mode='r') as f:
        keylist = [key for key in f.keys()]
        tuning = [None for i in range(len(keylist))]
        uparam = [None for i in range(len(keylist))]
        displacement = [None for i in range(len(keylist))]
        pval = [None for i in range(len(keylist))]
        for ikey in range(len(keylist)):
            session = f[keylist[ikey]]
            print(session)
            #print([key for key in session.keys()])
            if expttype in session and datafield in session[expttype]:
                sc0 = session[expttype]
                print(datafield)
                data = sc0[datafield][:]
                stim_id = sc0['stimulus_id'][:]
                nbefore = sc0['nbefore'][()]
                nafter = sc0['nafter'][()]
                if running:
                    trialrun = sc0[
                        'running_speed_cm_s'][:, nbefore:-nafter].mean(
                            -1) > run_speed_cutoff  #
                else:
                    trialrun = sc0['running_speed_cm_s'][:,
                                                         nbefore:-nafter].mean(
                                                             -1
                                                         ) < run_speed_cutoff
                #print(sc0['running_speed_cm_s'].shape)
                print(np.nanmean(trialrun))
                if np.nanmean(trialrun) > running_pct_cutoff:
                    tuning[ikey] = ut.compute_tuning(
                        data, stim_id, trial_criteria=trialrun)[:]
                elif fill_nans_under_cutoff:
                    tuning[ikey] = ut.compute_tuning(
                        data, stim_id, trial_criteria=trialrun)[:]
                    tuning[ikey] = np.nan * np.ones_like(tuning[ikey])
                uparam[ikey] = []
                for param in sc0['stimulus_parameters']:
                    uparam[ikey] = uparam[ikey] + [sc0[param][:]]
                if 'rf_displacement_deg' in sc0:
                    pval[ikey] = sc0['rf_mapping_pval'][:]
                    sqerror = session['retinotopy_0']['rf_sq_error'][:]
                    sigma = session['retinotopy_0']['rf_sigma'][:]
                    X = session['cell_center'][:]
                    y = sc0['rf_displacement_deg'][:].T
                    rf_conditions = [
                        ut.k_and(~np.isnan(X[:, 0]), ~np.isnan(y[:, 0])),
                        sqerror < 0.75, sigma > 3.3, pval[ikey] < 0.1
                    ]
                    lkat = np.ones((X.shape[0], ), dtype='bool')
                    for cond in rf_conditions:
                        lkat_new = (lkat & cond)
                        if lkat_new.sum() >= 5:
                            lkat = lkat_new.copy()
                    linreg = sklearn.linear_model.LinearRegression().fit(
                        X[lkat], y[lkat])
                    displacement[ikey] = np.zeros_like(y)
                    displacement[ikey][~np.isnan(X[:, 0])] = linreg.predict(
                        X[~np.isnan(X[:, 0])])
    return tuning, uparam, displacement, pval
コード例 #15
0
            np.squeeze(csd_fast_phase[ci, :, :]).T,
            np.squeeze(csd_fast_phase[cj, :, :]).T)
        plv_csd[ci, cj, :] = res_csd
        plv_csd[cj, ci, :] = res_csd
        res_lfp = plv(
            np.squeeze(lfp_fast_phase[ci, :, :]).T,
            np.squeeze(lfp_fast_phase[cj, :, :]).T)
        plv_lfp[ci, cj, :] = res_lfp
        plv_lfp[cj, ci, :] = res_lfp

# %% Visualize PLV briefly
plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.plot(gpcsd_model.t_pred.squeeze(), plv_csd.reshape((24 * 24, -1)).T)
plt.plot(gpcsd_model.t_pred.squeeze(),
         np.nanmean(plv_csd.reshape((24 * 24, -1)), 0),
         'k',
         linewidth=3)
plt.subplot(122)
plt.plot(gpcsd_model.t_pred.squeeze(), plv_lfp.reshape((24 * 24, -1)).T)
plt.plot(gpcsd_model.t_pred.squeeze(),
         np.nanmean(plv_lfp.reshape((24 * 24, -1)), 0),
         'k',
         linewidth=3)

# %% Save phases at time point index 350 for Matlab analysis
scipy.io.savemat(
    '%s/results/csd_lfp_filt_phases_%s.mat' % (root_path, probe_name), {
        'csd': csd_fast_phase[:, 350, :],
        'lfp': lfp_fast_phase[:, 350, :]
    })
コード例 #16
0
ファイル: synth.py プロジェクト: aasensio/DNHazel
    def precompute_rotation_maps(self, rotations=None):
        """
        Compute the averaged spectrum on the star for a given temperature map and for a given rotation
        Args:
            rotations (float) : [N_phases x 2] giving [longitude, latitude] in degrees for each phase
        
        Returns:
            None
        """
        if (rotations is None):
            print("Use some angles for the rotations")
            return

        self.n_phases = rotations.shape[0]

        self.avg_mu = [None] * self.n_phases
        self.avg_v = [None] * self.n_phases
        self.velocity = [None] * self.n_phases
        self.n_pixel_unique = [None] * self.n_phases
        self.n_pixels = [None] * self.n_phases
        self.pixel_unique = [None] * self.n_phases

        for loop in range(self.n_phases):
            mu_pixel = np.zeros_like(self.mu_angle)
            v_pixel = np.zeros_like(self.vel_projection)
        
            pixel_map = self.projector.projmap(self.indices, self.f_vec2pix, rot=rotations[loop,:])[:,0:int(self.npix/2)]
            pixel_unique = np.unique(pixel_map[np.isfinite(pixel_map)])

            for j in range(len(pixel_unique)):
                ind = np.where(pixel_map == pixel_unique[j])

                if (np.all(np.isfinite(self.mu_angle[ind[0],ind[1]]))):
                    if (self.mu_angle[ind[0],ind[1]].size == 0):
                        mu_pixel[ind[0],ind[1]] = 0.0
                        v_pixel[ind[0],ind[1]] = 0.0
                    else:                    
                        
                        if (self.clv):
                            value = np.nanmean(self.mu_angle[ind[0],ind[1]])
                        else:
                            value = 1.0

                        mu_pixel[ind[0],ind[1]] = value

                        value = np.nanmean(self.vel_projection[ind[0],ind[1]])
                        v_pixel[ind[0],ind[1]] = value
                else:
                    mu_pixel[ind[0],ind[1]] = 0.0
                    v_pixel[ind[0],ind[1]] = 0.0

            self.n_pixel_unique[loop] = len(pixel_unique)
            self.avg_mu[loop] = np.zeros(self.n_pixel_unique[loop])
            self.avg_v[loop] = np.zeros(self.n_pixel_unique[loop])
            self.velocity[loop] = np.zeros(self.n_pixel_unique[loop])
            self.n_pixels[loop] = np.zeros(self.n_pixel_unique[loop], dtype='int')
            self.pixel_unique[loop] = pixel_unique.astype('int')

            for i in range(len(pixel_unique)):
                ind = np.where(pixel_map == pixel_unique[i])
                self.n_pixels[loop][i] = len(ind[0])
                self.avg_mu[loop][i] = np.unique(mu_pixel[ind[0], ind[1]])
                self.avg_v[loop][i] = np.unique(v_pixel[ind[0], ind[1]])            
                self.velocity[loop][i] = self.avg_mu[loop][i] * self.avg_v[loop][i]
コード例 #17
0
            print(log_iw.shape)

            psis_lw, K_hat_stan = psislw(log_iw.T)
            K_hat_stan_advi_list[j, n] = K_hat_stan
            print(psis_lw.shape)
            print('K hat statistic for Stan ADVI:')
            print(K_hat_stan)

    ###################### Plotting L2 norm here #################################

plt.figure()
plt.plot(stan_vb_w[:, 0], stan_vb_w[:, 1], 'mo', label='STAN-ADVI')
plt.savefig('vb_w_samples_mf.pdf')

np.save('K_hat_linear_' + datatype + '_' + algo_name + '_' + str(N) + 'N',
        K_hat_stan_advi_list)

plt.figure()
plt.plot(K_list, np.nanmean(K_hat_stan_advi_list, axis=1), 'r-', alpha=1)
plt.plot(K_list, np.nanmin(K_hat_stan_advi_list, axis=1), 'r-', alpha=0.5)
plt.plot(K_list, np.nanmax(K_hat_stan_advi_list, axis=1), 'r-', alpha=0.5)
plt.xlabel('Dimensions')
plt.ylabel('K-hat')

np.save(
    'K_hat_linear_' + datatype + '_' + algo_name + '_' + str(N) + 'N' +
    '_samples_' + str(gradsamples), K_hat_stan_advi_list)
#plt.ylim((0,5))
plt.legend()
plt.savefig('Linear_Regression_K_hat_vs_D_' + datatype + '_' + algo_name +
            '_' + str(N) + 'N.pdf')
コード例 #18
0
def fit_weights_and_save(weights_file,ca_data_file='rs_vm_denoise_200605.npy',opto_silencing_data_file='vip_halo_data_for_sim.npy',opto_activation_data_file='vip_chrimson_data_for_sim.npy',constrain_wts=None,allow_var=True,fit_s02=True,constrain_isn=True,tv=False,l2_penalty=0.01,init_noise=0.1,init_W_from_lsq=False,scale_init_by=1,init_W_from_file=False,init_file=None,correct_Eta=False,init_Eta_with_s02=False,init_Eta12_with_dYY=False,use_opto_transforms=False):
    
    nsize,ncontrast = 6,6
    
    npfile = np.load(ca_data_file,allow_pickle=True)[()]#,{'rs':rs,'rs_denoise':rs_denoise},allow_pickle=True)
    rs = npfile['rs']
    #rs_denoise = npfile['rs_denoise']
    
    nsize,ncontrast,ndir = 6,6,8
    #ori_dirs = [[0,4],[2,6]] #[[0,4],[1,3,5,7],[2,6]]
    ori_dirs = [[0,1,2,3,4,5,6,7]]
    nT = len(ori_dirs)
    nS = len(rs[0])
    
    def sum_to_1(r):
        R = r.reshape((r.shape[0],-1))
        #R = R/np.nansum(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis]
        R = R/np.nansum(R,axis=1)[:,np.newaxis] # changed 8/28
        return R
    
    def norm_to_mean(r):
        R = r.reshape((r.shape[0],-1))
        R = R/np.nanmean(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis]
        return R
    
    Rs = [[None,None] for i in range(len(rs))]
    Rso = [[[None for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(rs))]
    rso = [[[None for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(rs))]
    
    for iR,r in enumerate(rs):#rs_denoise):
        print(iR)
        for ialign in range(nS):
            #Rs[iR][ialign] = r[ialign][:,:nsize,:]
            #sm = np.nanmean(np.nansum(np.nansum(Rs[iR][ialign],1),1))
            #Rs[iR][ialign] = Rs[iR][ialign]/sm
            Rs[iR][ialign] = sum_to_1(r[ialign][:,:nsize,:])
    #         Rs[iR][ialign] = von_mises_denoise(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir)))
    
    kernel = np.ones((1,2,2))
    kernel = kernel/kernel.sum()
    
    for iR,r in enumerate(rs):
        for ialign in range(nS):
            for iori in range(nT):
                Rso[iR][ialign][iori] = np.nanmean(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))[:,:,:,ori_dirs[iori]],-1)
                Rso[iR][ialign][iori][:,:,0] = np.nanmean(Rso[iR][ialign][iori][:,:,0],1)[:,np.newaxis] # average 0 contrast values
                Rso[iR][ialign][iori][:,1:,1:] = ssi.convolve(Rso[iR][ialign][iori],kernel,'valid')
                Rso[iR][ialign][iori] = Rso[iR][ialign][iori].reshape(Rso[iR][ialign][iori].shape[0],-1)
                #Rso[iR][ialign][iori] = Rso[iR][ialign][iori]/np.nanmean(Rso[iR][ialign][iori],-1)[:,np.newaxis]
    
    def set_bound(bd,code,val=0):
        # set bounds to 0 where 0s occur in 'code'
        for iitem in range(len(bd)):
            bd[iitem][code[iitem]] = val
    
    nN = 36
    nS = 2
    nP = 2
    nT = 1
    nQ = 4
    
    # code for bounds: 0 , constrained to 0
    # +/-1 , constrained to +/-1
    # 1.5, constrained to [0,1]
    # 2 , constrained to [0,inf)
    # -2 , constrained to (-inf,0]
    # 3 , unconstrained
    
    Wmx_bounds = 3*np.ones((nP,nQ),dtype=int)
    Wmx_bounds[0,1] = 0 # SSTs don't receive L4 input
    
    if allow_var:
        Wsx_bounds = 3*np.ones(Wmx_bounds.shape) #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds)
        Wsx_bounds[0,1] = 0
    else:
        Wsx_bounds = np.zeros(Wmx_bounds.shape) #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds)
    
    Wmy_bounds = 3*np.ones((nQ,nQ),dtype=int)
    Wmy_bounds[0,:] = 2 # PCs are excitatory
    Wmy_bounds[1:,:] = -2 # all the cell types except PCs are inhibitory
    Wmy_bounds[1,1] = 0 # SSTs don't inhibit themselves
    # Wmy_bounds[3,1] = 0 # PVs are allowed to inhibit SSTs, consistent with Hillel's unpublished results, but not consistent with Pfeffer et al.
    Wmy_bounds[2,0] = 0 # VIPs don't inhibit L2/3 PCs. According to Pfeffer et al., only L5 PCs were found to get VIP inhibition

    if allow_var:
        Wsy_bounds = 3*np.ones(Wmy_bounds.shape) #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds)
        Wsy_bounds[1,1] = 0
        Wsy_bounds[3,1] = 0 
        Wsy_bounds[2,0] = 0
    else:
        Wsy_bounds = np.zeros(Wmy_bounds.shape) #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds)

    if not constrain_wts is None:
        for wt in constrain_wts:
            Wmy_bounds[wt[0],wt[1]] = 0
            Wsy_bounds[wt[0],wt[1]] = 0
    
    def tile_nS_nT_nN(kernel):
        row = np.concatenate([kernel for idim in range(nS*nT)],axis=0)[np.newaxis,:]
        tiled = np.concatenate([row for irow in range(nN)],axis=0)
        return tiled
    
    if fit_s02:
        s02_bounds = 2*np.ones((nQ,)) # permitting noise as a free parameter
    else:
        s02_bounds = np.ones((nQ,))
    
    k_bounds = 1.5*np.ones((nQ*(nS-1),))
    
    kappa_bounds = np.ones((1,))
    # kappa_bounds = 2*np.ones((1,))
    
    T_bounds = 1.5*np.ones((nQ*(nT-1),))
    
    X_bounds = tile_nS_nT_nN(np.array([2,1]))
    # X_bounds = np.array([np.array([2,1,2,1])]*nN)
    
    Xp_bounds = tile_nS_nT_nN(np.array([3,1]))
    # Xp_bounds = np.array([np.array([3,1,3,1])]*nN)
    
    # Y_bounds = tile_nS_nT_nN(2*np.ones((nQ,)))
    # # Y_bounds = 2*np.ones((nN,nT*nS*nQ))
    
    Eta_bounds = tile_nS_nT_nN(3*np.ones((nQ,)))
    # Eta_bounds = 3*np.ones((nN,nT*nS*nQ))
    
    if allow_var:
        Xi_bounds = tile_nS_nT_nN(3*np.ones((nQ,)))
    else:
        Xi_bounds = tile_nS_nT_nN(np.zeros((nQ,)))

    # Xi_bounds = 3*np.ones((nN,nT*nS*nQ))
    
    h1_bounds = -2*np.ones((1,))
    
    h2_bounds = 2*np.ones((1,))
    
    
    # In[8]:
    
    
    # shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nN,nS*nP),(nN,nS*nQ),(nN,nS*nQ),(nN,nS*nQ)]
    shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ),(1,),(1,),(nN,nT*nS*nQ),(nN,nT*nS*nQ)]
    print('size of shapes: '+str(np.sum([np.prod(shp) for shp in shapes])))
    #         Wmx,    Wmy,    Wsx,    Wsy,    s02,  k,    kappa,T,   XX,            XXp,          Eta,          Xi, h1, h2, Eta1,   Eta2
    
    lb = [-np.inf*np.ones(shp) for shp in shapes]
    ub = [np.inf*np.ones(shp) for shp in shapes]
    bdlist = [Wmx_bounds,Wmy_bounds,Wsx_bounds,Wsy_bounds,s02_bounds,k_bounds,kappa_bounds,T_bounds,X_bounds,Xp_bounds,Eta_bounds,Xi_bounds,h1_bounds,h2_bounds,Eta_bounds,Eta_bounds]
    
    set_bound(lb,[bd==0 for bd in bdlist],val=0)
    set_bound(ub,[bd==0 for bd in bdlist],val=0)
    
    set_bound(lb,[bd==2 for bd in bdlist],val=0)
    
    set_bound(ub,[bd==-2 for bd in bdlist],val=0)
    
    set_bound(lb,[bd==1 for bd in bdlist],val=1)
    set_bound(ub,[bd==1 for bd in bdlist],val=1)
    
    set_bound(lb,[bd==1.5 for bd in bdlist],val=0)
    set_bound(ub,[bd==1.5 for bd in bdlist],val=1)
    
    set_bound(lb,[bd==-1 for bd in bdlist],val=-1)
    set_bound(ub,[bd==-1 for bd in bdlist],val=-1)
    
    # for bd in [lb,ub]:
    #     for ind in [2,3]:
    #         bd[ind][:,1] = 0
    
    # temporary for no variation expt.
    # lb[2] = np.zeros_like(lb[2])
    # lb[3] = np.zeros_like(lb[3])
    # lb[4] = np.ones_like(lb[4])
    # lb[5] = np.zeros_like(lb[5])
    # ub[2] = np.zeros_like(ub[2])
    # ub[3] = np.zeros_like(ub[3])
    # ub[4] = np.ones_like(ub[4])
    # ub[5] = np.ones_like(ub[5])
    # temporary for no variation expt.
    lb = np.concatenate([a.flatten() for a in lb])
    ub = np.concatenate([b.flatten() for b in ub])
    bounds = [(a,b) for a,b in zip(lb,ub)]
    
    
    # In[10]:
    
    
    nS = 2
    print('nT: '+str(nT))
    ndims = 5
    ncelltypes = 5
    Yhat = [[None for iT in range(nT)] for iS in range(nS)]
    Xhat = [[None for iT in range(nT)] for iS in range(nS)]
    Ypc_list = [[None for iT in range(nT)] for iS in range(nS)]
    Xpc_list = [[None for iT in range(nT)] for iS in range(nS)]
    mx = [None for iS in range(nS)]
    for iS in range(nS):
        mx[iS] = np.zeros((ncelltypes,))
        yy = [None for icelltype in range(ncelltypes)]
        for icelltype in range(ncelltypes):
            yy[icelltype] = np.nanmean(Rso[icelltype][iS][0],0)
            mx[iS][icelltype] = np.nanmax(yy[icelltype])
        for iT in range(nT):
            y = [np.nanmean(Rso[icelltype][iS][iT],axis=0)[:,np.newaxis]/mx[iS][icelltype] for icelltype in range(1,ncelltypes)]
            Ypc_list[iS][iT] = [None for icelltype in range(1,ncelltypes)]
            for icelltype in range(1,ncelltypes):
                rss = Rso[icelltype][iS][iT].copy()#/mx[iS][icelltype] #.reshape(Rs[icelltype][ialign].shape[0],-1)
                #rss = Rso[icelltype][iS][iT].copy() #.reshape(Rs[icelltype][ialign].shape[0],-1)
                rss = rss[np.isnan(rss).sum(1)==0]
        #         print(rss.max())
        #         rss[rss<0] = 0
        #         rss = rss[np.random.randn(rss.shape[0])>0]
                try:
                    u,s,v = np.linalg.svd(rss-np.mean(rss,0)[np.newaxis])
                    Ypc_list[iS][iT][icelltype-1] = [(s[idim],v[idim]) for idim in range(ndims)]
    #                 print('yep on Y')
    #                 print(np.min(np.sum(rs[icelltype][iS][iT],axis=1)))
                except:
    #                 print('nope on Y')
                    print(np.mean(np.isnan(rss)))
                    print(np.min(np.sum(rs[icelltype][iS][iT],axis=1)))
            Yhat[iS][iT] = np.concatenate(y,axis=1)
    #         x = sim_utils.columnize(Rso[0][iS][iT])[:,np.newaxis]
            icelltype = 0
            #x = np.nanmean(Rso[icelltype][iS][iT],0)[:,np.newaxis]#/mx[iS][icelltype]
            x = np.nanmean(Rso[icelltype][iS][iT],0)[:,np.newaxis]/mx[iS][icelltype]
    #         opto_column = np.concatenate((np.zeros((nN,)),np.zeros((nNO/2,)),np.ones((nNO/2,))),axis=0)[:,np.newaxis]
            Xhat[iS][iT] = np.concatenate((x,np.ones_like(x)),axis=1)
    #         Xhat[iS][iT] = np.concatenate((x,np.ones_like(x),opto_column),axis=1)
            icelltype = 0
            #rss = Rso[icelltype][iS][iT].copy()/mx[iS][icelltype]
            rss = Rso[icelltype][iS][iT].copy()
            rss = rss[np.isnan(rss).sum(1)==0]
    #         try:
            u,s,v = np.linalg.svd(rss-rss.mean(0)[np.newaxis])
            Xpc_list[iS][iT] = [None for iinput in range(2)]
            Xpc_list[iS][iT][0] = [(s[idim],v[idim]) for idim in range(ndims)]
            Xpc_list[iS][iT][1] = [(0,np.zeros((Xhat[0][0].shape[0],))) for idim in range(ndims)]
    #         except:
    #             print('nope on X')
    #             print(np.mean(np.isnan(rss)))
    #             print(np.min(np.sum(Rso[icelltype][iS][iT],axis=1)))
    nN,nP = Xhat[0][0].shape
    print('nP: '+str(nP))
    nQ = Yhat[0][0].shape[1]
    
    
    # In[11]:
    
    
    def compute_f_(Eta,Xi,s02):
        return sim_utils.f_miller_troyer(Eta,Xi**2+np.concatenate([s02 for ipixel in range(nS*nT)]))
    def compute_fprime_m_(Eta,Xi,s02):
        return sim_utils.fprime_miller_troyer(Eta,Xi**2+np.concatenate([s02 for ipixel in range(nS*nT)]))*Xi
    def compute_fprime_s_(Eta,Xi,s02):
        s2 = Xi**2+np.concatenate((s02,s02),axis=0)
        return sim_utils.fprime_s_miller_troyer(Eta,s2)*(Xi/s2)
    def sorted_r_eigs(w):
        drW,prW = np.linalg.eig(w)
        srtinds = np.argsort(drW)
        return drW[srtinds],prW[:,srtinds]
    
    
    # In[12]:
    
    
    #         0.Wmx,  1.Wmy,  2.Wsx,  3.Wsy,  4.s02,5.K,  6.kappa,7.T,8.XX,        9.XXp,        10.Eta,       11.Xi,   12.h1,  13.h2,  14.Eta1,    15.Eta2
    
    shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ),(1,),(1,),(nN,nT*nS*nQ),(nN,nT*nS*nQ)]
    print('size of shapes: '+str(np.sum([np.prod(shp) for shp in shapes])))
    
    
    import calnet.fitting_spatial_feature
    import sim_utils

    YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat)
    XXhat = calnet.utils.flatten_nested_list_of_2d_arrays(Xhat)
    
    opto_dict = np.load(opto_silencing_data_file,allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = np.nanmean(np.reshape(Yhat_opto,(nN,2,nS,2,nQ)),3).reshape((nN*2,-1))
    Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    print(Yhat_opto.shape)
    h_opto = opto_dict['h_opto']
    #dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2]
        
    YYhat_halo = Yhat_opto.reshape((nN,2,-1))
    opto_transform1 = calnet.utils.fit_opto_transform(YYhat_halo)

    opto_transform1.res[:,[0,2,3,4,6,7]] = 0

    dYY1 = opto_transform1.transform(YYhat) - YYhat
    #YYhat_halo_sim = calnet.utils.simulate_opto_effect(YYhat,YYhat_halo)
    #dYY1 = YYhat_halo_sim[:,1,:] - YYhat_halo_sim[:,0,:]

    def overwrite_plus_n(arr,to_overwrite,n):
        arr[:,to_overwrite] = arr[:,int(to_overwrite+n)]
        return arr

    for to_overwrite in [1,2]:
        n = 4
        dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \
                = [overwrite_plus_n(x,to_overwrite,n) for x in \
                        [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]]
    for to_overwrite in [7]:
        n = -4
        dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \
                = [overwrite_plus_n(x,to_overwrite,n) for x in \
                        [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]]
    #for to_overwrite in [1,2]:
    #    dYY1[:,to_overwrite] = dYY1[:,to_overwrite+4]
    #for to_overwrite in [7]:
    #    dYY1[:,to_overwrite] = dYY1[:,to_overwrite-4]
    
    #Yhat_opto = opto_dict['Yhat_opto']
    #for iS in range(nS):
    #    mx = np.zeros((nQ,))
    #    for iQ in range(nQ):
    #        slicer = slice(nQ*nT*iS+iQ,nQ*nT*(1+iS),nQ)
    #        mx[iQ] = np.nanmax(Yhat_opto[0::2][:,slicer])
    #        Yhat_opto[:,slicer] = Yhat_opto[:,slicer]/mx[iQ]
    ##Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    #print(Yhat_opto.shape)
    #h_opto = opto_dict['h_opto']
    #dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2]
    #for to_overwrite in [1,2,5,6]: # overwrite sst and vip with off-centered values
    #    dYY1[:,to_overwrite] = dYY1[:,to_overwrite+8]
    #for to_overwrite in [11,15]:
    #    dYY1[:,to_overwrite] = np.nan #dYY1[:,to_overwrite-8]


    opto_dict = np.load(opto_activation_data_file,allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = np.nanmean(np.reshape(Yhat_opto,(nN,2,nS,2,nQ)),3).reshape((nN*2,-1))
    Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    print(Yhat_opto.shape)
    h_opto = opto_dict['h_opto']
    #dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2]

    YYhat_chrimson = Yhat_opto.reshape((nN,2,-1))
    opto_transform2 = calnet.utils.fit_opto_transform(YYhat_chrimson)
    dYY2 = opto_transform2.transform(YYhat) - YYhat
    #YYhat_chrimson_sim = calnet.utils.simulate_opto_effect(YYhat,YYhat_chrimson)
    #dYY2 = YYhat_chrimson_sim[:,1,:] - YYhat_chrimson_sim[:,0,:]

    #Yhat_opto = opto_dict['Yhat_opto']
    #for iS in range(nS):
    #    mx = np.zeros((nQ,))
    #    for iQ in range(nQ):
    #        slicer = slice(nQ*nT*iS+iQ,nQ*nT*(1+iS),nQ)
    #        mx[iQ] = np.nanmax(Yhat_opto[0::2][:,slicer])
    #        Yhat_opto[:,slicer] = Yhat_opto[:,slicer]/mx[iQ]
    ##Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    #print(Yhat_opto.shape)
    #h_opto = opto_dict['h_opto']
    #dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2]
    
    print('dYY1 mean: %03f'%np.nanmean(np.abs(dYY1)))
    print('dYY2 mean: %03f'%np.nanmean(np.abs(dYY2)))

    dYY = np.concatenate((dYY1,dYY2),axis=0)
    
    titles = ['VIP silencing','VIP activation']
    for itype in [0,1,2,3]:
        plt.figure(figsize=(5,2.5))
        for iyy,dyy in enumerate([dYY1,dYY2]):
            plt.subplot(1,2,iyy+1)
            if np.sum(np.isnan(dyy[:,itype]))==0:
                sca.scatter_size_contrast(YYhat[:,itype],YYhat[:,itype]+dyy[:,itype],nsize=6,ncontrast=6)#,mn=0)
            plt.title(titles[iyy])
            plt.xlabel('cell type %d event rate, \n light off'%itype)
            plt.ylabel('cell type %d event rate, \n light on'%itype)
            ut.erase_top_right()
        plt.tight_layout()
        ut.mkdir('figures')
        plt.savefig('figures/scatter_light_on_light_off_target_celltype_%d.eps'%itype)
    
    opto_mask = ~np.isnan(dYY)
    #dYY[nN:][~opto_mask[nN:]] = -dYY[:nN][~opto_mask[nN:]]

    print('mean of opto_mask: '+str(opto_mask.mean()))
    
    #dYY[~opto_mask] = 0
    def zero_nans(arr):
        arr[np.isnan(arr)] = 0
        return arr
    #dYY,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res,\
    #        opto_transform2.slope,opto_transform2.intercept,opto_transform2.res\
    #        = [zero_nans(x) for x in \
    #                [dYY,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res,\
    #                opto_transform2.slope,opto_transform2.intercept,opto_transform2.res]]
    dYY = zero_nans(dYY)

    to_adjust = np.logical_or(np.isnan(opto_transform2.slope[0]),np.isnan(opto_transform2.intercept[0]))

    opto_transform2.slope[:,to_adjust] = 1/opto_transform1.slope[:,to_adjust]
    opto_transform2.intercept[:,to_adjust] = -opto_transform1.intercept[:,to_adjust]/opto_transform1.slope[:,to_adjust]
    opto_transform2.res[:,to_adjust] = -opto_transform1.res[:,to_adjust]/opto_transform1.slope[:,to_adjust]
    
    np.save('/Users/dan/Documents/notebooks/mossing-PC/shared_data/calnet_data/dYY.npy',dYY)
    
    from importlib import reload
    reload(calnet)
    #reload(calnet.fitting_spatial_feature_opto_nonlinear)
    reload(sim_utils)
    # reload(calnet.fitting_spatial_feature)
    # W0list = [np.ones(shp) for shp in shapes]
    wt_dict = {}
    wt_dict['X'] = 1
    wt_dict['Y'] = 15
    wt_dict['Eta'] = 10 # 1 # 
    wt_dict['Xi'] = 0.1
    wt_dict['stims'] = np.ones((nN,1)) #(np.arange(30)/30)[:,np.newaxis]**1 #
    wt_dict['barrier'] = 0. #30.0 #0.1
    wt_dict['opto'] = 1e-1#1e1
    wt_dict['isn'] = 3
    wt_dict['tv'] = 1
    wt_dict['stimsOpto'] = 0.6*np.ones((nN,1))
    wt_dict['stimsOpto'][0::6] = 3
    wt_dict['celltypesOpto'] = 0.67*np.ones((1,nQ*nS*nT))
    wt_dict['celltypesOpto'][0,0::nQ] = 2
    wt_dict['dirOpto'] = np.array((1,0.5))
    wt_dict['dYY'] = 1#1000
    wt_dict['Eta12'] = 1
    wt_dict['coupling'] = 1

    np.save('XXYYhat.npy',{'YYhat':YYhat,'XXhat':XXhat,'rs':rs,'Rs':Rs,'Rso':Rso,'Ypc_list':Ypc_list,'Xpc_list':Xpc_list})
    Eta0 = invert_f_mt(YYhat)
    Eta10 = invert_f_mt(YYhat + dYY[:nN])
    Eta20 = invert_f_mt(YYhat + dYY[nN:])
    print('mean Eta1 diff: '+str(np.mean(np.abs(Eta0-Eta10))))
    print('mean Eta2 diff: '+str(np.mean(np.abs(Eta0-Eta20))))


    ntries = 1
    nhyper = 1
    dt = 1e-1
    niter = int(np.round(10/dt)) #int(1e4)
    perturbation_size = 5e-2
    # learning_rate = 1e-4 # 1e-5 #np.linspace(3e-4,1e-3,niter+1) # 1e-5
    #l2_penalty = 0.1
    Wt = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    loss = np.zeros((nhyper,ntries))
    is_neg = np.array([b[1] for b in bounds])==0
    counter = 0
    negatize = [np.zeros(shp,dtype='bool') for shp in shapes]
    print(shapes)
    for ishp,shp in enumerate(shapes):
        nel = np.prod(shp)
        negatize[ishp][:][is_neg[counter:counter+nel].reshape(shp)] = True
        counter = counter + nel
    for ihyper in range(nhyper):
        for itry in range(ntries):
            print((ihyper,itry))
            W0list = [init_noise*(ihyper+1)*np.random.rand(*shp) for shp in shapes]
            print('size of shapes: '+str(np.sum([np.prod(shp) for shp in shapes])))
            print('size of w0: '+str(np.sum([np.size(x) for x in W0list])))
            print('len(W0list) : '+str(len(W0list)))
            counter = 0
            for ishp,shp in enumerate(shapes):
                W0list[ishp][negatize[ishp]] = -W0list[ishp][negatize[ishp]]
            W0list[4] = np.ones(shapes[5]) # s02
            W0list[5] = np.ones(shapes[5]) # K
            W0list[6] = np.ones(shapes[6]) # kappa
            W0list[7] = np.ones(shapes[7]) # T
            W0list[8] = np.concatenate(Xhat,axis=1) #XX
            W0list[9] = np.zeros_like(W0list[8]) #XXp
            W0list[10] = Eta0.copy() #np.zeros(shapes[10]) #Eta
            W0list[11] = np.zeros(shapes[11]) #Xi
            W0list[14] = Eta10.copy() # Eta1
            W0list[15] = Eta20.copy() # Eta2
            #[Wmx,Wmy,Wsx,Wsy,s02,k,kappa,T,XX,XXp,Eta,Xi]
    #         W0list = Wstar_dict['as_list'].copy()
    #         W0list[1][1,0] = -1.5
    #         W0list[1][3,0] = -1.5
            if init_W_from_lsq:
                W0list[0],W0list[1] = initialize_W(Xhat,Yhat,scale_by=scale_init_by)
                for ivar in range(0,2):
                    W0list[ivar] = W0list[ivar] + init_noise*np.random.randn(*W0list[ivar].shape)
            if constrain_isn:
                W0list[1][0,0] = 3 
                W0list[1][0,3] = 5 
                W0list[1][3,0] = -5
                W0list[1][3,3] = -5

            #if constrain_isn:
            #    W0list[1][0,0] = 2
            #    W0list[1][0,3] = 2
            #    W0list[1][3,0] = -2
            #    W0list[1][3,3] = -2

            #if wt_dict['coupling'] > 0:
            #    W0list[1][1,0] = -1

            if init_W_from_file:
                npyfile = np.load(init_file,allow_pickle=True)[()]
                W0list = npyfile['as_list']
                if W0list[8].size == nN*nS*2*nP:
                    W0list[7] = np.array(())
                    W0list[1][1,0] = W0list[1][1,0]
                    W0list[8] = np.nanmean(W0list[8].reshape((nN,nS,2,nP)),2).flatten() #XX
                    W0list[9] = np.nanmean(W0list[9].reshape((nN,nS,2,nP)),2).flatten() #XXp
                    W0list[10] = np.nanmean(W0list[10].reshape((nN,nS,2,nQ)),2).flatten() #Eta
                    W0list[11] = np.nanmean(W0list[11].reshape((nN,nS,2,nQ)),2).flatten() #Xi
                if correct_Eta:
                    W0list[10] = Eta0.copy()
                if len(W0list) < len(shapes):
                    W0list = W0list[:-1] + [np.array(-0.5),np.array(1),Eta10.copy(),Eta20.copy()] # add h1,h2,Eta1,Eta2
                if init_Eta_with_s02:
                    s02 = W0list[4].copy()
                    Eta0 = invert_f_mt_with_s02(YYhat,s02,nS=nS,nT=nT)
                    Eta10 = invert_f_mt_with_s02(YYhat+dYY[:nN],s02,nS=nS,nT=nT)
                    Eta20 = invert_f_mt_with_s02(YYhat+dYY[nN:],s02,nS=nS,nT=nT)
                    W0list[10] = Eta0.copy()
                    W0list[14] = Eta10.copy()
                    W0list[15] = Eta20.copy()
                if init_Eta12_with_dYY:
                    Eta0 = W0list[10].copy().reshape((nN,nQ*nS*nT))
                    Xi0 = W0list[11].copy().reshape((nN,nQ*nS*nT))
                    s020 = W0list[4].copy()
                    YY0s = compute_f_(Eta0,Xi0,s020)
                    this_YY1 = opto_transform1.transform(YY0s)
                    this_YY2 = opto_transform2.transform(YY0s)
                    Eta10 = invert_f_mt_with_s02(this_YY1,s020,nS=nS,nT=nT)
                    Eta20 = invert_f_mt_with_s02(this_YY2,s020,nS=nS,nT=nT)
                    W0list[14] = Eta10.copy()
                    W0list[15] = Eta20.copy()

                    YY10s = compute_f_(Eta10,Xi0,s020)
                    YY20s = compute_f_(Eta20,Xi0,s020)
                    titles = ['VIP silencing','VIP activation']
                    for itype in [0,1,2,3]:
                        plt.figure(figsize=(5,2.5))
                        for iyy,yy in enumerate([YY10s,YY20s]):
                            plt.subplot(1,2,iyy+1)
                            if np.sum(np.isnan(yy[:,itype]))==0:
                                sca.scatter_size_contrast(YY0s[:,itype],yy[:,itype],nsize=6,ncontrast=6)#,mn=0)
                            plt.title(titles[iyy])
                            plt.xlabel('cell type %d event rate, \n light off'%itype)
                            plt.ylabel('cell type %d event rate, \n light on'%itype)
                            ut.erase_top_right()
                        plt.tight_layout()
                        ut.mkdir('figures')
                        plt.savefig('figures/scatter_light_on_light_off_init_celltype_%d.eps'%itype)
                #if wt_dict['coupling'] > 0:
                #    W0list[1][1,0] = W0list[1][1,0] - 1
                for ivar in [0,1,4,5]: # Wmx, Wmy, s02, k
                    W0list[ivar] = W0list[ivar] + init_noise*np.random.randn(*W0list[ivar].shape)

            # wt_dict['Xi'] = 10
            # wt_dict['Eta'] = 10
            print('size of bounds: '+str(np.sum([np.size(x) for x in bdlist])))
            print('size of w0: '+str(np.sum([np.size(x) for x in W0list])))
            print('size of shapes: '+str(np.sum([np.prod(shp) for shp in shapes])))
            Wt[ihyper][itry],loss[ihyper][itry],gr,hess,result = calnet.fitting_spatial_feature_opto_nonlinear.fit_W_sim(Xhat,Xpc_list,Yhat,Ypc_list,pop_rate_fn=sim_utils.f_miller_troyer,pop_deriv_fn=sim_utils.fprime_miller_troyer,neuron_rate_fn=sim_utils.evaluate_f_mt,W0list=W0list.copy(),bounds=bounds,niter=niter,wt_dict=wt_dict,l2_penalty=l2_penalty,compute_hessian=False,dt=dt,perturbation_size=perturbation_size,dYY=dYY,constrain_isn=constrain_isn,tv=tv,opto_mask=opto_mask,use_opto_transforms=use_opto_transforms,opto_transform1=opto_transform1,opto_transform2=opto_transform2)
    #         Wt[ihyper][itry] = [w[-1] for w in Wt_temp]
    #         loss[ihyper,itry] = loss_temp[-1]
    
    
    def parse_W(W):
        Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2,Eta1,Eta2 = W
        return Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2,Eta1,Eta2
    
    
    itry = 0
    Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2,Eta1,Eta2 = parse_W(Wt[0][0])
    
    labels = ['Wmx','Wmy','Wsx','Wsy','s02','K','kappa','T','XX','XXp','Eta','Xi','h1','h2','Eta1','Eta2']
    Wstar_dict = {}
    for i,label in enumerate(labels):
        Wstar_dict[label] = Wt[0][0][i]
    Wstar_dict['as_list'] = [Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2,Eta1,Eta2]
    Wstar_dict['loss'] = loss[0][0]
    Wstar_dict['wt_dict'] = wt_dict
    np.save(weights_file,Wstar_dict,allow_pickle=True)
コード例 #19
0
def fit_weights_and_save(weights_file,ca_data_file='rs_sc_fg_pval_0_05_210410.npy',opto_silencing_data_file='vip_halo_data_for_sim.npy',opto_activation_data_file='vip_chrimson_data_for_sim.npy',constrain_wts=None,allow_var=True,multiout=True,multiout2=False,fit_s02=True,constrain_isn=True,tv=False,l2_penalty=0.01,l1_penalty=1.0,init_noise=0.1,init_W_from_lsq=False,scale_init_by=1,init_W_from_file=False,init_file=None,foldT=False,free_amplitude=False,correct_Eta=False,init_Eta_with_s02=False,no_halo_res=False,ignore_halo_vip=False,use_opto_transforms=False,norm_opto_transforms=False,nondim=False,fit_running=False,fit_non_running=True,fit_sc=True,fit_fg=False):
    
    
    nsize,ncontrast = 6,6
    
    nrun = 2
    nsize,ncontrast,ndir = 6,6,8
    nstim_fg = 5

    fit_both_running = (fit_non_running and fit_running)
    fit_both_stims = (fit_sc and fit_fg)

    if not fit_both_running:
        nrun = 1
        if fit_non_running:
            irun = 0
        elif fit_running:
            irun = 1

    nsc = nrun*nsize*ncontrast*ndir
    nfg = nrun*nstim_fg*ndir

    npfile = np.load(ca_data_file,allow_pickle=True)[()]#,{'rs':rs},allow_pickle=True) # ,'rs_denoise':rs_denoise
    if fit_both_running: 
        Rs_mean = npfile['Rs_mean_run']
        Rs_cov = npfile['Rs_cov_run']
        if not fit_both_stims:
            if fit_sc:
                Rs_mean,Rs_cov = get_Rs_slice(Rs_mean,Rs_cov,slice(None,nsc))
            elif fit_fg:
                Rs_mean,Rs_cov = get_Rs_slice(Rs_mean,Rs_cov,slice(nsc,None))
    else:
        Rs_mean = npfile['Rs_mean'][irun]
        Rs_cov = npfile['Rs_cov'][irun]
        if not fit_both_stims:
            if fit_sc:
                Rs_mean,Rs_cov = get_Rs_slice(Rs_mean,Rs_cov,slice(None,nsc))
            elif fit_fg:
                Rs_mean,Rs_cov = get_Rs_slice(Rs_mean,Rs_cov,slice(nsc,None))
    
    ori_dirs = [[0,4],[2,6]] #[[0,4],[1,3,5,7],[2,6]]
    ndims = 5
    nT = len(ori_dirs)
    nS = len(Rs_mean[0])
    
    def sum_to_1(r):
        R = r.reshape((r.shape[0],-1))
        R = R/np.nansum(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis] # changed 21/4/10
        return R
    
    def norm_to_mean(r):
        R = r.reshape((r.shape[0],-1))
        R = R/np.nanmean(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis]
        return R

    def ori_avg(Rs,these_ori_dirs):
        if fit_sc:
            rs_sc = np.nanmean(Rs[:nsc].reshape((nrun,nsize,ncontrast,ndir))[:,:,:,these_ori_dirs],-1)
            rs_sc[:,1:,1:] = ssi.convolve(rs_sc,kernel,'valid')
            rs_sc = rs_sc.reshape((nrun*nsize*ncontrast))
            if fit_fg:
                rs_fg = np.nanmean(Rs[nsc:].reshape((nrun,nstim_fg,ndir))[:,:,these_ori_dirs],-1)
                rs_fg = rs_fg.reshape((nrun*nstim_fg))
            else:
                rs_fg = np.zeros((0,))
        elif fit_fg:
            rs_sc = np.zeros((0,))
            rs_fg = np.nanmean(Rs.reshape((nrun,nstim_fg,ndir))[:,:,these_ori_dirs],-1)
            rs_fg = rs_fg.reshape((nrun*nstim_fg))
        Rso = np.concatenate((rs_sc,rs_fg))
        return Rso
    
    Rso_mean = [[[None for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(Rs_mean))]
    Rso_cov = [[[[[None,None] for idim in range(ndims)] for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(Rs_mean))]
    
    kernel = np.ones((1,2,2))
    kernel = kernel/kernel.sum()
    
    for iR,r in enumerate(Rs_mean):
        for ialign in range(nS):
            for iori in range(nT):
                Rso_mean[iR][ialign][iori] = ori_avg(Rs_mean[iR][ialign],ori_dirs[iori])
                for idim in range(ndims):
                    Rso_cov[iR][ialign][iori][idim][0] = Rs_cov[iR][ialign][idim][0]
                    Rso_cov[iR][ialign][iori][idim][1] = ori_avg(Rs_cov[iR][ialign][idim][1],ori_dirs[iori])

    def set_bound(bd,code,val=0):
        # set bounds to 0 where 0s occur in 'code'
        for iitem in range(len(bd)):
            bd[iitem][code[iitem]] = val
    
    nN = (36*fit_sc + 5*fit_fg)*(1 + fit_both_running)
    nS = 2
    nP = 2 + fit_both_running
    nT = 2
    nQ = 4

    ndims = 5
    ncelltypes = 5
    #print('foldT: %d'%foldT)
    if foldT:
        Yhat = [None for iS in range(nS)]
        Xhat = [None for iS in range(nS)]
        Ypc_list = [None for iS in range(nS)]
        Xpc_list = [None for iS in range(nS)]
        print('have not written this yet')
        assert(True==False)
    else:
        Yhat = [[None for iT in range(nT)] for iS in range(nS)]
        Xhat = [[None for iT in range(nT)] for iS in range(nS)]
        Ypc_list = [[None for iT in range(nT)] for iS in range(nS)]
        Xpc_list = [[None for iT in range(nT)] for iS in range(nS)]
        for iS in range(nS):
            mx = np.zeros((ncelltypes,))
            yy = [None for icelltype in range(ncelltypes)]
            for icelltype in range(ncelltypes):
                yy[icelltype] = np.concatenate(Rso_mean[icelltype][iS])
                mx[icelltype] = np.nanmax(yy[icelltype])
            for iT in range(nT):
                y = [Rso_mean[icelltype][iS][iT][:,np.newaxis]/mx[icelltype] for icelltype in range(1,ncelltypes)]
                Yhat[iS][iT] = np.concatenate(y,axis=1)
                Ypc_list[iS][iT] = [None for icelltype in range(1,ncelltypes)]
                for icelltype in range(1,ncelltypes):
                    Ypc_list[iS][iT][icelltype-1] = [(this_dim[0]/mx[icelltype],this_dim[1]) for this_dim in Rso_cov[icelltype][iS][iT]]
                icelltype = 0
                x = Rso_mean[icelltype][iS][iT][:,np.newaxis]/mx[icelltype]
                if fit_both_running:
                    run_vector = np.zeros_like(x)
                    if fit_both_stims:
                        run_vector[nsize*ncontrast:2*nsize*ncontrast] = 1
                        run_vector[-nstim_fg:] = 1
                    else:
                        run_vector[int(np.round(run_vector.shape[0]/2)):,:] = 1
                else:
                    run_vector = np.zeros((x.shape[0],0))
                Xhat[iS][iT] = np.concatenate((x,np.ones_like(x),run_vector),axis=1)
                Xpc_list[iS][iT] = [None for iinput in range(2+fit_both_running)]
                Xpc_list[iS][iT][0] = [(this_dim[0]/mx[icelltype],this_dim[1]) for this_dim in Rso_cov[icelltype][iS][iT]]
                Xpc_list[iS][iT][1] = [(0,np.zeros((Xhat[0][0].shape[0],))) for idim in range(ndims)]
                if fit_both_running:
                    Xpc_list[iS][iT][2] = [(0,np.zeros((Xhat[0][0].shape[0],))) for idim in range(ndims)]
    nN,nP = Xhat[0][0].shape
    nQ = Yhat[0][0].shape[1]
    
    # code for bounds: 0 , constrained to 0
    # +/-1 , constrained to +/-1
    # 1.5, constrained to [0,1]
    # -1.5, constrained to [-1,1]
    # 2 , constrained to [0,inf)
    # -2 , constrained to (-inf,0]
    # 3 , unconstrained
    
    W0x_bounds = 3*np.ones((nP,nQ),dtype=int)
    W0x_bounds[0,:] = 2 # L4 PCs are excitatory
    W0x_bounds[0,1] = 0 # SSTs don't receive L4 input 
    
    if allow_var:
        if nondim:
            W1x_bounds = -1.5*np.ones(W0x_bounds.shape) #W0x_bounds.copy()*0 #np.zeros_like(W0x_bounds)
        else:
            W1x_bounds = 3*np.ones(W0x_bounds.shape) #W0x_bounds.copy()*0 #np.zeros_like(W0x_bounds)
        W1x_bounds[0,1] = 0
    else:
        W1x_bounds = np.zeros(W0x_bounds.shape) #W0x_bounds.copy()*0 #np.zeros_like(W0x_bounds)
    
    W0y_bounds = 3*np.ones((nQ,nQ),dtype=int)
    W0y_bounds[0,:] = 2 # PCs are excitatory
    W0y_bounds[1:,:] = -2 # all the cell types except PCs are inhibitory
    W0y_bounds[1,1] = 0 # SSTs don't inhibit themselves
    # W0y_bounds[3,1] = 0 # PVs are allowed to inhibit SSTs, consistent with Hillel's unpublished results, but not consistent with Pfeffer et al.
    W0y_bounds[2,0] = 0 # VIPs don't inhibit L2/3 PCs. According to Pfeffer et al., only L5 PCs were found to get VIP inhibition
    W0y_bounds[2,2] = 0 # newly added: no VIP-VIP inhibition 


    if not constrain_wts is None:
        for wt in constrain_wts:
            W0y_bounds[wt[0],wt[1]] = 0
            W1y_bounds[wt[0],wt[1]] = 0
    
    def tile_nS_nT_nN(kernel):
        row = np.concatenate([kernel for idim in range(nS*nT)],axis=0)[np.newaxis,:]
        tiled = np.concatenate([row for irow in range(nN)],axis=0)
        return tiled
    
    if fit_s02:
        s02_bounds = 2*np.ones((nQ,)) # permitting noise as a free parameter
    else:
        s02_bounds = np.ones((nQ,))
    
    Kin0_bounds = 1.5*np.ones((nQ,))
    
    kappa_bounds = np.ones((1,))
    # kappa_bounds = 2*np.ones((1,))
    
    Tin0_bounds = 1.5*np.ones((nQ,))
    #T_bounds[2:4] = 1 # PV and VIP are constrained to have flat ori tuning
    #Tin0_bounds[1:4] = 1 # SST,VIP, and PV are constrained to have flat ori tuning

    if nondim:
        kt_factor = -1.5
    else:
        kt_factor = 3

    if allow_var:
        W1y_bounds = kt_factor*np.ones(W0y_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        Kin1_bounds = kt_factor*np.ones(Kin0_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        Tin1_bounds = kt_factor*np.ones(Tin0_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        W1y_bounds[1,1] = 0
        #W1y_bounds[3,1] = 0 
        W1y_bounds[2,0] = 0
        W1y_bounds[2,2] = 0 # newly added: no VIP-VIP inhibition
    else:
        W1y_bounds = np.zeros(W0y_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        Kin1_bounds = 0*np.ones(Kin0_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        Tin1_bounds = 0*np.ones(Tin0_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)

    if multiout:
        W2x_bounds = W1x_bounds.copy()
        W2y_bounds = W1y_bounds.copy()
        if multiout2:
            W3x_bounds = W1x_bounds.copy()
            W3y_bounds = W1y_bounds.copy()
        else:
            W3x_bounds = W1x_bounds.copy()*0
            W3y_bounds = W1y_bounds.copy()*0
    else:
        W2x_bounds = W1x_bounds.copy()*0
        W2y_bounds = W1y_bounds.copy()*0
        W3x_bounds = W1x_bounds.copy()*0
        W3y_bounds = W1y_bounds.copy()*0

    Kxout0_bounds = np.array((1.5,)+tuple(np.zeros((nP-1,))))
    Txout0_bounds = Kxout0_bounds.copy()
    Kxout1_bounds = np.array((kt_factor,)+tuple(np.zeros((nP-1,))))
    Txout1_bounds = Kxout1_bounds.copy() 

    Kyout0_bounds = Kin0_bounds.copy()
    Tyout0_bounds = Tin0_bounds.copy()
    Kyout1_bounds = Kin1_bounds.copy()
    Tyout1_bounds = Tin1_bounds.copy()
    
    if fit_both_running:
        to_tile = Xhat[0][0][:,1:]
        to_tile = np.concatenate((2*np.ones((to_tile.shape[0],1)),to_tile),axis=1)
        X_bounds = np.tile(to_tile,(1,nS*nT))
    else:
        X_bounds = tile_nS_nT_nN(np.array([2,1]))
    #print(X_bounds.shape)
    # X_bounds = np.array([np.array([2,1,2,1])]*nN)
    
    if fit_both_running:
        Xp_bounds = tile_nS_nT_nN(np.array([3,0,0])) # edited to set XXp to 0 for spont. term
    else:
        Xp_bounds = tile_nS_nT_nN(np.array([3,0])) # edited to set XXp to 0 for spont. term
    # Xp_bounds = np.array([np.array([3,1,3,1])]*nN)
    
    # Y_bounds = tile_nS_nT_nN(2*np.ones((nQ,)))
    # # Y_bounds = 2*np.ones((nN,nT*nS*nQ))
    
    Eta_bounds = tile_nS_nT_nN(3*np.ones((nQ,)))
    # Eta_bounds = 3*np.ones((nN,nT*nS*nQ))
    
    #if allow_var:
    #    Xi_bounds = tile_nS_nT_nN(3*np.ones((nQ,)))
    #else:
    #    Xi_bounds = tile_nS_nT_nN(np.zeros((nQ,)))
    Xi_bounds = tile_nS_nT_nN(3*np.ones((nQ,))) # temporarily allowing Xi even if W1 is not allowed

    # Xi_bounds = 3*np.ones((nN,nT*nS*nQ))
    
    h1_bounds = -2*np.ones((1,))

    h2_bounds = 2*np.ones((1,))

    bl_bounds = 3*np.ones((nQ,))

    if free_amplitude:
        amp_bounds = 2*np.ones((nT*nS*nQ,))
    else:
        amp_bounds = 1*np.ones((nT*nS*nQ,))
    
    # shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nN,nS*nP),(nN,nS*nQ),(nN,nS*nQ),(nN,nS*nQ)]
    #shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ),(1,)]
    shapes1 = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(nQ*(nS-1),),(nP*(nS-1),),(nQ*(nS-1),),(nP*(nS-1),),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nQ*(nT-1),),(nP*(nT-1),),(nQ*(nT-1),),(nP*(nT-1),),(nQ*(nT-1),),(1,),(1,),(nQ,),(nT*nS*nQ,)]
    shapes2 = [(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ)]
    #         W0x,    W0y,    W1x,    W1y,    W2x,    W2y,    W3x,    W3y,    s02,  k,    kappa,T,   XX,            XXp,          Eta,          Xi
    
    #lb = [-np.inf*np.ones(shp) for shp in shapes]
    #ub = [np.inf*np.ones(shp) for shp in shapes]
    #bdlist = [W0x_bounds,W0y_bounds,W1x_bounds,W1y_bounds,W2x_bounds,W2y_bounds,W3x_bounds,W3y_bounds,s02_bounds,k0_bounds,k1_bounds,k2_bounds,k3_bounds,kappa_bounds,Tin0_bounds,Tin1_bounds,Tout0_bounds,Tout1_bounds,X_bounds,Xp_bounds,Eta_bounds,Xi_bounds,h_bounds]
    bd1list = [W0x_bounds,W0y_bounds,W1x_bounds,W1y_bounds,W2x_bounds,W2y_bounds,W3x_bounds,W3y_bounds,s02_bounds,Kin0_bounds,Kin1_bounds,Kxout0_bounds,Kyout0_bounds,Kxout1_bounds,Kyout1_bounds,kappa_bounds,Tin0_bounds,Tin1_bounds,Txout0_bounds,Tyout0_bounds,Txout1_bounds,Tyout1_bounds,h1_bounds,h2_bounds,bl_bounds,amp_bounds]
    bd2list = [X_bounds,Xp_bounds,Eta_bounds,Xi_bounds]

    lb1,ub1 = [[sgn*np.inf*np.ones(shp) for shp in shapes1] for sgn in [-1,1]]
    lb1,ub1 = calnet.utils.set_bounds_by_code(lb1,ub1,bd1list)
    lb2,ub2 = [[sgn*np.inf*np.ones(shp) for shp in shapes2] for sgn in [-1,1]]
    lb2,ub2 = calnet.utils.set_bounds_by_code(lb2,ub2,bd2list)

    lb1 = np.concatenate([a.flatten() for a in lb1])
    ub1 = np.concatenate([b.flatten() for b in ub1])
    lb2 = np.concatenate([a.flatten() for a in lb2])
    ub2 = np.concatenate([b.flatten() for b in ub2])
    bounds1 = [(a,b) for a,b in zip(lb1,ub1)]
    bounds2 = [(a,b) for a,b in zip(lb2,ub2)]
    
    
    def compute_f_(Eta,Xi,s02):
        return sim_utils.f_miller_troyer(Eta,Xi**2+np.concatenate([s02 for ipixel in range(nS*nT)]))
    def compute_fprime_m_(Eta,Xi,s02):
        return sim_utils.fprime_miller_troyer(Eta,Xi**2+np.concatenate([s02 for ipixel in range(nS*nT)]))*Xi
    def compute_fprime_s_(Eta,Xi,s02):
        s2 = Xi**2+np.concatenate((s02,s02),axis=0)
        return sim_utils.fprime_s_miller_troyer(Eta,s2)*(Xi/s2)
    def sorted_r_eigs(w):
        drW,prW = np.linalg.eig(w)
        srtinds = np.argsort(drW)
        return drW[srtinds],prW[:,srtinds]
    
    #0.W0x,1.W0y,2.W1x,3.W1y,4.W2x,5.W2y,6.W3x,7.W3y,8.s02,9.Kin0,10.Kin1,11.Kout0,12.Kout1,13.kappa,14.Tin0,15.Tin1,16.Txout0,Tyout0,17.Txout1,Tyout1,18.h1,19.h2,20.bl,21.amp
    #0.XX,1.XXp,2.Eta,3.Xi
    
    #shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ),(1,)]
    #shapes1 = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(1,),(1,),(nQ,),(nT*nS*nQ,)]
    #shapes2 = [(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ)]
    
    import sim_utils

    YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat)
    
    opto_dict = np.load(opto_silencing_data_file,allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = np.ones((nN*2,nQ*nS*nT))
    #Yhat_opto = Yhat_opto.reshape((nN*2,-1))
    Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12],axis=0)[np.newaxis]
    Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12],axis=0)[np.newaxis]
    Yhat_opto = Yhat_opto/np.nanmax(Yhat_opto[0::2],0)[np.newaxis,:]
    #print(Yhat_opto.shape)
    h_opto = np.zeros((nN*2,))
    #h_opto = opto_dict['h_opto']
    #dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2]

    YYhat_halo = Yhat_opto.reshape((nN,2,-1))
    opto_transform1 = calnet.utils.fit_opto_transform(YYhat_halo,norm01=norm_opto_transforms)

    if no_halo_res:
        opto_transform1.res[:,[0,2,3,4,6,7]] = 0

    dYY1 = opto_transform1.transform(YYhat) - opto_transform1.preprocess(YYhat)
    #print('delta bias: %f'%dXX1[:,1].mean())
    #YYhat_halo_sim = calnet.utils.simulate_opto_effect(YYhat,YYhat_halo)
    #dYY1 = YYhat_halo_sim[:,1,:] - YYhat_halo_sim[:,0,:]

    def overwrite_plus_n(arr,to_overwrite,n):
        arr[:,to_overwrite] = arr[:,int(to_overwrite+n)]
        return arr

    for to_overwrite in [1,2]:
        n = 4
        dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \
                = [overwrite_plus_n(x,to_overwrite,n) for x in \
                        [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]]
    for to_overwrite in [7]:
        n = -4
        dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \
                = [overwrite_plus_n(x,to_overwrite,n) for x in \
                        [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]]

    opto_dict = np.load(opto_activation_data_file,allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = np.ones((nN*2,nQ*nS*nT))
    #Yhat_opto = Yhat_opto.reshape((nN*2,-1))
    Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12],axis=0)[np.newaxis]
    Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12],axis=0)[np.newaxis]
    Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    #print(Yhat_opto.shape)
    h_opto = np.zeros((nN*2,))
    #h_opto = opto_dict['h_opto']
    #dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2]

    YYhat_chrimson = Yhat_opto.reshape((nN,2,-1))
    opto_transform2 = calnet.utils.fit_opto_transform(YYhat_chrimson,norm01=norm_opto_transforms)

    dYY2 = opto_transform2.transform(YYhat) - opto_transform2.preprocess(YYhat)

    dYY = np.concatenate((dYY1,dYY2),axis=0)

    if ignore_halo_vip:
        dYY1[:,2::nQ] = np.nan
    
    from importlib import reload
    reload(calnet)
    reload(calnet.fitting_2step_spatial_feature_opto_multiout_axon_nonlinear)
    reload(sim_utils)
    wt_dict = {}
    wt_dict['X'] = 1
    wt_dict['Y'] = 3
    wt_dict['Eta'] = 3# 10
    wt_dict['Xi'] = 3
    wt_dict['stims'] = np.ones((nN,1)) #(np.arange(30)/30)[:,np.newaxis]**1 #
    wt_dict['barrier'] = 0. #30.0 #0.1
    wt_dict['opto'] = 0#1e0#1e-1#1e1
    wt_dict['smi'] = 0
    wt_dict['isn'] = 0.1
    wt_dict['tv'] = 1


    YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat)
    XXhat = calnet.utils.flatten_nested_list_of_2d_arrays(Xhat)
    Eta0 = invert_f_mt(YYhat)
    Xi0 = invert_fprime_mt(Ypc_list,Eta0,nN=nN,nQ=nQ,nS=nS,nT=nT,foldT=foldT)

    ntries = 1
    nhyper = 1
    dt = 1e-1
    niter = int(np.round(10/dt)) #int(1e4)
    perturbation_size = 5e-2
    W1t = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    W2t = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    loss = np.zeros((nhyper,ntries))
    is_neg = np.array([b[1] for b in bounds1])==0
    counter = 0
    negatize = [np.zeros(shp,dtype='bool') for shp in shapes1]
    for ishp,shp in enumerate(shapes1):
        nel = np.prod(shp)
        negatize[ishp][:][is_neg[counter:counter+nel].reshape(shp)] = True
        counter = counter + nel
    for ihyper in range(nhyper):
        for itry in range(ntries):
            print((ihyper,itry))
            W10list = [init_noise*(ihyper+1)*np.random.rand(*shp) for shp in shapes1]
            W20list = [init_noise*(ihyper+1)*np.random.rand(*shp) for shp in shapes2]
            counter = 0
            for ishp,shp in enumerate(shapes1):
                W10list[ishp][negatize[ishp]] = -W10list[ishp][negatize[ishp]]
            nextraW = 4
            nextraK = nextraW + 3
            nextraT = nextraK + 3
            #Wstar_dict['as_list'] = [W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,XX,XXp,Eta,Xi,h1,h2,bl,amp]#,h2
            init_val = [1,1,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1]
            W10list = [iv*np.ones(shp) for iv,shp in zip(init_val,shapes1)]
            #W10list[nextraW+4] = np.ones(shapes1[nextraW+4]) # s02
            #W10list[nextraW+5] = np.ones(shapes1[nextraW+5]) # K
            #W10list[nextraW+6] = np.ones(shapes1[nextraW+6]) # K
            #W10list[nextraW+7] = np.zeros(shapes1[nextraW+7]) # K
            #W10list[nextraW+8] = np.zeros(shapes1[nextraW+8]) # K
            #W10list[nextraK+6] = np.ones(shapes1[nextraK+6]) # kappa
            #W10list[nextraK+7] = np.ones(shapes1[nextraK+7]) # T
            #W10list[nextraK+8] = np.ones(shapes1[nextraK+8]) # T
            #W10list[nextraK+9] = np.zeros(shapes1[nextraK+9]) # T
            #W10list[nextraK+10] = np.zeros(shapes1[nextraK+10]) # T
            W20list[0] = XXhat #np.concatenate(Xhat,axis=1) #XX
            W20list[1] = get_pc_dim(Xpc_list,nN=nN,nPQ=nP,nS=nS,nT=nT,idim=0,foldT=foldT) #XXp
            W20list[2] = Eta0 #np.zeros(shapes[nextraT+10]) #Eta
            W20list[3] = Xi0 #Xi
            #print(XXhat.shape)
            isn_init = np.array(((3,5),(-5,-5)))
            if init_W_from_lsq:
                # shapes1
                #0.W0x,1.W0y,2.W1x,3.W1y,4.W2x,5.W2y,6.W3x,7.W3y,8.s02,9.Kin0,10.Kin1,11.Kout0,12.Kout1,13.kappa,14.Tin0,15.Tin1,16.Txout0,Tyout0,17.Txout1,Tyout1,18.h1,19.h2,20.bl,21.amp
                # shapes2
                #0.XX,1.XXp,2.Eta,3.Xi
                #W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,Kin0,Kin1,Tin0,Tin1 = initialize_Ws(Xhat,Yhat,Xpc_list,Ypc_list,scale_by=1)
                nvar,nxy = 4,2
                freeze_vals = [[None for _ in range(nxy)] for _ in range(nvar)]
                lams = 1e5*np.array((0,1,1,1,0,1,0,1))
                for ivar in range(nvar):
                    for ixy in range(nxy):
                        iflat = np.ravel_multi_index((ivar,ixy),(nvar,nxy))
                        freeze_vals[ivar][ixy] = np.zeros(bd1list[iflat].shape)
                        freeze_vals[ivar][ixy][bd1list[iflat]==0] = np.nan
                if constrain_isn:
                    freeze_vals[0][1][slice(0,None,3)][:,slice(0,None,3)] = isn_init
                #Wlist = [W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1]
                # W1list = [W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,h1,h2,bl,amp]#,h2
                # W0,W1,W2,W3,Kin0,Kin1,Tin0,Tin1
                thisWlist = initialize_Ws(Xhat,Yhat,Xpc_list,Ypc_list,scale_by=1,freeze_vals=freeze_vals,lams=lams,foldT=foldT)
                Winds = [0,1,2,3,4,5,6,7,9,10,11,12,13,14,16,17,18,19,20,21]
                for ivar,Wind in enumerate(Winds):
                    W10list[Wind] = thisWlist[ivar]
                #W10list[0],W10list[1] = initialize_W(Xhat,Yhat,scale_by=scale_init_by)
                for Wind in Winds:
                    W10list[Wind] = W10list[Wind] + init_noise*np.random.randn(*W10list[Wind].shape)
            else:
                if constrain_isn:
                    W10list[1][slice(0,None,3)][:,slice(0,None,3)] = isn_init
                    #W10list[1][0,0] = 3 
                    #W10list[1][0,3] = 5 
                    #W10list[1][3,0] = -5
                    #W10list[1][3,3] = -5
            np.save('/home/dan/calnet_data/W0list.npy',{'W10list':W10list,'W20list':W20list,'bd1list':bd1list,'bd2list':bd2list,'freeze_vals':freeze_vals,'bounds1':bounds1,'bounds2':bounds2},allow_pickle=True)

            if init_W_from_file:
                # did not adjust this yet
                npyfile = np.load(init_file,allow_pickle=True)[()]
                print(len(npyfile['as_list']))
                print([w.shape for w in npyfile['as_list']])
                W10list = [npyfile['as_list'][ivar] for ivar in [0,1,2,3,4,5,6,7,12]]
                W20list = [npyfile['as_list'][ivar] for ivar in [8,9,10,11]]
                if correct_Eta:
                    #assert(True==False)
                    W20list[2] = Eta0.copy()
                if len(W10list) < len(shapes1):
                    #assert(True==False)
                    W10list = W10list + [np.array(1),np.zeros((nQ,)),np.zeros((nT*nS*nQ,))] # add bl, amp #np.array(1), #h2, 
                #W10 = unparse_W(W10list)
                #W20 = unparse_W(W20list)
                opt = fmc.gen_opt()
                #resEta0,resXi0 = fmc.compute_res(W10,W20,opt)
                if init_W1xy_with_res:
                    W1x0,W1y0,Kin10,Tin10 = optimize_W1xy(W10list,W20list,opt)
                    W0list[2] = W1x0
                    W0list[3] = W1y0
                    W0list[10] = Kin10
                    W0list[15] = Tin10
                if init_W2xy_with_res:
                    W2x0,W2y0 = optimize_W2xy(W10list,W20list,opt)
                    W0list[4] = W2x0
                    W0list[5] = W2y0
                if init_Eta_with_s02:
                    #assert(True==False)
                    s02 = W10list[4].copy()
                    Eta0 = invert_f_mt_with_s02(YYhat,s02,nS=nS,nT=nT)
                    W20list[2] = Eta0.copy()
                for ivar in [0,1,4,5]: # Wmx, Wmy, s02, k
                    print(init_noise)
                    W10list[ivar] = W10list[ivar] + init_noise*np.random.randn(*W10list[ivar].shape)
                #W0list = npyfile['as_list']

                extra_Ws = [np.zeros_like(W10list[ivar]) for ivar in range(2)]
                extra_ks = [np.zeros_like(W10list[5]) for ivar in range(3)]
                extra_Ts = [np.zeros_like(W10list[7]) for ivar in range(3)]
                W10list = W10list[:4] + extra_Ws*2 + W10list[4:6] + extra_ks + W10list[6:8] + extra_Ts + W10list[8:]

            print(len(W10list))
            W1t[ihyper][itry],W2t[ihyper][itry],loss[ihyper][itry],gr,hess,result = calnet.fitting_2step_spatial_feature_opto_multiout_axon_nonlinear.fit_W_sim(Xhat,Xpc_list,Yhat,Ypc_list,pop_rate_fn=sim_utils.f_miller_troyer,pop_deriv_fn=sim_utils.fprime_miller_troyer,neuron_rate_fn=sim_utils.evaluate_f_mt,W10list=W10list.copy(),W20list=W20list.copy(),bounds1=bounds1,bounds2=bounds2,niter=niter,wt_dict=wt_dict,l2_penalty=l2_penalty,l1_penalty=l1_penalty,compute_hessian=False,dt=dt,perturbation_size=perturbation_size,dYY=dYY,constrain_isn=constrain_isn,tv=tv,foldT=foldT,use_opto_transforms=use_opto_transforms,opto_transform1=opto_transform1,opto_transform2=opto_transform2,nondim=nondim)
    
    #def parse_W(W):
    #    W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kout0,Kout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,XX,XXp,Eta,Xi,h = W
    #    return W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kout0,Kout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,XX,XXp,Eta,Xi,h
    def parse_W1(W):
        W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,h1,h2,bl,amp = W #h2,
        return W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,h1,h2,bl,amp #h2,
    def parse_W2(W):
        XX,XXp,Eta,Xi = W
        return XX,XXp,Eta,Xi    

    def unparse_W(Ws):
        return np.concatenate([ww.flatten() for ww in Ws])
    
    itry = 0
    W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,h1,h2,bl,amp = parse_W1(W1t[0][0])#h2,
    XX,XXp,Eta,Xi = parse_W2(W2t[0][0])
    
    #labels = ['W0x','W0y','W1x','W1y','W2x','W2y','W3x','W3y','s02','Kin0','Kin1','Kout0','Kout1','kappa','Tin0','Tin1','Tout0','Tout1','XX','XXp','Eta','Xi','h']
    labels1 = ['W0x','W0y','W1x','W1y','W2x','W2y','W3x','W3y','s02','Kin0','Kin1','Kxout0','Kyout0','Kxout1','Kyout1','kappa','Tin0','Tin1','Txout0','Tyout0','Txout1','Tyout1','h1','h2','bl','amp']#,'h2'
    labels2 = ['XX','XXp','Eta','Xi']
    Wstar_dict = {}
    for i,label in enumerate(labels1):
        Wstar_dict[label] = W1t[0][0][i]
    for i,label in enumerate(labels2):
        Wstar_dict[label] = W2t[0][0][i]
    #Wstar_dict = {}
    #for i,label in enumerate(labels):
    #    Wstar_dict[label] = W1t[0][0][i]
    Wstar_dict['as_list'] = [W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,XX,XXp,Eta,Xi,h1,h2,bl,amp]#,h2
    Wstar_dict['loss'] = loss[0][0]
    Wstar_dict['wt_dict'] = wt_dict
    np.save(weights_file,Wstar_dict,allow_pickle=True)
コード例 #20
0
ファイル: sim_from_gp_1D.py プロジェクト: natalieklein/gpcsd
plt.imshow(normalize(kcsd_values[1:-1, :, 0]),
           vmin=-1,
           vmax=1,
           cmap='bwr',
           aspect='auto')
plt.title('kCSD')
plt.xlabel('Time')
plt.subplot(144)
plt.imshow(lfp[:, :, 50], vmin=-1, vmax=1, cmap='bwr', aspect='auto')
plt.title('LFP')
plt.xlabel('Time')
plt.tight_layout()

# %% Compute MSE -- mean squared error across space/time
tcsd_meansqerr = np.nanmean(np.square(
    normalize(tcsd_pred[1:-1, :, :]) -
    normalize(csd_interior_electrodes[1:-1, :, 50:])),
                            axis=(0, 1))
gpcsd_meansqerr = np.nanmean(np.square(
    normalize(gpcsd_model.csd_pred[1:-1, :, :]) -
    normalize(csd_interior_electrodes[1:-1, :, 50:])),
                             axis=(0, 1))
kcsd_meansqerr = np.nanmean(np.square(
    normalize(kcsd_values[2:-2, :, :]) -
    normalize(csd_interior_electrodes[1:-1, :, 50:])),
                            axis=(0, 1))

tcsd_rsq = 1 - np.sum(np.square(
    normalize(tcsd_pred[1:-1, :, :]) -
    normalize(csd_interior_electrodes[1:-1, :, 50:])),
                      axis=(0, 1)) / np.sum(np.square(
                          normalize(csd_interior_electrodes[1:-1, :, 50:])),
コード例 #21
0
           vmin=-1,
           vmax=1,
           cmap='bwr',
           aspect='auto')
plt.title('Trad CSD prediction')
plt.xlabel('Time (ms)')
plt.subplot(144)
plt.imshow(normalize(lfp3[:, :, 0]),
           vmin=-1,
           vmax=1,
           cmap='bwr',
           aspect='auto')
plt.title('LFP')
plt.xlabel('Time (ms)')
plt.tight_layout()

# %%
gpcsd2_meansqerr = np.nanmean(np.square(
    normalize(gpcsd_model2.csd_pred[1:-1, :, :]) -
    normalize(csd2_interior_electrodes[1:-1, :, :])),
                              axis=(0, 1))
gpcsd3_meansqerr = np.nanmean(np.square(
    normalize(gpcsd_model3.csd_pred[1:-1, :, :]) -
    normalize(csd3_interior_electrodes[1:-1, :, :])),
                              axis=(0, 1))

print('GPCSD2 average MSE across trials: %0.3g' % np.mean(gpcsd2_meansqerr))
print('GPCSD3 average MSE across trials: %0.3g' % np.mean(gpcsd3_meansqerr))

# %%
コード例 #22
0
                   jac=True,
                   options={
                       'maxiter': 5000,
                       'disp': True,
                       'ftol': 0
                   },
                   callback=callback0)
    PATH = ROOT_PATH + "/MMR_IVs/results/" + sname + "/"
    os.makedirs(PATH, exist_ok=True)
    np.save(PATH + 'LMO_errs_{}_nystr.npy'.format(seed),
            [opt_params, prev_norm, opt_test_err])


if __name__ == '__main__':
    snames = ['mnist_z', 'mnist_x', 'mnist_xz']
    for sname in snames:
        for seed in range(100):
            experiment(sname, seed)

        PATH = ROOT_PATH + "/MMR_IVs/results/" + sname + "/"
        ress = []
        for seed in range(100):
            filename = PATH + 'LMO_errs_{}_nystr.npy'.format(seed)
            if os.path.exists(filename):
                res = np.load(filename, allow_pickle=True)
                if res[-1] is not None:
                    ress += [res[-1]]
        ress = np.array(ress)
        ress = remove_outliers(ress)
        print(np.nanmean(ress), np.nanstd(ress))
コード例 #23
0
def adagrad_workflow_optimize(n_iters,
                              objective_and_grad,
                              init_param,
                              K,
                              has_log_norm=False,
                              window=10,
                              learning_rate=.01,
                              learning_rate_end=None,
                              epsilon=.1,
                              tolerance=0.05,
                              eval_elbo=100,
                              stopping_rule=1,
                              n_optimizers=1,
                              r_mean_threshold=1.20,
                              r_sigma_threshold=1.20,
                              tail_avg_iters=200,
                              plotting=False,
                              model_name=None):
    """
    stopping rule 1 means traditional ELBO stopping rule, while
    stopping rule 2 means MCSE stopping rule.

    The windowed Adagrad optimizer with convergence diagnostics and iterate averaging ...

    :param n_iters:
    :param objective_and_grad:
    :param init_param: initial params
    :param K:
    :param has_log_norm:
    :param window:
    :param learning_rate:
    :param epsilon:
    :param rhat_window:
    :param averaging:
    :param n_optimisers:
    :param r_mean_threshold:
    :param r_sigma_threshold:
    :param tail_avg_iters:
    :param eval_elbo:
    :param tolerance:
    :param stopping_rule:
    :param avg_grad_norm:
    :param learning_rate_end:
    :param plotting:
    :param model_name:
    :return:
    """

    log_norm_history = []
    variational_param = init_param.copy()
    prev_elbo = 0.
    pmz_size = init_param.size
    optimisation_log = {}
    variational_param_history_list = []
    variational_param_post_conv_history_list = []
    # index for iters
    t = 0
    # index for iters after convergence ..
    j = 0
    N_overall = 50000
    sto_process_convergence = False
    sto_process_sigma_conv = False
    sto_process_mean_conv = False
    for o in range(n_optimizers):
        local_log_norm_history = []
        local_grad_history = []
        log_norm_history = []
        value_history = []
        elbo_diff_rel_med = 10.
        elbo_diff_rel_avg = 10.
        elbo_diff_rel_list = []
        np.random.seed(seed=o)
        if o >= 1:
            variational_param = init_param + stats.norm.rvs(
                size=len(init_param)) * (o + 1) * 0.1
        schedule = learning_rate_schedule(n_iters, learning_rate,
                                          learning_rate_end)
        t = 0
        variational_param_history = []
        variational_param_post_conv_history = []
        mcse_all = np.zeros((pmz_size, 1))
        stop = False
        for curr_learning_rate in schedule:
            if t == N_overall:
                break

            if sto_process_convergence:
                j = j + 1

            if has_log_norm == 1:
                obj_val, obj_grad, log_norm = objective_and_grad(
                    variational_param)
            else:
                obj_val, obj_grad = objective_and_grad(variational_param)
                log_norm = 0.

            if stopping_rule == 1 and t > 1000 and t % eval_elbo == 0:
                elbo_diff_rel = np.abs(obj_val - prev_elbo) / (prev_elbo +
                                                               1e-8)
                elbo_diff_rel_list.append(elbo_diff_rel)
                elbo_diff_rel_med = np.nanmedian(elbo_diff_rel_list)
                elbo_diff_rel_avg = np.nanmean(elbo_diff_rel_list)

            prev_elbo = obj_val
            start_stats = 1500
            if stopping_rule == 2 and t > 1500 and t % eval_elbo == 0:
                #print(np.nanmedian(mcse_all[:, -1]))
                mcse_se_combined_list = monte_carlo_se(
                    np.array(variational_param_history)[None, :], 0)
                mcse_all = np.hstack((mcse_all, mcse_se_combined_list[:,
                                                                      None]))

            value_history.append(obj_val)
            local_grad_history.append(obj_grad)
            local_log_norm_history.append(log_norm)
            log_norm_history.append(log_norm)
            if len(local_grad_history) > window:
                local_grad_history.pop(0)
                local_log_norm_history.pop(0)

            grad_scale = np.exp(
                np.min(local_log_norm_history) -
                np.array(local_log_norm_history))
            scaled_grads = grad_scale[:, np.newaxis] * np.array(
                local_grad_history)
            accum_sum = np.sum(scaled_grads**2, axis=0)
            variational_param = variational_param - curr_learning_rate * obj_grad / np.sqrt(
                epsilon + accum_sum)
            #if i >= 0:
            variational_param_history.append(variational_param.copy())

            if t % 10 == 0:
                avg_loss = np.mean(value_history[max(0, t - 1000):t + 1])

            t = t + 1
            if stopping_rule == 1 and stop == False and elbo_diff_rel_med <= tolerance:
                N_overall = t + 100
                stop = True
            if stopping_rule == 1 and stop == False and elbo_diff_rel_avg <= tolerance:
                N_overall = t + 100
                stop = True

            if stopping_rule ==2 and stop == False and sto_process_convergence == True and t > 1500 and \
                    t % eval_elbo == 0 and (np.nanmedian(mcse_all[:,-1]) <= epsilon) and j > 300:
                print('Optimization stopping reliably!')
                stop = True
                break

            variational_param_history_array = np.array(
                variational_param_history)
            if stopping_rule == 2 and t % eval_elbo == 0 and t > 800 and sto_process_convergence == False:
                variational_param_history_list.append(
                    variational_param_history_array)
                variational_param_history_chains = np.stack(
                    variational_param_history_list, axis=0)
                variational_param_history_list.pop(0)
                rhats_halfway_last = compute_R_hat(
                    variational_param_history_chains, warmup=0.5)[1]
                rhat_mean_halfway, rhat_sigma_halfway = rhats_halfway_last[:K], rhats_halfway_last[
                    K:]

                if (rhat_mean_halfway < r_mean_threshold
                    ).all() and sto_process_mean_conv == False:
                    start_swa_m_iters = t
                    print('Rhat- All means converged ...')
                    sto_process_mean_conv = True
                    start_stats = start_swa_m_iters

                if (rhat_sigma_halfway < r_sigma_threshold
                    ).all() and sto_process_sigma_conv == False:
                    start_swa_s_iters = t
                    print('Rhat- All sigmas converged ...')
                    sto_process_sigma_conv = True
                    start_stats = start_swa_s_iters

            if sto_process_mean_conv == True and sto_process_sigma_conv == True:
                sto_process_convergence = True
                start_stats = np.maximum(start_swa_m_iters, start_swa_s_iters)

            if sto_process_convergence:
                variational_param_post_conv_history.append(variational_param)

            if sto_process_convergence and j > 100 and t % (eval_elbo) == 0:
                variational_param_post_conv_history_array = np.array(
                    variational_param_post_conv_history)
                variational_param_post_conv_history_list.append(
                    variational_param_post_conv_history_array)
                variational_param_post_conv_history_chains = np.stack(
                    variational_param_post_conv_history_list, axis=0)
                variational_param_post_conv_history_list.pop(0)
                pmz_size = variational_param_post_conv_history_chains.shape[2]
                Neff = np.zeros(pmz_size)
                Rhot = []
                khat_iterates = []
                khat_iterates2 = []
                # compute khat for iterates
                for k in range(pmz_size):
                    neff, rho_t_sum, autocov, rho_t = autocorrelation(
                        variational_param_post_conv_history_chains, 0, k)
                    #mcse_se_combined = monte_carlo_se2(variational_param_history_chains, start_stats,i)
                    Neff[k] = neff
                    #mcmc_se2.append(mcse_se_combined)
                    Rhot.append(rho_t)
                    khat_i = compute_khat_iterates(
                        variational_param_post_conv_history_chains,
                        0,
                        k,
                        increasing=True)
                    khat_iterates.append(khat_i)
                    khat_i2 = compute_khat_iterates(
                        variational_param_post_conv_history_chains,
                        0,
                        k,
                        increasing=False)
                    khat_iterates2.append(khat_i2)

                rhot_array = np.stack(Rhot, axis=0)
                khat_combined = np.maximum(khat_iterates, khat_iterates2)

    if sto_process_convergence:

        optimisation_log['start_avg_mean_iters'] = start_swa_m_iters
        optimisation_log['start_avg_sigma_iters'] = start_swa_s_iters
        optimisation_log['r_hat_mean_halfway'] = rhat_mean_halfway
        optimisation_log['r_hat_sigma_halfway'] = rhat_sigma_halfway

        try:
            Neff
        except NameError:
            pass
        else:
            optimisation_log['neff'] = Neff
            optimisation_log['autocov'] = autocov
            optimisation_log['rhot'] = rhot_array
            optimisation_log['start_stats'] = start_stats
            # optimisation_log['mcmc_se2'] = mcmc_se2_array
            optimisation_log['khat_iterates_comb'] = khat_combined

    if stopping_rule == 1:
        start_stats = t - tail_avg_iters

    if stopping_rule == 1:
        smoothed_opt_param = np.mean(
            variational_param_history_array[start_stats:, :], axis=0)
    elif stopping_rule == 2 and sto_process_convergence == True:
        smoothed_opt_param = np.mean(
            variational_param_post_conv_history_chains[0, :], axis=0)

    if stopping_rule == 2 and sto_process_convergence == False:
        smoothed_opt_param = np.mean(
            variational_param_history_array[start_stats:, :], axis=0)

    if plotting:
        fig = plt.figure(figsize=(4.2, 2.5))
        ax = fig.add_subplot(1, 1, 1)
        ax.plot(rhot_array[0, :100], label='loc-1')
        ax.plot(rhot_array[1, :100], label='loc-2')
        #ax.plot(rhot_array[2, :100], label='loc-3')
        plt.xlabel('Lags')
        plt.ylabel('autocorrelation')
        plt.legend()
        plt.savefig('autocor_model_adagrad_mean_mf.pdf')

        fig = plt.figure(figsize=(4.2, 2.5))
        ax = fig.add_subplot(1, 1, 1)
        ax.plot(rhot_array[K, :100], label='sigma-1')
        ax.plot(rhot_array[K + 1, :100], label='sigma-2')
        #ax.plot(rhot_array[K + 2, :100], label='sigma-3')
        plt.xlabel('Lags')
        plt.ylabel('autocorrelation')
        plt.legend()
        plt.savefig('autocor_model_adagrad_sigma_mf.pdf')
        khat_array = optimisation_log['khat_iterates_comb']

    return (smoothed_opt_param, variational_param_history,
            np.array(value_history), np.array(log_norm_history),
            optimisation_log)
コード例 #24
0
def adam_workflow_optimize(n_iters,
                           objective_and_grad,
                           init_param,
                           K,
                           has_log_norm=False,
                           window=100,
                           learning_rate=.01,
                           learning_rate_end=None,
                           epsilon=.02,
                           averaging=True,
                           n_optimisers=1,
                           r_mean_threshold=1.20,
                           r_sigma_threshold=1.20,
                           tail_avg_iters=200,
                           eval_elbo=100,
                           tolerance=0.01,
                           stopping_rule=1,
                           plotting=True,
                           model_name=None):
    """
    stopping rule 1 means traditional ELBO stopping rule, while
    stopping rule 2 means MCSE stopping rule.

    The windowed ADAM optimizer with the convergence diagnostics and iterate averaging ...

    :param n_iters:
    :param objective_and_grad:
    :param init_param: initial params
    :param K:
    :param has_log_norm:
    :param window:
    :param learning_rate:
    :param epsilon:
    :param rhat_window:
    :param averaging:
    :param n_optimisers:
    :param r_mean_threshold:
    :param r_sigma_threshold:
    :param tail_avg_iters:
    :param eval_elbo:
    :param tolerance:
    :param stopping_rule:
    :param avg_grad_norm:
    :param learning_rate_end:
    :param plotting:
    :param model_name:
    :return:
    """

    optimisation_log = {}
    variational_param_post_conv_history_list = []
    # index for iters
    t = 0
    # index for iters after convergence ..
    j = 0
    N_overall = 50000
    sto_process_convergence = False
    sto_process_sigma_conv = False
    sto_process_mean_conv = False

    value_history = []
    log_norm_history = []
    variational_param = init_param.copy()
    averaged_variational_param_history = []
    start_avg_iter = n_iters // 1.3
    variational_param_history_list = []
    averaged_variational_mean_list = []
    averaged_variational_sigmas_list = []
    grad_val = 0.
    grad_squared = 0
    beta1 = 0.9
    beta2 = 0.999
    prev_elbo = 0.
    pmz_size = init_param.size
    mcse_all = np.zeros(pmz_size)

    for o in range(n_optimisers):
        np.random.seed(seed=o)
        if o >= 1:
            variational_param = init_param + stats.norm.rvs(
                size=len(init_param)) * (o + 1) * 0.5
        elbo_diff_rel_med = 10.
        elbo_diff_rel_avg = 10.
        local_grad_history = []
        local_log_norm_history = []
        value_history = []
        log_norm_history = []
        averaged_variational_mean_list = []
        averaged_variational_sigmas_list = []
        elbo_diff_rel_list = []
        variational_param = init_param.copy()
        t = 0
        variational_param_history = []
        variational_param_post_conv_history = []
        mcse_all = np.zeros((pmz_size, 1))
        stop = False

        with tqdm.trange(n_iters) as progress:
            try:
                schedule = learning_rate_schedule(n_iters, learning_rate,
                                                  learning_rate_end)
                for i, curr_learning_rate in zip(progress, schedule):
                    if i == N_overall:
                        break

                    if sto_process_convergence:
                        j = j + 1
                    if has_log_norm == 1:
                        obj_val, obj_grad, log_norm = objective_and_grad(
                            variational_param)
                    else:
                        obj_val, obj_grad = objective_and_grad(
                            variational_param)
                        log_norm = 0

                    if stopping_rule == 1 and i > 1000 and i % eval_elbo == 0:
                        elbo_diff_rel = np.abs(obj_val -
                                               prev_elbo) / (prev_elbo + 1e-8)
                        elbo_diff_rel_list.append(elbo_diff_rel)
                        elbo_diff_rel_med = np.nanmedian(elbo_diff_rel_list)
                        elbo_diff_rel_avg = np.nanmean(elbo_diff_rel_list)

                    prev_elbo = obj_val
                    start_stats = 1000
                    mcse_se_combined_list = np.zeros((pmz_size, 1))
                    if stopping_rule == 2 and i > 1000 and i % eval_elbo == 0:
                        mcse_se_combined_list = monte_carlo_se(
                            np.array(variational_param_history)[None, :], 0)
                        mcse_all = np.hstack(
                            (mcse_all, mcse_se_combined_list[:, None]))

                    value_history.append(obj_val)
                    local_grad_history.append(obj_grad)
                    local_log_norm_history.append(log_norm)
                    log_norm_history.append(log_norm)
                    if len(local_grad_history) > window:
                        local_grad_history.pop(0)
                        local_log_norm_history.pop(0)

                    if has_log_norm:
                        grad_norm = np.exp(log_norm)
                    else:
                        grad_norm = np.sum(obj_grad**2, axis=0)
                    if i == 0:
                        grad_squared = 0.9 * obj_grad**2
                        grad_val = 0.9 * obj_grad
                    else:
                        grad_squared = grad_squared * beta2 + (
                            1. - beta2) * obj_grad**2
                        grad_val = grad_val * beta1 + (1. - beta1) * obj_grad
                    grad_scale = np.exp(
                        np.min(local_log_norm_history) -
                        np.array(local_log_norm_history))
                    scaled_grads = grad_scale[:, np.newaxis] * np.array(
                        local_grad_history)
                    accum_sum = np.sum(scaled_grads**2, axis=0)
                    old_variational_param = variational_param.copy()
                    m_hat = grad_val / (1 - np.power(beta1, i + 2))
                    v_hat = grad_squared / (1 - np.power(beta2, i + 2))
                    variational_param = variational_param - curr_learning_rate * m_hat / np.sqrt(
                        epsilon + v_hat)
                    if averaging is True and i > start_avg_iter:
                        averaged_variational_param = (
                            variational_param + old_variational_param *
                            (i - start_avg_iter)) / (i - start_avg_iter + 1)
                        averaged_variational_param_history.append(
                            averaged_variational_param)

                    if i > 100:
                        variational_param_history.append(old_variational_param)

                    if len(variational_param_history) > 100 * window:
                        variational_param_history.pop(0)
                    if i % 100 == 0:
                        avg_loss = np.mean(value_history[max(0, i - 1000):i +
                                                         1])
                        #print(avg_loss)
                        progress.set_description(
                            'Average Loss = {:,.6g}'.format(avg_loss))

                    t = t + 1
                    if stopping_rule == 1 and stop == False and elbo_diff_rel_med <= epsilon:
                        print('Convergence achieved due to ELBO median')
                        N_overall = i + 100
                        stop = True
                    if stopping_rule == 1 and stop == False and elbo_diff_rel_avg <= epsilon:
                        print('Convergence achieved due to ELBO mean')
                        N_overall = i + 100
                        stop = True

                    if stopping_rule == 2 and stop == False and sto_process_convergence == True and i > 1500 and \
                            t % eval_elbo == 0 and (np.nanmedian(mcse_all[:, -1]) <= epsilon) and j > 500:
                        print('Optimization stopping reliably!')
                        stop = True
                        break

                    variational_param_history_array = np.array(
                        variational_param_history)
                    if stopping_rule == 2 and t % eval_elbo == 0 and t > 1000 and sto_process_convergence == False:
                        variational_param_history_list.append(
                            variational_param_history_array)
                        variational_param_history_chains = np.stack(
                            variational_param_history_list, axis=0)
                        variational_param_history_list.pop(0)
                        rhats_halfway_last = compute_R_hat(
                            variational_param_history_chains, warmup=0.5)[1]
                        rhat_mean_halfway, rhat_sigma_halfway = rhats_halfway_last[:K], rhats_halfway_last[
                            K:]
                        if (rhat_mean_halfway < r_mean_threshold
                            ).all() and sto_process_mean_conv == False:
                            start_swa_m_iters = i
                            print('Rhat- All mean converged ...')
                            sto_process_mean_conv = True
                            start_stats = start_swa_m_iters

                        if (rhat_sigma_halfway < r_sigma_threshold
                            ).all() and sto_process_sigma_conv == False:
                            start_swa_s_iters = i
                            print('Rhat- All sigmas converged ...')
                            sto_process_sigma_conv = True
                            start_stats = start_swa_s_iters

                    if sto_process_mean_conv == True and sto_process_sigma_conv == True:
                        sto_process_convergence = True
                        start_stats = np.maximum(start_swa_m_iters,
                                                 start_swa_s_iters)

                    if sto_process_convergence:
                        variational_param_post_conv_history.append(
                            variational_param)

                    if sto_process_convergence and j > 200 and t % eval_elbo == 0:
                        variational_param_post_conv_history_array = np.array(
                            variational_param_post_conv_history)
                        variational_param_post_conv_history_list.append(
                            variational_param_post_conv_history_array)
                        variational_param_post_conv_history_chains = np.stack(
                            variational_param_post_conv_history_list, axis=0)
                        variational_param_post_conv_history_list.pop(0)
                        pmz_size = variational_param_post_conv_history_chains.shape[
                            2]
                        Neff = np.zeros(pmz_size)
                        Rhot = []
                        khat_iterates = []
                        khat_iterates2 = []
                        # compute khat for iterates
                        for z in range(pmz_size):
                            neff, rho_t_sum, autocov, rho_t = autocorrelation(
                                variational_param_post_conv_history_chains, 0,
                                z)
                            Neff[z] = neff
                            Rhot.append(rho_t)
                            khat_i = compute_khat_iterates(
                                variational_param_post_conv_history_chains,
                                0,
                                z,
                                increasing=True)
                            khat_iterates.append(khat_i)
                            khat_i2 = compute_khat_iterates(
                                variational_param_post_conv_history_chains,
                                0,
                                z,
                                increasing=False)
                            khat_iterates2.append(khat_i2)

                        rhot_array = np.stack(Rhot, axis=0)
                        khat_combined = np.maximum(khat_iterates,
                                                   khat_iterates2)

            except (KeyboardInterrupt, StopIteration) as e:
                progress.close()
            finally:
                progress.close()

    if sto_process_convergence:
        optimisation_log['start_avg_mean_iters'] = start_swa_m_iters
        optimisation_log['start_avg_sigma_iters'] = start_swa_s_iters
        optimisation_log['r_hat_mean_halfway'] = rhat_mean_halfway
        optimisation_log['r_hat_sigma_halfway'] = rhat_sigma_halfway
        try:
            Neff
        except NameError:
            pass
        else:
            optimisation_log['neff'] = Neff
            optimisation_log['autocov'] = autocov
            optimisation_log['rhot'] = rhot_array
            optimisation_log['start_stats'] = start_stats
            # optimisation_log['mcmc_se2'] = mcmc_se2_array
            optimisation_log['khat_iterates_comb'] = khat_combined

    if stopping_rule == 1:
        start_stats = i - tail_avg_iters

    if stopping_rule == 1:
        variational_param_history_list.append(variational_param_history_array)
        variational_param_history_chains = np.stack(
            variational_param_history_list, axis=0)
        smoothed_opt_param = np.mean(
            variational_param_history_array[start_stats:, :], axis=0)
        averaged_variational_mean_list.append(smoothed_opt_param[:K])
        averaged_variational_sigmas_list.append(smoothed_opt_param[K:])

    elif stopping_rule == 2 and sto_process_convergence == True:
        smoothed_opt_param = np.mean(
            variational_param_post_conv_history_chains[0, :, :], axis=0)
        averaged_variational_mean_list.append(smoothed_opt_param[:K])
        averaged_variational_sigmas_list.append(smoothed_opt_param[K:])

    if stopping_rule == 2 and sto_process_convergence == False:
        start_stats = t - tail_avg_iters
        variational_param_history_list.append(variational_param_history_array)
        variational_param_history_chains = np.stack(
            variational_param_history_list, axis=0)
        smoothed_opt_param = np.mean(
            variational_param_history_array[start_stats:, :], axis=0)
        averaged_variational_mean_list.append(smoothed_opt_param[:K])
        averaged_variational_sigmas_list.append(smoothed_opt_param[K:])

    if plotting:
        fig = plt.figure(figsize=(4.2, 2.5))
        ax = fig.add_subplot(1, 1, 1)
        ax.plot(rhot_array[0, :100], label='loc-1')
        ax.plot(rhot_array[1, :100], label='loc-2')
        #ax.plot(rhot_array[2, :100], label='loc-3')
        plt.xlabel('Lags')
        plt.ylabel('autocorrelation')
        plt.legend()
        plt.savefig('autocor_model_adam_mean_mf.pdf')

        fig = plt.figure(figsize=(4.2, 2.5))
        ax = fig.add_subplot(1, 1, 1)
        ax.plot(rhot_array[K, :100], label='sigma-1')
        ax.plot(rhot_array[K + 1, :100], label='sigma-2')
        #ax.plot(rhot_array[K + 2, :100], label='sigma-3')
        plt.xlabel('Lags')
        plt.ylabel('autocorrelation')
        plt.legend()
        plt.savefig('autocor_model_adam_sigma_mf.pdf')

    return (variational_param, variational_param_history_chains,
            averaged_variational_mean_list, averaged_variational_sigmas_list,
            np.array(value_history), np.array(log_norm_history),
            optimisation_log)
コード例 #25
0
ファイル: sim_utils.py プロジェクト: dmossing/analysis
def columnize(arr):
    output = np.nanmean(arr, 0).flatten()
    output = output / output.max()
    return output
コード例 #26
0
ファイル: sim_utils.py プロジェクト: dmossing/analysis
 def combiner(sem, axis=1):
     return np.sqrt(np.nanmean(sem**2, axis=axis))
コード例 #27
0
def fit_weights_and_save(weights_file,ca_data_file='rs_vm_denoise_200605.npy',opto_data_file='vip_halo_data_for_sim.npy',constrain_wts=None,allow_var=True,fit_s02=True,constrain_isn=True,tv=False,l2_penalty=0.01,init_noise=0.1,init_W_from_lsq=False,scale_init_by=1,init_W_from_file=False,init_file=None):
    
    
    nsize,ncontrast = 6,6
    
    #print('ca data file: '+ca_data_file) 
    npfile = np.load(ca_data_file,allow_pickle=True)[()]#,{'rs':rs},allow_pickle=True) # ,'rs_denoise':rs_denoise
    rs = npfile['rs']
    #rs_denoise = npfile['rs_denoise']
    
    nsize,ncontrast,ndir = 6,6,8
    ori_dirs = [[0,4],[2,6]] #[[0,4],[1,3,5,7],[2,6]]
    nT = len(ori_dirs)
    #nS = len(rs_denoise[0])
    nS = len(rs[0])
    #print('rs mean: '+str(np.nanmean(rs[0][0])))
    #print('rs isnan: '+str(np.nanmean(np.isnan(rs[0][0]))))
    
    def sum_to_1(r):
        R = r.reshape((r.shape[0],-1))
        #R = R/np.nansum(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis]
        R = R/np.nansum(R,axis=1)[:,np.newaxis] # changed 8/28
        return R
    
    def norm_to_mean(r):
        R = r.reshape((r.shape[0],-1))
        R = R/np.nanmean(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis]
        return R
    
    Rs = [[None,None] for i in range(len(rs))]
    Rso = [[[None for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(rs))]
    rso = [[[None for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(rs))]
    
    for iR,r in enumerate(rs):#rs_denoise):
        print(iR)
        for ialign in range(nS):
            Rs[iR][ialign] = sum_to_1(r[ialign][:,:nsize,:])
    #         Rs[iR][ialign] = von_mises_denoise(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir)))
    #print('Rs isnan: '+str(np.nanmean(np.isnan(Rs[0][0]))))
    
    kernel = np.ones((1,2,2))
    kernel = kernel/kernel.sum()
    
    for iR,r in enumerate(rs):
        for ialign in range(nS):
            for iori in range(nT):
                Rso[iR][ialign][iori] = np.nanmean(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))[:,:,:,ori_dirs[iori]],-1)
                Rso[iR][ialign][iori][:,:,0] = np.nanmean(Rso[iR][ialign][iori][:,:,0],1)[:,np.newaxis]
                #print('Rso isnan before conv: '+str(np.nanmean(np.isnan(Rso[iR][ialign][iori]))))

                Rso[iR][ialign][iori][:,1:,1:] = ssi.convolve(Rso[iR][ialign][iori],kernel,'valid')
                #print('Rso isnan after conv: '+str(np.nanmean(np.isnan(Rso[iR][ialign][iori]))))
                Rso[iR][ialign][iori] = Rso[iR][ialign][iori].reshape(Rso[iR][ialign][iori].shape[0],-1)

    #print('Rso isnan: '+str(np.nanmean(np.isnan(Rso[0][0][0]))))
    
    #kernel = np.ones((1,2,2))
    #kernel = kernel/kernel.sum()
    #
    #for iR,r in enumerate(rs):
    #    for ialign in range(nS):
    #        for iori in range(nT):
    #            Rso[iR][ialign][iori] = np.nanmean(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))[:,:,:,ori_dirs[iori]],-1)
    #            Rso[iR][ialign][iori] = ssi.convolve(Rso[iR][ialign][iori],kernel,'same')
    #            Rso[iR][ialign][iori] = Rso[iR][ialign][iori].reshape(Rso[iR][ialign][iori].shape[0],-1)
    
    def set_bound(bd,code,val=0):
        # set bounds to 0 where 0s occur in 'code'
        for iitem in range(len(bd)):
            bd[iitem][code[iitem]] = val
    
    nN = 36
    nS = 2
    nP = 2
    nT = 2
    nQ = 4
    
    # code for bounds: 0 , constrained to 0
    # +/-1 , constrained to +/-1
    # 1.5, constrained to [0,1]
    # 2 , constrained to [0,inf)
    # -2 , constrained to (-inf,0]
    # 3 , unconstrained
    
    Wmx_bounds = 3*np.ones((nP,nQ),dtype=int)
    Wmx_bounds[0,1] = 0 # SSTs don't receive L4 input
    
    if allow_var:
        Wsx_bounds = 3*np.ones(Wmx_bounds.shape) #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds)
        Wsx_bounds[0,1] = 0
    else:
        Wsx_bounds = np.zeros(Wmx_bounds.shape) #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds)
    
    Wmy_bounds = 3*np.ones((nQ,nQ),dtype=int)
    Wmy_bounds[0,:] = 2 # PCs are excitatory
    Wmy_bounds[1:,:] = -2 # all the cell types except PCs are inhibitory
    Wmy_bounds[1,1] = 0 # SSTs don't inhibit themselves
    # Wmy_bounds[3,1] = 0 # PVs are allowed to inhibit SSTs, consistent with Hillel's unpublished results, but not consistent with Pfeffer et al.
    Wmy_bounds[2,0] = 0 # VIPs don't inhibit L2/3 PCs. According to Pfeffer et al., only L5 PCs were found to get VIP inhibition

    if allow_var:
        Wsy_bounds = 3*np.ones(Wmy_bounds.shape) #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds)
        Wsy_bounds[1,1] = 0
        Wsy_bounds[3,1] = 0 
        Wsy_bounds[2,0] = 0
    else:
        Wsy_bounds = np.zeros(Wmy_bounds.shape) #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds)

    if not constrain_wts is None:
        for wt in constrain_wts:
            Wmy_bounds[wt[0],wt[1]] = 0
            Wsy_bounds[wt[0],wt[1]] = 0
    
    def tile_nS_nT_nN(kernel):
        row = np.concatenate([kernel for idim in range(nS*nT)],axis=0)[np.newaxis,:]
        tiled = np.concatenate([row for irow in range(nN)],axis=0)
        return tiled
    
    if fit_s02:
        s02_bounds = 2*np.ones((nQ,)) # permitting noise as a free parameter
    else:
        s02_bounds = np.ones((nQ,))
    
    k_bounds = 1.5*np.ones((nQ,))
    
    kappa_bounds = np.ones((1,))
    # kappa_bounds = 2*np.ones((1,))
    
    T_bounds = 1.5*np.ones((nQ,))
    #T_bounds[2:4] = 1 # PV and VIP are constrained to have flat ori tuning
    T_bounds[1:4] = 1 # SST,VIP, and PV are constrained to have flat ori tuning
    
    X_bounds = tile_nS_nT_nN(np.array([2,1]))
    # X_bounds = np.array([np.array([2,1,2,1])]*nN)
    
    Xp_bounds = tile_nS_nT_nN(np.array([3,1]))
    # Xp_bounds = np.array([np.array([3,1,3,1])]*nN)
    
    # Y_bounds = tile_nS_nT_nN(2*np.ones((nQ,)))
    # # Y_bounds = 2*np.ones((nN,nT*nS*nQ))
    
    Eta_bounds = tile_nS_nT_nN(3*np.ones((nQ,)))
    # Eta_bounds = 3*np.ones((nN,nT*nS*nQ))
    
    if allow_var:
        Xi_bounds = tile_nS_nT_nN(3*np.ones((nQ,)))
    else:
        Xi_bounds = tile_nS_nT_nN(np.zeros((nQ,)))

    # Xi_bounds = 3*np.ones((nN,nT*nS*nQ))
    
    h_bounds = -2*np.ones((1,))
    
    # shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nN,nS*nP),(nN,nS*nQ),(nN,nS*nQ),(nN,nS*nQ)]
    shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nQ,),(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ),(1,)]
    #         Wmx,    Wmy,    Wsx,    Wsy,    s02,  k,    kappa,T,   XX,            XXp,          Eta,          Xi
    
    lb = [-np.inf*np.ones(shp) for shp in shapes]
    ub = [np.inf*np.ones(shp) for shp in shapes]
    bdlist = [Wmx_bounds,Wmy_bounds,Wsx_bounds,Wsy_bounds,s02_bounds,k_bounds,kappa_bounds,T_bounds,X_bounds,Xp_bounds,Eta_bounds,Xi_bounds,h_bounds]
    
    set_bound(lb,[bd==0 for bd in bdlist],val=0)
    set_bound(ub,[bd==0 for bd in bdlist],val=0)
    
    set_bound(lb,[bd==2 for bd in bdlist],val=0)
    
    set_bound(ub,[bd==-2 for bd in bdlist],val=0)
    
    set_bound(lb,[bd==1 for bd in bdlist],val=1)
    set_bound(ub,[bd==1 for bd in bdlist],val=1)
    
    set_bound(lb,[bd==1.5 for bd in bdlist],val=0)
    set_bound(ub,[bd==1.5 for bd in bdlist],val=1)
    
    set_bound(lb,[bd==-1 for bd in bdlist],val=-1)
    set_bound(ub,[bd==-1 for bd in bdlist],val=-1)
    
    # for bd in [lb,ub]:
    #     for ind in [2,3]:
    #         bd[ind][:,1] = 0
    
    # temporary for no variation expt.
    # lb[2] = np.zeros_like(lb[2])
    # lb[3] = np.zeros_like(lb[3])
    # lb[4] = np.ones_like(lb[4])
    # lb[5] = np.zeros_like(lb[5])
    # ub[2] = np.zeros_like(ub[2])
    # ub[3] = np.zeros_like(ub[3])
    # ub[4] = np.ones_like(ub[4])
    # ub[5] = np.ones_like(ub[5])
    # temporary for no variation expt.
    lb = np.concatenate([a.flatten() for a in lb])
    ub = np.concatenate([b.flatten() for b in ub])
    bounds = [(a,b) for a,b in zip(lb,ub)]
    
    nS = 2
    ndims = 5
    ncelltypes = 5
    Yhat = [[None for iT in range(nT)] for iS in range(nS)]
    Xhat = [[None for iT in range(nT)] for iS in range(nS)]
    Ypc_list = [[None for iT in range(nT)] for iS in range(nS)]
    Xpc_list = [[None for iT in range(nT)] for iS in range(nS)]
    for iS in range(nS):
        mx = np.zeros((ncelltypes,))
        yy = [None for icelltype in range(ncelltypes)]
        for icelltype in range(ncelltypes):
            yy[icelltype] = np.nanmean(Rso[icelltype][iS][0],0)
            mx[icelltype] = np.nanmax(yy[icelltype])
        for iT in range(nT):
            y = [np.nanmean(Rso[icelltype][iS][iT],axis=0)[:,np.newaxis]/mx[icelltype] for icelltype in range(1,ncelltypes)]
            Ypc_list[iS][iT] = [None for icelltype in range(1,ncelltypes)]
            for icelltype in range(1,ncelltypes):
                rss = Rso[icelltype][iS][iT].copy() #.reshape(Rs[icelltype][ialign].shape[0],-1)
                rss = rss[np.isnan(rss).sum(1)==0]
        #         print(rss.max())
        #         rss[rss<0] = 0
        #         rss = rss[np.random.randn(rss.shape[0])>0]
                try:
                    u,s,v = np.linalg.svd(rss-np.mean(rss,0)[np.newaxis])
                    Ypc_list[iS][iT][icelltype-1] = [(s[idim],v[idim]) for idim in range(ndims)]
    #                 print('yep on Y')
    #                 print(np.min(np.sum(rs[icelltype][iS][iT],axis=1)))
                except:
                    print('nope on Y')
                    #print(np.mean(np.isnan(rss)))
                    #print(np.min(np.sum(rs[icelltype][iS][iT],axis=1)))
            Yhat[iS][iT] = np.concatenate(y,axis=1)
    #         x = sim_utils.columnize(Rso[0][iS][iT])[:,np.newaxis]
            icelltype = 0
            x = np.nanmean(Rso[icelltype][iS][iT],0)[:,np.newaxis]/mx[icelltype]
    #         opto_column = np.concatenate((np.zeros((nN,)),np.zeros((nNO/2,)),np.ones((nNO/2,))),axis=0)[:,np.newaxis]
            Xhat[iS][iT] = np.concatenate((x,np.ones_like(x)),axis=1)
    #         Xhat[iS][iT] = np.concatenate((x,np.ones_like(x),opto_column),axis=1)
            icelltype = 0
            rss = Rso[icelltype][iS][iT].copy()
            rss = rss[np.isnan(rss).sum(1)==0]
    #         try:
            u,s,v = np.linalg.svd(rss-rss.mean(0)[np.newaxis])
            #print(np.min(np.isnan(Rso[icelltype][iS][iT]).sum(1)))
            #print('Rso shape: '+str(Rso[icelltype][iS][iT].shape))
            #print('rss shape: '+str(rss.shape))
            #print('s shape: '+str(s.shape))
            #print('v shape: '+str(v.shape))
            Xpc_list[iS][iT] = [None for iinput in range(2)]
            Xpc_list[iS][iT][0] = [(s[idim],v[idim]) for idim in range(ndims)]
            Xpc_list[iS][iT][1] = [(0,np.zeros((Xhat[0][0].shape[0],))) for idim in range(ndims)]
    #         except:
    #             print('nope on X')
    #             print(np.mean(np.isnan(rss)))
    #             print(np.min(np.sum(Rso[icelltype][iS][iT],axis=1)))
    nN,nP = Xhat[0][0].shape
    nQ = Yhat[0][0].shape[1]
    
    def compute_f_(Eta,Xi,s02):
        return sim_utils.f_miller_troyer(Eta,Xi**2+np.concatenate([s02 for ipixel in range(nS*nT)]))
    def compute_fprime_m_(Eta,Xi,s02):
        return sim_utils.fprime_miller_troyer(Eta,Xi**2+np.concatenate([s02 for ipixel in range(nS*nT)]))*Xi
    def compute_fprime_s_(Eta,Xi,s02):
        s2 = Xi**2+np.concatenate((s02,s02),axis=0)
        return sim_utils.fprime_s_miller_troyer(Eta,s2)*(Xi/s2)
    def sorted_r_eigs(w):
        drW,prW = np.linalg.eig(w)
        srtinds = np.argsort(drW)
        return drW[srtinds],prW[:,srtinds]
    
    #         0.Wmx,  1.Wmy,  2.Wsx,  3.Wsy,  4.s02,5.K,  6.kappa,7.T,8.XX,        9.XXp,        10.Eta,       11.Xi
    
    shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nQ,),(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ),(1,)]
    
    import calnet.fitting_spatial_feature
    import sim_utils
    
    opto_dict = np.load(opto_data_file,allow_pickle=True)[()]
    
    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    print(Yhat_opto.shape)
    h_opto = opto_dict['h_opto']
    dYY = Yhat_opto[1::2]-Yhat_opto[0::2]
    for to_overwrite in [1,2,5,6]:
        dYY[:,to_overwrite] = dYY[:,to_overwrite+8]
    for to_overwrite in [11,15]:
        dYY[:,to_overwrite] = dYY[:,to_overwrite-8]
    
    from importlib import reload
    reload(calnet)
    reload(calnet.fitting_spatial_feature_opto)
    reload(sim_utils)
    # reload(calnet.fitting_spatial_feature)
    # W0list = [np.ones(shp) for shp in shapes]
    wt_dict = {}
    wt_dict['X'] = 1
    wt_dict['Y'] = 3
    wt_dict['Eta'] = 1# 10
    wt_dict['Xi'] = 0.1
    wt_dict['stims'] = np.ones((nN,1)) #(np.arange(30)/30)[:,np.newaxis]**1 #
    wt_dict['barrier'] = 0. #30.0 #0.1
    wt_dict['opto'] = 1e0#1e-1#1e1
    wt_dict['isn'] = 0.1
    wt_dict['tv'] = 0.

    YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat)
    XXhat = calnet.utils.flatten_nested_list_of_2d_arrays(Xhat)
    Eta0 = invert_f_mt(YYhat)

    ntries = 1
    nhyper = 1
    dt = 1e-1
    niter = int(np.round(10/dt)) #int(1e4)
    perturbation_size = 5e-2
    # learning_rate = 1e-4 # 1e-5 #np.linspace(3e-4,1e-3,niter+1) # 1e-5
    #l2_penalty = 0.1
    Wt = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    loss = np.zeros((nhyper,ntries))
    is_neg = np.array([b[1] for b in bounds])==0
    counter = 0
    negatize = [np.zeros(shp,dtype='bool') for shp in shapes]
    for ishp,shp in enumerate(shapes):
        nel = np.prod(shp)
        negatize[ishp][:][is_neg[counter:counter+nel].reshape(shp)] = True
        counter = counter + nel
    for ihyper in range(nhyper):
        for itry in range(ntries):
            print((ihyper,itry))
            W0list = [init_noise*(ihyper+1)*np.random.rand(*shp) for shp in shapes]
            counter = 0
            for ishp,shp in enumerate(shapes):
                W0list[ishp][negatize[ishp]] = -W0list[ishp][negatize[ishp]]
            W0list[4] = np.ones(shapes[5]) # s02
            W0list[5] = np.ones(shapes[5]) # K
            W0list[6] = np.ones(shapes[6]) # kappa
            W0list[7] = np.ones(shapes[7]) # T
            W0list[8] = np.concatenate(Xhat,axis=1) #XX
            W0list[9] = np.zeros_like(W0list[8]) #XXp
            W0list[10] = Eta0 #np.zeros(shapes[10]) #Eta
            W0list[11] = np.zeros(shapes[11]) #Xi
            #[Wmx,Wmy,Wsx,Wsy,s02,k,kappa,T,XX,XXp,Eta,Xi]
    #         W0list = Wstar_dict['as_list'].copy()
    #         W0list[1][1,0] = -1.5
    #         W0list[1][3,0] = -1.5
            if init_W_from_lsq:
                W0list[0],W0list[1] = initialize_W(Xhat,Yhat,scale_by=scale_init_by)
                for ivar in range(0,2):
                    W0list[ivar] = W0list[ivar] + init_noise*np.random.randn(*W0list[ivar].shape)
            if constrain_isn:
                W0list[1][0,0] = 3 
                W0list[1][0,3] = 5 
                W0list[1][3,0] = -5
                W0list[1][3,3] = -5

            if init_W_from_file:
                npyfile = np.load(init_file,allow_pickle=True)[()]
                W0list = npyfile['as_list']

                #W0list[7][0] = 0 # T

                # alternative initialization
                #n = 0.5
                #W0list[7][0] = 1/(n+1)*(W0list[7][0] + n*0) # T
                #W0list[7][3] = 1/(n+1)*(W0list[7][3] + n*1) # T
                #W0list[1][1,0] = W0list[1][1,0]

                #[Wmx,Wmy,Wsx,Wsy,s02,k,kappa,T,XX,XXp,Eta,Xi]
                for ivar in [0,1,4,5]: # Wmx, Wmy, s02, k
                    W0list[ivar] = W0list[ivar] + init_noise*np.random.randn(*W0list[ivar].shape)

            # wt_dict['Xi'] = 10
            # wt_dict['Eta'] = 10
            Wt[ihyper][itry],loss[ihyper][itry],gr,hess,result = calnet.fitting_spatial_feature_opto.fit_W_sim(Xhat,Xpc_list,Yhat,Ypc_list,pop_rate_fn=sim_utils.f_miller_troyer,pop_deriv_fn=sim_utils.fprime_miller_troyer,neuron_rate_fn=sim_utils.evaluate_f_mt,W0list=W0list.copy(),bounds=bounds,niter=niter,wt_dict=wt_dict,l2_penalty=l2_penalty,compute_hessian=False,dt=dt,perturbation_size=perturbation_size,dYY=dYY,constrain_isn=constrain_isn,tv=tv)
    #         Wt[ihyper][itry] = [w[-1] for w in Wt_temp]
    #         loss[ihyper,itry] = loss_temp[-1]
    
    def parse_W(W):
        Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h = W
        return Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h
    
    
    itry = 0
    Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h = parse_W(Wt[0][0])
    
    labels = ['Wmx','Wmy','Wsx','Wsy','s02','K','kappa','T','XX','XXp','Eta','Xi','h']
    Wstar_dict = {}
    for i,label in enumerate(labels):
        Wstar_dict[label] = Wt[0][0][i]
    Wstar_dict['as_list'] = [Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h]
    Wstar_dict['loss'] = loss[0][0]
    Wstar_dict['wt_dict'] = wt_dict
    np.save(weights_file,Wstar_dict,allow_pickle=True)
コード例 #28
0
 def norm_to_mean(r):
     R = r.reshape((r.shape[0],-1))
     R = R/np.nanmean(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis]
     return R
コード例 #29
0
ファイル: DDP.py プロジェクト: nik31096/control_algorithms
def boxQP(H, g, lower, upper, x0):
    n = H.shape[0]
    clamped = np.zeros(n)
    free = np.ones(n)
    Hfree = np.zeros(n)
    oldvalue = 0
    result = 0
    nfactor = 0
    clamp = lambda value: np.maximum(lower, np.minimum(upper, value))

    maxIter = 100
    minRelImprove = 1e-8
    minGrad = 1e-8
    stepDec = 0.6
    minStep = 1e-22
    Armijo = 0.1

    if x0.shape[0] == n:
        x = clamp(x0)
    else:
        lu = np.array([lower, upper])
        lu[np.isnan(lu)] = np.nan
        x = np.nanmean(lu, axis=1)

    value = np.dot(x.T, np.dot(H, x)) + np.dot(x.T, g)

    for iteration in range(maxIter):
        if result != 0:
            break

        if iteration > 1 and (oldvalue - value) < minRelImprove * abs(oldvalue):
            result = 4
            logging.info("[QP info] Improvement smaller than tolerance")
            break

        oldvalue = value

        grad = g + np.dot(H, x)

        old_clamped = clamped
        clamped = np.zeros(n)
        clamped[np.logical_and(x == lower, grad > 0)] = 1
        clamped[np.logical_and(x == upper, grad < 0)] = 1
        free = np.logical_not(clamped)

        if np.all(clamped):
            result = 6
            logging.info("[QP info] All dimensions are clamped")
            break

        if iteration == 0:
            factorize = True
        else:
            factorize = np.any(old_clamped != clamped)

        if factorize:
            try:
                if not np.all(np.allclose(H, H.T)):
                    H = np.triu(H)
                Hfree = np.linalg.cholesky(H[np.ix_(free, free)])
            except LinAlgError:
                eigs, _ = np.linalg.eig(H[np.ix_(free, free)])
                print(eigs)
                result = -1
                logging.info("[QP info] Hessian is not positive definite")
                break
            nfactor += 1

        gnorm = np.linalg.norm(grad[free])
        if gnorm < minGrad:
            result = 5
            logging.info("[QP info] Gradient norm smaller than tolerance")
            break

        grad_clamped = g + np.dot(H, x*clamped)
        search = np.zeros(n)

        y = np.linalg.lstsq(Hfree.T, grad_clamped[free])[0]
        search[free] = -np.linalg.lstsq(Hfree, y)[0] - x[free]
        sdotg = np.sum(search*grad)
        if sdotg >= 0:
            print(f"[QP info] No descent direction found. Should not happen. Grad is {grad}")
            break

        # armijo linesearch
        step = 1
        nstep = 0
        xc = clamp(x + step*search)
        vc = np.dot(xc.T, g) + 0.5*np.dot(xc.T, np.dot(H, xc))
        while (vc - oldvalue) / (step*sdotg) < Armijo:
            step *= stepDec
            nstep += 1
            xc = clamp(x + step * search)
            vc = np.dot(xc.T, g) + 0.5 * np.dot(xc.T, np.dot(H, xc))

            if step < minStep:
                result = 2
                break

        # accept candidate
        x = xc
        value = vc

        # print(f"[QP info] Iteration {iteration}, value of the cost: {vc}")

    if iteration >= maxIter:
        result = 1

    return x, result, Hfree, free
コード例 #30
0
def fit_weights_and_save(
        weights_file,
        ca_data_file='rs_vm_denoise_200605.npy',
        opto_silencing_data_file='vip_halo_data_for_sim.npy',
        opto_activation_data_file='vip_chrimson_data_for_sim.npy',
        constrain_wts=None,
        allow_var=True,
        fit_s02=True,
        constrain_isn=True,
        tv=False,
        l2_penalty=0.01,
        init_noise=0.1,
        init_W_from_lsq=False,
        init_W_from_lbfgs=False,
        scale_init_by=1,
        init_W_from_file=False,
        init_file=None,
        correct_Eta=False,
        init_Eta_with_s02=False,
        init_Eta12_with_dYY=False,
        use_opto_transforms=False,
        share_residuals=False,
        stimwise=False,
        simulate1=True,
        simulate2=False,
        help_constrain_isn=True,
        ignore_halo_vip=False,
        verbose=True,
        free_amplitude=False,
        norm_opto_transforms=False,
        zero_extra_weights=None,
        allow_s2=True):

    nsize, ncontrast = 6, 6

    npfile = np.load(ca_data_file, allow_pickle=True)[(
    )]  #,{'rs':rs,'rs_denoise':rs_denoise},allow_pickle=True)
    rs = npfile['rs']
    #rs_denoise = npfile['rs_denoise']

    nsize, ncontrast, ndir = 6, 6, 8
    #ori_dirs = [[0,4],[2,6]] #[[0,4],[1,3,5,7],[2,6]]
    ori_dirs = [[0, 1, 2, 3, 4, 5, 6, 7]]
    nT = len(ori_dirs)
    nS = len(rs[0])

    def sum_to_1(r):
        R = r.reshape((r.shape[0], -1))
        #R = R/np.nansum(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis]
        R = R / np.nansum(R, axis=1)[:, np.newaxis]  # changed 8/28
        return R

    def norm_to_mean(r):
        R = r.reshape((r.shape[0], -1))
        R = R / np.nanmean(R[:, ~np.isnan(R.sum(0))], axis=1)[:, np.newaxis]
        return R

    Rs = [[None, None] for i in range(len(rs))]
    Rso = [[[None for iT in range(nT)] for iS in range(nS)]
           for icelltype in range(len(rs))]
    rso = [[[None for iT in range(nT)] for iS in range(nS)]
           for icelltype in range(len(rs))]

    for iR, r in enumerate(rs):  #rs_denoise):
        #print(iR)
        for ialign in range(nS):
            #Rs[iR][ialign] = r[ialign][:,:nsize,:]
            #sm = np.nanmean(np.nansum(np.nansum(Rs[iR][ialign],1),1))
            #Rs[iR][ialign] = Rs[iR][ialign]/sm
            #print('frac isnan Rs %d,%d: %f'%(iR,ialign,np.isnan(r[ialign]).mean()))
            Rs[iR][ialign] = sum_to_1(r[ialign][:, :nsize, :])
    #         Rs[iR][ialign] = von_mises_denoise(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir)))

    kernel = np.ones((1, 2, 2))
    kernel = kernel / kernel.sum()

    for iR, r in enumerate(rs):
        for ialign in range(nS):
            for iori in range(nT):
                #print('this Rs shape: '+str(Rs[iR][ialign].shape))
                #print('this Rs reshaped shape: '+str(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))[:,:,:,ori_dirs[iori]].shape))
                #print('this Rs max percent nan: '+str(np.isnan(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))[:,:,:,ori_dirs[iori]]).mean(-1).max()))
                Rso[iR][ialign][iori] = np.nanmean(
                    Rs[iR][ialign].reshape(
                        (-1, nsize, ncontrast, ndir))[:, :, :, ori_dirs[iori]],
                    -1)
                Rso[iR][ialign][iori][:, :, 0] = np.nanmean(
                    Rso[iR][ialign][iori][:, :, 0],
                    1)[:, np.newaxis]  # average 0 contrast values
                #print('frac isnan pre-conv Rso %d,%d,%d: %f'%(iR,ialign,iori,np.isnan(Rso[iR][ialign][iori]).mean()))
                Rso[iR][ialign][iori][:, 1:, 1:] = ssi.convolve(
                    Rso[iR][ialign][iori], kernel, 'valid')
                Rso[iR][ialign][iori] = Rso[iR][ialign][iori].reshape(
                    Rso[iR][ialign][iori].shape[0], -1)
                #print('frac isnan Rso %d,%d,%d: %f'%(iR,ialign,iori,np.isnan(Rso[iR][ialign][iori]).mean()))
                #print('sum of Rso isnan: '+str(np.isnan(Rso[iR][ialign][iori]).sum(1)))
                #Rso[iR][ialign][iori] = Rso[iR][ialign][iori]/np.nanmean(Rso[iR][ialign][iori],-1)[:,np.newaxis]

    def set_bound(bd, code, val=0):
        # set bounds to 0 where 0s occur in 'code'
        for iitem in range(len(bd)):
            bd[iitem][code[iitem]] = val

    nN = 36
    nS = 2
    nP = 2
    nT = 1
    nQ = 4

    # code for bounds: 0 , constrained to 0
    # +/-1 , constrained to +/-1
    # 1.5, constrained to [0,1]
    # 2 , constrained to [0,inf)
    # -2 , constrained to (-inf,0]
    # 3 , unconstrained

    Wmx_bounds = 3 * np.ones((nP, nQ), dtype=int)
    Wmx_bounds[0, :] = 2  # L4 PCs are excitatory
    Wmx_bounds[0, 1] = 0  # SSTs don't receive L4 input

    if allow_var:
        Wsx_bounds = 3 * np.ones(
            Wmx_bounds.shape)  #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds)
        Wsx_bounds[0, 1] = 0
    else:
        Wsx_bounds = np.zeros(
            Wmx_bounds.shape)  #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds)

    Wmy_bounds = 3 * np.ones((nQ, nQ), dtype=int)
    Wmy_bounds[0, :] = 2  # PCs are excitatory
    Wmy_bounds[1:, :] = -2  # all the cell types except PCs are inhibitory
    Wmy_bounds[1, 1] = 0  # SSTs don't inhibit themselves
    # Wmy_bounds[3,1] = 0 # PVs are allowed to inhibit SSTs, consistent with Hillel's unpublished results, but not consistent with Pfeffer et al.
    Wmy_bounds[
        2,
        0] = 0  # VIPs don't inhibit L2/3 PCs. According to Pfeffer et al., only L5 PCs were found to get VIP inhibition

    if not zero_extra_weights is None:
        Wmx_bounds[zero_extra_weights[0]] = 0
        Wmy_bounds[zero_extra_weights[1]] = 0

    if allow_var:
        Wsy_bounds = 3 * np.ones(
            Wmy_bounds.shape)  #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds)
        Wsy_bounds[1, 1] = 0
        Wsy_bounds[3, 1] = 0
        Wsy_bounds[2, 0] = 0
    else:
        Wsy_bounds = np.zeros(
            Wmy_bounds.shape)  #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds)

    if not constrain_wts is None:
        for wt in constrain_wts:
            Wmy_bounds[wt[0], wt[1]] = 0
            Wsy_bounds[wt[0], wt[1]] = 0

    def tile_nS_nT_nN(kernel):
        row = np.concatenate([kernel for idim in range(nS * nT)],
                             axis=0)[np.newaxis, :]
        tiled = np.concatenate([row for irow in range(nN)], axis=0)
        return tiled

    def set_bounds_by_code(lb, ub, bdlist):
        set_bound(lb, [bd == 0 for bd in bdlist], val=0)
        set_bound(ub, [bd == 0 for bd in bdlist], val=0)

        set_bound(lb, [bd == 2 for bd in bdlist], val=0)

        set_bound(ub, [bd == -2 for bd in bdlist], val=0)

        set_bound(lb, [bd == 1 for bd in bdlist], val=1)
        set_bound(ub, [bd == 1 for bd in bdlist], val=1)

        set_bound(lb, [bd == 1.5 for bd in bdlist], val=0)
        set_bound(ub, [bd == 1.5 for bd in bdlist], val=1)

        set_bound(lb, [bd == -1 for bd in bdlist], val=-1)
        set_bound(ub, [bd == -1 for bd in bdlist], val=-1)

    if allow_s2:
        if fit_s02:
            s02_bounds = 2 * np.ones(
                (nQ, ))  # permitting noise as a free parameter
        else:
            s02_bounds = np.ones((nQ, ))
    else:
        s02_bounds = np.zeros((nQ, ))

    k_bounds = 1.5 * np.ones((nQ * (nS - 1), ))

    #k_bounds[1] = 0 # temporary: spatial kernel constrained to 0 for SST
    #k_bounds[2] = 0 # temporary: spatial kernel constrained to 0 for VIP

    kappa_bounds = np.ones((1, ))
    # kappa_bounds = 2*np.ones((1,))

    T_bounds = 1.5 * np.ones((nQ * (nT - 1), ))

    X_bounds = tile_nS_nT_nN(np.array([2, 1]))
    # X_bounds = np.array([np.array([2,1,2,1])]*nN)

    Xp_bounds = tile_nS_nT_nN(np.array([3, 1]))
    # Xp_bounds = np.array([np.array([3,1,3,1])]*nN)

    # Y_bounds = tile_nS_nT_nN(2*np.ones((nQ,)))
    # # Y_bounds = 2*np.ones((nN,nT*nS*nQ))

    Eta_bounds = tile_nS_nT_nN(3 * np.ones((nQ, )))
    # Eta_bounds = 3*np.ones((nN,nT*nS*nQ))

    if allow_s2:
        if allow_var:
            Xi_bounds = tile_nS_nT_nN(3 * np.ones((nQ, )))
        else:
            Xi_bounds = tile_nS_nT_nN(np.zeros((nQ, )))
    else:
        Xi_bounds = tile_nS_nT_nN(np.zeros((nQ, )))

    # Xi_bounds = 3*np.ones((nN,nT*nS*nQ))

    h1_bounds = -2 * np.ones((1, ))

    h2_bounds = 2 * np.ones((1, ))

    bl_bounds = 3 * np.ones((nQ, ))

    if free_amplitude:
        amp_bounds = 2 * np.ones((nT * nS * nQ, ))
    else:
        amp_bounds = 1 * np.ones((nT * nS * nQ, ))

    # shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nN,nS*nP),(nN,nS*nQ),(nN,nS*nQ),(nN,nS*nQ)]
    shapes1 = [(nP, nQ), (nQ, nQ), (nP, nQ),
               (nQ, nQ), (nQ, ), (nQ * (nS - 1), ), (1, ), (nQ * (nT - 1), ),
               (1, ), (1, ), (nQ, ), (nQ * nS * nT, )]
    shapes2 = [(nN, nT * nS * nP), (nN, nT * nS * nP), (nN, nT * nS * nQ),
               (nN, nT * nS * nQ)]
    #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1])))
    #print('size of shapes2: '+str(np.sum([np.prod(shp) for shp in shapes2])))
    #         Wmx,    Wmy,    Wsx,    Wsy,    s02,  k,    kappa,T,   h1, h2
    #XX,            XXp,          Eta,          Xi

    #bdlist = [Wmx_bounds,Wmy_bounds,Wsx_bounds,Wsy_bounds,s02_bounds,k_bounds,kappa_bounds,T_bounds,X_bounds,Xp_bounds,Eta_bounds,Xi_bounds,h1_bounds,h2_bounds]
    bd1list = [
        Wmx_bounds, Wmy_bounds, Wsx_bounds, Wsy_bounds, s02_bounds, k_bounds,
        kappa_bounds, T_bounds, h1_bounds, h2_bounds, bl_bounds, amp_bounds
    ]
    bd2list = [X_bounds, Xp_bounds, Eta_bounds, Xi_bounds]

    lb1, ub1 = [[sgn * np.inf * np.ones(shp) for shp in shapes1]
                for sgn in [-1, 1]]
    set_bounds_by_code(lb1, ub1, bd1list)
    lb2, ub2 = [[sgn * np.inf * np.ones(shp) for shp in shapes2]
                for sgn in [-1, 1]]
    set_bounds_by_code(lb2, ub2, bd2list)

    #set_bound(lb,[bd==0 for bd in bdlist],val=0)
    #set_bound(ub,[bd==0 for bd in bdlist],val=0)
    #
    #set_bound(lb,[bd==2 for bd in bdlist],val=0)
    #
    #set_bound(ub,[bd==-2 for bd in bdlist],val=0)
    #
    #set_bound(lb,[bd==1 for bd in bdlist],val=1)
    #set_bound(ub,[bd==1 for bd in bdlist],val=1)
    #
    #set_bound(lb,[bd==1.5 for bd in bdlist],val=0)
    #set_bound(ub,[bd==1.5 for bd in bdlist],val=1)
    #
    #set_bound(lb,[bd==-1 for bd in bdlist],val=-1)
    #set_bound(ub,[bd==-1 for bd in bdlist],val=-1)

    # for bd in [lb,ub]:
    #     for ind in [2,3]:
    #         bd[ind][:,1] = 0

    # temporary for no variation expt.
    # lb[2] = np.zeros_like(lb[2])
    # lb[3] = np.zeros_like(lb[3])
    # lb[4] = np.ones_like(lb[4])
    # lb[5] = np.zeros_like(lb[5])
    # ub[2] = np.zeros_like(ub[2])
    # ub[3] = np.zeros_like(ub[3])
    # ub[4] = np.ones_like(ub[4])
    # ub[5] = np.ones_like(ub[5])
    # temporary for no variation expt.
    lb1 = np.concatenate([a.flatten() for a in lb1])
    ub1 = np.concatenate([b.flatten() for b in ub1])
    lb2 = np.concatenate([a.flatten() for a in lb2])
    ub2 = np.concatenate([b.flatten() for b in ub2])
    bounds1 = [(a, b) for a, b in zip(lb1, ub1)]
    bounds2 = [(a, b) for a, b in zip(lb2, ub2)]

    nS = 2
    #print('nT: '+str(nT))
    ndims = 5
    ncelltypes = 5
    Yhat = [[None for iT in range(nT)] for iS in range(nS)]
    Xhat = [[None for iT in range(nT)] for iS in range(nS)]
    Ypc_list = [[None for iT in range(nT)] for iS in range(nS)]
    Xpc_list = [[None for iT in range(nT)] for iS in range(nS)]
    mx = [None for iS in range(nS)]
    for iS in range(nS):
        mx[iS] = np.zeros((ncelltypes, ))
        yy = [None for icelltype in range(ncelltypes)]
        for icelltype in range(ncelltypes):
            yy[icelltype] = np.nanmean(Rso[icelltype][iS][0], 0)
            mx[iS][icelltype] = np.nanmax(yy[icelltype])
        for iT in range(nT):
            y = [
                np.nanmean(Rso[icelltype][iS][iT], axis=0)[:, np.newaxis] /
                mx[iS][icelltype] for icelltype in range(1, ncelltypes)
            ]
            Ypc_list[iS][iT] = [None for icelltype in range(1, ncelltypes)]
            for icelltype in range(1, ncelltypes):
                # as currently written, penalties involving (X,Y)pc_list are effectively artificially smaller by
                # a factor of mx[iS][icelltype] compared to what one would expect from the (X,Y)-penalty as defined
                # subsequently.
                rss = Rso[icelltype][iS][iT].copy(
                )  #/mx[iS][icelltype] #.reshape(Rs[icelltype][ialign].shape[0],-1)
                #print('sum of isnan: '+str(np.isnan(rss).sum(1)))
                #rss = Rso[icelltype][iS][iT].copy() #.reshape(Rs[icelltype][ialign].shape[0],-1)
                rss = rss[np.isnan(rss).sum(1) == 0]
                #         print(rss.max())
                #         rss[rss<0] = 0
                #         rss = rss[np.random.randn(rss.shape[0])>0]
                try:
                    u, s, v = np.linalg.svd(rss - np.mean(rss, 0)[np.newaxis])
                    Ypc_list[iS][iT][icelltype - 1] = [
                        (s[idim], v[idim]) for idim in range(ndims)
                    ]
    #                 print('yep on Y')
    #                 print(np.min(np.sum(rs[icelltype][iS][iT],axis=1)))
                except:
                    print('nope on Y')
                    #print('shape of rss: '+str(rss.shape))
                    #print('mean of rss: '+str(np.mean(np.isnan(rss))))
                    #print('min of this rs: '+str(np.min(np.sum(rs[icelltype][iS][iT],axis=1))))
            Yhat[iS][iT] = np.concatenate(y, axis=1)
            #         x = sim_utils.columnize(Rso[0][iS][iT])[:,np.newaxis]
            icelltype = 0
            #x = np.nanmean(Rso[icelltype][iS][iT],0)[:,np.newaxis]#/mx[iS][icelltype]
            x = np.nanmean(Rso[icelltype][iS][iT],
                           0)[:, np.newaxis] / mx[iS][icelltype]
            #         opto_column = np.concatenate((np.zeros((nN,)),np.zeros((nNO/2,)),np.ones((nNO/2,))),axis=0)[:,np.newaxis]
            Xhat[iS][iT] = np.concatenate((x, np.ones_like(x)), axis=1)
            #         Xhat[iS][iT] = np.concatenate((x,np.ones_like(x),opto_column),axis=1)
            icelltype = 0
            #rss = Rso[icelltype][iS][iT].copy()/mx[iS][icelltype]
            rss = Rso[icelltype][iS][iT].copy()
            rss = rss[np.isnan(rss).sum(1) == 0]
            #         try:
            u, s, v = np.linalg.svd(rss - rss.mean(0)[np.newaxis])
            Xpc_list[iS][iT] = [None for iinput in range(2)]
            Xpc_list[iS][iT][0] = [(s[idim], v[idim]) for idim in range(ndims)]
            Xpc_list[iS][iT][1] = [(0, np.zeros((Xhat[0][0].shape[0], )))
                                   for idim in range(ndims)]
    #         except:
    #             print('nope on X')
    #             print(np.mean(np.isnan(rss)))
    #             print(np.min(np.sum(Rso[icelltype][iS][iT],axis=1)))
    nN, nP = Xhat[0][0].shape
    #print('nP: '+str(nP))
    nQ = Yhat[0][0].shape[1]

    import sim_utils

    pop_rate_fn = sim_utils.f_miller_troyer
    pop_deriv_fn = sim_utils.fprime_miller_troyer

    def compute_f_(Eta, Xi, s02):
        return sim_utils.f_miller_troyer(
            Eta, Xi**2 + np.concatenate([s02 for ipixel in range(nS * nT)]))

    def compute_fprime_m_(Eta, Xi, s02):
        return sim_utils.fprime_miller_troyer(
            Eta, Xi**2 + np.concatenate([s02
                                         for ipixel in range(nS * nT)])) * Xi

    def compute_fprime_s_(Eta, Xi, s02):
        s2 = Xi**2 + np.concatenate((s02, s02), axis=0)
        return sim_utils.fprime_s_miller_troyer(Eta, s2) * (Xi / s2)

    def sorted_r_eigs(w):
        drW, prW = np.linalg.eig(w)
        srtinds = np.argsort(drW)
        return drW[srtinds], prW[:, srtinds]

    #         0.Wmx,  1.Wmy,  2.Wsx,  3.Wsy,  4.s02,5.K,  6.kappa,7.T,8.XX,        9.XXp,        10.Eta,       11.Xi,   12.h1,  13.h2

    shapes1 = [(nP, nQ), (nQ, nQ), (nP, nQ),
               (nQ, nQ), (nQ, ), (nQ * (nS - 1), ), (1, ), (nQ * (nT - 1), ),
               (1, ), (1, ), (nQ, ), (nT * nS * nQ, )]
    shapes2 = [(nN, nT * nS * nP), (nN, nT * nS * nP), (nN, nT * nS * nQ),
               (nN, nT * nS * nQ)]
    #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1])))
    #print('size of shapes2: '+str(np.sum([np.prod(shp) for shp in shapes2])))

    import calnet.fitting_spatial_feature

    YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat)
    XXhat = calnet.utils.flatten_nested_list_of_2d_arrays(Xhat)

    opto_dict = np.load(opto_silencing_data_file, allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = np.nanmean(np.reshape(Yhat_opto, (nN, 2, nS, 2, nQ)),
                           3).reshape((nN * 2, -1))
    Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12], axis=0)[np.newaxis]
    Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12], axis=0)[np.newaxis]
    Yhat_opto = Yhat_opto / np.nanmax(Yhat_opto[0::2], 0)[np.newaxis, :]
    #print(Yhat_opto.shape)
    h_opto = opto_dict['h_opto']
    #dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2]

    YYhat_halo = Yhat_opto.reshape((nN, 2, -1))
    opto_transform1 = calnet.utils.fit_opto_transform(
        YYhat_halo, norm01=norm_opto_transforms)

    opto_transform1.res[:, [0, 2, 3, 4, 6, 7]] = 0

    dYY1 = opto_transform1.transform(YYhat) - opto_transform1.preprocess(YYhat)

    #YYhat_halo_sim = calnet.utils.simulate_opto_effect(YYhat,YYhat_halo)
    #dYY1 = YYhat_halo_sim[:,1,:] - YYhat_halo_sim[:,0,:]

    def overwrite_plus_n(arr, to_overwrite, n):
        arr[:, to_overwrite] = arr[:, int(to_overwrite + n)]
        return arr

    for to_overwrite in [1, 2]:
        n = 4
        dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \
                = [overwrite_plus_n(x,to_overwrite,n) for x in \
                        [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]]
    for to_overwrite in [7]:
        n = -4
        dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \
                = [overwrite_plus_n(x,to_overwrite,n) for x in \
                        [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]]

    if ignore_halo_vip:
        dYY1[:, 2::nQ] = np.nan

    #for to_overwrite in [1,2]:
    #    dYY1[:,to_overwrite] = dYY1[:,to_overwrite+4]
    #for to_overwrite in [7]:
    #    dYY1[:,to_overwrite] = dYY1[:,to_overwrite-4]

    #Yhat_opto = opto_dict['Yhat_opto']
    #for iS in range(nS):
    #    mx = np.zeros((nQ,))
    #    for iQ in range(nQ):
    #        slicer = slice(nQ*nT*iS+iQ,nQ*nT*(1+iS),nQ)
    #        mx[iQ] = np.nanmax(Yhat_opto[0::2][:,slicer])
    #        Yhat_opto[:,slicer] = Yhat_opto[:,slicer]/mx[iQ]
    ##Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    #print(Yhat_opto.shape)
    #h_opto = opto_dict['h_opto']
    #dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2]
    #for to_overwrite in [1,2,5,6]: # overwrite sst and vip with off-centered values
    #    dYY1[:,to_overwrite] = dYY1[:,to_overwrite+8]
    #for to_overwrite in [11,15]:
    #    dYY1[:,to_overwrite] = np.nan #dYY1[:,to_overwrite-8]

    opto_dict = np.load(opto_activation_data_file, allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = np.nanmean(np.reshape(Yhat_opto, (nN, 2, nS, 2, nQ)),
                           3).reshape((nN * 2, -1))
    Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12], axis=0)[np.newaxis]
    Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12], axis=0)[np.newaxis]
    Yhat_opto = Yhat_opto / Yhat_opto[0::2].max(0)[np.newaxis, :]
    #print(Yhat_opto.shape)
    h_opto = opto_dict['h_opto']
    #dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2]

    YYhat_chrimson = Yhat_opto.reshape((nN, 2, -1))
    opto_transform2 = calnet.utils.fit_opto_transform(
        YYhat_chrimson, norm01=norm_opto_transforms)
    dYY2 = opto_transform2.transform(YYhat) - opto_transform2.preprocess(YYhat)
    #YYhat_chrimson_sim = calnet.utils.simulate_opto_effect(YYhat,YYhat_chrimson)
    #dYY2 = YYhat_chrimson_sim[:,1,:] - YYhat_chrimson_sim[:,0,:]

    #Yhat_opto = opto_dict['Yhat_opto']
    #for iS in range(nS):
    #    mx = np.zeros((nQ,))
    #    for iQ in range(nQ):
    #        slicer = slice(nQ*nT*iS+iQ,nQ*nT*(1+iS),nQ)
    #        mx[iQ] = np.nanmax(Yhat_opto[0::2][:,slicer])
    #        Yhat_opto[:,slicer] = Yhat_opto[:,slicer]/mx[iQ]
    ##Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    #print(Yhat_opto.shape)
    #h_opto = opto_dict['h_opto']
    #dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2]

    #print('dYY1 mean: %03f'%np.nanmean(np.abs(dYY1)))
    #print('dYY2 mean: %03f'%np.nanmean(np.abs(dYY2)))

    dYY = np.concatenate((dYY1, dYY2), axis=0)

    #titles = ['VIP silencing','VIP activation']
    #for itype in [0,1,2,3]:
    #    plt.figure(figsize=(5,2.5))
    #    for iyy,dyy in enumerate([dYY1,dYY2]):
    #        plt.subplot(1,2,iyy+1)
    #        if np.sum(np.isnan(dyy[:,itype]))==0:
    #            sca.scatter_size_contrast(YYhat[:,itype],YYhat[:,itype]+dyy[:,itype],nsize=6,ncontrast=6)#,mn=0)
    #        plt.title(titles[iyy])
    #        plt.xlabel('cell type %d event rate, \n light off'%itype)
    #        plt.ylabel('cell type %d event rate, \n light on'%itype)
    #        ut.erase_top_right()
    #    plt.tight_layout()
    #    ut.mkdir('figures')
    #    plt.savefig('figures/scatter_light_on_light_off_target_celltype_%d.eps'%itype)

    opto_mask = ~np.isnan(dYY)

    #dYY[nN:][~opto_mask[nN:]] = -dYY[:nN][~opto_mask[nN:]]

    #print('mean of opto_mask: '+str(opto_mask.mean()))

    #dYY[~opto_mask] = 0
    def zero_nans(arr):
        arr[np.isnan(arr)] = 0
        return arr

    #dYY,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res,\
    #        opto_transform2.slope,opto_transform2.intercept,opto_transform2.res\
    #        = [zero_nans(x) for x in \
    #                [dYY,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res,\
    #                opto_transform2.slope,opto_transform2.intercept,opto_transform2.res]]
    dYY = zero_nans(dYY)

    to_adjust = np.logical_or(np.isnan(opto_transform2.slope[0]),
                              np.isnan(opto_transform2.intercept[0]))

    opto_transform2.slope[:,
                          to_adjust] = 1 / opto_transform1.slope[:, to_adjust]
    opto_transform2.intercept[:,
                              to_adjust] = -opto_transform1.intercept[:,
                                                                      to_adjust] / opto_transform1.slope[:,
                                                                                                         to_adjust]
    opto_transform2.res[:,
                        to_adjust] = -opto_transform1.res[:,
                                                          to_adjust] / opto_transform1.slope[:,
                                                                                             to_adjust]

    #np.save('/Users/dan/Documents/notebooks/mossing-PC/shared_data/calnet_data/dYY.npy',dYY)

    from importlib import reload
    reload(calnet)
    #reload(calnet.fitting_2step_spatial_feature_opto_tight_nonlinear)
    reload(sim_utils)
    # reload(calnet.fitting_spatial_feature)
    # W0list = [np.ones(shp) for shp in shapes]
    wt_dict = {}
    wt_dict['X'] = 3  #1
    wt_dict['Y'] = 3
    #wt_dict['Eta'] = 3 # 1 #
    wt_dict['Xi'] = 0.1
    wt_dict['stims'] = np.ones((nN, 1))  #(np.arange(30)/30)[:,np.newaxis]**1 #
    wt_dict['barrier'] = 0.  #30.0 #0.1
    wt_dict['opto'] = 1  #1e1
    wt_dict['isn'] = 0.3
    wt_dict['tv'] = 1
    spont_frac = 0.5
    pc_frac = 0.5
    wt_dict['stimsOpto'] = (1 - spont_frac) * 6 / 5 * np.ones((nN, 1))
    wt_dict['stimsOpto'][0::6] = spont_frac * 6
    wt_dict['celltypesOpto'] = (1 - pc_frac) * 4 / 3 * np.ones(
        (1, nQ * nS * nT))
    wt_dict['celltypesOpto'][0, 0::nQ] = pc_frac * 4
    wt_dict['dirOpto'] = np.array((1, 0.3))
    wt_dict['dYY'] = 10  #10
    wt_dict['coupling'] = 1e-3
    wt_dict['smi'] = 0.1
    wt_dict['smi_halo'] = 30
    wt_dict['smi_chrimson'] = 0.1

    ##temporary no_opto
    wt_dict['opto'] = 0
    wt_dict['dirOpto'] = np.array((1, 1))
    #wt_dict['stimsOpto'] = np.ones((nN,1))
    wt_dict['celltypesOpto'] = np.ones((1, nQ * nS * nT))
    wt_dict['smi'] = 0  #0.01 # 0
    wt_dict['smi_halo'] = 0  #1 # 0
    wt_dict['smi_chrimson'] = 0  #0.01 # 0
    wt_dict['isn'] = 0.1
    wt_dict['tv'] = 0.1
    wt_dict['X'] = 3
    wt_dict['Eta'] = 10  #3 # 1 #

    ## temporary opto from no_opto
    #wt_dict['opto'] = 0.01
    #wt_dict['tv'] = 0.3#0.1

    np.save(
        'XXYYhat.npy', {
            'YYhat': YYhat,
            'XXhat': XXhat,
            'rs': rs,
            'Rs': Rs,
            'Rso': Rso,
            'Ypc_list': Ypc_list,
            'Xpc_list': Xpc_list
        })
    if allow_s2:
        Eta0 = invert_f_mt(YYhat)
    else:
        Eta0 = invert_f_mt(YYhat, s02=0)

    #         Wmx,    Wmy,    Wsx,    Wsy,    s02,  k,    kappa,T,   h1, h2
    #XX,            XXp,          Eta,          Xi

    opt = fmc.gen_opt(nS=nS, nT=nT)
    opt['allow_s02'] = False
    opt['allow_A'] = False
    opt['allow_B'] = True

    ntries = 1
    nhyper = 1
    dt = 1e-1
    niter = int(np.round(10 / dt))  #int(1e4)
    perturbation_size = 5e-2
    # learning_rate = 1e-4 # 1e-5 #np.linspace(3e-4,1e-3,niter+1) # 1e-5
    #l2_penalty = 0.1
    W1t = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    W2t = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    loss = np.zeros((nhyper, ntries))
    is_neg = np.array([b[1] for b in bounds1]) == 0
    counter = 0
    negatize = [np.zeros(shp, dtype='bool') for shp in shapes1]
    #print(shapes1)
    for ishp, shp in enumerate(shapes1):
        nel = np.prod(shp)
        negatize[ishp][:][is_neg[counter:counter + nel].reshape(shp)] = True
        counter = counter + nel
    for ihyper in range(nhyper):
        for itry in range(ntries):
            #print((ihyper,itry))
            #[0.(nP,nQ),1.(nQ,nQ),2.(nP,nQ),3.(nQ,nQ),4.(nQ,),5.(nQ*(nS-1),),6.(1,),7.(nQ*(nT-1),),8.(1,),9.(1,),10.(nQ,),11.(nQ*nS*nT,)]
            W10list = [
                init_noise * (ihyper + 1) * np.random.rand(*shp)
                for shp in shapes1
            ]
            W20list = [
                init_noise * (ihyper + 1) * np.random.rand(*shp)
                for shp in shapes2
            ]
            #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1])))
            #print('size of w10: '+str(np.sum([np.size(x) for x in W10list])))
            #print('len(W10list) : '+str(len(W10list)))
            counter = 0
            for ishp, shp in enumerate(shapes1):
                W10list[ishp][negatize[ishp]] = -W10list[ishp][negatize[ishp]]
            W10list[4] = np.ones(shapes1[4])  # s02
            W10list[5] = np.ones(shapes1[5])  # K
            W10list[6] = np.ones(shapes1[6])  # kappa
            W10list[7] = np.ones(shapes1[7])  # T
            W10list[8] = np.zeros(shapes1[8])  # h1
            W10list[9] = np.zeros(shapes1[9])  # h2
            W10list[10] = np.zeros(shapes1[10])  # baseline
            W10list[11] = np.ones(shapes1[11])  # amplitude
            W20list[0] = np.concatenate(Xhat, axis=1)  #XX
            W20list[1] = np.zeros_like(W20list[1])  #XXp
            W20list[2] = Eta0.copy()  #np.zeros(shapes[10]) #Eta
            W20list[3] = np.zeros(shapes2[3])  #Xi
            #[Wmx,Wmy,Wsx,Wsy,s02,k,kappa,T,XX,XXp,Eta,Xi]
            if init_W_from_lsq:
                W10list[0], W10list[1] = initialize_W(Xhat,
                                                      Yhat,
                                                      scale_by=scale_init_by,
                                                      allow_s2=allow_s2)
                for ivar in range(0, 2):
                    W10list[
                        ivar] = W10list[ivar] + init_noise * np.random.randn(
                            *W10list[ivar].shape)
            if init_W_from_lbfgs:
                print(opt)
                opt_param, result, _, _, _, _, _, _, _, _, _, _, _ = fmc.initialize_params(
                    XXhat, YYhat, opt, wpcpc=5, wpvpv=-6)
                these_shapes = [(nP, nQ), (nQ, nQ), (nQ, ), (nQ, ), (nQ, ),
                                (nQ, )]
                Wmx0, Wmy0, K0, s020, amplitude0, baseline0 = calnet.utils.parse_thing(
                    opt_param, these_shapes)
                if init_Eta_with_s02:
                    #assert(True==False)
                    Eta0 = invert_f_mt_with_s02(YYhat -
                                                np.tile(baseline0, nS * nT),
                                                s020,
                                                nS=nS,
                                                nT=nT)
                    W20list[2] = Eta0.copy()
                #Wmx0 = opt_param[:nP]
                #Wmy0 = opt_param[nP:nP+nQ]
                #K0 = opt_param[nP+nQ]
                #s020 = opt_param[nP+nQ+1]
                #amplitude0 = opt_param[nP+nQ+2]
                #baseline0 = opt_param[nP+nQ+3]
                print((Wmx0, Wmy0, K0, s020, np.tile(amplitude0,
                                                     2), baseline0))
                W10list[0], W10list[1], W10list[5], W10list[4], W10list[
                    -1], W10list[-2] = Wmx0, Wmy0, K0, s020, np.tile(
                        amplitude0, 2), baseline0
                for ivar in range(0, 2):
                    W10list[
                        ivar] = W10list[ivar] + init_noise * np.random.randn(
                            *W10list[ivar].shape)
            elif constrain_isn:
                W10list[1][0, 0] = 3
                if help_constrain_isn:
                    W10list[1][0, 3] = 5
                    W10list[1][3, 0] = -5
                    W10list[1][3, 3] = -5
                else:
                    W10list[1][0, 1:4] = 5
                    W10list[1][1:4, 0] = -5

            if init_W_from_file:
                npyfile = np.load(init_file, allow_pickle=True)[()]

                #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp = parse_W1(W1)
                #XX,XXp,Eta,Xi = parse_W2(W2)
                #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2,bl,amp = parse_W1(W1)
                W10list = [
                    npyfile['as_list'][ivar]
                    for ivar in [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]
                ]
                W20list = [npyfile['as_list'][ivar] for ivar in [8, 9, 10, 11]]
                if W20list[0].size == nN * nS * 2 * nP:
                    #assert(True==False)
                    W10list[7] = np.array(())
                    W10list[1][1, 0] = W10list[1][1, 0]
                    W20list[0] = np.nanmean(
                        W20list[0].reshape((nN, nS, 2, nP)), 2).flatten()  #XX
                    W20list[1] = np.nanmean(
                        W20list[1].reshape((nN, nS, 2, nP)), 2).flatten()  #XXp
                    W20list[2] = np.nanmean(
                        W20list[2].reshape((nN, nS, 2, nQ)), 2).flatten()  #Eta
                    W20list[3] = np.nanmean(
                        W20list[3].reshape((nN, nS, 2, nQ)), 2).flatten()  #Xi
                if correct_Eta:
                    #assert(True==False)
                    W20list[2] = Eta0.copy()
                if len(W10list) < len(shapes1):
                    #assert(True==False)
                    W10list = W10list + [
                        np.array(1),
                        np.zeros((nQ, )),
                        np.zeros((nT * nS * nQ, ))
                    ]  # add h2, bl, amp
                if init_Eta_with_s02:
                    #assert(True==False)
                    s02 = W10list[4].copy()
                    Eta0 = invert_f_mt_with_s02(YYhat, s02, nS=nS, nT=nT)
                    W20list[2] = Eta0.copy()
                #if init_Eta12_with_dYY:
                #    Eta0 = W20list[2].copy().reshape((nN,nQ*nS*nT))
                #    Xi0 = W20list[3].copy().reshape((nN,nQ*nS*nT))
                #    s020 = W10list[4].copy()
                #    YY0s = compute_f_(Eta0,Xi0,s020)
                #titles = ['VIP silencing','VIP activation']
                #for itype in [0,1,2,3]:
                #    plt.figure(figsize=(5,2.5))
                #    for iyy,yy in enumerate([YY10s,YY20s]):
                #        plt.subplot(1,2,iyy+1)
                #        if np.sum(np.isnan(yy[:,itype]))==0:
                #            sca.scatter_size_contrast(YY0s[:,itype],yy[:,itype],nsize=6,ncontrast=6)#,mn=0)
                #        plt.title(titles[iyy])
                #        plt.xlabel('cell type %d event rate, \n light off'%itype)
                #        plt.ylabel('cell type %d event rate, \n light on'%itype)
                #        ut.erase_top_right()
                #    plt.tight_layout()
                #    ut.mkdir('figures')
                #    plt.savefig('figures/scatter_light_on_light_off_init_celltype_%d.eps'%itype)
                for ivar in [0, 1, 4, 5]:  # Wmx, Wmy, s02, k
                    print(init_noise)
                    W10list[
                        ivar] = W10list[ivar] + init_noise * np.random.randn(
                            *W10list[ivar].shape)

            #print('size of bounds1: '+str(np.sum([np.size(x) for x in bd1list])))
            #print('size of w10: '+str(np.sum([np.size(x) for x in W10list])))
            #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1])))
            W1t[ihyper][itry], W2t[ihyper][itry], loss[ihyper][
                itry], gr, hess, result = calnet.fitting_2step_spatial_feature_opto_tight_nonlinear_baseline.fit_W_sim(
                    Xhat,
                    Xpc_list,
                    Yhat,
                    Ypc_list,
                    pop_rate_fn=pop_rate_fn,
                    pop_deriv_fn=pop_deriv_fn,
                    W10list=W10list.copy(),
                    W20list=W20list.copy(),
                    bounds1=bounds1,
                    bounds2=bounds2,
                    niter=niter,
                    wt_dict=wt_dict,
                    l2_penalty=l2_penalty,
                    compute_hessian=False,
                    dt=dt,
                    perturbation_size=perturbation_size,
                    dYY=dYY,
                    constrain_isn=constrain_isn,
                    tv=tv,
                    opto_mask=opto_mask,
                    use_opto_transforms=use_opto_transforms,
                    opto_transform1=opto_transform1,
                    opto_transform2=opto_transform2,
                    share_residuals=share_residuals,
                    stimwise=stimwise,
                    simulate1=simulate1,
                    simulate2=simulate2,
                    verbose=verbose)

    #def parse_W(W):
    #    Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2 = W
    #    return Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2
    def parse_W1(W):
        Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, h1, h2, bl, amp = W
        return Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, h1, h2, bl, amp

    def parse_W2(W):
        XX, XXp, Eta, Xi = W
        return XX, XXp, Eta, Xi

    itry = 0
    Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, h1, h2, bl, amp = parse_W1(W1t[0][0])
    XX, XXp, Eta, Xi = parse_W2(W2t[0][0])

    labels1 = [
        'Wmx', 'Wmy', 'Wsx', 'Wsy', 's02', 'K', 'kappa', 'T', 'h1', 'h2', 'bl',
        'amp'
    ]
    labels2 = ['XX', 'XXp', 'Eta', 'Xi']
    Wstar_dict = {}
    for i, label in enumerate(labels1):
        Wstar_dict[label] = W1t[0][0][i]
    for i, label in enumerate(labels2):
        Wstar_dict[label] = W2t[0][0][i]
    Wstar_dict['as_list'] = [
        Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, h1, h2, bl, amp
    ]
    Wstar_dict['loss'] = loss[0][0]
    Wstar_dict['wt_dict'] = wt_dict
    np.save(weights_file, Wstar_dict, allow_pickle=True)
コード例 #31
0
def get_rotation_gp(t, y, yerr, period, min_period, max_period):
    kernel = get_basic_kernel(t, y, yerr, period)
    kernel += get_rotation_kernel(t, y, yerr, period, min_period, max_period)
    gp = celerite.GP(kernel=kernel, mean=np.nanmean(y))
    gp.compute(t)
    return gp