def posterior_mean_and_cov(plot_data, curve_xs): zero_func = VentureFunction(lambda x: 0, sp_type = sp.SPType([], t.NumberType())) prior_mean = zero_func prior_cov = gpexample_plugin.squared_exponential(plot_data.sigma, plot_data.l) return gp_conditional.conditional_mean_and_cov(curve_xs, prior_mean, prior_cov, plot_data.Xseen, plot_data.Yseen)
def __call__(self, inferrer, f_, curve_xs_, true_ys_, scatter_xs_, scatter_ys_, circles_, user_prefix_): # Some of this is hackish aux = [map(getNumber, p.getArray()) for p in fromStackDict(f_[0]["aux"]).getArray()] f = f_[0]["value"] prior_mean = f.mean prior_cov = f.covariance if len(aux) > 0: Xseen, Yseen = zip(*aux) else: Xseen, Yseen = [], [] curve_xs = map(getNumber, fromStackDict(curve_xs_[0]).getArray()) true_ys = map(getNumber, fromStackDict(true_ys_[0]).getArray()) user_prefix = fromStackDict(user_prefix_[0]).getString() scatter_xs = map(getNumber, fromStackDict(scatter_xs_[0]).getArray()) scatter_ys = map(getNumber, fromStackDict(scatter_ys_[0]).getArray()) circles = map(lambda a: map(getNumber, a), map(getArray, fromStackDict(circles_[0]).getArray())) mean, cov = gp_conditional.conditional_mean_and_cov(curve_xs, prior_mean, prior_cov, Xseen, Yseen) fig, ax = plt.subplots(1) ax.set_xlabel("a") ax.set_ylabel("r(a)").set_rotation(0) ax.scatter(scatter_xs, scatter_ys, color="k", s=15) ax.set_xlim(min(curve_xs), max(curve_xs)) for i in range(100): ys = np.random.multivariate_normal(mean, cov) ax.plot(curve_xs, ys, c="red", alpha=0.2, linewidth=2) if len(true_ys) > 0: ax.plot(curve_xs, true_ys, c="blue") for ccoords in circles: (x, y) = ccoords[0:2] def to_hex_color(n): h = hex(int(n))[2:] padded = (6 - len(h)) * "0" + h return clr.hex2color("#" + padded) color = to_hex_color(ccoords[2]) if len(ccoords) > 2 else "green" xsize, ysize = ccoords[3:5] if len(ccoords) > 3 else (0.4, 0.33) circle = Ellipse([x, y], xsize, ysize, color=color, linewidth=5, fill=False) circle.set_zorder(10) ax.add_artist(circle) date_fmt = "%Y%m%d_%H%M%S" directory = "draw_gp_curves_callback" def j(fname): return os.path.join(directory, fname) output_prefix = "%s_%s" % (user_prefix, datetime.now().strftime(date_fmt)) scatterpath = j("%s_scatter.pkl" % (output_prefix,)) print "Logging scatter data to %s" % (scatterpath,) scatter_data = {"scatter_xs": scatter_xs, "scatter_ys": scatter_ys} with open(scatterpath, "wb") as f: pickle.dump(scatter_data, f) print "Outputting to %s.png" % (j(output_prefix),) fig.savefig("%s.png" % (j(output_prefix),), dpi=fig.dpi, bbox_inches="tight") print "Done."