Exemplo n.º 1
0
    def act(self, step_wrappers):
        # """""""""""""""""""""
        # Act
        # """""""""""""""""""""
        SeatActorBase.act_mixed(step_wrappers=step_wrappers,
                                owner=self.seat_id,
                                br_learner=self.br_learner,
                                avg_learner=self.avg_learner,
                                current_policy_tags=self._current_policy_tags,
                                random_prob=self.br_learner.eps)

        # """""""""""""""""""""
        # Add to memories
        # """""""""""""""""""""
        for sw in step_wrappers:
            e_i = sw.env_idx
            if (self._current_policy_tags[e_i] == SeatActorBase.BR) and (
                    self._t_prof.add_random_actions_to_buffer or
                (not sw.action_was_random)):
                self._avg_buf_savers[e_i].add_step(
                    pub_obs=sw.obs,
                    a=sw.action,
                    legal_actions_mask=rl_util.get_legal_action_mask_np(
                        n_actions=self._env_bldr.N_ACTIONS,
                        legal_actions_list=sw.legal_actions_list))
            self._br_memory_savers[e_i].add_experience(
                obs_t_before_acted=sw.obs,
                a_selected_t=sw.action,
                legal_actions_list_t=sw.legal_actions_list)
Exemplo n.º 2
0
 def act_for_avg_opp(self, step_wrappers):
     """
     Purely random because that's how it should be for correct reach
     """
     SeatActorBase.act_avg(
         step_wrappers=step_wrappers,
         owner=self.owner,
         avg_learner=self.avg_learner,
     )
Exemplo n.º 3
0
 def act_for_br_opp(self, step_wrappers):
     """ Anticipatory; greedy BR + AVG """
     SeatActorBase.act_mixed(
         step_wrappers=step_wrappers,
         br_learner=self.br_learner,
         owner=self.owner,
         avg_learner=self.avg_learner,
         current_policy_tags=self._current_policy_tags_OPP_BR,
         random_prob=0)
Exemplo n.º 4
0
 def act_for_avg_opp(self, step_wrappers):
     """
     Purely random because that's how it should be for correct reach
     """
     SeatActorBase.act_mixed(
         step_wrappers=step_wrappers,
         br_learner=self.br_learner,
         owner=self.owner,
         avg_learner=self.avg_learner,
         current_policy_tags=self._current_policy_tags_O_AVG,
         explore=True)
Exemplo n.º 5
0
    def act_for_br_trav(self, step_wrappers):
        # Act
        SeatActorBase.act_eps_greedy(step_wrappers=step_wrappers, br_learner=self.br_learner, owner=self.owner,
                                     random_prob=self._constant_eps)

        # Add to memories
        for sw in step_wrappers:
            e_i = sw.env_idx
            self._br_memory_savers[e_i].add_experience(obs_t_before_acted=sw.obs,
                                                       a_selected_t=sw.action,
                                                       legal_actions_list_t=sw.legal_actions_list)
Exemplo n.º 6
0
 def update_if_terminal_for_br(self, step_wrappers, is_traverser=False):
     for sw in step_wrappers:
         if sw.TERMINAL:
             if is_traverser:
                 self._br_memory_savers[sw.env_idx].add_to_buf(
                     reward_p=sw.term_rew_all[self.owner],
                     terminal_obs=sw.term_obs,
                 )
                 self._br_memory_savers[sw.env_idx].reset(
                     range_idx=sw.range_idxs[self.owner])
                 self._current_policy_tags_T_BR[
                     sw.env_idx] = SeatActorBase.pick_training_policy(
                         br_prob=self.sampler.antic)
             else:
                 self._current_policy_tags_OPP_BR[
                     sw.env_idx] = SeatActorBase.pick_training_policy(
                         br_prob=self.sampler.antic)
Exemplo n.º 7
0
    def act_for_br_trav(self, step_wrappers):
        # Act
        SeatActorBase.act_mixed(
            step_wrappers=step_wrappers,
            br_learner=self.br_learner,
            owner=self.owner,
            avg_learner=self.avg_learner,
            current_policy_tags=self._current_policy_tags_T_BR,
            explore=True)

        # Add to memories
        for sw in step_wrappers:
            e_i = sw.env_idx
            self._br_memory_savers[e_i].add_experience(
                obs_t_before_acted=sw.obs,
                a_selected_t=sw.action,
                legal_actions_list_t=sw.legal_actions_list)
Exemplo n.º 8
0
    def init(self, sws_br, sws_avg, nfsp_iter):
        self._current_policy_tags_T_BR = np.empty(
            shape=self._t_prof.n_steps_br_per_iter_per_la, dtype=np.int32)
        self._current_policy_tags_OPP_BR = np.empty(
            shape=self._t_prof.n_steps_br_per_iter_per_la, dtype=np.int32)
        for sw in sws_br:
            self._current_policy_tags_OPP_BR[
                sw.env_idx] = SeatActorBase.pick_training_policy(
                    br_prob=self.sampler.antic)
            self._current_policy_tags_T_BR[
                sw.env_idx] = SeatActorBase.pick_training_policy(
                    br_prob=self.sampler.antic)
            self._br_memory_savers[sw.env_idx].reset(
                range_idx=sw.range_idxs[self.owner])

        for sw in sws_avg:
            self._avg_memory_savers[sw.env_idx].reset(
                range_idx=sw.range_idxs[self.owner],
                sample_weight=nfsp_iter if self._t_prof.linear else 1)
Exemplo n.º 9
0
 def update_if_terminal_for_avg(self,
                                step_wrappers,
                                nfsp_iter,
                                is_traverser=False):
     for sw in step_wrappers:
         if sw.TERMINAL:
             if is_traverser:
                 self._avg_memory_savers[sw.env_idx].reset(
                     range_idx=sw.range_idxs[self.owner],
                     sample_weight=nfsp_iter if self._t_prof.linear else 1)
             else:
                 self._current_policy_tags_O_AVG[
                     sw.env_idx] = SeatActorBase.pick_training_policy(
                         br_prob=self.sampler.antic)
Exemplo n.º 10
0
    def update_if_terminal(self, step_wrappers, nfsp_iter):
        for sw in step_wrappers:
            if sw.TERMINAL:
                self._br_memory_savers[sw.env_idx].add_to_buf(
                    reward_p=sw.term_rew_all[self.seat_id],
                    terminal_obs=sw.term_obs,
                )
                self._br_memory_savers[sw.env_idx].reset(
                    range_idx=sw.range_idxs[self.seat_id])
                self._avg_buf_savers[sw.env_idx].reset(
                    range_idx=sw.range_idxs[self.seat_id],
                    sample_weight=nfsp_iter if self._t_prof.linear else 1)

                self._current_policy_tags[
                    sw.env_idx] = SeatActorBase.pick_training_policy(
                        br_prob=self.sampler.antic)
Exemplo n.º 11
0
 def act_for_avg_trav(self, step_wrappers):
     """ BR greedy """
     with torch.no_grad():
         if len(step_wrappers) > 0:
             actions, _ = SeatActorBase.choose_a_br(step_wrappers=step_wrappers, owner=self.owner,
                                                    br_learner=self.br_learner, random_prob=0)
             for a, sw in zip(actions, step_wrappers):
                 a = a.item()
                 sw.action = a
                 sw.action_was_random = False
                 self._avg_memory_savers[sw.env_idx].add_step(pub_obs=sw.obs,
                                                              a=a,
                                                              legal_actions_mask=rl_util.get_legal_action_mask_np(
                                                                  n_actions=self._env_bldr.N_ACTIONS,
                                                                  legal_actions_list=sw.legal_actions_list)
                                                              )