예제 #1
0
    def initialize_to_background_rate(self):
        # self.w = abs(1e-6 * np.random.randn(*self.w.shape))
        self.w = 1e-6 * np.ones_like(self.w)
        if len(self.data_list) > 0:
            N = 0
            T = 0
            for F,S in self.data_list:
                N += S.sum(axis=0)
                T += S.shape[0] * self.dt

            lambda0 = self.invlink(N / float(T))
            self.w[0] = lambda0
 def __call__(self, t, *args, **kwargs):
     mu = self.get_params(**kwargs)['mu']
     return np.ones_like(t) * mu
예제 #3
0
    def __init__(
            self,
            x0,
            max_iters=None,
            min_grad_norm=None,
            a0=None,
            step_method=None,
            v0=None,
            mass=None,
            delta=None,  # momentum, nesterov, relativistic
            avg_sq_grad0=None,
            gamma=None,
            eps=None,  # rmsprop
            b1=None,
            b2=None,
            mean0=None,
            variance0=None,  # adam
            verbose=None,
            verbose_start=None,
            verbose_stride=None):
        super().__init__(x0, max_iters, min_grad_norm, verbose, verbose_start,
                         verbose_stride)

        if a0 is None:
            a0 = 1.0
        self.a0 = a0

        if step_method is None:
            step_method = 'gradient'
        self.step_method = step_method

        if v0 is None:
            v0 = np.zeros_like(x0)
        self.v0 = v0

        if mass is None:
            mass = 0.0
        self.mass = mass

        if delta is None:
            delta = 10.0
        self.delta = delta

        if avg_sq_grad0 is None:
            avg_sq_grad0 = np.ones_like(x0)
        self.avg_sq_grad0 = avg_sq_grad0

        if gamma is None:
            gamma = 0.9
        self.gamma = gamma

        if eps is None:
            eps = 1e-8
        self.eps = eps

        if b1 is None:
            b1 = 0.9
        self.b1 = b1

        if b2 is None:
            b2 = 0.999
        self.b2 = b2

        if mean0 is None:
            mean0 = np.zeros_like(x0)
        self.mean0 = mean0

        if variance0 is None:
            variance0 = np.zeros_like(x0)
        self.variance0 = variance0
plt.ylabel("Log Probability")
plt.legend(loc="lower right")

# In[15]:

# Plot the observation distributions
from hips.plotting.colormaps import white_to_color_cmap
xmins = x.min(axis=0)
xmaxs = x.max(axis=0)
npts = 100
XX, YY = np.meshgrid(np.linspace(xmins[0], xmaxs[0], npts),
                     np.linspace(xmins[1], xmaxs[1], npts))

data = np.column_stack((XX.ravel(), YY.ravel(), np.zeros((npts**2, D - 2))))
input = np.zeros((data.shape[0], 0))
mask = np.ones_like(data, dtype=bool)
tag = None
lls = hmm.observations.log_likelihoods(data, input, mask, tag)

plt.figure(figsize=(6, 6))
for k in range(K):
    plt.contour(XX,
                YY,
                np.exp(lls[:, k]).reshape(XX.shape),
                cmap=white_to_color_cmap(colors[k % len(colors)]))
    plt.plot(x[z == k, 0], x[z == k, 1], 'o', mfc=colors[k], mec='none', ms=4)

plt.plot(x[:, 0], x[:, 1], '-k', lw=2, alpha=.5)
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")
plt.title("Observation Distributions")
예제 #5
0
def factor_analysis_with_imputation(D, datas, masks=None, num_iters=50):
    datas = [datas] if not isinstance(datas, (list, tuple)) else datas
    if masks is not None:
        masks = [masks] if not isinstance(masks, (list, tuple)) else masks
        assert np.all([m.shape == d.shape for d, m in zip(datas, masks)])
    else:
        masks = [np.ones_like(data, dtype=bool) for data in datas]
    N = datas[0].shape[1]

    # Make the factor analysis model
    from pybasicbayes.models import FactorAnalysis
    fa = FactorAnalysis(N, D, alpha_0=1e-3, beta_0=1e-3)
    fa.regression.sigmasq_flat = np.ones(N)
    for data, mask in zip(datas, masks):
        fa.add_data(data, mask=mask)
    fa.set_empirical_mean()

    # Fit with EM
    lls = [fa.log_likelihood()]
    pbar = trange(num_iters)
    pbar.set_description("Itr {} LP: {:.1f}".format(0, lls[-1]))
    for itr in pbar:
        fa.EM_step()
        lls.append(fa.log_likelihood())

        pbar.set_description("Itr {} LP: {:.1f}".format(itr, lls[-1]))
        pbar.update(1)
    lls = np.array(lls)

    # Get the continuous states and their covariances
    E_xs = [states.E_Z for states in fa.data_list]
    E_xxTs = [states.E_ZZT for states in fa.data_list]
    Cov_xs = [
        E_xxT - E_x[:, :, None] * E_x[:, None, :]
        for E_x, E_xxT in zip(E_xs, E_xxTs)
    ]

    # Rotate the states with SVD so that the columns of the
    # emission matrix, C, are orthogonal and sorted in order
    # of decreasing explained variance.
    #
    # Note: the columns of C are not normalized like in PCA!
    # This is because factor analysis assumes the latents are
    # distributed according to a standard normal distribution.
    # The FA latents are only invariant to *orthogonal* transforms.
    # Thus, the scaling must be accounted for in C.
    C, S, VT = np.linalg.svd(fa.W, full_matrices=False)
    xhats = [x.dot(VT.T) for x in E_xs]
    Cov_xhats = [
        np.matmul(np.matmul(VT[None, :, :], Cov_x), VT.T[None, :, :])
        for Cov_x in Cov_xs
    ]

    # Test that we got this right
    for x, xhat in zip(E_xs, xhats):
        y = x.dot(fa.W.T) + fa.mean
        yhat = xhat.dot((C * S).T) + fa.mean
        assert np.allclose(y, yhat)

    # Strip out the data from the factor analysis model,
    # update the emission matrix
    fa.regression.A = C * S
    fa.data_list = []

    return fa, xhats, Cov_xhats, lls
예제 #6
0
 def log_likelihoods(self, data, input, mask, tag):
     mus, sigmas = self.mus, np.exp(self.inv_sigmas)
     mask = np.ones_like(data, dtype=bool) if mask is None else mask
     return -0.5 * np.sum(
         (np.log(2 * np.pi * sigmas) + (data[:, None, :] - mus)**2 / sigmas) 
         * mask[:, None, :], axis=2)
예제 #7
0
 def normalize(arr):
     arrsum = arr.sum(1)
     well_behaved = (arrsum > 0)[:, np.newaxis]
     arrnorm = well_behaved * arr / arrsum[:, np.newaxis] + (
         ~well_behaved) * np.ones_like(arr) / arr.shape[1]
     return arrnorm
예제 #8
0
def fit_weights_and_save(
        weights_file,
        ca_data_file='rs_vm_denoise_200605.npy',
        opto_silencing_data_file='vip_halo_data_for_sim.npy',
        opto_activation_data_file='vip_chrimson_data_for_sim.npy',
        constrain_wts=None,
        allow_var=True,
        fit_s02=True,
        constrain_isn=True,
        tv=False,
        l2_penalty=0.01,
        init_noise=0.1,
        init_W_from_lsq=False,
        init_W_from_lbfgs=False,
        scale_init_by=1,
        init_W_from_file=False,
        init_file=None,
        correct_Eta=False,
        init_Eta_with_s02=False,
        init_Eta12_with_dYY=False,
        use_opto_transforms=False,
        share_residuals=False,
        stimwise=False,
        simulate1=True,
        simulate2=False,
        help_constrain_isn=True,
        ignore_halo_vip=False,
        verbose=True,
        free_amplitude=False,
        norm_opto_transforms=False,
        zero_extra_weights=None,
        allow_s2=True):

    nsize, ncontrast = 6, 6

    npfile = np.load(ca_data_file, allow_pickle=True)[(
    )]  #,{'rs':rs,'rs_denoise':rs_denoise},allow_pickle=True)
    rs = npfile['rs']
    #rs_denoise = npfile['rs_denoise']

    nsize, ncontrast, ndir = 6, 6, 8
    #ori_dirs = [[0,4],[2,6]] #[[0,4],[1,3,5,7],[2,6]]
    ori_dirs = [[0, 1, 2, 3, 4, 5, 6, 7]]
    nT = len(ori_dirs)
    nS = len(rs[0])

    def sum_to_1(r):
        R = r.reshape((r.shape[0], -1))
        #R = R/np.nansum(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis]
        R = R / np.nansum(R, axis=1)[:, np.newaxis]  # changed 8/28
        return R

    def norm_to_mean(r):
        R = r.reshape((r.shape[0], -1))
        R = R / np.nanmean(R[:, ~np.isnan(R.sum(0))], axis=1)[:, np.newaxis]
        return R

    Rs = [[None, None] for i in range(len(rs))]
    Rso = [[[None for iT in range(nT)] for iS in range(nS)]
           for icelltype in range(len(rs))]
    rso = [[[None for iT in range(nT)] for iS in range(nS)]
           for icelltype in range(len(rs))]

    for iR, r in enumerate(rs):  #rs_denoise):
        #print(iR)
        for ialign in range(nS):
            #Rs[iR][ialign] = r[ialign][:,:nsize,:]
            #sm = np.nanmean(np.nansum(np.nansum(Rs[iR][ialign],1),1))
            #Rs[iR][ialign] = Rs[iR][ialign]/sm
            #print('frac isnan Rs %d,%d: %f'%(iR,ialign,np.isnan(r[ialign]).mean()))
            Rs[iR][ialign] = sum_to_1(r[ialign][:, :nsize, :])
    #         Rs[iR][ialign] = von_mises_denoise(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir)))

    kernel = np.ones((1, 2, 2))
    kernel = kernel / kernel.sum()

    for iR, r in enumerate(rs):
        for ialign in range(nS):
            for iori in range(nT):
                #print('this Rs shape: '+str(Rs[iR][ialign].shape))
                #print('this Rs reshaped shape: '+str(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))[:,:,:,ori_dirs[iori]].shape))
                #print('this Rs max percent nan: '+str(np.isnan(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))[:,:,:,ori_dirs[iori]]).mean(-1).max()))
                Rso[iR][ialign][iori] = np.nanmean(
                    Rs[iR][ialign].reshape(
                        (-1, nsize, ncontrast, ndir))[:, :, :, ori_dirs[iori]],
                    -1)
                Rso[iR][ialign][iori][:, :, 0] = np.nanmean(
                    Rso[iR][ialign][iori][:, :, 0],
                    1)[:, np.newaxis]  # average 0 contrast values
                #print('frac isnan pre-conv Rso %d,%d,%d: %f'%(iR,ialign,iori,np.isnan(Rso[iR][ialign][iori]).mean()))
                Rso[iR][ialign][iori][:, 1:, 1:] = ssi.convolve(
                    Rso[iR][ialign][iori], kernel, 'valid')
                Rso[iR][ialign][iori] = Rso[iR][ialign][iori].reshape(
                    Rso[iR][ialign][iori].shape[0], -1)
                #print('frac isnan Rso %d,%d,%d: %f'%(iR,ialign,iori,np.isnan(Rso[iR][ialign][iori]).mean()))
                #print('sum of Rso isnan: '+str(np.isnan(Rso[iR][ialign][iori]).sum(1)))
                #Rso[iR][ialign][iori] = Rso[iR][ialign][iori]/np.nanmean(Rso[iR][ialign][iori],-1)[:,np.newaxis]

    def set_bound(bd, code, val=0):
        # set bounds to 0 where 0s occur in 'code'
        for iitem in range(len(bd)):
            bd[iitem][code[iitem]] = val

    nN = 36
    nS = 2
    nP = 2
    nT = 1
    nQ = 4

    # code for bounds: 0 , constrained to 0
    # +/-1 , constrained to +/-1
    # 1.5, constrained to [0,1]
    # 2 , constrained to [0,inf)
    # -2 , constrained to (-inf,0]
    # 3 , unconstrained

    Wmx_bounds = 3 * np.ones((nP, nQ), dtype=int)
    Wmx_bounds[0, :] = 2  # L4 PCs are excitatory
    Wmx_bounds[0, 1] = 0  # SSTs don't receive L4 input

    if allow_var:
        Wsx_bounds = 3 * np.ones(
            Wmx_bounds.shape)  #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds)
        Wsx_bounds[0, 1] = 0
    else:
        Wsx_bounds = np.zeros(
            Wmx_bounds.shape)  #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds)

    Wmy_bounds = 3 * np.ones((nQ, nQ), dtype=int)
    Wmy_bounds[0, :] = 2  # PCs are excitatory
    Wmy_bounds[1:, :] = -2  # all the cell types except PCs are inhibitory
    Wmy_bounds[1, 1] = 0  # SSTs don't inhibit themselves
    # Wmy_bounds[3,1] = 0 # PVs are allowed to inhibit SSTs, consistent with Hillel's unpublished results, but not consistent with Pfeffer et al.
    Wmy_bounds[
        2,
        0] = 0  # VIPs don't inhibit L2/3 PCs. According to Pfeffer et al., only L5 PCs were found to get VIP inhibition

    if not zero_extra_weights is None:
        Wmx_bounds[zero_extra_weights[0]] = 0
        Wmy_bounds[zero_extra_weights[1]] = 0

    if allow_var:
        Wsy_bounds = 3 * np.ones(
            Wmy_bounds.shape)  #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds)
        Wsy_bounds[1, 1] = 0
        Wsy_bounds[3, 1] = 0
        Wsy_bounds[2, 0] = 0
    else:
        Wsy_bounds = np.zeros(
            Wmy_bounds.shape)  #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds)

    if not constrain_wts is None:
        for wt in constrain_wts:
            Wmy_bounds[wt[0], wt[1]] = 0
            Wsy_bounds[wt[0], wt[1]] = 0

    def tile_nS_nT_nN(kernel):
        row = np.concatenate([kernel for idim in range(nS * nT)],
                             axis=0)[np.newaxis, :]
        tiled = np.concatenate([row for irow in range(nN)], axis=0)
        return tiled

    def set_bounds_by_code(lb, ub, bdlist):
        set_bound(lb, [bd == 0 for bd in bdlist], val=0)
        set_bound(ub, [bd == 0 for bd in bdlist], val=0)

        set_bound(lb, [bd == 2 for bd in bdlist], val=0)

        set_bound(ub, [bd == -2 for bd in bdlist], val=0)

        set_bound(lb, [bd == 1 for bd in bdlist], val=1)
        set_bound(ub, [bd == 1 for bd in bdlist], val=1)

        set_bound(lb, [bd == 1.5 for bd in bdlist], val=0)
        set_bound(ub, [bd == 1.5 for bd in bdlist], val=1)

        set_bound(lb, [bd == -1 for bd in bdlist], val=-1)
        set_bound(ub, [bd == -1 for bd in bdlist], val=-1)

    if allow_s2:
        if fit_s02:
            s02_bounds = 2 * np.ones(
                (nQ, ))  # permitting noise as a free parameter
        else:
            s02_bounds = np.ones((nQ, ))
    else:
        s02_bounds = np.zeros((nQ, ))

    k_bounds = 1.5 * np.ones((nQ * (nS - 1), ))

    #k_bounds[1] = 0 # temporary: spatial kernel constrained to 0 for SST
    #k_bounds[2] = 0 # temporary: spatial kernel constrained to 0 for VIP

    kappa_bounds = np.ones((1, ))
    # kappa_bounds = 2*np.ones((1,))

    T_bounds = 1.5 * np.ones((nQ * (nT - 1), ))

    X_bounds = tile_nS_nT_nN(np.array([2, 1]))
    # X_bounds = np.array([np.array([2,1,2,1])]*nN)

    Xp_bounds = tile_nS_nT_nN(np.array([3, 1]))
    # Xp_bounds = np.array([np.array([3,1,3,1])]*nN)

    # Y_bounds = tile_nS_nT_nN(2*np.ones((nQ,)))
    # # Y_bounds = 2*np.ones((nN,nT*nS*nQ))

    Eta_bounds = tile_nS_nT_nN(3 * np.ones((nQ, )))
    # Eta_bounds = 3*np.ones((nN,nT*nS*nQ))

    if allow_s2:
        if allow_var:
            Xi_bounds = tile_nS_nT_nN(3 * np.ones((nQ, )))
        else:
            Xi_bounds = tile_nS_nT_nN(np.zeros((nQ, )))
    else:
        Xi_bounds = tile_nS_nT_nN(np.zeros((nQ, )))

    # Xi_bounds = 3*np.ones((nN,nT*nS*nQ))

    h1_bounds = -2 * np.ones((1, ))

    h2_bounds = 2 * np.ones((1, ))

    bl_bounds = 3 * np.ones((nQ, ))

    if free_amplitude:
        amp_bounds = 2 * np.ones((nT * nS * nQ, ))
    else:
        amp_bounds = 1 * np.ones((nT * nS * nQ, ))

    # shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nN,nS*nP),(nN,nS*nQ),(nN,nS*nQ),(nN,nS*nQ)]
    shapes1 = [(nP, nQ), (nQ, nQ), (nP, nQ),
               (nQ, nQ), (nQ, ), (nQ * (nS - 1), ), (1, ), (nQ * (nT - 1), ),
               (1, ), (1, ), (nQ, ), (nQ * nS * nT, )]
    shapes2 = [(nN, nT * nS * nP), (nN, nT * nS * nP), (nN, nT * nS * nQ),
               (nN, nT * nS * nQ)]
    #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1])))
    #print('size of shapes2: '+str(np.sum([np.prod(shp) for shp in shapes2])))
    #         Wmx,    Wmy,    Wsx,    Wsy,    s02,  k,    kappa,T,   h1, h2
    #XX,            XXp,          Eta,          Xi

    #bdlist = [Wmx_bounds,Wmy_bounds,Wsx_bounds,Wsy_bounds,s02_bounds,k_bounds,kappa_bounds,T_bounds,X_bounds,Xp_bounds,Eta_bounds,Xi_bounds,h1_bounds,h2_bounds]
    bd1list = [
        Wmx_bounds, Wmy_bounds, Wsx_bounds, Wsy_bounds, s02_bounds, k_bounds,
        kappa_bounds, T_bounds, h1_bounds, h2_bounds, bl_bounds, amp_bounds
    ]
    bd2list = [X_bounds, Xp_bounds, Eta_bounds, Xi_bounds]

    lb1, ub1 = [[sgn * np.inf * np.ones(shp) for shp in shapes1]
                for sgn in [-1, 1]]
    set_bounds_by_code(lb1, ub1, bd1list)
    lb2, ub2 = [[sgn * np.inf * np.ones(shp) for shp in shapes2]
                for sgn in [-1, 1]]
    set_bounds_by_code(lb2, ub2, bd2list)

    #set_bound(lb,[bd==0 for bd in bdlist],val=0)
    #set_bound(ub,[bd==0 for bd in bdlist],val=0)
    #
    #set_bound(lb,[bd==2 for bd in bdlist],val=0)
    #
    #set_bound(ub,[bd==-2 for bd in bdlist],val=0)
    #
    #set_bound(lb,[bd==1 for bd in bdlist],val=1)
    #set_bound(ub,[bd==1 for bd in bdlist],val=1)
    #
    #set_bound(lb,[bd==1.5 for bd in bdlist],val=0)
    #set_bound(ub,[bd==1.5 for bd in bdlist],val=1)
    #
    #set_bound(lb,[bd==-1 for bd in bdlist],val=-1)
    #set_bound(ub,[bd==-1 for bd in bdlist],val=-1)

    # for bd in [lb,ub]:
    #     for ind in [2,3]:
    #         bd[ind][:,1] = 0

    # temporary for no variation expt.
    # lb[2] = np.zeros_like(lb[2])
    # lb[3] = np.zeros_like(lb[3])
    # lb[4] = np.ones_like(lb[4])
    # lb[5] = np.zeros_like(lb[5])
    # ub[2] = np.zeros_like(ub[2])
    # ub[3] = np.zeros_like(ub[3])
    # ub[4] = np.ones_like(ub[4])
    # ub[5] = np.ones_like(ub[5])
    # temporary for no variation expt.
    lb1 = np.concatenate([a.flatten() for a in lb1])
    ub1 = np.concatenate([b.flatten() for b in ub1])
    lb2 = np.concatenate([a.flatten() for a in lb2])
    ub2 = np.concatenate([b.flatten() for b in ub2])
    bounds1 = [(a, b) for a, b in zip(lb1, ub1)]
    bounds2 = [(a, b) for a, b in zip(lb2, ub2)]

    nS = 2
    #print('nT: '+str(nT))
    ndims = 5
    ncelltypes = 5
    Yhat = [[None for iT in range(nT)] for iS in range(nS)]
    Xhat = [[None for iT in range(nT)] for iS in range(nS)]
    Ypc_list = [[None for iT in range(nT)] for iS in range(nS)]
    Xpc_list = [[None for iT in range(nT)] for iS in range(nS)]
    mx = [None for iS in range(nS)]
    for iS in range(nS):
        mx[iS] = np.zeros((ncelltypes, ))
        yy = [None for icelltype in range(ncelltypes)]
        for icelltype in range(ncelltypes):
            yy[icelltype] = np.nanmean(Rso[icelltype][iS][0], 0)
            mx[iS][icelltype] = np.nanmax(yy[icelltype])
        for iT in range(nT):
            y = [
                np.nanmean(Rso[icelltype][iS][iT], axis=0)[:, np.newaxis] /
                mx[iS][icelltype] for icelltype in range(1, ncelltypes)
            ]
            Ypc_list[iS][iT] = [None for icelltype in range(1, ncelltypes)]
            for icelltype in range(1, ncelltypes):
                # as currently written, penalties involving (X,Y)pc_list are effectively artificially smaller by
                # a factor of mx[iS][icelltype] compared to what one would expect from the (X,Y)-penalty as defined
                # subsequently.
                rss = Rso[icelltype][iS][iT].copy(
                )  #/mx[iS][icelltype] #.reshape(Rs[icelltype][ialign].shape[0],-1)
                #print('sum of isnan: '+str(np.isnan(rss).sum(1)))
                #rss = Rso[icelltype][iS][iT].copy() #.reshape(Rs[icelltype][ialign].shape[0],-1)
                rss = rss[np.isnan(rss).sum(1) == 0]
                #         print(rss.max())
                #         rss[rss<0] = 0
                #         rss = rss[np.random.randn(rss.shape[0])>0]
                try:
                    u, s, v = np.linalg.svd(rss - np.mean(rss, 0)[np.newaxis])
                    Ypc_list[iS][iT][icelltype - 1] = [
                        (s[idim], v[idim]) for idim in range(ndims)
                    ]
    #                 print('yep on Y')
    #                 print(np.min(np.sum(rs[icelltype][iS][iT],axis=1)))
                except:
                    print('nope on Y')
                    #print('shape of rss: '+str(rss.shape))
                    #print('mean of rss: '+str(np.mean(np.isnan(rss))))
                    #print('min of this rs: '+str(np.min(np.sum(rs[icelltype][iS][iT],axis=1))))
            Yhat[iS][iT] = np.concatenate(y, axis=1)
            #         x = sim_utils.columnize(Rso[0][iS][iT])[:,np.newaxis]
            icelltype = 0
            #x = np.nanmean(Rso[icelltype][iS][iT],0)[:,np.newaxis]#/mx[iS][icelltype]
            x = np.nanmean(Rso[icelltype][iS][iT],
                           0)[:, np.newaxis] / mx[iS][icelltype]
            #         opto_column = np.concatenate((np.zeros((nN,)),np.zeros((nNO/2,)),np.ones((nNO/2,))),axis=0)[:,np.newaxis]
            Xhat[iS][iT] = np.concatenate((x, np.ones_like(x)), axis=1)
            #         Xhat[iS][iT] = np.concatenate((x,np.ones_like(x),opto_column),axis=1)
            icelltype = 0
            #rss = Rso[icelltype][iS][iT].copy()/mx[iS][icelltype]
            rss = Rso[icelltype][iS][iT].copy()
            rss = rss[np.isnan(rss).sum(1) == 0]
            #         try:
            u, s, v = np.linalg.svd(rss - rss.mean(0)[np.newaxis])
            Xpc_list[iS][iT] = [None for iinput in range(2)]
            Xpc_list[iS][iT][0] = [(s[idim], v[idim]) for idim in range(ndims)]
            Xpc_list[iS][iT][1] = [(0, np.zeros((Xhat[0][0].shape[0], )))
                                   for idim in range(ndims)]
    #         except:
    #             print('nope on X')
    #             print(np.mean(np.isnan(rss)))
    #             print(np.min(np.sum(Rso[icelltype][iS][iT],axis=1)))
    nN, nP = Xhat[0][0].shape
    #print('nP: '+str(nP))
    nQ = Yhat[0][0].shape[1]

    import sim_utils

    pop_rate_fn = sim_utils.f_miller_troyer
    pop_deriv_fn = sim_utils.fprime_miller_troyer

    def compute_f_(Eta, Xi, s02):
        return sim_utils.f_miller_troyer(
            Eta, Xi**2 + np.concatenate([s02 for ipixel in range(nS * nT)]))

    def compute_fprime_m_(Eta, Xi, s02):
        return sim_utils.fprime_miller_troyer(
            Eta, Xi**2 + np.concatenate([s02
                                         for ipixel in range(nS * nT)])) * Xi

    def compute_fprime_s_(Eta, Xi, s02):
        s2 = Xi**2 + np.concatenate((s02, s02), axis=0)
        return sim_utils.fprime_s_miller_troyer(Eta, s2) * (Xi / s2)

    def sorted_r_eigs(w):
        drW, prW = np.linalg.eig(w)
        srtinds = np.argsort(drW)
        return drW[srtinds], prW[:, srtinds]

    #         0.Wmx,  1.Wmy,  2.Wsx,  3.Wsy,  4.s02,5.K,  6.kappa,7.T,8.XX,        9.XXp,        10.Eta,       11.Xi,   12.h1,  13.h2

    shapes1 = [(nP, nQ), (nQ, nQ), (nP, nQ),
               (nQ, nQ), (nQ, ), (nQ * (nS - 1), ), (1, ), (nQ * (nT - 1), ),
               (1, ), (1, ), (nQ, ), (nT * nS * nQ, )]
    shapes2 = [(nN, nT * nS * nP), (nN, nT * nS * nP), (nN, nT * nS * nQ),
               (nN, nT * nS * nQ)]
    #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1])))
    #print('size of shapes2: '+str(np.sum([np.prod(shp) for shp in shapes2])))

    import calnet.fitting_spatial_feature

    YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat)
    XXhat = calnet.utils.flatten_nested_list_of_2d_arrays(Xhat)

    opto_dict = np.load(opto_silencing_data_file, allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = np.nanmean(np.reshape(Yhat_opto, (nN, 2, nS, 2, nQ)),
                           3).reshape((nN * 2, -1))
    Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12], axis=0)[np.newaxis]
    Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12], axis=0)[np.newaxis]
    Yhat_opto = Yhat_opto / np.nanmax(Yhat_opto[0::2], 0)[np.newaxis, :]
    #print(Yhat_opto.shape)
    h_opto = opto_dict['h_opto']
    #dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2]

    YYhat_halo = Yhat_opto.reshape((nN, 2, -1))
    opto_transform1 = calnet.utils.fit_opto_transform(
        YYhat_halo, norm01=norm_opto_transforms)

    opto_transform1.res[:, [0, 2, 3, 4, 6, 7]] = 0

    dYY1 = opto_transform1.transform(YYhat) - opto_transform1.preprocess(YYhat)

    #YYhat_halo_sim = calnet.utils.simulate_opto_effect(YYhat,YYhat_halo)
    #dYY1 = YYhat_halo_sim[:,1,:] - YYhat_halo_sim[:,0,:]

    def overwrite_plus_n(arr, to_overwrite, n):
        arr[:, to_overwrite] = arr[:, int(to_overwrite + n)]
        return arr

    for to_overwrite in [1, 2]:
        n = 4
        dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \
                = [overwrite_plus_n(x,to_overwrite,n) for x in \
                        [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]]
    for to_overwrite in [7]:
        n = -4
        dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \
                = [overwrite_plus_n(x,to_overwrite,n) for x in \
                        [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]]

    if ignore_halo_vip:
        dYY1[:, 2::nQ] = np.nan

    #for to_overwrite in [1,2]:
    #    dYY1[:,to_overwrite] = dYY1[:,to_overwrite+4]
    #for to_overwrite in [7]:
    #    dYY1[:,to_overwrite] = dYY1[:,to_overwrite-4]

    #Yhat_opto = opto_dict['Yhat_opto']
    #for iS in range(nS):
    #    mx = np.zeros((nQ,))
    #    for iQ in range(nQ):
    #        slicer = slice(nQ*nT*iS+iQ,nQ*nT*(1+iS),nQ)
    #        mx[iQ] = np.nanmax(Yhat_opto[0::2][:,slicer])
    #        Yhat_opto[:,slicer] = Yhat_opto[:,slicer]/mx[iQ]
    ##Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    #print(Yhat_opto.shape)
    #h_opto = opto_dict['h_opto']
    #dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2]
    #for to_overwrite in [1,2,5,6]: # overwrite sst and vip with off-centered values
    #    dYY1[:,to_overwrite] = dYY1[:,to_overwrite+8]
    #for to_overwrite in [11,15]:
    #    dYY1[:,to_overwrite] = np.nan #dYY1[:,to_overwrite-8]

    opto_dict = np.load(opto_activation_data_file, allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = np.nanmean(np.reshape(Yhat_opto, (nN, 2, nS, 2, nQ)),
                           3).reshape((nN * 2, -1))
    Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12], axis=0)[np.newaxis]
    Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12], axis=0)[np.newaxis]
    Yhat_opto = Yhat_opto / Yhat_opto[0::2].max(0)[np.newaxis, :]
    #print(Yhat_opto.shape)
    h_opto = opto_dict['h_opto']
    #dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2]

    YYhat_chrimson = Yhat_opto.reshape((nN, 2, -1))
    opto_transform2 = calnet.utils.fit_opto_transform(
        YYhat_chrimson, norm01=norm_opto_transforms)
    dYY2 = opto_transform2.transform(YYhat) - opto_transform2.preprocess(YYhat)
    #YYhat_chrimson_sim = calnet.utils.simulate_opto_effect(YYhat,YYhat_chrimson)
    #dYY2 = YYhat_chrimson_sim[:,1,:] - YYhat_chrimson_sim[:,0,:]

    #Yhat_opto = opto_dict['Yhat_opto']
    #for iS in range(nS):
    #    mx = np.zeros((nQ,))
    #    for iQ in range(nQ):
    #        slicer = slice(nQ*nT*iS+iQ,nQ*nT*(1+iS),nQ)
    #        mx[iQ] = np.nanmax(Yhat_opto[0::2][:,slicer])
    #        Yhat_opto[:,slicer] = Yhat_opto[:,slicer]/mx[iQ]
    ##Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    #print(Yhat_opto.shape)
    #h_opto = opto_dict['h_opto']
    #dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2]

    #print('dYY1 mean: %03f'%np.nanmean(np.abs(dYY1)))
    #print('dYY2 mean: %03f'%np.nanmean(np.abs(dYY2)))

    dYY = np.concatenate((dYY1, dYY2), axis=0)

    #titles = ['VIP silencing','VIP activation']
    #for itype in [0,1,2,3]:
    #    plt.figure(figsize=(5,2.5))
    #    for iyy,dyy in enumerate([dYY1,dYY2]):
    #        plt.subplot(1,2,iyy+1)
    #        if np.sum(np.isnan(dyy[:,itype]))==0:
    #            sca.scatter_size_contrast(YYhat[:,itype],YYhat[:,itype]+dyy[:,itype],nsize=6,ncontrast=6)#,mn=0)
    #        plt.title(titles[iyy])
    #        plt.xlabel('cell type %d event rate, \n light off'%itype)
    #        plt.ylabel('cell type %d event rate, \n light on'%itype)
    #        ut.erase_top_right()
    #    plt.tight_layout()
    #    ut.mkdir('figures')
    #    plt.savefig('figures/scatter_light_on_light_off_target_celltype_%d.eps'%itype)

    opto_mask = ~np.isnan(dYY)

    #dYY[nN:][~opto_mask[nN:]] = -dYY[:nN][~opto_mask[nN:]]

    #print('mean of opto_mask: '+str(opto_mask.mean()))

    #dYY[~opto_mask] = 0
    def zero_nans(arr):
        arr[np.isnan(arr)] = 0
        return arr

    #dYY,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res,\
    #        opto_transform2.slope,opto_transform2.intercept,opto_transform2.res\
    #        = [zero_nans(x) for x in \
    #                [dYY,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res,\
    #                opto_transform2.slope,opto_transform2.intercept,opto_transform2.res]]
    dYY = zero_nans(dYY)

    to_adjust = np.logical_or(np.isnan(opto_transform2.slope[0]),
                              np.isnan(opto_transform2.intercept[0]))

    opto_transform2.slope[:,
                          to_adjust] = 1 / opto_transform1.slope[:, to_adjust]
    opto_transform2.intercept[:,
                              to_adjust] = -opto_transform1.intercept[:,
                                                                      to_adjust] / opto_transform1.slope[:,
                                                                                                         to_adjust]
    opto_transform2.res[:,
                        to_adjust] = -opto_transform1.res[:,
                                                          to_adjust] / opto_transform1.slope[:,
                                                                                             to_adjust]

    #np.save('/Users/dan/Documents/notebooks/mossing-PC/shared_data/calnet_data/dYY.npy',dYY)

    from importlib import reload
    reload(calnet)
    #reload(calnet.fitting_2step_spatial_feature_opto_tight_nonlinear)
    reload(sim_utils)
    # reload(calnet.fitting_spatial_feature)
    # W0list = [np.ones(shp) for shp in shapes]
    wt_dict = {}
    wt_dict['X'] = 3  #1
    wt_dict['Y'] = 3
    #wt_dict['Eta'] = 3 # 1 #
    wt_dict['Xi'] = 0.1
    wt_dict['stims'] = np.ones((nN, 1))  #(np.arange(30)/30)[:,np.newaxis]**1 #
    wt_dict['barrier'] = 0.  #30.0 #0.1
    wt_dict['opto'] = 1  #1e1
    wt_dict['isn'] = 0.3
    wt_dict['tv'] = 1
    spont_frac = 0.5
    pc_frac = 0.5
    wt_dict['stimsOpto'] = (1 - spont_frac) * 6 / 5 * np.ones((nN, 1))
    wt_dict['stimsOpto'][0::6] = spont_frac * 6
    wt_dict['celltypesOpto'] = (1 - pc_frac) * 4 / 3 * np.ones(
        (1, nQ * nS * nT))
    wt_dict['celltypesOpto'][0, 0::nQ] = pc_frac * 4
    wt_dict['dirOpto'] = np.array((1, 0.3))
    wt_dict['dYY'] = 10  #10
    wt_dict['coupling'] = 1e-3
    wt_dict['smi'] = 0.1
    wt_dict['smi_halo'] = 30
    wt_dict['smi_chrimson'] = 0.1

    ##temporary no_opto
    wt_dict['opto'] = 0
    wt_dict['dirOpto'] = np.array((1, 1))
    #wt_dict['stimsOpto'] = np.ones((nN,1))
    wt_dict['celltypesOpto'] = np.ones((1, nQ * nS * nT))
    wt_dict['smi'] = 0  #0.01 # 0
    wt_dict['smi_halo'] = 0  #1 # 0
    wt_dict['smi_chrimson'] = 0  #0.01 # 0
    wt_dict['isn'] = 0.1
    wt_dict['tv'] = 0.1
    wt_dict['X'] = 3
    wt_dict['Eta'] = 10  #3 # 1 #

    ## temporary opto from no_opto
    #wt_dict['opto'] = 0.01
    #wt_dict['tv'] = 0.3#0.1

    np.save(
        'XXYYhat.npy', {
            'YYhat': YYhat,
            'XXhat': XXhat,
            'rs': rs,
            'Rs': Rs,
            'Rso': Rso,
            'Ypc_list': Ypc_list,
            'Xpc_list': Xpc_list
        })
    if allow_s2:
        Eta0 = invert_f_mt(YYhat)
    else:
        Eta0 = invert_f_mt(YYhat, s02=0)

    #         Wmx,    Wmy,    Wsx,    Wsy,    s02,  k,    kappa,T,   h1, h2
    #XX,            XXp,          Eta,          Xi

    opt = fmc.gen_opt(nS=nS, nT=nT)
    opt['allow_s02'] = False
    opt['allow_A'] = False
    opt['allow_B'] = True

    ntries = 1
    nhyper = 1
    dt = 1e-1
    niter = int(np.round(10 / dt))  #int(1e4)
    perturbation_size = 5e-2
    # learning_rate = 1e-4 # 1e-5 #np.linspace(3e-4,1e-3,niter+1) # 1e-5
    #l2_penalty = 0.1
    W1t = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    W2t = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    loss = np.zeros((nhyper, ntries))
    is_neg = np.array([b[1] for b in bounds1]) == 0
    counter = 0
    negatize = [np.zeros(shp, dtype='bool') for shp in shapes1]
    #print(shapes1)
    for ishp, shp in enumerate(shapes1):
        nel = np.prod(shp)
        negatize[ishp][:][is_neg[counter:counter + nel].reshape(shp)] = True
        counter = counter + nel
    for ihyper in range(nhyper):
        for itry in range(ntries):
            #print((ihyper,itry))
            #[0.(nP,nQ),1.(nQ,nQ),2.(nP,nQ),3.(nQ,nQ),4.(nQ,),5.(nQ*(nS-1),),6.(1,),7.(nQ*(nT-1),),8.(1,),9.(1,),10.(nQ,),11.(nQ*nS*nT,)]
            W10list = [
                init_noise * (ihyper + 1) * np.random.rand(*shp)
                for shp in shapes1
            ]
            W20list = [
                init_noise * (ihyper + 1) * np.random.rand(*shp)
                for shp in shapes2
            ]
            #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1])))
            #print('size of w10: '+str(np.sum([np.size(x) for x in W10list])))
            #print('len(W10list) : '+str(len(W10list)))
            counter = 0
            for ishp, shp in enumerate(shapes1):
                W10list[ishp][negatize[ishp]] = -W10list[ishp][negatize[ishp]]
            W10list[4] = np.ones(shapes1[4])  # s02
            W10list[5] = np.ones(shapes1[5])  # K
            W10list[6] = np.ones(shapes1[6])  # kappa
            W10list[7] = np.ones(shapes1[7])  # T
            W10list[8] = np.zeros(shapes1[8])  # h1
            W10list[9] = np.zeros(shapes1[9])  # h2
            W10list[10] = np.zeros(shapes1[10])  # baseline
            W10list[11] = np.ones(shapes1[11])  # amplitude
            W20list[0] = np.concatenate(Xhat, axis=1)  #XX
            W20list[1] = np.zeros_like(W20list[1])  #XXp
            W20list[2] = Eta0.copy()  #np.zeros(shapes[10]) #Eta
            W20list[3] = np.zeros(shapes2[3])  #Xi
            #[Wmx,Wmy,Wsx,Wsy,s02,k,kappa,T,XX,XXp,Eta,Xi]
            if init_W_from_lsq:
                W10list[0], W10list[1] = initialize_W(Xhat,
                                                      Yhat,
                                                      scale_by=scale_init_by,
                                                      allow_s2=allow_s2)
                for ivar in range(0, 2):
                    W10list[
                        ivar] = W10list[ivar] + init_noise * np.random.randn(
                            *W10list[ivar].shape)
            if init_W_from_lbfgs:
                print(opt)
                opt_param, result, _, _, _, _, _, _, _, _, _, _, _ = fmc.initialize_params(
                    XXhat, YYhat, opt, wpcpc=5, wpvpv=-6)
                these_shapes = [(nP, nQ), (nQ, nQ), (nQ, ), (nQ, ), (nQ, ),
                                (nQ, )]
                Wmx0, Wmy0, K0, s020, amplitude0, baseline0 = calnet.utils.parse_thing(
                    opt_param, these_shapes)
                if init_Eta_with_s02:
                    #assert(True==False)
                    Eta0 = invert_f_mt_with_s02(YYhat -
                                                np.tile(baseline0, nS * nT),
                                                s020,
                                                nS=nS,
                                                nT=nT)
                    W20list[2] = Eta0.copy()
                #Wmx0 = opt_param[:nP]
                #Wmy0 = opt_param[nP:nP+nQ]
                #K0 = opt_param[nP+nQ]
                #s020 = opt_param[nP+nQ+1]
                #amplitude0 = opt_param[nP+nQ+2]
                #baseline0 = opt_param[nP+nQ+3]
                print((Wmx0, Wmy0, K0, s020, np.tile(amplitude0,
                                                     2), baseline0))
                W10list[0], W10list[1], W10list[5], W10list[4], W10list[
                    -1], W10list[-2] = Wmx0, Wmy0, K0, s020, np.tile(
                        amplitude0, 2), baseline0
                for ivar in range(0, 2):
                    W10list[
                        ivar] = W10list[ivar] + init_noise * np.random.randn(
                            *W10list[ivar].shape)
            elif constrain_isn:
                W10list[1][0, 0] = 3
                if help_constrain_isn:
                    W10list[1][0, 3] = 5
                    W10list[1][3, 0] = -5
                    W10list[1][3, 3] = -5
                else:
                    W10list[1][0, 1:4] = 5
                    W10list[1][1:4, 0] = -5

            if init_W_from_file:
                npyfile = np.load(init_file, allow_pickle=True)[()]

                #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp = parse_W1(W1)
                #XX,XXp,Eta,Xi = parse_W2(W2)
                #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2,bl,amp = parse_W1(W1)
                W10list = [
                    npyfile['as_list'][ivar]
                    for ivar in [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]
                ]
                W20list = [npyfile['as_list'][ivar] for ivar in [8, 9, 10, 11]]
                if W20list[0].size == nN * nS * 2 * nP:
                    #assert(True==False)
                    W10list[7] = np.array(())
                    W10list[1][1, 0] = W10list[1][1, 0]
                    W20list[0] = np.nanmean(
                        W20list[0].reshape((nN, nS, 2, nP)), 2).flatten()  #XX
                    W20list[1] = np.nanmean(
                        W20list[1].reshape((nN, nS, 2, nP)), 2).flatten()  #XXp
                    W20list[2] = np.nanmean(
                        W20list[2].reshape((nN, nS, 2, nQ)), 2).flatten()  #Eta
                    W20list[3] = np.nanmean(
                        W20list[3].reshape((nN, nS, 2, nQ)), 2).flatten()  #Xi
                if correct_Eta:
                    #assert(True==False)
                    W20list[2] = Eta0.copy()
                if len(W10list) < len(shapes1):
                    #assert(True==False)
                    W10list = W10list + [
                        np.array(1),
                        np.zeros((nQ, )),
                        np.zeros((nT * nS * nQ, ))
                    ]  # add h2, bl, amp
                if init_Eta_with_s02:
                    #assert(True==False)
                    s02 = W10list[4].copy()
                    Eta0 = invert_f_mt_with_s02(YYhat, s02, nS=nS, nT=nT)
                    W20list[2] = Eta0.copy()
                #if init_Eta12_with_dYY:
                #    Eta0 = W20list[2].copy().reshape((nN,nQ*nS*nT))
                #    Xi0 = W20list[3].copy().reshape((nN,nQ*nS*nT))
                #    s020 = W10list[4].copy()
                #    YY0s = compute_f_(Eta0,Xi0,s020)
                #titles = ['VIP silencing','VIP activation']
                #for itype in [0,1,2,3]:
                #    plt.figure(figsize=(5,2.5))
                #    for iyy,yy in enumerate([YY10s,YY20s]):
                #        plt.subplot(1,2,iyy+1)
                #        if np.sum(np.isnan(yy[:,itype]))==0:
                #            sca.scatter_size_contrast(YY0s[:,itype],yy[:,itype],nsize=6,ncontrast=6)#,mn=0)
                #        plt.title(titles[iyy])
                #        plt.xlabel('cell type %d event rate, \n light off'%itype)
                #        plt.ylabel('cell type %d event rate, \n light on'%itype)
                #        ut.erase_top_right()
                #    plt.tight_layout()
                #    ut.mkdir('figures')
                #    plt.savefig('figures/scatter_light_on_light_off_init_celltype_%d.eps'%itype)
                for ivar in [0, 1, 4, 5]:  # Wmx, Wmy, s02, k
                    print(init_noise)
                    W10list[
                        ivar] = W10list[ivar] + init_noise * np.random.randn(
                            *W10list[ivar].shape)

            #print('size of bounds1: '+str(np.sum([np.size(x) for x in bd1list])))
            #print('size of w10: '+str(np.sum([np.size(x) for x in W10list])))
            #print('size of shapes1: '+str(np.sum([np.prod(shp) for shp in shapes1])))
            W1t[ihyper][itry], W2t[ihyper][itry], loss[ihyper][
                itry], gr, hess, result = calnet.fitting_2step_spatial_feature_opto_tight_nonlinear_baseline.fit_W_sim(
                    Xhat,
                    Xpc_list,
                    Yhat,
                    Ypc_list,
                    pop_rate_fn=pop_rate_fn,
                    pop_deriv_fn=pop_deriv_fn,
                    W10list=W10list.copy(),
                    W20list=W20list.copy(),
                    bounds1=bounds1,
                    bounds2=bounds2,
                    niter=niter,
                    wt_dict=wt_dict,
                    l2_penalty=l2_penalty,
                    compute_hessian=False,
                    dt=dt,
                    perturbation_size=perturbation_size,
                    dYY=dYY,
                    constrain_isn=constrain_isn,
                    tv=tv,
                    opto_mask=opto_mask,
                    use_opto_transforms=use_opto_transforms,
                    opto_transform1=opto_transform1,
                    opto_transform2=opto_transform2,
                    share_residuals=share_residuals,
                    stimwise=stimwise,
                    simulate1=simulate1,
                    simulate2=simulate2,
                    verbose=verbose)

    #def parse_W(W):
    #    Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2 = W
    #    return Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2
    def parse_W1(W):
        Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, h1, h2, bl, amp = W
        return Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, h1, h2, bl, amp

    def parse_W2(W):
        XX, XXp, Eta, Xi = W
        return XX, XXp, Eta, Xi

    itry = 0
    Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, h1, h2, bl, amp = parse_W1(W1t[0][0])
    XX, XXp, Eta, Xi = parse_W2(W2t[0][0])

    labels1 = [
        'Wmx', 'Wmy', 'Wsx', 'Wsy', 's02', 'K', 'kappa', 'T', 'h1', 'h2', 'bl',
        'amp'
    ]
    labels2 = ['XX', 'XXp', 'Eta', 'Xi']
    Wstar_dict = {}
    for i, label in enumerate(labels1):
        Wstar_dict[label] = W1t[0][0][i]
    for i, label in enumerate(labels2):
        Wstar_dict[label] = W2t[0][0][i]
    Wstar_dict['as_list'] = [
        Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, h1, h2, bl, amp
    ]
    Wstar_dict['loss'] = loss[0][0]
    Wstar_dict['wt_dict'] = wt_dict
    np.save(weights_file, Wstar_dict, allow_pickle=True)
def DSoptimizeSingleTraj(ftr,
                         weights,
                         featureList,
                         fixHead=2,
                         divlim=0.5,
                         feedGradient=True):
    """
    ##### arg traj_to_simu:   [n x 2], each point is (s, d)
    arg weights:        The weights of the features, (should consider the normalizer)
    arg featureList:    The features to compute, same order with weights
    arg ftr:            The feature object to compute features, which holds the refline and other information
    arg fixHead:        Add the constraint to fix the heading n time steps as the init condition
    arg divlim:         The limit of the diviation
    """
    traj_to_simu = ftr.spaceTransFromXY()
    FIXDIV = divlim == 0  # and False
    timescale = len(traj_to_simu)
    len_dsvec = timescale if FIXDIV else timescale * 2
    assert (len(weights) == len(featureList))

    def objective(ds_vec):
        # print(ds_vec)
        if (FIXDIV):
            ds_vec = np.concatenate([ds_vec, np.zeros_like(ds_vec)], axis=0)
        # ds_vec = ds_vec.reshape(-1,2)
        ds_vec = ds_vec.reshape(2, -1).T
        ftr.update(ds_vec)
        obj = np.sum(
            [ftr.featureValue(f) * w for w, f in zip(weights, featureList)])
        return obj

    def jac(ds_vec):
        if (FIXDIV):
            ds_vec = np.concatenate([ds_vec, np.zeros_like(ds_vec)], axis=0)
        # ds_vec = ds_vec.reshape(-1,2)
        ds_vec = ds_vec.reshape(2, -1).T
        ftr.update(ds_vec)
        g = np.sum([
            ftr.featureGradJacobNormalizer(f, False, False)[0][:] * w
            for w, f in zip(weights, featureList)
        ],
                   axis=0)
        if (FIXDIV):
            return g[:len(g) // 2]
        # print(g.shape)
        return g

    def Hes(ds_vec):
        """
        This is not used by SQP
        """
        H = np.sum([
            ftr.featureGradJacobNormalizer(f, False)[1][:, :] * w
            for w, f in zip(weights, featureList)
        ],
                   axis=0)
        if (FIXDIV):
            return H[:timescale, :timescale]
        # print(g.shape)
        return H

    def seq_cons(ds_vec):
        """
        this constraint ensures that points are one in a sequence.
        """
        MaxLengthFactor = 1.5  # the length of optimize result cannot be more than this times the original length
        if (FIXDIV):
            # return ds_vec[1:] - ds_vec[:-1]
            return np.array(
                list(ds_vec[1:] - ds_vec[:-1]) +
                [MaxLengthFactor * traj_to_simu[-1, 0] - ds_vec[-1]])
        ds_vec = ds_vec.reshape(2, -1).T
        return np.array(
            list(ds_vec[1:, 0] - ds_vec[:-1, 0]) +
            [MaxLengthFactor * traj_to_simu[-1, 0] - ds_vec[-1, 0]])

    seq_jac = np.zeros((timescale, len_dsvec))
    seq_jac += np.eye(timescale) * (-1)
    seq_jac[:-1, 1:timescale] += np.eye(timescale - 1) * 1

    def init_cons(ds_vec):
        """
        This constraint on the x0 y0
        """
        # ds_vec = ds_vec.reshape(-1,2)
        if (FIXDIV):
            return ds_vec[:fixHead] - traj_to_simu[:fixHead, 0]
        ds_vec = ds_vec.reshape(2, -1).T
        return (ds_vec[:fixHead] - traj_to_simu[:fixHead]).T.reshape(-1)

    if (FIXDIV):
        init_jac = np.zeros((fixHead, len_dsvec))
        init_jac[:fixHead, :fixHead] = np.eye(fixHead) * 1
    else:
        init_jac = np.zeros((fixHead * 2, len_dsvec))
        init_jac[:fixHead, :fixHead] = np.eye(fixHead) * 1
        init_jac[fixHead:, timescale:timescale + fixHead] = np.eye(fixHead) * 1

    if (FIXDIV):
        lb, ub = 0, float("inf")
        bounds = np.tile(np.array([lb, ub]), (len(traj_to_simu), 1))
    else:
        lb = (np.array([0, -divlim]) *
              np.ones_like(traj_to_simu)).T.reshape(-1)
        ub = (np.array([float("inf"), divlim]) *
              np.ones_like(traj_to_simu)).T.reshape(-1)
        bounds = np.concatenate([lb[:, None], ub[:, None]], axis=1)
    if (FIXDIV):
        x0 = traj_to_simu[:, 0]
    else:
        x0 = traj_to_simu.T.reshape(-1)

    constraints = [{
        'type': 'ineq',
        'fun': seq_cons,
        'jac': lambda v: seq_jac
    }, {
        'type': 'eq',
        'fun': init_cons,
        'jac': lambda v: init_jac
    }]
    options = {"maxiter": 50000, "disp": 2}
    if (not feedGradient):
        jac = None
    # print(bounds[:,0])
    # print(x0)
    # assert(np.all(bounds[:,0]<=x0) and np.all(x0<=bounds[:,1]) and np.all(seq_cons(x0)>=0))
    res = minimize(objective,
                   x0,
                   bounds=bounds,
                   jac=jac,
                   hess=Hes,
                   constraints=constraints,
                   options=options)
    if (FIXDIV):
        return np.concatenate([res.x, np.zeros_like(res.x)], axis=0)
    # assert(res.success)
    return res.x
예제 #10
0
# plt.plot(outputs_training, '.')
# misfit.inputs = inputs_training
# outputs_forward = misfit.forward(x)
# # print("outputs_forward = ", outputs_forward)
# plt.plot(outputs_forward, 'r.')
#
# plt.subplot(2, 1, 2)
# plt.plot(outputs_testing, '.')
# misfit.inputs = inputs_testing
# outputs_forward = misfit.forward(x)
# # print("outputs_forward = ", outputs_forward)
# plt.plot(outputs_forward, 'r.')
# plt.show()

# misfit.inputs = inputs_training
sigma = 10 * np.ones_like(x)
prior = FiniteDimensionalPrior(dimension, sigma=sigma, uniform=True)
# prior = FiniteDimensionalPrior(dimension)
prior.CovSqrt = 1 * np.diag(np.ones_like(x))

model = Model(prior, misfit)

if __name__ == "__main__":

    test_time = time.time()
    for i in range(100):
        x0 = np.random.normal(0, 1, dimension)
        prior.cost(x)
    print("average cost time = ", (time.time() - test_time) / 100)

    # check_gradient(x0, misfit)
예제 #11
0
    misfit = Misfit(dimension, beta)
    u = misfit.solution(x)
    loc = np.arange(0, dimension, 5)
    sigma = 0.1
    obs = u[loc] + sigma + 0.*np.random.normal(0, sigma, len(loc))
    noise_covariance = np.diag(sigma**2*np.ones(len(loc)))
    misfit = Misfit(dimension, beta, obs, loc, noise_covariance)

    misfit.x = x
    model = Model(prior, misfit)

    d, U = misfit.eigdecomp(x, k=dimension-1)

    # print("d, U = ", d, U)

    x = np.ones_like(x)
    cost = misfit.cost(x)
    g = np.zeros(dimension)
    gx = misfit.grad(x, g)
    xhat = np.ones(dimension)
    h = np.zeros(dimension)
    misfit.hvp.update_x(x)
    misfit.hvp.mult(xhat, h)
    print("cost, grad = ", cost, g)
    # print("hvp", h)

    mg = np.zeros(dimension)
    model.gradient(x, mg, misfit_only=False)
    # print("model cost, grad = ", model.cost(x), mg)

    comm = MPI.COMM_WORLD
plt.legend(fontsize=10)

plt.tight_layout()
# fig.savefig('{}/fig-single-traj.{}'.format(args.fig_dir, FORMAT))

#%% [markdown]
# ## Plot learnt functions
#%%
fig = plt.figure(figsize=(9.6, 2.5), dpi=DPI)
q = np.linspace(-5.0, 5.0, 40)
q_tensor = torch.tensor(q, dtype=torch.float32).view(40, 1).to(device)

plt.subplot(1, 3, 1)

g_q = symoden_ode_struct_model.g_net(q_tensor)
plt.plot(q, np.ones_like(q), label='Ground Truth', color='k', linewidth=2)
plt.plot(q,
         g_q.detach().cpu().numpy(),
         'b--',
         linewidth=3,
         label=r'SymODEN $g_{\theta_3}(q)$')
plt.xlabel("$q$", fontsize=14)
# plt.ylabel("$g(q)$", rotation=0, fontsize=14)
plt.title("$g(q)$", pad=10, fontsize=14)
plt.xlim(-5, 5)
plt.ylim(0, 4)
plt.legend(fontsize=10)

M_q_inv = symoden_ode_struct_model.M_net(q_tensor)
plt.subplot(1, 3, 2)
plt.plot(q, 3 * np.ones_like(q), label='Ground Truth', color='k', linewidth=2)
예제 #13
0
def fit_weights_and_save(weights_file,
                         ca_data_file='rs_vm_denoise_200605.npy',
                         opto_data_file='vip_halo_data_for_sim.npy',
                         constrain_wts=None,
                         allow_var=True,
                         multiout=True,
                         multiout2=False,
                         fit_s02=True,
                         constrain_isn=True,
                         tv=False,
                         l2_penalty=0.01,
                         init_noise=0.1,
                         init_W_from_lsq=False,
                         scale_init_by=1,
                         init_W_from_file=False,
                         init_file=None):

    nsize, ncontrast = 6, 6

    npfile = np.load(ca_data_file, allow_pickle=True)[(
    )]  #,{'rs':rs},allow_pickle=True) # ,'rs_denoise':rs_denoise
    rs = npfile['rs']

    nsize, ncontrast, ndir = 6, 6, 8
    ori_dirs = [[0, 4], [2, 6]]  #[[0,4],[1,3,5,7],[2,6]]
    nT = len(ori_dirs)
    nS = len(rs[0])

    def sum_to_1(r):
        R = r.reshape((r.shape[0], -1))
        R = R / np.nansum(R, axis=1)[:, np.newaxis]  # changed 8/28
        return R

    def norm_to_mean(r):
        R = r.reshape((r.shape[0], -1))
        R = R / np.nanmean(R[:, ~np.isnan(R.sum(0))], axis=1)[:, np.newaxis]
        return R

    Rs = [[None, None] for i in range(len(rs))]
    Rso = [[[None for iT in range(nT)] for iS in range(nS)]
           for icelltype in range(len(rs))]
    rso = [[[None for iT in range(nT)] for iS in range(nS)]
           for icelltype in range(len(rs))]

    for iR, r in enumerate(rs):  #rs_denoise):
        print(iR)
        for ialign in range(nS):
            Rs[iR][ialign] = sum_to_1(r[ialign][:, :nsize, :])

    kernel = np.ones((1, 2, 2))
    kernel = kernel / kernel.sum()

    for iR, r in enumerate(rs):
        for ialign in range(nS):
            for iori in range(nT):
                Rso[iR][ialign][iori] = np.nanmean(
                    Rs[iR][ialign].reshape(
                        (-1, nsize, ncontrast, ndir))[:, :, :, ori_dirs[iori]],
                    -1)
                Rso[iR][ialign][iori][:, :, 0] = np.nanmean(
                    Rso[iR][ialign][iori][:, :, 0], 1)[:, np.newaxis]

                Rso[iR][ialign][iori][:, 1:, 1:] = ssi.convolve(
                    Rso[iR][ialign][iori], kernel, 'valid')
                Rso[iR][ialign][iori] = Rso[iR][ialign][iori].reshape(
                    Rso[iR][ialign][iori].shape[0], -1)

    def set_bound(bd, code, val=0):
        # set bounds to 0 where 0s occur in 'code'
        for iitem in range(len(bd)):
            bd[iitem][code[iitem]] = val

    nN = 36
    nS = 2
    nP = 2
    nT = 2
    nQ = 4

    # code for bounds: 0 , constrained to 0
    # +/-1 , constrained to +/-1
    # 1.5, constrained to [0,1]
    # 2 , constrained to [0,inf)
    # -2 , constrained to (-inf,0]
    # 3 , unconstrained

    W0x_bounds = 3 * np.ones((nP, nQ), dtype=int)
    W0x_bounds[0, 1] = 0  # SSTs don't receive L4 input

    if allow_var:
        W1x_bounds = 3 * np.ones(
            W0x_bounds.shape)  #W0x_bounds.copy()*0 #np.zeros_like(W0x_bounds)
        W1x_bounds[0, 1] = 0
    else:
        W1x_bounds = np.zeros(
            W0x_bounds.shape)  #W0x_bounds.copy()*0 #np.zeros_like(W0x_bounds)

    W0y_bounds = 3 * np.ones((nQ, nQ), dtype=int)
    W0y_bounds[0, :] = 2  # PCs are excitatory
    W0y_bounds[1:, :] = -2  # all the cell types except PCs are inhibitory
    W0y_bounds[1, 1] = 0  # SSTs don't inhibit themselves
    # W0y_bounds[3,1] = 0 # PVs are allowed to inhibit SSTs, consistent with Hillel's unpublished results, but not consistent with Pfeffer et al.
    W0y_bounds[
        2,
        0] = 0  # VIPs don't inhibit L2/3 PCs. According to Pfeffer et al., only L5 PCs were found to get VIP inhibition

    if not constrain_wts is None:
        for wt in constrain_wts:
            W0y_bounds[wt[0], wt[1]] = 0
            Wsy_bounds[wt[0], wt[1]] = 0

    def tile_nS_nT_nN(kernel):
        row = np.concatenate([kernel for idim in range(nS * nT)],
                             axis=0)[np.newaxis, :]
        tiled = np.concatenate([row for irow in range(nN)], axis=0)
        return tiled

    if fit_s02:
        s02_bounds = 2 * np.ones(
            (nQ, ))  # permitting noise as a free parameter
    else:
        s02_bounds = np.ones((nQ, ))

    k0_bounds = 1.5 * np.ones((nQ, ))

    kappa_bounds = np.ones((1, ))
    # kappa_bounds = 2*np.ones((1,))

    T0_bounds = 1.5 * np.ones((nQ, ))
    #T_bounds[2:4] = 1 # PV and VIP are constrained to have flat ori tuning
    #T0_bounds[1:4] = 1 # SST,VIP, and PV are constrained to have flat ori tuning

    if allow_var:
        W1y_bounds = 3 * np.ones(
            W0y_bounds.shape)  #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        W1y_bounds[1, 1] = 0
        W1y_bounds[3, 1] = 0
        W1y_bounds[2, 0] = 0
        k1_bounds = 3 * np.ones(
            k0_bounds.shape)  #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        T1_bounds = 3 * np.ones(
            T0_bounds.shape)  #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
    else:
        W1y_bounds = np.zeros(
            W0y_bounds.shape)  #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        k1_bounds = 0 * np.ones(
            k0_bounds.shape)  #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        T1_bounds = 0 * np.ones(
            T0_bounds.shape)  #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)

    if multiout:
        W2x_bounds = W1x_bounds.copy()
        W2y_bounds = W1y_bounds.copy()
        if multiout2:
            W3x_bounds = W1x_bounds.copy()
            W3y_bounds = W1y_bounds.copy()
        else:
            W3x_bounds = W1x_bounds.copy() * 0
            W3y_bounds = W1y_bounds.copy() * 0
    else:
        W2x_bounds = W1x_bounds.copy() * 0
        W2y_bounds = W1y_bounds.copy() * 0
        W3x_bounds = W1x_bounds.copy() * 0
        W3y_bounds = W1y_bounds.copy() * 0
    k2_bounds = k1_bounds.copy() * 0
    T2_bounds = T1_bounds.copy() * 0
    k3_bounds = k1_bounds.copy() * 0
    T3_bounds = T1_bounds.copy() * 0

    X_bounds = tile_nS_nT_nN(np.array([2, 1]))
    # X_bounds = np.array([np.array([2,1,2,1])]*nN)

    Xp_bounds = tile_nS_nT_nN(np.array([3, 1]))
    # Xp_bounds = np.array([np.array([3,1,3,1])]*nN)

    # Y_bounds = tile_nS_nT_nN(2*np.ones((nQ,)))
    # # Y_bounds = 2*np.ones((nN,nT*nS*nQ))

    Eta_bounds = tile_nS_nT_nN(3 * np.ones((nQ, )))
    # Eta_bounds = 3*np.ones((nN,nT*nS*nQ))

    if allow_var:
        Xi_bounds = tile_nS_nT_nN(3 * np.ones((nQ, )))
    else:
        Xi_bounds = tile_nS_nT_nN(np.zeros((nQ, )))

    # Xi_bounds = 3*np.ones((nN,nT*nS*nQ))

    h_bounds = -2 * np.ones((1, ))

    # shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nN,nS*nP),(nN,nS*nQ),(nN,nS*nQ),(nN,nS*nQ)]
    shapes = [(nP, nQ), (nQ, nQ), (nP, nQ), (nQ, nQ), (nP, nQ), (nQ, nQ),
              (nP, nQ), (nQ, nQ), (nQ, ), (nQ * (nS - 1), ), (nQ * (nS - 1), ),
              (nQ * (nS - 1), ), (nQ * (nS - 1), ), (1, ), (nQ * (nT - 1), ),
              (nQ * (nT - 1), ), (nQ * (nT - 1), ), (nQ * (nT - 1), ),
              (nN, nT * nS * nP), (nN, nT * nS * nP), (nN, nT * nS * nQ),
              (nN, nT * nS * nQ), (1, )]
    #         W0x,    W0y,    W1x,    W1y,    W2x,    W2y,    W3x,    W3y,    s02,  k,    kappa,T,   XX,            XXp,          Eta,          Xi

    lb = [-np.inf * np.ones(shp) for shp in shapes]
    ub = [np.inf * np.ones(shp) for shp in shapes]
    bdlist = [
        W0x_bounds, W0y_bounds, W1x_bounds, W1y_bounds, W2x_bounds, W2y_bounds,
        W3x_bounds, W3y_bounds, s02_bounds, k0_bounds, k1_bounds, k2_bounds,
        k3_bounds, kappa_bounds, T0_bounds, T1_bounds, T2_bounds, T3_bounds,
        X_bounds, Xp_bounds, Eta_bounds, Xi_bounds, h_bounds
    ]

    #print([b.shape for b in bdlist])
    #print(np.sum([b.size for b in bdlist]))

    set_bound(lb, [bd == 0 for bd in bdlist], val=0)
    set_bound(ub, [bd == 0 for bd in bdlist], val=0)

    set_bound(lb, [bd == 2 for bd in bdlist], val=0)

    set_bound(ub, [bd == -2 for bd in bdlist], val=0)

    set_bound(lb, [bd == 1 for bd in bdlist], val=1)
    set_bound(ub, [bd == 1 for bd in bdlist], val=1)

    set_bound(lb, [bd == 1.5 for bd in bdlist], val=0)
    set_bound(ub, [bd == 1.5 for bd in bdlist], val=1)

    set_bound(lb, [bd == -1 for bd in bdlist], val=-1)
    set_bound(ub, [bd == -1 for bd in bdlist], val=-1)

    # for bd in [lb,ub]:
    #     for ind in [2,3]:
    #         bd[ind][:,1] = 0

    # temporary for no variation expt.
    # lb[2] = np.zeros_like(lb[2])
    # lb[3] = np.zeros_like(lb[3])
    # lb[4] = np.ones_like(lb[4])
    # lb[5] = np.zeros_like(lb[5])
    # ub[2] = np.zeros_like(ub[2])
    # ub[3] = np.zeros_like(ub[3])
    # ub[4] = np.ones_like(ub[4])
    # ub[5] = np.ones_like(ub[5])
    # temporary for no variation expt.
    lb = np.concatenate([a.flatten() for a in lb])
    ub = np.concatenate([b.flatten() for b in ub])
    bounds = [(a, b) for a, b in zip(lb, ub)]

    nS = 2
    ndims = 5
    ncelltypes = 5
    Yhat = [[None for iT in range(nT)] for iS in range(nS)]
    Xhat = [[None for iT in range(nT)] for iS in range(nS)]
    Ypc_list = [[None for iT in range(nT)] for iS in range(nS)]
    Xpc_list = [[None for iT in range(nT)] for iS in range(nS)]
    for iS in range(nS):
        mx = np.zeros((ncelltypes, ))
        yy = [None for icelltype in range(ncelltypes)]
        for icelltype in range(ncelltypes):
            yy[icelltype] = np.nanmean(Rso[icelltype][iS][0], 0)
            mx[icelltype] = np.nanmax(yy[icelltype])
        for iT in range(nT):
            y = [
                np.nanmean(Rso[icelltype][iS][iT], axis=0)[:, np.newaxis] /
                mx[icelltype] for icelltype in range(1, ncelltypes)
            ]
            Ypc_list[iS][iT] = [None for icelltype in range(1, ncelltypes)]
            for icelltype in range(1, ncelltypes):
                rss = Rso[icelltype][iS][iT].copy(
                )  #.reshape(Rs[icelltype][ialign].shape[0],-1)
                rss = rss[np.isnan(rss).sum(1) == 0]
                try:
                    u, s, v = np.linalg.svd(rss - np.mean(rss, 0)[np.newaxis])
                    Ypc_list[iS][iT][icelltype - 1] = [
                        (s[idim], v[idim]) for idim in range(ndims)
                    ]
                except:
                    print('nope on Y')
            Yhat[iS][iT] = np.concatenate(y, axis=1)
            icelltype = 0
            x = np.nanmean(Rso[icelltype][iS][iT],
                           0)[:, np.newaxis] / mx[icelltype]
            Xhat[iS][iT] = np.concatenate((x, np.ones_like(x)), axis=1)
            icelltype = 0
            rss = Rso[icelltype][iS][iT].copy()
            rss = rss[np.isnan(rss).sum(1) == 0]
            u, s, v = np.linalg.svd(rss - rss.mean(0)[np.newaxis])
            Xpc_list[iS][iT] = [None for iinput in range(2)]
            Xpc_list[iS][iT][0] = [(s[idim], v[idim]) for idim in range(ndims)]
            Xpc_list[iS][iT][1] = [(0, np.zeros((Xhat[0][0].shape[0], )))
                                   for idim in range(ndims)]
    nN, nP = Xhat[0][0].shape
    nQ = Yhat[0][0].shape[1]

    def compute_f_(Eta, Xi, s02):
        return sim_utils.f_miller_troyer(
            Eta, Xi**2 + np.concatenate([s02 for ipixel in range(nS * nT)]))

    def compute_fprime_m_(Eta, Xi, s02):
        return sim_utils.fprime_miller_troyer(
            Eta, Xi**2 + np.concatenate([s02
                                         for ipixel in range(nS * nT)])) * Xi

    def compute_fprime_s_(Eta, Xi, s02):
        s2 = Xi**2 + np.concatenate((s02, s02), axis=0)
        return sim_utils.fprime_s_miller_troyer(Eta, s2) * (Xi / s2)

    def sorted_r_eigs(w):
        drW, prW = np.linalg.eig(w)
        srtinds = np.argsort(drW)
        return drW[srtinds], prW[:, srtinds]

    #         0.W0x,  1.W0y,  2.W1x,  3.W1y,  4.W2x,  5.W2y,  6.W3x,  7.W3y,  8.s02,9.K,  10.kappa,11.T,12.XX,        13.XXp,        14.Eta,       15.Xi    16.h

    shapes = [(nP, nQ), (nQ, nQ), (nP, nQ), (nQ, nQ), (nP, nQ), (nQ, nQ),
              (nP, nQ), (nQ, nQ), (nQ, ), (nQ * (nS - 1), ), (nQ * (nS - 1), ),
              (nQ * (nS - 1), ), (nQ * (nS - 1), ), (1, ), (nQ * (nT - 1), ),
              (nQ * (nT - 1), ), (nQ * (nT - 1), ), (nQ * (nT - 1), ),
              (nN, nT * nS * nP), (nN, nT * nS * nP), (nN, nT * nS * nQ),
              (nN, nT * nS * nQ), (1, )]

    import calnet.fitting_spatial_feature
    import sim_utils

    opto_dict = np.load(opto_data_file, allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = Yhat_opto / Yhat_opto[0::2].max(0)[np.newaxis, :]
    #print(Yhat_opto.shape)
    h_opto = opto_dict['h_opto']
    dYY = Yhat_opto[1::2] - Yhat_opto[0::2]
    for to_overwrite in [1, 2, 5, 6]:
        dYY[:, to_overwrite] = dYY[:, to_overwrite + 8]
    for to_overwrite in [11, 15]:
        dYY[:, to_overwrite] = dYY[:, to_overwrite - 8]

    from importlib import reload
    reload(calnet)
    reload(calnet.fitting_spatial_feature_opto_multiout)
    reload(sim_utils)
    wt_dict = {}
    wt_dict['X'] = 1
    wt_dict['Y'] = 3
    wt_dict['Eta'] = 1  # 10
    wt_dict['Xi'] = 0.1
    wt_dict['stims'] = np.ones((nN, 1))  #(np.arange(30)/30)[:,np.newaxis]**1 #
    wt_dict['barrier'] = 0.  #30.0 #0.1
    wt_dict['opto'] = 1e0  #1e-1#1e1
    wt_dict['isn'] = 0.1
    wt_dict['tv'] = 1

    YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat)
    XXhat = calnet.utils.flatten_nested_list_of_2d_arrays(Xhat)
    Eta0 = invert_f_mt(YYhat)

    ntries = 1
    nhyper = 1
    dt = 1e-1
    niter = int(np.round(10 / dt))  #int(1e4)
    perturbation_size = 5e-2
    Wt = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    loss = np.zeros((nhyper, ntries))
    is_neg = np.array([b[1] for b in bounds]) == 0
    counter = 0
    negatize = [np.zeros(shp, dtype='bool') for shp in shapes]
    for ishp, shp in enumerate(shapes):
        nel = np.prod(shp)
        negatize[ishp][:][is_neg[counter:counter + nel].reshape(shp)] = True
        counter = counter + nel
    for ihyper in range(nhyper):
        for itry in range(ntries):
            print((ihyper, itry))
            W0list = [
                init_noise * (ihyper + 1) * np.random.rand(*shp)
                for shp in shapes
            ]
            counter = 0
            for ishp, shp in enumerate(shapes):
                W0list[ishp][negatize[ishp]] = -W0list[ishp][negatize[ishp]]
            nextraW = 4
            nextraK = nextraW + 3
            nextraT = nextraK + 3
            W0list[nextraW + 4] = np.ones(shapes[nextraW + 4])  # s02
            W0list[nextraW + 5] = np.ones(shapes[nextraW + 5])  # K
            W0list[nextraW + 6] = np.ones(shapes[nextraW + 6])  # K
            W0list[nextraW + 7] = np.ones(shapes[nextraW + 7])  # K
            W0list[nextraW + 8] = np.ones(shapes[nextraW + 8])  # K
            W0list[nextraK + 6] = np.ones(shapes[nextraK + 6])  # kappa
            W0list[nextraK + 7] = np.ones(shapes[nextraK + 7])  # T
            W0list[nextraK + 8] = np.ones(shapes[nextraK + 8])  # T
            W0list[nextraK + 9] = np.ones(shapes[nextraK + 9])  # T
            W0list[nextraK + 10] = np.ones(shapes[nextraK + 10])  # T
            W0list[nextraT + 8] = np.concatenate(Xhat, axis=1)  #XX
            W0list[nextraT + 9] = np.zeros_like(W0list[nextraT + 8])  #XXp
            W0list[nextraT + 10] = Eta0  #np.zeros(shapes[nextraT+10]) #Eta
            W0list[nextraT + 11] = np.zeros(shapes[nextraT + 11])  #Xi
            #[Wmx,Wmy,Wsx,Wsy,s02,k,kappa,T,XX,XXp,Eta,Xi,h]
            if init_W_from_lsq:
                W0list[0], W0list[1] = initialize_W(Xhat,
                                                    Yhat,
                                                    scale_by=scale_init_by)
                for ivar in range(0, 2):
                    W0list[ivar] = W0list[ivar] + init_noise * np.random.randn(
                        *W0list[ivar].shape)
            if constrain_isn:
                W0list[1][0, 0] = 3
                W0list[1][0, 3] = 5
                W0list[1][3, 0] = -5
                W0list[1][3, 3] = -5

            if init_W_from_file:
                npyfile = np.load(init_file, allow_pickle=True)[()]
                W0list = npyfile['as_list']

                extra_Ws = [np.zeros_like(W0list[ivar]) for ivar in range(2)]
                extra_ks = [np.zeros_like(W0list[5]) for ivar in range(3)]
                extra_Ts = [np.zeros_like(W0list[7]) for ivar in range(3)]
                W0list = W0list[:4] + extra_Ws * 2 + W0list[
                    4:6] + extra_ks + W0list[6:8] + extra_Ts + W0list[8:]

                #W0list[7][0] = 0 # T

                # alternative initialization
                #n = 0.5
                #W0list[7][0] = 1/(n+1)*(W0list[7][0] + n*0) # T
                #W0list[7][3] = 1/(n+1)*(W0list[7][3] + n*1) # T
                #W0list[1][1,0] = W0list[1][1,0]

                #[W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,k,kappa,T,XX,XXp,Eta,Xi,h]
                for ivar in np.concatenate(
                    (np.arange(13), np.arange(14, 18))):  # Ws, s02, k
                    W0list[ivar] = W0list[ivar] + init_noise * np.random.randn(
                        *W0list[ivar].shape)
                #print([b.shape for b in bdlist])
                #print(np.sum([b.size for b in bdlist]))

            Wt[ihyper][itry], loss[ihyper][
                itry], gr, hess, result = calnet.fitting_spatial_feature_opto_multiout.fit_W_sim(
                    Xhat,
                    Xpc_list,
                    Yhat,
                    Ypc_list,
                    pop_rate_fn=sim_utils.f_miller_troyer,
                    pop_deriv_fn=sim_utils.fprime_miller_troyer,
                    neuron_rate_fn=sim_utils.evaluate_f_mt,
                    W0list=W0list.copy(),
                    bounds=bounds,
                    niter=niter,
                    wt_dict=wt_dict,
                    l2_penalty=l2_penalty,
                    compute_hessian=False,
                    dt=dt,
                    perturbation_size=perturbation_size,
                    dYY=dYY,
                    constrain_isn=constrain_isn,
                    tv=tv)

    def parse_W(W):
        W0x, W0y, W1x, W1y, W2x, W2y, W3x, W3y, s02, k0, k1, k2, k3, kappa, T0, T1, T2, T3, XX, XXp, Eta, Xi, h = W
        return W0x, W0y, W1x, W1y, W2x, W2y, W3x, W3y, s02, k0, k1, k2, k3, kappa, T0, T1, T2, T3, XX, XXp, Eta, Xi, h

    itry = 0
    W0x, W0y, W1x, W1y, W2x, W2y, W3x, W3y, s02, k0, k1, k2, k3, kappa, T0, T1, T2, T3, XX, XXp, Eta, Xi, h = parse_W(
        Wt[0][0])

    labels = [
        'W0x', 'W0y', 'W1x', 'W1y', 'W2x', 'W2y', 'W3x', 'W3y', 's02', 'K0',
        'K1', 'K2', 'K3', 'kappa', 'T0', 'T1', 'T2', 'T3', 'XX', 'XXp', 'Eta',
        'Xi', 'h'
    ]
    Wstar_dict = {}
    for i, label in enumerate(labels):
        Wstar_dict[label] = Wt[0][0][i]
    Wstar_dict['as_list'] = [
        W0x, W0y, W1x, W1y, W2x, W2y, W3x, W3y, s02, k0, k1, k2, k3, kappa, T0,
        T1, T2, T3, XX, XXp, Eta, Xi, h
    ]
    Wstar_dict['loss'] = loss[0][0]
    Wstar_dict['wt_dict'] = wt_dict
    np.save(weights_file, Wstar_dict, allow_pickle=True)
예제 #14
0
 def most_likely_states(self, variational_mean, data, input=None, mask=None, tag=None):
     pi0 = self.init_state_distn.initial_state_distn
     Ps = self.transitions.transition_matrices(variational_mean, input, mask, tag)
     log_likes = self.dynamics.log_likelihoods(variational_mean, input, np.ones_like(variational_mean, dtype=bool), tag)
     log_likes += self.emissions.log_likelihoods(data, input, mask, tag, variational_mean)
     return viterbi(pi0, Ps, log_likes)
예제 #15
0
    (zero mean and unit variance). The same procedure is then applied
    to the test set features.
    """
    train_mean = train.mean(axis=0)
    # +0.1 to avoid division by zero in this specific case
    train_std = train.std(axis=0) + 0.1

    train = (train - train_mean) / train_std
    test = (test - train_mean) / train_std
    return train, test


answers = mnist.target[:, np.newaxis]

data_and_answers = np.hstack(
    (np.ones_like(answers), mnist.data[:, :], answers))
np.random.shuffle(data_and_answers)

train_X = data_and_answers[:N, :-1]
train_y = data_and_answers[:N, -1].reshape((N, 1))

test_X = data_and_answers[N:, :-1]
test_y = data_and_answers[N:, -1].reshape((T, 1))

train_X, test_X = normalize_features(train_X, test_X)

calc_ksi = lambda W, X, N: (-np.log(
    np.power(np.sum(np.exp(np.dot(X, W)), axis=1), -1))).reshape((N, 1))

X_INITIAL = train_X
Y_INITIAL = train_y
예제 #16
0
def fit(request, hyper_params=default_hyper_params, nperiod=288):
    passed_hyper_params = hyper_params
    hyper_params = {}
    hyper_params.update(passed_hyper_params)
    hyper_params.update(request.hyper_params)

    logging.info(f"fitting model with hyper parameters {hyper_params}")

    frame = make_frame(request, hyper_params=hyper_params)

    basal_insulin_curve = expia1(
        np.arange(nperiod),
        request.basal_insulin_parameters.get("delay", 5.0) / 5.0,
        request.basal_insulin_parameters["peak"] / 5.0,
        request.basal_insulin_parameters["duration"] / 5.0,
    )
    # TODO: make this the average carb curve
    default_carb_curve = carb_curve(np.arange(nperiod), 3, 36)

    # Set up parameter schedules.
    #
    # We arrange for each of basal, insulin sensitivity, and carb ratios
    # to have 24 windows in each day.
    #
    # TODO: assign windows for carb ratios based on data density
    #
    # TODO: find a better initialization strategy when no schedules are provided
    #
    # Order is: basals, insulin sensitivities, carb ratios
    if request.insulin_sensitivity_schedule is not None:
        init_insulin_sensitivity_params = attribute_parameters(
            basal_insulin_curve, request.insulin_sensitivity_schedule.index,
            request.insulin_sensitivity_schedule.values)
    else:
        init_insulin_sensitivity_params = 140 * np.ones(24)

    if request.carb_ratio_schedule is not None:
        init_carb_ratio_params = attribute_parameters(
            default_carb_curve, request.carb_ratio_schedule.index,
            request.carb_ratio_schedule.values)
    else:
        init_carb_ratio_params = 15. * np.ones(24)

    if request.basal_rate_schedule is not None:
        init_basal_rate_params = attribute_parameters(
            basal_insulin_curve, request.basal_rate_schedule.index,
            request.basal_rate_schedule.values)
    else:
        init_basal_rate_params = np.zeros(24)

    init_params = np.concatenate([
        init_basal_rate_params, init_insulin_sensitivity_params,
        init_carb_ratio_params
    ])

    def unpack_params(params):
        return params[:24], params[24:48], params[48:72]

    insulin = frame["insulin"].values
    carbs = frame["carb"].values
    deltas = frame["delta"].values

    hour = frame.index.hour
    quantile = hyper_params["quantile_loss_quantile"]

    # Construct bounds based on the allowable tuning limit.
    if request.tuning_limit is not None and request.tuning_limit > 0:
        bounds = list(
            zip(init_params * (1 - request.tuning_limit),
                init_params * 1 + request.tuning_limit))
    else:
        bounds = None

    # Re-weight entries that have carbohydrate activity so that
    # the model prefers having (much) better carb parameters
    # over slightly worse-fitting sensitivity and basal parameters.
    weights = np.ones_like(deltas)
    weights[frame["carb"] > 0] = (np.sum(frame["carb"] == 0) /
                                  np.sum(frame["carb"] > 0))

    def model(params):
        basals, insulin_sensitivities, carb_ratios = unpack_params(params)
        basal = basals[hour]
        insulin_sensitivity = insulin_sensitivities[hour]
        carb_ratio = carb_ratios[hour]
        return insulin_sensitivity * (carbs / carb_ratio - insulin + basal)

    if bounds is not None:
        lower, upper = zip(*bounds)
        lower, upper = np.array(lower), np.array(upper)
        # This is a hack to get around the fact that basals are summed
        # over multiple hours. Thus this is only an approximate bounds,
        # but it's much simpler than the alternative.
        insulin_duration_hours = request.basal_insulin_parameters[
            "duration"] / 60.
        lower[:24] = lower[:24] / insulin_duration_hours
        upper[:24] = upper[:24] / insulin_duration_hours

    def loss(params, iter):
        preds = model(params)
        penalty = -10.0 * np.sum(np.minimum(params, 0.0))

        # Use a barrier function if bounds are provided.
        if bounds is not None:
            # HACK: simulate a "rectified" barrier function here.
            # Note also that this doesn't work for basals since they
            # are summed up.
            epsilon = 0.00001
            penalty_params = params.copy()
            penalty_params[penalty_params >=
                           upper] = upper[penalty_params > upper] - epsilon
            penalty_params[penalty_params <=
                           lower] = lower[penalty_params <= lower] + epsilon
            penalty += np.sum(
                np.maximum(0., -0.01 * np.log(upper - penalty_params)))
            penalty += np.sum(
                np.maximum(0., -0.01 * np.log(penalty_params - lower)))

        # Quantile regression: 50 pctile
        error = weights * (deltas - preds)
        return np.mean(np.maximum(quantile * error,
                                  (quantile - 1.0) * error)) + penalty

    if hyper_params["optimizer"] == "adam":
        params, training_loss = train.minimize(loss, init_params)
    elif hyper_params["optimizer"] == "scipy.minimize":
        opt = optimize.minimize(loss, init_params, args=(0, ))
        params = opt.x
        training_loss = opt.fun

    # Clip the parameters here in case the loss penalties
    # above were insufficient.
    params = np.maximum(params, 0.0)
    basals, insulin_sensitivities, carb_ratios = unpack_params(params)

    # Now, infer parameter schedules based on the optimized
    # instantaneous parameters. For carbs, we use the average
    # carb curve based on data. We also use the basal insulin
    # parameters for ISF schedules.

    if request.basal_rate_schedule is None:
        # Default: hourly
        basal_rate_index = np.arange(0, 288, 12)
    else:
        basal_rate_index = request.basal_rate_schedule.reindexed(5)
    basal_rate_schedule = (identify_curve(
        basal_insulin_curve, basal_rate_index, np.repeat(basals, 12)) * 12)

    if request.insulin_sensitivity_schedule is None:
        insulin_sensitivity_index = np.arange(0, 288, 12 * 4)
    else:
        insulin_sensitivity_index = request.insulin_sensitivity_schedule.reindexed(
            5)
    insulin_sensitivity_schedule = identify_curve(
        basal_insulin_curve, insulin_sensitivity_index,
        np.repeat(insulin_sensitivities, 12))

    if request.carb_ratio_schedule is None:
        carb_ratio_index = 12 * 6 + np.arange(0, 12 * 12, 4 * 12)
    else:
        carb_ratio_index = request.carb_ratio_schedule.reindexed(5)
    carb_ratio_schedule = identify_curve(default_carb_curve, carb_ratio_index,
                                         np.repeat(carb_ratios, 12))

    # Finally, "quantize" the basal schedule if needed.
    #
    # TODO: Currently this simply tries to match the closest
    # allowable basal rate. We should try to push this up to the
    # model (e.g., the cost function could encourage values close to
    # allowable values), or split the schedule so so that the total
    # amount delivered over the scheduled intervals is equal to the
    # modeled amount, but the rate varies within the intervals.
    #
    # TODO: Another possibility is to perform one model run
    # to fit the basals, then another with the basals "fixed" to the
    # snapped values, allowing the model to adjust the other
    # parameters accordingly.
    #
    # TODO: collapse adjacent entries with the same value.
    if request.allowed_basal_rates is not None:
        allowed = sorted(request.allowed_basal_rates)
        for (i, rate) in enumerate(basal_rate_schedule):
            j = bisect.bisect(allowed, rate)
            # TODO: perhaps be a little more generous here,
            # snapping up when values are (much) closer.
            if j == 0 and rate != allowed[0]:
                basal_rate_schedule[i] = 0.0
            elif j >= len(basal_rate_schedule) or rate != allowed[j]:
                basal_rate_schedule[i] = allowed[j - 1]

    def make_schedule(index, schedule):
        assert len(index) == len(schedule)
        return ((5 * index).tolist(), schedule.tolist())

    return Model(
        params={
            "insulin_sensitivity_schedule":
            make_schedule(insulin_sensitivity_index,
                          insulin_sensitivity_schedule),
            "carb_ratio_schedule":
            make_schedule(carb_ratio_index, carb_ratio_schedule),
            "basal_rate_schedule":
            make_schedule(basal_rate_index, basal_rate_schedule),
        },
        raw_insulin_sensitivities=insulin_sensitivities,
        raw_carb_ratios=carb_ratios,
        raw_basals=basals,
        training_loss=training_loss,
    )
예제 #17
0
smooth_y = true_lds.smooth(x, y)

plt.ion()
plt.figure()
plt.plot(x)

plt.figure()
for n in range(N):
    plt.plot(y[:, n] + 10 * n, '-k')
    # plt.plot(smooth_y[:, n] + 4 * n, 'r--')

# n=n+1
# plt.figure()
# plt.plot(y[:,n])
u = np.zeros((T,0))
mask = np.ones_like(y)
tag = None
lls = true_lds.emissions.log_likelihoods(y, u, mask, tag, x)

lambdas = true_lds.emissions.mean(true_lds.emissions.forward(x, u, tag))[:,0,:]
# plt.figure()
# for n in range(N):
#     plt.axhline(4*n, color='k')
#     plt.plot(lambdas[:,n] / bin_size + 4 * n, '-b')

# hess = true_lds.emissions.hessian_log_emissions_prob(y, u, mask, tag, x)
# hessian_analytical = block_diag(*hess)

# hess_autograd = hessian(lambda x : true_lds.emissions.log_likelihoods(y, u, mask, tag, x))
# hessian_autograd = hess_autograd(x).reshape((T*D), (T*D))
# print("Norm of difference: ", np.linalg.norm(hessian_analytical - hessian_autograd))
예제 #18
0
    def fit(self,
            frequency,
            recency,
            T,
            weights=None,
            initial_params=None,
            verbose=False,
            tol=1e-7,
            index=None,
            **kwargs):
        """
        Fit a dataset to the BG/NBD model.

        Parameters
        ----------
        frequency: array_like
            the frequency vector of customers' purchases
            (denoted x in literature).
        recency: array_like
            the recency vector of customers' purchases
            (denoted t_x in literature).
        T: array_like
            customers' age (time units since first purchase)
        weights: None or array_like
            Number of customers with given frequency/recency/T,
            defaults to 1 if not specified. Fader and
            Hardie condense the individual RFM matrix into all
            observed combinations of frequency/recency/T. This
            parameter represents the count of customers with a given
            purchase pattern. Instead of calculating individual
            loglikelihood, the loglikelihood is calculated for each
            pattern and multiplied by the number of customers with
            that pattern.
        initial_params: array_like, optional
            set the initial parameters for the fitter.
        verbose : bool, optional
            set to true to print out convergence diagnostics.
        tol : float, optional
            tolerance for termination of the function minimization process.
        index: array_like, optional
            index for resulted DataFrame which is accessible via self.data
        kwargs:
            key word arguments to pass to the scipy.optimize.minimize
            function as options dict

        Returns
        -------
        BetaGeoFitter
            with additional properties like ``params_`` and methods like ``predict``
        """

        frequency = np.asarray(frequency).astype(int)
        recency = np.asarray(recency)
        T = np.asarray(T)
        _check_inputs(frequency, recency, T)

        if weights is None:
            weights = np.ones_like(recency, dtype=int)
        else:
            weights = np.asarray(weights)

        self._scale = _scale_time(T)
        scaled_recency = recency * self._scale
        scaled_T = T * self._scale

        log_params_, self._negative_log_likelihood_, self._hessian_ = self._fit(
            (frequency, scaled_recency, scaled_T, weights,
             self.penalizer_coef), initial_params, 4, verbose, tol, **kwargs)

        self.params_ = pd.Series(np.exp(log_params_),
                                 index=["r", "alpha", "a", "b"])
        self.params_["alpha"] /= self._scale

        self.data = pd.DataFrame(
            {
                "frequency": frequency,
                "recency": recency,
                "T": T,
                "weights": weights
            },
            index=index)

        self.generate_new_data = lambda size=1: beta_geometric_nbd_model(
            T, *self._unload_params("r", "alpha", "a", "b"), size=size)

        self.predict = self.conditional_expected_number_of_purchases_up_to_time

        self.variance_matrix_ = self._compute_variance_matrix()
        self.standard_errors_ = self._compute_standard_errors()
        self.confidence_intervals_ = self._compute_confidence_intervals()

        return self
예제 #19
0
def test_laplace_em_hessian(N=5, K=3, D=2, T=20):
    for transitions in ["standard", "recurrent", "recurrent_only"]:
        for emissions in ["gaussian_orthog", "gaussian"]:
            print("Checking analytical hessian for transitions={},  "
                  "and emissions={}".format(transitions, emissions))
            slds = ssm.SLDS(N,
                            K,
                            D,
                            transitions=transitions,
                            dynamics="gaussian",
                            emissions=emissions)
            z, x, y = slds.sample(T)
            new_slds = ssm.SLDS(N,
                                K,
                                D,
                                transitions="standard",
                                dynamics="gaussian",
                                emissions=emissions)

            inputs = [np.zeros((T, 0))]
            masks = [np.ones_like(y)]
            tags = [None]
            method = "laplace_em"
            datas = [y]
            num_samples = 1

            def neg_expected_log_joint_wrapper(x_vec, T, D):
                x = x_vec.reshape(T, D)
                return new_slds._laplace_neg_expected_log_joint(
                    datas[0], inputs[0], masks[0], tags[0], x, Ez, Ezzp1)

            variational_posterior = new_slds._make_variational_posterior(
                "structured_meanfield", datas, inputs, masks, tags, method)
            new_slds._fit_laplace_em_discrete_state_update(
                variational_posterior, datas, inputs, masks, tags, num_samples)
            Ez, Ezzp1, _ = variational_posterior.discrete_expectations[0]

            x = variational_posterior.mean_continuous_states[0]
            scale = x.size
            J_diag, J_lower_diag = new_slds._laplace_hessian_neg_expected_log_joint(
                datas[0], inputs[0], masks[0], tags[0], x, Ez, Ezzp1)
            dense_hessian = scipy.linalg.block_diag(*[x for x in J_diag])
            dense_hessian[D:, :-D] += scipy.linalg.block_diag(
                *[x for x in J_lower_diag])
            dense_hessian[:-D, D:] += scipy.linalg.block_diag(
                *[x.T for x in J_lower_diag])

            true_hess = hessian(neg_expected_log_joint_wrapper)(x.reshape(-1),
                                                                T, D)
            assert np.allclose(true_hess, dense_hessian)
            print("Hessian passed.")

            # Also check that computation of H works.
            h_dense = dense_hessian @ x.reshape(-1)
            h_dense = h_dense.reshape(T, D)

            J_ini, J_dyn_11, J_dyn_21, J_dyn_22, J_obs = new_slds._laplace_neg_hessian_params(
                datas[0], inputs[0], masks[0], tags[0], x, Ez, Ezzp1)
            h_ini, h_dyn_1, h_dyn_2, h_obs = new_slds._laplace_neg_hessian_params_to_hs(
                x, J_ini, J_dyn_11, J_dyn_21, J_dyn_22, J_obs)

            h = h_obs.copy()
            h[0] += h_ini
            h[:-1] += h_dyn_1
            h[1:] += h_dyn_2

            assert np.allclose(h, h_dense)
예제 #20
0
def fit_weights_and_save(weights_file,ca_data_file='rs_sc_fg_pval_0_05_210410.npy',opto_silencing_data_file='vip_halo_data_for_sim.npy',opto_activation_data_file='vip_chrimson_data_for_sim.npy',constrain_wts=None,allow_var=True,multiout=True,multiout2=False,fit_s02=True,constrain_isn=True,tv=False,l2_penalty=0.01,l1_penalty=1.0,init_noise=0.1,init_W_from_lsq=False,scale_init_by=1,init_W_from_file=False,init_file=None,foldT=False,free_amplitude=False,correct_Eta=False,init_Eta_with_s02=False,no_halo_res=False,ignore_halo_vip=False,use_opto_transforms=False,norm_opto_transforms=False,nondim=False,fit_running=False,fit_non_running=True,fit_sc=True,fit_fg=False):
    
    
    nsize,ncontrast = 6,6
    
    nrun = 2
    nsize,ncontrast,ndir = 6,6,8
    nstim_fg = 5

    fit_both_running = (fit_non_running and fit_running)
    fit_both_stims = (fit_sc and fit_fg)

    if not fit_both_running:
        nrun = 1
        if fit_non_running:
            irun = 0
        elif fit_running:
            irun = 1

    nsc = nrun*nsize*ncontrast*ndir
    nfg = nrun*nstim_fg*ndir

    npfile = np.load(ca_data_file,allow_pickle=True)[()]#,{'rs':rs},allow_pickle=True) # ,'rs_denoise':rs_denoise
    if fit_both_running: 
        Rs_mean = npfile['Rs_mean_run']
        Rs_cov = npfile['Rs_cov_run']
        if not fit_both_stims:
            if fit_sc:
                Rs_mean,Rs_cov = get_Rs_slice(Rs_mean,Rs_cov,slice(None,nsc))
            elif fit_fg:
                Rs_mean,Rs_cov = get_Rs_slice(Rs_mean,Rs_cov,slice(nsc,None))
    else:
        Rs_mean = npfile['Rs_mean'][irun]
        Rs_cov = npfile['Rs_cov'][irun]
        if not fit_both_stims:
            if fit_sc:
                Rs_mean,Rs_cov = get_Rs_slice(Rs_mean,Rs_cov,slice(None,nsc))
            elif fit_fg:
                Rs_mean,Rs_cov = get_Rs_slice(Rs_mean,Rs_cov,slice(nsc,None))
    
    ori_dirs = [[0,4],[2,6]] #[[0,4],[1,3,5,7],[2,6]]
    ndims = 5
    nT = len(ori_dirs)
    nS = len(Rs_mean[0])
    
    def sum_to_1(r):
        R = r.reshape((r.shape[0],-1))
        R = R/np.nansum(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis] # changed 21/4/10
        return R
    
    def norm_to_mean(r):
        R = r.reshape((r.shape[0],-1))
        R = R/np.nanmean(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis]
        return R

    def ori_avg(Rs,these_ori_dirs):
        if fit_sc:
            rs_sc = np.nanmean(Rs[:nsc].reshape((nrun,nsize,ncontrast,ndir))[:,:,:,these_ori_dirs],-1)
            rs_sc[:,1:,1:] = ssi.convolve(rs_sc,kernel,'valid')
            rs_sc = rs_sc.reshape((nrun*nsize*ncontrast))
            if fit_fg:
                rs_fg = np.nanmean(Rs[nsc:].reshape((nrun,nstim_fg,ndir))[:,:,these_ori_dirs],-1)
                rs_fg = rs_fg.reshape((nrun*nstim_fg))
            else:
                rs_fg = np.zeros((0,))
        elif fit_fg:
            rs_sc = np.zeros((0,))
            rs_fg = np.nanmean(Rs.reshape((nrun,nstim_fg,ndir))[:,:,these_ori_dirs],-1)
            rs_fg = rs_fg.reshape((nrun*nstim_fg))
        Rso = np.concatenate((rs_sc,rs_fg))
        return Rso
    
    Rso_mean = [[[None for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(Rs_mean))]
    Rso_cov = [[[[[None,None] for idim in range(ndims)] for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(Rs_mean))]
    
    kernel = np.ones((1,2,2))
    kernel = kernel/kernel.sum()
    
    for iR,r in enumerate(Rs_mean):
        for ialign in range(nS):
            for iori in range(nT):
                Rso_mean[iR][ialign][iori] = ori_avg(Rs_mean[iR][ialign],ori_dirs[iori])
                for idim in range(ndims):
                    Rso_cov[iR][ialign][iori][idim][0] = Rs_cov[iR][ialign][idim][0]
                    Rso_cov[iR][ialign][iori][idim][1] = ori_avg(Rs_cov[iR][ialign][idim][1],ori_dirs[iori])

    def set_bound(bd,code,val=0):
        # set bounds to 0 where 0s occur in 'code'
        for iitem in range(len(bd)):
            bd[iitem][code[iitem]] = val
    
    nN = (36*fit_sc + 5*fit_fg)*(1 + fit_both_running)
    nS = 2
    nP = 2 + fit_both_running
    nT = 2
    nQ = 4

    ndims = 5
    ncelltypes = 5
    #print('foldT: %d'%foldT)
    if foldT:
        Yhat = [None for iS in range(nS)]
        Xhat = [None for iS in range(nS)]
        Ypc_list = [None for iS in range(nS)]
        Xpc_list = [None for iS in range(nS)]
        print('have not written this yet')
        assert(True==False)
    else:
        Yhat = [[None for iT in range(nT)] for iS in range(nS)]
        Xhat = [[None for iT in range(nT)] for iS in range(nS)]
        Ypc_list = [[None for iT in range(nT)] for iS in range(nS)]
        Xpc_list = [[None for iT in range(nT)] for iS in range(nS)]
        for iS in range(nS):
            mx = np.zeros((ncelltypes,))
            yy = [None for icelltype in range(ncelltypes)]
            for icelltype in range(ncelltypes):
                yy[icelltype] = np.concatenate(Rso_mean[icelltype][iS])
                mx[icelltype] = np.nanmax(yy[icelltype])
            for iT in range(nT):
                y = [Rso_mean[icelltype][iS][iT][:,np.newaxis]/mx[icelltype] for icelltype in range(1,ncelltypes)]
                Yhat[iS][iT] = np.concatenate(y,axis=1)
                Ypc_list[iS][iT] = [None for icelltype in range(1,ncelltypes)]
                for icelltype in range(1,ncelltypes):
                    Ypc_list[iS][iT][icelltype-1] = [(this_dim[0]/mx[icelltype],this_dim[1]) for this_dim in Rso_cov[icelltype][iS][iT]]
                icelltype = 0
                x = Rso_mean[icelltype][iS][iT][:,np.newaxis]/mx[icelltype]
                if fit_both_running:
                    run_vector = np.zeros_like(x)
                    if fit_both_stims:
                        run_vector[nsize*ncontrast:2*nsize*ncontrast] = 1
                        run_vector[-nstim_fg:] = 1
                    else:
                        run_vector[int(np.round(run_vector.shape[0]/2)):,:] = 1
                else:
                    run_vector = np.zeros((x.shape[0],0))
                Xhat[iS][iT] = np.concatenate((x,np.ones_like(x),run_vector),axis=1)
                Xpc_list[iS][iT] = [None for iinput in range(2+fit_both_running)]
                Xpc_list[iS][iT][0] = [(this_dim[0]/mx[icelltype],this_dim[1]) for this_dim in Rso_cov[icelltype][iS][iT]]
                Xpc_list[iS][iT][1] = [(0,np.zeros((Xhat[0][0].shape[0],))) for idim in range(ndims)]
                if fit_both_running:
                    Xpc_list[iS][iT][2] = [(0,np.zeros((Xhat[0][0].shape[0],))) for idim in range(ndims)]
    nN,nP = Xhat[0][0].shape
    nQ = Yhat[0][0].shape[1]
    
    # code for bounds: 0 , constrained to 0
    # +/-1 , constrained to +/-1
    # 1.5, constrained to [0,1]
    # -1.5, constrained to [-1,1]
    # 2 , constrained to [0,inf)
    # -2 , constrained to (-inf,0]
    # 3 , unconstrained
    
    W0x_bounds = 3*np.ones((nP,nQ),dtype=int)
    W0x_bounds[0,:] = 2 # L4 PCs are excitatory
    W0x_bounds[0,1] = 0 # SSTs don't receive L4 input 
    
    if allow_var:
        if nondim:
            W1x_bounds = -1.5*np.ones(W0x_bounds.shape) #W0x_bounds.copy()*0 #np.zeros_like(W0x_bounds)
        else:
            W1x_bounds = 3*np.ones(W0x_bounds.shape) #W0x_bounds.copy()*0 #np.zeros_like(W0x_bounds)
        W1x_bounds[0,1] = 0
    else:
        W1x_bounds = np.zeros(W0x_bounds.shape) #W0x_bounds.copy()*0 #np.zeros_like(W0x_bounds)
    
    W0y_bounds = 3*np.ones((nQ,nQ),dtype=int)
    W0y_bounds[0,:] = 2 # PCs are excitatory
    W0y_bounds[1:,:] = -2 # all the cell types except PCs are inhibitory
    W0y_bounds[1,1] = 0 # SSTs don't inhibit themselves
    # W0y_bounds[3,1] = 0 # PVs are allowed to inhibit SSTs, consistent with Hillel's unpublished results, but not consistent with Pfeffer et al.
    W0y_bounds[2,0] = 0 # VIPs don't inhibit L2/3 PCs. According to Pfeffer et al., only L5 PCs were found to get VIP inhibition
    W0y_bounds[2,2] = 0 # newly added: no VIP-VIP inhibition 


    if not constrain_wts is None:
        for wt in constrain_wts:
            W0y_bounds[wt[0],wt[1]] = 0
            W1y_bounds[wt[0],wt[1]] = 0
    
    def tile_nS_nT_nN(kernel):
        row = np.concatenate([kernel for idim in range(nS*nT)],axis=0)[np.newaxis,:]
        tiled = np.concatenate([row for irow in range(nN)],axis=0)
        return tiled
    
    if fit_s02:
        s02_bounds = 2*np.ones((nQ,)) # permitting noise as a free parameter
    else:
        s02_bounds = np.ones((nQ,))
    
    Kin0_bounds = 1.5*np.ones((nQ,))
    
    kappa_bounds = np.ones((1,))
    # kappa_bounds = 2*np.ones((1,))
    
    Tin0_bounds = 1.5*np.ones((nQ,))
    #T_bounds[2:4] = 1 # PV and VIP are constrained to have flat ori tuning
    #Tin0_bounds[1:4] = 1 # SST,VIP, and PV are constrained to have flat ori tuning

    if nondim:
        kt_factor = -1.5
    else:
        kt_factor = 3

    if allow_var:
        W1y_bounds = kt_factor*np.ones(W0y_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        Kin1_bounds = kt_factor*np.ones(Kin0_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        Tin1_bounds = kt_factor*np.ones(Tin0_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        W1y_bounds[1,1] = 0
        #W1y_bounds[3,1] = 0 
        W1y_bounds[2,0] = 0
        W1y_bounds[2,2] = 0 # newly added: no VIP-VIP inhibition
    else:
        W1y_bounds = np.zeros(W0y_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        Kin1_bounds = 0*np.ones(Kin0_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)
        Tin1_bounds = 0*np.ones(Tin0_bounds.shape) #W0y_bounds.copy()*0 #np.zeros_like(W0y_bounds)

    if multiout:
        W2x_bounds = W1x_bounds.copy()
        W2y_bounds = W1y_bounds.copy()
        if multiout2:
            W3x_bounds = W1x_bounds.copy()
            W3y_bounds = W1y_bounds.copy()
        else:
            W3x_bounds = W1x_bounds.copy()*0
            W3y_bounds = W1y_bounds.copy()*0
    else:
        W2x_bounds = W1x_bounds.copy()*0
        W2y_bounds = W1y_bounds.copy()*0
        W3x_bounds = W1x_bounds.copy()*0
        W3y_bounds = W1y_bounds.copy()*0

    Kxout0_bounds = np.array((1.5,)+tuple(np.zeros((nP-1,))))
    Txout0_bounds = Kxout0_bounds.copy()
    Kxout1_bounds = np.array((kt_factor,)+tuple(np.zeros((nP-1,))))
    Txout1_bounds = Kxout1_bounds.copy() 

    Kyout0_bounds = Kin0_bounds.copy()
    Tyout0_bounds = Tin0_bounds.copy()
    Kyout1_bounds = Kin1_bounds.copy()
    Tyout1_bounds = Tin1_bounds.copy()
    
    if fit_both_running:
        to_tile = Xhat[0][0][:,1:]
        to_tile = np.concatenate((2*np.ones((to_tile.shape[0],1)),to_tile),axis=1)
        X_bounds = np.tile(to_tile,(1,nS*nT))
    else:
        X_bounds = tile_nS_nT_nN(np.array([2,1]))
    #print(X_bounds.shape)
    # X_bounds = np.array([np.array([2,1,2,1])]*nN)
    
    if fit_both_running:
        Xp_bounds = tile_nS_nT_nN(np.array([3,0,0])) # edited to set XXp to 0 for spont. term
    else:
        Xp_bounds = tile_nS_nT_nN(np.array([3,0])) # edited to set XXp to 0 for spont. term
    # Xp_bounds = np.array([np.array([3,1,3,1])]*nN)
    
    # Y_bounds = tile_nS_nT_nN(2*np.ones((nQ,)))
    # # Y_bounds = 2*np.ones((nN,nT*nS*nQ))
    
    Eta_bounds = tile_nS_nT_nN(3*np.ones((nQ,)))
    # Eta_bounds = 3*np.ones((nN,nT*nS*nQ))
    
    #if allow_var:
    #    Xi_bounds = tile_nS_nT_nN(3*np.ones((nQ,)))
    #else:
    #    Xi_bounds = tile_nS_nT_nN(np.zeros((nQ,)))
    Xi_bounds = tile_nS_nT_nN(3*np.ones((nQ,))) # temporarily allowing Xi even if W1 is not allowed

    # Xi_bounds = 3*np.ones((nN,nT*nS*nQ))
    
    h1_bounds = -2*np.ones((1,))

    h2_bounds = 2*np.ones((1,))

    bl_bounds = 3*np.ones((nQ,))

    if free_amplitude:
        amp_bounds = 2*np.ones((nT*nS*nQ,))
    else:
        amp_bounds = 1*np.ones((nT*nS*nQ,))
    
    # shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nN,nS*nP),(nN,nS*nQ),(nN,nS*nQ),(nN,nS*nQ)]
    #shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ),(1,)]
    shapes1 = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(nQ*(nS-1),),(nP*(nS-1),),(nQ*(nS-1),),(nP*(nS-1),),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nQ*(nT-1),),(nP*(nT-1),),(nQ*(nT-1),),(nP*(nT-1),),(nQ*(nT-1),),(1,),(1,),(nQ,),(nT*nS*nQ,)]
    shapes2 = [(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ)]
    #         W0x,    W0y,    W1x,    W1y,    W2x,    W2y,    W3x,    W3y,    s02,  k,    kappa,T,   XX,            XXp,          Eta,          Xi
    
    #lb = [-np.inf*np.ones(shp) for shp in shapes]
    #ub = [np.inf*np.ones(shp) for shp in shapes]
    #bdlist = [W0x_bounds,W0y_bounds,W1x_bounds,W1y_bounds,W2x_bounds,W2y_bounds,W3x_bounds,W3y_bounds,s02_bounds,k0_bounds,k1_bounds,k2_bounds,k3_bounds,kappa_bounds,Tin0_bounds,Tin1_bounds,Tout0_bounds,Tout1_bounds,X_bounds,Xp_bounds,Eta_bounds,Xi_bounds,h_bounds]
    bd1list = [W0x_bounds,W0y_bounds,W1x_bounds,W1y_bounds,W2x_bounds,W2y_bounds,W3x_bounds,W3y_bounds,s02_bounds,Kin0_bounds,Kin1_bounds,Kxout0_bounds,Kyout0_bounds,Kxout1_bounds,Kyout1_bounds,kappa_bounds,Tin0_bounds,Tin1_bounds,Txout0_bounds,Tyout0_bounds,Txout1_bounds,Tyout1_bounds,h1_bounds,h2_bounds,bl_bounds,amp_bounds]
    bd2list = [X_bounds,Xp_bounds,Eta_bounds,Xi_bounds]

    lb1,ub1 = [[sgn*np.inf*np.ones(shp) for shp in shapes1] for sgn in [-1,1]]
    lb1,ub1 = calnet.utils.set_bounds_by_code(lb1,ub1,bd1list)
    lb2,ub2 = [[sgn*np.inf*np.ones(shp) for shp in shapes2] for sgn in [-1,1]]
    lb2,ub2 = calnet.utils.set_bounds_by_code(lb2,ub2,bd2list)

    lb1 = np.concatenate([a.flatten() for a in lb1])
    ub1 = np.concatenate([b.flatten() for b in ub1])
    lb2 = np.concatenate([a.flatten() for a in lb2])
    ub2 = np.concatenate([b.flatten() for b in ub2])
    bounds1 = [(a,b) for a,b in zip(lb1,ub1)]
    bounds2 = [(a,b) for a,b in zip(lb2,ub2)]
    
    
    def compute_f_(Eta,Xi,s02):
        return sim_utils.f_miller_troyer(Eta,Xi**2+np.concatenate([s02 for ipixel in range(nS*nT)]))
    def compute_fprime_m_(Eta,Xi,s02):
        return sim_utils.fprime_miller_troyer(Eta,Xi**2+np.concatenate([s02 for ipixel in range(nS*nT)]))*Xi
    def compute_fprime_s_(Eta,Xi,s02):
        s2 = Xi**2+np.concatenate((s02,s02),axis=0)
        return sim_utils.fprime_s_miller_troyer(Eta,s2)*(Xi/s2)
    def sorted_r_eigs(w):
        drW,prW = np.linalg.eig(w)
        srtinds = np.argsort(drW)
        return drW[srtinds],prW[:,srtinds]
    
    #0.W0x,1.W0y,2.W1x,3.W1y,4.W2x,5.W2y,6.W3x,7.W3y,8.s02,9.Kin0,10.Kin1,11.Kout0,12.Kout1,13.kappa,14.Tin0,15.Tin1,16.Txout0,Tyout0,17.Txout1,Tyout1,18.h1,19.h2,20.bl,21.amp
    #0.XX,1.XXp,2.Eta,3.Xi
    
    #shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ),(1,)]
    #shapes1 = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(nQ*(nT-1),),(1,),(1,),(nQ,),(nT*nS*nQ,)]
    #shapes2 = [(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ)]
    
    import sim_utils

    YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat)
    
    opto_dict = np.load(opto_silencing_data_file,allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = np.ones((nN*2,nQ*nS*nT))
    #Yhat_opto = Yhat_opto.reshape((nN*2,-1))
    Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12],axis=0)[np.newaxis]
    Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12],axis=0)[np.newaxis]
    Yhat_opto = Yhat_opto/np.nanmax(Yhat_opto[0::2],0)[np.newaxis,:]
    #print(Yhat_opto.shape)
    h_opto = np.zeros((nN*2,))
    #h_opto = opto_dict['h_opto']
    #dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2]

    YYhat_halo = Yhat_opto.reshape((nN,2,-1))
    opto_transform1 = calnet.utils.fit_opto_transform(YYhat_halo,norm01=norm_opto_transforms)

    if no_halo_res:
        opto_transform1.res[:,[0,2,3,4,6,7]] = 0

    dYY1 = opto_transform1.transform(YYhat) - opto_transform1.preprocess(YYhat)
    #print('delta bias: %f'%dXX1[:,1].mean())
    #YYhat_halo_sim = calnet.utils.simulate_opto_effect(YYhat,YYhat_halo)
    #dYY1 = YYhat_halo_sim[:,1,:] - YYhat_halo_sim[:,0,:]

    def overwrite_plus_n(arr,to_overwrite,n):
        arr[:,to_overwrite] = arr[:,int(to_overwrite+n)]
        return arr

    for to_overwrite in [1,2]:
        n = 4
        dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \
                = [overwrite_plus_n(x,to_overwrite,n) for x in \
                        [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]]
    for to_overwrite in [7]:
        n = -4
        dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res \
                = [overwrite_plus_n(x,to_overwrite,n) for x in \
                        [dYY1,opto_transform1.slope,opto_transform1.intercept,opto_transform1.res]]

    opto_dict = np.load(opto_activation_data_file,allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    Yhat_opto = np.ones((nN*2,nQ*nS*nT))
    #Yhat_opto = Yhat_opto.reshape((nN*2,-1))
    Yhat_opto[0::12] = np.nanmean(Yhat_opto[0::12],axis=0)[np.newaxis]
    Yhat_opto[1::12] = np.nanmean(Yhat_opto[1::12],axis=0)[np.newaxis]
    Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    #print(Yhat_opto.shape)
    h_opto = np.zeros((nN*2,))
    #h_opto = opto_dict['h_opto']
    #dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2]

    YYhat_chrimson = Yhat_opto.reshape((nN,2,-1))
    opto_transform2 = calnet.utils.fit_opto_transform(YYhat_chrimson,norm01=norm_opto_transforms)

    dYY2 = opto_transform2.transform(YYhat) - opto_transform2.preprocess(YYhat)

    dYY = np.concatenate((dYY1,dYY2),axis=0)

    if ignore_halo_vip:
        dYY1[:,2::nQ] = np.nan
    
    from importlib import reload
    reload(calnet)
    reload(calnet.fitting_2step_spatial_feature_opto_multiout_axon_nonlinear)
    reload(sim_utils)
    wt_dict = {}
    wt_dict['X'] = 1
    wt_dict['Y'] = 3
    wt_dict['Eta'] = 3# 10
    wt_dict['Xi'] = 3
    wt_dict['stims'] = np.ones((nN,1)) #(np.arange(30)/30)[:,np.newaxis]**1 #
    wt_dict['barrier'] = 0. #30.0 #0.1
    wt_dict['opto'] = 0#1e0#1e-1#1e1
    wt_dict['smi'] = 0
    wt_dict['isn'] = 0.1
    wt_dict['tv'] = 1


    YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat)
    XXhat = calnet.utils.flatten_nested_list_of_2d_arrays(Xhat)
    Eta0 = invert_f_mt(YYhat)
    Xi0 = invert_fprime_mt(Ypc_list,Eta0,nN=nN,nQ=nQ,nS=nS,nT=nT,foldT=foldT)

    ntries = 1
    nhyper = 1
    dt = 1e-1
    niter = int(np.round(10/dt)) #int(1e4)
    perturbation_size = 5e-2
    W1t = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    W2t = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    loss = np.zeros((nhyper,ntries))
    is_neg = np.array([b[1] for b in bounds1])==0
    counter = 0
    negatize = [np.zeros(shp,dtype='bool') for shp in shapes1]
    for ishp,shp in enumerate(shapes1):
        nel = np.prod(shp)
        negatize[ishp][:][is_neg[counter:counter+nel].reshape(shp)] = True
        counter = counter + nel
    for ihyper in range(nhyper):
        for itry in range(ntries):
            print((ihyper,itry))
            W10list = [init_noise*(ihyper+1)*np.random.rand(*shp) for shp in shapes1]
            W20list = [init_noise*(ihyper+1)*np.random.rand(*shp) for shp in shapes2]
            counter = 0
            for ishp,shp in enumerate(shapes1):
                W10list[ishp][negatize[ishp]] = -W10list[ishp][negatize[ishp]]
            nextraW = 4
            nextraK = nextraW + 3
            nextraT = nextraK + 3
            #Wstar_dict['as_list'] = [W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,XX,XXp,Eta,Xi,h1,h2,bl,amp]#,h2
            init_val = [1,1,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1]
            W10list = [iv*np.ones(shp) for iv,shp in zip(init_val,shapes1)]
            #W10list[nextraW+4] = np.ones(shapes1[nextraW+4]) # s02
            #W10list[nextraW+5] = np.ones(shapes1[nextraW+5]) # K
            #W10list[nextraW+6] = np.ones(shapes1[nextraW+6]) # K
            #W10list[nextraW+7] = np.zeros(shapes1[nextraW+7]) # K
            #W10list[nextraW+8] = np.zeros(shapes1[nextraW+8]) # K
            #W10list[nextraK+6] = np.ones(shapes1[nextraK+6]) # kappa
            #W10list[nextraK+7] = np.ones(shapes1[nextraK+7]) # T
            #W10list[nextraK+8] = np.ones(shapes1[nextraK+8]) # T
            #W10list[nextraK+9] = np.zeros(shapes1[nextraK+9]) # T
            #W10list[nextraK+10] = np.zeros(shapes1[nextraK+10]) # T
            W20list[0] = XXhat #np.concatenate(Xhat,axis=1) #XX
            W20list[1] = get_pc_dim(Xpc_list,nN=nN,nPQ=nP,nS=nS,nT=nT,idim=0,foldT=foldT) #XXp
            W20list[2] = Eta0 #np.zeros(shapes[nextraT+10]) #Eta
            W20list[3] = Xi0 #Xi
            #print(XXhat.shape)
            isn_init = np.array(((3,5),(-5,-5)))
            if init_W_from_lsq:
                # shapes1
                #0.W0x,1.W0y,2.W1x,3.W1y,4.W2x,5.W2y,6.W3x,7.W3y,8.s02,9.Kin0,10.Kin1,11.Kout0,12.Kout1,13.kappa,14.Tin0,15.Tin1,16.Txout0,Tyout0,17.Txout1,Tyout1,18.h1,19.h2,20.bl,21.amp
                # shapes2
                #0.XX,1.XXp,2.Eta,3.Xi
                #W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,Kin0,Kin1,Tin0,Tin1 = initialize_Ws(Xhat,Yhat,Xpc_list,Ypc_list,scale_by=1)
                nvar,nxy = 4,2
                freeze_vals = [[None for _ in range(nxy)] for _ in range(nvar)]
                lams = 1e5*np.array((0,1,1,1,0,1,0,1))
                for ivar in range(nvar):
                    for ixy in range(nxy):
                        iflat = np.ravel_multi_index((ivar,ixy),(nvar,nxy))
                        freeze_vals[ivar][ixy] = np.zeros(bd1list[iflat].shape)
                        freeze_vals[ivar][ixy][bd1list[iflat]==0] = np.nan
                if constrain_isn:
                    freeze_vals[0][1][slice(0,None,3)][:,slice(0,None,3)] = isn_init
                #Wlist = [W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1]
                # W1list = [W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,h1,h2,bl,amp]#,h2
                # W0,W1,W2,W3,Kin0,Kin1,Tin0,Tin1
                thisWlist = initialize_Ws(Xhat,Yhat,Xpc_list,Ypc_list,scale_by=1,freeze_vals=freeze_vals,lams=lams,foldT=foldT)
                Winds = [0,1,2,3,4,5,6,7,9,10,11,12,13,14,16,17,18,19,20,21]
                for ivar,Wind in enumerate(Winds):
                    W10list[Wind] = thisWlist[ivar]
                #W10list[0],W10list[1] = initialize_W(Xhat,Yhat,scale_by=scale_init_by)
                for Wind in Winds:
                    W10list[Wind] = W10list[Wind] + init_noise*np.random.randn(*W10list[Wind].shape)
            else:
                if constrain_isn:
                    W10list[1][slice(0,None,3)][:,slice(0,None,3)] = isn_init
                    #W10list[1][0,0] = 3 
                    #W10list[1][0,3] = 5 
                    #W10list[1][3,0] = -5
                    #W10list[1][3,3] = -5
            np.save('/home/dan/calnet_data/W0list.npy',{'W10list':W10list,'W20list':W20list,'bd1list':bd1list,'bd2list':bd2list,'freeze_vals':freeze_vals,'bounds1':bounds1,'bounds2':bounds2},allow_pickle=True)

            if init_W_from_file:
                # did not adjust this yet
                npyfile = np.load(init_file,allow_pickle=True)[()]
                print(len(npyfile['as_list']))
                print([w.shape for w in npyfile['as_list']])
                W10list = [npyfile['as_list'][ivar] for ivar in [0,1,2,3,4,5,6,7,12]]
                W20list = [npyfile['as_list'][ivar] for ivar in [8,9,10,11]]
                if correct_Eta:
                    #assert(True==False)
                    W20list[2] = Eta0.copy()
                if len(W10list) < len(shapes1):
                    #assert(True==False)
                    W10list = W10list + [np.array(1),np.zeros((nQ,)),np.zeros((nT*nS*nQ,))] # add bl, amp #np.array(1), #h2, 
                #W10 = unparse_W(W10list)
                #W20 = unparse_W(W20list)
                opt = fmc.gen_opt()
                #resEta0,resXi0 = fmc.compute_res(W10,W20,opt)
                if init_W1xy_with_res:
                    W1x0,W1y0,Kin10,Tin10 = optimize_W1xy(W10list,W20list,opt)
                    W0list[2] = W1x0
                    W0list[3] = W1y0
                    W0list[10] = Kin10
                    W0list[15] = Tin10
                if init_W2xy_with_res:
                    W2x0,W2y0 = optimize_W2xy(W10list,W20list,opt)
                    W0list[4] = W2x0
                    W0list[5] = W2y0
                if init_Eta_with_s02:
                    #assert(True==False)
                    s02 = W10list[4].copy()
                    Eta0 = invert_f_mt_with_s02(YYhat,s02,nS=nS,nT=nT)
                    W20list[2] = Eta0.copy()
                for ivar in [0,1,4,5]: # Wmx, Wmy, s02, k
                    print(init_noise)
                    W10list[ivar] = W10list[ivar] + init_noise*np.random.randn(*W10list[ivar].shape)
                #W0list = npyfile['as_list']

                extra_Ws = [np.zeros_like(W10list[ivar]) for ivar in range(2)]
                extra_ks = [np.zeros_like(W10list[5]) for ivar in range(3)]
                extra_Ts = [np.zeros_like(W10list[7]) for ivar in range(3)]
                W10list = W10list[:4] + extra_Ws*2 + W10list[4:6] + extra_ks + W10list[6:8] + extra_Ts + W10list[8:]

            print(len(W10list))
            W1t[ihyper][itry],W2t[ihyper][itry],loss[ihyper][itry],gr,hess,result = calnet.fitting_2step_spatial_feature_opto_multiout_axon_nonlinear.fit_W_sim(Xhat,Xpc_list,Yhat,Ypc_list,pop_rate_fn=sim_utils.f_miller_troyer,pop_deriv_fn=sim_utils.fprime_miller_troyer,neuron_rate_fn=sim_utils.evaluate_f_mt,W10list=W10list.copy(),W20list=W20list.copy(),bounds1=bounds1,bounds2=bounds2,niter=niter,wt_dict=wt_dict,l2_penalty=l2_penalty,l1_penalty=l1_penalty,compute_hessian=False,dt=dt,perturbation_size=perturbation_size,dYY=dYY,constrain_isn=constrain_isn,tv=tv,foldT=foldT,use_opto_transforms=use_opto_transforms,opto_transform1=opto_transform1,opto_transform2=opto_transform2,nondim=nondim)
    
    #def parse_W(W):
    #    W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kout0,Kout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,XX,XXp,Eta,Xi,h = W
    #    return W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kout0,Kout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,XX,XXp,Eta,Xi,h
    def parse_W1(W):
        W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,h1,h2,bl,amp = W #h2,
        return W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,h1,h2,bl,amp #h2,
    def parse_W2(W):
        XX,XXp,Eta,Xi = W
        return XX,XXp,Eta,Xi    

    def unparse_W(Ws):
        return np.concatenate([ww.flatten() for ww in Ws])
    
    itry = 0
    W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,h1,h2,bl,amp = parse_W1(W1t[0][0])#h2,
    XX,XXp,Eta,Xi = parse_W2(W2t[0][0])
    
    #labels = ['W0x','W0y','W1x','W1y','W2x','W2y','W3x','W3y','s02','Kin0','Kin1','Kout0','Kout1','kappa','Tin0','Tin1','Tout0','Tout1','XX','XXp','Eta','Xi','h']
    labels1 = ['W0x','W0y','W1x','W1y','W2x','W2y','W3x','W3y','s02','Kin0','Kin1','Kxout0','Kyout0','Kxout1','Kyout1','kappa','Tin0','Tin1','Txout0','Tyout0','Txout1','Tyout1','h1','h2','bl','amp']#,'h2'
    labels2 = ['XX','XXp','Eta','Xi']
    Wstar_dict = {}
    for i,label in enumerate(labels1):
        Wstar_dict[label] = W1t[0][0][i]
    for i,label in enumerate(labels2):
        Wstar_dict[label] = W2t[0][0][i]
    #Wstar_dict = {}
    #for i,label in enumerate(labels):
    #    Wstar_dict[label] = W1t[0][0][i]
    Wstar_dict['as_list'] = [W0x,W0y,W1x,W1y,W2x,W2y,W3x,W3y,s02,Kin0,Kin1,Kxout0,Kyout0,Kxout1,Kyout1,kappa,Tin0,Tin1,Txout0,Tyout0,Txout1,Tyout1,XX,XXp,Eta,Xi,h1,h2,bl,amp]#,h2
    Wstar_dict['loss'] = loss[0][0]
    Wstar_dict['wt_dict'] = wt_dict
    np.save(weights_file,Wstar_dict,allow_pickle=True)
예제 #21
0
파일: emissions.py 프로젝트: zshwuhan/ssm
 def log_likelihoods(self, data, input, mask, tag, x):
     assert data.dtype == int and data.min() >= 0 and data.max() <= 1
     ps = self.mean(self.compute_mus(x))
     mask = np.ones_like(data, dtype=bool) if mask is None else mask
     lls = data[:, None, :] * np.log(ps) + (1 - data[:, None, :]) * np.log(1 - ps)
     return np.sum(lls * mask[:, None, :], axis=2)
예제 #22
0
 def fun(x):
     return rewrites.constant_folding_einsum(',i,i,j->', np.sum(x),
                                             np.ones_like(x), x, ones)
예제 #23
0
파일: emissions.py 프로젝트: zshwuhan/ssm
 def log_likelihoods(self, data, input, mask, tag, x):
     assert data.dtype == int
     lambdas = self.mean(self.compute_mus(x))
     mask = np.ones_like(data, dtype=bool) if mask is None else mask
     lls = -gammaln(data[:,None,:] + 1) -lambdas + data[:,None,:] * np.log(lambdas)
     return np.sum(lls * mask[:, None, :], axis=2)
예제 #24
0
def run_fitting(wavfile):
    print('running trace3d fit...')
    fs, data = siw.read(wavfile)
    freq, t, spec = ssi.spectrogram(data)
    nt = t.shape[0]
    spec = spec / np.max(np.mean(spec, axis=0))
    minpad = 1000
    maxlen = 2000
    #npad = minpad + np.argmax(ssi.convolve(np.sum(spec[:,minpad:],0),np.ones((maxlen,)),mode='valid'))
    #endat = npad+maxlen#np.minimum(npad+maxlen,nt-npad)
    npad, endat = find_loud_window(spec, minpad=minpad, maxlen=maxlen)
    rg = slice(npad, endat)
    ntpad = endat - npad
    freq_min = 5

    def trace3d_to_spec(freq, t, trace, fsigma, fudge=1e-6):
        amp = trace[:, 0, :]
        fmu = trace[:, 1, :]
        #         fsigma = trace[:,2,:]
        dfreq = freq[np.newaxis, :, np.newaxis] - fmu[:, np.newaxis, :]
        fs2 = fsigma[:, np.newaxis, np.newaxis]**2 + fudge
        #         print((dfreq.max(),fmu.mean(),fsigma.mean()))
        return np.sum(amp[:, np.newaxis, :] * np.exp(-0.5 * dfreq**2 / fs2) /
                      np.sqrt(2 * np.pi * fs2),
                      axis=0)

    def parse_trace(trace):
        this_trace = np.reshape(trace[:-nharmonics], (nharmonics, ndim, ntpad))
        fsigma = trace[-nharmonics:]
        return this_trace, fsigma

    def cost_lsq(trace, fsigma):
        spec_modeled = trace3d_to_spec(freq, t[rg], trace, fsigma)
        return np.sum(
            (spec[freq_min:, npad:endat] - spec_modeled[freq_min:])**2)

    def cost_tv_l1(trace):
        return np.sum(np.abs(np.diff(trace, axis=2)), axis=2)

    def cost_tv_l2(trace):
        return np.sum(np.abs(np.diff(trace, axis=2))**2, axis=2)

    def cost_l1(trace):
        return np.sum(np.abs(trace), axis=2)

    # lam_tv_l1 = np.array((1,0,0))
    # lam_tv_l2 = np.array((0,10000,10000))
    # lam_l1 = np.array((0.1,0,0))
    lam_tv_l1 = np.array((1e2, 0))
    lam_tv_l2 = np.array((0, 3e5))
    lam_l1 = np.array((1e2, 0))

    def cost_total(trace):
        this_trace, fsigma = parse_trace(trace)
        tv_l1_term = np.sum(lam_tv_l1[np.newaxis, :] * cost_tv_l1(this_trace))
        tv_l2_term = np.sum(lam_tv_l2[np.newaxis, :] * cost_tv_l2(this_trace))
        l1_term = np.sum(lam_l1[np.newaxis, :] * cost_l1(this_trace))
        lsq_term = cost_lsq(this_trace, fsigma)
        #print((lsq_term,tv_l1_term,tv_l2_term,l1_term))
        return lsq_term + tv_l1_term + tv_l2_term + l1_term

    nharmonics = 1
    ndim = 2
    trace0 = np.zeros((nharmonics, ndim, ntpad))
    fsigma0 = 1 * np.mean(np.diff(freq)) * np.ones((nharmonics, ))
    fsigma0[-1] = 2 * fsigma0[-1]
    # trace0[0,2,:] = fsigma0
    trace0[:, 0, :] = (np.std(spec[:, rg], axis=0) /
                       np.max(np.std(spec[:, rg], axis=0)) *
                       np.max(spec[:, rg]))[np.newaxis, :] * np.sqrt(
                           2 * np.pi * fsigma0[:, np.newaxis]**2)
    trace0[:, 1, :] = (np.sum(freq[:, np.newaxis] * spec[:, rg], axis=0) /
                       np.sum(spec[:, rg], axis=0))[np.newaxis]
    trace0 = np.concatenate((trace0.flatten(), fsigma0))
    bds = sop.Bounds(lb=np.zeros_like(trace0),
                     ub=np.inf * np.ones_like(trace0))
    # start_time = timeit.default_timer()
    res = sop.minimize(cost_total, trace0, jac=grad(cost_total), bounds=bds)
    trace, fsigma = parse_trace(res.x)
    return trace, fsigma
예제 #25
0
    def fit(self,
            frequency,
            monetary_value,
            weights=None,
            initial_params=None,
            verbose=False,
            tol=1e-7,
            index=None,
            q_constraint=False,
            **kwargs):
        """
        Fit the data to the Gamma/Gamma model.

        Parameters
        ----------
        frequency: array_like
            the frequency vector of customers' purchases
            (denoted x in literature).
        monetary_value: array_like
            the monetary value vector of customer's purchases
            (denoted m in literature).
        weights: None or array_like
            Number of customers with given frequency/monetary_value,
            defaults to 1 if not specified. Fader and
            Hardie condense the individual RFM matrix into all
            observed combinations of frequency/monetary_value. This
            parameter represents the count of customers with a given
            purchase pattern. Instead of calculating individual
            loglikelihood, the loglikelihood is calculated for each
            pattern and multiplied by the number of customers with
            that pattern.
        initial_params: array_like, optional
            set the initial parameters for the fitter.
        verbose : bool, optional
            set to true to print out convergence diagnostics.
        tol : float, optional
            tolerance for termination of the function minimization process.
        index: array_like, optional
            index for resulted DataFrame which is accessible via self.data
        q_constraint: bool, optional
            when q < 1, population mean will result in a negative value
            leading to negative CLV outputs. If True, we penalize negative values of q to avoid this issue.
        kwargs:
            key word arguments to pass to the scipy.optimize.minimize
            function as options dict

        Returns
        -------
        GammaGammaFitter
            fitted and with parameters estimated

        """
        _check_inputs(frequency, monetary_value=monetary_value)

        frequency = np.asarray(frequency).astype(float)
        monetary_value = np.asarray(monetary_value).astype(float)

        if weights is None:
            weights = np.ones_like(frequency, dtype=int)
        else:
            weights = np.asarray(weights)

        log_params, self._negative_log_likelihood_, self._hessian_ = self._fit(
            (frequency, monetary_value, weights, self.penalizer_coef),
            initial_params,
            3,
            verbose,
            tol=tol,
            bounds=((None, None), (0, None), (None,
                                              None)) if q_constraint else None,
            **kwargs)

        self.data = DataFrame(
            {
                "monetary_value": monetary_value,
                "frequency": frequency,
                "weights": weights
            },
            index=index)

        self.params_ = pd.Series(np.exp(log_params), index=["p", "q", "v"])

        self.variance_matrix_ = self._compute_variance_matrix()
        self.standard_errors_ = self._compute_standard_errors()
        self.confidence_intervals_ = self._compute_confidence_intervals()

        return self
예제 #26
0
def fit_weights_and_save(weights_file,ca_data_file='rs_vm_denoise_200605.npy',opto_silencing_data_file='vip_halo_data_for_sim.npy',opto_activation_data_file='vip_chrimson_data_for_sim.npy',constrain_wts=None,allow_var=True,fit_s02=True,constrain_isn=True,l2_penalty=0.01,init_noise=0.1,init_W_from_lsq=False,scale_init_by=1,init_W_from_file=False,init_file=None,correct_Eta=False):
    
    
    nsize,ncontrast = 6,6
    
    npfile = np.load(ca_data_file,allow_pickle=True)[()]#,{'rs':rs,'rs_denoise':rs_denoise},allow_pickle=True)
    rs = npfile['rs']
    #rs_denoise = npfile['rs_denoise']
    
    nsize,ncontrast,ndir = 6,6,8
    ori_dirs = [[0,4],[2,6]] #[[0,4],[1,3,5,7],[2,6]]
    nT = len(ori_dirs)
    nS = len(rs[0])
    
    def sum_to_1(r):
        R = r.reshape((r.shape[0],-1))
        #R = R/np.nansum(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis]
        R = R/np.nansum(R,axis=1)[:,np.newaxis] # changed 8/28
        return R
    
    def norm_to_mean(r):
        R = r.reshape((r.shape[0],-1))
        R = R/np.nanmean(R[:,~np.isnan(R.sum(0))],axis=1)[:,np.newaxis]
        return R
    
    Rs = [[None,None] for i in range(len(rs))]
    Rso = [[[None for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(rs))]
    rso = [[[None for iT in range(nT)] for iS in range(nS)] for icelltype in range(len(rs))]
    
    for iR,r in enumerate(rs):#rs_denoise):
        print(iR)
        for ialign in range(nS):
            #Rs[iR][ialign] = r[ialign][:,:nsize,:]
            #sm = np.nanmean(np.nansum(np.nansum(Rs[iR][ialign],1),1))
            #Rs[iR][ialign] = Rs[iR][ialign]/sm
            Rs[iR][ialign] = sum_to_1(r[ialign][:,:nsize,:])
    #         Rs[iR][ialign] = von_mises_denoise(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir)))
    
    kernel = np.ones((1,2,2))
    kernel = kernel/kernel.sum()
    
    for iR,r in enumerate(rs):
        for ialign in range(nS):
            for iori in range(nT):
                Rso[iR][ialign][iori] = np.nanmean(Rs[iR][ialign].reshape((-1,nsize,ncontrast,ndir))[:,:,:,ori_dirs[iori]],-1)
                Rso[iR][ialign][iori][:,:,0] = np.nanmean(Rso[iR][ialign][iori][:,:,0],1)[:,np.newaxis] # average 0 contrast values
                Rso[iR][ialign][iori][:,1:,1:] = ssi.convolve(Rso[iR][ialign][iori],kernel,'valid')
                Rso[iR][ialign][iori] = Rso[iR][ialign][iori].reshape(Rso[iR][ialign][iori].shape[0],-1)
                #Rso[iR][ialign][iori] = Rso[iR][ialign][iori]/np.nanmean(Rso[iR][ialign][iori],-1)[:,np.newaxis]
    
    def set_bound(bd,code,val=0):
        # set bounds to 0 where 0s occur in 'code'
        for iitem in range(len(bd)):
            bd[iitem][code[iitem]] = val
    
    nN = 36
    nS = 2
    nP = 2
    nT = 2
    nQ = 4
    
    # code for bounds: 0 , constrained to 0
    # +/-1 , constrained to +/-1
    # 1.5, constrained to [0,1]
    # 2 , constrained to [0,inf)
    # -2 , constrained to (-inf,0]
    # 3 , unconstrained
    
    Wmx_bounds = 3*np.ones((nP,nQ),dtype=int)
    Wmx_bounds[0,1] = 0 # SSTs don't receive L4 input
    
    if allow_var:
        Wsx_bounds = 3*np.ones(Wmx_bounds.shape) #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds)
        Wsx_bounds[0,1] = 0
    else:
        Wsx_bounds = np.zeros(Wmx_bounds.shape) #Wmx_bounds.copy()*0 #np.zeros_like(Wmx_bounds)
    
    Wmy_bounds = 3*np.ones((nQ,nQ),dtype=int)
    Wmy_bounds[0,:] = 2 # PCs are excitatory
    Wmy_bounds[1:,:] = -2 # all the cell types except PCs are inhibitory
    Wmy_bounds[1,1] = 0 # SSTs don't inhibit themselves
    # Wmy_bounds[3,1] = 0 # PVs are allowed to inhibit SSTs, consistent with Hillel's unpublished results, but not consistent with Pfeffer et al.
    Wmy_bounds[2,0] = 0 # VIPs don't inhibit L2/3 PCs. According to Pfeffer et al., only L5 PCs were found to get VIP inhibition

    if allow_var:
        Wsy_bounds = 3*np.ones(Wmy_bounds.shape) #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds)
        Wsy_bounds[1,1] = 0
        Wsy_bounds[3,1] = 0 
        Wsy_bounds[2,0] = 0
    else:
        Wsy_bounds = np.zeros(Wmy_bounds.shape) #Wmy_bounds.copy()*0 #np.zeros_like(Wmy_bounds)

    if not constrain_wts is None:
        for wt in constrain_wts:
            Wmy_bounds[wt[0],wt[1]] = 0
            Wsy_bounds[wt[0],wt[1]] = 0
    
    def tile_nS_nT_nN(kernel):
        row = np.concatenate([kernel for idim in range(nS*nT)],axis=0)[np.newaxis,:]
        tiled = np.concatenate([row for irow in range(nN)],axis=0)
        return tiled
    
    if fit_s02:
        s02_bounds = 2*np.ones((nQ,)) # permitting noise as a free parameter
    else:
        s02_bounds = np.ones((nQ,))
    
    k_bounds = 1.5*np.ones((nQ,))
    
    kappa_bounds = np.ones((1,))
    # kappa_bounds = 2*np.ones((1,))
    
    T_bounds = 1.5*np.ones((nQ,))
    
    X_bounds = tile_nS_nT_nN(np.array([2,1]))
    # X_bounds = np.array([np.array([2,1,2,1])]*nN)
    
    Xp_bounds = tile_nS_nT_nN(np.array([3,1]))
    # Xp_bounds = np.array([np.array([3,1,3,1])]*nN)
    
    # Y_bounds = tile_nS_nT_nN(2*np.ones((nQ,)))
    # # Y_bounds = 2*np.ones((nN,nT*nS*nQ))
    
    Eta_bounds = tile_nS_nT_nN(3*np.ones((nQ,)))
    # Eta_bounds = 3*np.ones((nN,nT*nS*nQ))
    
    if allow_var:
        Xi_bounds = tile_nS_nT_nN(3*np.ones((nQ,)))
    else:
        Xi_bounds = tile_nS_nT_nN(np.zeros((nQ,)))

    # Xi_bounds = 3*np.ones((nN,nT*nS*nQ))
    
    h1_bounds = -2*np.ones((1,))
    
    h2_bounds = 2*np.ones((1,))
    
    # shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nN,nS*nP),(nN,nS*nQ),(nN,nS*nQ),(nN,nS*nQ)]
    shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nQ,),(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ),(1,),(1,),(nN,nT*nS*nQ),(nN,nT*nS*nQ)]
    print('size of shapes: '+str(np.sum([np.prod(shp) for shp in shapes])))
    #         Wmx,    Wmy,    Wsx,    Wsy,    s02,  k,    kappa,T,   XX,            XXp,          Eta,          Xi, h1, h2, Eta1,   Eta2
    
    lb = [-np.inf*np.ones(shp) for shp in shapes]
    ub = [np.inf*np.ones(shp) for shp in shapes]
    bdlist = [Wmx_bounds,Wmy_bounds,Wsx_bounds,Wsy_bounds,s02_bounds,k_bounds,kappa_bounds,T_bounds,X_bounds,Xp_bounds,Eta_bounds,Xi_bounds,h1_bounds,h2_bounds,Eta_bounds,Eta_bounds]
    
    set_bound(lb,[bd==0 for bd in bdlist],val=0)
    set_bound(ub,[bd==0 for bd in bdlist],val=0)
    
    set_bound(lb,[bd==2 for bd in bdlist],val=0)
    
    set_bound(ub,[bd==-2 for bd in bdlist],val=0)
    
    set_bound(lb,[bd==1 for bd in bdlist],val=1)
    set_bound(ub,[bd==1 for bd in bdlist],val=1)
    
    set_bound(lb,[bd==1.5 for bd in bdlist],val=0)
    set_bound(ub,[bd==1.5 for bd in bdlist],val=1)
    
    set_bound(lb,[bd==-1 for bd in bdlist],val=-1)
    set_bound(ub,[bd==-1 for bd in bdlist],val=-1)
    
    # for bd in [lb,ub]:
    #     for ind in [2,3]:
    #         bd[ind][:,1] = 0
    
    # temporary for no variation expt.
    # lb[2] = np.zeros_like(lb[2])
    # lb[3] = np.zeros_like(lb[3])
    # lb[4] = np.ones_like(lb[4])
    # lb[5] = np.zeros_like(lb[5])
    # ub[2] = np.zeros_like(ub[2])
    # ub[3] = np.zeros_like(ub[3])
    # ub[4] = np.ones_like(ub[4])
    # ub[5] = np.ones_like(ub[5])
    # temporary for no variation expt.
    lb = np.concatenate([a.flatten() for a in lb])
    ub = np.concatenate([b.flatten() for b in ub])
    bounds = [(a,b) for a,b in zip(lb,ub)]
    
    nS = 2
    ndims = 5
    ncelltypes = 5
    Yhat = [[None for iT in range(nT)] for iS in range(nS)]
    Xhat = [[None for iT in range(nT)] for iS in range(nS)]
    Ypc_list = [[None for iT in range(nT)] for iS in range(nS)]
    Xpc_list = [[None for iT in range(nT)] for iS in range(nS)]
    mx = [None for iS in range(nS)]
    for iS in range(nS):
        mx[iS] = np.zeros((ncelltypes,))
        yy = [None for icelltype in range(ncelltypes)]
        for icelltype in range(ncelltypes):
            yy[icelltype] = np.nanmean(Rso[icelltype][iS][0],0)
            mx[iS][icelltype] = np.nanmax(yy[icelltype])
        for iT in range(nT):
            y = [np.nanmean(Rso[icelltype][iS][iT],axis=0)[:,np.newaxis]/mx[iS][icelltype] for icelltype in range(1,ncelltypes)]
            Ypc_list[iS][iT] = [None for icelltype in range(1,ncelltypes)]
            for icelltype in range(1,ncelltypes):
                rss = Rso[icelltype][iS][iT].copy()#/mx[iS][icelltype] #.reshape(Rs[icelltype][ialign].shape[0],-1)
                #rss = Rso[icelltype][iS][iT].copy() #.reshape(Rs[icelltype][ialign].shape[0],-1)
                rss = rss[np.isnan(rss).sum(1)==0]
        #         print(rss.max())
        #         rss[rss<0] = 0
        #         rss = rss[np.random.randn(rss.shape[0])>0]
                try:
                    u,s,v = np.linalg.svd(rss-np.mean(rss,0)[np.newaxis])
                    Ypc_list[iS][iT][icelltype-1] = [(s[idim],v[idim]) for idim in range(ndims)]
    #                 print('yep on Y')
    #                 print(np.min(np.sum(rs[icelltype][iS][iT],axis=1)))
                except:
    #                 print('nope on Y')
                    print(np.mean(np.isnan(rss)))
                    print(np.min(np.sum(rs[icelltype][iS][iT],axis=1)))
            Yhat[iS][iT] = np.concatenate(y,axis=1)
    #         x = sim_utils.columnize(Rso[0][iS][iT])[:,np.newaxis]
            icelltype = 0
            #x = np.nanmean(Rso[icelltype][iS][iT],0)[:,np.newaxis]#/mx[iS][icelltype]
            x = np.nanmean(Rso[icelltype][iS][iT],0)[:,np.newaxis]/mx[iS][icelltype]
    #         opto_column = np.concatenate((np.zeros((nN,)),np.zeros((nNO/2,)),np.ones((nNO/2,))),axis=0)[:,np.newaxis]
            Xhat[iS][iT] = np.concatenate((x,np.ones_like(x)),axis=1)
    #         Xhat[iS][iT] = np.concatenate((x,np.ones_like(x),opto_column),axis=1)
            icelltype = 0
            #rss = Rso[icelltype][iS][iT].copy()/mx[iS][icelltype]
            rss = Rso[icelltype][iS][iT].copy()
            rss = rss[np.isnan(rss).sum(1)==0]
    #         try:
            u,s,v = np.linalg.svd(rss-rss.mean(0)[np.newaxis])
            Xpc_list[iS][iT] = [None for iinput in range(2)]
            Xpc_list[iS][iT][0] = [(s[idim],v[idim]) for idim in range(ndims)]
            Xpc_list[iS][iT][1] = [(0,np.zeros((Xhat[0][0].shape[0],))) for idim in range(ndims)]
    #         except:
    #             print('nope on X')
    #             print(np.mean(np.isnan(rss)))
    #             print(np.min(np.sum(Rso[icelltype][iS][iT],axis=1)))
    nN,nP = Xhat[0][0].shape
    print('nP: '+str(nP))
    nQ = Yhat[0][0].shape[1]
    
    def compute_f_(Eta,Xi,s02):
        return sim_utils.f_miller_troyer(Eta,Xi**2+np.concatenate([s02 for ipixel in range(nS*nT)]))
    def compute_fprime_m_(Eta,Xi,s02):
        return sim_utils.fprime_miller_troyer(Eta,Xi**2+np.concatenate([s02 for ipixel in range(nS*nT)]))*Xi
    def compute_fprime_s_(Eta,Xi,s02):
        s2 = Xi**2+np.concatenate((s02,s02),axis=0)
        return sim_utils.fprime_s_miller_troyer(Eta,s2)*(Xi/s2)
    def sorted_r_eigs(w):
        drW,prW = np.linalg.eig(w)
        srtinds = np.argsort(drW)
        return drW[srtinds],prW[:,srtinds]
    
    #         0.Wmx,  1.Wmy,  2.Wsx,  3.Wsy,  4.s02,5.K,  6.kappa,7.T,8.XX,        9.XXp,        10.Eta,       11.Xi,   12.h1,  13.h2,  14.Eta1,    15.Eta2
    
    shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ,),(1,),(nQ,),(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ),(1,),(1,),(nN,nT*nS*nQ),(nN,nT*nS*nQ)]
    print('size of shapes: '+str(np.sum([np.prod(shp) for shp in shapes])))
    
    import calnet.fitting_spatial_feature
    import sim_utils
    
    opto_dict = np.load(opto_silencing_data_file,allow_pickle=True)[()]
    
    Yhat_opto = opto_dict['Yhat_opto']
    for iS in range(nS):
        mx = np.zeros((nQ,))
        for iQ in range(nQ):
            slicer = slice(nQ*nT*iS+iQ,nQ*nT*(1+iS),nQ)
            mx[iQ] = np.nanmax(Yhat_opto[0::2][:,slicer])
            Yhat_opto[:,slicer] = Yhat_opto[:,slicer]/mx[iQ]
    #Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    print(Yhat_opto.shape)
    h_opto = opto_dict['h_opto']
    dYY1 = Yhat_opto[1::2]-Yhat_opto[0::2]
    for to_overwrite in [1,2,5,6]: # overwrite sst and vip with off-centered values
        dYY1[:,to_overwrite] = dYY1[:,to_overwrite+8]
    for to_overwrite in [11,15]:
        dYY1[:,to_overwrite] = np.nan #dYY1[:,to_overwrite-8]


    opto_dict = np.load(opto_activation_data_file,allow_pickle=True)[()]

    Yhat_opto = opto_dict['Yhat_opto']
    for iS in range(nS):
        mx = np.zeros((nQ,))
        for iQ in range(nQ):
            slicer = slice(nQ*nT*iS+iQ,nQ*nT*(1+iS),nQ)
            mx[iQ] = np.nanmax(Yhat_opto[0::2][:,slicer])
            Yhat_opto[:,slicer] = Yhat_opto[:,slicer]/mx[iQ]
    #Yhat_opto = Yhat_opto/Yhat_opto[0::2].max(0)[np.newaxis,:]
    print(Yhat_opto.shape)
    h_opto = opto_dict['h_opto']
    dYY2 = Yhat_opto[1::2]-Yhat_opto[0::2]
    
    print('dYY1 mean: %03f'%np.nanmean(np.abs(dYY1)))
    print('dYY2 mean: %03f'%np.nanmean(np.abs(dYY2)))

    dYY = np.concatenate((dYY1,dYY2),axis=0)
    
    opto_mask = ~np.isnan(dYY)
    
    dYY[~opto_mask] = 0
    
    np.save('/Users/dan/Documents/notebooks/mossing-PC/shared_data/calnet_data/dYY.npy',dYY)
    
    from importlib import reload
    reload(calnet)
    #reload(calnet.fitting_spatial_feature_opto_nonlinear)
    reload(sim_utils)
    # reload(calnet.fitting_spatial_feature)
    # W0list = [np.ones(shp) for shp in shapes]
    wt_dict = {}
    wt_dict['X'] = 1
    wt_dict['Y'] = 5
    wt_dict['Eta'] = 10 # 1 # 
    wt_dict['Xi'] = 0.1
    wt_dict['stims'] = np.ones((nN,1)) #(np.arange(30)/30)[:,np.newaxis]**1 #
    wt_dict['barrier'] = 0. #30.0 #0.1
    wt_dict['opto'] = 1e-1#1e1
    wt_dict['isn'] = 3
    wt_dict['dYY'] = 300#1000
    wt_dict['Eta12'] = 100
    wt_dict['EtaTV'] = 0.03
    wt_dict['coupling'] = 3

    YYhat = calnet.utils.flatten_nested_list_of_2d_arrays(Yhat)
    XXhat = calnet.utils.flatten_nested_list_of_2d_arrays(Xhat)
    np.save('XXYYhat.npy',{'YYhat':YYhat,'XXhat':XXhat,'rs':rs,'Rs':Rs,'Rso':Rso,'Ypc_list':Ypc_list,'Xpc_list':Xpc_list})
    Eta0 = invert_f_mt(YYhat)

    ntries = 1
    nhyper = 1
    dt = 1e-1
    niter = int(np.round(50/dt)) #int(1e4)
    perturbation_size = 5e-2
    # learning_rate = 1e-4 # 1e-5 #np.linspace(3e-4,1e-3,niter+1) # 1e-5
    #l2_penalty = 0.1
    Wt = [[None for itry in range(ntries)] for ihyper in range(nhyper)]
    loss = np.zeros((nhyper,ntries))
    is_neg = np.array([b[1] for b in bounds])==0
    counter = 0
    negatize = [np.zeros(shp,dtype='bool') for shp in shapes]
    print(shapes)
    for ishp,shp in enumerate(shapes):
        nel = np.prod(shp)
        negatize[ishp][:][is_neg[counter:counter+nel].reshape(shp)] = True
        counter = counter + nel
    for ihyper in range(nhyper):
        for itry in range(ntries):
            print((ihyper,itry))
            W0list = [init_noise*(ihyper+1)*np.random.rand(*shp) for shp in shapes]
            print('size of shapes: '+str(np.sum([np.prod(shp) for shp in shapes])))
            print('size of w0: '+str(np.sum([np.size(x) for x in W0list])))
            print('len(W0list) : '+str(len(W0list)))
            counter = 0
            for ishp,shp in enumerate(shapes):
                W0list[ishp][negatize[ishp]] = -W0list[ishp][negatize[ishp]]
            W0list[4] = np.ones(shapes[5]) # s02
            W0list[5] = np.ones(shapes[5]) # K
            W0list[6] = np.ones(shapes[6]) # kappa
            W0list[7] = np.ones(shapes[7]) # T
            W0list[8] = np.concatenate(Xhat,axis=1) #XX
            W0list[9] = np.zeros_like(W0list[8]) #XXp
            W0list[10] = Eta0.copy() #np.zeros(shapes[10]) #Eta
            W0list[11] = np.zeros(shapes[11]) #Xi
            W0list[14] = Eta0.copy() # Eta1
            W0list[15] = Eta0.copy() # Eta2
            #[Wmx,Wmy,Wsx,Wsy,s02,k,kappa,T,XX,XXp,Eta,Xi]
    #         W0list = Wstar_dict['as_list'].copy()
    #         W0list[1][1,0] = -1.5
    #         W0list[1][3,0] = -1.5
            if init_W_from_lsq:
                W0list[0],W0list[1] = initialize_W(Xhat,Yhat,scale_by=scale_init_by)
                for ivar in range(0,2):
                    W0list[ivar] = W0list[ivar] + init_noise*np.random.randn(*W0list[ivar].shape)
            if constrain_isn:
                W0list[1][0,0] = 3 
                W0list[1][0,3] = 5 
                W0list[1][3,0] = -5
                W0list[1][3,3] = -5

            #if constrain_isn:
            #    W0list[1][0,0] = 2
            #    W0list[1][0,3] = 2
            #    W0list[1][3,0] = -2
            #    W0list[1][3,3] = -2

            #if wt_dict['coupling'] > 0:
            #    W0list[1][1,0] = -1

            if init_W_from_file:
                npyfile = np.load(init_file,allow_pickle=True)[()]
                W0list = npyfile['as_list']
                if correct_Eta:
                    W0list[10] = Eta0.copy()
                if len(W0list) < len(shapes):
                    W0list = W0list + [np.array(0.7),W0list[10].copy(),W0list[10].copy()] # add h2
                if wt_dict['coupling'] > 0:
                    W0list[1][1,0] = W0list[1][1,0] - 1

            # wt_dict['Xi'] = 10
            # wt_dict['Eta'] = 10
            print('size of bounds: '+str(np.sum([np.size(x) for x in bdlist])))
            print('size of w0: '+str(np.sum([np.size(x) for x in W0list])))
            print('size of shapes: '+str(np.sum([np.prod(shp) for shp in shapes])))
            Wt[ihyper][itry],loss[ihyper][itry],gr,hess,result = calnet.fitting_spatial_feature_opto_nonlinear.fit_W_sim(Xhat,Xpc_list,Yhat,Ypc_list,pop_rate_fn=sim_utils.f_miller_troyer,pop_deriv_fn=sim_utils.fprime_miller_troyer,neuron_rate_fn=sim_utils.evaluate_f_mt,W0list=W0list.copy(),bounds=bounds,niter=niter,wt_dict=wt_dict,l2_penalty=l2_penalty,compute_hessian=False,dt=dt,perturbation_size=perturbation_size,dYY=dYY,constrain_isn=constrain_isn,opto_mask=opto_mask)
    #         Wt[ihyper][itry] = [w[-1] for w in Wt_temp]
    #         loss[ihyper,itry] = loss_temp[-1]
    
    def parse_W(W):
        Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2,Eta1,Eta2 = W
        return Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2,Eta1,Eta2
    
    
    itry = 0
    Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2,Eta1,Eta2 = parse_W(Wt[0][0])

    labels = ['Wmx','Wmy','Wsx','Wsy','s02','K','kappa','T','XX','XXp','Eta','Xi','h1','h2','Eta1','Eta2']
    Wstar_dict = {}
    for i,label in enumerate(labels):
        Wstar_dict[label] = Wt[0][0][i]
    Wstar_dict['as_list'] = [Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2,Eta1,Eta2]
    Wstar_dict['loss'] = loss[0][0]
    Wstar_dict['wt_dict'] = wt_dict
    np.save(weights_file,Wstar_dict,allow_pickle=True)
예제 #27
0
파일: lds.py 프로젝트: yahmadian/ssm
    def _surrogate_elbo(self,
                        variational_posterior,
                        datas,
                        inputs=None,
                        masks=None,
                        tags=None,
                        alpha=0.75,
                        **kwargs):
        """
        Lower bound on the marginal likelihood p(y | gamma)
        using variational posterior q(x; phi) where phi = variational_params
        and gamma = emission parameters.  As part of computing this objective,
        we optimize q(z | x) and take a natural gradient step wrt theta, the
        parameters of the dynamics model.

        Note that the surrogate ELBO is a lower bound on the ELBO above.
           E_p(z | x, y)[log p(z, x, y)]
           = E_p(z | x, y)[log p(z, x, y) - log p(z | x, y) + log p(z | x, y)]
           = E_p(z | x, y)[log p(x, y) + log p(z | x, y)]
           = log p(x, y) + E_p(z | x, y)[log p(z | x, y)]
           = log p(x, y) -H[p(z | x, y)]
          <= log p(x, y)
        with equality only when p(z | x, y) is atomic.  The gap equals the
        entropy of the posterior on z.
        """
        # log p(theta)
        elbo = self.log_prior()

        # Sample x from the variational posterior
        xs = variational_posterior.sample()

        # Inner optimization: find the true posterior p(z | x, y; theta).
        # Then maximize the inner ELBO wrt theta,
        #
        #    E_p(z | x, y; theta_fixed)[log p(z, x, y; theta).
        #
        # This can be seen as a natural gradient step in theta
        # space.  Note: we do not want to compute gradients wrt x or the
        # emissions parameters backward throgh this optimization step,
        # so we unbox them first.
        xs_unboxed = [getval(x) for x in xs]
        emission_params_boxed = self.emissions.params
        flat_emission_params_boxed, unflatten = flatten(emission_params_boxed)
        self.emissions.params = unflatten(getval(flat_emission_params_boxed))

        # E step: compute the true posterior p(z | x, y, theta_fixed) and
        # the necessary expectations under this posterior.
        expectations = [
            self.expected_states(x, data, input, mask,
                                 tag) for x, data, input, mask, tag in zip(
                                     xs_unboxed, datas, inputs, masks, tags)
        ]

        # M step: maximize expected log joint wrt parameters
        # Note: Only do a partial update toward the M step for this sample of xs
        x_masks = [np.ones_like(x, dtype=bool) for x in xs_unboxed]
        for distn in [self.init_state_distn, self.transitions, self.dynamics]:
            curr_prms = copy.deepcopy(distn.params)
            distn.m_step(expectations, xs_unboxed, inputs, x_masks, tags,
                         **kwargs)
            distn.params = convex_combination(curr_prms, distn.params, alpha)

        # Box up the emission parameters again before computing the ELBO
        self.emissions.params = emission_params_boxed

        # Compute expected log likelihood E_q(z | x, y) [log p(z, x, y; theta)]
        for (Ez, Ezzp1, _), x, x_mask, data, mask, input, tag in \
            zip(expectations, xs, x_masks, datas, masks, inputs, tags):

            # Compute expected log likelihood (inner ELBO)
            log_pi0 = self.init_state_distn.log_initial_state_distn(
                x, input, x_mask, tag)
            log_Ps = self.transitions.log_transition_matrices(
                x, input, x_mask, tag)
            log_likes = self.dynamics.log_likelihoods(x, input, x_mask, tag)
            log_likes += self.emissions.log_likelihoods(
                data, input, mask, tag, x)

            elbo += np.sum(Ez[0] * log_pi0)
            elbo += np.sum(Ezzp1 * log_Ps)
            elbo += np.sum(Ez * log_likes)

        # -log q(x)
        elbo -= variational_posterior.log_density(xs)
        assert np.isfinite(elbo)

        return elbo
 def majorant(self, t, *args, **kwargs):
     kappa = self.get_param('kappa', **kwargs)
     mu = self.get_param('mu', 0.0, **kwargs)
     kappa = np.maximum(kappa, -mu)
     return np.ones_like(t) * (mu + np.amax(kappa))
def get_distances(positions,
                  cell,
                  cutoff_distance,
                  skin=0.01,
                  strain=np.zeros((3, 3))):
  """Get distances to atoms in a periodic unitcell.

    Parameters
    ----------

    positions: atomic positions. array-like (natoms, 3)
    cell: unit cell. array-like (3, 3)
    cutoff_distance: Maximum distance to get neighbor distances for. float
    skin: A tolerance for the cutoff_distance. float
    strain: array-like (3, 3)

    Returns
    -------

    distances : an array of dimension (atom_i, atom_j, distance) The shape is
    (natoms, natoms, nunitcells) where nunitcells is the total number of unit
    cells required to tile the space to be sure all neighbors will be found. The
    atoms that are outside the cutoff distance are zeroed.

    offsets

    """
  positions = np.array(positions)
  cell = np.array(cell)
  strain_tensor = np.eye(3) + strain
  cell = np.dot(strain_tensor, cell.T).T
  positions = np.dot(strain_tensor, positions.T).T

  inverse_cell = np.linalg.inv(cell)
  num_repeats = cutoff_distance * np.linalg.norm(inverse_cell, axis=0)

  fractional_coords = np.dot(positions, inverse_cell) % 1
  mins = np.min(np.floor(fractional_coords - num_repeats), axis=0)
  maxs = np.max(np.ceil(fractional_coords + num_repeats), axis=0)

  # Now we generate a set of cell offsets
  v0_range = np.arange(mins[0], maxs[0])
  v1_range = np.arange(mins[1], maxs[1])
  v2_range = np.arange(mins[2], maxs[2])

  xhat = np.array([1, 0, 0])
  yhat = np.array([0, 1, 0])
  zhat = np.array([0, 0, 1])

  v0_range = v0_range[:, None] * xhat[None, :]
  v1_range = v1_range[:, None] * yhat[None, :]
  v2_range = v2_range[:, None] * zhat[None, :]

  offsets = (
      v0_range[:, None, None] + v1_range[None, :, None] +
      v2_range[None, None, :])

  offsets = np.int_(offsets.reshape(-1, 3))
  # Now we have a vector of unit cell offsets (offset_index, 3)
  # We convert that to cartesian coordinate offsets
  cart_offsets = np.dot(offsets, cell)

  # we need to offset each coord by each offset.
  # This array is (atom_index, offset, 3)
  shifted_cart_coords = positions[:, None] + cart_offsets[None, :]

  # Next, we subtract each position from the array of positions
  # (atom_i, atom_j, positionvector, 3)
  pv = shifted_cart_coords - positions[:, None, None]

  # This is the distance squared
  # (atom_i, atom_j, distance_ij)
  d2 = np.sum(pv**2, axis=3)

  # The gradient of sqrt is nan at r=0, so we do this round about way to
  # avoid that.
  zeros = np.equal(d2, 0.0)
  adjusted = np.where(zeros, np.ones_like(d2), d2)
  d = np.where(zeros, np.zeros_like(d2), np.sqrt(adjusted))

  distances = np.where(d < (cutoff_distance + skin), d, np.zeros_like(d))
  return distances, offsets
예제 #30
0
def test_gamma_method_irregular():
    N = 20000
    arr = np.random.normal(1, .2, size=N)
    afull = pe.Obs([arr], ['a'])

    configs = np.ones_like(arr)
    for i in np.random.uniform(0, len(arr), size=int(.8 * N)):
        configs[int(i)] = 0
    zero_arr = [arr[i] for i in range(len(arr)) if not configs[i] == 0]
    idx = [i + 1 for i in range(len(configs)) if configs[i] == 1]
    a = pe.Obs([zero_arr], ['a'], idl=[idx])

    afull.gamma_method()
    a.gamma_method()
    ad = a.dvalue

    expe = (afull.dvalue * np.sqrt(N / np.sum(configs)))
    assert (a.dvalue - 5 * a.ddvalue < expe
            and expe < a.dvalue + 5 * a.ddvalue)

    afull.gamma_method(fft=False)
    a.gamma_method(fft=False)

    expe = (afull.dvalue * np.sqrt(N / np.sum(configs)))
    assert (a.dvalue - 5 * a.ddvalue < expe
            and expe < a.dvalue + 5 * a.ddvalue)
    assert np.abs(a.dvalue -
                  ad) <= 10 * max(a.dvalue, ad) * np.finfo(np.float64).eps

    afull.gamma_method(tau_exp=.00001)
    a.gamma_method(tau_exp=.00001)

    expe = (afull.dvalue * np.sqrt(N / np.sum(configs)))
    assert (a.dvalue - 5 * a.ddvalue < expe
            and expe < a.dvalue + 5 * a.ddvalue)

    arr2 = np.random.normal(1, .2, size=N)
    afull = pe.Obs([arr, arr2], ['a1', 'a2'])

    configs = np.ones_like(arr2)
    for i in np.random.uniform(0, len(arr2), size=int(.8 * N)):
        configs[int(i)] = 0
    zero_arr2 = [arr2[i] for i in range(len(arr2)) if not configs[i] == 0]
    idx2 = [i + 1 for i in range(len(configs)) if configs[i] == 1]
    a = pe.Obs([zero_arr, zero_arr2], ['a1', 'a2'], idl=[idx, idx2])

    afull.gamma_method()
    a.gamma_method()

    expe = (afull.dvalue * np.sqrt(N / np.sum(configs)))
    assert (a.dvalue - 5 * a.ddvalue < expe
            and expe < a.dvalue + 5 * a.ddvalue)

    def gen_autocorrelated_array(inarr, rho):
        outarr = np.copy(inarr)
        for i in range(1, len(outarr)):
            outarr[i] = rho * outarr[i - 1] + np.sqrt(1 - rho**2) * outarr[i]
        return outarr

    arr = np.random.normal(1, .2, size=N)
    carr = gen_autocorrelated_array(arr, .346)
    a = pe.Obs([carr], ['a'])
    a.gamma_method()

    ae = pe.Obs([[carr[i] for i in range(len(carr)) if i % 2 == 0]], ['a'],
                idl=[[i for i in range(len(carr)) if i % 2 == 0]])
    ae.gamma_method()

    ao = pe.Obs([[carr[i] for i in range(len(carr)) if i % 2 == 1]], ['a'],
                idl=[[i for i in range(len(carr)) if i % 2 == 1]])
    ao.gamma_method()

    assert (ae.e_tauint['a'] < a.e_tauint['a'])
    assert ((ae.e_tauint['a'] - 4 * ae.e_dtauint['a'] < ao.e_tauint['a']))
    assert ((ae.e_tauint['a'] + 4 * ae.e_dtauint['a'] > ao.e_tauint['a']))
예제 #31
0
def _a_from_x(x):
    y = anp.matmul(anp.transpose(x), x)
    onevec = anp.ones_like(x[0])
    return y + 0.01 * anp.diag(onevec)