class DQNAgent(Agent): _agent_name = "DQN" _allow_unknown_subkeys = [ "model", "optimizer", "tf_session_args", "env_config"] _default_config = DEFAULT_CONFIG def _init(self): self.local_evaluator = DQNEvaluator( self.registry, self.env_creator, self.config, self.logdir, 0) remote_cls = ray.remote( num_cpus=1, num_gpus=self.config["num_gpus_per_worker"])( DQNEvaluator) self.remote_evaluators = [ remote_cls.remote( self.registry, self.env_creator, self.config, self.logdir, i) for i in range(self.config["num_workers"])] if self.config["force_evaluators_remote"]: self.remote_evaluators = drop_colocated(self.remote_evaluators) for k in OPTIMIZER_SHARED_CONFIGS: if k not in self.config["optimizer_config"]: self.config["optimizer_config"][k] = self.config[k] self.optimizer = getattr(optimizers, self.config["optimizer_class"])( self.config["optimizer_config"], self.local_evaluator, self.remote_evaluators) self.saver = tf.train.Saver(max_to_keep=None) self.last_target_update_ts = 0 self.num_target_updates = 0 @property def global_timestep(self): return self.optimizer.num_steps_sampled def update_target_if_needed(self): if self.global_timestep - self.last_target_update_ts > \ self.config["target_network_update_freq"]: self.local_evaluator.update_target() self.last_target_update_ts = self.global_timestep self.num_target_updates += 1 def _train(self): start_timestep = self.global_timestep while (self.global_timestep - start_timestep < self.config["timesteps_per_iteration"]): self.optimizer.step() self.update_target_if_needed() self.local_evaluator.set_global_timestep(self.global_timestep) for e in self.remote_evaluators: e.set_global_timestep.remote(self.global_timestep) return self._train_stats(start_timestep) def _train_stats(self, start_timestep): if self.remote_evaluators: stats = ray.get([ e.stats.remote() for e in self.remote_evaluators]) else: stats = self.local_evaluator.stats() if not isinstance(stats, list): stats = [stats] mean_100ep_reward = 0.0 mean_100ep_length = 0.0 num_episodes = 0 explorations = [] if self.config["per_worker_exploration"]: # Return stats from workers with the lowest 20% of exploration test_stats = stats[-int(max(1, len(stats)*0.2)):] else: test_stats = stats for s in test_stats: mean_100ep_reward += s["mean_100ep_reward"] / len(test_stats) mean_100ep_length += s["mean_100ep_length"] / len(test_stats) for s in stats: num_episodes += s["num_episodes"] explorations.append(s["exploration"]) opt_stats = self.optimizer.stats() result = TrainingResult( episode_reward_mean=mean_100ep_reward, episode_len_mean=mean_100ep_length, episodes_total=num_episodes, timesteps_this_iter=self.global_timestep - start_timestep, info=dict({ "min_exploration": min(explorations), "max_exploration": max(explorations), "num_target_updates": self.num_target_updates, }, **opt_stats)) return result def _populate_replay_buffer(self): if self.remote_evaluators: for e in self.remote_evaluators: e.sample.remote(no_replay=True) else: self.local_evaluator.sample(no_replay=True) def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for ev in self.remote_evaluators: ev.__ray_terminate__.remote(ev._ray_actor_id.id()) def _save(self, checkpoint_dir): checkpoint_path = self.saver.save( self.local_evaluator.sess, os.path.join(checkpoint_dir, "checkpoint"), global_step=self.iteration) extra_data = [ self.local_evaluator.save(), ray.get([e.save.remote() for e in self.remote_evaluators]), self.optimizer.save(), self.num_target_updates, self.last_target_update_ts] pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path def _restore(self, checkpoint_path): self.saver.restore(self.local_evaluator.sess, checkpoint_path) extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) self.local_evaluator.restore(extra_data[0]) ray.get([ e.restore.remote(d) for (d, e) in zip(extra_data[1], self.remote_evaluators)]) self.optimizer.restore(extra_data[2]) self.num_target_updates = extra_data[3] self.last_target_update_ts = extra_data[4] def compute_action(self, observation): return self.local_evaluator.dqn_graph.act( self.local_evaluator.sess, np.array(observation)[None], 0.0)[0]
class DQNAgent(Agent): _agent_name = "DQN" _allow_unknown_subkeys = [ "model", "optimizer", "tf_session_args", "env_config"] _default_config = DEFAULT_CONFIG @classmethod def default_resource_request(cls, config): cf = dict(cls._default_config, **config) return Resources( cpu=1, gpu=cf["gpu"] and 1 or 0, extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) def _init(self): self.local_evaluator = DQNEvaluator( self.registry, self.env_creator, self.config, self.logdir, 0) remote_cls = ray.remote( num_cpus=self.config["num_cpus_per_worker"], num_gpus=self.config["num_gpus_per_worker"])( DQNEvaluator) self.remote_evaluators = [ remote_cls.remote( self.registry, self.env_creator, self.config, self.logdir, i) for i in range(self.config["num_workers"])] for k in OPTIMIZER_SHARED_CONFIGS: if k not in self.config["optimizer_config"]: self.config["optimizer_config"][k] = self.config[k] self.optimizer = getattr(optimizers, self.config["optimizer_class"])( self.config["optimizer_config"], self.local_evaluator, self.remote_evaluators) self.saver = tf.train.Saver(max_to_keep=None) self.last_target_update_ts = 0 self.num_target_updates = 0 @property def global_timestep(self): return self.optimizer.num_steps_sampled def update_target_if_needed(self): if self.global_timestep - self.last_target_update_ts > \ self.config["target_network_update_freq"]: self.local_evaluator.update_target() self.last_target_update_ts = self.global_timestep self.num_target_updates += 1 def _train(self): start_timestep = self.global_timestep while (self.global_timestep - start_timestep < self.config["timesteps_per_iteration"]): self.optimizer.step() self.update_target_if_needed() self.local_evaluator.set_global_timestep(self.global_timestep) for e in self.remote_evaluators: e.set_global_timestep.remote(self.global_timestep) return self._train_stats(start_timestep) def _train_stats(self, start_timestep): if self.remote_evaluators: stats = ray.get([ e.stats.remote() for e in self.remote_evaluators]) else: stats = self.local_evaluator.stats() if not isinstance(stats, list): stats = [stats] mean_100ep_reward = 0.0 mean_100ep_length = 0.0 num_episodes = 0 explorations = [] if self.config["per_worker_exploration"]: # Return stats from workers with the lowest 20% of exploration test_stats = stats[-int(max(1, len(stats)*0.2)):] else: test_stats = stats for s in test_stats: mean_100ep_reward += s["mean_100ep_reward"] / len(test_stats) mean_100ep_length += s["mean_100ep_length"] / len(test_stats) for s in stats: num_episodes += s["num_episodes"] explorations.append(s["exploration"]) opt_stats = self.optimizer.stats() result = TrainingResult( episode_reward_mean=mean_100ep_reward, episode_len_mean=mean_100ep_length, episodes_total=num_episodes, timesteps_this_iter=self.global_timestep - start_timestep, info=dict({ "min_exploration": min(explorations), "max_exploration": max(explorations), "num_target_updates": self.num_target_updates, }, **opt_stats)) return result def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for ev in self.remote_evaluators: ev.__ray_terminate__.remote() def _save(self, checkpoint_dir): checkpoint_path = self.saver.save( self.local_evaluator.sess, os.path.join(checkpoint_dir, "checkpoint"), global_step=self.iteration) extra_data = [ self.local_evaluator.save(), ray.get([e.save.remote() for e in self.remote_evaluators]), self.optimizer.save(), self.num_target_updates, self.last_target_update_ts] pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path def _restore(self, checkpoint_path): self.saver.restore(self.local_evaluator.sess, checkpoint_path) extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) self.local_evaluator.restore(extra_data[0]) ray.get([ e.restore.remote(d) for (d, e) in zip(extra_data[1], self.remote_evaluators)]) self.optimizer.restore(extra_data[2]) self.num_target_updates = extra_data[3] self.last_target_update_ts = extra_data[4] def compute_action(self, observation): return self.local_evaluator.dqn_graph.act( self.local_evaluator.sess, np.array(observation)[None], 0.0)[0]
def stats(self): if self.workers: return ray.get([s.stats.remote() for s in self.workers]) else: return DQNEvaluator.stats(self)