Exemplo n.º 1
0
def trpo_lagrangian(**kwargs):
    # Objective-penalized form of Lagrangian TRPO.
    trpo_kwargs = dict(reward_penalized=False,
                       objective_penalized=True,
                       learn_penalty=True,
                       penalty_param_loss=True)
    agent = TRPOAgent(**trpo_kwargs)
    run_polopt_agent(agent=agent, **kwargs)
Exemplo n.º 2
0
def ppo(**kwargs):
    ppo_kwargs = dict(
        reward_penalized=False,
        objective_penalized=False,
        learn_penalty=False,
        penalty_param_loss=False  # Irrelevant in unconstrained
    )
    agent = PPOAgent(**ppo_kwargs)
    run_polopt_agent(agent=agent, **kwargs)
Exemplo n.º 3
0
def cpo(**kwargs):
    cpo_kwargs = dict(
        reward_penalized=False,  # Irrelevant in CPO
        objective_penalized=False,  # Irrelevant in CPO
        learn_penalty=False,  # Irrelevant in CPO
        penalty_param_loss=False  # Irrelevant in CPO
    )
    agent = CPOAgent(**cpo_kwargs)
    run_polopt_agent(agent=agent, **kwargs)
Exemplo n.º 4
0
def ppo_lagrangian(pi_iters, **kwargs):
    # Objective-penalized form of Lagrangian PPO.
    ppo_kwargs = dict(
        pi_iters=pi_iters,
        reward_penalized=False,
        objective_penalized=True,
        learn_penalty=True,
        penalty_param_loss=True,
    )
    agent = PPOAgent(**ppo_kwargs)
    run_polopt_agent(agent=agent, **kwargs)
Exemplo n.º 5
0
 def learn(
         self,
         # Experience collection:
         steps_per_epoch=4000,
         epochs=50,
         max_ep_len=1000,
         # Discount factors:
         gamma=0.99,
         lam=0.97,
         cost_gamma=0.99,
         cost_lam=0.97,
         # Policy learning:
         ent_reg=0.,
         # Cost constraints / penalties:
         cost_lim=25,
         penalty_init=1.,
         penalty_lr=5e-2,
         # KL divergence:
         target_kl=0.01,
         # Value learning:
         vf_lr=1e-3,
         vf_iters=80):
     self.sess, self.pi, self.mu, self.x_ph = run_polopt_agent(
         env_fn=self.env_fn,
         agent=self.agent,
         steps_per_epoch=steps_per_epoch,
         epochs=epochs,
         max_ep_len=max_ep_len,
         gamma=gamma,
         lam=lam,
         cost_gamma=cost_gamma,
         cost_lam=cost_lam,
         ent_reg=ent_reg,
         cost_lim=cost_lim,
         penalty_init=penalty_init,
         penalty_lr=penalty_lr,
         target_kl=target_kl,
         vf_lr=vf_lr,
         vf_iters=vf_iters,
         actor_critic=self.actor_critic_fn,
         ac_kwargs=self.ac_kwargs,
         seed=self.seed,
         render=self.render,
         logger=self.logger,
         logger_kwargs=self.logger_kwargs,
         save_freq=self.save_freq)