def __init__(self, config): ''' Initialize the Limitholdem environment ''' self.name = 'no-limit-holdem' self.default_game_config = DEFAULT_GAME_CONFIG self.game = Game() super().__init__(config) self.actions = Action self.state_shape = [54] # for raise_amount in range(1, self.game.init_chips+1): # self.actions.append(raise_amount) with open(os.path.join(rlcard.__path__[0], 'games/limitholdem/card2index.json'), 'r') as file: self.card2index = json.load(file)
class NolimitholdemEnv(Env): ''' Limitholdem Environment ''' def __init__(self, config): ''' Initialize the Limitholdem environment ''' self.name = 'no-limit-holdem' self.default_game_config = DEFAULT_GAME_CONFIG self.game = Game() super().__init__(config) self.actions = Action self.state_shape = [[54] for _ in range(self.num_players)] self.action_shape = [None for _ in range(self.num_players)] # for raise_amount in range(1, self.game.init_chips+1): # self.actions.append(raise_amount) with open(os.path.join(rlcard.__path__[0], 'games/limitholdem/card2index.json'), 'r') as file: self.card2index = json.load(file) def _get_legal_actions(self): ''' Get all leagal actions Returns: encoded_action_list (list): return encoded legal action list (from str to int) ''' return self.game.get_legal_actions() def _extract_state(self, state): ''' Extract the state representation from state dictionary for agent Note: Currently the use the hand cards and the public cards. TODO: encode the states Args: state (dict): Original state from the game Returns: observation (list): combine the player's score and dealer's observable score for observation ''' extracted_state = {} legal_actions = OrderedDict({action.value: None for action in state['legal_actions']}) extracted_state['legal_actions'] = legal_actions public_cards = state['public_cards'] hand = state['hand'] my_chips = state['my_chips'] all_chips = state['all_chips'] cards = public_cards + hand idx = [self.card2index[card] for card in cards] obs = np.zeros(54) obs[idx] = 1 obs[52] = float(my_chips) obs[53] = float(max(all_chips)) extracted_state['obs'] = obs extracted_state['raw_obs'] = state extracted_state['raw_legal_actions'] = [a for a in state['legal_actions']] extracted_state['action_record'] = self.action_recorder return extracted_state def get_payoffs(self): ''' Get the payoff of a game Returns: payoffs (list): list of payoffs ''' return np.array(self.game.get_payoffs()) def _decode_action(self, action_id): ''' Decode the action for applying to the game Args: action id (int): action id Returns: action (str): action for the game ''' legal_actions = self.game.get_legal_actions() if self.actions(action_id) not in legal_actions: if Action.CHECK in legal_actions: return Action.CHECK else: print("Tried non legal action", action_id, self.actions(action_id), legal_actions) return Action.FOLD return self.actions(action_id) def get_perfect_information(self): ''' Get the perfect information of the current state Returns: (dict): A dictionary of all the perfect information of the current state ''' state = {} state['chips'] = [self.game.players[i].in_chips for i in range(self.num_players)] state['public_card'] = [c.get_index() for c in self.game.public_cards] if self.game.public_cards else None state['hand_cards'] = [[c.get_index() for c in self.game.players[i].hand] for i in range(self.num_players)] state['current_player'] = self.game.game_pointer state['legal_actions'] = self.game.get_legal_actions() return state
class NolimitholdemEnv(Env): ''' Limitholdem Environment ''' def __init__(self, config): ''' Initialize the Limitholdem environment ''' self.game = Game() super().__init__(config) self.actions = Action self.state_shape = [47] # for raise_amount in range(1, self.game.init_chips+1): # self.actions.append(raise_amount) with open( os.path.join(rlcard.__path__[0], 'games/limitholdem/card2index.json'), 'r') as file: self.card2index = json.load(file) def _get_legal_actions(self): ''' Get all leagal actions Returns: encoded_action_list (list): return encoded legal action list (from str to int) ''' return self.game.get_legal_actions() def _extract_state(self, state): # Stuff we would like in our observation: # General features # - call size as a %pot DONE # - all in size as a %pot DONE WITH ISSUES # - number of others in hand - DONE # - number of others in hand who have raised or called before in hand - DONE # - number of others in hand who need to call the current raise - DONE # - street number - DONE # - board position - DONE # Hand features # Better encoded board and player cards DONE # preflop EHS of card pairs vs all (and vs premium) - DONE # postflop (showdown) EHS vs all and vs premium - DONE # # History features # aggression on each street - DONE WITH PAST + CURRENT encoded_public_cards = encode_multihot(state['public_cards']) encoded_private_cards = encode_multihot(state['hand']) call_percent = state['to_call'] all_in_percent = state['to_allin'] n_others = state['n_others'] position = state['position'] already_called = state['already_called'] need_to_call = state['need_to_call'] pot = state['pot'] EHS_preflop = unified_EHS(to_deuces_intlist(state['hand']), [], en, hs_model, ehs_model, lookup_table) EHS_postflop = unified_EHS(to_deuces_intlist(state['hand']), to_deuces_intlist(state['public_cards']), en, hs_model, ehs_model, lookup_table) past_aggression = state['past_aggression'] street_aggression = state['street_aggression'] my_chips = state['my_chips'] all_chips = state['all_chips'] obs = encoded_public_cards[0].tolist( ) + encoded_public_cards[1].tolist() + encoded_private_cards[0].tolist( ) + encoded_public_cards[1].tolist() + [ call_percent, all_in_percent, n_others, position, already_called, need_to_call, pot, EHS_preflop, EHS_postflop, past_aggression, street_aggression, float(my_chips), float(max(all_chips)) ] obs = np.asarray(obs) extracted_state = {} legal_actions = [action.value for action in state['legal_actions']] extracted_state['legal_actions'] = legal_actions # public_cards = state['public_cards'] # hand = state['hand'] # my_chips = state['my_chips'] # all_chips = state['all_chips'] # cards = public_cards + hand # idx = [self.card2index[card] for card in cards] # obs = np.zeros(self.state_shape[0]) # obs[idx] = 1 # obs[52] = float(my_chips) # obs[53] = float(max(all_chips)) extracted_state['obs'] = obs if self.allow_raw_data: extracted_state['raw_obs'] = state extracted_state['raw_legal_actions'] = [ a for a in state['legal_actions'] ] if self.record_action: extracted_state['action_record'] = self.action_recorder return extracted_state def get_payoffs(self): ''' Get the payoff of a game Returns: payoffs (list): list of payoffs ''' return np.array(self.game.get_payoffs()) def _decode_action(self, action_id): ''' Decode the action for applying to the game Args: action id (int): action id Returns: action (str): action for the game ''' legal_actions = self.game.get_legal_actions() if self.actions(action_id) not in legal_actions: if Action.CHECK in legal_actions: return Action.CHECK else: print("Tried non legal action", action_id, self.actions(action_id), legal_actions) return Action.FOLD return self.actions(action_id) def get_perfect_information(self): ''' Get the perfect information of the current state Returns: (dict): A dictionary of all the perfect information of the current state ''' state = {} state['chips'] = [ self.game.players[i].in_chips for i in range(self.player_num) ] state['public_card'] = [c.get_index() for c in self.game.public_cards ] if self.game.public_cards else None state['hand_cards'] = [[ c.get_index() for c in self.game.players[i].hand ] for i in range(self.player_num)] state['current_player'] = self.game.game_pointer state['legal_actions'] = self.game.get_legal_actions() return state
def main(): parser = argparse.ArgumentParser(description="P-MCTS") parser.add_argument("--model", type=str, default="WU-UCT", help="Base MCTS model WU-UCT/UCT (default: WU-UCT)") parser.add_argument("--env-name", type=str, default="AlienNoFrameskip-v0", help="Environment name (default: AlienNoFrameskip-v0)") parser.add_argument("--MCTS-max-steps", type=int, default=128, help="Max simulation step of MCTS (default: 500)") parser.add_argument("--MCTS-max-depth", type=int, default=100, help="Max depth of MCTS simulation (default: 100)") parser.add_argument("--MCTS-max-width", type=int, default=20, help="Max width of MCTS simulation (default: 20)") parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor (default: 1.0)") parser.add_argument("--expansion-worker-num", type=int, default=8, help="Number of expansion workers (default: 1)") parser.add_argument("--simulation-worker-num", type=int, default=8, help="Number of simulation workers (default: 16)") parser.add_argument("--seed", type=int, default=123, help="random seed (default: 123)") parser.add_argument("--max-episode-length", type=int, default=100000, help="Maximum episode length (default: 100000)") parser.add_argument( "--policy", type=str, default="Random", help= "Prior prob/simulation policy used in MCTS Random/PPO/DistillPPO (default: Random)" ) parser.add_argument( "--device", type=str, default="cpu", help= "PyTorch device, if entered 'cuda', use cuda device parallelization (default: cpu)" ) parser.add_argument("--record-video", default=False, action="store_true", help="Record video if supported (default: False)") parser.add_argument("--mode", type=str, default="MCTS", help="Mode MCTS/Distill (default: MCTS)") args = parser.parse_args() env_params = { "env_name": args.env_name, "max_episode_length": args.max_episode_length } if args.mode == "MCTS": # Model initialization if args.model == "WU-UCT": MCTStree = WU_UCT(env_params, args.MCTS_max_steps, args.MCTS_max_depth, args.MCTS_max_width, args.gamma, args.expansion_worker_num, args.simulation_worker_num, policy=args.policy, seed=args.seed, device=args.device, record_video=args.record_video) elif args.model == "UCT": MCTStree = UCT(env_params, args.MCTS_max_steps, args.MCTS_max_depth, args.MCTS_max_width, args.gamma, policy=args.policy, seed=args.seed) else: raise NotImplementedError() cards = [Card('S', 'A'), Card('D', 'A')] g = Game() g.init_game() g.players[0].hand = cards move, node = MCTStree.simulate_single_move(g) print(node.children_visit_count) with open("Results/" + args.model + ".txt", "a+") as f: f.write( "Model: {}, env: {}, result: {}, MCTS max steps: {}, policy: {}, worker num: {}" .format(args.model, args.env_name, move, args.MCTS_max_steps, args.policy, args.simulation_worker_num)) if not os.path.exists("OutLogs/"): try: os.mkdir("OutLogs/") except: pass # sio.savemat("OutLogs/" + args.model + "_" + args.env_name + "_" + str(args.seed) + "_" + # str(args.simulation_worker_num) + ".mat", # {"rewards": rewards, "times": times}) MCTStree.close() elif args.mode == "Distill": train_distillation(args.env_name, args.device)