def fit_W_sim(Xhat,
              Xpc_list,
              Yhat,
              Ypc_list,
              dYY,
              pop_rate_fn=None,
              pop_deriv_fn=None,
              neuron_rate_fn=None,
              W0list=None,
              bounds=None,
              dt=1e-1,
              perturbation_size=5e-2,
              niter=1,
              wt_dict=None,
              eta=0.1,
              compute_hessian=False,
              l2_penalty=1.0,
              constrain_isn=False,
              opto_mask=None,
              nsize=6,
              ncontrast=6,
              coupling_constraints=[(1, 0, -1)],
              tv=False,
              topo_stims=np.arange(36),
              topo_shape=(6, 6),
              use_opto_transforms=False,
              opto_transform1=None,
              opto_transform2=None):
    # coupling constraints: (i,j,sgn) --> coupling term i->j is constrained to be > 0 (sgn=1) or < 0 (sgn=-1)

    fudge = 1e-4
    noise = 1
    big_val = 1e6

    fprime_m = pop_deriv_fn  #utils.fprime_miller_troyer #egrad(pop_rate_fn,0)

    YYhat = utils.flatten_nested_list_of_2d_arrays(Yhat)
    XXhat = utils.flatten_nested_list_of_2d_arrays(Xhat)

    nS = len(Yhat)
    nT = len(Yhat[0])
    assert (nS == len(Xhat))
    assert (nT == len(Xhat[0]))
    nN, nP = Xhat[0][0].shape
    nQ = Yhat[0][0].shape[1]
    assert (nN == Yhat[0][0].shape[0])

    def add_key_val(d, key, val):
        if not key in d:
            d[key] = val

    if wt_dict is None:
        wt_dict = {}
    add_key_val(wt_dict, 'celltypes', np.ones((1, nT * nS * nQ)))
    add_key_val(
        wt_dict, 'inputs',
        np.concatenate([np.array((1, 0)) for i in range(nT * nS)],
                       axis=0)[np.newaxis, :])
    add_key_val(wt_dict, 'stims', np.ones((nN, 1)))
    add_key_val(wt_dict, 'X', 1)
    add_key_val(wt_dict, 'Y', 1)
    add_key_val(wt_dict, 'Eta', 1)
    add_key_val(wt_dict, 'Xi', 1)
    add_key_val(wt_dict, 'barrier', 1)
    add_key_val(wt_dict, 'opto', 100.)
    add_key_val(wt_dict, 'isn', 0.01)
    add_key_val(wt_dict, 'tv', 0.01)
    add_key_val(wt_dict, 'celltypesOpto', np.ones((1, nT * nS * nQ)))
    add_key_val(wt_dict, 'stimsOpto', np.ones((nN, 1)))
    add_key_val(wt_dict, 'dirOpto', np.ones((2, )))

    wtCell = wt_dict['celltypes']
    wtInp = wt_dict['inputs']
    wtStim = wt_dict['stims']
    wtX = wt_dict['X']
    wtY = wt_dict['Y']
    wtEta = wt_dict['Eta']
    wtXi = wt_dict['Xi']
    barrier_wt = wt_dict['barrier']
    wtOpto = wt_dict['opto']
    wtISN = wt_dict['isn']
    wtdYY = wt_dict['dYY']
    wtEta12 = wt_dict['Eta12']
    #wtEtaTV = wt_dict['EtaTV']
    wtTV = wt_dict['tv']
    wtCoupling = wt_dict['coupling']
    wtCellOpto = wt_dict['celltypesOpto']
    wtStimOpto = wt_dict['stimsOpto']
    wtDirOpto = wt_dict['dirOpto']

    #if wtEtaTV > 0:
    #    assert(nsize*ncontrast==nN)

    if wtCoupling > 0:
        assert (not coupling_constraints is None)
        constrain_coupling = True
    else:
        constrain_coupling = False

    # Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,YY,Eta,Xi,h1,h2,Eta1,Eta2
    shapes = [(nP, nQ), (nQ, nQ), (nP, nQ), (nQ, nQ),
              (nQ, ), (nQ * (nS - 1), ), (1, ), (nQ * (nT - 1), ),
              (nN, nT * nS * nP), (nN, nT * nS * nP), (nN, nT * nS * nQ),
              (nN, nT * nS * nQ), (1, ), (1, ), (nN, nT * nS * nQ),
              (nN, nT * nS * nQ)]

    first = True

    # Yhat is all measured tuning curves, Y is the averages of the model tuning curves
    def parse_W(W):
        Ws = utils.parse_thing(W, shapes)
        return Ws

    def unparse_W(*Ws):
        return np.concatenate([ww.flatten() for ww in Ws])

    def normalize(arr):
        arrsum = arr.sum(1)
        well_behaved = (arrsum > 0)[:, np.newaxis]
        arrnorm = well_behaved * arr / arrsum[:, np.newaxis] + (
            ~well_behaved) * np.ones_like(arr) / arr.shape[1]
        return arrnorm

    def gen_Weight(W, K, kappa, T):
        return utils.gen_Weight_k_kappa_t(W, K, kappa, T, nS=nS, nT=nT)

    def compute_kl_divergence(stim_deriv, noise, mu_data, mu_model, pc_list):
        return utils.compute_kl_divergence(stim_deriv,
                                           noise,
                                           mu_data,
                                           mu_model,
                                           pc_list,
                                           nS=nS,
                                           nT=nT)

    def compute_var(Xi, s02):
        return fudge + Xi**2 + np.concatenate(
            [s02 for ipixel in range(nS * nT)], axis=0)

    def optimize(W0, compute_hessian=False):
        def compute_fprime_(Eta, Xi, s02):
            return fprime_m(Eta, compute_var(Xi, s02)) * Xi

        def compute_f_(Eta, Xi, s02):
            return pop_rate_fn(Eta, compute_var(Xi, s02))

        def compute_f_fprime_t_(W,
                                perturbation,
                                max_dist=1):  # max dist added 10/14/20
            Wmx, Wmy, Wsx, Wsy, s02, k, kappa, T, XX, XXp, Eta, Xi, h1, h2, Eta1, Eta2 = parse_W(
                W)
            fval = compute_f_(Eta, Xi, s02)
            fprimeval = compute_fprime_(Eta, Xi, s02)
            resEta = Eta - u_fn(XX, fval, Wmx, Wmy, k, kappa, T)
            resXi = Xi - u_fn(XX, fval, Wsx, Wsy, k, kappa)
            YY = fval + perturbation
            YYp = fprimeval

            def dYYdt(YY, Eta1, Xi1):
                return -YY + compute_f_(Eta1, Xi1, s02)

            def dYYpdt(YYp, Eta1, Xi1):
                return -YYp + compute_fprime_(Eta1, Xi1, s02)

            for t in range(niter):
                if np.mean(np.abs(YY - fval)) < max_dist:
                    Eta1 = resEta + u_fn(XX, YY, Wmx, Wmy, k, kappa, T)
                    Xi1 = resXi + u_fn(XX, YY, Wmx, Wmy, k, kappa, T)
                    YY = YY + dt * dYYdt(YY, Eta1, Xi1)
                    YYp = YYp + dt * dYYpdt(YYp, Eta1, Xi1)
                elif np.remainder(t, 500) == 0:
                    print('unstable fixed point?')

            #YYp = compute_fprime_(Eta1,Xi1,s02)

            return YY, YYp

        def compute_f_fprime_t_avg_(W, perturbation, burn_in=0.5, max_dist=1):
            Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, h1, h2, Eta1, Eta2 = parse_W(
                W)
            fval = compute_f_(Eta, Xi, s02)
            fprimeval = compute_fprime_(Eta, Xi, s02)
            resEta = Eta - u_fn(XX, fval, Wmx, Wmy, K, kappa, T)
            resXi = Xi - u_fn(XX, fval, Wsx, Wsy, K, kappa, T)
            YY = fval + perturbation
            YYp = fprimeval
            YYmean = np.zeros_like(Eta)
            YYprimemean = np.zeros_like(Eta)

            def dYYdt(YY, Eta1, Xi1):
                return -YY + compute_f_(Eta1, Xi1, s02)

            def dYYpdt(YYp, Eta1, Xi1):
                return -YYp + compute_fprime_(Eta1, Xi1, s02)

            for t in range(niter):
                if np.mean(np.abs(YY - fval)) < max_dist:
                    Eta1 = resEta + u_fn(XX, YY, Wmx, Wmy, K, kappa, T)
                    Xi1 = resXi + u_fn(XX, YY, Wsx, Wsy, K, kappa, T)
                    YY = YY + dt * dYYdt(YY, Eta1, Xi1)
                    YYp = YYp + dt * dYYpdt(YYp, Eta1, Xi1)
                else:
                    print('unstable fixed point?')
                #Eta1 = resEta + u_fn(XX,YY,Wmx,Wmy,K,kappa,T)
                #Xi1 = resXi + u_fn(XX,YY,Wsx,Wsy,K,kappa,T)
                #YY = YY + dt*dYYdt(YY,Eta1,Xi1)
                if t > niter * burn_in:
                    #YYp = compute_fprime_(Eta1,Xi1,s02)
                    YYmean = YYmean + 1 / niter / burn_in * YY
                    YYprimemean = YYprimemean + 1 / niter / burn_in * YYp

            return YYmean, YYprimemean

        def u_fn(XX, YY, Wx, Wy, K, kappa, T):
            WWx, WWy = [gen_Weight(W, K, kappa, T) for W in [Wx, Wy]]
            return XX @ WWx + YY @ WWy

        def minusLW(W):
            def compute_sq_error(a, b, wt):
                return np.sum(wt * (a - b)**2)

            def compute_kl_error(mu_data, pc_list, mu_model, fprimeval, wt):
                # how to model variability in X?
                kl = compute_kl_divergence(fprimeval, noise, mu_data, mu_model,
                                           pc_list)
                return kl  #wt*kl
                # principled way would be to use 1/wt for noise term. Should add later.

            def compute_opto_error_nonlinear(W, wt=None):
                if wt is None:
                    wt = np.ones((2 * nN, nQ * nS * nT))
                Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, h1, h2, Eta1, Eta2 = parse_W(
                    W)
                Eta12 = np.concatenate((Eta1, Eta2), axis=0)
                Xi12 = np.concatenate((Xi, Xi), axis=0)
                XX12 = np.concatenate((XX, XX), axis=0)
                fval12 = compute_f_(Eta12, Xi12, s02)

                fval = compute_f_(Eta, Xi, s02)
                dYY12 = fval12 - np.concatenate((fval, fval), axis=0)
                dYYterm = np.sum(wt[opto_mask] *
                                 (dYY12[opto_mask] - dYY[opto_mask])**2)

                dHH = np.zeros((nN, nQ * nS * nT))
                dHH[:, np.arange(2, nQ * nS * nT, nQ)] = 1
                dHH = np.concatenate((dHH * h1, dHH * h2), axis=0)
                Eta12perf = u_fn(XX12, fval12, Wmx, Wmy, K, kappa, T) + dHH
                Eta12term = np.sum(wt * (Eta12perf - Eta12)**2)

                #cost = wtdYY*dYYterm + wtEta12*Eta12term
                return dYYterm, Eta12term  #cost

            def compute_opto_error_nonlinear_transform(W, wt=None):
                if wt is None:
                    wt = np.ones((2 * nN, nQ * nS * nT))
                Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, h1, h2, Eta1, Eta2 = parse_W(
                    W)
                Eta12 = np.concatenate((Eta1, Eta2), axis=0)
                Xi12 = np.concatenate((Xi, Xi), axis=0)
                XX12 = np.concatenate((XX, XX), axis=0)
                fval12 = compute_f_(Eta12, Xi12, s02)

                fval = compute_f_(Eta, Xi, s02)
                #fvalrep = np.concatenate((fval,fval),axis=0)
                #dYY12 = fval12 - fvalrep
                fval12target = np.concatenate(
                    (opto_transform1.transform(fval),
                     opto_transform2.transform(fval)),
                    axis=0)
                #this_dYY = fval12target - fvalrep

                dYYterm = np.sum(
                    wt[opto_mask] *
                    (fval12[opto_mask] - fval12target[opto_mask])**2)

                dHH = np.zeros((nN, nQ * nS * nT))
                dHH[:, np.arange(2, nQ * nS * nT, nQ)] = 1
                dHH = np.concatenate((dHH * h1, dHH * h2), axis=0)
                Eta12perf = u_fn(XX12, fval12, Wmx, Wmy, K, kappa, T) + dHH
                Eta12term = np.sum(wt * (Eta12perf - Eta12)**2)

                #cost = wtdYY*dYYterm + wtEta12*Eta12term
                return dYYterm, Eta12term  #cost

            def compute_coupling(W):
                Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, h1, h2, Eta1, Eta2 = parse_W(
                    W)
                WWy = gen_Weight(Wmy, K, kappa, T)
                Phi = fprime_m(Eta, compute_var(Xi, s02))
                Phi = np.concatenate((Phi, Phi), axis=0)
                Phi1 = np.array([np.diag(phi) for phi in Phi])
                coupling = np.array([
                    phi1 @ np.linalg.inv(np.eye(nQ * nS * nT) - WWy @ phi1)
                    for phi1 in Phi1
                ])
                return coupling

            def compute_coupling_error(W, i, j, sgn=-1):
                # constrain coupling term i,j to have a specified sign,
                # -1 for negative or +1 for positive
                coupling = compute_coupling(W)
                log_arg = sgn * coupling[:, i, j]
                cost = utils.minus_sum_log_ceil(log_arg, big_val / nN)
                return cost

            #def compute_eta_tv(this_Eta):
            #    Etar = this_Eta.reshape((nsize,ncontrast,nQ*nS*nT))
            #    diff_size = np.sum(np.abs(np.diff(Etar,axis=0)))
            #    diff_contrast = np.sum(np.abs(np.diff(Etar,axis=1)))
            #    return diff_size + diff_contrast

            def compute_isn_error(W):
                Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, h1, h2, Eta1, Eta2 = parse_W(
                    W)
                Phi = fprime_m(Eta, compute_var(Xi, s02))
                #print('min Eta: %f'%np.min(Eta[:,0]))
                #print('WEE: %f'%Wmy[0,0])
                #print('min phiE*WEE: %f'%np.min(Phi[:,0]*Wmy[0,0]))
                log_arg = Phi[:, 0] * Wmy[0, 0] - 1
                cost = utils.minus_sum_log_ceil(log_arg, big_val / nN)
                #print('ISN cost: %f'%cost)
                return cost

            def compute_tv_error(W):
                # sq l2 norm for tv error
                Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, h1, h2, Eta1, Eta2 = parse_W(
                    W)
                topo_var_list = [arr.reshape(topo_shape+(-1,)) for arr in \
                        [XX,XXp,Eta,Xi,Eta1,Eta2]]
                sqdiffy = [
                    np.sum(np.abs(np.diff(top, axis=0))**2)
                    for top in topo_var_list
                ]
                sqdiffx = [
                    np.sum(np.abs(np.diff(top, axis=1))**2)
                    for top in topo_var_list
                ]
                cost = np.sum(sqdiffy + sqdiffx)
                return cost

            Wmx, Wmy, Wsx, Wsy, s02, K, kappa, T, XX, XXp, Eta, Xi, h1, h2, Eta1, Eta2 = parse_W(
                W)

            #utils.print_labeled('T',T)
            #utils.print_labeled('K',K)
            #utils.print_labeled('Wmy',Wmy)

            perturbation = perturbation_size * np.random.randn(*Eta.shape)

            #             fval,fprimeval = compute_f_fprime_t_(W,perturbation) # Eta the mean input per cell, Xi the stdev. input per cell, s02 the baseline variability in input
            fval, fprimeval = compute_f_fprime_t_avg_(
                W, perturbation
            )  # Eta the mean input per cell, Xi the stdev. input per cell, s02 the baseline variability in input
            #utils.print_labeled('fval',fval)

            Xterm = compute_kl_error(XXhat, Xpc_list, XX, XXp, wtStim *
                                     wtInp)  # XX the modeled input layer (L4)
            Yterm = compute_kl_error(
                YYhat, Ypc_list, fval, fprimeval,
                wtStim * wtCell)  # fval the modeled output layer (L2/3)
            Etaterm = compute_sq_error(
                Eta, u_fn(XX, fval, Wmx, Wmy, K, kappa, T),
                wtStim * wtCell)  # magnitude of fudge factor in mean input
            Xiterm = compute_sq_error(
                Xi, u_fn(XX, fval, Wsx, Wsy, K, kappa, T), wtStim *
                wtCell)  # magnitude of fudge factor in input variability
            # returns value float
            #Optoterm = compute_opto_error_nonlinear(W) #testing out 8/20/20
            opto_wt = np.concatenate(
                [wtStimOpto * wtCellOpto * w for w in wtDirOpto], axis=0)
            if use_opto_transforms:
                dYYterm, Eta12term = compute_opto_error_nonlinear_transform(
                    W, opto_wt)
            else:
                dYYterm, Eta12term = compute_opto_error_nonlinear(W, opto_wt)
            Optoterm = wtdYY * dYYterm + wtEta12 * Eta12term
            #EtaTVterm = 0
            #for this_Eta in [Eta,Eta1,Eta2]:
            #    EtaTVterm = EtaTVterm + compute_eta_tv(this_Eta)
            cost = wtX * Xterm + wtY * Yterm + wtEta * Etaterm + wtXi * Xiterm + wtOpto * Optoterm  # + wtEtaTV*EtaTVterm
            if constrain_isn:
                ISNterm = compute_isn_error(W)
                cost = cost + wtISN * ISNterm
            if constrain_coupling:
                Couplingterm = 0
                for el in coupling_constraints:
                    i, j, sgn = el
                    Couplingterm = Couplingterm + compute_coupling_error(
                        W, i, j, sgn)
                cost = cost + wtCoupling * Couplingterm
            if tv:
                TVterm = compute_tv_error(W)
                cost = cost + wtTV * TVterm

            if isinstance(Xterm, float):
                print('X:%f' % (wtX * Xterm))
                print('Y:%f' % (wtY * Yterm))
                print('Eta:%f' % (wtEta * Etaterm))
                print('Xi:%f' % (wtXi * Xiterm))
                print('Opto dYY:%f' % (wtOpto * wtdYY * dYYterm))
                print('Opto Eta:%f' % (wtOpto * wtEta12 * Eta12term))
                #print('TV:%f'%(wtEtaTV*EtaTVterm))
                print('TV:%f' % (wtTV * TVterm))
                if constrain_isn:
                    print('ISN:%f' % (wtISN * ISNterm))
                if constrain_coupling:
                    print('coupling:%f' % (wtCoupling * Couplingterm))

            #lbls = ['Yterm']
            #vars = [Yterm]
            lbls = ['cost']
            vars = [cost]
            for lbl, var in zip(lbls, vars):
                utils.print_labeled(lbl, var)
            return cost

        def minusdLdW(W):
            # returns value (R,)
            # sum in first dimension: (N,1) times (N,1) times (N,P)
            #             return jacobian(minusLW)(W)
            return grad(minusLW)(W)

        def fix_violations(w, bounds):
            lb = np.array([b[0] for b in bounds])
            ub = np.array([b[1] for b in bounds])
            lb_violation = w < lb
            ub_violation = w > ub
            w[lb_violation] = lb[lb_violation]
            w[ub_violation] = ub[ub_violation]
            return w, lb_violation, ub_violation

        def sorted_r_eigs(w):
            drW, prW = np.linalg.eig(w)
            srtinds = np.argsort(drW)
            return drW[srtinds], prW[:, srtinds]

        def compute_eig_penalty_(Wmy, K0, kappa, T0):
            # still need to finish! Hopefully won't need
            # need to fix this to reflect addition of kappa argument
            Wsquig = gen_Weight(Wmy, K0, kappa, T0)
            drW, prW = sorted_r_eigs(Wsquig - np.eye(nQ * nS * nT))
            plW = np.linalg.inv(prW)
            eig_outer_all = [
                np.real(np.outer(plW[:, k], prW[k, :]))
                for k in range(nS * nQ * nT)
            ]
            eig_penalty_size_all = [
                barrier_wt / np.abs(np.real(drW[k]))
                for k in range(nS * nQ * nT)
            ]
            eig_penalty_dir_w = [
                eig_penalty_size *
                ((eig_outer[:nQ, :nQ] + eig_outer[nQ:, nQ:]) +
                 K0[np.newaxis, :] *
                 (eig_outer[:nQ, nQ:] + kappa * eig_outer[nQ:, :nQ]))
                for eig_outer, eig_penalty_size in zip(eig_outer_all,
                                                       eig_penalty_size_all)
            ]
            eig_penalty_dir_k = [
                eig_penalty_size *
                ((eig_outer[:nQ, nQ:] + eig_outer[nQ:, :nQ] * kappa) *
                 W0my).sum(0) for eig_outer, eig_penalty_size in zip(
                     eig_outer_all, eig_penalty_size_all)
            ]
            eig_penalty_dir_kappa = [
                eig_penalty_size *
                (eig_outer[nQ:, :nQ] * k0[np.newaxis, :] * W0my).sum().reshape(
                    (1, )) for eig_outer, eig_penalty_size in zip(
                        eig_outer_all, eig_penalty_size_all)
            ]
            eig_penalty_dir_w = np.array(eig_penalty_dir_w).sum(0)
            eig_penalty_dir_k = np.array(eig_penalty_dir_k).sum(0)
            eig_penalty_dir_kappa = np.array(eig_penalty_dir_kappa).sum(0)
            return eig_penalty_dir_w, eig_penalty_dir_k, eig_penalty_dir_kappa

        def compute_eig_penalty(W):
            # still need to finish! Hopefully won't need
            W0mx, W0my, W0sx, W0sy, s020, K0, kappa0, T0, XX0, XXp0, Eta0, Xi0, h10, h20, Eta10, Eta20 = parse_W(
                W)
            eig_penalty_dir_w, eig_penalty_dir_k, eig_penalty_dir_kappa = compute_eig_penalty_(
                W0my, k0, kappa0)
            eig_penalty_W = unparse_W(np.zeros_like(W0mx), eig_penalty_dir_w,
                                      np.zeros_like(W0sx), np.zeros_like(W0sy),
                                      np.zeros_like(s020),
                                      eig_penalty_dir_k, eig_penalty_dir_kappa,
                                      np.zeros_like(XX0), np.zeros_like(XXp0),
                                      np.zeros_like(Eta0), np.zeros_like(Xi0))
            #             assert(True==False)
            return eig_penalty_W

        allhot = np.zeros(W0.shape)
        allhot[:nP * nQ + nQ**2] = 1
        W_l2_reg = lambda W: np.sum((W * allhot)**2)
        f = lambda W: minusLW(W) + l2_penalty * W_l2_reg(W)
        fprime = lambda W: minusdLdW(W) + 2 * l2_penalty * W * allhot

        fix_violations(W0, bounds)

        W1, loss, result = sop.fmin_l_bfgs_b(f,
                                             W0,
                                             fprime=fprime,
                                             bounds=bounds,
                                             factr=1e2,
                                             maxiter=int(1e3),
                                             maxls=40)
        if compute_hessian:
            gr = grad(minusLW)(W1)
            hess = hessian(minusLW)(W1)
        else:
            gr = None
            hess = None

#         W0mx,W0my,W0sx,W0sy,s020,k0,kappa0,XX0,XXp0,Eta0,Xi0 = parse_W(W1)

        return W1, loss, gr, hess, result

    W0 = unparse_W(*W0list)

    W1, loss, gr, hess, result = optimize(W0, compute_hessian=compute_hessian)

    Wt = parse_W(W1)  #[Wmx,Wmy,Wsx,Wsy,s02,k,kappa,XX,XXp,Eta,Xi]

    return Wt, loss, gr, hess, result
def fit_W_sim(Xhat,Xpc_list,Yhat,Ypc_list,dYY,pop_rate_fn=None,pop_deriv_fn=None,W10list=None,W20list=None,bounds1=None,bounds2=None,dt=1e-1,perturbation_size=5e-2,niter=1,wt_dict=None,eta=0.1,compute_hessian=False,l2_penalty=1.0,constrain_isn=False,opto_mask=None,nsize=6,ncontrast=6,coupling_constraints=[(1,0,-1)],tv=False,topo_stims=np.arange(36),topo_shape=(6,6),use_opto_transforms=False,opto_transform1=None,opto_transform2=None,share_residuals=False,stimwise=False,simulate1=True,simulate2=False,verbose=True):
    # coupling constraints: (i,j,sgn) --> coupling term i->j is constrained to be > 0 (sgn=1) or < 0 (sgn=-1)
    
    fudge = 1e-4
    noise = 1
    big_val = 1e5
    
    fprime_m = pop_deriv_fn #utils.fprime_miller_troyer #egrad(pop_rate_fn,0)
    
    YYhat = utils.flatten_nested_list_of_2d_arrays(Yhat)
    XXhat = utils.flatten_nested_list_of_2d_arrays(Xhat)
    
    nS = len(Yhat)
    nT = len(Yhat[0])
    assert(nS==len(Xhat))
    assert(nT==len(Xhat[0]))
    nN,nP = Xhat[0][0].shape
    nQ = Yhat[0][0].shape[1]
    assert(nN==Yhat[0][0].shape[0])
    
    def add_key_val(d,key,val):
        if not key in d:
            d[key] = val
    
    if wt_dict is None:
        wt_dict = {}
    add_key_val(wt_dict,'celltypes',np.ones((1,nT*nS*nQ)))
    add_key_val(wt_dict,'inputs',np.concatenate([np.array((1,0)) for i in range(nT*nS)],axis=0)[np.newaxis,:])
    add_key_val(wt_dict,'stims',np.ones((nN,1)))
    add_key_val(wt_dict,'X',1)
    add_key_val(wt_dict,'Y',1)
    add_key_val(wt_dict,'Eta',1)
    add_key_val(wt_dict,'Xi',1)
    add_key_val(wt_dict,'barrier',1)
    add_key_val(wt_dict,'opto',100.)
    add_key_val(wt_dict,'isn',0.01)
    add_key_val(wt_dict,'tv',0.01)
    add_key_val(wt_dict,'celltypesOpto',np.ones((1,nT*nS*nQ)))
    add_key_val(wt_dict,'stimsOpto',np.ones((nN,1)))
    add_key_val(wt_dict,'dirOpto',np.ones((2,)))
    add_key_val(wt_dict,'smi',1)
    add_key_val(wt_dict,'smi_halo',0.5)
    add_key_val(wt_dict,'smi_chrimson',0.5)
    
    wtCell = wt_dict['celltypes']
    wtInp = wt_dict['inputs']
    wtStim = wt_dict['stims']
    wtX = wt_dict['X']
    wtY = wt_dict['Y']
    wtEta = wt_dict['Eta']
    wtXi = wt_dict['Xi']
    barrier_wt = wt_dict['barrier']
    wtOpto = wt_dict['opto']
    wtISN = wt_dict['isn']
    wtdYY = wt_dict['dYY']
    #wtEta12 = wt_dict['Eta12']
    #wtEtaTV = wt_dict['EtaTV']
    wtTV = wt_dict['tv']
    wtCoupling = wt_dict['coupling']
    wtCellOpto = wt_dict['celltypesOpto']
    wtStimOpto = wt_dict['stimsOpto']
    wtDirOpto = wt_dict['dirOpto']
    wtSMI = wt_dict['smi']
    wtSMIhalo = wt_dict['smi_halo']
    wtSMIchrimson = wt_dict['smi_chrimson']

    #if wtEtaTV > 0:
    #    assert(nsize*ncontrast==nN)

    if wtCoupling > 0:
        assert(not coupling_constraints is None)
        constrain_coupling = True
    else:
        constrain_coupling = False

    # Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,YY,Eta,Xi,h1,h2
    #shapes = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ),(1,),(1,)]
    shapes1 = [(nP,nQ),(nQ,nQ),(nP,nQ),(nQ,nQ),(nQ,),(nQ*(nS-1),),(1,),(nQ*(nT-1),),(1,),(1,),(nQ,),(nT*nS*nQ,)]
    shapes2 = [(nN,nT*nS*nP),(nN,nT*nS*nP),(nN,nT*nS*nQ),(nN,nT*nS*nQ)]

    first = True
        
    # Yhat is all measured tuning curves, Y is the averages of the model tuning curves
    def parse_W1(W):
        Ws = utils.parse_thing(W,shapes1)
        return Ws

    def parse_W2(W):
        Ws = utils.parse_thing(W,shapes2)
        return Ws
    
    def unparse_W(*Ws):
        return np.concatenate([ww.flatten() for ww in Ws])
    
    def normalize(arr):
        arrsum = arr.sum(1)
        well_behaved = (arrsum>0)[:,np.newaxis]
        arrnorm = well_behaved*arr/arrsum[:,np.newaxis] + (~well_behaved)*np.ones_like(arr)/arr.shape[1]
        return arrnorm
    
    def gen_Weight(W,K,kappa,T):
        return utils.gen_Weight_k_kappa_t(W,K,kappa,T,nS=nS,nT=nT) 
        
    def compute_kl_divergence(stim_deriv,noise,mu_data,mu_model,pc_list):
        return utils.compute_kl_divergence(stim_deriv,noise,mu_data,mu_model,pc_list,nS=nS,nT=nT)

    def compute_var(Xi,s02):
        return fudge+Xi**2+np.concatenate([s02 for ipixel in range(nS*nT)],axis=0)

    def compute_fprime_(Eta,Xi,s02):
        return fprime_m(Eta,compute_var(Xi,s02))*Xi

    def compute_f_(Eta,Xi,s02):
        return pop_rate_fn(Eta,compute_var(Xi,s02))

    def compute_f_fprime_(W1,W2):
        Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp = parse_W1(W1)
        XX,XXp,Eta,Xi = parse_W2(W2)
        return compute_f_(Eta,Xi,s02),compute_fprime_(Eta,Xi,s02)

    def compute_f_fprime_t_(W1,W2,perturbation,max_dist=1): # max dist added 10/14/20
        #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2 = parse_W(W)
        Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp = parse_W1(W1)
        XX,XXp,Eta,Xi = parse_W2(W2)
        fval = compute_f_(Eta,Xi,s02)
        fprimeval = compute_fprime_(Eta,Xi,s02)
        resEta = Eta - u_fn(XX,fval,Wmx,Wmy,K,kappa,T)
        resXi  = Xi - u_fn(XX,fval,Wsx,Wsy,K,kappa,T)
        YY = fval + perturbation
        YYp = fprimeval
        def dYYdt(YY,Eta1,Xi1):
            return -YY + compute_f_(Eta1,Xi1,s02)
        def dYYpdt(YYp,Eta1,Xi1):
            return -YYp +compute_fprime_(Eta1,Xi1,s02)
        for t in range(niter):
            if np.mean(np.abs(YY-fval)) < max_dist:
                Eta1 = resEta + u_fn(XX,YY,Wmx,Wmy,k,kappa,T)
                Xi1 = resXi + u_fn(XX,YY,Wsx,Wsy,k,kappa,T)
                YY = YY + dt*dYYdt(YY,Eta1,Xi1)
                YYp = YYp + dt*dYYpdt(YYp,Eta1,Xi1)
            elif np.remainder(t,500)==0:
                print('unstable fixed point?')

        #YY = YY + np.tile(bl,nS*nT)[np.newaxis,:]
            
        #YYp = compute_fprime_(Eta1,Xi1,s02)
        
        return YY,YYp

    def compute_f_fprime_t_12_(W1,W2,perturbation,max_dist=1): # max dist added 10/14/20
        #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2 = parse_W(W)
        Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp = parse_W1(W1)
        XX,XXp,Eta,Xi = parse_W2(W2)
        fval = compute_f_(Eta,Xi,s02)
        fprimeval = compute_fprime_(Eta,Xi,s02)
        if share_residuals:
            resEta = Eta - u_fn(XX,fval,Wmx,Wmy,K,kappa,T)
            resXi  = Xi - u_fn(XX,fval,Wsx,Wsy,K,kappa,T)
            resEta12 = np.concatenate((resEta,resEta),axis=0)
            resXi12 = np.concatenate((resXi,resXi),axis=0)
        else:
            resEta12 = 0
            resXi12 = 0
        dHH = np.zeros((nN,nQ*nS*nT))
        dHH[:,np.arange(2,nQ*nS*nT,nQ)] = 1
        dHH = np.concatenate((dHH*h1,dHH*h2),axis=0)
        YY = fval + perturbation
        YYp = fprimeval
        XX12 = np.concatenate((XX,XX),axis=0)
        YY12 = np.concatenate((YY,YY),axis=0)
        YYp12 = np.concatenate((YYp,YYp),axis=0)
        def dYYdt(YY,Eta1,Xi1):
            return -YY + compute_f_(Eta1,Xi1,s02)
        def dYYpdt(YYp,Eta1,Xi1):
            return -YYp + compute_fprime_(Eta1,Xi1,s02)
        for t in range(niter):
            if np.mean(np.abs(YY-fval)) < max_dist:
                Eta121 = resEta12 + u_fn(XX12,YY12,Wmx,Wmy,K,kappa,T) + dHH
                Xi121 = resXi12 + u_fn(XX12,YY12,Wsx,Wsy,K,kappa,T)
                YY12 = YY12 + dt*dYYdt(YY12,Eta121,Xi121)
                YYp12 = YYp12 + dt*dYYpdt(YYp12,Eta121,Xi121)
            elif np.remainder(t,500)==0:
                print('unstable fixed point?')

        #YY12 = YY12 + np.tile(bl,nS*nT)[np.newaxis,:]
        
        return YY12,YYp12

    def compute_f_fprime_t_avg_(W1,W2,perturbation,burn_in=0.5,max_dist=1):
        #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2 = parse_W(W)
        Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp = parse_W1(W1)
        XX,XXp,Eta,Xi = parse_W2(W2)
        fval = compute_f_(Eta,Xi,s02)
        fprimeval = compute_fprime_(Eta,Xi,s02)
        resEta = Eta - u_fn(XX,fval,Wmx,Wmy,K,kappa,T)
        resXi  = Xi - u_fn(XX,fval,Wsx,Wsy,K,kappa,T)
        YY = fval + perturbation
        YYp = fprimeval
        YYmean = np.zeros_like(Eta)
        YYprimemean = np.zeros_like(Eta)
        def dYYdt(YY,Eta1,Xi1):
            return -YY + compute_f_(Eta1,Xi1,s02)
        def dYYpdt(YYp,Eta1,Xi1):
            return -YYp + compute_fprime_(Eta1,Xi1,s02)
        for t in range(niter):
            if np.mean(np.abs(YY-fval)) < max_dist:
                Eta1 = resEta + u_fn(XX,YY,Wmx,Wmy,K,kappa,T)
                Xi1 = resXi + u_fn(XX,YY,Wsx,Wsy,K,kappa,T)
                YY = YY + dt*dYYdt(YY,Eta1,Xi1)
                YYp = YYp + dt*dYYpdt(YYp,Eta1,Xi1)
            else:
                print('unstable fixed point?')
            #Eta1 = resEta + u_fn(XX,YY,Wmx,Wmy,K,kappa,T)
            #Xi1 = resXi + u_fn(XX,YY,Wsx,Wsy,K,kappa,T)
            #YY = YY + dt*dYYdt(YY,Eta1,Xi1)
            if t>niter*burn_in:
                #YYp = compute_fprime_(Eta1,Xi1,s02)
                YYmean = YYmean + 1/niter/burn_in*YY
                YYprimemean = YYprimemean + 1/niter/burn_in*YYp

        #YYmean = YYmean + np.tile(bl,nS*nT)[np.newaxis,:]
            
        return YYmean,YYprimemean

    def compute_f_fprime_t_avg_12_(W1,W2,perturbation,max_dist=1,burn_in=0.5): # max dist added 10/14/20
        #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2 = parse_W(W)
        Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp = parse_W1(W1)
        XX,XXp,Eta,Xi = parse_W2(W2)
        fval = compute_f_(Eta,Xi,s02)
        fprimeval = compute_fprime_(Eta,Xi,s02)
        if share_residuals:
            resEta = Eta - u_fn(XX,fval,Wmx,Wmy,K,kappa,T)
            resXi  = Xi - u_fn(XX,fval,Wsx,Wsy,K,kappa,T)
            resEta12 = np.concatenate((resEta,resEta),axis=0)
            resXi12 = np.concatenate((resXi,resXi),axis=0)
        else:
            resEta12 = 0
            resXi12 = 0
        dHH = np.zeros((nN,nQ*nS*nT))
        dHH[:,np.arange(2,nQ*nS*nT,nQ)] = 1
        dHH = np.concatenate((dHH*h1,dHH*h2),axis=0)
        YY = fval + perturbation
        YYp = fprimeval
        XX12 = np.concatenate((XX,XX),axis=0)
        YY12 = np.concatenate((YY,YY),axis=0)
        YYp12 = np.concatenate((YYp,YYp),axis=0)
        YYmean = np.zeros_like(YY12)
        YYprimemean = np.zeros_like(YY12)
        def dYYdt(YY,Eta1,Xi1):
            return -YY + compute_f_(Eta1,Xi1,s02)
        def dYYpdt(YYp,Eta1,Xi1):
            return -YYp + compute_fprime_(Eta1,Xi1,s02)
        for t in range(niter):
            if np.mean(np.abs(YY-fval)) < max_dist:
                Eta121 = resEta12 + u_fn(XX12,YY12,Wmx,Wmy,K,kappa,T) + dHH
                Xi121 = resXi12 + u_fn(XX12,YY12,Wmx,Wmy,K,kappa,T)
                YY12 = YY12 + dt*dYYdt(YY12,Eta121,Xi121)
                YYp12 = YYp12 + dt*dYYpdt(YYp12,Eta121,Xi121)
            elif np.remainder(t,500)==0:
                print('unstable fixed point?')
            if t>niter*burn_in:
                YYmean = YYmean + 1/niter/burn_in*YY12
                YYprimemean = YYprimemean + 1/niter/burn_in*YYp12

        #YYmean = YYmean + np.tile(bl,nS*nT)[np.newaxis,:]
        
        return YYmean,YYprimemean

    def u_fn(XX,YY,Wx,Wy,K,kappa,T):
        WWx,WWy = [gen_Weight(W,K,kappa,T) for W in [Wx,Wy]]
        return XX @ WWx + YY @ WWy
                    
    def minusLW(W1,W2,simulate=True,verbose=True):
        
        def compute_sq_error(a,b,wt):
            return np.sum(wt*(a-b)**2)
        
        def compute_kl_error(mu_data,pc_list,mu_model,fprimeval,wt):
            # how to model variability in X?
            kl = compute_kl_divergence(fprimeval,noise,mu_data,mu_model,pc_list)
            return kl #wt*kl
            # principled way would be to use 1/wt for noise term. Should add later.

        def compute_opto_error_nonlinear(fval,fval12,wt=None):
            if wt is None:
                wt = np.ones((2*nN,nQ*nS*nT))
            fval_both = np.concatenate((np.concatenate((fval,fval),axis=0)[:,np.newaxis,:],\
                    fval12[:,np.newaxis,:]),axis=1)
            this_fval12 = opto_transform1.preprocess(fval_both)
            dYY12 = this_fval12[:,1,:] - this_fval12[:,0,:]
            dYYterm = np.sum(wt[opto_mask]*(dYY12[opto_mask] - dYY[opto_mask])**2)
            return dYYterm

        def compute_opto_error_nonlinear_transform(fval,fval12,wt=None):
            if wt is None:
                wt = np.ones((2*nN,nQ*nS*nT))
            fval_both = np.concatenate((np.concatenate((fval,fval),axis=0)[:,np.newaxis,:],\
                    fval12[:,np.newaxis,:]),axis=1)
            this_fval12 = opto_transform1.preprocess(fval_both)[:,1,:]
            fval12target = np.concatenate((opto_transform1.transform(fval),opto_transform2.transform(fval)),axis=0)
            dYYterm = np.sum(wt[opto_mask]*(this_fval12[opto_mask] - fval12target[opto_mask])**2)
            return dYYterm

        def compute_coupling(W1,W2):
            #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2 = parse_W(W)
            Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp = parse_W1(W1)
            XX,XXp,Eta,Xi = parse_W2(W2)
            WWy = gen_Weight(Wmy,K,kappa,T)
            Phi = fprime_m(Eta,compute_var(Xi,s02))
            Phi = np.concatenate((Phi,Phi),axis=0)
            Phi1 = np.array([np.diag(phi) for phi in Phi])
            coupling = np.array([phi1 @ np.linalg.pinv(np.eye(nQ*nS*nT) - WWy @ phi1) for phi1 in Phi1])
            return coupling

        def compute_coupling_error(W1,W2,i,j,sgn=-1):
            # constrain coupling term i,j to have a specified sign, 
            # -1 for negative or +1 for positive
            coupling = compute_coupling(W1,W2)
            log_arg = sgn*coupling[:,i,j]
            cost = utils.minus_sum_log_slope(log_arg,big_val/nN)
            return cost

        #def compute_eta_tv(this_Eta):
        #    Etar = this_Eta.reshape((nsize,ncontrast,nQ*nS*nT))
        #    diff_size = np.sum(np.abs(np.diff(Etar,axis=0)))
        #    diff_contrast = np.sum(np.abs(np.diff(Etar,axis=1)))
        #    return diff_size + diff_contrast

        def compute_isn_error(W1,W2):
            #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2 = parse_W(W)
            Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp = parse_W1(W1)
            XX,XXp,Eta,Xi = parse_W2(W2)
            Phi = fprime_m(Eta,compute_var(Xi,s02))
            #print('min Eta: %f'%np.min(Eta[:,0]))
            #print('WEE: %f'%Wmy[0,0])
            #print('min phiE*WEE: %f'%np.min(Phi[:,0]*Wmy[0,0]))
            if K.size:
                k = K[0]
            else:
                k = 0
            if T.size:
                t = T[0]
            else:
                t = 0
            log_arg = Phi[:,0]*Wmy[0,0]*(1+k)*(1+t) - 1
            cost = utils.minus_sum_log_slope(log_arg,big_val/nN)
            #print('ISN cost: %f'%cost)
            return cost
        
        def compute_tv_error(W1,W2):
            # sq l2 norm for tv error
            #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2 = parse_W(W)
            Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp = parse_W1(W1)
            XX,XXp,Eta,Xi = parse_W2(W2)
            topo_var_list = [arr.reshape(topo_shape+(-1,)) for arr in \
                    [XX,XXp,Eta,Xi]]
            sqdiffy = [np.sum(np.abs(np.diff(top,axis=0))**2) for top in topo_var_list]
            sqdiffx = [np.sum(np.abs(np.diff(top,axis=1))**2) for top in topo_var_list]
            cost = np.sum(sqdiffy+sqdiffx)
            return cost

        def compute_smi_error(fval,fval12,halo_mult=1,chrimson_mult=1):
            fval = compute_f_(Eta,Xi,s02)
            ipc = 0
            def compute_dsmi(fval):
                fpc = fval[:,ipc].reshape(topo_shape)
                smi = fpc[-1,:]/np.max(fpc,0)
                dsmi = smi[1] - smi[5]
                return dsmi
            dsmis = [compute_dsmi(f) for f in [fval,fval12[:nN],fval12[nN:]]]
            smi_halo_error = halo_mult*(dsmis[1] - dsmis[0])**2
            smi_chrimson_error = chrimson_mult*utils.minus_sum_log_slope(dsmis[2] - dsmis[0],big_val)
            smi_baseline_error = 1*utils.minus_sum_log_slope(dsmis[0],big_val)
            return smi_halo_error,smi_chrimson_error,smi_baseline_error

        #Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,XX,XXp,Eta,Xi,h1,h2 = parse_W(W)
        Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp = parse_W1(W1)
        XX,XXp,Eta,Xi = parse_W2(W2)

        #utils.print_labeled('T',T)
        #utils.print_labeled('K',K)
        #utils.print_labeled('Wmy',Wmy)
        
        perturbation = perturbation_size*np.random.randn(*Eta.shape)
        
#         fval,fprimeval = compute_f_fprime_t_(W1,W2,perturbation) # Eta the mean input per cell, Xi the stdev. input per cell, s02 the baseline variability in input
        #print('simulate: '+str(simulate))
        if simulate:
            fval,fprimeval = compute_f_fprime_t_avg_(W1,W2,perturbation) # Eta the mean input per cell, Xi the stdev. input per cell, s02 the baseline variability in input
        else:
            fval,fprimeval = compute_f_fprime_(W1,W2) # Eta the mean input per cell, Xi the stdev. input per cell, s02 the baseline variability in input
        fval12,fprimeval12 = compute_f_fprime_t_avg_12_(W1,W2,perturbation) # Eta the mean input per cell, Xi the stdev. input per cell, s02 the baseline variability in input
        #utils.print_labeled('fval',fval)

        bltile = np.tile(bl,nS*nT)[np.newaxis,:]
        
        Xterm = compute_kl_error(XXhat,Xpc_list,XX,XXp,wtStim*wtInp) # XX the modeled input layer (L4)
        Yterm = compute_kl_error(YYhat,Ypc_list,amp*fval+bltile,amp*fprimeval,wtStim*wtCell) # fval the modeled output layer (L2/3)

        Etaterm = compute_sq_error(Eta,u_fn(XX,fval,Wmx,Wmy,K,kappa,T),wtStim*wtCell) # magnitude of fudge factor in mean input
        Xiterm = compute_sq_error(Xi,u_fn(XX,fval,Wsx,Wsy,K,kappa,T),wtStim*wtCell) # magnitude of fudge factor in input variability
        # returns value float
        #Optoterm = compute_opto_error_nonlinear(W) #testing out 8/20/20
        opto_wt = np.concatenate([wtStimOpto*wtCellOpto*w for w in wtDirOpto],axis=0)
        if wtSMI != 0:
            SMIhaloterm,SMIchrimsonterm,SMIbaselineterm = compute_smi_error(fval,fval12,halo_mult=1,chrimson_mult=1)
        else:
            SMIhaloterm,SMIchrimsonterm,SMIbaselineterm = 0,0,0
        if wtdYY != 0:
            if use_opto_transforms:
                dYYterm = compute_opto_error_nonlinear_transform(amp*fval+bltile,amp*fval12+bltile,opto_wt)
            else:
                dYYterm = compute_opto_error_nonlinear(amp*fval+bltile,amp*fval12+bltile,opto_wt)
            Optoterm = wtdYY*dYYterm
        else:
            Optoterm = 0
        cost = wtX*Xterm + wtY*Yterm + wtEta*Etaterm + wtXi*Xiterm + wtOpto*Optoterm + wtSMIhalo*SMIhaloterm + wtSMIchrimson*SMIchrimsonterm + wtSMI*SMIbaselineterm# + wtEtaTV*EtaTVterm 
        if constrain_isn:
            ISNterm = compute_isn_error(W1,W2)
            cost = cost + wtISN*ISNterm
        if constrain_coupling:
            Couplingterm = 0
            for el in coupling_constraints:
                i,j,sgn = el
                Couplingterm = Couplingterm + compute_coupling_error(W1,W2,i,j,sgn)
            cost = cost + wtCoupling*Couplingterm
        if tv:
            TVterm = compute_tv_error(W1,W2)
            cost = cost + wtTV*TVterm

        #print('Yterm as float: '+str(float(Yterm)))
        #print('Yterm as float: '+str(Yterm.astype('float')))
            
        if isinstance(Yterm,float) and verbose:
            print('X:%f'%(wtX*Xterm))
            print('Y:%f'%(wtY*Yterm.astype('float')))
            print('Eta:%f'%(wtEta*Etaterm))
            print('Xi:%f'%(wtXi*Xiterm))
            print('Opto dYY:%f'%(wtOpto*wtdYY*dYYterm))
            #print('Opto Eta:%f'%(wtOpto*wtEta12*Eta12term))
            #print('TV:%f'%(wtEtaTV*EtaTVterm))
            print('TV:%f'%(wtTV*TVterm))
            print('SMI halo:%f'%(wtSMIhalo*SMIhaloterm))
            print('SMI chrimson:%f'%(wtSMIchrimson*SMIchrimsonterm))
            print('SMI baseline:%f'%(wtSMI*SMIbaselineterm))
            if constrain_isn:
                print('ISN:%f'%(wtISN*ISNterm))
            if constrain_coupling:
                print('coupling:%f'%(wtCoupling*Couplingterm))

        #lbls = ['Yterm']
        #vars = [Yterm]
        lbls = ['cost']
        vars = [cost]
        if verbose:
            for lbl,var in zip(lbls,vars):
                utils.print_labeled(lbl,var)
        return cost

    def minusdLdW1(W1,W2,simulate=True,verbose=True): 
        # returns value (R,)
        # sum in first dimension: (N,1) times (N,1) times (N,P)
#         return jacobian(minusLW)(W)
        return grad(lambda W1: minusLW(W1,W2,simulate=simulate,verbose=verbose))(W1)
        
    def minusdLdW2(W1,W2,simulate=True,verbose=True): 
        # returns value (R,)
        # sum in first dimension: (N,1) times (N,1) times (N,P)
#         return jacobian(minusLW)(W)
        return grad(lambda W2: minusLW(W1,W2,simulate=simulate,verbose=verbose))(W2)

    def fix_violations(w,bounds):
        lb = np.array([b[0] for b in bounds])
        ub = np.array([b[1] for b in bounds])
        lb_violation = w<lb
        ub_violation = w>ub
        w[lb_violation] = lb[lb_violation]
        w[ub_violation] = ub[ub_violation]
        return w,lb_violation,ub_violation
    
    def sorted_r_eigs(w):
        drW,prW = np.linalg.eig(w)
        srtinds = np.argsort(drW)
        return drW[srtinds],prW[:,srtinds]
    
    def compute_eig_penalty_(Wmy,K0,kappa,T0):
        # still need to finish! Hopefully won't need
        # need to fix this to reflect addition of kappa argument
        Wsquig = gen_Weight(Wmy,K0,kappa,T0)
        drW,prW = sorted_r_eigs(Wsquig - np.eye(nQ*nS*nT))
        plW = np.linalg.pinv(prW)
        eig_outer_all = [np.real(np.outer(plW[:,k],prW[k,:])) for k in range(nS*nQ*nT)]
        eig_penalty_size_all = [barrier_wt/np.abs(np.real(drW[k])) for k in range(nS*nQ*nT)]
        eig_penalty_dir_w = [eig_penalty_size*((eig_outer[:nQ,:nQ] + eig_outer[nQ:,nQ:]) + K0[np.newaxis,:]*(eig_outer[:nQ,nQ:] + kappa*eig_outer[nQ:,:nQ])) for eig_outer,eig_penalty_size in zip(eig_outer_all,eig_penalty_size_all)]
        eig_penalty_dir_k = [eig_penalty_size*((eig_outer[:nQ,nQ:] + eig_outer[nQ:,:nQ]*kappa)*W0my).sum(0) for eig_outer,eig_penalty_size in zip(eig_outer_all,eig_penalty_size_all)]
        eig_penalty_dir_kappa = [eig_penalty_size*(eig_outer[nQ:,:nQ]*k0[np.newaxis,:]*W0my).sum().reshape((1,)) for eig_outer,eig_penalty_size in zip(eig_outer_all,eig_penalty_size_all)]
        eig_penalty_dir_w = np.array(eig_penalty_dir_w).sum(0)
        eig_penalty_dir_k = np.array(eig_penalty_dir_k).sum(0)
        eig_penalty_dir_kappa = np.array(eig_penalty_dir_kappa).sum(0)
        return eig_penalty_dir_w,eig_penalty_dir_k,eig_penalty_dir_kappa
    
    def compute_eig_penalty(W):
        # still need to finish! Hopefully won't need
        W0mx,W0my,W0sx,W0sy,s020,K0,kappa0,T0,XX0,XXp0,Eta0,Xi0,h10,h20,Eta10,Eta20 = parse_W(W)
        eig_penalty_dir_w,eig_penalty_dir_k,eig_penalty_dir_kappa = compute_eig_penalty_(W0my,k0,kappa0)
        eig_penalty_W = unparse_W(np.zeros_like(W0mx),eig_penalty_dir_w,np.zeros_like(W0sx),np.zeros_like(W0sy),np.zeros_like(s020),eig_penalty_dir_k,eig_penalty_dir_kappa,np.zeros_like(XX0),np.zeros_like(XXp0),np.zeros_like(Eta0),np.zeros_like(Xi0))
#         assert(True==False)
        return eig_penalty_W
    
    def optimize1(W10,W20,compute_hessian=False,simulate=True,verbose=True):

        allhot = np.zeros(W10.shape)
        allhot[:nP*nQ+nQ**2] = 1
        W_l2_reg = lambda W: np.sum((W*allhot)**2)
        f = lambda W: minusLW(W,W20,simulate=simulate,verbose=verbose) + l2_penalty*W_l2_reg(W)
        fprime = lambda W: minusdLdW1(W,W20,simulate=simulate,verbose=verbose) + 2*l2_penalty*W*allhot

        fix_violations(W10,bounds1)
        
        #W11,loss,result = sop.fmin_l_bfgs_b(f,W10,fprime=fprime,bounds=bounds1,factr=1e2,maxiter=int(1e3),maxls=40)
        options = {}
        options['factr']=1e2
        options['maxiter']=int(1e3)
        options['maxls']=40
        result = sop.minimize(f,W10,jac=fprime,bounds=bounds1,options=options,method='L-BFGS-B')
        W11 = result.x
        loss = result.fun
        if compute_hessian:
            gr = grad(lambda W1: minusLW(W1,W2,simulate=simulate,verbose=verbose))(W1)
            hess = hessian(lambda W1: minusLW(W1,W2,simulate=simulate,verbose=verbose))(W1)
        else:
            gr = None
            hess = None
        
#         W0mx,W0my,W0sx,W0sy,s020,k0,kappa0,XX0,XXp0,Eta0,Xi0 = parse_W(W1)
        
        return W11,loss,gr,hess,result
    
    def optimize2(W10,W20,compute_hessian=False,simulate=False,verbose=True): 
        to_zero = np.array([(b[0]==0)&(b[1]==0) for b in bounds2])
        to_one = np.array([(b[0]==1)&(b[1]==1) for b in bounds2])
        relevant = ~to_zero  & ~to_one

        W21 = W20.copy()#np.zeros_like(W20)
        W21[to_zero] = 0
        W21[to_one] = 1

        def f(w):
            W = np.zeros_like(W20)
            W[to_one] = 1
            W[relevant] = w
            return minusLW(W10,W,simulate=simulate,verbose=verbose)

        def fprime(w):
            W = np.zeros_like(W20)
            W[to_one] = 1
            W[relevant] = w
            return minusdLdW2(W10,W,simulate=simulate,verbose=verbose)[relevant]

        w20 = W20[relevant]
        #w21,loss,result = sop.fmin_cg(f,w20,fprime=fprime)
        options = {}
        options['gtol'] = 1e-1
        result = sop.minimize(f,w20,jac=fprime,options=options,method='CG')
        w21 = result.x
        loss = result.fun

        W21[relevant] = w21

        if compute_hessian:
            gr = grad(lambda W2: minusLW(W1,W2,simulate=simulate,verbose=verbose))(W2)
            hess = hessian(lambda W2: minusLW(W1,W2,simulate=simulate,verbose=verbose))(W2)
        else:
            gr = None
            hess = None
        
        return W21,loss,gr,hess,result

    def optimize2_stimwise(W10,W20,compute_hessian=False,simulate=False,verbose=True): 
        to_zero = np.array([(b[0]==0)&(b[1]==0) for b in bounds2])
        to_one = np.array([(b[0]==1)&(b[1]==1) for b in bounds2])

        W21 = W20.copy()#np.zeros_like(W20)
        W21[to_zero] = 0
        W21[to_one] = 1
        for istim in range(nN):
            print('on stimulus #%d'%istim)
            in_this_stim_list = [np.zeros(shp,dtype='bool') for shp in shapes2]
            for ivar in range(len(shapes2)):
                in_this_stim_list[ivar][istim] = True
            in_this_stim = unparse_W(*in_this_stim_list)
            relevant = ~to_zero  & ~to_one & in_this_stim

            def f(w):
                W = np.zeros_like(W20)
                W[~relevant] = W21[~relevant]
                W[relevant] = w
                return minusLW(W10,W,simulate=simulate,verbose=verbose)

            def fprime(w):
                W = np.zeros_like(W20)
                W[~relevant] = W21[~relevant]
                W[relevant] = w
                return minusdLdW2(W10,W,simulate=simulate,verbose=verbose)[relevant]

            w20 = W20[relevant]
            #w21,loss,result = sop.fmin_cg(f,w20,fprime=fprime)
            options = {}
            options['gtol'] = 1e-2
            result = sop.minimize(f,w20,jac=fprime,options=options,method='CG')
            w21 = result.x
            loss = result.fun
            print('sum of relevant: '+str(relevant.sum()))
            W21[relevant] = w21

        if compute_hessian:
            gr = grad(lambda W2: minusLW(W1,W2,simulate=simulate,verbose=verbose))(W21)
            hess = hessian(lambda W2: minusLW(W1,W2,simulate=simulate,verbose=verbose))(W21)
        else:
            gr = None
            hess = None
        
        return W21,loss,gr,hess,result
    
    W10 = unparse_W(*W10list)
    W20 = unparse_W(*W20list)

    #simulate1,simulate2 = True,False
    verbose1,verbose2 = verbose,verbose

    old_loss = np.inf
    if stimwise:
        W21,loss,gr,hess,result = optimize2_stimwise(W10,W20,compute_hessian=compute_hessian,simulate=simulate2,verbose=verbose2)
    else:
        W21,loss,gr,hess,result = optimize2(W10,W20,compute_hessian=compute_hessian,simulate=simulate2,verbose=verbose2)
    W11,loss,gr,hess,result = optimize1(W10,W21,compute_hessian=compute_hessian,simulate=simulate1,verbose=verbose1)

    #W11,loss,gr,hess,result = optimize1(W10,W20,compute_hessian=compute_hessian,simulate=simulate1,verbose=verbose1)
    #if stimwise:
    #    W21,loss,gr,hess,result = optimize2_stimwise(W11,W20,compute_hessian=compute_hessian,simulate=simulate2,verbose=verbose2)
    #else:
    #    W21,loss,gr,hess,result = optimize2(W11,W20,compute_hessian=compute_hessian,simulate=simulate2,verbose=verbose2)

    delta = old_loss - loss
    while delta > 0.1:
        old_loss = loss
        #W11,loss,gr,hess,result = optimize1(W10,W20,compute_hessian=compute_hessian,simulate=True)
        #W21,loss,gr,hess,result = optimize2(W11,W20,compute_hessian=compute_hessian,simulate=False)

        if stimwise:
            W21,loss,gr,hess,result = optimize2_stimwise(W11,W21,compute_hessian=compute_hessian,simulate=simulate2,verbose=verbose2)
        else:
            W21,loss,gr,hess,result = optimize2(W11,W21,compute_hessian=compute_hessian,simulate=simulate2,verbose=verbose2)
        W11,loss,gr,hess,result = optimize1(W11,W21,compute_hessian=compute_hessian,simulate=simulate1,verbose=verbose1)

        #W11,loss,gr,hess,result = optimize1(W11,W21,compute_hessian=compute_hessian,simulate=simulate1,verbose=verbose1)
        #if stimwise:
        #    W21,loss,gr,hess,result = optimize2_stimwise(W11,W21,compute_hessian=compute_hessian,simulate=simulate2,verbose=verbose2)
        #else:
        #    W21,loss,gr,hess,result = optimize2(W11,W21,compute_hessian=compute_hessian,simulate=simulate2,verbose=verbose2)

        delta = old_loss - loss

        
    W1t = parse_W1(W11) #[Wmx,Wmy,Wsx,Wsy,s02,K,kappa,T,h1,h2,bl,amp]
    W2t = parse_W2(W21) #[XX,XXp,Eta,Xi]
    
    return W1t,W2t,loss,gr,hess,result