def create_job(kwargs): import warnings warnings.filterwarnings("ignore") # pendulum env env = gym.make('Pendulum-TO-v0') env._max_episode_steps = 10000 env.unwrapped.dt = 0.02 env.unwrapped.umax = np.array([2.5]) env.unwrapped.periodic = False dm_state = env.observation_space.shape[0] dm_act = env.action_space.shape[0] state = env.reset() init_state = tuple([state, 1e-4 * np.eye(dm_state)]) solver = MBGPS(env, init_state=init_state, init_action_sigma=25., nb_steps=300, kl_bound=.1, action_penalty=1e-3, activation={ 'shift': 250, 'mult': 0.5 }) solver.run(nb_iter=100, verbose=False) solver.ctl.sigma = np.dstack([1e-1 * np.eye(dm_act)] * 300) data = solver.rollout(nb_episodes=1, stoch=True, init=state) obs, act = np.squeeze(data['x'], axis=-1).T, np.squeeze(data['u'], axis=-1).T return obs, act
mbgps = MBGPS(env, nb_steps=100, init_state=env.init(), init_action_sigma=100., kl_bound=5.) mbgps.run(nb_iter=15, verbose=True) riccati = Riccati(env, nb_steps=100, init_state=env.init()) riccati.run() np.random.seed(1337) env.seed(1337) gps_data = mbgps.rollout(250, stoch=False) np.random.seed(1337) env.seed(1337) riccati_data = riccati.rollout(250) print('GPS Cost: ', np.mean(np.sum(gps_data['c'], axis=0)), ', Riccati Cost', np.mean(np.sum(riccati_data['c'], axis=0))) plt.figure(figsize=(6, 12)) plt.suptitle("LQR Mean Traj.: Riccati vs GPS") for i in range(dm_state): plt.subplot(dm_state + dm_act, 1, i + 1) plt.plot(riccati.xref[i, ...], color='k') plt.plot(mbgps.xdist.mu[i, ...],