Пример #1
0
        def build_pend_env(args, **kwargs):
            alg = args.alg
            seed = args.seed

            flatten_dict_observations = alg not in {'her'}
            env = make_vec_env(args.env, 'classic_control', args.num_env or 1, seed, reward_scale=args.reward_scale,
                                     flatten_dict_observations=flatten_dict_observations)
            return VecVAEStack(env, k=k, load_from=vae_name.replace('/', ':'))
Пример #2
0
        def build_pend_env(args, **kwargs):
            alg = args.alg
            seed = args.seed

            flatten_dict_observations = alg not in {'her'}
            env = make_vec_env(args.env, 'atari', args.num_env or 1, seed, reward_scale=args.reward_scale,
                                     flatten_dict_observations=flatten_dict_observations)
            return VecVAEStack(env, k=k, load_from=vae_name.replace('/', ':'), vae_network='atari', norm_fac=(1/255))
Пример #3
0
def build_pend_env(args, **kwargs):
    alg = args.alg
    seed = args.seed

    flatten_dict_observations = alg not in {'her'}
    env = make_vec_env(args.env,
                       'atari',
                       args.num_env or 1,
                       seed,
                       reward_scale=args.reward_scale,
                       flatten_dict_observations=flatten_dict_observations)
    return VecVAEStack(env, k=3, load_from=vae_name)
Пример #4
0
    def build_pend_env(args, **kwargs):
        alg = args.alg
        seed = args.seed

        flatten_dict_observations = alg not in {'her'}
        env = make_vec_env(args.env,
                           'classic_control',
                           args.num_env or 1,
                           seed,
                           reward_scale=args.reward_scale,
                           flatten_dict_observations=flatten_dict_observations)
        return VecVAEStack(
            env,
            k=3,
            load_from='pendvisualuniform-b77.5-lat5-lr0.001-2019-03-21T00/13'.
            replace('/', ':'))
Пример #5
0
import numpy as np
from baselines.common.cmd_util import make_vec_env

from gym.envs.classic_control.pendulum_test import PendulumTestEnv
import matplotlib.pyplot as plt
import seaborn as sns

vae_name = 'pendvisualuniform-b77.5-lat5-lr0.001-2019-03-21T00/13'.replace(
    '/', ':')

env = make_vec_env('PendulumTest-v0',
                   'classic_control',
                   2,
                   0,
                   flatten_dict_observations=False)
venv = VecVAEStack(env, k=3, load_from=vae_name)

thetas = np.linspace(0, 2 * np.pi, 20)

dings1 = []
dings2 = []

o = venv.reset()
d = False

while not np.any(d):
    print(o)
    exit(0)
    # when using env obs, we can plot to sanity check theta
    # plt.imshow(np.squeeze(o), cmap='Greys_r', label='theta {}'.format(th))
    # plt.title('{}'.format(th))