def get_a_probs_for_each_hand(self, pub_obs, legal_actions_list): with torch.no_grad(): mask = rl_util.get_legal_action_mask_torch( n_actions=self._env_bldr.N_ACTIONS, legal_actions_list=legal_actions_list, device=self.device, dtype=torch.uint8) mask = mask.unsqueeze(0).expand(self._env_bldr.rules.RANGE_SIZE, -1) pred = self._net(pub_obses=[pub_obs] * self._env_bldr.rules.RANGE_SIZE, range_idxs=self._all_range_idxs, legal_action_masks=mask) return nnf.softmax(pred, dim=1).cpu().numpy()
def _get_a_probs_of_hands(self, pub_obs, range_idxs_tensor, legal_actions_list): with torch.no_grad(): n_hands = range_idxs_tensor.size(0) if self._adv_net is None: # at iteration 0 uniform_even_legal = torch.zeros((self._env_bldr.N_ACTIONS,), dtype=torch.float32, device=self._device) uniform_even_legal[legal_actions_list] = 1.0 / len(legal_actions_list) # always >0 uniform_even_legal = uniform_even_legal.unsqueeze(0).expand(n_hands, self._env_bldr.N_ACTIONS) return uniform_even_legal.cpu().numpy() else: legal_action_masks = rl_util.get_legal_action_mask_torch(n_actions=self._env_bldr.N_ACTIONS, legal_actions_list=legal_actions_list, device=self._device, dtype=torch.float32) legal_action_masks = legal_action_masks.unsqueeze(0).expand(n_hands, -1) advantages = self._adv_net(pub_obses=[pub_obs] * n_hands, range_idxs=range_idxs_tensor, legal_action_masks=legal_action_masks) # """""""""""""""""""" relu_advantages = F.relu(advantages, inplace=False) # Cause the sum of *positive* regret matters in CFR sum_pos_adv_expanded = relu_advantages.sum(1).unsqueeze(-1).expand_as(relu_advantages) # """""""""""""""""""" # In case all negative # """""""""""""""""""" best_legal_deterministic = torch.zeros((n_hands, self._env_bldr.N_ACTIONS,), dtype=torch.float32, device=self._device) bests = torch.argmax( torch.where(legal_action_masks.byte(), advantages, torch.full_like(advantages, fill_value=-10e20)), dim=1 ) _batch_arranged = torch.arange(n_hands, device=self._device, dtype=torch.long) best_legal_deterministic[_batch_arranged, bests] = 1 # """""""""""""""""""" # Strategy # """""""""""""""""""" strategy = torch.where( sum_pos_adv_expanded > 0, relu_advantages / sum_pos_adv_expanded, best_legal_deterministic, ) return strategy.cpu().numpy()
def _traverser_act(self, start_state_dict, traverser, trav_depth, plyrs_range_idxs, iteration_strats, sample_reach, cfr_iter): """ Last state values are the average, not the sum of all samples of that state since we add v~(I) = * p(a) * |A(I)|. Since we sample multiple actions on each traverser node, we have to average over their returns like: v~(I) * Sum_a=0_N (v~(I|a) * p(a) * ||A(I)|| / N). """ self.total_node_count_traversed += 1 self._env_wrapper.load_state_dict(start_state_dict) legal_actions_list = self._env_wrapper.env.get_legal_actions() legal_action_mask = rl_util.get_legal_action_mask_torch( n_actions=self._env_bldr.N_ACTIONS, legal_actions_list=legal_actions_list, device=self._adv_buffers[traverser].device, dtype=torch.float32) current_pub_obs = self._env_wrapper.get_current_obs() round_t = self._env_wrapper.env.current_round traverser_range_idx = plyrs_range_idxs[traverser] # """"""""""""""""""""""""" # Strategy # """"""""""""""""""""""""" strat_i = iteration_strats[traverser].get_a_probs( pub_obses=[current_pub_obs], range_idxs=[traverser_range_idx], legal_actions_lists=[legal_actions_list], to_np=False, )[0] # """"""""""""""""""""""""" # Sample action # """"""""""""""""""""""""" n_legal_actions = len(legal_actions_list) sample_strat = (1 - self._eps) * strat_i + self._eps * ( legal_action_mask.cpu() / n_legal_actions) a = np.random.choice(self._actions_arranged, p=sample_strat.numpy()) # Step pub_obs_tp1, rew_for_all, done, _info = self._env_wrapper.step(a) round_tp1 = self._env_wrapper.env.current_round # """"""""""""""""""""""""" # Utility # """"""""""""""""""""""""" utility = self._get_utility( traverser=traverser, acting_player=traverser, u_bootstrap=rew_for_all[traverser] if done else self._recursive_traversal( start_state_dict=self._env_wrapper.state_dict(), traverser=traverser, trav_depth=trav_depth + 1, plyrs_range_idxs=plyrs_range_idxs, iteration_strats=iteration_strats, cfr_iter=cfr_iter, sample_reach=sample_reach * sample_strat[a] * n_legal_actions), # recursion pub_obs=current_pub_obs, range_idx_trav=plyrs_range_idxs[traverser], range_idx_opp=plyrs_range_idxs[1 - traverser], legal_actions_list=legal_actions_list, legal_action_mask=legal_action_mask, a=a, round_t=round_t, round_tp1=round_tp1, sample_strat=sample_strat, pub_obs_tp1=pub_obs_tp1, ) # Regret aprx_imm_reg = torch.full(size=(self._env_bldr.N_ACTIONS, ), fill_value=-(utility * strat_i).sum(), dtype=torch.float32, device=self._adv_buffers[traverser].device) aprx_imm_reg += utility aprx_imm_reg *= legal_action_mask # add current datapoint to ADVBuf self._adv_buffers[traverser].add( pub_obs=current_pub_obs, range_idx=traverser_range_idx, legal_action_mask=legal_action_mask, adv=aprx_imm_reg, iteration=(cfr_iter + 1) / sample_reach, ) # if trav_depth == 0 and traverser == 0: # self._reg_buf[traverser_range_idx].append(aprx_imm_reg.clone().cpu().numpy()) return (utility * strat_i).sum()
def _any_non_traverser_act(self, start_state_dict, traverser, plyrs_range_idxs, trav_depth, iteration_strats, sample_reach, cfr_iter): self.total_node_count_traversed += 1 self._env_wrapper.load_state_dict(start_state_dict) p_id_acting = self._env_wrapper.env.current_player.seat_id current_pub_obs = self._env_wrapper.get_current_obs() range_idx = plyrs_range_idxs[p_id_acting] legal_actions_list = self._env_wrapper.env.get_legal_actions() legal_action_mask = rl_util.get_legal_action_mask_torch( n_actions=self._env_bldr.N_ACTIONS, legal_actions_list=legal_actions_list, device=self._adv_buffers[traverser].device, dtype=torch.float32) round_t = self._env_wrapper.env.current_round # """"""""""""""""""""""""" # The players strategy # """"""""""""""""""""""""" strat_opp = iteration_strats[p_id_acting].get_a_probs( pub_obses=[current_pub_obs], range_idxs=[range_idx], legal_actions_lists=[legal_actions_list], to_np=False)[0] # """"""""""""""""""""""""" # Adds to opponent's # average buffer if # applicable # """"""""""""""""""""""""" if self._avrg_buffers is not None: self._avrg_buffers[p_id_acting].add( pub_obs=current_pub_obs, range_idx=range_idx, legal_actions_list=legal_actions_list, a_probs=strat_opp.to( self._avrg_buffers[p_id_acting].device).squeeze(), iteration=cfr_iter + 1) # """"""""""""""""""""""""" # Execute action from strat # """"""""""""""""""""""""" a = torch.multinomial(strat_opp.cpu(), num_samples=1).item() pub_obs_tp1, rew_for_all, done, _info = self._env_wrapper.step(a) round_tp1 = self._env_wrapper.env.current_round # """"""""""""""""""""""""" # Utility # """"""""""""""""""""""""" utility = self._get_utility( traverser=traverser, acting_player=1 - traverser, u_bootstrap=rew_for_all[traverser] if done else self._recursive_traversal( start_state_dict=self._env_wrapper.state_dict(), traverser=traverser, trav_depth=trav_depth, plyrs_range_idxs=plyrs_range_idxs, iteration_strats=iteration_strats, cfr_iter=cfr_iter, sample_reach=sample_reach, ), pub_obs=current_pub_obs, range_idx_trav=plyrs_range_idxs[traverser], range_idx_opp=plyrs_range_idxs[1 - traverser], legal_actions_list=legal_actions_list, legal_action_mask=legal_action_mask, a=a, round_t=round_t, round_tp1=round_tp1, sample_strat=strat_opp, pub_obs_tp1=pub_obs_tp1, ) return (utility * strat_opp).sum()
def _traverser_act(self, start_state_dict, traverser, trav_depth, plyrs_range_idxs, iteration_strats, cfr_iter): """ Last state values are the average, not the sum of all samples of that state since we add v~(I) = * p(a) * |A(I)|. Since we sample multiple actions on each traverser node, we have to average over their returns like: v~(I) * Sum_a=0_N (v~(I|a) * p(a) * ||A(I)|| / N). """ self._env_wrapper.load_state_dict(start_state_dict) legal_actions_list = self._env_wrapper.env.get_legal_actions() legal_action_mask = rl_util.get_legal_action_mask_torch( n_actions=self._env_bldr.N_ACTIONS, legal_actions_list=legal_actions_list, device=self._adv_buffers[traverser].device, dtype=torch.float32) current_pub_obs = self._env_wrapper.get_current_obs() traverser_range_idx = plyrs_range_idxs[traverser] # """"""""""""""""""""""""" # Sample actions # """"""""""""""""""""""""" n_legal_actions = len(legal_actions_list) n_actions_to_smpl = self._get_n_a_to_sample( trav_depth=trav_depth, n_legal_actions=n_legal_actions) _idxs = np.arange(n_legal_actions) np.random.shuffle(_idxs) _idxs = _idxs[:n_actions_to_smpl] actions = [legal_actions_list[i] for i in _idxs] strat_i = iteration_strats[traverser].get_a_probs( pub_obses=[current_pub_obs], range_idxs=[traverser_range_idx], legal_actions_lists=[legal_actions_list], to_np=True)[0] cumm_rew = 0.0 aprx_imm_reg = torch.zeros(size=(self._env_bldr.N_ACTIONS, ), dtype=torch.float32, device=self._adv_buffers[traverser].device) # """"""""""""""""""""""""" # Create next states # """"""""""""""""""""""""" for _c, a in enumerate(actions): strat_i_a = strat_i[a] # Re-initialize environment after one action-branch loop finished with current state and random future if _c > 0: self._env_wrapper.load_state_dict(start_state_dict) self._env_wrapper.env.reshuffle_remaining_deck() _obs, _rew_for_all, _done, _info = self._env_wrapper.step(a) _cfv_traverser_a = _rew_for_all[traverser] # Recursion over sub-trees if not _done: _cfv_traverser_a += self._recursive_traversal( start_state_dict=self._env_wrapper.state_dict(), traverser=traverser, trav_depth=trav_depth + 1, plyrs_range_idxs=plyrs_range_idxs, iteration_strats=iteration_strats, cfr_iter=cfr_iter) # accumulate reward for backward-pass on tree cumm_rew += strat_i_a * _cfv_traverser_a # """""""""""""""""""""""" # Compute the approximate # immediate regret # """""""""""""""""""""""" aprx_imm_reg -= strat_i_a * _cfv_traverser_a # This is for all actions =/= a # add regret for a and undo the change made to a's regret in the line above. aprx_imm_reg[a] += _cfv_traverser_a aprx_imm_reg *= legal_action_mask / n_actions_to_smpl # mean over all legal actions sampled # add current datapoint to ADVBuf self._adv_buffers[traverser].add( pub_obs=current_pub_obs, range_idx=traverser_range_idx, legal_action_mask=legal_action_mask, adv=aprx_imm_reg, iteration=cfr_iter + 1, ) # increase total entries generated counter to control break point self._generated_entries_adv += 1 # *n_legal_actions because we multiply by strat. # /n_actions_to_smpl because we summed that many returns and want their mean return cumm_rew * n_legal_actions / n_actions_to_smpl
def _traverser_act(self, start_state_dict, traverser, trav_depth, sample_reach, plyrs_range_idxs, iteration_strats, cfr_iter): self.total_node_count_traversed += 1 """ Last state values are the average, not the sum of all samples of that state since we add v~(I) = * p(a) * |A(I)|. Since we sample multiple actions on each traverser node, we have to average over their returns like: v~(I) * Sum_a=0_N (v~(I|a) * p(a) * ||A(I)|| / N). """ self._env_wrapper.load_state_dict(start_state_dict) legal_actions_list = self._env_wrapper.env.get_legal_actions() legal_action_mask = rl_util.get_legal_action_mask_torch( n_actions=self._env_bldr.N_ACTIONS, legal_actions_list=legal_actions_list, device=self._adv_buffers[traverser].device, dtype=torch.float32) current_pub_obs = self._env_wrapper.get_current_obs() traverser_range_idx = plyrs_range_idxs[traverser] strat_i = iteration_strats[traverser].get_a_probs( pub_obses=[current_pub_obs], range_idxs=[traverser_range_idx], legal_actions_lists=[legal_actions_list], to_np=True)[0] n_legal_actions = len(legal_actions_list) n_actions_to_smpl = self._get_n_a_to_sample( trav_depth=trav_depth, n_legal_actions=n_legal_actions) sample_strat = (1 - self._eps) * strat_i + self._eps * ( legal_action_mask.cpu().numpy() / n_legal_actions) # """"""""""""""""""""""""" # Create next states # """"""""""""""""""""""""" cumm_rew = 0.0 aprx_imm_reg = torch.zeros(size=(self._env_bldr.N_ACTIONS, ), dtype=torch.float32, device=self._adv_buffers[traverser].device) _1_over_sample_reach_sum = 0 _sample_reach_sum = 0 for _c, a in enumerate( np.random.choice(self._actions_arranged, p=sample_strat, size=n_actions_to_smpl)): # Re-initialize environment after one action-branch loop finished with current state and random future if _c > 0: self._env_wrapper.load_state_dict(start_state_dict) self._env_wrapper.env.reshuffle_remaining_deck() _obs, _rew_for_all, _done, _info = self._env_wrapper.step(a) _u_a = _rew_for_all[traverser] _sample_reach = sample_reach * sample_strat[a] * n_legal_actions _sample_reach_sum += _sample_reach _1_over_sample_reach_sum += 1 / _sample_reach # Recursion over sub-trees if not _done: _u_a += self._recursive_traversal( start_state_dict=self._env_wrapper.state_dict(), traverser=traverser, trav_depth=trav_depth + 1, plyrs_range_idxs=plyrs_range_idxs, iteration_strats=iteration_strats, sample_reach=_sample_reach, cfr_iter=cfr_iter) # accumulate reward for backward-pass on tree cumm_rew += strat_i[a] / sample_strat[a] * _u_a # """""""""""""""""""""""" # Compute the approximate # immediate regret # """""""""""""""""""""""" _aprx_imm_reg = torch.full_like( aprx_imm_reg, fill_value=-strat_i[a] * _u_a) # This is for all actions != a _aprx_imm_reg[ a] += _u_a # add regret for a and undo the change made to a's regret in the line above. aprx_imm_reg += _aprx_imm_reg / _sample_reach aprx_imm_reg *= legal_action_mask / _1_over_sample_reach_sum # mean over all legal actions sampled # add current datapoint to ADVBuf self._adv_buffers[traverser].add( pub_obs=current_pub_obs, range_idx=traverser_range_idx, legal_action_mask=legal_action_mask, adv=aprx_imm_reg, iteration=(cfr_iter + 1) / _sample_reach_sum, ) # /n_actions_to_smpl because we summed that many returns and want their mean return cumm_rew / n_actions_to_smpl
def _get_mask(self, legal_actions_list): return rl_util.get_legal_action_mask_torch( n_actions=self._env_bldr.N_ACTIONS, legal_actions_list=legal_actions_list, device=self.device, dtype=torch.float32)
def _any_non_traverser_act(self, start_state_dict, traverser, plyrs_range_idxs, trav_depth, iteration_strats, sample_reach, cfr_iter): self.total_node_count_traversed += 1 self._env_wrapper.load_state_dict(start_state_dict) p_id_acting = self._env_wrapper.env.current_player.seat_id current_pub_obs = self._env_wrapper.get_current_obs() range_idx = plyrs_range_idxs[p_id_acting] legal_actions_list = self._env_wrapper.env.get_legal_actions() legal_action_mask = rl_util.get_legal_action_mask_torch( n_actions=self._env_bldr.N_ACTIONS, legal_actions_list=legal_actions_list, device=self._adv_buffers[traverser].device, dtype=torch.float32) # """"""""""""""""""""""""" # The players strategy # """"""""""""""""""""""""" strat_opp = iteration_strats[p_id_acting].get_a_probs( pub_obses=[current_pub_obs], range_idxs=[range_idx], legal_actions_lists=[legal_actions_list], to_np=False)[0] # """"""""""""""""""""""""" # Execute action from strat # """"""""""""""""""""""""" a = torch.multinomial(strat_opp.cpu(), num_samples=1).item() pub_obs_tp1, rew_for_all, done, _info = self._env_wrapper.step(a) legal_action_mask_tp1 = rl_util.get_legal_action_mask_torch( n_actions=self._env_bldr.N_ACTIONS, legal_actions_list=self._env_wrapper.env.get_legal_actions(), device=self._adv_buffers[traverser].device, dtype=torch.float32) # """"""""""""""""""""""""" # Adds to opponent's # average buffer if # applicable # """"""""""""""""""""""""" if self._avrg_buffers is not None: self._avrg_buffers[p_id_acting].add( pub_obs=current_pub_obs, range_idx=range_idx, legal_actions_list=legal_actions_list, a_probs=strat_opp.to( self._avrg_buffers[p_id_acting].device).squeeze(), iteration=(cfr_iter + 1) / sample_reach) # """"""""""""""""""""""""" # Recursion # """"""""""""""""""""""""" if done: strat_tp1 = torch.zeros_like(strat_opp) self.total_node_count_traversed += 1 else: u_bootstrap, strat_tp1 = self._recursive_traversal( start_state_dict=self._env_wrapper.state_dict(), traverser=traverser, trav_depth=trav_depth + 1, plyrs_range_idxs=plyrs_range_idxs, iteration_strats=iteration_strats, cfr_iter=cfr_iter, sample_reach=sample_reach) # """"""""""""""""""""""""" # Utility # """"""""""""""""""""""""" utility = self._get_utility( traverser=traverser, u_bootstrap=rew_for_all[traverser] if done else u_bootstrap, pub_obs=current_pub_obs, range_idx_crazy_embedded=_crazy_embed( plyrs_range_idxs=plyrs_range_idxs), legal_actions_list=legal_actions_list, legal_action_mask=legal_action_mask, a=a, sample_strat=strat_opp, ) # add datapoint to baseline net self._baseline_buf.add( pub_obs=current_pub_obs, range_idx_crazy_embedded=_crazy_embed( plyrs_range_idxs=plyrs_range_idxs), legal_action_mask=legal_action_mask, r=rew_for_all[0], # 0 bc we mirror for 1... zero-sum a=a, done=done, pub_obs_tp1=pub_obs_tp1, strat_tp1=strat_tp1, legal_action_mask_tp1=legal_action_mask_tp1, ) return (utility * strat_opp).sum(), strat_opp