def plot_swap_distr(fit_az, data, ax=None, n_bins=10, **kwargs): if ax is None: f = plt.figure() ax = f.add_subplot(1, 1, 1) out = swan.get_normalized_centroid_distance(fit_az, data, **kwargs) true_arr_full, pred_arr_full, _ = out _, bins, _ = ax.hist(true_arr_full, histtype='step', density=True, bins=n_bins) ax.hist(pred_arr_full, histtype='step', density=True, bins=bins) gpl.add_vlines([0, 1], ax) return ax
def plot_mu_hists(models, mu_key, mu2=None, axs=None, fwid=3): pshape = list(models.values())[0].posterior[mu_key].shape[2:] if axs is None: figsize = (fwid * pshape[1], fwid * pshape[0]) f, axs = plt.subplots(*pshape, figsize=figsize) for k, m in models.items(): mus = m.posterior[mu_key] for (i, j) in u.make_array_ind_iterator(pshape): mu_plot = mus[..., i, j].to_numpy().flatten() if mu2 is not None: mu2_plot = m.posterior[mu2][..., i, j].to_numpy().flatten() mu_plot = mu_plot - mu2_plot axs[i, j].hist(mu_plot, label=k, histtype='step') gpl.add_vlines(0, axs[i, j]) axs[i, j].legend(frameon=False) return axs
def plot_decoding_map(*maps, fwid=5, thresh=True, ts=(5, 15, 25)): n_plots = len(maps) f, axs = plt.subplots(1, n_plots, figsize=(fwid * n_plots, fwid)) for i, map_i in enumerate(maps): if thresh: map_i[map_i < 0] = 0 ax_ts = np.arange(map_i.shape[0]) m = gpl.pcolormesh(ax_ts, ax_ts, map_i, ax=axs[i], vmin=0, vmax=1) for t in ts: gpl.add_hlines(t, axs[i]) gpl.add_vlines(t, axs[i]) axs[i].set_xlabel('testing time') axs[i].set_xticks(ts) axs[i].set_yticks(ts) axs[0].set_ylabel('training time') f.colorbar(m, ax=axs)
def plot_posterior_predictive_dims(m, d, dims=5, axs=None, ks=None): if axs is None: f, axs = plt.subplots(5, 1) total_post = np.concatenate(m.posterior_predictive.err_hat, axis=0) total_post = np.concatenate(total_post, axis=0) if ks is not None: ks_inds = ks[1] d_ks = d[ks_inds] for i in range(dims): _, bins, _ = axs[i].hist(d[:, i], density=True, label='observed') axs[i].hist(total_post[:, i], histtype='step', density=True, label='predictive', linestyle='dashed', color='k', bins=bins) if ks is not None: gpl.add_vlines(d_ks[:, i], axs[i])
def plot_k_distributions(models, labels, k_thresh=.7, fwid=3, sharex=True, sharey=True, compute_k=True): if compute_k: models = list(swan.get_pareto_k(m)[0] for m in models) combs = list(it.combinations(range(len(models)), 2)) side_plots = len(models) - 1 f, axs = plt.subplots(side_plots, side_plots, figsize=(fwid * side_plots, fwid * side_plots), sharex=sharex, sharey=sharey) for i, (ind1, ind2) in enumerate(combs): ax = axs[ind1, ind2 - 1] ax.plot(models[ind1], models[ind2], 'o') gpl.add_hlines(k_thresh, ax) gpl.add_vlines(k_thresh, ax) ax.set_xlabel(labels[ind1]) ax.set_ylabel(labels[ind2]) return f, axs
def plot_session_swap_distr_collection(session_dict, axs=None, n_bins=20, fwid=3, p_ind=1, bin_bounds=None, ret_data=True, colors=None, **kwargs): if colors is None: colors = {} if axs is None: n_plots = len(list(session_dict.values())[0][0]) fsize = (fwid * n_plots, fwid) f, axs = plt.subplots(1, n_plots, figsize=fsize, sharex=False, sharey=False) if n_plots == 1: axs = [axs] true_d = {} pred_d = {} ps_d = {} for (sn, (mdict, data)) in session_dict.items(): for (k, faz) in mdict.items(): out = swan.get_normalized_centroid_distance(faz, data, p_ind=p_ind, **kwargs) true, pred, ps = out true_k = true_d.get(k, []) true_k.append(true) true_d[k] = true_k pred_k = pred_d.get(k, []) pred_k.append(pred) pred_d[k] = pred_k ps_k = ps_d.get(k, []) ps_k.append(ps[:, p_ind]) ps_d[k] = ps_k if bin_bounds is not None: bins = np.linspace(*bin_bounds, n_bins) else: bins = n_bins out_data = {} for i, (k, td) in enumerate(true_d.items()): td_full = np.concatenate(td, axis=0) pd_full = np.concatenate(pred_d[k], axis=0) ps_full = np.concatenate(ps_d[k], axis=0) # print(swd.dip(td_full)) color = colors.get(k) _, bins, _ = axs[i].hist(td_full, bins=bins, color=color, density=True, label='observed') axs[i].hist(pd_full, bins=bins, histtype='step', color='k', linestyle='dashed', density=True, label='predicted') gpl.add_vlines([0, 1], axs[i]) axs[i].set_ylabel(k) if ret_data: out_data[k] = td_full axs[i].legend(frameon=False) if ret_data: out = axs, out_data else: out = axs return out