class A3CAgent(Agent): """A3C implementations in TensorFlow and PyTorch.""" _agent_name = "A3C" _default_config = DEFAULT_CONFIG @classmethod def default_resource_request(cls, config): cf = dict(cls._default_config, **config) return Resources( cpu=1, gpu=0, extra_cpu=cf["num_workers"], extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0) def _init(self): if self.config["use_pytorch"]: from ray.rllib.agents.a3c.a3c_torch_policy import \ A3CTorchPolicyGraph policy_cls = A3CTorchPolicyGraph else: from ray.rllib.agents.a3c.a3c_tf_policy import A3CPolicyGraph policy_cls = A3CPolicyGraph self.local_evaluator = self.make_local_evaluator( self.env_creator, policy_cls) self.remote_evaluators = self.make_remote_evaluators( self.env_creator, policy_cls, self.config["num_workers"], {"num_gpus": 1 if self.config["use_gpu_for_workers"] else 0}) self.optimizer = AsyncGradientsOptimizer(self.config["optimizer"], self.local_evaluator, self.remote_evaluators) def _train(self): self.optimizer.step() FilterManager.synchronize(self.local_evaluator.filters, self.remote_evaluators) result = collect_metrics(self.local_evaluator, self.remote_evaluators) result = result._replace(info=self.optimizer.stats()) return result def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for ev in self.remote_evaluators: ev.__ray_terminate__.remote() def _save(self, checkpoint_dir): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint-{}".format(self.iteration)) agent_state = ray.get( [a.save.remote() for a in self.remote_evaluators]) extra_data = { "remote_state": agent_state, "local_state": self.local_evaluator.save() } pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path def _restore(self, checkpoint_path): extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) ray.get([ a.restore.remote(o) for a, o in zip(self.remote_evaluators, extra_data["remote_state"]) ]) self.local_evaluator.restore(extra_data["local_state"])
class A3CAgent(Agent): _agent_name = "A3C" _default_config = DEFAULT_CONFIG @classmethod def default_resource_request(cls, config): cf = dict(cls._default_config, **config) return Resources( cpu=1, gpu=0, extra_cpu=cf["num_workers"], extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0) def _init(self): if self.config["use_pytorch"]: from ray.rllib.a3c.a3c_torch_policy import A3CTorchPolicyGraph self.policy_cls = A3CTorchPolicyGraph else: from ray.rllib.a3c.a3c_tf_policy import A3CPolicyGraph self.policy_cls = A3CPolicyGraph if self.config["use_pytorch"]: session_creator = None else: import tensorflow as tf def session_creator(): return tf.Session( config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1, gpu_options=tf.GPUOptions( allow_growth=True))) remote_cls = CommonPolicyEvaluator.as_remote( num_gpus=1 if self.config["use_gpu_for_workers"] else 0) self.local_evaluator = CommonPolicyEvaluator( self.env_creator, self.config["multiagent"]["policy_graphs"] or self.policy_cls, policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", tf_session_creator=session_creator, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) self.remote_evaluators = [ remote_cls.remote( self.env_creator, self.config["multiagent"]["policy_graphs"] or self.policy_cls, policy_mapping_fn=( self.config["multiagent"]["policy_mapping_fn"]), batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", sample_async=True, tf_session_creator=session_creator, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"], worker_index=i + 1) for i in range(self.config["num_workers"]) ] self.optimizer = AsyncGradientsOptimizer(self.config["optimizer"], self.local_evaluator, self.remote_evaluators) def _train(self): self.optimizer.step() FilterManager.synchronize(self.local_evaluator.filters, self.remote_evaluators) result = collect_metrics(self.local_evaluator, self.remote_evaluators) result = result._replace(info=self.optimizer.stats()) return result def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for ev in self.remote_evaluators: ev.__ray_terminate__.remote() def _save(self, checkpoint_dir): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint-{}".format(self.iteration)) agent_state = ray.get( [a.save.remote() for a in self.remote_evaluators]) extra_data = { "remote_state": agent_state, "local_state": self.local_evaluator.save() } pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path def _restore(self, checkpoint_path): extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) ray.get([ a.restore.remote(o) for a, o in zip(self.remote_evaluators, extra_data["remote_state"]) ]) self.local_evaluator.restore(extra_data["local_state"]) def compute_action(self, observation, state=None): if state is None: state = [] obs = self.local_evaluator.filters["default"](observation, update=False) return self.local_evaluator.for_policy( lambda p: p.compute_single_action(obs, state, is_training=False)[0 ])