Beispiel #1
0
    def set_batch_canonical_form(batch_states, batch_group_maps, player):
        """
        Assumes the turn of all states is player
        The returned state is a seperate copy of the given state
        :param batch_states:
        :param player:
        :return:
        """

        if player == govars.WHITE:
            batch_states[:, [govars.BLACK, govars.WHITE]] = batch_states[:, [govars.WHITE, govars.BLACK]]
            state_utils.batch_set_turn(batch_states)
            for group_map in batch_group_maps:
                group_map.reverse()
Beispiel #2
0
def batch_next_states(batch_states, batch_action1d, canonical=False):
    # Deep copy the state to modify
    batch_states = np.copy(batch_states)

    # Initialize basic variables
    board_shape = batch_states.shape[2:]
    pass_idx = np.prod(board_shape)
    batch_pass = np.nonzero(batch_action1d == pass_idx)
    batch_non_pass = np.nonzero(batch_action1d != pass_idx)[0]
    batch_prev_passed = batch_prev_player_passed(batch_states)
    batch_game_ended = np.nonzero(batch_prev_passed
                                  & (batch_action1d == pass_idx))
    batch_action2d = np.array([
        batch_action1d[batch_non_pass] // board_shape[0],
        batch_action1d[batch_non_pass] % board_shape[1]
    ]).T

    batch_players = batch_turn(batch_states)
    batch_non_pass_players = batch_players[batch_non_pass]
    batch_ko_protect = np.empty(len(batch_states), dtype=object)

    # Pass moves
    batch_states[batch_pass, govars.PASS_CHNL] = 1
    # Game ended
    batch_states[batch_game_ended, govars.DONE_CHNL] = 1

    # Non-pass moves
    batch_states[batch_non_pass, govars.PASS_CHNL] = 0

    # Assert all non-pass moves are valid
    assert (batch_states[batch_non_pass, govars.INVD_CHNL,
                         batch_action2d[:, 0], batch_action2d[:,
                                                              1]] == 0).all()

    # Add piece
    batch_states[batch_non_pass, batch_non_pass_players, batch_action2d[:, 0],
                 batch_action2d[:, 1]] = 1

    # Get adjacent location and check whether the piece will be surrounded by opponent's piece
    batch_adj_locs, batch_surrounded = state_utils.batch_adj_data(
        batch_states[batch_non_pass], batch_action2d, batch_non_pass_players)

    # Update pieces
    batch_killed_groups = state_utils.batch_update_pieces(
        batch_non_pass, batch_states, batch_adj_locs, batch_non_pass_players)

    # Ko-protection
    for i, (killed_groups,
            surrounded) in enumerate(zip(batch_killed_groups,
                                         batch_surrounded)):
        # If only killed one group, and that one group was one piece, and piece set is surrounded,
        # activate ko protection
        if len(killed_groups) == 1 and surrounded:
            killed_group = killed_groups[0]
            if len(killed_group) == 1:
                batch_ko_protect[batch_non_pass[i]] = killed_group[0]

    # Update invalid moves
    batch_states[:,
                 govars.INVD_CHNL] = state_utils.batch_compute_invalid_moves(
                     batch_states, batch_players, batch_ko_protect)

    # Switch turn
    state_utils.batch_set_turn(batch_states)

    if canonical:
        # Set canonical form
        batch_states = batch_canonical_form(batch_states)

    return batch_states
Beispiel #3
0
    def get_next_states(state, batch_action1d, group_map=None, canonical=False):
        """
        Does not change the given state
        """
        if group_map is None:
            group_map = state_utils.get_group_map(state)

        m, n = state_utils.get_board_size(state)
        pass_idcs = np.where(batch_action1d == m * n)
        non_pass_idcs = np.where(batch_action1d != m * n)

        batch_size = len(batch_action1d)
        board_shape = state.shape[1:]
        batch_action2d = np.empty((batch_size, 2), dtype=np.int)
        batch_action2d[:, 0] = batch_action1d // n
        batch_action2d[:, 1] = batch_action1d % n
        batch_action2d[np.where(batch_action1d == m * n)] = 0

        player = state_utils.get_turn(state)
        opponent = 1 - player
        previously_passed = GoGame.get_prev_player_passed(state)

        batch_states = np.tile(state, (batch_size, 1, 1, 1))

        # Check move is valid
        assert (batch_states[non_pass_idcs, govars.INVD_CHNL, batch_action2d[non_pass_idcs, 0], batch_action2d[
            non_pass_idcs, 1]] == 0).all(), "Invalid move"

        batch_group_maps = [[group_map[0].copy(), group_map[1].copy()] for _ in range(batch_size)]
        batch_single_kill = [None for _ in range(batch_size)]
        batch_killed_groups = [set() for _ in range(batch_size)]

        batch_adj_locs, batch_surrounded = state_utils.get_batch_adj_data(state, batch_action2d)

        if previously_passed:
            batch_states[pass_idcs, govars.DONE_CHNL] = 1
        else:
            batch_states[pass_idcs, govars.PASS_CHNL] = 1

        # Non passes
        batch_states[non_pass_idcs, govars.PASS_CHNL] = 0

        # Add pieces
        batch_states[non_pass_idcs, player, batch_action2d[non_pass_idcs, 0], batch_action2d[non_pass_idcs, 1]] = 1

        batch_data = enumerate(zip(batch_action1d, batch_action2d, batch_states, batch_group_maps, batch_adj_locs,
                                   batch_killed_groups))

        for i, (action_1d, action_2d, state, group_map, adj_locs, killed_groups) in batch_data:
            # if the current player passes
            if action_1d == m * n:
                continue

            action_2d = tuple(action_2d)

            # Get all adjacent information
            adj_own_groups, adj_opp_groups = state_utils.get_adjacent_groups(group_map, adj_locs, player)

            # Go through opponent groups
            for group in adj_opp_groups:
                assert action_2d in group.liberties, (action_2d, player, group, state)
                if len(group.liberties) <= 1:
                    # Killed group
                    killed_groups.add(group)

                    # Remove group in board and group map
                    for loc in group.locations:
                        state[opponent, loc[0], loc[1]] = 0
                    group_map[opponent].remove(group)

                    # Metric for ko-protection
                    if len(group.locations) <= 1 and batch_single_kill[i] is None:
                        batch_single_kill[i] = next(iter(group.locations))

            adj_opp_groups.difference_update(killed_groups)

            # Update surviving adjacent opponent groups by removing liberties by the new action
            for opp_group in adj_opp_groups:
                assert action_2d in opp_group.liberties, (action_2d, opp_group, adj_opp_groups)

                # New group copy
                group_map[opponent].remove(opp_group)
                opp_group = opp_group.copy()
                group_map[opponent].add(opp_group)

                opp_group.liberties.remove(action_2d)

            # Update adjacent own groups that are merged with the action
            if len(adj_own_groups) > 0:
                merged_group = adj_own_groups.pop()
                group_map[player].remove(merged_group)
                merged_group = merged_group.copy()
            else:
                merged_group = govars.Group()

            group_map[player].add(merged_group)

            # Locations from action and adjacent groups
            merged_group.locations.add(action_2d)

            for own_group in adj_own_groups:
                merged_group.locations.update(own_group.locations)
                merged_group.liberties.update(own_group.liberties)
                group_map[player].remove(own_group)

            # Liberties from action
            for adj_loc in adj_locs:
                if np.count_nonzero(state[[govars.BLACK, govars.WHITE], adj_loc[0], adj_loc[1]]) == 0:
                    merged_group.liberties.add(adj_loc)

            if action_2d in merged_group.liberties:
                merged_group.liberties.remove(action_2d)

            # More work to do if we killed
            if len(killed_groups) > 0:
                killed_map = np.zeros(board_shape)
                for group in killed_groups:
                    for loc in group.locations:
                        killed_map[loc] = 1
                # Update own groups adjacent to opponent groups that we just killed
                killed_liberties = ndimage.binary_dilation(killed_map)
                affected_idcs = set(zip(*np.nonzero(state[player] * killed_liberties)))
                groups_to_update = set()
                for group in group_map[player]:
                    if not affected_idcs.isdisjoint(group.locations):
                        groups_to_update.add(group)

                all_pieces = np.sum(state[[govars.BLACK, govars.WHITE]], axis=0)
                empties = (1 - all_pieces)
                for group in groups_to_update:
                    group_matrix = np.zeros(board_shape)
                    for loc in group.locations:
                        group_matrix[loc] = 1

                    additional_liberties = ndimage.binary_dilation(group_matrix) * empties * killed_map
                    additional_liberties = set(zip(*np.where(additional_liberties)))

                    group_map[player].remove(group)
                    group = group.copy()
                    group_map[player].add(group)

                    group.liberties.update(additional_liberties)

        # Update illegal moves
        batch_states[:, govars.INVD_CHNL] = state_utils.get_batch_invalid_moves(batch_states, batch_group_maps, player)

        # If group was one piece, and location is surrounded by opponents,
        # activate ko protection
        for i, (single_kill, killed_groups, surrounded) in enumerate(zip(batch_single_kill, batch_killed_groups,
                                                                         batch_surrounded)):
            if single_kill is not None and len(killed_groups) == 1 and surrounded:
                state[govars.INVD_CHNL, single_kill[0], single_kill[1]] = 1

        # Switch turn
        state_utils.batch_set_turn(batch_states)

        if canonical:
            GoGame.set_batch_canonical_form(batch_states, batch_group_maps, opponent)

        return batch_states, batch_group_maps