def plot_samples(init_samples, samples, save_to_file=False, epoch=None): all_samples = gp.concatenate((init_samples.reshape((1, init_samples.shape[0], init_samples.shape[1])), samples)) n_samples = all_samples.shape[0] n_chains = all_samples.shape[1] img = np.zeros((29 * n_samples + 1, 29 * n_chains - 1), dtype="uint8") for step in range(n_samples): v = all_samples[step, :, :] A = dlutil.tile_raster_images( gp.as_numpy_array(v), img_shape=(28, 28), tile_shape=(1, n_chains), tile_spacing=(1, 1) ) img[29 * step : 29 * step + 28, :] = A if save_to_file: assert epoch is not None pilimage = pil.fromarray(img) pilimage.save("samples-%02i.png" % epoch) return img
def plot_pcd_chains(rbm, epoch): v = gp.as_numpy_array(rbm.persistent_vis) A = dlutil.tile_raster_images(v, (28, 28), (8, 8)).astype("float64") pilimage = pil.fromarray(A).convert("RGB") pilimage.save("pcd-vis-%02i.png" % epoch)
def plot_weights(rbm, epoch): W = gp.as_numpy_array(rbm.weights) A = dlutil.tile_raster_images(W.T, (28, 28), (10, 10)).astype("float64") pilimage = pil.fromarray(A).convert("RGB") pilimage.save("filters-%02i.png" % epoch)