def set_batch_canonical_form(batch_states, batch_group_maps, player): """ Assumes the turn of all states is player The returned state is a seperate copy of the given state :param batch_states: :param player: :return: """ if player == govars.WHITE: batch_states[:, [govars.BLACK, govars.WHITE]] = batch_states[:, [govars.WHITE, govars.BLACK]] state_utils.batch_set_turn(batch_states) for group_map in batch_group_maps: group_map.reverse()
def batch_next_states(batch_states, batch_action1d, canonical=False): # Deep copy the state to modify batch_states = np.copy(batch_states) # Initialize basic variables board_shape = batch_states.shape[2:] pass_idx = np.prod(board_shape) batch_pass = np.nonzero(batch_action1d == pass_idx) batch_non_pass = np.nonzero(batch_action1d != pass_idx)[0] batch_prev_passed = batch_prev_player_passed(batch_states) batch_game_ended = np.nonzero(batch_prev_passed & (batch_action1d == pass_idx)) batch_action2d = np.array([ batch_action1d[batch_non_pass] // board_shape[0], batch_action1d[batch_non_pass] % board_shape[1] ]).T batch_players = batch_turn(batch_states) batch_non_pass_players = batch_players[batch_non_pass] batch_ko_protect = np.empty(len(batch_states), dtype=object) # Pass moves batch_states[batch_pass, govars.PASS_CHNL] = 1 # Game ended batch_states[batch_game_ended, govars.DONE_CHNL] = 1 # Non-pass moves batch_states[batch_non_pass, govars.PASS_CHNL] = 0 # Assert all non-pass moves are valid assert (batch_states[batch_non_pass, govars.INVD_CHNL, batch_action2d[:, 0], batch_action2d[:, 1]] == 0).all() # Add piece batch_states[batch_non_pass, batch_non_pass_players, batch_action2d[:, 0], batch_action2d[:, 1]] = 1 # Get adjacent location and check whether the piece will be surrounded by opponent's piece batch_adj_locs, batch_surrounded = state_utils.batch_adj_data( batch_states[batch_non_pass], batch_action2d, batch_non_pass_players) # Update pieces batch_killed_groups = state_utils.batch_update_pieces( batch_non_pass, batch_states, batch_adj_locs, batch_non_pass_players) # Ko-protection for i, (killed_groups, surrounded) in enumerate(zip(batch_killed_groups, batch_surrounded)): # If only killed one group, and that one group was one piece, and piece set is surrounded, # activate ko protection if len(killed_groups) == 1 and surrounded: killed_group = killed_groups[0] if len(killed_group) == 1: batch_ko_protect[batch_non_pass[i]] = killed_group[0] # Update invalid moves batch_states[:, govars.INVD_CHNL] = state_utils.batch_compute_invalid_moves( batch_states, batch_players, batch_ko_protect) # Switch turn state_utils.batch_set_turn(batch_states) if canonical: # Set canonical form batch_states = batch_canonical_form(batch_states) return batch_states
def get_next_states(state, batch_action1d, group_map=None, canonical=False): """ Does not change the given state """ if group_map is None: group_map = state_utils.get_group_map(state) m, n = state_utils.get_board_size(state) pass_idcs = np.where(batch_action1d == m * n) non_pass_idcs = np.where(batch_action1d != m * n) batch_size = len(batch_action1d) board_shape = state.shape[1:] batch_action2d = np.empty((batch_size, 2), dtype=np.int) batch_action2d[:, 0] = batch_action1d // n batch_action2d[:, 1] = batch_action1d % n batch_action2d[np.where(batch_action1d == m * n)] = 0 player = state_utils.get_turn(state) opponent = 1 - player previously_passed = GoGame.get_prev_player_passed(state) batch_states = np.tile(state, (batch_size, 1, 1, 1)) # Check move is valid assert (batch_states[non_pass_idcs, govars.INVD_CHNL, batch_action2d[non_pass_idcs, 0], batch_action2d[ non_pass_idcs, 1]] == 0).all(), "Invalid move" batch_group_maps = [[group_map[0].copy(), group_map[1].copy()] for _ in range(batch_size)] batch_single_kill = [None for _ in range(batch_size)] batch_killed_groups = [set() for _ in range(batch_size)] batch_adj_locs, batch_surrounded = state_utils.get_batch_adj_data(state, batch_action2d) if previously_passed: batch_states[pass_idcs, govars.DONE_CHNL] = 1 else: batch_states[pass_idcs, govars.PASS_CHNL] = 1 # Non passes batch_states[non_pass_idcs, govars.PASS_CHNL] = 0 # Add pieces batch_states[non_pass_idcs, player, batch_action2d[non_pass_idcs, 0], batch_action2d[non_pass_idcs, 1]] = 1 batch_data = enumerate(zip(batch_action1d, batch_action2d, batch_states, batch_group_maps, batch_adj_locs, batch_killed_groups)) for i, (action_1d, action_2d, state, group_map, adj_locs, killed_groups) in batch_data: # if the current player passes if action_1d == m * n: continue action_2d = tuple(action_2d) # Get all adjacent information adj_own_groups, adj_opp_groups = state_utils.get_adjacent_groups(group_map, adj_locs, player) # Go through opponent groups for group in adj_opp_groups: assert action_2d in group.liberties, (action_2d, player, group, state) if len(group.liberties) <= 1: # Killed group killed_groups.add(group) # Remove group in board and group map for loc in group.locations: state[opponent, loc[0], loc[1]] = 0 group_map[opponent].remove(group) # Metric for ko-protection if len(group.locations) <= 1 and batch_single_kill[i] is None: batch_single_kill[i] = next(iter(group.locations)) adj_opp_groups.difference_update(killed_groups) # Update surviving adjacent opponent groups by removing liberties by the new action for opp_group in adj_opp_groups: assert action_2d in opp_group.liberties, (action_2d, opp_group, adj_opp_groups) # New group copy group_map[opponent].remove(opp_group) opp_group = opp_group.copy() group_map[opponent].add(opp_group) opp_group.liberties.remove(action_2d) # Update adjacent own groups that are merged with the action if len(adj_own_groups) > 0: merged_group = adj_own_groups.pop() group_map[player].remove(merged_group) merged_group = merged_group.copy() else: merged_group = govars.Group() group_map[player].add(merged_group) # Locations from action and adjacent groups merged_group.locations.add(action_2d) for own_group in adj_own_groups: merged_group.locations.update(own_group.locations) merged_group.liberties.update(own_group.liberties) group_map[player].remove(own_group) # Liberties from action for adj_loc in adj_locs: if np.count_nonzero(state[[govars.BLACK, govars.WHITE], adj_loc[0], adj_loc[1]]) == 0: merged_group.liberties.add(adj_loc) if action_2d in merged_group.liberties: merged_group.liberties.remove(action_2d) # More work to do if we killed if len(killed_groups) > 0: killed_map = np.zeros(board_shape) for group in killed_groups: for loc in group.locations: killed_map[loc] = 1 # Update own groups adjacent to opponent groups that we just killed killed_liberties = ndimage.binary_dilation(killed_map) affected_idcs = set(zip(*np.nonzero(state[player] * killed_liberties))) groups_to_update = set() for group in group_map[player]: if not affected_idcs.isdisjoint(group.locations): groups_to_update.add(group) all_pieces = np.sum(state[[govars.BLACK, govars.WHITE]], axis=0) empties = (1 - all_pieces) for group in groups_to_update: group_matrix = np.zeros(board_shape) for loc in group.locations: group_matrix[loc] = 1 additional_liberties = ndimage.binary_dilation(group_matrix) * empties * killed_map additional_liberties = set(zip(*np.where(additional_liberties))) group_map[player].remove(group) group = group.copy() group_map[player].add(group) group.liberties.update(additional_liberties) # Update illegal moves batch_states[:, govars.INVD_CHNL] = state_utils.get_batch_invalid_moves(batch_states, batch_group_maps, player) # If group was one piece, and location is surrounded by opponents, # activate ko protection for i, (single_kill, killed_groups, surrounded) in enumerate(zip(batch_single_kill, batch_killed_groups, batch_surrounded)): if single_kill is not None and len(killed_groups) == 1 and surrounded: state[govars.INVD_CHNL, single_kill[0], single_kill[1]] = 1 # Switch turn state_utils.batch_set_turn(batch_states) if canonical: GoGame.set_batch_canonical_form(batch_states, batch_group_maps, opponent) return batch_states, batch_group_maps