Esempio n. 1
0
    def _initialize(self):
        """Initialize non-common things."""
        if not self.is_test:
            # replay memory for a single step
            self.memory = ReplayBuffer(
                self.hyper_params.buffer_size,
                self.hyper_params.batch_size,
            )
            self.memory = PrioritizedBufferWrapper(
                self.memory, alpha=self.hyper_params.per_alpha)

            # replay memory for multi-steps
            if self.use_n_step:
                self.memory_n = ReplayBuffer(
                    self.hyper_params.buffer_size,
                    self.hyper_params.batch_size,
                    n_step=self.hyper_params.n_step,
                    gamma=self.hyper_params.gamma,
                )

        build_args = dict(
            hyper_params=self.hyper_params,
            log_cfg=self.log_cfg,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.n,
            is_test=self.is_test,
            load_from=self.load_from,
        )
        self.learner = build_learner(self.learner_cfg, build_args)
Esempio n. 2
0
    def _initialize(self):
        """Initialize non-common things."""
        self.per_beta = self.hyper_params.per_beta

        self.use_n_step = self.hyper_params.n_step > 1

        if not self.args.test:
            # load demo replay memory
            with open(self.args.demo_path, "rb") as f:
                demos = pickle.load(f)

            if self.use_n_step:
                demos, demos_n_step = common_utils.get_n_step_info_from_demo(
                    demos, self.hyper_params.n_step, self.hyper_params.gamma)

                # replay memory for multi-steps
                self.memory_n = ReplayBuffer(
                    max_len=self.hyper_params.buffer_size,
                    batch_size=self.hyper_params.batch_size,
                    n_step=self.hyper_params.n_step,
                    gamma=self.hyper_params.gamma,
                    demo=demos_n_step,
                )

            # replay memory for a single step
            self.memory = ReplayBuffer(
                self.hyper_params.buffer_size,
                self.hyper_params.batch_size,
            )
            self.memory = PrioritizedBufferWrapper(
                self.memory, alpha=self.hyper_params.per_alpha)

        self.learner_cfg.type = "DDPGfDLearner"
        self.learner = build_learner(self.learner_cfg)
Esempio n. 3
0
    def _spawn(self):
        """Intialize distributed worker, learner and centralized replay buffer."""
        replay_buffer = ReplayBuffer(
            self.hyper_params.buffer_size,
            self.hyper_params.batch_size,
        )
        per_buffer = PrioritizedBufferWrapper(
            replay_buffer, alpha=self.hyper_params.per_alpha)
        self.global_buffer = ApeXBufferWrapper.remote(per_buffer, self.args,
                                                      self.hyper_params,
                                                      self.comm_cfg)

        learner = build_learner(self.learner_cfg)
        self.learner = ApeXLearnerWrapper.remote(learner, self.comm_cfg)

        state_dict = learner.get_state_dict()
        worker_build_args = dict(args=self.args, state_dict=state_dict)

        self.workers = []
        self.num_workers = self.hyper_params.num_workers
        for rank in range(self.num_workers):
            worker_build_args["rank"] = rank
            worker = build_worker(self.worker_cfg,
                                  build_args=worker_build_args)
            apex_worker = ApeXWorkerWrapper.remote(worker, self.args,
                                                   self.comm_cfg)
            self.workers.append(apex_worker)

        self.logger = build_logger(self.logger_cfg)

        self.processes = self.workers + [
            self.learner, self.global_buffer, self.logger
        ]
Esempio n. 4
0
    def _initialize(self):
        """Initialize non-common things."""
        if not self.args.test:
            # load demo replay memory
            demos = self._load_demos()

            if self.use_n_step:
                demos, demos_n_step = common_utils.get_n_step_info_from_demo(
                    demos, self.hyper_params.n_step, self.hyper_params.gamma)

                self.memory_n = ReplayBuffer(
                    max_len=self.hyper_params.buffer_size,
                    batch_size=self.hyper_params.batch_size,
                    n_step=self.hyper_params.n_step,
                    gamma=self.hyper_params.gamma,
                    demo=demos_n_step,
                )

            # replay memory
            self.memory = ReplayBuffer(
                self.hyper_params.buffer_size,
                self.hyper_params.batch_size,
                demo=demos,
            )
            self.memory = PrioritizedBufferWrapper(
                self.memory,
                alpha=self.hyper_params.per_alpha,
                epsilon_d=self.hyper_params.per_eps_demo,
            )

        self.learner_cfg.type = "DQfDLearner"
        self.learner = build_learner(self.learner_cfg)
Esempio n. 5
0
    def _initialize(self):
        """Initialize non-common things."""
        if not self.args.test:

            self.memory = RecurrentReplayBuffer(
                self.hyper_params.buffer_size,
                self.hyper_params.batch_size,
                self.hyper_params.sequence_size,
                self.hyper_params.overlap_size,
                n_step=self.hyper_params.n_step,
                gamma=self.hyper_params.gamma,
            )
            self.memory = PrioritizedBufferWrapper(
                self.memory, alpha=self.hyper_params.per_alpha)

            # replay memory for multi-steps
            if self.use_n_step:
                self.memory_n = RecurrentReplayBuffer(
                    self.hyper_params.buffer_size,
                    self.hyper_params.batch_size,
                    self.hyper_params.sequence_size,
                    self.hyper_params.overlap_size,
                    n_step=self.hyper_params.n_step,
                    gamma=self.hyper_params.gamma,
                )

        self.learner = build_learner(self.learner_cfg)
def generate_prioritized_buffer(
    buffer_length: int, batch_size: int, idx_lst=None, prior_lst=None
) -> Tuple[PrioritizedBufferWrapper, List]:
    """Generate Prioritized Replay Buffer with random Prior."""
    buffer = ReplayBuffer(max_len=buffer_length, batch_size=batch_size)
    prioritized_buffer = PrioritizedBufferWrapper(buffer)
    priority = np.random.randint(10, size=buffer_length)
    for i, j in enumerate(priority):
        prioritized_buffer.sum_tree[i] = j
    if idx_lst:
        for i, j in list(zip(idx_lst, prior_lst)):
            priority[i] = j
            prioritized_buffer.sum_tree[i] = j

    prop_lst = [i / sum(priority) for i in priority]

    return prioritized_buffer, prop_lst
Esempio n. 7
0
    def _initialize(self):
        """Initialize non-common things."""
        self.per_beta = self.hyper_params.per_beta
        self.use_n_step = self.hyper_params.n_step > 1

        if not self.is_test:
            # load demo replay memory
            with open(self.hyper_params.demo_path, "rb") as f:
                demos = pickle.load(f)

            if self.use_n_step:
                demos, demos_n_step = common_utils.get_n_step_info_from_demo(
                    demos, self.hyper_params.n_step, self.hyper_params.gamma)

                # replay memory for multi-steps
                self.memory_n = ReplayBuffer(
                    max_len=self.hyper_params.buffer_size,
                    batch_size=self.hyper_params.batch_size,
                    n_step=self.hyper_params.n_step,
                    gamma=self.hyper_params.gamma,
                    demo=demos_n_step,
                )

            # replay memory for a single step
            self.memory = ReplayBuffer(
                self.hyper_params.buffer_size,
                self.hyper_params.batch_size,
                demo=demos,
            )
            self.memory = PrioritizedBufferWrapper(
                self.memory,
                alpha=self.hyper_params.per_alpha,
                epsilon_d=self.hyper_params.per_eps_demo,
            )

        build_args = dict(
            hyper_params=self.hyper_params,
            log_cfg=self.log_cfg,
            noise_cfg=self.noise_cfg,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.shape[0],
            is_test=self.is_test,
            load_from=self.load_from,
        )
        self.learner = build_learner(self.learner_cfg, build_args)
Esempio n. 8
0
    def _initialize(self):
        """Initialize non-common things."""
        if not self.is_test:
            # load demo replay memory
            demos = self._load_demos()

            if self.use_n_step:
                demos, demos_n_step = common_utils.get_n_step_info_from_demo(
                    demos, self.hyper_params.n_step, self.hyper_params.gamma)

                self.memory_n = ReplayBuffer(
                    max_len=self.hyper_params.buffer_size,
                    batch_size=self.hyper_params.batch_size,
                    n_step=self.hyper_params.n_step,
                    gamma=self.hyper_params.gamma,
                    demo=demos_n_step,
                )

            # replay memory
            self.memory = ReplayBuffer(
                self.hyper_params.buffer_size,
                self.hyper_params.batch_size,
                demo=demos,
            )
            self.memory = PrioritizedBufferWrapper(
                self.memory,
                alpha=self.hyper_params.per_alpha,
                epsilon_d=self.hyper_params.per_eps_demo,
            )

        build_args = dict(
            hyper_params=self.hyper_params,
            log_cfg=self.log_cfg,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.n,
            is_test=self.is_test,
            load_from=self.load_from,
        )
        self.learner_cfg.type = "DQfDLearner"
        self.learner = build_learner(self.learner_cfg, build_args)
Esempio n. 9
0
    def _spawn(self):
        """Intialize distributed worker, learner and centralized replay buffer."""
        replay_buffer = ReplayBuffer(
            self.hyper_params.buffer_size,
            self.hyper_params.batch_size,
        )
        per_buffer = PrioritizedBufferWrapper(
            replay_buffer, alpha=self.hyper_params.per_alpha)
        self.global_buffer = ApeXBufferWrapper.remote(per_buffer,
                                                      self.hyper_params,
                                                      self.comm_cfg)

        # Build learner
        learner_build_args = dict(
            hyper_params=self.hyper_params,
            log_cfg=self.log_cfg,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.n,
            is_test=self.is_test,
            load_from=self.load_from,
        )
        learner = build_learner(self.learner_cfg, learner_build_args)
        self.learner = ApeXLearnerWrapper.remote(learner, self.comm_cfg)

        # Build workers
        state_dict = learner.get_state_dict()
        worker_build_args = dict(
            hyper_params=self.hyper_params,
            backbone=self.learner_cfg.backbone,
            head=self.learner_cfg.head,
            loss_type=self.learner_cfg.loss_type,
            state_dict=state_dict,
            env_name=self.env_info.name,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.n,
            is_atari=self.env_info.is_atari,
            max_episode_steps=self.max_episode_steps,
        )
        self.workers = []
        self.num_workers = self.hyper_params.num_workers
        for rank in range(self.num_workers):
            worker_build_args["rank"] = rank
            worker = build_worker(self.worker_cfg,
                                  build_args=worker_build_args)
            apex_worker = ApeXWorkerWrapper.remote(worker, self.comm_cfg)
            self.workers.append(apex_worker)

        # Build logger
        logger_build_args = dict(
            log_cfg=self.log_cfg,
            comm_cfg=self.comm_cfg,
            backbone=self.learner_cfg.backbone,
            head=self.learner_cfg.head,
            env_name=self.env_info.name,
            is_atari=self.env_info.is_atari,
            state_size=self.env_info.observation_space.shape,
            output_size=self.env_info.action_space.n,
            max_update_step=self.hyper_params.max_update_step,
            episode_num=self.episode_num,
            max_episode_steps=self.max_episode_steps,
            is_log=self.is_log,
            is_render=self.is_render,
            interim_test_num=self.interim_test_num,
        )

        self.logger = build_logger(self.logger_cfg, logger_build_args)

        self.processes = self.workers + [
            self.learner, self.global_buffer, self.logger
        ]