def _spawn(self): """Intialize distributed worker, learner and centralized replay buffer.""" replay_buffer = ReplayBuffer( self.hyper_params.buffer_size, self.hyper_params.batch_size, ) per_buffer = PrioritizedBufferWrapper( replay_buffer, alpha=self.hyper_params.per_alpha) self.global_buffer = ApeXBufferWrapper.remote(per_buffer, self.args, self.hyper_params, self.comm_cfg) learner = build_learner(self.learner_cfg) self.learner = ApeXLearnerWrapper.remote(learner, self.comm_cfg) state_dict = learner.get_state_dict() worker_build_args = dict(args=self.args, state_dict=state_dict) self.workers = [] self.num_workers = self.hyper_params.num_workers for rank in range(self.num_workers): worker_build_args["rank"] = rank worker = build_worker(self.worker_cfg, build_args=worker_build_args) apex_worker = ApeXWorkerWrapper.remote(worker, self.args, self.comm_cfg) self.workers.append(apex_worker) self.logger = build_logger(self.logger_cfg) self.processes = self.workers + [ self.learner, self.global_buffer, self.logger ]
def test(self): """Load model from checkpoint and run logger for testing.""" # NOTE: You could also load the Ape-X trained model on the single agent DQN self.logger = build_logger(self.logger_cfg) self.logger.load_params.remote(self.args.load_from) ray.get([self.logger.test.remote(update_step=0, interim_test=False)]) print("Exiting testing...")
def _spawn(self): """Intialize distributed worker, learner and centralized replay buffer.""" replay_buffer = ReplayBuffer( self.hyper_params.buffer_size, self.hyper_params.batch_size, ) per_buffer = PrioritizedBufferWrapper( replay_buffer, alpha=self.hyper_params.per_alpha) self.global_buffer = ApeXBufferWrapper.remote(per_buffer, self.hyper_params, self.comm_cfg) # Build learner learner_build_args = dict( hyper_params=self.hyper_params, log_cfg=self.log_cfg, env_name=self.env_info.name, state_size=self.env_info.observation_space.shape, output_size=self.env_info.action_space.n, is_test=self.is_test, load_from=self.load_from, ) learner = build_learner(self.learner_cfg, learner_build_args) self.learner = ApeXLearnerWrapper.remote(learner, self.comm_cfg) # Build workers state_dict = learner.get_state_dict() worker_build_args = dict( hyper_params=self.hyper_params, backbone=self.learner_cfg.backbone, head=self.learner_cfg.head, loss_type=self.learner_cfg.loss_type, state_dict=state_dict, env_name=self.env_info.name, state_size=self.env_info.observation_space.shape, output_size=self.env_info.action_space.n, is_atari=self.env_info.is_atari, max_episode_steps=self.max_episode_steps, ) self.workers = [] self.num_workers = self.hyper_params.num_workers for rank in range(self.num_workers): worker_build_args["rank"] = rank worker = build_worker(self.worker_cfg, build_args=worker_build_args) apex_worker = ApeXWorkerWrapper.remote(worker, self.comm_cfg) self.workers.append(apex_worker) # Build logger logger_build_args = dict( log_cfg=self.log_cfg, comm_cfg=self.comm_cfg, backbone=self.learner_cfg.backbone, head=self.learner_cfg.head, env_name=self.env_info.name, is_atari=self.env_info.is_atari, state_size=self.env_info.observation_space.shape, output_size=self.env_info.action_space.n, max_update_step=self.hyper_params.max_update_step, episode_num=self.episode_num, max_episode_steps=self.max_episode_steps, is_log=self.is_log, is_render=self.is_render, interim_test_num=self.interim_test_num, ) self.logger = build_logger(self.logger_cfg, logger_build_args) self.processes = self.workers + [ self.learner, self.global_buffer, self.logger ]