def _prepare_training_data(self, samples):
     inputs = []
     targets_w = []
     targets_pi = []
     env = Connect4Env(width=config.Width, height=config.Height)
     for sample in samples:
         inputs.append(utils.format_state(sample[0], env))
         targets_pi.append(sample[1])
         targets_w.append(sample[2])
     return np.vstack(inputs), [np.vstack(targets_w), np.vstack(targets_pi)]
 def _symmetrize_steps(self, steps):
     env = Connect4Env(width=config.Width, height=config.Height)
     for i in range(len(steps)):
         state = steps[i][0]
         prob = steps[i][1]
         symmetrical_state = env.get_mirror_state(state)
         symmetrical_prob = prob[::-1]
         steps.append([
             symmetrical_state, symmetrical_prob, steps[i][2], steps[i][3]
         ])
     return steps
 def run_episode(self):
     steps = []
     env = Connect4Env(width=config.Width, height=config.Height)
     mct = MCT(network=self.best_network)
     state = env.get_state()
     reward = 0
     result = 0
     while True:
         # MCTS
         for i in range(config.MCTS_Num):
             mct.search(state=state, reward=reward, result=result, env=env)
         # get PI from MCT
         if len(steps) < 10:
             pi = mct.get_actions_probability(state=state,
                                              env=env,
                                              temperature=1)
         else:
             pi = mct.get_actions_probability(state=state,
                                              env=env,
                                              temperature=0)
         # add (state, PI and placeholder of W) to memory
         steps.append([state, pi, None, env.get_current_player()])
         # choose an action based on PI
         action = np.random.choice(len(pi), p=pi)
         # take the action
         state, reward, result = env.step(action)
         logger.debug(action + 1)
         logger.debug(env.to_str(state))
         # if game is finished, back fill the W placeholder
         if result != 0:
             steps = self._assign_w(steps=steps, winner=result)
             steps = self._symmetrize_steps(steps=steps)
             # logger.info(self.self_play_env.to_str(state))
             break
     for step in steps:
         self.memory.append(step)
         logger.debug('================================')
         logger.debug(env.to_str(step[0]))
         logger.debug('player: {}'.format(step[3]))
         logger.debug('probabilities: {}'.format(step[1]))
         logger.debug('value: {}'.format(step[2]))
示例#4
0
def env(game):
    if game == 'tictactoe':
        return TictactoeEnv()
    elif game == 'connect4':
        return Connect4Env()
示例#5
0
    network = Network(input_dim=(7, 6, 1),
                      output_dim=7,
                      layers_metadata=[{
                          'filters': 42,
                          'kernel_size': (4, 4)
                      }, {
                          'filters': 42,
                          'kernel_size': (4, 4)
                      }, {
                          'filters': 42,
                          'kernel_size': (4, 4)
                      }],
                      reg_const=0.6,
                      learning_rate=0.0005,
                      root_path=None)
    env = Connect4Env(width=7, height=6)
    mct = MCT(network=network)

    player = 1
    try:
        human_player = int(
            input(
                'Would you like to be the 1st player or the 2nd player (answer 1 or 2): '
            ))
        if human_player not in (1, 2):
            print('Sorry, I don'
                  't understand your answer. I will play with myself.')
    except:
        print('Sorry, I don'
              't understand your answer. I will play with myself.')
        human_player = 3
    def compete_for_best_network(self, new_network, best_network):
        logger.info('Comparing network....')
        mct_new = MCT(network=new_network)
        mct_best = MCT(network=best_network)
        players = [[mct_new, 0], [mct_best, 0]]
        env = Connect4Env(width=config.Width, height=config.Height)

        mct_new_wins = 0
        mct_best_wins = 0
        draw_games = 0
        for i in range(config.Compete_Game_Num):
            env.reset()
            state = env.get_state()
            reward = 0
            result = 0
            step = 0

            logger.debug('{} network get the upper hand for this game.'.format(
                players[step % 2][0].network.name))
            while True:
                for _ in range(config.Test_MCTS_Num):
                    players[step % 2][0].search(state=state,
                                                reward=reward,
                                                result=result,
                                                env=env)
                prob = players[step % 2][0].get_actions_probability(
                    state=state, env=env, temperature=0)
                action = np.random.choice(len(prob), p=prob)
                state, reward, result = env.step(col_idx=action)
                if result == 1:
                    players[0][1] += 1
                    break
                elif result == 2:
                    players[1][1] += 1
                    break
                elif result == 3:
                    draw_games += 1
                    break
                else:
                    step += 1
            logger.debug(env.to_str())
            logger.debug(result)

            if players[0][0] == mct_new:
                mct_new_wins = players[0][1]
                mct_best_wins = players[1][1]
            else:
                mct_new_wins = players[1][1]
                mct_best_wins = players[0][1]

            logger.info(''.join(
                ('O' * mct_new_wins, 'X' * mct_best_wins, '-' * draw_games,
                 '.' * (config.Compete_Game_Num - i - 1))))

            if mct_best_wins / (mct_new_wins + mct_best_wins +
                                (config.Compete_Game_Num - i - 1)) >= (
                                    1 - config.Best_Network_Threshold):
                logger.info(
                    'new network has no hope to win in the comparison, so stop the comparison early.'
                )
                break
            elif mct_new_wins / (mct_new_wins + mct_best_wins +
                                 (config.Compete_Game_Num - i -
                                  1)) > config.Best_Network_Threshold:
                logger.info(
                    'new network has already won in the comparison, so stop the comparison early.'
                )
                break
            else:
                players.reverse()

        compete_result = mct_new_wins / (mct_best_wins + mct_new_wins)
        logger.debug(
            'new network won {} games, best network won {} games, draw games are {}'
            .format(mct_new_wins, mct_best_wins, draw_games))
        logger.info('new network winning ratio is {}'.format(compete_result))

        is_update = compete_result > config.Best_Network_Threshold
        if is_update:
            self.best_network.replace_by(new_network)
            logger.info('Updated best network!!!!')
        else:
            # self.current_network.replace_by(self.best_network)
            logger.info('Discarded current network....')
        return is_update
if __name__ == '__main__':

    training_flag = str(
        input(
            'Would you like to train the network before test it (answer Y or N): '
        )).upper() == 'Y'

    best_network = Network('Best')

    if training_flag:
        training = Training(best_network)
        time.sleep(10)
        training.train()
    # ==========================================
    player = 1
    env = Connect4Env(width=config.Width, height=config.Height)
    mct = MCT(network=best_network)
    reward = 0
    result = 0
    try:
        human_player = int(
            input(
                'Would you like to be the 1st player or the 2nd player (answer 1 or 2): '
            ))
        if human_player not in (1, 2):
            print('Sorry, I don'
                  't understand your answer. I will play with myself.')
    except:
        print('Sorry, I don'
              't understand your answer. I will play with myself.')
        human_player = 3
示例#8
0
    for i in range(4):
        for j in range(5):
            env.step(j+1)
    env.render()

def test_rand():
    """play randomly"""
    env.reset()
    while env.done is False:
        print(env.legal_moves() )
        print(env.board)
        env.step(choice(env.legal_moves() ))
    print(env.done)

print(chr(27) + "[2J")
env = Connect4Env()
"""
print("\ntesting init")
print(env.board)
print("\ntesting legal moves")
test_legal()
print("\ntesting step")
test_step()
print("\ntesting render")
test_render()
print("\ntesting vertical check")
test_vert()
print("\ntesting horizontal check")
test_horiz()
print("\ntesting diagonal check")
test_diag()