def testCompleteEpisodesPacking(self): ev = CommonPolicyEvaluator(env_creator=lambda _: MockEnv(10), policy_graph=MockPolicyGraph, batch_steps=15, batch_mode="complete_episodes") batch = ev.sample() self.assertEqual(batch.count, 20)
def testMultiAgentSampleRoundRobin(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) ev = CommonPolicyEvaluator( env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True), policy_graph={ "p0": (MockPolicyGraph, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p0", batch_steps=50) batch = ev.sample() self.assertEqual(batch.count, 50) # since we round robin introduce agents into the env, some of the env # steps don't count as proper transitions self.assertEqual(batch.policy_batches["p0"].count, 42) self.assertEqual( batch.policy_batches["p0"]["obs"].tolist()[:10], [0, 1, 2, 3, 4] * 2) self.assertEqual( batch.policy_batches["p0"]["new_obs"].tolist()[:10], [1, 2, 3, 4, 5] * 2) self.assertEqual( batch.policy_batches["p0"]["rewards"].tolist()[:10], [100, 100, 100, 100, 0] * 2) self.assertEqual( batch.policy_batches["p0"]["dones"].tolist()[:10], [False, False, False, False, True] * 2) self.assertEqual( batch.policy_batches["p0"]["t"].tolist()[:10], [4, 9, 14, 19, 24, 5, 10, 15, 20, 25])
def testBasic(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph) batch = ev.sample() for key in ["obs", "actions", "rewards", "dones", "advantages"]: self.assertIn(key, batch) self.assertGreater(batch["advantages"][0], 1)
def testServingEnvBadActions(self): ev = CommonPolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), policy_graph=BadPolicyGraph, sample_async=True, batch_steps=40, batch_mode="truncate_episodes") self.assertRaises(Exception, lambda: ev.sample())
def testServingEnvHorizonNotSupported(self): ev = CommonPolicyEvaluator( env_creator=lambda _: SimpleServing(MockEnv(25)), policy_graph=MockPolicyGraph, episode_horizon=20, batch_steps=10, batch_mode="complete_episodes") ev.sample() self.assertRaises(Exception, lambda: ev.sample())
def testPackEpisodes(self): for batch_size in [1, 10, 100, 1000]: ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, batch_steps=batch_size, batch_mode="pack_episodes") batch = ev.sample() self.assertEqual(batch.count, batch_size)
def testServingEnvOffPolicy(self): ev = CommonPolicyEvaluator( env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25)), policy_graph=MockPolicyGraph, batch_steps=40, batch_mode="complete_episodes") for _ in range(3): batch = ev.sample() self.assertEqual(batch.count, 50)
def testAutoConcat(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=40), policy_graph=MockPolicyGraph, sample_async=True, batch_steps=10, batch_mode="truncate_episodes", observation_filter="ConcurrentMeanStdFilter") time.sleep(2) batch = ev.sample() self.assertEqual(batch.count, 40) # auto-concat up to 5 episodes
def testCompleteEpisodes(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, batch_steps=2, batch_mode="complete_episodes") batch = ev.sample() self.assertGreater(batch.count, 2) self.assertTrue(batch["dones"][-1]) batch = ev.sample() self.assertGreater(batch.count, 2) self.assertTrue(batch["dones"][-1])
def testMetrics(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=10), policy_graph=MockPolicyGraph, batch_mode="complete_episodes") remote_ev = CommonPolicyEvaluator.as_remote().remote( env_creator=lambda _: MockEnv(episode_length=10), policy_graph=MockPolicyGraph, batch_mode="complete_episodes") ev.sample() ray.get(remote_ev.sample.remote()) result = collect_metrics(ev, [remote_ev]) self.assertEqual(result.episodes_total, 20) self.assertEqual(result.episode_reward_mean, 10)
def testFilterSync(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, sample_async=True, observation_filter="ConcurrentMeanStdFilter") time.sleep(2) ev.sample() filters = ev.get_filters(flush_after=True) obs_f = filters["obs_filter"] self.assertNotEqual(obs_f.rs.n, 0) self.assertNotEqual(obs_f.buffer.n, 0)
def _init(self): if self.config["use_pytorch"]: from ray.rllib.a3c.a3c_torch_policy import A3CTorchPolicyGraph self.policy_cls = A3CTorchPolicyGraph else: from ray.rllib.a3c.a3c_tf_policy import A3CPolicyGraph self.policy_cls = A3CPolicyGraph if self.config["use_pytorch"]: session_creator = None else: import tensorflow as tf def session_creator(): return tf.Session( config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1, gpu_options=tf.GPUOptions( allow_growth=True))) remote_cls = CommonPolicyEvaluator.as_remote( num_gpus=1 if self.config["use_gpu_for_workers"] else 0) self.local_evaluator = CommonPolicyEvaluator( self.env_creator, self.config["multiagent"]["policy_graphs"] or self.policy_cls, policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", tf_session_creator=session_creator, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) self.remote_evaluators = [ remote_cls.remote( self.env_creator, self.config["multiagent"]["policy_graphs"] or self.policy_cls, policy_mapping_fn=( self.config["multiagent"]["policy_mapping_fn"]), batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", sample_async=True, tf_session_creator=session_creator, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"], worker_index=i + 1) for i in range(self.config["num_workers"]) ] self.optimizer = AsyncGradientsOptimizer(self.config["optimizer"], self.local_evaluator, self.remote_evaluators)
def testBatchesSmallerWhenVectorized(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=8), policy_graph=MockPolicyGraph, batch_mode="truncate_episodes", batch_steps=16, num_envs=4) batch = ev.sample() self.assertEqual(batch.count, 16) result = collect_metrics(ev, []) self.assertEqual(result.episodes_total, 0) batch = ev.sample() result = collect_metrics(ev, []) self.assertEqual(result.episodes_total, 4)
def testGetFilters(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, sample_async=True, observation_filter="ConcurrentMeanStdFilter") self.sample_and_flush(ev) filters = ev.get_filters(flush_after=False) time.sleep(2) filters2 = ev.get_filters(flush_after=False) obs_f = filters["obs_filter"] obs_f2 = filters2["obs_filter"] self.assertGreaterEqual(obs_f2.rs.n, obs_f.rs.n) self.assertGreaterEqual(obs_f2.buffer.n, obs_f.buffer.n)
def testTruncateEpisodes(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, batch_steps=2, batch_mode="truncate_episodes") batch = ev.sample() self.assertEqual(batch.count, 2) ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, batch_steps=1000, batch_mode="truncate_episodes") self.assertLess(batch.count, 200)
def _init(self): adjusted_batch_size = (self.config["sample_batch_size"] + self.config["n_step"] - 1) self.local_evaluator = CommonPolicyEvaluator( self.env_creator, self.config["multiagent"]["policy_graphs"] or self._policy_graph, policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], batch_steps=adjusted_batch_size, batch_mode="truncate_episodes", preprocessor_pref="deepmind", compress_observations=True, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) remote_cls = CommonPolicyEvaluator.as_remote( num_cpus=self.config["num_cpus_per_worker"], num_gpus=self.config["num_gpus_per_worker"]) self.remote_evaluators = [ remote_cls.remote(self.env_creator, self._policy_graph, batch_steps=adjusted_batch_size, batch_mode="truncate_episodes", preprocessor_pref="deepmind", compress_observations=True, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"], worker_index=i + 1) for i in range(self.config["num_workers"]) ] self.exploration0 = self._make_exploration_schedule(0) self.explorations = [ self._make_exploration_schedule(i) for i in range(self.config["num_workers"]) ] for k in OPTIMIZER_SHARED_CONFIGS: if k not in self.config["optimizer_config"]: self.config["optimizer_config"][k] = self.config[k] self.optimizer = getattr(optimizers, self.config["optimizer_class"])( self.config["optimizer_config"], self.local_evaluator, self.remote_evaluators) self.last_target_update_ts = 0 self.num_target_updates = 0
def testVectorEnvSupport(self): ev = CommonPolicyEvaluator( env_creator=lambda _: MockVectorEnv(episode_length=20, num_envs=8), policy_graph=MockPolicyGraph, batch_mode="truncate_episodes", batch_steps=10) for _ in range(8): batch = ev.sample() self.assertEqual(batch.count, 10) result = collect_metrics(ev, []) self.assertEqual(result.episodes_total, 0) for _ in range(8): batch = ev.sample() self.assertEqual(batch.count, 10) result = collect_metrics(ev, []) self.assertEqual(result.episodes_total, 8)
def _init(self): self.policy_cls = get_policy_cls(self.config) if self.config["use_pytorch"]: session_creator = None else: import tensorflow as tf def session_creator(): return tf.Session( config=tf.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1, gpu_options=tf.GPUOptions( allow_growth=True))) remote_cls = CommonPolicyEvaluator.as_remote( num_gpus=1 if self.config["use_gpu_for_workers"] else 0) self.local_evaluator = CommonPolicyEvaluator( self.env_creator, self.policy_cls, batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", tf_session_creator=session_creator, registry=self.registry, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) self.remote_evaluators = [ remote_cls.remote(self.env_creator, self.policy_cls, batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", sample_async=True, tf_session_creator=session_creator, registry=self.registry, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) for i in range(self.config["num_workers"]) ] self.optimizer = AsyncOptimizer(self.config["optimizer"], self.local_evaluator, self.remote_evaluators)
def testBatchDivisibilityCheck(self): self.assertRaises( ValueError, lambda: CommonPolicyEvaluator( env_creator=lambda _: MockEnv(episode_length=8), policy_graph=MockPolicyGraph, batch_mode="truncate_episodes", batch_steps=15, num_envs=4))
def testMultiAgentSample(self): act_space = gym.spaces.Discrete(2) obs_space = gym.spaces.Discrete(2) ev = CommonPolicyEvaluator( env_creator=lambda _: BasicMultiAgent(5), policy_graph={ "p0": (MockPolicyGraph, obs_space, act_space, {}), "p1": (MockPolicyGraph, obs_space, act_space, {}), }, policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), batch_steps=50) batch = ev.sample() self.assertEqual(batch.count, 50) self.assertEqual(batch.policy_batches["p0"].count, 150) self.assertEqual(batch.policy_batches["p1"].count, 100) self.assertEqual( batch.policy_batches["p0"]["t"].tolist(), list(range(25)) * 6)
def _init(self): def session_creator(): return tf.Session(config=tf.ConfigProto( **self.config["tf_session_args"])) self.local_evaluator = CommonPolicyEvaluator( self.env_creator, self._default_policy_graph, tf_session_creator=session_creator, batch_mode="complete_episodes", observation_filter=self.config["observation_filter"], env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config) RemoteEvaluator = CommonPolicyEvaluator.as_remote( num_cpus=self.config["num_cpus_per_worker"], num_gpus=self.config["num_gpus_per_worker"]) self.remote_evaluators = [ RemoteEvaluator.remote( self.env_creator, self._default_policy_graph, batch_mode="complete_episodes", observation_filter=self.config["observation_filter"], env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config) for _ in range(self.config["num_workers"]) ] self.optimizer = LocalMultiGPUOptimizer( { "sgd_batch_size": self.config["sgd_batchsize"], "sgd_stepsize": self.config["sgd_stepsize"], "num_sgd_iter": self.config["num_sgd_iter"], "timesteps_per_batch": self.config["timesteps_per_batch"] }, self.local_evaluator, self.remote_evaluators) # TODO(rliaw): Push into Policy Graph with self.local_evaluator.tf_sess.graph.as_default(): self.saver = tf.train.Saver()
def _testWithOptimizer(self, optimizer_cls): n = 3 env = gym.make("CartPole-v0") act_space = env.action_space obs_space = env.observation_space dqn_config = {"gamma": 0.95, "n_step": 3} if optimizer_cls == SyncReplayOptimizer: # TODO: support replay with non-DQN graphs. Currently this can't # happen since the replay buffer doesn't encode extra fields like # "advantages" that PG uses. policies = { "p1": (DQNPolicyGraph, obs_space, act_space, dqn_config), "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), } else: policies = { "p1": (PGPolicyGraph, obs_space, act_space, {}), "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), } ev = CommonPolicyEvaluator( env_creator=lambda _: MultiCartpole(n), policy_graph=policies, policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2], batch_steps=50) if optimizer_cls == AsyncGradientsOptimizer: remote_evs = [CommonPolicyEvaluator.as_remote().remote( env_creator=lambda _: MultiCartpole(n), policy_graph=policies, policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2], batch_steps=50)] else: remote_evs = [] optimizer = optimizer_cls({}, ev, remote_evs) for i in range(200): ev.foreach_policy( lambda p, _: p.set_epsilon(max(0.02, 1 - i * .02)) if isinstance(p, DQNPolicyGraph) else None) optimizer.step() result = collect_metrics(ev, remote_evs) if i % 20 == 0: ev.foreach_policy( lambda p, _: p.update_target() if isinstance(p, DQNPolicyGraph) else None) print("Iter {}, rew {}".format(i, result.policy_reward_mean)) print("Total reward", result.episode_reward_mean) if result.episode_reward_mean >= 25 * n: return print(result) raise Exception("failed to improve reward")
def testTrainMultiCartpoleManyPolicies(self): n = 20 env = gym.make("CartPole-v0") act_space = env.action_space obs_space = env.observation_space policies = {} for i in range(20): policies["pg_{}".format(i)] = ( PGPolicyGraph, obs_space, act_space, {}) policy_ids = list(policies.keys()) ev = CommonPolicyEvaluator( env_creator=lambda _: MultiCartpole(n), policy_graph=policies, policy_mapping_fn=lambda agent_id: random.choice(policy_ids), batch_steps=100) optimizer = SyncSamplesOptimizer({}, ev, []) for i in range(100): optimizer.step() result = collect_metrics(ev) print("Iteration {}, rew {}".format(i, result.policy_reward_mean)) print("Total reward", result.episode_reward_mean) if result.episode_reward_mean >= 25 * n: return raise Exception("failed to improve reward")
def testSyncFilter(self): ev = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=MockPolicyGraph, sample_async=True, observation_filter="ConcurrentMeanStdFilter") obs_f = self.sample_and_flush(ev) # Current State filters = ev.get_filters(flush_after=False) obs_f = filters["obs_filter"] self.assertLessEqual(obs_f.buffer.n, 20) new_obsf = obs_f.copy() new_obsf.rs._n = 100 ev.sync_filters({"obs_filter": new_obsf}) filters = ev.get_filters(flush_after=False) obs_f = filters["obs_filter"] self.assertGreaterEqual(obs_f.rs.n, 100) self.assertLessEqual(obs_f.buffer.n, 20)
class PPOAgent(Agent): _agent_name = "PPO" _default_config = DEFAULT_CONFIG _default_policy_graph = PPOTFPolicyGraph @classmethod def default_resource_request(cls, config): cf = dict(cls._default_config, **config) return Resources( cpu=1, gpu=len([d for d in cf["devices"] if "gpu" in d.lower()]), extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) def _init(self): def session_creator(): return tf.Session(config=tf.ConfigProto( **self.config["tf_session_args"])) self.local_evaluator = CommonPolicyEvaluator( self.env_creator, self._default_policy_graph, tf_session_creator=session_creator, batch_mode="complete_episodes", observation_filter=self.config["observation_filter"], env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config) RemoteEvaluator = CommonPolicyEvaluator.as_remote( num_cpus=self.config["num_cpus_per_worker"], num_gpus=self.config["num_gpus_per_worker"]) self.remote_evaluators = [ RemoteEvaluator.remote( self.env_creator, self._default_policy_graph, batch_mode="complete_episodes", observation_filter=self.config["observation_filter"], env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config) for _ in range(self.config["num_workers"]) ] self.optimizer = LocalMultiGPUOptimizer( { "sgd_batch_size": self.config["sgd_batchsize"], "sgd_stepsize": self.config["sgd_stepsize"], "num_sgd_iter": self.config["num_sgd_iter"], "timesteps_per_batch": self.config["timesteps_per_batch"] }, self.local_evaluator, self.remote_evaluators) # TODO(rliaw): Push into Policy Graph with self.local_evaluator.tf_sess.graph.as_default(): self.saver = tf.train.Saver() def _train(self): def postprocess_samples(batch): # Divide by the maximum of value.std() and 1e-4 # to guard against the case where all values are equal value = batch["advantages"] standardized = (value - value.mean()) / max(1e-4, value.std()) batch.data["advantages"] = standardized batch.shuffle() dummy = np.zeros_like(batch["advantages"]) if not self.config["use_gae"]: batch.data["value_targets"] = dummy batch.data["vf_preds"] = dummy extra_fetches = self.optimizer.step(postprocess_fn=postprocess_samples) kl = np.array(extra_fetches["kl"]).mean(axis=1)[-1] total_loss = np.array(extra_fetches["total_loss"]).mean(axis=1)[-1] policy_loss = np.array(extra_fetches["policy_loss"]).mean(axis=1)[-1] vf_loss = np.array(extra_fetches["vf_loss"]).mean(axis=1)[-1] entropy = np.array(extra_fetches["entropy"]).mean(axis=1)[-1] newkl = self.local_evaluator.for_policy(lambda pi: pi.update_kl(kl)) info = { "kl_divergence": kl, "kl_coefficient": newkl, "total_loss": total_loss, "policy_loss": policy_loss, "vf_loss": vf_loss, "entropy": entropy, } FilterManager.synchronize(self.local_evaluator.filters, self.remote_evaluators) res = collect_metrics(self.local_evaluator, self.remote_evaluators) res = res._replace(info=info) return res def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for ev in self.remote_evaluators: ev.__ray_terminate__.remote() def _save(self, checkpoint_dir): checkpoint_path = self.saver.save(self.local_evaluator.tf_sess, os.path.join(checkpoint_dir, "checkpoint"), global_step=self.iteration) agent_state = ray.get( [a.save.remote() for a in self.remote_evaluators]) extra_data = [self.local_evaluator.save(), agent_state] pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path def _restore(self, checkpoint_path): self.saver.restore(self.local_evaluator.tf_sess, checkpoint_path) extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) self.local_evaluator.restore(extra_data[0]) ray.get([ a.restore.remote(o) for (a, o) in zip(self.remote_evaluators, extra_data[1]) ]) def compute_action(self, observation, state=None): if state is None: state = [] obs = self.local_evaluator.filters["default"](observation, update=False) return self.local_evaluator.for_policy( lambda p: p.compute_single_action(obs, state, is_training=False)[0 ])
class A3CAgent(Agent): _agent_name = "A3C" _default_config = DEFAULT_CONFIG _allow_unknown_subkeys = ["model", "optimizer", "env_config"] @classmethod def default_resource_request(cls, config): cf = dict(cls._default_config, **config) return Resources( cpu=1, gpu=0, extra_cpu=cf["num_workers"], extra_gpu=cf["use_gpu_for_workers"] and cf["num_workers"] or 0) def _init(self): self.policy_cls = get_policy_cls(self.config) if self.config["use_pytorch"]: session_creator = None else: import tensorflow as tf def session_creator(): return tf.Session( config=tf.ConfigProto( intra_op_parallelism_threads=1, inter_op_parallelism_threads=1, gpu_options=tf.GPUOptions(allow_growth=True))) remote_cls = CommonPolicyEvaluator.as_remote( num_gpus=1 if self.config["use_gpu_for_workers"] else 0) self.local_evaluator = CommonPolicyEvaluator( self.env_creator, self.policy_cls, batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", tf_session_creator=session_creator, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) self.remote_evaluators = [ remote_cls.remote( self.env_creator, self.policy_cls, batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", sample_async=True, tf_session_creator=session_creator, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) for i in range(self.config["num_workers"])] self.optimizer = AsyncOptimizer( self.config["optimizer"], self.local_evaluator, self.remote_evaluators) def _train(self): self.optimizer.step() FilterManager.synchronize( self.local_evaluator.filters, self.remote_evaluators) return collect_metrics(self.local_evaluator, self.remote_evaluators) def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for ev in self.remote_evaluators: ev.__ray_terminate__.remote() def _save(self, checkpoint_dir): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint-{}".format(self.iteration)) agent_state = ray.get( [a.save.remote() for a in self.remote_evaluators]) extra_data = { "remote_state": agent_state, "local_state": self.local_evaluator.save() } pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path def _restore(self, checkpoint_path): extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) ray.get([ a.restore.remote(o) for a, o in zip(self.remote_evaluators, extra_data["remote_state"]) ]) self.local_evaluator.restore(extra_data["local_state"]) def compute_action(self, observation, state=None): if state is None: state = [] obs = self.local_evaluator.obs_filter(observation, update=False) return self.local_evaluator.for_policy( lambda p: p.compute_single_action( obs, state, is_training=False)[0])
class DQNAgent(Agent): _agent_name = "DQN" _default_config = DEFAULT_CONFIG _policy_graph = DQNPolicyGraph @classmethod def default_resource_request(cls, config): cf = dict(cls._default_config, **config) return Resources( cpu=1, gpu=cf["gpu"] and 1 or 0, extra_cpu=cf["num_cpus_per_worker"] * cf["num_workers"], extra_gpu=cf["num_gpus_per_worker"] * cf["num_workers"]) def _init(self): adjusted_batch_size = (self.config["sample_batch_size"] + self.config["n_step"] - 1) self.local_evaluator = CommonPolicyEvaluator( self.env_creator, self.config["multiagent"]["policy_graphs"] or self._policy_graph, policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], batch_steps=adjusted_batch_size, batch_mode="truncate_episodes", preprocessor_pref="deepmind", compress_observations=True, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"]) remote_cls = CommonPolicyEvaluator.as_remote( num_cpus=self.config["num_cpus_per_worker"], num_gpus=self.config["num_gpus_per_worker"]) self.remote_evaluators = [ remote_cls.remote(self.env_creator, self._policy_graph, batch_steps=adjusted_batch_size, batch_mode="truncate_episodes", preprocessor_pref="deepmind", compress_observations=True, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, num_envs=self.config["num_envs"], worker_index=i + 1) for i in range(self.config["num_workers"]) ] self.exploration0 = self._make_exploration_schedule(0) self.explorations = [ self._make_exploration_schedule(i) for i in range(self.config["num_workers"]) ] for k in OPTIMIZER_SHARED_CONFIGS: if k not in self.config["optimizer_config"]: self.config["optimizer_config"][k] = self.config[k] self.optimizer = getattr(optimizers, self.config["optimizer_class"])( self.config["optimizer_config"], self.local_evaluator, self.remote_evaluators) self.last_target_update_ts = 0 self.num_target_updates = 0 def _make_exploration_schedule(self, worker_index): # Use either a different `eps` per worker, or a linear schedule. if self.config["per_worker_exploration"]: assert self.config["num_workers"] > 1, \ "This requires multiple workers" return ConstantSchedule(0.4**( 1 + worker_index / float(self.config["num_workers"] - 1) * 7)) return LinearSchedule( schedule_timesteps=int(self.config["exploration_fraction"] * self.config["schedule_max_timesteps"]), initial_p=1.0, final_p=self.config["exploration_final_eps"]) @property def global_timestep(self): return self.optimizer.num_steps_sampled def update_target_if_needed(self): if self.global_timestep - self.last_target_update_ts > \ self.config["target_network_update_freq"]: self.local_evaluator.foreach_policy(lambda p, _: p.update_target()) self.last_target_update_ts = self.global_timestep self.num_target_updates += 1 def _train(self): start_timestep = self.global_timestep while (self.global_timestep - start_timestep < self.config["timesteps_per_iteration"]): self.optimizer.step() self.update_target_if_needed() exp_vals = [self.exploration0.value(self.global_timestep)] self.local_evaluator.foreach_policy( lambda p, _: p.set_epsilon(exp_vals[0])) for i, e in enumerate(self.remote_evaluators): exp_val = self.explorations[i].value(self.global_timestep) e.foreach_policy.remote(lambda p, _: p.set_epsilon(exp_val)) exp_vals.append(exp_val) result = collect_metrics(self.local_evaluator, self.remote_evaluators) return result._replace(info=dict( { "min_exploration": min(exp_vals), "max_exploration": max(exp_vals), "num_target_updates": self.num_target_updates, }, **self.optimizer.stats())) def _stop(self): # workaround for https://github.com/ray-project/ray/issues/1516 for ev in self.remote_evaluators: ev.__ray_terminate__.remote() def _save(self, checkpoint_dir): checkpoint_path = os.path.join(checkpoint_dir, "checkpoint-{}".format(self.iteration)) extra_data = [ self.local_evaluator.save(), ray.get([e.save.remote() for e in self.remote_evaluators]), self.optimizer.save(), self.num_target_updates, self.last_target_update_ts ] pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb")) return checkpoint_path def _restore(self, checkpoint_path): extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb")) self.local_evaluator.restore(extra_data[0]) ray.get([ e.restore.remote(d) for (d, e) in zip(extra_data[1], self.remote_evaluators) ]) self.optimizer.restore(extra_data[2]) self.num_target_updates = extra_data[3] self.last_target_update_ts = extra_data[4] def compute_action(self, observation, state=None): if state is None: state = [] return self.local_evaluator.for_policy( lambda p: p.compute_single_action( observation, state, is_training=False)[0])