def init_saem_grad(As, p, n_iter=10, step=0.1, setting="gaussian"):
    global model
    if setting == "binary":
        model = model_bin
    n_samples, n, _ = As.shape
    theta, _, _ = init_saem(As, p)
    F, mu, sigma, sigma_l = theta
    sigma = 1
    sigma_l = 1
    mode = st.proj_V(F)
    Xs = np.array([mode.copy() for _ in range(n_samples)])
    ls = mu[None, :] + sigma_l * np.random.randn(n_samples, p)
    lks = []
    it = trange(n_iter)
    prop_l = 1
    current_log_lk = np.array([
        model.log_lk_partial(Xs[i], ls[i], As[i], theta)
        for i in range(n_samples)
    ])
    for t in it:
        mode = st.proj_V(F)
        posterior_std_l = 1 / (1 / sigma**2 + 1 / sigma_l**2)
        for _ in range(10):
            for i in range(n_samples):
                if t % 5 == 0:
                    m, s = st.greedy_permutation(mode, Xs[i])
                    Xs[i] = s * Xs[i][:, m]
                    ls[i] = ls[i][m]

                grad_X = model.log_lk_partial_grad_X(Xs[i], ls[i], As[i],
                                                     theta)
                grad_X = grad_X / norm(grad_X)
                Xs[i] = st.proj_V(Xs[i] + step * grad_X)

                # [l] Generate next move
                l2 = ls[i] + prop_l * np.random.randn(p)
                # [l] Compute the acceptance log-probability
                new_log_lk = model.log_lk_partial(Xs[i], l2, As[i], theta)
                log_alpha = new_log_lk - current_log_lk[i]
                # [l] Accept or reject
                if np.log(np.random.rand()) < log_alpha:
                    ls[i] = l2
                    current_log_lk[i] = new_log_lk

        F = vmf.mle(Xs.mean(axis=0))
        mu = ls.mean(axis=0)
        sigma = ((As - st.comp_numba_many(Xs, ls))**2).mean()
        sigma_l = ((ls - mu)**2).mean()
        theta = (F, mu, sigma, sigma_l)
        lks.append(model.log_lk(Xs, ls, As, theta, normalized=True))
        it.set_postfix({"lk": lks[-1]})
    return theta, Xs, ls, lks
def map_mask(A, mask, theta, n_iter):
    """
    Given a set of coefficients of A, finds the MAP estimator of the
    remaining hidden coefficients and the latent variables (X, l).
    mask is the set of unknown coefficients, given as two arrays of x and y indices.
    The function returns the arrays of values of A, X and l along the MCMC.
    """
    F, mu, sigma, sigma_l = theta
    mx, my = mask
    accepts_X = np.zeros(n_iter)
    n, p = F.shape
    batch = 50

    A = A.copy()
    X = st.proj_V(F)
    l = mu.copy()
    # Posterior standard deviation for lambda:
    posterior_std_l = np.sqrt(1/(1/sigma**2 + 1/sigma_l**2))
    lks = np.zeros(n_iter)
    
    it = range(n_iter)
    for t in it:
        step = 1/(2*t+1)
        
        # [A] Explicit maximum on A
        comp = st.comp_numba_single(X, l)
        for i in range(len(mx)):
            A[mx[i], my[i]] = comp[mx[i],my[i]]
            A[my[i], mx[i]] = comp[mx[i],my[i]]
        
        # [X] Sample on X
        grad_X = model.log_lk_partial_grad_X(X, l, A, theta)
        grad_X = grad_X/norm(grad_X)
        X = st.proj_V(X + step*grad_X)

        # [l] Explicit maximum on lambda
        v = np.diag(X.T@A@X)
        l = (posterior_std_l**2)*(v/sigma**2 + mu/sigma_l**2)
        
        lks[t] = model.log_lk_partial(X, l, A, theta)
        
    return A, X, l, lks
def init_saem_grad_cluster(As, p, K, n_iter=10, step=0.1, setting="gaussian"):
    n_samples, n, _ = As.shape
    kmeans = KMeans(n_clusters=K).fit(As.reshape(n_samples, -1))
    zs = kmeans.labels_

    F = np.zeros((K, n, p))
    mu = np.zeros((K, p))
    sigma = np.zeros(K)
    sigma_l = np.zeros(K)
    pi = np.bincount(zs) / n_samples
    for k in range(K):
        idx = np.where(zs == k)[0]
        (F[k], mu[k], sigma[k], sigma_l[k]), _, _ = init_saem(As[idx], p)

    mode = [st.proj_V(F[k]) for k in range(K)]
    Xs = np.array([mode[zs[i]].copy() for i in range(n_samples)])
    ls = mu[zs]

    lks = []
    prop_l = 1
    it = trange(n_iter)
    current_log_lk = np.array([
        model.log_lk_partial(
            Xs[i], ls[i], As[i],
            (F[zs[k]], mu[zs[k]], sigma[zs[k]], sigma_l[zs[k]]))
        for i in range(n_samples)
    ])
    for t in it:
        mode = [st.proj_V(F[k]) for k in range(K)]
        posterior_std_l = 1 / (1 / sigma**2 + 1 / sigma_l**2)
        for _ in range(10):
            for i in range(n_samples):
                if t % 5 == 0:
                    m, s = st.greedy_permutation(mode[k], Xs[i])
                    Xs[i] = s * Xs[i][:, m]
                    ls[i] = ls[i][m]

                k = zs[i]
                theta = (F[k], mu[k], sigma[k], sigma_l[k])

                if setting == "gaussian":
                    grad_X = model.log_lk_partial_grad_X(
                        Xs[i], ls[i], As[i], theta)
                elif setting == "binary":
                    grad_X = model_bin.log_lk_partial_grad_X(
                        Xs[i], ls[i], As[i], theta)
                grad_X = grad_X / norm(grad_X)
                Xs[i] = st.proj_V(Xs[i] + step * grad_X)

                # [l] Generate next move
                l2 = ls[i] + prop_l * np.random.randn(p)
                # [l] Compute the acceptance log-probability
                if setting == "gaussian":
                    new_log_lk = model.log_lk_partial(Xs[i], l2, As[i], theta)
                elif setting == "binary":
                    new_log_lk = model_bin.log_lk_partial(
                        Xs[i], l2, As[i], theta)
                log_alpha = new_log_lk - current_log_lk[i]
                # [l] Accept or reject
                if np.log(np.random.rand()) < log_alpha:
                    ls[i] = l2
                    current_log_lk[i] = new_log_lk

        for k in range(K):
            idx = np.where(zs == k)[0]
            F[k] = vmf.mle(Xs[idx].mean(axis=0))
            mu[k] = ls[idx].mean(axis=0)
            sigma[k] = ((As[idx] -
                         st.comp_numba_many(Xs[idx], ls[idx]))**2).mean()
            sigma_l[k] = ((ls[idx] - mu[k])**2).mean()

        if setting == "gaussian":
            lks.append(
                model_cluster.log_lk(Xs,
                                     ls,
                                     zs,
                                     As, (F, mu, sigma, sigma_l, pi),
                                     normalized=True))
        elif setting == "binary":
            lks.append(
                model_cluster.log_lk(Xs,
                                     ls,
                                     zs,
                                     As, (F, mu, sigma, sigma_l, pi),
                                     normalized=True))

        it.set_postfix({"lk": lks[-1]})
    return (F, mu, sigma, sigma_l, pi), Xs, ls, zs, lks
def mh(As, theta, n_iter, init=None, prop_X=0.01, prop_l=0.5, setting="gaussian"):
    """
    Metropolis within Gibbs sampler for the base model.
    - setting can be set to "binary" to handle binary networks
    - prop_X and prop_l are the proposal variances for X and l
    The function returns the final values of X and l, as well as the running likelihood
    and the chain acceptance rates.
    """
    F, mu, sigma, sigma_l = theta
    n_samples = As.shape[0]
    accepts_X = np.zeros((n_iter, n_samples))
    accepts_l = np.zeros((n_iter, n_samples))
    n, p = F.shape[-2:]
    if init==None:
        mode = st.proj_V(F)
        Xs = np.zeros((n_samples, n, p))
        ls = sigma_l*np.random.randn(n_samples, p)
        for i in range(n_samples):
            Xs[i] = mode
            ls[i] += mu
    else:
        Xs, ls = init
        Xs = Xs.copy()
        ls = ls.copy()
    
    if setting=="gaussian":
        current_log_lk = np.array([model.log_lk_partial(Xs[i], ls[i], As[i], theta) for i in range(n_samples)])
    elif setting=="binary":
        current_log_lk = np.array([model_bin.log_lk_partial(Xs[i], ls[i], As[i], theta) for i in range(n_samples)])
    
    for t in range(n_iter):
        for i in range(n_samples):
            # [X] Generate next move
            D = prop_X*np.random.randn(n,p)
            X2 = st.proj_V(Xs[i] + D)
            # [X] Compute the acceptance log-probability
            if setting=="gaussian":
                new_log_lk = model.log_lk_partial(X2, ls[i], As[i], theta)
            elif setting=="binary":
                new_log_lk = model_bin.log_lk_partial(X2, ls[i], As[i], theta)
            log_alpha = new_log_lk - current_log_lk[i]
            # [X] Accept or reject
            if np.log(np.random.rand()) < log_alpha:
                Xs[i] = X2
                current_log_lk[i] = new_log_lk
                accepts_X[t,i] = 1
            else:
                accepts_X[t,i] = 0
            
            # [l] Generate next move
            l2 = ls[i] + prop_l*np.random.randn(p)
            # [l] Compute the acceptance log-probability
            if setting=="gaussian":
                new_log_lk = model.log_lk_partial(Xs[i], l2, As[i], theta)
            elif setting=="binary":
                new_log_lk = model_bin.log_lk_partial(Xs[i], l2, As[i], theta)
            log_alpha = new_log_lk - current_log_lk[i]
            # [l] Accept or reject
            if np.log(np.random.rand()) < log_alpha:
                ls[i] = l2
                current_log_lk[i] = new_log_lk
                accepts_l[t,i] = 1
            else:
                accepts_l[t,i] = 0
            
    return Xs, ls, current_log_lk.sum(), accepts_X.mean(), accepts_l.mean()
def mh_mask(A, mask, theta, n_iter, init=None, progress=True, prop_X=0.02):
    """
    Given a set of coefficients of A, runs a MCMC chain to sample from the
    remaining hidden coefficients and the latent variables (X, l).
    - mask is the set of unknown coefficients, given as two arrays of x and y indices
    - prop_X is the initial proposal variance
    The function returns the arrays of values of A, X and l along the MCMC.
    """
    A_init = A.copy()
    F, mu, sigma, sigma_l = theta
    mx, my = mask
    accepts_X = np.zeros(n_iter)
    n, p = F.shape
    batch = 50
    optimal_rate = 0.234
    if init is None:
        X = st.proj_V(F)
        l = mu.copy()
    else:
        A, X, l = init
    
    # Posterior variance for lambda:
    posterior_std_l = np.sqrt(1/(1/sigma**2 + 1/sigma_l**2))
    sv_F = np.array([norm(F[:,i]) for i in range(p)])
    lks = np.zeros(n_iter)
    A_mh = np.zeros((n_iter, n, n))
    X_mh = np.zeros((n_iter, n, p))
    l_mh = np.zeros((n_iter, p))
    
    it = range(n_iter)
    for t in it:
        lks[t] = model.log_lk_partial(X, l, A, theta)
        # Sample on A
        A2 = A_init.copy()
        comp = st.comp_numba_single(X, l)
        for i in range(len(mx)):
            eps = sigma*np.sqrt(2)*np.random.randn()
            A2[mx[i], my[i]] = comp[mx[i],my[i]] + eps
        A = (A2+A2.T)/2
        
        # [X] Generate next move
        D = prop_X*np.random.randn(n,p)/sv_F
        X2 = st.proj_V(X + D)
        # [X] Compute the acceptance log-probability
        current_log_lk = model.log_lk_partial(X, l, A, theta)
        new_log_lk = model.log_lk_partial(X2, l, A, theta)
        log_alpha = (new_log_lk - current_log_lk) * 100
        # [X] Accept or reject
        if np.log(np.random.rand()) < log_alpha:
            X = X2
            current_log_lk = new_log_lk
            accepts_X[t] = 1
        else:
            accepts_X[t] = 0

        # Sample on lambda
        v = np.diag(X.T@A@X)
        posterior_mean = (posterior_std_l**2)*(v/sigma**2 + mu/sigma_l**2)
        l = posterior_mean
        
        A_mh[t] = A
        X_mh[t] = X
        l_mh[t] = l
        
        # Adaptively tune the acceptance rate
        if t%batch==0 and t>1:
            rate_X = accepts_X[max(0, t-batch):t+1].mean()
            adaptive_X = 2*(rate_X > optimal_rate)-1
            prop_X = np.exp(np.log(prop_X) + 0.5*adaptive_X/np.sqrt(1+n))
        
    return A_mh, X_mh, l_mh, lks
def mala(As, theta, n_iter, init=None, progress=True,
         prop_X=0.01, prop_l=0.5, setting="gaussian"):
    """
    Metropolis Adjusted Langevin Algorithm sampler for the base model.
    - setting can be set to "binary" to handle binary networks
    - prop_X and prop_l are the proposal variances for X and l
    The function returns the final values of X and l, as well as the running likelihood
    and the chain acceptance rates.
    """
    F, mu, sigma, sigma_l = theta
    n_samples = As.shape[0]
    accepts_X = np.zeros((n_iter, n_samples))
    accepts_l = np.zeros((n_iter, n_samples))
    n, p = F.shape[-2:]
    if init is None:
        mode = st.proj_V(F)
        Xs = np.array([mode.copy() for _ in range(n_samples)])
        ls = mu[None,:] + sigma_l*np.random.randn(n_samples, p)
    else:
        Xs, ls = init
        Xs = Xs.copy()
        ls = ls.copy()
    
    if setting=="gaussian":
        current_log_lk = np.array([model.log_lk_partial(Xs[i], ls[i], As[i], theta) for i in range(n_samples)])
    elif setting=="binary":
        current_log_lk = np.array([model_bin.log_lk_partial(Xs[i], ls[i], As[i], theta) for i in range(n_samples)])
    
    step_X = 0.5*prop_X**2
    step_l = 0.5*prop_l**2
    if setting=="gaussian":
        current_grad_X  = np.array([model.log_lk_partial_grad_X(Xs[i], ls[i], As[i], theta) for i in range(n_samples)])
    elif setting=="binary":
        current_grad_X  = np.array([model_bin.log_lk_partial_grad_X(Xs[i], ls[i], As[i], theta) for i in range(n_samples)])
    current_grad_X = np.array([g/norm(g) for g in current_grad_X])
    current_drift_X = np.array([st.proj_V(Xs[i] + step_X*current_grad_X[i]) for i in range(n_samples)])
    if setting=="gaussian":
        current_grad_lambda  = np.array([model.log_lk_partial_grad_lambda(Xs[i], ls[i], As[i], theta) for i in range(n_samples)])
    elif setting=="binary":
        current_grad_lambda  = np.array([model_bin.log_lk_partial_grad_lambda(Xs[i], ls[i], As[i], theta) for i in range(n_samples)])
    current_grad_lambda = np.array([g/norm(g) for g in current_grad_lambda])
    
    it = trange(n_iter) if progress else range(n_iter)
    for t in it:
        for i in range(n_samples):
            # [X] Generate next move
            D = prop_X*np.random.randn(n,p)
            grad_X = current_grad_X[i]
            drift_X = current_drift_X[i]
            D += step_X * grad_X
            X2 = st.proj_V(Xs[i] + D)
            if setting=="gaussian":
                grad_X2 = model.log_lk_partial_grad_X(X2, ls[i], As[i], theta)
            elif setting=="binary":
                grad_X2 = model_bin.log_lk_partial_grad_X(X2, ls[i], As[i], theta)
            grad_X2 = grad_X2/norm(grad_X2)
            drift_X2 = st.proj_V(X2 + step_X*grad_X2)
            mala_jump = (-st.discr(Xs[i], drift_X2) + st.discr(X2, drift_X)) / (2*prop_X**2)
            # [X] Compute the acceptance log-probability
            if setting=="gaussian":
                new_log_lk = model.log_lk_partial(X2, ls[i], As[i], theta)
            elif setting=="binary":
                new_log_lk = model_bin.log_lk_partial(X2, ls[i], As[i], theta)
            log_alpha = new_log_lk - current_log_lk[i] + mala_jump
            # [X] Accept or reject
            if np.log(np.random.rand()) < log_alpha:
                Xs[i] = X2
                current_log_lk[i] = new_log_lk
                accepts_X[t,i] = 1
                current_grad_X[i] = grad_X2
                current_drift_X[i] = drift_X2
                if setting=="gaussian":
                    g = model.log_lk_partial_grad_lambda(Xs[i], ls[i], As[i], theta)
                elif setting=="binary":
                    g = model_bin.log_lk_partial_grad_lambda(Xs[i], ls[i], As[i], theta)
                current_grad_lambda[i] = g/norm(g)
            else:
                accepts_X[t,i] = 0
            
            # [l] Generate next move
            l2 = ls[i] + prop_l*np.random.randn(p)
            grad_l = current_grad_lambda[i]
            l2 += step_l * grad_l
            if setting=="gaussian":
                grad_l2 = model.log_lk_partial_grad_lambda(Xs[i], l2, As[i], theta)
            elif setting=="binary":
                grad_l2 = model_bin.log_lk_partial_grad_lambda(Xs[i], l2, As[i], theta)
            grad_l2 = grad_l2/norm(grad_l2)
            mala_jump = (-norm(ls[i]-l2-step_l*grad_l2)**2 + norm(l2-ls[i]-step_l*grad_l)**2) / (2*prop_l**2)
            # [l] Compute the acceptance log-probability
            if setting=="gaussian":
                new_log_lk = model.log_lk_partial(Xs[i], l2, As[i], theta)
            elif setting=="binary":
                new_log_lk = model_bin.log_lk_partial(Xs[i], l2, As[i], theta)
            log_alpha = new_log_lk - current_log_lk[i] + mala_jump
            # [l] Accept or reject
            if np.log(np.random.rand()) < log_alpha:
                ls[i] = l2
                current_log_lk[i] = new_log_lk
                accepts_l[t,i] = 1
                current_grad_lambda[i] = grad_l2
                if setting=="gaussian":
                    g = model.log_lk_partial_grad_X(Xs[i], ls[i], As[i], theta)
                elif setting=="binary":
                    g = model_bin.log_lk_partial_grad_X(Xs[i], ls[i], As[i], theta)
                current_grad_X[i] = g/norm(g)
                current_drift_X[i] = st.proj_V(Xs[i] + step_X*current_grad_X[i])
            else:
                accepts_l[t,i] = 0
            
            
        if progress: it.set_postfix({"log_lk": current_log_lk.sum()})
    if progress: print("Acceptance rates", accepts_X.mean(), accepts_l.mean())
        
    return Xs, ls, current_log_lk.sum(), accepts_X.mean(), accepts_l.mean()