Esempio n. 1
0
 def get_weighted_legal(self, state, action_probs):
     failed_attempts = 0
     index = -1
     legal = False
     while not legal:
         if index != -1:
             # Avoid picking this action again and re-normalise the probabilities.
             failed_attempts += 1
             action_probs[index] = 0
             total_action_probs = action_probs.sum()
             if total_action_probs <= 0 or isnan(
                     total_action_probs) or failed_attempts > 192:
                 orig_action_probs = self.get_action_probs(state)
                 orig_best = np.argmax(orig_action_probs)
                 log('Oh dear - problem getting legal action in state...')
                 print(state)
                 log('Action probabilities (after %d adjustment(s)) are...'
                     % (failed_attempts))
                 log(action_probs)
                 log('Action probabilities (before adjustment) were...')
                 log(orig_action_probs)
                 log('Best original action was %s (%d)' % (lg.encode_move(
                     bt.convert_index_to_move(orig_best,
                                              state.player)), orig_best))
                 log('Last cleared action was %s (%d)' % (lg.encode_move(
                     bt.convert_index_to_move(index, state.player)), index))
                 log('Total = %f' % (total_action_probs))
                 return self._first_legal_action(state)
             action_probs /= total_action_probs
         index = np.random.choice(bt.ACTIONS, p=action_probs)
         legal = state.is_legal(
             bt.convert_index_to_move(index, state.player))
     return index
Esempio n. 2
0
 def _first_legal_action(self, state):
     for index in range(bt.ACTIONS):
         if state.is_legal(bt.convert_index_to_move(index, state.player)):
             return index
     log('No legal action at all in state...')
     print(state)
     return -1
Esempio n. 3
0
def get_best_legal(state, policy):
    index = -1
    action_probs = policy.get_action_probs(state)
    legal = False
    while not legal:
        if index != -1:
            action_probs[index] = 0
        index = np.argmax(action_probs)
        legal = state.is_legal(bt.convert_index_to_move(index, state.player))
    return index
Esempio n. 4
0
def compare_policies_in_parallel(our_policy, their_policy, num_matches=100):
    states = [bt.Breakthrough() for _ in range(num_matches)]

    # We start all the even numbered games, they start all the odd ones.  Advance all the odd numbered games by a turn
    # so that it's our turn in every game.
    for state in states[1::2]:
        index = their_policy.get_action_index(state)
        state.apply(bt.convert_index_to_move(index, state.player))

    # Rollout all the games to completion.
    current_policy = our_policy
    other_policy = their_policy

    move_made = True
    while move_made:
        # Compute the next move for each game in parallel
        move_made = False
        for (state, action) in zip(states,
                                   current_policy.get_action_indicies(states)):
            if not state.terminated:
                state.apply(bt.convert_index_to_move(action, state.player))
                move_made = True

        # Now it's the other player's turn, so swap policies.
        current_policy, other_policy = other_policy, current_policy

    wins = 0
    p0_wins = 0
    p1_wins = 0
    for state in states[0::2]:
        if state.is_win_for(0):
            wins += 1
            p0_wins += 1
    for state in states[1::2]:
        if state.is_win_for(1):
            wins += 1
            p1_wins += 1

    log('Wins, Wins as p0, Wins as p1 = %d, %d, %d' % (wins, p0_wins, p1_wins))
    return wins / num_matches
Esempio n. 5
0
def evaluate_for(initial_state, policy, player, num_rollouts=1000):
    states = [bt.Breakthrough(initial_state) for _ in range(num_rollouts)]

    move_made = True
    while move_made:
        move_made = False
        for (state, action) in zip(states, policy.get_action_indicies(states)):
            if not state.terminated:
                state.apply(bt.convert_index_to_move(action, state.player))
                move_made = True

    wins = 0
    for state in states:
        if state.is_win_for(player): wins += 1
    return wins / num_rollouts
Esempio n. 6
0
def predict():
    # Load the trained policy
    policy = CNPolicy(checkpoint=PRIMARY_CHECKPOINT)

    # Advance the game to the desired state
    history = input('Input game history: ')
    state = bt.Breakthrough()
    for part in history.split(' '):
        if len(part) == 5:
            state = bt.Breakthrough(state, lg.decode_move(part))
    print(state)

    desired_reward = 1 if state.player == 0 else -1

    # Predict the next move
    prediction = policy.get_action_probs(state)
    sorted_indices = np.argsort(prediction)[::-1][0:5]
    for index in sorted_indices:
        trial_state = bt.Breakthrough(
            state, bt.convert_index_to_move(index, state.player))
        greedy_win = rollout(policy, trial_state,
                             greedy=True) == desired_reward
        win_rate = evaluate_for(trial_state, policy, state.player)
        state_value = policy.get_state_value(trial_state)
        log("Play %s with probability %f (%s) for win rate %d%% and state-value %d%%"
            % (convert_index_to_move(index, state.player), prediction[index],
               '*' if greedy_win else '!', int(win_rate * 100),
               int(state_value * 50) + 50))  # Scale from [-1,+1] to [0,100]

    #log('MCTS evaluation...')
    #tree = mcts.MCTSTrainer(policy)
    #tree.prepare_for_eval(state)
    #tree.iterate(state=state, num_iterations=50000)

    _ = input('Press enter to play on')
    rollout(policy, state, greedy=True, show=True)
Esempio n. 7
0
def convert_index_to_move(index, player):
    move = bt.convert_index_to_move(index, player)
    return lg.encode_move(move)
Esempio n. 8
0
    def do_POST(self):
        global state
        global role
        global policy

        print('POST to %s' % self.path)
        length = self.headers.get('content-length')
        request = self.rfile.read(int(length)).decode('ascii')
        print('Request was %s' % request)

        response = ''
        if request.startswith('( INFO )'):
            response = '( (name Mimic) (status available) )'
        elif request.startswith('( START '):
            state = bt.Breakthrough()
            if request.split(sep=' ')[3] == 'white':
                print('We are white')
                role = 0
            else:
                print('We are black')
                role = 1
            response = 'ready'
        elif request.startswith('( PLAY '):
            parsed_play_req = re.match(r'\( PLAY [^ ]* (.*)', request)
            # Parse out the last move (remembering it's NIL at the start of the game)
            move = parsed_play_req.group(1)
            move = move.replace('noop', '').replace('move', '').replace(
                '(', '').replace(')', '').replace(' ', '')
            if move != 'NIL':
                print('Raw move was %s' % move)
                move = list(move)
                src_col = 8 - int(move[0])
                src_row = int(move[1]) - 1
                dst_col = 8 - int(move[2])
                dst_row = int(move[3]) - 1
                move = (src_row, src_col, dst_row, dst_col)
                print('Move was %s' % lg.encode_move(move))
                state.apply(move)
            else:
                print('First move')

            if state.player == role:
                prediction = policy.get_action_probs(state)
                index = np.argsort(prediction)[::-1][0]
                move = bt.convert_index_to_move(index, state.player)
                print('Playing %s' % (lg.encode_move(move)))
                src_row = move[0] + 1
                src_col = 8 - move[1]
                dst_row = move[2] + 1
                dst_col = 8 - move[3]
                response = '( move %d %d %d %d )' % (src_col, src_row, dst_col,
                                                     dst_row)
            else:
                print('Not our turn - no-op')
                response = 'noop'

        print('Responding with %s' % response)
        response_bytes = response.encode('ascii')
        self.send_response(200)
        self.send_header('Content-Length', len(response_bytes))
        self.end_headers()
        self.wfile.write(response_bytes)
        self.wfile.flush()