示例#1
0
文件: pg.py 项目: jamescasbon/ray
class PGAgent(Agent):
    """Simple policy gradient agent.

    This is an example agent to show how to implement algorithms in RLlib.
    In most cases, you will probably want to use the PPO agent instead.
    """

    _agent_name = "PG"
    _default_config = DEFAULT_CONFIG
    _policy_graph = PGPolicyGraph

    @override(Agent)
    def _init(self):
        self.local_evaluator = self.make_local_evaluator(
            self.env_creator, self._policy_graph)
        self.remote_evaluators = self.make_remote_evaluators(
            self.env_creator, self._policy_graph, self.config["num_workers"])
        optimizer_config = dict(
            self.config["optimizer"],
            **{"train_batch_size": self.config["train_batch_size"]})
        self.optimizer = SyncSamplesOptimizer(
            self.local_evaluator, self.remote_evaluators, optimizer_config)

    @override(Agent)
    def _train(self):
        prev_steps = self.optimizer.num_steps_sampled
        self.optimizer.step()
        result = self.optimizer.collect_metrics(
            self.config["collect_metrics_timeout"])
        result.update(timesteps_this_iter=self.optimizer.num_steps_sampled -
                      prev_steps)
        return result
 def testTrainMultiCartpoleManyPolicies(self):
     n = 20
     env = gym.make("CartPole-v0")
     act_space = env.action_space
     obs_space = env.observation_space
     policies = {}
     for i in range(20):
         policies["pg_{}".format(i)] = (PGPolicyGraph, obs_space, act_space,
                                        {})
     policy_ids = list(policies.keys())
     ev = PolicyEvaluator(
         env_creator=lambda _: MultiCartpole(n),
         policy_graph=policies,
         policy_mapping_fn=lambda agent_id: random.choice(policy_ids),
         batch_steps=100)
     optimizer = SyncSamplesOptimizer(ev, [], {})
     for i in range(100):
         optimizer.step()
         result = collect_metrics(ev)
         print("Iteration {}, rew {}".format(i,
                                             result["policy_reward_mean"]))
         print("Total reward", result["episode_reward_mean"])
         if result["episode_reward_mean"] >= 25 * n:
             return
     raise Exception("failed to improve reward")
示例#3
0
    class trainer_cls(base):
        _name = name
        _default_config = default_config or COMMON_CONFIG
        _policy = default_policy

        def __init__(self, config=None, env=None, logger_creator=None):
            Trainer.__init__(self, config, env, logger_creator)

        def _init(self, config, env_creator):
            if validate_config:
                validate_config(config)
            if get_initial_state:
                self.state = get_initial_state(self)
            else:
                self.state = {}
            if get_policy_class is None:
                policy = default_policy
            else:
                policy = get_policy_class(config)
            if before_init:
                before_init(self)
            if make_workers:
                self.workers = make_workers(self, env_creator, policy, config)
            else:
                self.workers = self._make_workers(env_creator, policy, config,
                                                  self.config["num_workers"])
            if make_policy_optimizer:
                self.optimizer = make_policy_optimizer(self.workers, config)
            else:
                optimizer_config = dict(
                    config["optimizer"],
                    **{"train_batch_size": config["train_batch_size"]})
                self.optimizer = SyncSamplesOptimizer(self.workers,
                                                      **optimizer_config)
            if after_init:
                after_init(self)

        @override(Trainer)
        def _train(self):
            if before_train_step:
                before_train_step(self)
            prev_steps = self.optimizer.num_steps_sampled

            start = time.time()
            while True:
                fetches = self.optimizer.step()
                if after_optimizer_step:
                    after_optimizer_step(self, fetches)
                if (time.time() - start >= self.config["min_iter_time_s"]
                        and self.optimizer.num_steps_sampled - prev_steps >=
                        self.config["timesteps_per_iteration"]):
                    break

            if collect_metrics_fn:
                res = collect_metrics_fn(self)
            else:
                res = self.collect_metrics()
            res.update(
                timesteps_this_iter=self.optimizer.num_steps_sampled -
                prev_steps,
                info=res.get("info", {}))

            if after_train_result:
                after_train_result(self, res)
            return res

        @override(Trainer)
        def _before_evaluate(self):
            if before_evaluate_fn:
                before_evaluate_fn(self)

        def __getstate__(self):
            state = Trainer.__getstate__(self)
            state["trainer_state"] = self.state.copy()
            return state

        def __setstate__(self, state):
            Trainer.__setstate__(self, state)
            self.state = state["trainer_state"].copy()
示例#4
0
    class trainer_cls(base):
        _name = name
        _default_config = default_config or COMMON_CONFIG
        _policy = default_policy

        def __init__(self, config=None, env=None, logger_creator=None):
            Trainer.__init__(self, config, env, logger_creator)

        def _init(self, config, env_creator):
            if validate_config:
                validate_config(config)

            if get_initial_state:
                self.state = get_initial_state(self)
            else:
                self.state = {}
            if get_policy_class is None:
                self._policy = default_policy
            else:
                self._policy = get_policy_class(config)
            if before_init:
                before_init(self)
            use_exec_api = (execution_plan
                            and (self.config["use_exec_api"]
                                 or "RLLIB_EXEC_API" in os.environ))

            # Creating all workers (excluding evaluation workers).
            if make_workers and not use_exec_api:
                self.workers = make_workers(self, env_creator, self._policy,
                                            config)
            else:
                self.workers = self._make_workers(env_creator, self._policy,
                                                  config,
                                                  self.config["num_workers"])
            self.train_exec_impl = None
            self.optimizer = None
            self.execution_plan = execution_plan

            if use_exec_api:
                logger.warning(
                    "The experimental distributed execution API is enabled "
                    "for this algorithm. Disable this by setting "
                    "'use_exec_api': False.")
                self.train_exec_impl = execution_plan(self.workers, config)
            elif make_policy_optimizer:
                self.optimizer = make_policy_optimizer(self.workers, config)
            else:
                optimizer_config = dict(
                    config["optimizer"],
                    **{"train_batch_size": config["train_batch_size"]})
                self.optimizer = SyncSamplesOptimizer(self.workers,
                                                      **optimizer_config)
            if after_init:
                after_init(self)

        @override(Trainer)
        def _train(self):
            if self.train_exec_impl:
                return self._train_exec_impl()

            if before_train_step:
                before_train_step(self)
            prev_steps = self.optimizer.num_steps_sampled

            start = time.time()
            optimizer_steps_this_iter = 0
            while True:
                fetches = self.optimizer.step()
                optimizer_steps_this_iter += 1
                if after_optimizer_step:
                    after_optimizer_step(self, fetches)
                if (time.time() - start >= self.config["min_iter_time_s"]
                        and self.optimizer.num_steps_sampled - prev_steps >=
                        self.config["timesteps_per_iteration"]):
                    break

            if collect_metrics_fn:
                res = collect_metrics_fn(self)
            else:
                res = self.collect_metrics()
            res.update(
                optimizer_steps_this_iter=optimizer_steps_this_iter,
                timesteps_this_iter=self.optimizer.num_steps_sampled -
                prev_steps,
                info=res.get("info", {}))

            if after_train_result:
                after_train_result(self, res)
            return res

        def _train_exec_impl(self):
            if before_train_step:
                logger.warning("Ignoring before_train_step callback")
            res = next(self.train_exec_impl)
            if after_train_result:
                logger.warning("Ignoring after_train_result callback")
            return res

        @override(Trainer)
        def _before_evaluate(self):
            if before_evaluate_fn:
                before_evaluate_fn(self)

        def __getstate__(self):
            state = Trainer.__getstate__(self)
            state["trainer_state"] = self.state.copy()
            if self.train_exec_impl:
                state["train_exec_impl"] = (
                    self.train_exec_impl.shared_metrics.get().save())
            return state

        def __setstate__(self, state):
            Trainer.__setstate__(self, state)
            self.state = state["trainer_state"].copy()
            if self.train_exec_impl:
                self.train_exec_impl.shared_metrics.get().restore(
                    state["train_exec_impl"])