def plot_LatentDistanceModel(W, L, N, L_true=None, ax=None):
    """
    If D==2, plot the embedded nodes and the connections between them

    :param L_true:  If given, rotate the inferred features to match F_true
    :return:
    """
    # Color the weights by the
    import matplotlib.cm as cm
    cmap = cm.get_cmap("RdBu")
    W_lim = abs(W[:, :]).max()
    W_rel = (W[:, :] - (-W_lim)) / (2 * W_lim)

    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111, aspect="equal")

    # If true locations are given, rotate L to match L_true
    if L_true is not None:
        R = compute_optimal_rotation(L, L_true)
        L = L.dot(R)

    # Scatter plot the node embeddings
    # Plot the edges between nodes
    for n1 in range(N):
        for n2 in range(N):
            ax.plot([L[n1, 0], L[n2, 0]], [L[n1, 1], L[n2, 1]],
                    '-',
                    color=cmap(W_rel[n1, n2]),
                    lw=1.0)
    ax.plot(L[:, 0],
            L[:, 1],
            's',
            color='k',
            markerfacecolor='k',
            markeredgecolor='k')

    # Get extreme feature values
    b = np.amax(abs(L)) + L[:].std() / 2.0

    # Plot grids for origin
    ax.plot([0, 0], [-b, b], ':k', lw=0.5)
    ax.plot([-b, b], [0, 0], ':k', lw=0.5)

    # Set the limits
    ax.set_xlim([-b, b])
    ax.set_ylim([-b, b])

    # Labels
    ax.set_xlabel('Latent Dimension 1')
    ax.set_ylabel('Latent Dimension 2')
    plt.show()

    return ax
示例#2
0
A = 1 * (np.random.rand(N, N) < P[i0])

W = A * W

new_data = dict(N=N, W=W, A=A)
fit = sm.sampling(data=new_data, iter=1000, chains=4)

samples = fit.extract(permuted=True)
L_estimate_all = samples['l']
p_estimate_all = samples['p']
eta_estimate_all = samples['eta']
rho_estimate_all = samples['rho']

for i in range(2000):
    R = compute_optimal_rotation(L_estimate_all[i, :, :], L)
    L_estimate_all[i, :, :] = np.dot(L_estimate_all[i, :, :], R)

L_estimate = np.mean(L_estimate_all, 0)
sns.heatmap(W, ax=axs[i0, 0])
sns.heatmap(A, ax=axs[i0, 1])
sns.kdeplot(samples['p'], ax=axs[i0, 2])
axs[i0, 2].vlines(P[i0], 0, 10, colors="r", linestyles="dashed")
axs[i0, 3].scatter(L[:, 0], L[:, 1])

from hips.plotting.colormaps import harvard_colors

color = harvard_colors()[0:10]
for i in range(N):
    axs[i0, 4].scatter(L_estimate_all[-50:, i, 0],
                       L_estimate_all[-50:, i, 1],
        W[n, m] = npr.multivariate_normal(Mu[n, m], Sig[n, m])

aa = 1.0
bb = 1.0
cc = 1.0

sm = pickle.load(
    open('/Users/pillowlab/Dropbox/pyglm-master/Practices/model.pkl', 'rb'))

new_data = dict(N=N, W=W, B=dim)

for i in range(100):
    fit = sm.sampling(data=new_data,
                      iter=100,
                      warmup=50,
                      chains=1,
                      init=[dict(l=L1, sigma=aa)],
                      control=dict(stepsize=0.001))

    samples = fit.extract(permuted=True)
    aa = np.mean(samples['sigma'])
    #aa = samples['sigma'][-1]
    #bb = np.mean(samples['eta'])
    #cc = np.mean(samples['rho'])
    L1 = np.mean(samples['l'], 0)
    #L1 = samples['l'][-1]
    R = compute_optimal_rotation(L1, L)
    L1 = np.dot(L1, R)

plt.scatter(L1[:, 0], L1[:, 1])
plt.scatter(L[:, 0], L[:, 1])
示例#4
0
    dlp = grad(lp)
    stepsz = 0.005
    nsteps = 10
    accept_rate = 0.9
    smpls[s], stepsz, accept_rate= \
        hmc(lp, dlp, stepsz, nsteps, smpls[s-1], negative_log_prob=False, avg_accept_rate=accept_rate,
                adaptive_step_sz=True)

    lp1[s] = lp(smpls[s])
    sigma = _resample_sigma(smpls[s])
    a[s] = sigma
    W_all[s - 1] = W1
    print(sigma)

for s in range(N_samples):
    R = compute_optimal_rotation(smpls[s], L)
    smpls[s] = np.dot(smpls[s], R)

L_estimate = smpls[N_samples // 2:].mean(0)

# Debug here, because the two directed weights are ploted together
# With different strength

#plot_LatentDistanceModel(W, L_estimate, N, L_true=L)
#plot_LatentDistanceModel(W, L, N)
plt.figure(1)
plt.scatter(smpls[-100:, :, 0], smpls[-100:, :, 1])
plt.scatter(L[:, 0], L[:, 1], color='r')
plt.figure(2)
plt.plot(lp1)
plt.figure(3)