class A3CAgent(Agent): _agent_name = "A3C" _default_config = DEFAULT_CONFIG _allow_unknown_subkeys = ["model", "optimizer", "env_config"] @classmethod def default_resource_request(cls, config): cf = dict(cls._default_config, **config) return Resources( cpu=1, gpu=0, extra_cpu=cf["num_workers"], extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0) def _init(self): self.policy_cls = get_policy_cls(self.config) if self.config["use_pytorch"]: session_creator = None else: import tensorflow as tf def session_creator(): return tf.Session( config=tf.ConfigProto( intra_op_parallelism_threads=1, inter_op_parallelism_threads=1, gpu_options=tf.GPUOptions(allow_growth=True))) remote_cls = CommonPolicyEvaluator.as_remote( num_gpus=1 if self.config["use_gpu_for_workers"] else 0) self.local_evaluator = CommonPolicyEvaluator( self.env_creator, self.policy_cls, batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", tf_session_creator=session_creator, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) self.remote_evaluators = [ remote_cls.remote( self.env_creator, self.policy_cls, batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", sample_async=True, tf_session_creator=session_creator, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) for i in range(self.config["num_workers"])] self.optimizer = AsyncOptimizer( self.config["optimizer"], self.local_evaluator, self.remote_evaluators) def _train(self): self.optimizer.step() FilterManager.synchronize( self.local_evaluator.filters, self.remote_evaluators) return collect_metrics(self.local_evaluator, self.remote_evaluators) def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for ev in self.remote_evaluators: ev.__ray_terminate__.remote() def _save(self, checkpoint_dir): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint-{}".format(self.iteration)) agent_state = ray.get( [a.save.remote() for a in self.remote_evaluators]) extra_data = { "remote_state": agent_state, "local_state": self.local_evaluator.save() } pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path def _restore(self, checkpoint_path): extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) ray.get([ a.restore.remote(o) for a, o in zip(self.remote_evaluators, extra_data["remote_state"]) ]) self.local_evaluator.restore(extra_data["local_state"]) def compute_action(self, observation, state=None): if state is None: state = [] obs = self.local_evaluator.obs_filter(observation, update=False) return self.local_evaluator.for_policy( lambda p: p.compute_single_action( obs, state, is_training=False)[0])
class PPOAgent(Agent): _agent_name = "PPO" _default_config = DEFAULT_CONFIG _default_policy_graph = PPOTFPolicyGraph @classmethod def default_resource_request(cls, config): cf = dict(cls._default_config, **config) return Resources( cpu=1, gpu=len([d for d in cf["devices"] if "gpu" in d.lower()]), extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) def _init(self): def session_creator(): return tf.Session(config=tf.ConfigProto( **self.config["tf_session_args"])) self.local_evaluator = CommonPolicyEvaluator( self.env_creator, self._default_policy_graph, tf_session_creator=session_creator, batch_mode="complete_episodes", observation_filter=self.config["observation_filter"], env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config) RemoteEvaluator = CommonPolicyEvaluator.as_remote( num_cpus=self.config["num_cpus_per_worker"], num_gpus=self.config["num_gpus_per_worker"]) self.remote_evaluators = [ RemoteEvaluator.remote( self.env_creator, self._default_policy_graph, batch_mode="complete_episodes", observation_filter=self.config["observation_filter"], env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config) for _ in range(self.config["num_workers"]) ] self.optimizer = LocalMultiGPUOptimizer( { "sgd_batch_size": self.config["sgd_batchsize"], "sgd_stepsize": self.config["sgd_stepsize"], "num_sgd_iter": self.config["num_sgd_iter"], "timesteps_per_batch": self.config["timesteps_per_batch"] }, self.local_evaluator, self.remote_evaluators) # TODO(rliaw): Push into Policy Graph with self.local_evaluator.tf_sess.graph.as_default(): self.saver = tf.train.Saver() def _train(self): def postprocess_samples(batch): # Divide by the maximum of value.std() and 1e-4 # to guard against the case where all values are equal value = batch["advantages"] standardized = (value - value.mean()) / max(1e-4, value.std()) batch.data["advantages"] = standardized batch.shuffle() dummy = np.zeros_like(batch["advantages"]) if not self.config["use_gae"]: batch.data["value_targets"] = dummy batch.data["vf_preds"] = dummy extra_fetches = self.optimizer.step(postprocess_fn=postprocess_samples) kl = np.array(extra_fetches["kl"]).mean(axis=1)[-1] total_loss = np.array(extra_fetches["total_loss"]).mean(axis=1)[-1] policy_loss = np.array(extra_fetches["policy_loss"]).mean(axis=1)[-1] vf_loss = np.array(extra_fetches["vf_loss"]).mean(axis=1)[-1] entropy = np.array(extra_fetches["entropy"]).mean(axis=1)[-1] newkl = self.local_evaluator.for_policy(lambda pi: pi.update_kl(kl)) info = { "kl_divergence": kl, "kl_coefficient": newkl, "total_loss": total_loss, "policy_loss": policy_loss, "vf_loss": vf_loss, "entropy": entropy, } FilterManager.synchronize(self.local_evaluator.filters, self.remote_evaluators) res = collect_metrics(self.local_evaluator, self.remote_evaluators) res = res._replace(info=info) return res def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for ev in self.remote_evaluators: ev.__ray_terminate__.remote() def _save(self, checkpoint_dir): checkpoint_path = self.saver.save(self.local_evaluator.tf_sess, os.path.join(checkpoint_dir, "checkpoint"), global_step=self.iteration) agent_state = ray.get( [a.save.remote() for a in self.remote_evaluators]) extra_data = [self.local_evaluator.save(), agent_state] pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path def _restore(self, checkpoint_path): self.saver.restore(self.local_evaluator.tf_sess, checkpoint_path) extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) self.local_evaluator.restore(extra_data[0]) ray.get([ a.restore.remote(o) for (a, o) in zip(self.remote_evaluators, extra_data[1]) ]) def compute_action(self, observation, state=None): if state is None: state = [] obs = self.local_evaluator.filters["default"](observation, update=False) return self.local_evaluator.for_policy( lambda p: p.compute_single_action(obs, state, is_training=False)[0 ])
class DQNAgent(Agent): _agent_name = "DQN" _default_config = DEFAULT_CONFIG _policy_graph = DQNPolicyGraph @classmethod def default_resource_request(cls, config): cf = dict(cls._default_config, **config) return Resources( cpu=1, gpu=cf["gpu"] and 1 or 0, extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) def _init(self): adjusted_batch_size = (self.config["sample_batch_size"] + self.config["n_step"] - 1) self.local_evaluator = CommonPolicyEvaluator( self.env_creator, self.config["multiagent"]["policy_graphs"] or self._policy_graph, policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], batch_steps=adjusted_batch_size, batch_mode="truncate_episodes", preprocessor_pref="deepmind", compress_observations=True, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) remote_cls = CommonPolicyEvaluator.as_remote( num_cpus=self.config["num_cpus_per_worker"], num_gpus=self.config["num_gpus_per_worker"]) self.remote_evaluators = [ remote_cls.remote(self.env_creator, self._policy_graph, batch_steps=adjusted_batch_size, batch_mode="truncate_episodes", preprocessor_pref="deepmind", compress_observations=True, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"], worker_index=i + 1) for i in range(self.config["num_workers"]) ] self.exploration0 = self._make_exploration_schedule(0) self.explorations = [ self._make_exploration_schedule(i) for i in range(self.config["num_workers"]) ] for k in OPTIMIZER_SHARED_CONFIGS: if k not in self.config["optimizer_config"]: self.config["optimizer_config"][k] = self.config[k] self.optimizer = getattr(optimizers, self.config["optimizer_class"])( self.config["optimizer_config"], self.local_evaluator, self.remote_evaluators) self.last_target_update_ts = 0 self.num_target_updates = 0 def _make_exploration_schedule(self, worker_index): # Use either a different `eps` per worker, or a linear schedule. if self.config["per_worker_exploration"]: assert self.config["num_workers"] > 1, \ "This requires multiple workers" return ConstantSchedule(0.4**( 1 + worker_index / float(self.config["num_workers"] - 1) * 7)) return LinearSchedule( schedule_timesteps=int(self.config["exploration_fraction"] * self.config["schedule_max_timesteps"]), initial_p=1.0, final_p=self.config["exploration_final_eps"]) @property def global_timestep(self): return self.optimizer.num_steps_sampled def update_target_if_needed(self): if self.global_timestep - self.last_target_update_ts > \ self.config["target_network_update_freq"]: self.local_evaluator.foreach_policy(lambda p, _: p.update_target()) self.last_target_update_ts = self.global_timestep self.num_target_updates += 1 def _train(self): start_timestep = self.global_timestep while (self.global_timestep - start_timestep < self.config["timesteps_per_iteration"]): self.optimizer.step() self.update_target_if_needed() exp_vals = [self.exploration0.value(self.global_timestep)] self.local_evaluator.foreach_policy( lambda p, _: p.set_epsilon(exp_vals[0])) for i, e in enumerate(self.remote_evaluators): exp_val = self.explorations[i].value(self.global_timestep) e.foreach_policy.remote(lambda p, _: p.set_epsilon(exp_val)) exp_vals.append(exp_val) result = collect_metrics(self.local_evaluator, self.remote_evaluators) return result._replace(info=dict( { "min_exploration": min(exp_vals), "max_exploration": max(exp_vals), "num_target_updates": self.num_target_updates, }, **self.optimizer.stats())) def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for ev in self.remote_evaluators: ev.__ray_terminate__.remote() def _save(self, checkpoint_dir): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint-{}".format(self.iteration)) extra_data = [ self.local_evaluator.save(), ray.get([e.save.remote() for e in self.remote_evaluators]), self.optimizer.save(), self.num_target_updates, self.last_target_update_ts ] pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path def _restore(self, checkpoint_path): extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) self.local_evaluator.restore(extra_data[0]) ray.get([ e.restore.remote(d) for (d, e) in zip(extra_data[1], self.remote_evaluators) ]) self.optimizer.restore(extra_data[2]) self.num_target_updates = extra_data[3] self.last_target_update_ts = extra_data[4] def compute_action(self, observation, state=None): if state is None: state = [] return self.local_evaluator.for_policy( lambda p: p.compute_single_action( observation, state, is_training=False)[0])