Exemple #1
0
    def get_a_probs_for_each_hand(self, pub_obs, legal_actions_list):
        with torch.no_grad():
            mask = rl_util.get_legal_action_mask_torch(
                n_actions=self._env_bldr.N_ACTIONS,
                legal_actions_list=legal_actions_list,
                device=self.device,
                dtype=torch.uint8)
            mask = mask.unsqueeze(0).expand(self._env_bldr.rules.RANGE_SIZE,
                                            -1)

            pred = self._net(pub_obses=[pub_obs] *
                             self._env_bldr.rules.RANGE_SIZE,
                             range_idxs=self._all_range_idxs,
                             legal_action_masks=mask)

            return nnf.softmax(pred, dim=1).cpu().numpy()
Exemple #2
0
    def _get_a_probs_of_hands(self, pub_obs, range_idxs_tensor, legal_actions_list):
        with torch.no_grad():
            n_hands = range_idxs_tensor.size(0)

            if self._adv_net is None:  # at iteration 0
                uniform_even_legal = torch.zeros((self._env_bldr.N_ACTIONS,), dtype=torch.float32, device=self._device)
                uniform_even_legal[legal_actions_list] = 1.0 / len(legal_actions_list)  # always >0
                uniform_even_legal = uniform_even_legal.unsqueeze(0).expand(n_hands, self._env_bldr.N_ACTIONS)
                return uniform_even_legal.cpu().numpy()

            else:
                legal_action_masks = rl_util.get_legal_action_mask_torch(n_actions=self._env_bldr.N_ACTIONS,
                                                                         legal_actions_list=legal_actions_list,
                                                                         device=self._device, dtype=torch.float32)
                legal_action_masks = legal_action_masks.unsqueeze(0).expand(n_hands, -1)

                advantages = self._adv_net(pub_obses=[pub_obs] * n_hands,
                                           range_idxs=range_idxs_tensor,
                                           legal_action_masks=legal_action_masks)

                # """"""""""""""""""""
                relu_advantages = F.relu(advantages, inplace=False)  # Cause the sum of *positive* regret matters in CFR
                sum_pos_adv_expanded = relu_advantages.sum(1).unsqueeze(-1).expand_as(relu_advantages)

                # """"""""""""""""""""
                # In case all negative
                # """"""""""""""""""""
                best_legal_deterministic = torch.zeros((n_hands, self._env_bldr.N_ACTIONS,), dtype=torch.float32,
                                                       device=self._device)
                bests = torch.argmax(
                    torch.where(legal_action_masks.byte(), advantages, torch.full_like(advantages, fill_value=-10e20)),
                    dim=1
                )

                _batch_arranged = torch.arange(n_hands, device=self._device, dtype=torch.long)
                best_legal_deterministic[_batch_arranged, bests] = 1

                # """"""""""""""""""""
                # Strategy
                # """"""""""""""""""""
                strategy = torch.where(
                    sum_pos_adv_expanded > 0,
                    relu_advantages / sum_pos_adv_expanded,
                    best_legal_deterministic,
                )

                return strategy.cpu().numpy()
Exemple #3
0
    def _traverser_act(self, start_state_dict, traverser, trav_depth,
                       plyrs_range_idxs, iteration_strats, sample_reach,
                       cfr_iter):
        """
        Last state values are the average, not the sum of all samples of that state since we add
        v~(I) = * p(a) * |A(I)|. Since we sample multiple actions on each traverser node, we have to average over
        their returns like: v~(I) * Sum_a=0_N (v~(I|a) * p(a) * ||A(I)|| / N).
        """
        self.total_node_count_traversed += 1
        self._env_wrapper.load_state_dict(start_state_dict)
        legal_actions_list = self._env_wrapper.env.get_legal_actions()
        legal_action_mask = rl_util.get_legal_action_mask_torch(
            n_actions=self._env_bldr.N_ACTIONS,
            legal_actions_list=legal_actions_list,
            device=self._adv_buffers[traverser].device,
            dtype=torch.float32)
        current_pub_obs = self._env_wrapper.get_current_obs()
        round_t = self._env_wrapper.env.current_round
        traverser_range_idx = plyrs_range_idxs[traverser]

        # """""""""""""""""""""""""
        # Strategy
        # """""""""""""""""""""""""
        strat_i = iteration_strats[traverser].get_a_probs(
            pub_obses=[current_pub_obs],
            range_idxs=[traverser_range_idx],
            legal_actions_lists=[legal_actions_list],
            to_np=False,
        )[0]

        # """""""""""""""""""""""""
        # Sample action
        # """""""""""""""""""""""""
        n_legal_actions = len(legal_actions_list)
        sample_strat = (1 - self._eps) * strat_i + self._eps * (
            legal_action_mask.cpu() / n_legal_actions)
        a = np.random.choice(self._actions_arranged, p=sample_strat.numpy())

        # Step
        pub_obs_tp1, rew_for_all, done, _info = self._env_wrapper.step(a)
        round_tp1 = self._env_wrapper.env.current_round

        # """""""""""""""""""""""""
        # Utility
        # """""""""""""""""""""""""
        utility = self._get_utility(
            traverser=traverser,
            acting_player=traverser,
            u_bootstrap=rew_for_all[traverser]
            if done else self._recursive_traversal(
                start_state_dict=self._env_wrapper.state_dict(),
                traverser=traverser,
                trav_depth=trav_depth + 1,
                plyrs_range_idxs=plyrs_range_idxs,
                iteration_strats=iteration_strats,
                cfr_iter=cfr_iter,
                sample_reach=sample_reach * sample_strat[a] *
                n_legal_actions),  # recursion
            pub_obs=current_pub_obs,
            range_idx_trav=plyrs_range_idxs[traverser],
            range_idx_opp=plyrs_range_idxs[1 - traverser],
            legal_actions_list=legal_actions_list,
            legal_action_mask=legal_action_mask,
            a=a,
            round_t=round_t,
            round_tp1=round_tp1,
            sample_strat=sample_strat,
            pub_obs_tp1=pub_obs_tp1,
        )

        # Regret
        aprx_imm_reg = torch.full(size=(self._env_bldr.N_ACTIONS, ),
                                  fill_value=-(utility * strat_i).sum(),
                                  dtype=torch.float32,
                                  device=self._adv_buffers[traverser].device)
        aprx_imm_reg += utility
        aprx_imm_reg *= legal_action_mask

        # add current datapoint to ADVBuf
        self._adv_buffers[traverser].add(
            pub_obs=current_pub_obs,
            range_idx=traverser_range_idx,
            legal_action_mask=legal_action_mask,
            adv=aprx_imm_reg,
            iteration=(cfr_iter + 1) / sample_reach,
        )

        # if trav_depth == 0 and traverser == 0:
        #     self._reg_buf[traverser_range_idx].append(aprx_imm_reg.clone().cpu().numpy())
        return (utility * strat_i).sum()
Exemple #4
0
    def _any_non_traverser_act(self, start_state_dict, traverser,
                               plyrs_range_idxs, trav_depth, iteration_strats,
                               sample_reach, cfr_iter):
        self.total_node_count_traversed += 1
        self._env_wrapper.load_state_dict(start_state_dict)
        p_id_acting = self._env_wrapper.env.current_player.seat_id

        current_pub_obs = self._env_wrapper.get_current_obs()
        range_idx = plyrs_range_idxs[p_id_acting]
        legal_actions_list = self._env_wrapper.env.get_legal_actions()
        legal_action_mask = rl_util.get_legal_action_mask_torch(
            n_actions=self._env_bldr.N_ACTIONS,
            legal_actions_list=legal_actions_list,
            device=self._adv_buffers[traverser].device,
            dtype=torch.float32)
        round_t = self._env_wrapper.env.current_round

        # """""""""""""""""""""""""
        # The players strategy
        # """""""""""""""""""""""""
        strat_opp = iteration_strats[p_id_acting].get_a_probs(
            pub_obses=[current_pub_obs],
            range_idxs=[range_idx],
            legal_actions_lists=[legal_actions_list],
            to_np=False)[0]

        # """""""""""""""""""""""""
        # Adds to opponent's
        # average buffer if
        # applicable
        # """""""""""""""""""""""""
        if self._avrg_buffers is not None:
            self._avrg_buffers[p_id_acting].add(
                pub_obs=current_pub_obs,
                range_idx=range_idx,
                legal_actions_list=legal_actions_list,
                a_probs=strat_opp.to(
                    self._avrg_buffers[p_id_acting].device).squeeze(),
                iteration=cfr_iter + 1)

        # """""""""""""""""""""""""
        # Execute action from strat
        # """""""""""""""""""""""""
        a = torch.multinomial(strat_opp.cpu(), num_samples=1).item()
        pub_obs_tp1, rew_for_all, done, _info = self._env_wrapper.step(a)
        round_tp1 = self._env_wrapper.env.current_round

        # """""""""""""""""""""""""
        # Utility
        # """""""""""""""""""""""""
        utility = self._get_utility(
            traverser=traverser,
            acting_player=1 - traverser,
            u_bootstrap=rew_for_all[traverser]
            if done else self._recursive_traversal(
                start_state_dict=self._env_wrapper.state_dict(),
                traverser=traverser,
                trav_depth=trav_depth,
                plyrs_range_idxs=plyrs_range_idxs,
                iteration_strats=iteration_strats,
                cfr_iter=cfr_iter,
                sample_reach=sample_reach,
            ),
            pub_obs=current_pub_obs,
            range_idx_trav=plyrs_range_idxs[traverser],
            range_idx_opp=plyrs_range_idxs[1 - traverser],
            legal_actions_list=legal_actions_list,
            legal_action_mask=legal_action_mask,
            a=a,
            round_t=round_t,
            round_tp1=round_tp1,
            sample_strat=strat_opp,
            pub_obs_tp1=pub_obs_tp1,
        )

        return (utility * strat_opp).sum()
Exemple #5
0
    def _traverser_act(self, start_state_dict, traverser, trav_depth,
                       plyrs_range_idxs, iteration_strats, cfr_iter):
        """
        Last state values are the average, not the sum of all samples of that state since we add
        v~(I) = * p(a) * |A(I)|. Since we sample multiple actions on each traverser node, we have to average over
        their returns like: v~(I) * Sum_a=0_N (v~(I|a) * p(a) * ||A(I)|| / N).
        """
        self._env_wrapper.load_state_dict(start_state_dict)
        legal_actions_list = self._env_wrapper.env.get_legal_actions()
        legal_action_mask = rl_util.get_legal_action_mask_torch(
            n_actions=self._env_bldr.N_ACTIONS,
            legal_actions_list=legal_actions_list,
            device=self._adv_buffers[traverser].device,
            dtype=torch.float32)
        current_pub_obs = self._env_wrapper.get_current_obs()

        traverser_range_idx = plyrs_range_idxs[traverser]

        # """""""""""""""""""""""""
        # Sample actions
        # """""""""""""""""""""""""
        n_legal_actions = len(legal_actions_list)
        n_actions_to_smpl = self._get_n_a_to_sample(
            trav_depth=trav_depth, n_legal_actions=n_legal_actions)
        _idxs = np.arange(n_legal_actions)
        np.random.shuffle(_idxs)
        _idxs = _idxs[:n_actions_to_smpl]
        actions = [legal_actions_list[i] for i in _idxs]

        strat_i = iteration_strats[traverser].get_a_probs(
            pub_obses=[current_pub_obs],
            range_idxs=[traverser_range_idx],
            legal_actions_lists=[legal_actions_list],
            to_np=True)[0]

        cumm_rew = 0.0
        aprx_imm_reg = torch.zeros(size=(self._env_bldr.N_ACTIONS, ),
                                   dtype=torch.float32,
                                   device=self._adv_buffers[traverser].device)

        # """""""""""""""""""""""""
        # Create next states
        # """""""""""""""""""""""""
        for _c, a in enumerate(actions):
            strat_i_a = strat_i[a]

            # Re-initialize environment after one action-branch loop finished with current state and random future
            if _c > 0:
                self._env_wrapper.load_state_dict(start_state_dict)
                self._env_wrapper.env.reshuffle_remaining_deck()

            _obs, _rew_for_all, _done, _info = self._env_wrapper.step(a)
            _cfv_traverser_a = _rew_for_all[traverser]

            # Recursion over sub-trees
            if not _done:
                _cfv_traverser_a += self._recursive_traversal(
                    start_state_dict=self._env_wrapper.state_dict(),
                    traverser=traverser,
                    trav_depth=trav_depth + 1,
                    plyrs_range_idxs=plyrs_range_idxs,
                    iteration_strats=iteration_strats,
                    cfr_iter=cfr_iter)

            # accumulate reward for backward-pass on tree
            cumm_rew += strat_i_a * _cfv_traverser_a

            # """"""""""""""""""""""""
            # Compute the approximate
            # immediate regret
            # """"""""""""""""""""""""
            aprx_imm_reg -= strat_i_a * _cfv_traverser_a  # This is for all actions =/= a

            # add regret for a and undo the change made to a's regret in the line above.
            aprx_imm_reg[a] += _cfv_traverser_a

        aprx_imm_reg *= legal_action_mask / n_actions_to_smpl  # mean over all legal actions sampled

        # add current datapoint to ADVBuf
        self._adv_buffers[traverser].add(
            pub_obs=current_pub_obs,
            range_idx=traverser_range_idx,
            legal_action_mask=legal_action_mask,
            adv=aprx_imm_reg,
            iteration=cfr_iter + 1,
        )
        # increase total entries generated counter to control break point
        self._generated_entries_adv += 1

        # *n_legal_actions    because we multiply by strat.
        # /n_actions_to_smpl  because we summed that many returns and want their mean
        return cumm_rew * n_legal_actions / n_actions_to_smpl
    def _traverser_act(self, start_state_dict, traverser, trav_depth,
                       sample_reach, plyrs_range_idxs, iteration_strats,
                       cfr_iter):
        self.total_node_count_traversed += 1
        """
        Last state values are the average, not the sum of all samples of that state since we add
        v~(I) = * p(a) * |A(I)|. Since we sample multiple actions on each traverser node, we have to average over
        their returns like: v~(I) * Sum_a=0_N (v~(I|a) * p(a) * ||A(I)|| / N).
        """
        self._env_wrapper.load_state_dict(start_state_dict)
        legal_actions_list = self._env_wrapper.env.get_legal_actions()
        legal_action_mask = rl_util.get_legal_action_mask_torch(
            n_actions=self._env_bldr.N_ACTIONS,
            legal_actions_list=legal_actions_list,
            device=self._adv_buffers[traverser].device,
            dtype=torch.float32)
        current_pub_obs = self._env_wrapper.get_current_obs()

        traverser_range_idx = plyrs_range_idxs[traverser]

        strat_i = iteration_strats[traverser].get_a_probs(
            pub_obses=[current_pub_obs],
            range_idxs=[traverser_range_idx],
            legal_actions_lists=[legal_actions_list],
            to_np=True)[0]

        n_legal_actions = len(legal_actions_list)
        n_actions_to_smpl = self._get_n_a_to_sample(
            trav_depth=trav_depth, n_legal_actions=n_legal_actions)
        sample_strat = (1 - self._eps) * strat_i + self._eps * (
            legal_action_mask.cpu().numpy() / n_legal_actions)

        # """""""""""""""""""""""""
        # Create next states
        # """""""""""""""""""""""""
        cumm_rew = 0.0
        aprx_imm_reg = torch.zeros(size=(self._env_bldr.N_ACTIONS, ),
                                   dtype=torch.float32,
                                   device=self._adv_buffers[traverser].device)
        _1_over_sample_reach_sum = 0
        _sample_reach_sum = 0
        for _c, a in enumerate(
                np.random.choice(self._actions_arranged,
                                 p=sample_strat,
                                 size=n_actions_to_smpl)):
            # Re-initialize environment after one action-branch loop finished with current state and random future
            if _c > 0:
                self._env_wrapper.load_state_dict(start_state_dict)
                self._env_wrapper.env.reshuffle_remaining_deck()

            _obs, _rew_for_all, _done, _info = self._env_wrapper.step(a)
            _u_a = _rew_for_all[traverser]

            _sample_reach = sample_reach * sample_strat[a] * n_legal_actions
            _sample_reach_sum += _sample_reach
            _1_over_sample_reach_sum += 1 / _sample_reach
            # Recursion over sub-trees
            if not _done:
                _u_a += self._recursive_traversal(
                    start_state_dict=self._env_wrapper.state_dict(),
                    traverser=traverser,
                    trav_depth=trav_depth + 1,
                    plyrs_range_idxs=plyrs_range_idxs,
                    iteration_strats=iteration_strats,
                    sample_reach=_sample_reach,
                    cfr_iter=cfr_iter)

            # accumulate reward for backward-pass on tree
            cumm_rew += strat_i[a] / sample_strat[a] * _u_a

            # """"""""""""""""""""""""
            # Compute the approximate
            # immediate regret
            # """"""""""""""""""""""""

            _aprx_imm_reg = torch.full_like(
                aprx_imm_reg,
                fill_value=-strat_i[a] * _u_a)  # This is for all actions != a
            _aprx_imm_reg[
                a] += _u_a  # add regret for a and undo the change made to a's regret in the line above.
            aprx_imm_reg += _aprx_imm_reg / _sample_reach

        aprx_imm_reg *= legal_action_mask / _1_over_sample_reach_sum  # mean over all legal actions sampled

        # add current datapoint to ADVBuf
        self._adv_buffers[traverser].add(
            pub_obs=current_pub_obs,
            range_idx=traverser_range_idx,
            legal_action_mask=legal_action_mask,
            adv=aprx_imm_reg,
            iteration=(cfr_iter + 1) / _sample_reach_sum,
        )

        # /n_actions_to_smpl  because we summed that many returns and want their mean
        return cumm_rew / n_actions_to_smpl
 def _get_mask(self, legal_actions_list):
     return rl_util.get_legal_action_mask_torch(
         n_actions=self._env_bldr.N_ACTIONS,
         legal_actions_list=legal_actions_list,
         device=self.device,
         dtype=torch.float32)
    def _any_non_traverser_act(self, start_state_dict, traverser,
                               plyrs_range_idxs, trav_depth, iteration_strats,
                               sample_reach, cfr_iter):
        self.total_node_count_traversed += 1
        self._env_wrapper.load_state_dict(start_state_dict)
        p_id_acting = self._env_wrapper.env.current_player.seat_id

        current_pub_obs = self._env_wrapper.get_current_obs()
        range_idx = plyrs_range_idxs[p_id_acting]
        legal_actions_list = self._env_wrapper.env.get_legal_actions()
        legal_action_mask = rl_util.get_legal_action_mask_torch(
            n_actions=self._env_bldr.N_ACTIONS,
            legal_actions_list=legal_actions_list,
            device=self._adv_buffers[traverser].device,
            dtype=torch.float32)
        # """""""""""""""""""""""""
        # The players strategy
        # """""""""""""""""""""""""
        strat_opp = iteration_strats[p_id_acting].get_a_probs(
            pub_obses=[current_pub_obs],
            range_idxs=[range_idx],
            legal_actions_lists=[legal_actions_list],
            to_np=False)[0]

        # """""""""""""""""""""""""
        # Execute action from strat
        # """""""""""""""""""""""""
        a = torch.multinomial(strat_opp.cpu(), num_samples=1).item()
        pub_obs_tp1, rew_for_all, done, _info = self._env_wrapper.step(a)
        legal_action_mask_tp1 = rl_util.get_legal_action_mask_torch(
            n_actions=self._env_bldr.N_ACTIONS,
            legal_actions_list=self._env_wrapper.env.get_legal_actions(),
            device=self._adv_buffers[traverser].device,
            dtype=torch.float32)

        # """""""""""""""""""""""""
        # Adds to opponent's
        # average buffer if
        # applicable
        # """""""""""""""""""""""""
        if self._avrg_buffers is not None:
            self._avrg_buffers[p_id_acting].add(
                pub_obs=current_pub_obs,
                range_idx=range_idx,
                legal_actions_list=legal_actions_list,
                a_probs=strat_opp.to(
                    self._avrg_buffers[p_id_acting].device).squeeze(),
                iteration=(cfr_iter + 1) / sample_reach)

        # """""""""""""""""""""""""
        # Recursion
        # """""""""""""""""""""""""
        if done:
            strat_tp1 = torch.zeros_like(strat_opp)
            self.total_node_count_traversed += 1
        else:
            u_bootstrap, strat_tp1 = self._recursive_traversal(
                start_state_dict=self._env_wrapper.state_dict(),
                traverser=traverser,
                trav_depth=trav_depth + 1,
                plyrs_range_idxs=plyrs_range_idxs,
                iteration_strats=iteration_strats,
                cfr_iter=cfr_iter,
                sample_reach=sample_reach)

        # """""""""""""""""""""""""
        # Utility
        # """""""""""""""""""""""""
        utility = self._get_utility(
            traverser=traverser,
            u_bootstrap=rew_for_all[traverser] if done else u_bootstrap,
            pub_obs=current_pub_obs,
            range_idx_crazy_embedded=_crazy_embed(
                plyrs_range_idxs=plyrs_range_idxs),
            legal_actions_list=legal_actions_list,
            legal_action_mask=legal_action_mask,
            a=a,
            sample_strat=strat_opp,
        )

        # add datapoint to baseline net
        self._baseline_buf.add(
            pub_obs=current_pub_obs,
            range_idx_crazy_embedded=_crazy_embed(
                plyrs_range_idxs=plyrs_range_idxs),
            legal_action_mask=legal_action_mask,
            r=rew_for_all[0],  # 0 bc we mirror for 1... zero-sum
            a=a,
            done=done,
            pub_obs_tp1=pub_obs_tp1,
            strat_tp1=strat_tp1,
            legal_action_mask_tp1=legal_action_mask_tp1,
        )

        return (utility * strat_opp).sum(), strat_opp