コード例 #1
0
ファイル: alphazul.py プロジェクト: pine73/alphazul
def worker_routine(game, w2s_conn, public_q):
    commands = np.argwhere(np.ones((6, 5, 6)) == 1)
    inf_helper = InfHelper(w2s_conn)

    search = mcts.MCTSearch(game, inf_helper, commands)

    accumulated_data = []
    winner = None
    while True:
        action_command, training_data = search.start_search(300)
        accumulated_data.append(training_data)
        is_turn_end = game.take_command(action_command)
        if is_turn_end:
            game.turn_end(verbose=False)
            if game.is_terminal:
                game.final_score(verbose=True)
                w2s_conn.send([True] * 3)
                winner = game.leading_player_num
                print('end in', game.turn)
                break
            else:
                game.start_turn()
                if game.turn >= 11:
                    w2s_conn.send([True] * 3)
                    game.final_score()
                    winner = game.leading_player_num
                    print('exceeding turn 10')
                    break
                search = mcts.MCTSearch(game, inf_helper, commands)
        else:
            ##########################
            search.change_root()
            # search = mcts.MCTSearch(game, inf_helper, commands)
            #########################

    state_data, action_data, value_data, mask_data = [], [], [], []
    for state, action_index, player, mask in accumulated_data:
        state_data.append(state)
        action_data.append(action_index)
        mask_data.append(mask)
        if player == winner:
            value_data.append(1.)
        else:
            value_data.append(-1.)

    public_q.put((state_data, action_data, value_data, mask_data))
コード例 #2
0
ファイル: alphazul.py プロジェクト: pine73/alphazul
def debug():
    game = azul.Azul(2)
    game.start()
    commands = np.argwhere(np.ones((6, 5, 6)) == 1)
    inf_helper = InfHelperS()

    search = mcts.MCTSearch(game, inf_helper, commands)

    searches = [search]
    accumulated_data = []
    winner = None
    while True:
        action_command, training_data = search.start_search(300)
        accumulated_data.append(training_data)
        is_turn_end = game.take_command(action_command)
        if is_turn_end:
            game.turn_end(verbose=False)
            if game.is_terminal:
                game.final_score()
                winner = game.leading_player_num
                break
            else:
                game.start_turn()
                search = mcts.MCTSearch(game, inf_helper, commands)
        else:
            # search.change_root()
            search = mcts.MCTSearch(game, inf_helper, commands)
            searches.append(search)

    state_data, action_data, value_data = [], [], []
    for state, action_index, player in accumulated_data:
        state_data.append(state)
        action_index = str(action_index // 30) + str(
            (action_index % 30) // 6) + str(action_index % 6)
        action_data.append(action_index)
        if player == winner:
            value_data.append(1.)
        else:
            value_data.append(-1.)
    return state_data, action_data, value_data, searches
コード例 #3
0
ファイル: test.py プロジェクト: pine73/alphazul
import azul
import policy
import mcts

if __name__ == '__main__':
    game = azul.Azul(2)
    game.start()
    # states = game.states()
    # mask = game.flat_mask()
    # inf = alphazul.InferenceNetwork(states.shape[0],mask.shape[0])
    # value, prior = inf.predict([states],[mask])
    # print(prior,np.sum(prior))
    # print(np.sum(prior>0),np.sum(mask))

    search = mcts.MCTSearch(game,
                            policy.rollout,
                            commands=np.argwhere(np.ones((6, 5, 6)) == 1))
    while True:

        # game.display()
        # print('------------------------------------')
        # search._root.game.display()
        # print('\n\n')

        action, _ = search.start_search(100)
        is_turn_end = game.take_command(action)
        if is_turn_end:
            game.turn_end(verbose=False)
            break
        else:
            search.change_root()
コード例 #4
0
def mcts_roolout(game):
    commands = np.argwhere(np.ones((6, 5, 6)) == 1)
    search = mcts.MCTSearch(game, rollout, commands)
    action = search.start_search_deterministic(300)
    return action
コード例 #5
0
ファイル: policy.py プロジェクト: pine73/alphazul
 def __call__(self,game):
     commands = np.argwhere(np.ones((6,5,6))==1)
     search = mcts.MCTSearch(game,self._infhelper,commands)
     action = search.start_search_deterministic(300)
     return action