Example #1
0
def create_job(kwargs):
    import warnings
    warnings.filterwarnings("ignore")

    # pendulum env
    env = gym.make('Pendulum-TO-v0')
    env._max_episode_steps = 10000
    env.unwrapped.dt = 0.02
    env.unwrapped.umax = np.array([2.5])
    env.unwrapped.periodic = False

    dm_state = env.observation_space.shape[0]
    dm_act = env.action_space.shape[0]

    state = env.reset()
    init_state = tuple([state, 1e-4 * np.eye(dm_state)])
    solver = MBGPS(env,
                   init_state=init_state,
                   init_action_sigma=25.,
                   nb_steps=300,
                   kl_bound=.1,
                   action_penalty=1e-3,
                   activation={
                       'shift': 250,
                       'mult': 0.5
                   })

    solver.run(nb_iter=100, verbose=False)

    solver.ctl.sigma = np.dstack([1e-1 * np.eye(dm_act)] * 300)
    data = solver.rollout(nb_episodes=1, stoch=True, init=state)

    obs, act = np.squeeze(data['x'], axis=-1).T, np.squeeze(data['u'],
                                                            axis=-1).T
    return obs, act
Example #2
0
mbgps = MBGPS(env,
              nb_steps=100,
              init_state=env.init(),
              init_action_sigma=100.,
              kl_bound=5.)

mbgps.run(nb_iter=15, verbose=True)

riccati = Riccati(env, nb_steps=100, init_state=env.init())

riccati.run()

np.random.seed(1337)
env.seed(1337)
gps_data = mbgps.rollout(250, stoch=False)

np.random.seed(1337)
env.seed(1337)
riccati_data = riccati.rollout(250)

print('GPS Cost: ', np.mean(np.sum(gps_data['c'], axis=0)), ', Riccati Cost',
      np.mean(np.sum(riccati_data['c'], axis=0)))

plt.figure(figsize=(6, 12))
plt.suptitle("LQR Mean Traj.: Riccati vs GPS")

for i in range(dm_state):
    plt.subplot(dm_state + dm_act, 1, i + 1)
    plt.plot(riccati.xref[i, ...], color='k')
    plt.plot(mbgps.xdist.mu[i, ...],