Пример #1
0
    def __init__(self, config):
        ''' Initialize the Limitholdem environment
        '''
        self.name = 'no-limit-holdem'
        self.default_game_config = DEFAULT_GAME_CONFIG
        self.game = Game()
        super().__init__(config)
        self.actions = Action
        self.state_shape = [54]
        # for raise_amount in range(1, self.game.init_chips+1):
        #     self.actions.append(raise_amount)

        with open(os.path.join(rlcard.__path__[0], 'games/limitholdem/card2index.json'), 'r') as file:
            self.card2index = json.load(file)
Пример #2
0
class NolimitholdemEnv(Env):
    ''' Limitholdem Environment
    '''

    def __init__(self, config):
        ''' Initialize the Limitholdem environment
        '''
        self.name = 'no-limit-holdem'
        self.default_game_config = DEFAULT_GAME_CONFIG
        self.game = Game()
        super().__init__(config)
        self.actions = Action
        self.state_shape = [[54] for _ in range(self.num_players)]
        self.action_shape = [None for _ in range(self.num_players)]
        # for raise_amount in range(1, self.game.init_chips+1):
        #     self.actions.append(raise_amount)

        with open(os.path.join(rlcard.__path__[0], 'games/limitholdem/card2index.json'), 'r') as file:
            self.card2index = json.load(file)

    def _get_legal_actions(self):
        ''' Get all leagal actions

        Returns:
            encoded_action_list (list): return encoded legal action list (from str to int)
        '''
        return self.game.get_legal_actions()

    def _extract_state(self, state):
        ''' Extract the state representation from state dictionary for agent

        Note: Currently the use the hand cards and the public cards. TODO: encode the states

        Args:
            state (dict): Original state from the game

        Returns:
            observation (list): combine the player's score and dealer's observable score for observation
        '''
        extracted_state = {}

        legal_actions = OrderedDict({action.value: None for action in state['legal_actions']})
        extracted_state['legal_actions'] = legal_actions

        public_cards = state['public_cards']
        hand = state['hand']
        my_chips = state['my_chips']
        all_chips = state['all_chips']
        cards = public_cards + hand
        idx = [self.card2index[card] for card in cards]
        obs = np.zeros(54)
        obs[idx] = 1
        obs[52] = float(my_chips)
        obs[53] = float(max(all_chips))
        extracted_state['obs'] = obs

        extracted_state['raw_obs'] = state
        extracted_state['raw_legal_actions'] = [a for a in state['legal_actions']]
        extracted_state['action_record'] = self.action_recorder

        return extracted_state

    def get_payoffs(self):
        ''' Get the payoff of a game

        Returns:
           payoffs (list): list of payoffs
        '''
        return np.array(self.game.get_payoffs())

    def _decode_action(self, action_id):
        ''' Decode the action for applying to the game

        Args:
            action id (int): action id

        Returns:
            action (str): action for the game
        '''
        legal_actions = self.game.get_legal_actions()
        if self.actions(action_id) not in legal_actions:
            if Action.CHECK in legal_actions:
                return Action.CHECK
            else:
                print("Tried non legal action", action_id, self.actions(action_id), legal_actions)
                return Action.FOLD
        return self.actions(action_id)

    def get_perfect_information(self):
        ''' Get the perfect information of the current state

        Returns:
            (dict): A dictionary of all the perfect information of the current state
        '''
        state = {}
        state['chips'] = [self.game.players[i].in_chips for i in range(self.num_players)]
        state['public_card'] = [c.get_index() for c in self.game.public_cards] if self.game.public_cards else None
        state['hand_cards'] = [[c.get_index() for c in self.game.players[i].hand] for i in range(self.num_players)]
        state['current_player'] = self.game.game_pointer
        state['legal_actions'] = self.game.get_legal_actions()
        return state
Пример #3
0
class NolimitholdemEnv(Env):
    ''' Limitholdem Environment
    '''
    def __init__(self, config):
        ''' Initialize the Limitholdem environment
        '''
        self.game = Game()
        super().__init__(config)
        self.actions = Action
        self.state_shape = [47]
        # for raise_amount in range(1, self.game.init_chips+1):
        #     self.actions.append(raise_amount)

        with open(
                os.path.join(rlcard.__path__[0],
                             'games/limitholdem/card2index.json'),
                'r') as file:
            self.card2index = json.load(file)

    def _get_legal_actions(self):
        ''' Get all leagal actions

        Returns:
            encoded_action_list (list): return encoded legal action list (from str to int)
        '''
        return self.game.get_legal_actions()

    def _extract_state(self, state):
        #  Stuff we would like in our observation:
        # General features
        #  - call size as a %pot DONE
        #  - all in size as a %pot DONE WITH ISSUES
        #  - number of others in hand - DONE
        #  - number of others in hand who have raised or called before in hand - DONE
        #  - number of others in hand who need to call the current raise - DONE
        #  - street number - DONE
        #  - board position - DONE

        # Hand features
        # Better encoded board and player cards DONE
        # preflop EHS of card pairs vs all (and vs premium) - DONE
        # postflop (showdown) EHS vs all and vs premium - DONE
        #
        # History features
        # aggression on each street - DONE WITH PAST + CURRENT

        encoded_public_cards = encode_multihot(state['public_cards'])
        encoded_private_cards = encode_multihot(state['hand'])
        call_percent = state['to_call']
        all_in_percent = state['to_allin']
        n_others = state['n_others']
        position = state['position']
        already_called = state['already_called']
        need_to_call = state['need_to_call']
        pot = state['pot']
        EHS_preflop = unified_EHS(to_deuces_intlist(state['hand']), [], en,
                                  hs_model, ehs_model, lookup_table)
        EHS_postflop = unified_EHS(to_deuces_intlist(state['hand']),
                                   to_deuces_intlist(state['public_cards']),
                                   en, hs_model, ehs_model, lookup_table)
        past_aggression = state['past_aggression']
        street_aggression = state['street_aggression']
        my_chips = state['my_chips']
        all_chips = state['all_chips']

        obs = encoded_public_cards[0].tolist(
        ) + encoded_public_cards[1].tolist() + encoded_private_cards[0].tolist(
        ) + encoded_public_cards[1].tolist() + [
            call_percent, all_in_percent, n_others, position, already_called,
            need_to_call, pot, EHS_preflop, EHS_postflop, past_aggression,
            street_aggression,
            float(my_chips),
            float(max(all_chips))
        ]
        obs = np.asarray(obs)

        extracted_state = {}

        legal_actions = [action.value for action in state['legal_actions']]
        extracted_state['legal_actions'] = legal_actions

        #   public_cards = state['public_cards']
        #   hand = state['hand']
        #   my_chips = state['my_chips']
        #   all_chips = state['all_chips']
        #   cards = public_cards + hand

        #   idx = [self.card2index[card] for card in cards]
        #   obs = np.zeros(self.state_shape[0])
        #   obs[idx] = 1
        #   obs[52] = float(my_chips)
        #   obs[53] = float(max(all_chips))
        extracted_state['obs'] = obs

        if self.allow_raw_data:
            extracted_state['raw_obs'] = state
            extracted_state['raw_legal_actions'] = [
                a for a in state['legal_actions']
            ]
        if self.record_action:
            extracted_state['action_record'] = self.action_recorder
        return extracted_state

    def get_payoffs(self):
        ''' Get the payoff of a game

        Returns:
           payoffs (list): list of payoffs
        '''
        return np.array(self.game.get_payoffs())

    def _decode_action(self, action_id):
        ''' Decode the action for applying to the game

        Args:
            action id (int): action id

        Returns:
            action (str): action for the game
        '''
        legal_actions = self.game.get_legal_actions()
        if self.actions(action_id) not in legal_actions:
            if Action.CHECK in legal_actions:
                return Action.CHECK
            else:
                print("Tried non legal action", action_id,
                      self.actions(action_id), legal_actions)
                return Action.FOLD
        return self.actions(action_id)

    def get_perfect_information(self):
        ''' Get the perfect information of the current state

        Returns:
            (dict): A dictionary of all the perfect information of the current state
        '''
        state = {}
        state['chips'] = [
            self.game.players[i].in_chips for i in range(self.player_num)
        ]
        state['public_card'] = [c.get_index() for c in self.game.public_cards
                                ] if self.game.public_cards else None
        state['hand_cards'] = [[
            c.get_index() for c in self.game.players[i].hand
        ] for i in range(self.player_num)]
        state['current_player'] = self.game.game_pointer
        state['legal_actions'] = self.game.get_legal_actions()
        return state
Пример #4
0
def main():
    parser = argparse.ArgumentParser(description="P-MCTS")
    parser.add_argument("--model",
                        type=str,
                        default="WU-UCT",
                        help="Base MCTS model WU-UCT/UCT (default: WU-UCT)")

    parser.add_argument("--env-name",
                        type=str,
                        default="AlienNoFrameskip-v0",
                        help="Environment name (default: AlienNoFrameskip-v0)")

    parser.add_argument("--MCTS-max-steps",
                        type=int,
                        default=128,
                        help="Max simulation step of MCTS (default: 500)")
    parser.add_argument("--MCTS-max-depth",
                        type=int,
                        default=100,
                        help="Max depth of MCTS simulation (default: 100)")
    parser.add_argument("--MCTS-max-width",
                        type=int,
                        default=20,
                        help="Max width of MCTS simulation (default: 20)")

    parser.add_argument("--gamma",
                        type=float,
                        default=0.99,
                        help="Discount factor (default: 1.0)")

    parser.add_argument("--expansion-worker-num",
                        type=int,
                        default=8,
                        help="Number of expansion workers (default: 1)")
    parser.add_argument("--simulation-worker-num",
                        type=int,
                        default=8,
                        help="Number of simulation workers (default: 16)")

    parser.add_argument("--seed",
                        type=int,
                        default=123,
                        help="random seed (default: 123)")

    parser.add_argument("--max-episode-length",
                        type=int,
                        default=100000,
                        help="Maximum episode length (default: 100000)")

    parser.add_argument(
        "--policy",
        type=str,
        default="Random",
        help=
        "Prior prob/simulation policy used in MCTS Random/PPO/DistillPPO (default: Random)"
    )

    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        help=
        "PyTorch device, if entered 'cuda', use cuda device parallelization (default: cpu)"
    )

    parser.add_argument("--record-video",
                        default=False,
                        action="store_true",
                        help="Record video if supported (default: False)")

    parser.add_argument("--mode",
                        type=str,
                        default="MCTS",
                        help="Mode MCTS/Distill (default: MCTS)")

    args = parser.parse_args()

    env_params = {
        "env_name": args.env_name,
        "max_episode_length": args.max_episode_length
    }

    if args.mode == "MCTS":
        # Model initialization
        if args.model == "WU-UCT":
            MCTStree = WU_UCT(env_params,
                              args.MCTS_max_steps,
                              args.MCTS_max_depth,
                              args.MCTS_max_width,
                              args.gamma,
                              args.expansion_worker_num,
                              args.simulation_worker_num,
                              policy=args.policy,
                              seed=args.seed,
                              device=args.device,
                              record_video=args.record_video)
        elif args.model == "UCT":
            MCTStree = UCT(env_params,
                           args.MCTS_max_steps,
                           args.MCTS_max_depth,
                           args.MCTS_max_width,
                           args.gamma,
                           policy=args.policy,
                           seed=args.seed)
        else:
            raise NotImplementedError()

        cards = [Card('S', 'A'), Card('D', 'A')]

        g = Game()
        g.init_game()
        g.players[0].hand = cards

        move, node = MCTStree.simulate_single_move(g)
        print(node.children_visit_count)

        with open("Results/" + args.model + ".txt", "a+") as f:
            f.write(
                "Model: {}, env: {}, result: {}, MCTS max steps: {}, policy: {}, worker num: {}"
                .format(args.model, args.env_name, move, args.MCTS_max_steps,
                        args.policy, args.simulation_worker_num))

        if not os.path.exists("OutLogs/"):
            try:
                os.mkdir("OutLogs/")
            except:
                pass

        # sio.savemat("OutLogs/" + args.model + "_" + args.env_name + "_" + str(args.seed) + "_" +
        #             str(args.simulation_worker_num)  + ".mat",
        #             {"rewards": rewards, "times": times})

        MCTStree.close()

    elif args.mode == "Distill":
        train_distillation(args.env_name, args.device)