def main(): env = TfEnv(CustomGymEnv('PointMazeLeft-v0')) experts = load_latest_experts('data/point', n=50) irl_model = GCLDiscrim(env_spec=env.spec, expert_trajs=experts) policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32)) algo = IRLTRPO( env=env, policy=policy, irl_model=irl_model, n_itr=200, batch_size=2000, max_path_length=100, discount=0.99, store_paths=True, discrim_train_itrs=50, irl_model_wt=1.0, entropy_weight=0.1, # this should be 1.0 but 0.1 seems to work better zero_environment_reward=True, baseline=LinearFeatureBaseline(env_spec=env.spec)) with rllab_logdir(algo=algo, dirname='data/point_gcl'): with tf.Session() as sess: algo.train() test_pointmaze(sess.run(policy))
def main(num_examples=50, discount=0.99): env = TfEnv(GymEnv('Pendulum-v0', record_video=False, record_log=False)) experts = load_latest_experts('data/pendulum', n=num_examples) irl_model = GCLDiscrim(env_spec=env.spec, expert_trajs=experts) policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32)) algo = IRLTRPO( env=env, policy=policy, irl_model=irl_model, n_itr=200, batch_size=2000, max_path_length=100, discount=discount, store_paths=True, discrim_train_itrs=50, irl_model_wt=1.0, entropy_weight=0.1, # this should be 1.0 but 0.1 seems to work better zero_environment_reward=True, baseline=LinearFeatureBaseline(env_spec=env.spec) ) with rllab_logdir(algo=algo, dirname='data/pendulum_gcl'): with tf.Session(): algo.train()
def main(): env = TfEnv(GymEnv('Ant-v1', record_video=False, record_log=False)) experts = load_latest_experts('data/ant', n=50) irl_model = GCLDiscrim( env_spec=env.spec, expert_trajs=experts, discrim_arch=disentangled_net) policy = GaussianMLPPolicy(name='policy', env_spec=env.spec, hidden_sizes=(32, 32)) algo = IRLTRPO( env=env, policy=policy, irl_model=irl_model, n_itr=2000, batch_size=10000, max_path_length=1000, discount=0.995, store_paths=True, discrim_train_itrs=50, irl_model_wt=1.0, entropy_weight=0.1, zero_environment_reward=True, baseline=LinearFeatureBaseline(env_spec=env.spec) ) with rllab_logdir(algo=algo, dirname='data/ant_airl'): with tf.Session(): algo.train()