コード例 #1
0
    def get_b(self, pub_obses, range_idxs, legal_actions_lists, to_np=False):
        with torch.no_grad():
            range_idxs = torch.tensor(range_idxs, dtype=torch.long, device=self.device)

            masks = rl_util.batch_get_legal_action_mask_torch(n_actions=self._env_bldr.N_ACTIONS,
                                                              legal_actions_lists=legal_actions_lists,
                                                              device=self.device, dtype=torch.float32)
            self.eval()
            q = self._net(pub_obses=pub_obses, range_idxs=range_idxs, legal_action_masks=masks)
            q *= masks

            if to_np:
                q = q.cpu().numpy()

            return q
コード例 #2
0
    def get_a_probs(self, pub_obses, range_idxs, legal_actions_lists):
        """
        Args:
            pub_obses (list):             list of np arrays of shape [np.arr([history_len, n_features]), ...)
            range_idxs (np.ndarray):    array of range_idxs (one for each pub_obs) tensor([2, 421, 58, 912, ...])
            legal_actions_lists (list:  list of lists. each 2nd level lists contains ints representing legal actions
        """
        with torch.no_grad():
            masks = rl_util.batch_get_legal_action_mask_torch(
                n_actions=self._env_bldr.N_ACTIONS,
                legal_actions_lists=legal_actions_lists,
                device=self.device)

            return self.get_a_probs2(pub_obses=pub_obses,
                                     range_idxs=range_idxs,
                                     legal_action_masks=masks)
コード例 #3
0
    def get_a_probs(self, pub_obses, range_idxs, legal_actions_lists, to_np=True):
        """
        Args:
            pub_obses (list):               batch (list) of np arrays of shape [np.arr([history_len, n_features]), ...)
            range_idxs (list):              batch (list) of range_idxs (one for each pub_obs) [2, 421, 58, 912, ...]
            legal_actions_lists (list):     batch (list) of lists of integers that represent legal actions
        """

        with torch.no_grad():
            masks = rl_util.batch_get_legal_action_mask_torch(n_actions=self._env_bldr.N_ACTIONS,
                                                              legal_actions_lists=legal_actions_lists,
                                                              device=self._device, dtype=torch.float32)
            return self.get_a_probs2(pub_obses=pub_obses,
                                     range_idxs=range_idxs,
                                     legal_action_masks=masks,
                                     to_np=to_np)
コード例 #4
0
    def get_a_probs_tensor(self, pub_obses, range_idxs, legal_actions_lists,
                           hand_info):
        ## with torch.no_grad(): <--- this function is used in training
        masks = rl_util.batch_get_legal_action_mask_torch(
            n_actions=self._env_bldr.N_ACTIONS,
            legal_actions_lists=legal_actions_lists,
            device=self.device)
        masks = masks.view(1, -1)
        pred = self._net(pub_obses=pub_obses,
                         range_idxs=torch.from_numpy(range_idxs).to(
                             dtype=torch.long, device=self.device),
                         legal_action_masks=masks,
                         hand_info=hand_info)

        return nnf.softmax(
            pred, dim=-1).cpu()  #.numpy() removed to stay in Tensor form
コード例 #5
0
    def select_br_a(self, pub_obses, range_idxs, legal_actions_lists, explore=False):
        if explore and (np.random.random() < self._eps):
            return np.array(
                [legal_actions[np.random.randint(len(legal_actions))] for legal_actions in legal_actions_lists]
            )

        with torch.no_grad():
            self.eval()
            range_idxs = torch.tensor(range_idxs, dtype=torch.long, device=self.device)
            q = self._net(pub_obses=pub_obses, range_idxs=range_idxs,
                          legal_action_masks=rl_util.batch_get_legal_action_mask_torch(
                              n_actions=self._env_bldr.N_ACTIONS,
                              legal_actions_lists=legal_actions_lists,
                              device=self.device,
                              dtype=torch.float32)).cpu().numpy()
            for b in range(q.shape[0]):
                illegal_actions = [i for i in self._n_actions_arranged if i not in legal_actions_lists[b]]
                if len(illegal_actions) > 0:
                    illegal_actions = np.array(illegal_actions)
                    q[b, illegal_actions] = -1e20

            return np.argmax(q, axis=1)