Ejemplo n.º 1
0
class A3CAgent(Agent):
    """A3C implementations in TensorFlow and PyTorch."""

    _agent_name = "A3C"
    _default_config = DEFAULT_CONFIG

    @classmethod
    def default_resource_request(cls, config):
        cf = dict(cls._default_config, **config)
        return Resources(
            cpu=1,
            gpu=0,
            extra_cpu=cf["num_workers"],
            extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0)

    def _init(self):
        if self.config["use_pytorch"]:
            from ray.rllib.agents.a3c.a3c_torch_policy import \
                A3CTorchPolicyGraph
            policy_cls = A3CTorchPolicyGraph
        else:
            from ray.rllib.agents.a3c.a3c_tf_policy import A3CPolicyGraph
            policy_cls = A3CPolicyGraph

        self.local_evaluator = self.make_local_evaluator(
            self.env_creator, policy_cls)
        self.remote_evaluators = self.make_remote_evaluators(
            self.env_creator, policy_cls, self.config["num_workers"],
            {"num_gpus": 1 if self.config["use_gpu_for_workers"] else 0})
        self.optimizer = AsyncGradientsOptimizer(self.config["optimizer"],
                                                 self.local_evaluator,
                                                 self.remote_evaluators)

    def _train(self):
        self.optimizer.step()
        FilterManager.synchronize(self.local_evaluator.filters,
                                  self.remote_evaluators)
        result = collect_metrics(self.local_evaluator, self.remote_evaluators)
        result = result._replace(info=self.optimizer.stats())
        return result

    def _stop(self):
        # workaround for https://github.com/ray-project/ray/issues/1516
        for ev in self.remote_evaluators:
            ev.__ray_terminate__.remote()

    def _save(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir,
                                       "checkpoint-{}".format(self.iteration))
        agent_state = ray.get(
            [a.save.remote() for a in self.remote_evaluators])
        extra_data = {
            "remote_state": agent_state,
            "local_state": self.local_evaluator.save()
        }
        pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
        return checkpoint_path

    def _restore(self, checkpoint_path):
        extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
        ray.get([
            a.restore.remote(o)
            for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
        ])
        self.local_evaluator.restore(extra_data["local_state"])
Ejemplo n.º 2
0
Archivo: a3c.py Proyecto: velconia/ray
class A3CAgent(Agent):
    _agent_name = "A3C"
    _default_config = DEFAULT_CONFIG

    @classmethod
    def default_resource_request(cls, config):
        cf = dict(cls._default_config, **config)
        return Resources(
            cpu=1,
            gpu=0,
            extra_cpu=cf["num_workers"],
            extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0)

    def _init(self):
        if self.config["use_pytorch"]:
            from ray.rllib.a3c.a3c_torch_policy import A3CTorchPolicyGraph
            self.policy_cls = A3CTorchPolicyGraph
        else:
            from ray.rllib.a3c.a3c_tf_policy import A3CPolicyGraph
            self.policy_cls = A3CPolicyGraph

        if self.config["use_pytorch"]:
            session_creator = None
        else:
            import tensorflow as tf

            def session_creator():
                return tf.Session(
                    config=tf.ConfigProto(intra_op_parallelism_threads=1,
                                          inter_op_parallelism_threads=1,
                                          gpu_options=tf.GPUOptions(
                                              allow_growth=True)))

        remote_cls = CommonPolicyEvaluator.as_remote(
            num_gpus=1 if self.config["use_gpu_for_workers"] else 0)
        self.local_evaluator = CommonPolicyEvaluator(
            self.env_creator,
            self.config["multiagent"]["policy_graphs"] or self.policy_cls,
            policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"],
            batch_steps=self.config["batch_size"],
            batch_mode="truncate_episodes",
            tf_session_creator=session_creator,
            env_config=self.config["env_config"],
            model_config=self.config["model"],
            policy_config=self.config,
            num_envs=self.config["num_envs"])
        self.remote_evaluators = [
            remote_cls.remote(
                self.env_creator,
                self.config["multiagent"]["policy_graphs"] or self.policy_cls,
                policy_mapping_fn=(
                    self.config["multiagent"]["policy_mapping_fn"]),
                batch_steps=self.config["batch_size"],
                batch_mode="truncate_episodes",
                sample_async=True,
                tf_session_creator=session_creator,
                env_config=self.config["env_config"],
                model_config=self.config["model"],
                policy_config=self.config,
                num_envs=self.config["num_envs"],
                worker_index=i + 1) for i in range(self.config["num_workers"])
        ]

        self.optimizer = AsyncGradientsOptimizer(self.config["optimizer"],
                                                 self.local_evaluator,
                                                 self.remote_evaluators)

    def _train(self):
        self.optimizer.step()
        FilterManager.synchronize(self.local_evaluator.filters,
                                  self.remote_evaluators)
        result = collect_metrics(self.local_evaluator, self.remote_evaluators)
        result = result._replace(info=self.optimizer.stats())
        return result

    def _stop(self):
        # workaround for https://github.com/ray-project/ray/issues/1516
        for ev in self.remote_evaluators:
            ev.__ray_terminate__.remote()

    def _save(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir,
                                       "checkpoint-{}".format(self.iteration))
        agent_state = ray.get(
            [a.save.remote() for a in self.remote_evaluators])
        extra_data = {
            "remote_state": agent_state,
            "local_state": self.local_evaluator.save()
        }
        pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
        return checkpoint_path

    def _restore(self, checkpoint_path):
        extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
        ray.get([
            a.restore.remote(o)
            for a, o in zip(self.remote_evaluators, extra_data["remote_state"])
        ])
        self.local_evaluator.restore(extra_data["local_state"])

    def compute_action(self, observation, state=None):
        if state is None:
            state = []
        obs = self.local_evaluator.filters["default"](observation,
                                                      update=False)
        return self.local_evaluator.for_policy(
            lambda p: p.compute_single_action(obs, state, is_training=False)[0
                                                                             ])