Exemplo n.º 1
0
def plot_samples(init_samples, samples, save_to_file=False, epoch=None):
    all_samples = gp.concatenate((init_samples.reshape((1, init_samples.shape[0], init_samples.shape[1])), samples))
    n_samples = all_samples.shape[0]
    n_chains = all_samples.shape[1]
    img = np.zeros((29 * n_samples + 1, 29 * n_chains - 1), dtype="uint8")

    for step in range(n_samples):
        v = all_samples[step, :, :]
        A = dlutil.tile_raster_images(
            gp.as_numpy_array(v), img_shape=(28, 28), tile_shape=(1, n_chains), tile_spacing=(1, 1)
        )
        img[29 * step : 29 * step + 28, :] = A

    if save_to_file:
        assert epoch is not None
        pilimage = pil.fromarray(img)
        pilimage.save("samples-%02i.png" % epoch)
    return img
Exemplo n.º 2
0
def plot_pcd_chains(rbm, epoch):
    v = gp.as_numpy_array(rbm.persistent_vis)
    A = dlutil.tile_raster_images(v, (28, 28), (8, 8)).astype("float64")
    pilimage = pil.fromarray(A).convert("RGB")
    pilimage.save("pcd-vis-%02i.png" % epoch)
Exemplo n.º 3
0
def plot_weights(rbm, epoch):
    W = gp.as_numpy_array(rbm.weights)
    A = dlutil.tile_raster_images(W.T, (28, 28), (10, 10)).astype("float64")
    pilimage = pil.fromarray(A).convert("RGB")
    pilimage.save("filters-%02i.png" % epoch)