checkpoint_dir ='64units-intended/2players_64hid_checkpoints' for epoch in range(20): # rl_dir2 = checkpoint_dir+'/rl_checkpoints/checkpoint_' + str(epoch+2) + '000.pt' # sl_dir1 = checkpoint_dir+'/sl_checkpoints/checkpoint_' + str(epoch+1) + '000.pt' # sl_dir2 = checkpoint_dir+'/sl_checkpoints/checkpoint_' + str(epoch+2) + '000.pt' policy_net = {} sl_net = {} rl_dir = {} sl_dir = {} for i in range(num_NFSP): rl_dir[i] = checkpoint_dir+'/rl_checkpoints/checkpoint_' + str(epoch+1+i) + '000.pt' sl_dir[i] = checkpoint_dir+'/sl_checkpoints/checkpoint_' + str(epoch+1+i) + '000.pt' policy_net[i] = DQN_limit(num_player=num_player,big=False, num_action = 3, hidden_units = hid).to(device) sl_net[i] = MLP_limit(num_player=num_player, big=False, num_action = 3, hidden_units = hid).to(device) checkpoint = torch.load(rl_dir[i], map_location='cpu') policy_net[i].load_state_dict(checkpoint['model']) policy_net[i].eval() checkpoint = torch.load(sl_dir[i], map_location='cpu') sl_net[i].load_state_dict(checkpoint['model']) sl_net[i].eval() results = [] for expriment in range(3): game_board = {} sum_board = {}
env = LimitTexasHoldemEnv(num_player, max_limit=1e9, debug=False) #initialize 3-player game env.add_player(0, stack=20000) # add a player to seat 0 with 2000 "chips" env.add_player(1, stack=20000) # add a player to seat 1 with 2000 "chips" env.add_player(2, stack=20000) # add a player to seat 2 with 2000 "chips" env.add_player(3, stack=20000) # add a player to seat 2 with 2000 "chips" card_dictionary = get_card_dict() for epoch in range(num_versions): rl_dir = checkpoint_dir + '/rl_checkpoints/checkpoint_' + str(epoch + 1) + '000.pt' sl_dir = checkpoint_dir + '/sl_checkpoints/checkpoint_' + str(epoch + 1) + '000.pt' # policy_net = DQN(num_player=num_player,big=False).to(device) policy_net = DQN_limit(num_player=num_player, big=big, num_action=3, hidden_units=hid).to(device) # checkpoint = torch.load('../holdem_result/rl_checkpoints/checkpoint1999999_8235201.000.pt') checkpoint = torch.load(rl_dir, map_location='cpu') policy_net.load_state_dict(checkpoint['model']) policy_net.eval() sl_net = MLP_limit(num_player=num_player, big=big, num_action=3, hidden_units=hid).to(device) # sl_net = MLP(num_player=num_player, big=False).to(device) checkpoint = torch.load(sl_dir, map_location='cpu') # checkpoint = torch.load('../holdem_result/sl_checkpoints/checkpoint1999999_8235201.000.pt') sl_net.load_state_dict(checkpoint['model']) sl_net.eval()
# players players = {} for i in range(num_player): players[i] = NFSPLimit(card_dict=card_dictionary, device=device) # memory M_sl = SLMemory(num_SL_memory) M_rl = RLMemory(num_RL_memory) # networks # networks policy_net = DQN_limit(num_player=num_player, big=big, res_net=use_res_net, num_layer=num_layer, num_action=3, hidden_units=num_hid).to(device) target_net = DQN_limit(num_player=num_player, big=big, res_net=use_res_net, num_layer=num_layer, num_action=3, hidden_units=num_hid).to(device) target_net.load_state_dict(policy_net.state_dict()) target_net.eval() sl_net = MLP_limit(num_player=num_player, big=big, res_net=use_res_net, num_layer=num_layer, num_action=3,