Exemplo n.º 1
0
import tensorflow as tf

from forkan.models import VAE
from forkan.datasets import load_atari_normalized

learning_rate = 1e-4
beta = 5.5
latents = 20

for name in ['pong', 'breakout', 'boxing', 'gopher', 'upndown']:
    data = load_atari_normalized(name)
    v = VAE(data.shape[1:], name=name, lr=learning_rate, beta=beta, latent_dim=latents)
    v.train(data, num_episodes=50, print_freq=-1)

    tf.reset_default_graph()
    del data
    del v
Exemplo n.º 2
0
import logging

import numpy as np
import scipy

from forkan import dataset_path
from forkan.common.utils import create_dir
from forkan.datasets import load_atari_normalized

logger = logging.getLogger(__name__)

logger.info('loading dataset ...')
data = load_atari_normalized('breakout-small')
logger.info('done loading')

np.random.seed(0)
idxs = [5, 6, 7, 305711, 244444]
rand_frames = data[idxs]

print('dumping file')
np.savez_compressed('{}/breakout-eval.npz'.format(dataset_path),
                    data=rand_frames)

print('storing some pngs')

create_dir('{}/breakout-eval/'.format(dataset_path))

for n, f in enumerate(rand_frames[:, ...]):
    scipy.misc.imsave('{}/breakout-eval/frame{}.png'.format(dataset_path, n),
                      np.squeeze(f))
print('done')
Exemplo n.º 3
0
import tensorflow as tf

from forkan.datasets import load_atari_normalized
from forkan.models import VAE

betas = [1.28, 2.0, 3.5, 5.0]

for beta in betas:
    data = load_atari_normalized('breakout')

    v = VAE(data.shape[1:],
            network='atari',
            name='breakout',
            beta=beta,
            lr=1e-4,
            latent_dim=20)
    v.train(data, num_episodes=100, print_freq=200, batch_size=128)

    v.s.close()
    tf.reset_default_graph()
    del data
    del v