npts = 100 XX, YY = np.meshgrid(np.linspace(xmins[0], xmaxs[0], npts), np.linspace(xmins[1], xmaxs[1], npts)) data = np.column_stack((XX.ravel(), YY.ravel(), np.zeros((npts**2, D - 2)))) input = np.zeros((data.shape[0], 0)) mask = np.ones_like(data, dtype=bool) tag = None lls = hmm.observations.log_likelihoods(data, input, mask, tag) plt.figure(figsize=(6, 6)) for k in range(K): plt.contour(XX, YY, np.exp(lls[:, k]).reshape(XX.shape), cmap=white_to_color_cmap(colors[k % len(colors)])) plt.plot(x[z == k, 0], x[z == k, 1], 'o', mfc=colors[k], mec='none', ms=4) plt.plot(x[:, 0], x[:, 1], '-k', lw=2, alpha=.5) plt.xlabel("$x_1$") plt.ylabel("$x_2$") plt.title("Observation Distributions") plt.tight_layout() if save_figures: plt.savefig("lds_3.pdf") # In[16]: # Simulate from the HMM fit
def plot_place_fields(results, pos, center, radius, data, figdir='.'): """ Plot the observation vector associated with a latent state """ model = results.samples model.relabel_by_usage() N_used = results.N_used[-1] lmbdas = model.rates[:N_used, :] stateseq = model.stateseqs[0] occupancy = model.state_usages # Plot a figure for each latent state N_colors = 9 colors = brewer2mpl.get_map('Set1', 'qualitative', N_colors).mpl_colors # State distributions dists = [] for s in xrange(N_used): cd = CircularDistribution(center, radius) cd.fit_xy(pos[stateseq == s, 0], pos[stateseq == s, 1]) dists.append(cd) # Plot the log likelihood as a function of iteration fig = create_figure((5, 4)) plt.figtext(0.05 / 5.0, 3.8 / 4.0, "A") toplot = [0, 13, 28, 38] for i, c in enumerate([0, 13, 28, 38]): left = 1.25 * i + 0.05 print "Plotting cell ", c color = colors[np.mod(c, N_colors)] cmap = white_to_color_cmap(color) # Compute the inferred place field inf_place_field = dists[0] * lmbdas[0, c] * occupancy[0] for s in range(1, N_used): inf_place_field += dists[s] * lmbdas[s, c] * occupancy[s] # inf_place_field = sum([d*(l*o) for d,l,o in zip(dists, lmbdas[c,:], occupancy)]) spks = np.array(data[:, c] > 0).ravel() true_place_field = CircularDistribution(center, radius) true_place_field.fit_xy(pos[spks, 0], pos[spks, 1]) # Plot the locations of this state ax = create_axis_at_location(fig, left, 2.65, 1.15, 1.15, transparent=True) remove_plot_labels(ax) # Plot the empirical location distribution inf_place_field.plot(ax=ax, cmap=cmap, plot_data=True, plot_colorbar=False) ax.set_title('Inf. Place Field %d' % (c + 1), fontdict={'fontsize': 9}) # Now plot the true place field ax = create_axis_at_location(fig, left, 1.25, 1.15, 1.15, transparent=True) remove_plot_labels(ax) true_place_field.plot(ax=ax, cmap=cmap, plot_data=True, plot_colorbar=False) ax.set_title('True Place Field %d' % (c + 1), fontdict={'fontsize': 9}) # Plot the KL divergence histogram kls = np.zeros(model.N) tvs = np.zeros(model.N) for c in xrange(model.N): # Compute the inferred place field inf_place_field = dists[0] * lmbdas[0, c] * occupancy[0] for s in range(1, N_used): inf_place_field += dists[s] * lmbdas[s, c] * occupancy[s] # inf_place_field = sum([d*(l*o) for d,l,o in zip(dists, lmbdas[c,:], occupancy)]) spks = np.array(data[:, c] > 0).ravel() true_place_field = CircularDistribution(center, radius) true_place_field.fit_xy(pos[spks, 0], pos[spks, 1]) kls[c] = compute_place_field_KL(inf_place_field, true_place_field) tvs[c] = compute_place_field_TV(inf_place_field, true_place_field) bin_centers = np.arange(0.006, 0.0141, 0.001) bin_width = 0.001 bin_edges = np.concatenate( (bin_centers - bin_width / 2.0, [bin_centers[-1] + bin_width / 2.0])) ax = create_axis_at_location(fig, 0.5, 0.5, 4., .5, transparent=True) ax.hist(tvs, bins=bin_edges, facecolor=allcolors[1]) ax.set_xlim(0.005, 0.015) ax.set_xticks(bin_centers) ax.set_xticklabels([ "{0:.3f}".format(bc) if i % 2 == 0 else "" for i, bc in enumerate(bin_centers) ]) ax.set_xlabel("$TV(p_{inf}, p_{true})$") ax.set_yticks(np.arange(17, step=4)) ax.set_ylabel("Count") plt.figtext(0.05 / 5.0, 1.1 / 4.0, "B") print "TVs of plotted cells: " print tvs[toplot] # fig.savefig(os.path.join(figdir,'figure8.pdf')) fig.savefig(os.path.join(figdir, 'figure8.png')) plt.show()
def make_figure(results, S_train, pos_train, S_test, pos_test, center, radius, figdir="."): model = results.samples model.relabel_by_usage() N_used = results.N_used[-1] stateseq = model.stateseqs[0] occupancy = model.state_usages T_test = S_test.shape[0] t_test = np.arange(T_test) * 0.25 fig = create_figure(figsize=(5,3)) # Plot the centers of the latent states ax = create_axis_at_location(fig, .05, 1.55, 1.15, 1.15, transparent=True) plt.figtext(0.05/5, 2.8/3, "A") remove_plot_labels(ax) circle = matplotlib.patches.Circle(xy=[0,0], radius= radius, linewidth=1, edgecolor="k", facecolor="white") ax.add_patch(circle) plt.figtext(1.2/5, 2.8/3, "B") for k in xrange(N_used): relocc = occupancy[k] / np.float(np.amax(occupancy)) cd = CircularDistribution(center, radius) cd.fit_xy(pos_train[stateseq==k,0], pos_train[stateseq==k,1]) # import pdb; pdb.set_trace() rm, thm = cd.mean xm,ym = convert_polar_to_xy(np.array([[rm, thm]]), [0,0]) ax.plot(xm,ym,'o', markersize=relocc*6, markerfacecolor='k', markeredgecolor='k', markeredgewidth=1) ax.set_xlim(-radius, radius) ax.set_ylim(-radius, radius) ax.set_title('All states', fontdict={'fontsize' : 9}) # Plot a figure for each latent state for k in xrange(3): left = 1.25 * (k+1) + 0.05 color = allcolors[k] cmap = white_to_color_cmap(color) # Plot the locations of this state ax = create_axis_at_location(fig, left, 1.55, 1.15, 1.15, transparent=True) remove_plot_labels(ax) # Plot the empirical location distribution cd = CircularDistribution(center, radius) cd.fit_xy(pos_train[stateseq==k,0], pos_train[stateseq==k,1]) cd.plot(ax=ax, cmap=cmap, plot_data=True, plot_colorbar=False) ax.set_title('State %d (%.1f%%)' % (k+1, 100.*occupancy[k]), fontdict={'fontsize' : 9}) # Bottom: Plot the true and predicted locations for heldout data plt.figtext(0.05/5, 1.55/3, "C") epdf = estimate_pos(model, S_train, pos_train, S_test, pos_test, center, radius) # Compute the mean trajectory mean_location = np.zeros_like(pos_test) for t in range(T_test): cd = CircularDistribution(center, radius, pdf=epdf[t,:]) mean_location[t,:] = convert_polar_to_xy(np.atleast_2d(cd.mean), center) # Convert estimates to x,y and compute mean squared error sqerr = np.sqrt((mean_location - pos_test)**2).mean(axis=1) mse = sqerr.mean(axis=0) stdse = sqerr.std(axis=0) print "MSE: %f \pm %f" % (mse, stdse) ax_y = create_axis_at_location(fig, 0.6, 0.4, 3.8, 0.5, box=True, ticks=True) ax_y.plot(t_test, pos_test[:,1] - center[1], '-k', lw=1) ax_y.plot(t_test, mean_location[:,1] - center[1], '-', color=allcolors[1]) ax_y.set_ylabel('$y(t)$ [cm]', fontsize=9) ax_y.set_ylim([-radius,radius]) ax_y.set_xlabel('$t$ [s]', fontsize=9) ax_y.set_xlim(0,T_test*0.25) ax_y.tick_params(axis='both', which='major', labelsize=9) ax_x = create_axis_at_location(fig, 0.6, 1., 3.8, 0.5, box=True, ticks=True) ax_x.plot(t_test, pos_test[:,0] - center[0], '-k', lw=1) ax_x.plot(t_test, mean_location[:,0] - center[0], '-', color=allcolors[1]) ax_x.set_ylabel('$x(t)$ [cm]', fontsize=9) ax_x.set_ylim([-radius,radius]) ax_x.set_xticks(ax_y.get_xticks()) ax_x.set_xticklabels([]) ax_x.tick_params(axis='both', which='major', labelsize=9) ax_x.set_xlim(0,T_test*0.25) fig.savefig(os.path.join(figdir, 'figure5.pdf')) fig.savefig(os.path.join(figdir, 'figure5.png')) plt.show()
lim = .85 * abs(y).max() XX, YY = np.meshgrid(np.linspace(-lim, lim, 100), np.linspace(-lim, lim, 100)) data = np.column_stack((XX.ravel(), YY.ravel())) input = np.zeros((data.shape[0], 0)) mask = np.ones_like(data, dtype=bool) tag = None lls = true_hmm.observations.log_likelihoods(data, input, mask, tag) # In[36]: plt.figure(figsize=(6, 6)) for k in range(K): plt.contour(XX, YY, np.exp(lls[:, k]).reshape(XX.shape), cmap=white_to_color_cmap(colors[k])) plt.plot(y[z == k, 0], y[z == k, 1], 'o', mfc=colors[k], mec='none', ms=4) plt.plot(y[:, 0], y[:, 1], '-k', lw=1, alpha=.25) plt.xlabel("$y_1$") plt.ylabel("$y_2$") plt.title("Observation Distributions") if save_figures: plt.savefig("hmm_1.pdf") # In[35]: # Plot the data and the smoothed data lim = 1.05 * abs(y).max() plt.figure(figsize=(8, 6))