Esempio n. 1
0
    def _init(self):
        if self.config["async_updates"]:
            self.local_evaluator = DQNEvaluator(self.registry,
                                                self.env_creator, self.config,
                                                self.logdir)
            remote_cls = ray.remote(
                num_cpus=1, num_gpus=self.config["num_gpus_per_worker"])(
                    DQNReplayEvaluator)
            remote_config = dict(self.config, num_workers=1)
            # In async mode, we create N remote evaluators, each with their
            # own replay buffer (i.e. the replay buffer is sharded).
            self.remote_evaluators = [
                remote_cls.remote(self.registry, self.env_creator,
                                  remote_config, self.logdir)
                for _ in range(self.config["num_workers"])
            ]
            optimizer_cls = AsyncOptimizer
        else:
            self.local_evaluator = DQNReplayEvaluator(self.registry,
                                                      self.env_creator,
                                                      self.config, self.logdir)
            # No remote evaluators. If num_workers > 1, the DQNReplayEvaluator
            # will internally create more workers for parallelism. This means
            # there is only one replay buffer regardless of num_workers.
            self.remote_evaluators = []
            if self.config["multi_gpu"]:
                optimizer_cls = LocalMultiGPUOptimizer
            else:
                optimizer_cls = LocalSyncOptimizer

        self.optimizer = optimizer_cls(self.config["optimizer"],
                                       self.local_evaluator,
                                       self.remote_evaluators)
        self.saver = tf.train.Saver(max_to_keep=None)

        self.global_timestep = 0
        self.last_target_update_ts = 0
        self.num_target_updates = 0
Esempio n. 2
0
class DQNAgent(Agent):
    _agent_name = "DQN"
    _allow_unknown_subkeys = [
        "model", "optimizer", "tf_session_args", "env_config"
    ]
    _default_config = DEFAULT_CONFIG

    def _init(self):
        if self.config["async_updates"]:
            self.local_evaluator = DQNEvaluator(self.registry,
                                                self.env_creator, self.config,
                                                self.logdir)
            remote_cls = ray.remote(
                num_cpus=1, num_gpus=self.config["num_gpus_per_worker"])(
                    DQNReplayEvaluator)
            remote_config = dict(self.config, num_workers=1)
            # In async mode, we create N remote evaluators, each with their
            # own replay buffer (i.e. the replay buffer is sharded).
            self.remote_evaluators = [
                remote_cls.remote(self.registry, self.env_creator,
                                  remote_config, self.logdir)
                for _ in range(self.config["num_workers"])
            ]
            optimizer_cls = AsyncOptimizer
        else:
            self.local_evaluator = DQNReplayEvaluator(self.registry,
                                                      self.env_creator,
                                                      self.config, self.logdir)
            # No remote evaluators. If num_workers > 1, the DQNReplayEvaluator
            # will internally create more workers for parallelism. This means
            # there is only one replay buffer regardless of num_workers.
            self.remote_evaluators = []
            if self.config["multi_gpu"]:
                optimizer_cls = LocalMultiGPUOptimizer
            else:
                optimizer_cls = LocalSyncOptimizer

        self.optimizer = optimizer_cls(self.config["optimizer"],
                                       self.local_evaluator,
                                       self.remote_evaluators)
        self.saver = tf.train.Saver(max_to_keep=None)

        self.global_timestep = 0
        self.last_target_update_ts = 0
        self.num_target_updates = 0

    def _train(self):
        start_timestep = self.global_timestep

        while (self.global_timestep - start_timestep <
               self.config["timesteps_per_iteration"]):

            if self.global_timestep < self.config["learning_starts"]:
                self._populate_replay_buffer()
            else:
                self.optimizer.step()

            stats = self._update_global_stats()

            if self.global_timestep - self.last_target_update_ts > \
                    self.config["target_network_update_freq"]:
                self.local_evaluator.update_target()
                self.last_target_update_ts = self.global_timestep
                self.num_target_updates += 1

        mean_100ep_reward = 0.0
        mean_100ep_length = 0.0
        num_episodes = 0
        exploration = -1

        for s in stats:
            mean_100ep_reward += s["mean_100ep_reward"] / len(stats)
            mean_100ep_length += s["mean_100ep_length"] / len(stats)
            num_episodes += s["num_episodes"]
            exploration = s["exploration"]

        result = TrainingResult(
            episode_reward_mean=mean_100ep_reward,
            episode_len_mean=mean_100ep_length,
            episodes_total=num_episodes,
            timesteps_this_iter=self.global_timestep - start_timestep,
            info=dict(
                {
                    "exploration": exploration,
                    "num_target_updates": self.num_target_updates,
                }, **self.optimizer.stats()))

        return result

    def _update_global_stats(self):
        if self.remote_evaluators:
            stats = ray.get([e.stats.remote() for e in self.remote_evaluators])
        else:
            stats = self.local_evaluator.stats()
            if not isinstance(stats, list):
                stats = [stats]
        new_timestep = sum(s["local_timestep"] for s in stats)
        assert new_timestep > self.global_timestep, new_timestep
        self.global_timestep = new_timestep
        self.local_evaluator.set_global_timestep(self.global_timestep)
        for e in self.remote_evaluators:
            e.set_global_timestep.remote(self.global_timestep)
        return stats

    def _populate_replay_buffer(self):
        if self.remote_evaluators:
            for e in self.remote_evaluators:
                e.sample.remote(no_replay=True)
        else:
            self.local_evaluator.sample(no_replay=True)

    def _save(self):
        checkpoint_path = self.saver.save(self.local_evaluator.sess,
                                          os.path.join(self.logdir,
                                                       "checkpoint"),
                                          global_step=self.iteration)
        extra_data = [
            self.local_evaluator.save(),
            ray.get([e.save.remote() for e in self.remote_evaluators]),
            self.global_timestep, self.num_target_updates,
            self.last_target_update_ts
        ]
        pickle.dump(extra_data, open(checkpoint_path + ".extra_data", "wb"))
        return checkpoint_path

    def _restore(self, checkpoint_path):
        self.saver.restore(self.local_evaluator.sess, checkpoint_path)
        extra_data = pickle.load(open(checkpoint_path + ".extra_data", "rb"))
        self.local_evaluator.restore(extra_data[0])
        ray.get([
            e.restore.remote(d)
            for (d, e) in zip(extra_data[1], self.remote_evaluators)
        ])
        self.global_timestep = extra_data[2]
        self.num_target_updates = extra_data[3]
        self.last_target_update_ts = extra_data[4]

    def compute_action(self, observation):
        return self.local_evaluator.dqn_graph.act(self.local_evaluator.sess,
                                                  np.array(observation)[None],
                                                  0.0)[0]