class PGAgent(Agent): """Simple policy gradient agent. This is an example agent to show how to implement algorithms in RLlib. In most cases, you will probably want to use the PPO agent instead. """ _agent_name = "PG" _default_config = DEFAULT_CONFIG _policy_graph = PGPolicyGraph @override(Agent) def _init(self): self.local_evaluator = self.make_local_evaluator( self.env_creator, self._policy_graph) self.remote_evaluators = self.make_remote_evaluators( self.env_creator, self._policy_graph, self.config["num_workers"]) optimizer_config = dict( self.config["optimizer"], **{"train_batch_size": self.config["train_batch_size"]}) self.optimizer = SyncSamplesOptimizer( self.local_evaluator, self.remote_evaluators, optimizer_config) @override(Agent) def _train(self): prev_steps = self.optimizer.num_steps_sampled self.optimizer.step() result = self.optimizer.collect_metrics( self.config["collect_metrics_timeout"]) result.update(timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps) return result
def testTrainMultiCartpoleManyPolicies(self): n = 20 env = gym.make("CartPole-v0") act_space = env.action_space obs_space = env.observation_space policies = {} for i in range(20): policies["pg_{}".format(i)] = (PGPolicyGraph, obs_space, act_space, {}) policy_ids = list(policies.keys()) ev = PolicyEvaluator( env_creator=lambda _: MultiCartpole(n), policy_graph=policies, policy_mapping_fn=lambda agent_id: random.choice(policy_ids), batch_steps=100) optimizer = SyncSamplesOptimizer(ev, [], {}) for i in range(100): optimizer.step() result = collect_metrics(ev) print("Iteration {}, rew {}".format(i, result["policy_reward_mean"])) print("Total reward", result["episode_reward_mean"]) if result["episode_reward_mean"] >= 25 * n: return raise Exception("failed to improve reward")
class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG _policy = default_policy def __init__(self, config=None, env=None, logger_creator=None): Trainer.__init__(self, config, env, logger_creator) def _init(self, config, env_creator): if validate_config: validate_config(config) if get_initial_state: self.state = get_initial_state(self) else: self.state = {} if get_policy_class is None: policy = default_policy else: policy = get_policy_class(config) if before_init: before_init(self) if make_workers: self.workers = make_workers(self, env_creator, policy, config) else: self.workers = self._make_workers(env_creator, policy, config, self.config["num_workers"]) if make_policy_optimizer: self.optimizer = make_policy_optimizer(self.workers, config) else: optimizer_config = dict( config["optimizer"], **{"train_batch_size": config["train_batch_size"]}) self.optimizer = SyncSamplesOptimizer(self.workers, **optimizer_config) if after_init: after_init(self) @override(Trainer) def _train(self): if before_train_step: before_train_step(self) prev_steps = self.optimizer.num_steps_sampled start = time.time() while True: fetches = self.optimizer.step() if after_optimizer_step: after_optimizer_step(self, fetches) if (time.time() - start >= self.config["min_iter_time_s"] and self.optimizer.num_steps_sampled - prev_steps >= self.config["timesteps_per_iteration"]): break if collect_metrics_fn: res = collect_metrics_fn(self) else: res = self.collect_metrics() res.update( timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=res.get("info", {})) if after_train_result: after_train_result(self, res) return res @override(Trainer) def _before_evaluate(self): if before_evaluate_fn: before_evaluate_fn(self) def __getstate__(self): state = Trainer.__getstate__(self) state["trainer_state"] = self.state.copy() return state def __setstate__(self, state): Trainer.__setstate__(self, state) self.state = state["trainer_state"].copy()
class trainer_cls(base): _name = name _default_config = default_config or COMMON_CONFIG _policy = default_policy def __init__(self, config=None, env=None, logger_creator=None): Trainer.__init__(self, config, env, logger_creator) def _init(self, config, env_creator): if validate_config: validate_config(config) if get_initial_state: self.state = get_initial_state(self) else: self.state = {} if get_policy_class is None: self._policy = default_policy else: self._policy = get_policy_class(config) if before_init: before_init(self) use_exec_api = (execution_plan and (self.config["use_exec_api"] or "RLLIB_EXEC_API" in os.environ)) # Creating all workers (excluding evaluation workers). if make_workers and not use_exec_api: self.workers = make_workers(self, env_creator, self._policy, config) else: self.workers = self._make_workers(env_creator, self._policy, config, self.config["num_workers"]) self.train_exec_impl = None self.optimizer = None self.execution_plan = execution_plan if use_exec_api: logger.warning( "The experimental distributed execution API is enabled " "for this algorithm. Disable this by setting " "'use_exec_api': False.") self.train_exec_impl = execution_plan(self.workers, config) elif make_policy_optimizer: self.optimizer = make_policy_optimizer(self.workers, config) else: optimizer_config = dict( config["optimizer"], **{"train_batch_size": config["train_batch_size"]}) self.optimizer = SyncSamplesOptimizer(self.workers, **optimizer_config) if after_init: after_init(self) @override(Trainer) def _train(self): if self.train_exec_impl: return self._train_exec_impl() if before_train_step: before_train_step(self) prev_steps = self.optimizer.num_steps_sampled start = time.time() optimizer_steps_this_iter = 0 while True: fetches = self.optimizer.step() optimizer_steps_this_iter += 1 if after_optimizer_step: after_optimizer_step(self, fetches) if (time.time() - start >= self.config["min_iter_time_s"] and self.optimizer.num_steps_sampled - prev_steps >= self.config["timesteps_per_iteration"]): break if collect_metrics_fn: res = collect_metrics_fn(self) else: res = self.collect_metrics() res.update( optimizer_steps_this_iter=optimizer_steps_this_iter, timesteps_this_iter=self.optimizer.num_steps_sampled - prev_steps, info=res.get("info", {})) if after_train_result: after_train_result(self, res) return res def _train_exec_impl(self): if before_train_step: logger.warning("Ignoring before_train_step callback") res = next(self.train_exec_impl) if after_train_result: logger.warning("Ignoring after_train_result callback") return res @override(Trainer) def _before_evaluate(self): if before_evaluate_fn: before_evaluate_fn(self) def __getstate__(self): state = Trainer.__getstate__(self) state["trainer_state"] = self.state.copy() if self.train_exec_impl: state["train_exec_impl"] = ( self.train_exec_impl.shared_metrics.get().save()) return state def __setstate__(self, state): Trainer.__setstate__(self, state) self.state = state["trainer_state"].copy() if self.train_exec_impl: self.train_exec_impl.shared_metrics.get().restore( state["train_exec_impl"])