def __init__(self, powers_centers, map_name, **kwargs): """ Builds the response :param powers_centers: A dict of {power_name: centers} objects :param map_name: The name of the map """ super(SupplyCenterResponse, self).__init__(**kwargs) remaining_scs = Map(map_name).scs[:] all_powers_bytes = [] # Parsing each power for power_name in sorted(powers_centers): centers = sorted(powers_centers[power_name]) power_clause = parse_string(Power, power_name) power_bytes = bytes(power_clause) for center in centers: sc_clause = parse_string(Province, center) power_bytes += bytes(sc_clause) remaining_scs.remove(center) all_powers_bytes += [power_bytes] # Parsing unowned center uno_token = tokens.UNO power_bytes = bytes(uno_token) for center in remaining_scs: sc_clause = parse_string(Province, center) power_bytes += bytes(sc_clause) all_powers_bytes += [power_bytes] # Storing full response self._bytes = bytes(tokens.SCO) \ + b''.join([add_parentheses(power_bytes) for power_bytes in all_powers_bytes])
def get_adjacency_matrix(map_name='standard'): """ Computes the adjacency matrix for map :param map_name: The name of the map :return: A (nb_nodes) x (nb_nodes) matrix """ if map_name in ADJACENCY_MATRIX: return ADJACENCY_MATRIX[map_name] # Finding list of all locations current_map = Map(map_name) locs = get_sorted_locs(current_map) adjacencies = np.zeros((len(locs), len(locs)), dtype=np.bool) # Building adjacencies between locs # Coasts are adjacent to their parent location (without coasts) for i, loc_1 in enumerate(locs): for j, loc_2 in enumerate(locs): if current_map.abuts('A', loc_1, '-', loc_2) or current_map.abuts( 'F', loc_1, '-', loc_2): adjacencies[i, j] = 1 if loc_1 != loc_2 and (loc_1[:3] == loc_2 or loc_1 == loc_2[:3]): adjacencies[i, j] = 1 # Storing in cache and returning ADJACENCY_MATRIX[map_name] = adjacencies return adjacencies
def get_stats(games): """ Computes stats """ nb_won, nb_most, nb_survived, nb_defeated = 0, 0, 0, 0 nb_power_assignations = {power_name: 0 for power_name in Map().powers} players_names = next(iter(games))['players_names'] assigned_powers, rankings = [], [] for game in games: if not game: continue game_assigned_powers = game['assigned_powers'] nb_centers = { power_name: len(game['phases'][-1]['state']['centers'][power_name]) for power_name in game_assigned_powers } if nb_centers[game_assigned_powers[0]] >= 18: nb_won += 1 elif nb_centers[game_assigned_powers[0]] == max(nb_centers.values()): nb_most += 1 elif nb_centers[game_assigned_powers[0]] > 0: nb_survived += 1 else: nb_defeated += 1 nb_power_assignations[game_assigned_powers[0]] += 1 assigned_powers.append(game_assigned_powers) rankings.append(game['ranking']) return players_names, GamesStats(nb_won, nb_most, nb_survived, nb_defeated, \ nb_power_assignations, assigned_powers, rankings)
def get_reward(self, prev_state_proto, state_proto, power_name, is_terminal_state, done_reason): """ Computes the reward for a given power :param prev_state_proto: The `.proto.State` representation of the last state of the game (before .process) :param state_proto: The `.proto.State` representation of the state of the game (after .process) :param power_name: The name of the power for which to calculate the reward (e.g. 'FRANCE'). :param is_terminal_state: Boolean flag to indicate we are at a terminal state. :param done_reason: An instance of DoneReason indicating why the terminal state was reached. :return: The current reward (float) for power_name. :type done_reason: diplomacy_research.models.gym.environment.DoneReason | None """ assert done_reason is None or isinstance( done_reason, DoneReason), 'done_reason must be a DoneReason object.' if power_name not in state_proto.centers or power_name not in prev_state_proto.centers: if power_name not in ALL_POWERS: LOGGER.error('Unknown power %s. Expected powers are: %s', power_name, ALL_POWERS) return 0. map_object = Map(state_proto.map) nb_centers_req_for_win = len(map_object.scs) // 2 + 1. current_centers = set(state_proto.centers[power_name].value) prev_centers = set(prev_state_proto.centers[power_name].value) sc_diff = len(current_centers) - len(prev_centers) if done_reason == DoneReason.THRASHED and current_centers: return -1. return sc_diff / nb_centers_req_for_win
def proto_to_heat_map(state_proto): """ Computes a heat map representation of the state of the game Size: 81 x 7 (NB_LOCS x NB_POWERS) > +1 is added for each locs where power has a unit and for each location that unit can reach > +0.5 is added for each loc where another power has a unit and for each loc that unit can reach :param state_proto: A `.proto.game.State` proto of the state of the game. :return: A heat map representation (NB_LOCS, NB_POWERS) of the state of the game """ map_object = None # Building heat map heat_maps = [] for power_name in state_proto.units: for unit in state_proto.units[power_name].value: cache_key = 'heat_map/{}/{}'.format(power_name, unit) cache_value = CACHE.get(cache_key, None) # Computing heat map for unit and storing in cache if cache_value is None: map_object = map_object or Map(state_proto.map) cache_value = _build_unit_heat_map(map_object, power_name, unit) CACHE[cache_key] = cache_value # Adding the unit heap map heat_maps += [cache_value] # Summing the heat maps and returning return np.sum(heat_maps, axis=0)
def root_data_generator(saved_game_proto, is_validation_set): """ Converts a dataset game to protocol buffer format :param saved_game_proto: A `.proto.game.SavedGame` object from the dataset. :param is_validation_set: Boolean that indicates if we are generating the validation set (otw. training set) :return: A dictionary with phase_ix as key and a dictionary {power_name: (msg_len, proto)} as value """ # Finding top victors and supply centers at end of game map_object = Map(saved_game_proto.map) top_victors = get_top_victors(saved_game_proto, map_object) all_powers = get_map_powers(map_object) nb_phases = len(saved_game_proto.phases) proto_results = {phase_ix: {} for phase_ix in range(nb_phases)} # Getting policy data for the phase_ix # (All powers for training - Top victors for evaluation) policy_data = get_policy_data( saved_game_proto, power_names=all_powers, top_victors=top_victors if is_validation_set else all_powers) value_data = get_value_data(saved_game_proto, all_powers) # Building results for phase_ix in range(nb_phases - 1): for power_name in all_powers: if is_validation_set and power_name not in top_victors: continue phase_policy = policy_data[phase_ix][power_name] phase_value = value_data[phase_ix][power_name] request_id = DatasetBuilder.get_request_id(saved_game_proto, phase_ix, power_name, is_validation_set) data = { 'request_id': request_id, 'player_seed': 0, 'decoder_inputs': [GO_ID], 'noise': 0., 'temperature': 0., 'dropout_rate': 0., 'current_power': POWER_VOCABULARY_KEY_TO_IX[power_name], 'current_season': get_current_season(saved_game_proto.phases[phase_ix].state) } data.update(phase_policy) data.update(phase_value) # Saving results proto_result = BaseDatasetBuilder.build_example( data, BaseDatasetBuilder.get_proto_fields()) proto_results[phase_ix][power_name] = (0, proto_result) # Returning data for buffer return proto_results
def add_cached_states_to_saved_game(saved_game): """ Adds a cached representation of board_state and prev_orders_state to the saved game """ if saved_game['map'].startswith('standard'): map_object = Map(saved_game['map']) for phase in saved_game['phases']: phase['state']['board_state'] = dict_to_flatten_board_state(phase['state'], map_object) if phase['name'][-1] == 'M': phase['prev_orders_state'] = dict_to_flatten_prev_orders_state(phase, map_object) return saved_game
def get_vocabulary(): """ Returns the list of words in the dictionary :return: The list of words in the dictionary """ map_object = Map() locs = sorted([loc.upper() for loc in map_object.locs]) vocab = [PAD_TOKEN, GO_TOKEN, EOS_TOKEN, DRAW_TOKEN] # Utility tokens vocab += [ '<%s>' % power_name for power_name in get_map_powers(map_object) ] # Power names vocab += ['B', 'C', 'D', 'H', 'S', 'VIA', 'WAIVE'] # Order Tokens (excl '-', 'R') vocab += ['- %s' % loc for loc in locs] # Locations with '-' vocab += [ 'A %s' % loc for loc in locs if map_object.is_valid_unit('A %s' % loc) ] # Army Units vocab += [ 'F %s' % loc for loc in locs if map_object.is_valid_unit('F %s' % loc) ] # Fleet Units return vocab
def show_opening_moves_data(proto_dataset_path): """ Displays a list of opening moves for each power on the standard map :param proto_dataset_path: The path to the proto dataset :return: Nothing """ if not os.path.exists(proto_dataset_path): raise RuntimeError('Unable to find Diplomacy dataset at {}'.format(proto_dataset_path)) # Openings dict map_object = Map('standard') openings = {power_name: {} for power_name in get_map_powers(map_object)} # Loading the phases count dataset to get the number of games total = None if os.path.exists(PHASES_COUNT_DATASET_PATH): with open(PHASES_COUNT_DATASET_PATH, 'rb') as file: total = len(pickle.load(file)) progress_bar = tqdm(total=total) # Loading dataset and building database LOGGER.info('... Building an opening move database.') with open(PROTO_DATASET_PATH, 'rb') as proto_dataset: # Reading games while True: saved_game_proto = read_next_proto(SavedGameProto, proto_dataset, compressed=False) if saved_game_proto is None: break progress_bar.update(1) # Only keeping games with the standard map (or its variations) if not saved_game_proto.map.startswith('standard'): continue initial_phase = saved_game_proto.phases[0] for power_name in initial_phase.orders: orders = initial_phase.orders[power_name].value orders = tuple(sorted(orders, key=lambda order: order.split()[1])) # Sorted by location if orders not in openings[power_name]: openings[power_name][orders] = 0 openings[power_name][orders] += 1 # Printing results for power_name in get_map_powers(map_object): print('=' * 80) print(power_name) print() for opening, count in sorted(openings[power_name].items(), key=lambda item: item[1], reverse=True): print(opening, 'Count:', count) # Closing progress_bar.close()
def get_reward(self, prev_state_proto, state_proto, power_name, is_terminal_state, done_reason): """ Computes the reward for a given power :param prev_state_proto: The `.proto.State` representation of the last state of the game (before .process) :param state_proto: The `.proto.State` representation of the state of the game (after .process) :param power_name: The name of the power for which to calculate the reward (e.g. 'FRANCE'). :param is_terminal_state: Boolean flag to indicate we are at a terminal state. :param done_reason: An instance of DoneReason indicating why the terminal state was reached. :return: The current reward (float) for power_name. :type done_reason: diplomacy_research.models.gym.environment.DoneReason | None """ assert done_reason is None or isinstance( done_reason, DoneReason), 'done_reason must be a DoneReason object.' if power_name not in state_proto.centers: if power_name not in ALL_POWERS: LOGGER.error('Unknown power %s. Expected powers are: %s', power_name, ALL_POWERS) return 0. if not is_terminal_state: return 0. if done_reason == DoneReason.THRASHED: return 0. map_object = Map(state_proto.map) nb_centers_req_for_win = len(map_object.scs) // 2 + 1. victors = [ power for power in state_proto.centers if len(state_proto.centers[power].value) >= nb_centers_req_for_win ] if victors: nb_scs = { power: len(state_proto.centers[power].value) for power in state_proto.centers } nb_excess_sc = nb_scs[victors[0]] - nb_centers_req_for_win nb_controlled_sc = sum(nb_scs.values()) - nb_excess_sc if power_name in victors: split_factor = nb_centers_req_for_win / nb_controlled_sc else: split_factor = nb_scs[power_name] / nb_controlled_sc else: survivors = [ power for power in state_proto.centers if state_proto.centers[power].value ] split_factor = 1. / len( survivors) if power_name in survivors else 0. return self.pot_size * split_factor
def get_alignments_index(map_name='standard'): """ Computes a list of nodes index for each possible location e.g. if the sorted list of locs is ['BRE', 'MAR', 'PAR'] would return {'BRE': [0], 'MAR': [1], 'PAR': [2]} """ current_map = Map(map_name) sorted_locs = get_sorted_locs(current_map) alignments_index = {} # Computing the index of each loc for loc in sorted_locs: if loc[:3] in alignments_index: continue alignments_index[loc[:3]] = [ index for index, sorted_loc in enumerate(sorted_locs) if loc[:3] == sorted_loc[:3] ] return alignments_index
def __init__(self, map_name, **kwargs): """ Builds the response :param map_name: The name of the map """ super(MapDefinitionResponse, self).__init__(**kwargs) game_map = Map(map_name) # (Powers): (power power ...) # (Provinces): ((supply_centers) (non_supply_centres)) # (Adjacencies): ((prov_adjacencies) (prov_adjacencies) ...) powers_clause = self._build_powers_clause(game_map) provinces_clause = self._build_provinces_clause(game_map) adjacencies_clause = self._build_adjacencies_clause(game_map) self._bytes = bytes(tokens.MDF) \ + powers_clause \ + provinces_clause \ + adjacencies_clause
def get_reward(self, prev_state_proto, state_proto, power_name, is_terminal_state, done_reason): """ Computes the reward for a given power :param prev_state_proto: The `.proto.State` representation of the last state of the game (before .process) :param state_proto: The `.proto.State` representation of the state of the game (after .process) :param power_name: The name of the power for which to calculate the reward (e.g. 'FRANCE'). :param is_terminal_state: Boolean flag to indicate we are at a terminal state. :param done_reason: An instance of DoneReason indicating why the terminal state was reached. :return: The current reward (float) for power_name. :type done_reason: diplomacy_research.models.gym.environment.DoneReason | None """ assert done_reason is None or isinstance( done_reason, DoneReason), 'done_reason must be a DoneReason object.' if power_name not in state_proto.centers or power_name not in prev_state_proto.centers: if power_name not in ALL_POWERS: LOGGER.error('Unknown power %s. Expected powers are: %s', power_name, ALL_POWERS) return 0. map_object = Map(state_proto.map) nb_centers_req_for_win = len(map_object.scs) // 2 + 1. current_centers = set(state_proto.centers[power_name].value) prev_centers = set(prev_state_proto.centers[power_name].value) all_scs = map_object.scs if done_reason == DoneReason.THRASHED and current_centers: return -1. * nb_centers_req_for_win # Adjusting supply centers for the current phase # Dislodged units don't count for adjustment for unit_power in state_proto.units: if unit_power == power_name: for unit in state_proto.units[unit_power].value: if '*' in unit: continue unit_loc = unit[2:5] if unit_loc in all_scs and unit_loc not in current_centers: current_centers.add(unit_loc) else: for unit in state_proto.units[unit_power].value: if '*' in unit: continue unit_loc = unit[2:5] if unit_loc in all_scs and unit_loc in current_centers: current_centers.remove(unit_loc) # Adjusting supply centers for the previous phase # Dislodged units don't count for adjustment for unit_power in prev_state_proto.units: if unit_power == power_name: for unit in prev_state_proto.units[unit_power].value: if '*' in unit: continue unit_loc = unit[2:5] if unit_loc in all_scs and unit_loc not in prev_centers: prev_centers.add(unit_loc) else: for unit in prev_state_proto.units[unit_power].value: if '*' in unit: continue unit_loc = unit[2:5] if unit_loc in all_scs and unit_loc in prev_centers: prev_centers.remove(unit_loc) # Computing difference gained_centers = current_centers - prev_centers lost_centers = prev_centers - current_centers # Computing reward return float(len(gained_centers) - len(lost_centers))
def get_feedable_item(locs, state_proto, power_name, phase_history_proto, possible_orders_proto, **kwargs): """ Computes and return a feedable item (to be fed into the feedable queue) :param locs: A list of locations for which we want orders :param state_proto: A `.proto.game.State` representation of the state of the game. :param power_name: The power name for which we want the orders and the state values :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases. :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc. :param kwargs: Additional optional kwargs: - player_seed: The seed to apply to the player to compute a deterministic mask. - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon) - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy) - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder. :return: A feedable item, with feature names as key and numpy arrays as values """ # pylint: disable=too-many-branches # Converting to state space map_object = Map(state_proto.map) board_state = proto_to_board_state(state_proto, map_object) # Building the decoder length # For adjustment phase, we restrict the number of builds/disbands to what is allowed by the game engine in_adjustment_phase = state_proto.name[-1] == 'A' nb_builds = state_proto.builds[power_name].count nb_homes = len(state_proto.builds[power_name].homes) # If we are in adjustment phase, making sure the locs are the orderable locs (and not the policy locs) if in_adjustment_phase: orderable_locs, _ = get_orderable_locs_for_powers(state_proto, [power_name]) if sorted(locs) != sorted(orderable_locs): if locs: LOGGER.warning('Adj. phase requires orderable locs. Got %s. Expected %s.', locs, orderable_locs) locs = orderable_locs # WxxxA - We can build units # WxxxA - We can disband units # Other phase if in_adjustment_phase and nb_builds >= 0: decoder_length = min(nb_builds, nb_homes) elif in_adjustment_phase and nb_builds < 0: decoder_length = abs(nb_builds) else: decoder_length = len(locs) # Computing the candidates for the policy if possible_orders_proto: # Adjustment Phase - Use all possible orders for each location. if in_adjustment_phase: # Building a list of all orders for all locations adj_orders = [] for loc in locs: adj_orders += possible_orders_proto[loc].value # Computing the candidates candidates = [get_order_based_mask(adj_orders)] * decoder_length # Regular phase - Compute candidates for each location else: candidates = [] for loc in locs: candidates += [get_order_based_mask(possible_orders_proto[loc].value)] # We don't have possible orders, so we cannot compute candidates # This might be normal if we are only getting the state value or the next message to send else: candidates = [] for _ in range(decoder_length): candidates.append([]) # Prev orders state prev_orders_state = [] for phase_proto in reversed(phase_history_proto): if len(prev_orders_state) == NB_PREV_ORDERS: break if phase_proto.name[-1] == 'M': prev_orders_state = [proto_to_prev_orders_state(phase_proto, map_object)] + prev_orders_state for _ in range(NB_PREV_ORDERS - len(prev_orders_state)): prev_orders_state = [np.zeros((NB_NODES, NB_ORDERS_FEATURES), dtype=np.uint8)] + prev_orders_state prev_orders_state = np.array(prev_orders_state) # Building (order) decoder inputs [GO_ID] decoder_inputs = [GO_ID] # kwargs player_seed = kwargs.get('player_seed', 0) noise = kwargs.get('noise', 0.) temperature = kwargs.get('temperature', 0.) dropout_rate = kwargs.get('dropout_rate', 0.) # Building feedable data item = { 'player_seed': player_seed, 'board_state': board_state, 'board_alignments': get_board_alignments(locs, in_adjustment_phase=in_adjustment_phase, tokens_per_loc=1, decoder_length=decoder_length), 'prev_orders_state': prev_orders_state, 'decoder_inputs': decoder_inputs, 'decoder_lengths': decoder_length, 'candidates': candidates, 'noise': noise, 'temperature': temperature, 'dropout_rate': dropout_rate, 'current_power': POWER_VOCABULARY_KEY_TO_IX[power_name], 'current_season': get_current_season(state_proto) } # Return return item
def get_policy_data(saved_game_proto, power_names, top_victors): """ Computes the proto to save in tf.train.Example as a training example for the policy network :param saved_game_proto: A `.proto.game.SavedGame` object from the dataset. :param power_names: The list of powers for which we want the policy data :param top_victors: The list of powers that ended with more than 25% of the supply centers :return: A dictionary with key: the phase_ix with value: A dict with the power_name as key and a dict with the example fields as value """ nb_phases = len(saved_game_proto.phases) policy_data = {phase_ix: {} for phase_ix in range(nb_phases - 1)} game_id = saved_game_proto.id map_object = Map(saved_game_proto.map) # Determining if we have a draw nb_sc_to_win = len(map_object.scs) // 2 + 1 has_solo_winner = max([len(saved_game_proto.phases[-1].state.centers[power_name].value) for power_name in saved_game_proto.phases[-1].state.centers]) >= nb_sc_to_win survivors = [power_name for power_name in saved_game_proto.phases[-1].state.centers if saved_game_proto.phases[-1].state.centers[power_name].value] has_draw = not has_solo_winner and len(survivors) >= 2 # Processing all phases (except the last one) current_year = 0 for phase_ix in range(nb_phases - 1): # Building a list of orders of previous phases previous_orders_states = [np.zeros((NB_NODES, NB_ORDERS_FEATURES), dtype=np.uint8)] * NB_PREV_ORDERS for phase_proto in saved_game_proto.phases[max(0, phase_ix - NB_PREV_ORDERS_HISTORY):phase_ix]: if phase_proto.name[-1] == 'M': previous_orders_states += [proto_to_prev_orders_state(phase_proto, map_object)] previous_orders_states = previous_orders_states[-NB_PREV_ORDERS:] prev_orders_state = np.array(previous_orders_states) # Parsing each requested power in the specified phase phase_proto = saved_game_proto.phases[phase_ix] phase_name = phase_proto.name state_proto = phase_proto.state phase_board_state = proto_to_board_state(state_proto, map_object) # Increasing year for every spring or when the game is completed if phase_proto.name == 'COMPLETED' or (phase_proto.name[0] == 'S' and phase_proto.name[-1] == 'M'): current_year += 1 for power_name in power_names: phase_issued_orders = get_issued_orders_for_powers(phase_proto, [power_name]) phase_possible_orders = get_possible_orders_for_powers(phase_proto, [power_name]) phase_draw_target = 1. if has_draw and phase_ix == (nb_phases - 2) and power_name in survivors else 0. # Data to use when not learning a policy blank_policy_data = {'board_state': phase_board_state, 'prev_orders_state': prev_orders_state, 'draw_target': phase_draw_target} # Power is not a top victor - We don't want to learn a policy from him if power_name not in top_victors: policy_data[phase_ix][power_name] = blank_policy_data continue # Finding the orderable locs orderable_locations = list(phase_issued_orders[power_name].keys()) # Skipping power for this phase if we are only issuing Hold for order_loc, order in phase_issued_orders[power_name].items(): order_tokens = get_order_tokens(order) if len(order_tokens) >= 2 and order_tokens[1] != 'H': break else: policy_data[phase_ix][power_name] = blank_policy_data continue # Removing orderable locs where orders are not possible (i.e. NO_CHECK games) for order_loc, order in phase_issued_orders[power_name].items(): if order not in phase_possible_orders[order_loc] and order_loc in orderable_locations: if 'NO_CHECK' not in saved_game_proto.rules: LOGGER.warning('%s not in all possible orders. Phase %s - Game %s.', order, phase_name, game_id) orderable_locations.remove(order_loc) # Remove orderable locs where the order is either invalid or not frequent if order_to_ix(order) is None and order_loc in orderable_locations: orderable_locations.remove(order_loc) # Determining if we are in an adjustment phase in_adjustment_phase = state_proto.name[-1] == 'A' nb_builds = state_proto.builds[power_name].count nb_homes = len(state_proto.builds[power_name].homes) # WxxxA - We can build units # WxxxA - We can disband units # Other phase if in_adjustment_phase and nb_builds >= 0: decoder_length = min(nb_builds, nb_homes) elif in_adjustment_phase and nb_builds < 0: decoder_length = abs(nb_builds) else: decoder_length = len(orderable_locations) # Not all units were disbanded - Skipping this power as we can't learn the orders properly if in_adjustment_phase and nb_builds < 0 and len(orderable_locations) < abs(nb_builds): policy_data[phase_ix][power_name] = blank_policy_data continue # Not enough orderable locations for this power, skipping if not orderable_locations or not decoder_length: policy_data[phase_ix][power_name] = blank_policy_data continue # decoder_inputs [GO, order1, order2, order3] decoder_inputs = [GO_ID] decoder_inputs += [order_to_ix(phase_issued_orders[power_name][loc]) for loc in orderable_locations] if in_adjustment_phase and nb_builds > 0: decoder_inputs += [order_to_ix('WAIVE')] * (min(nb_builds, nb_homes) - len(orderable_locations)) decoder_length = min(decoder_length, NB_SUPPLY_CENTERS) # Adjustment Phase - Use all possible orders for each location. if in_adjustment_phase: build_disband_locs = list(get_possible_orders_for_powers(phase_proto, [power_name]).keys()) phase_board_alignments = get_board_alignments(build_disband_locs, in_adjustment_phase=in_adjustment_phase, tokens_per_loc=1, decoder_length=decoder_length) # Building a list of all orders for all locations adj_orders = [] for loc in build_disband_locs: adj_orders += phase_possible_orders[loc] # Not learning builds for BUILD_ANY if nb_builds > 0 and 'BUILD_ANY' in state_proto.rules: adj_orders = [] # No orders found - Skipping if not adj_orders: policy_data[phase_ix][power_name] = blank_policy_data continue # Computing the candidates candidates = [get_order_based_mask(adj_orders)] * decoder_length # Regular phase - Compute candidates for each location else: phase_board_alignments = get_board_alignments(orderable_locations, in_adjustment_phase=in_adjustment_phase, tokens_per_loc=1, decoder_length=decoder_length) candidates = [] for loc in orderable_locations: candidates += [get_order_based_mask(phase_possible_orders[loc])] # Saving results # No need to return temperature, current_power, current_season policy_data[phase_ix][power_name] = {'board_state': phase_board_state, 'board_alignments': phase_board_alignments, 'prev_orders_state': prev_orders_state, 'decoder_inputs': decoder_inputs, 'decoder_lengths': decoder_length, 'candidates': candidates, 'draw_target': phase_draw_target} # Returning return policy_data
def generate_trajectory(players, reward_fn, advantage_fn, env_constructor=None, hparams=None, power_assignments=None, set_player_seed=None, initial_state_bytes=None, update_interval=0, update_queue=None, output_format='proto'): """ Generates a single trajectory (Saved Gamed Proto) for RL (self-play) with the power assigments :param players: A list of instantiated players :param reward_fn: The reward function to use to calculate rewards :param advantage_fn: An instance of `.models.self_play.advantages` :param env_constructor: A callable to get the OpenAI gym environment (args: players) :param hparams: A dictionary of hyper parameters with their values :param power_assignments: Optional. The power name we want to play as. (e.g. 'FRANCE') or a list of powers. :param set_player_seed: Boolean that indicates that we want to set the player seed on reset(). :param initial_state_bytes: A `game.State` proto (in bytes format) representing the initial state of the game. :param update_interval: Optional. If set, a partial saved game is put in the update_queue this every seconds. :param update_queue: Optional. If update interval is set, partial games will be put in this queue :param output_format: The output format. One of 'proto', 'bytes', 'zlib' :return: A SavedGameProto representing the game played (with policy details and power assignments) Depending on format, the output might be converted to a byte array, or a compressed byte array. :type players: List[diplomacy_research.players.player.Player] :type reward_fn: diplomacy_research.models.self_play.reward_functions.AbstractRewardFunction :type advantage_fn: diplomacy_research.models.self_play.advantages.base_advantage.BaseAdvantage :type update_queue: multiprocessing.Queue """ # pylint: disable=too-many-arguments assert output_format in ['proto', 'bytes', 'zlib' ], 'Format should be "proto", "bytes", "zlib"' assert len(players) == NB_POWERS # Making sure we use the SavedGame wrapper to record the game if env_constructor: env = env_constructor(players) else: env = default_env_constructor(players, hparams, power_assignments, set_player_seed, initial_state_bytes) wrapped_env = env while not isinstance(wrapped_env, DiplomacyEnv): if isinstance(wrapped_env, SaveGame): break wrapped_env = wrapped_env.env else: env = SaveGame(env) # Detecting if we have a Auto-Draw wrapper has_auto_draw = False wrapped_env = env while not isinstance(wrapped_env, DiplomacyEnv): if isinstance(wrapped_env, AutoDraw): has_auto_draw = True break wrapped_env = wrapped_env.env # Resetting env env.reset() # Timing vars for partial updates time_last_update = time.time() year_last_update = 0 start_phase_ix = 0 current_phase_ix = 0 nb_transitions = 0 # Cache Variables powers = sorted( [power_name for power_name in get_map_powers(env.game.map)]) assigned_powers = env.get_all_powers_name() stored_board_state = OrderedDict() # {phase_name: board_state} stored_prev_orders_state = OrderedDict() # {phase_name: prev_orders_state} stored_possible_orders = OrderedDict() # {phase_name: possible_orders} power_variables = { power_name: { 'orders': [], 'policy_details': [], 'state_values': [], 'rewards': [], 'returns': [], 'last_state_value': 0. } for power_name in powers } new_state_proto = None phase_history_proto = [] map_object = Map(name=env.game.map.name) # Generating while not env.is_done: state_proto = new_state_proto if new_state_proto is not None else extract_state_proto( env.game) possible_orders_proto = extract_possible_orders_proto(env.game) # Computing board_state board_state = proto_to_board_state(state_proto, map_object).flatten().tolist() state_proto.board_state.extend(board_state) # Storing possible orders for this phase current_phase = env.game.get_current_phase() stored_board_state[current_phase] = board_state stored_possible_orders[current_phase] = possible_orders_proto # Getting orders, policy details, and state value tasks = [(player, state_proto, pow_name, phase_history_proto[-NB_PREV_ORDERS_HISTORY:], possible_orders_proto) for player, pow_name in zip(env.players, assigned_powers)] step_args = yield [get_step_args(*args) for args in tasks] # Stepping through env, storing power variables for power_name, (orders, policy_details, state_value) in zip(assigned_powers, step_args): if orders: env.step((power_name, orders)) nb_transitions += 1 if has_auto_draw and policy_details is not None: env.set_draw_prob(power_name, policy_details['draw_prob']) # Processing env.process() current_phase_ix += 1 # Retrieving draw action and saving power variables for power_name, (orders, policy_details, state_value) in zip(assigned_powers, step_args): if has_auto_draw and policy_details is not None: policy_details['draw_action'] = env.get_draw_actions( )[power_name] power_variables[power_name]['orders'] += [orders] power_variables[power_name]['policy_details'] += [policy_details] power_variables[power_name]['state_values'] += [state_value] # Getting new state new_state_proto = extract_state_proto(env.game) # Storing reward for this transition done_reason = DoneReason(env.done_reason) if env.done_reason else None for power_name in powers: power_variables[power_name]['rewards'] += [ reward_fn.get_reward(prev_state_proto=state_proto, state_proto=new_state_proto, power_name=power_name, is_terminal_state=done_reason is not None, done_reason=done_reason) ] # Computing prev_orders_state for the previous state last_phase_proto = extract_phase_history_proto( env.game, nb_previous_phases=1)[-1] if last_phase_proto.name[-1] == 'M': prev_orders_state = proto_to_prev_orders_state( last_phase_proto, map_object).flatten().tolist() stored_prev_orders_state[last_phase_proto.name] = prev_orders_state last_phase_proto.prev_orders_state.extend(prev_orders_state) phase_history_proto += [last_phase_proto] # Sending partial game if: # 1) We have update_interval > 0 with an update queue, and # 2a) The game is completed, or 2b) the update time has elapsted and at least 5 years as passed has_update_interval = update_interval > 0 and update_queue is not None game_is_completed = env.is_done min_time_has_passed = time.time() - time_last_update > update_interval current_year = 9999 if env.game.get_current_phase( ) == 'COMPLETED' else int(env.game.get_current_phase()[1:5]) min_years_have_passed = current_year - year_last_update >= 5 if (has_update_interval and (game_is_completed or (min_time_has_passed and min_years_have_passed))): # Game is completed - last state value is 0 if game_is_completed: for power_name in powers: power_variables[power_name]['last_state_value'] = 0. # Otherwise - Querying the model for the value of the last state else: tasks = [ (player, new_state_proto, pow_name, phase_history_proto[-NB_PREV_ORDERS_HISTORY:], possible_orders_proto) for player, pow_name in zip(env.players, assigned_powers) ] last_state_values = yield [ get_state_value(*args) for args in tasks ] for power_name, last_state_value in zip( assigned_powers, last_state_values): power_variables[power_name][ 'last_state_value'] = last_state_value # Getting partial game and sending it on the update_queue saved_game_proto = get_saved_game_proto( env=env, players=players, stored_board_state=stored_board_state, stored_prev_orders_state=stored_prev_orders_state, stored_possible_orders=stored_possible_orders, power_variables=power_variables, start_phase_ix=start_phase_ix, reward_fn=reward_fn, advantage_fn=advantage_fn, is_partial_game=True) update_queue.put_nowait( (False, nb_transitions, proto_to_bytes(saved_game_proto))) # Updating stats start_phase_ix = current_phase_ix nb_transitions = 0 if not env.is_done: year_last_update = int(env.game.get_current_phase()[1:5]) # Since the environment is done (Completed game) - We can leave the last_state_value at 0. for power_name in powers: power_variables[power_name]['last_state_value'] = 0. # Getting completed game saved_game_proto = get_saved_game_proto( env=env, players=players, stored_board_state=stored_board_state, stored_prev_orders_state=stored_prev_orders_state, stored_possible_orders=stored_possible_orders, power_variables=power_variables, start_phase_ix=0, reward_fn=reward_fn, advantage_fn=advantage_fn, is_partial_game=False) # Converting to correct format output = { 'proto': lambda proto: proto, 'zlib': proto_to_zlib, 'bytes': proto_to_bytes }[output_format](saved_game_proto) # Returning return output
def _build_from_string(self, order, game=None): """ Builds this object from a string :type order: str :type game: diplomacy.Game """ # pylint: disable=too-many-return-statements,too-many-branches,too-many-statements # Converting move to retreat during retreat phase if self.phase_type == 'R': order = order.replace(' - ', ' R ') # Splitting into parts words = order.split() # --- Wait / Waive --- # [{"id": "56", "unitID": null, "type": "Wait", "toTerrID": "", "fromTerrID": "", "viaConvoy": ""}] if len(words) == 1 and words[0] == 'WAIVE': self.order_str = 'WAIVE' self.order_dict = {'terrID': None, 'unitType': '', 'type': 'Wait', 'toTerrID': '', 'fromTerrID': '', 'viaConvoy': ''} return # Validating if len(words) < 3: LOGGER.error('Unable to parse the order "%s". Require at least 3 words', order) return short_unit_type, loc_name, order_type = words[:3] if short_unit_type not in 'AF': LOGGER.error('Unable to parse the order "%s". Valid unit types are "A" and "F".', order) return if order_type not in 'H-SCRBD': LOGGER.error('Unable to parse the order "%s". Valid order types are H-SCRBD', order) return if loc_name not in CACHE[self.map_name]['loc_to_ix']: LOGGER.error('Received invalid loc "%s" for map "%s".', loc_name, self.map_name) return # Extracting territories unit_type = {'A': 'Army', 'F': 'Fleet'}[short_unit_type] terr_id = CACHE[self.map_name]['loc_to_ix'][loc_name] # --- Hold --- # {"id": "76", "unitID": "19", "type": "Hold", "toTerrID": "", "fromTerrID": "", "viaConvoy": ""} if order_type == 'H': self.order_str = '%s %s H' % (short_unit_type, loc_name) self.order_dict = {'terrID': terr_id, 'unitType': unit_type, 'type': 'Hold', 'toTerrID': '', 'fromTerrID': '', 'viaConvoy': ''} # --- Move --- # {"id": "73", "unitID": "16", "type": "Move", "toTerrID": "25", "fromTerrID": "", "viaConvoy": "Yes", # "convoyPath": ["22", "69"]}, # {"id": "74", "unitID": "17", "type": "Move", "toTerrID": "69", "fromTerrID": "", "viaConvoy": "No"} elif order_type == '-': if len(words) < 4: LOGGER.error('[Move] Unable to parse the move order "%s". Require at least 4 words', order) LOGGER.error(order) return # Building map map_object = Map(self.map_name) convoy_path = [] # Getting destination to_loc_name = words[3] to_terr_id = CACHE[self.map_name]['loc_to_ix'].get(to_loc_name, None) # Deciding if this move is doable by convoy or not if unit_type != 'Army': via_flag = '' else: # Any plausible convoy path (i.e. where fleets are on water, even though they are not convoying) # is valid for the 'convoyPath' argument reachable_by_land = map_object.abuts('A', loc_name, '-', to_loc_name) via_convoy = bool(words[-1] == 'VIA') or not reachable_by_land via_flag = ' VIA' if via_convoy else '' convoy_path = find_convoy_path(loc_name, to_loc_name, map_object, game) if to_loc_name is None: LOGGER.error('[Move] Received invalid to loc "%s" for map "%s".', to_terr_id, self.map_name) LOGGER.error(order) return self.order_str = '%s %s - %s%s' % (short_unit_type, loc_name, to_loc_name, via_flag) self.order_dict = {'terrID': terr_id, 'unitType': unit_type, 'type': 'Move', 'toTerrID': to_terr_id, 'fromTerrID': '', 'viaConvoy': 'Yes' if via_flag else 'No'} if convoy_path: self.order_dict['convoyPath'] = [CACHE[self.map_name]['loc_to_ix'][loc] for loc in convoy_path[:-1]] # --- Support hold --- # {"id": "73", "unitID": "16", "type": "Support hold", "toTerrID": "24", "fromTerrID": "", "viaConvoy": ""} elif order_type == 'S' and '-' not in words: if len(words) < 5: LOGGER.error('[Support H] Unable to parse the support hold order "%s". Require at least 5 words', order) LOGGER.error(order) return # Getting supported unit to_loc_name = words[4][:3] to_terr_id = CACHE[self.map_name]['loc_to_ix'].get(to_loc_name, None) if to_loc_name is None: LOGGER.error('[Support H] Received invalid to loc "%s" for map "%s".', to_terr_id, self.map_name) LOGGER.error(order) return self.order_str = '%s %s S %s' % (short_unit_type, loc_name, to_loc_name) self.order_dict = {'terrID': terr_id, 'unitType': unit_type, 'type': 'Support hold', 'toTerrID': to_terr_id, 'fromTerrID': '', 'viaConvoy': ''} # --- Support move --- # {"id": "73", "unitID": "16", "type": "Support move", "toTerrID": "24", "fromTerrID": "69", "viaConvoy": ""} elif order_type == 'S': if len(words) < 6: LOGGER.error('Unable to parse the support move order "%s". Require at least 6 words', order) return # Building map map_object = Map(self.map_name) convoy_path = [] # Getting supported unit move_index = words.index('-') to_loc_name = words[move_index + 1][:3] # Removing coast from dest from_loc_name = words[move_index - 1] to_terr_id = CACHE[self.map_name]['loc_to_ix'].get(to_loc_name, None) from_terr_id = CACHE[self.map_name]['loc_to_ix'].get(from_loc_name, None) if to_loc_name is None: LOGGER.error('[Support M] Received invalid to loc "%s" for map "%s".', to_terr_id, self.map_name) LOGGER.error(order) return if from_loc_name is None: LOGGER.error('[Support M] Received invalid from loc "%s" for map "%s".', from_terr_id, self.map_name) LOGGER.error(order) return # Deciding if we are support a move by convoy or not # Any plausible convoy path (i.e. where fleets are on water, even though they are not convoying) # is valid for the 'convoyPath' argument, only if it does not include the fleet issuing the support if words[move_index - 2] != 'F' and map_object.area_type(from_loc_name) == 'COAST': convoy_path = find_convoy_path(from_loc_name, to_loc_name, map_object, game, excluding=loc_name) self.order_str = '%s %s S %s - %s' % (short_unit_type, loc_name, from_loc_name, to_loc_name) self.order_dict = {'terrID': terr_id, 'unitType': unit_type, 'type': 'Support move', 'toTerrID': to_terr_id, 'fromTerrID': from_terr_id, 'viaConvoy': ''} if convoy_path: self.order_dict['convoyPath'] = [CACHE[self.map_name]['loc_to_ix'][loc] for loc in convoy_path[:-1]] # --- Convoy --- # {"id": "79", "unitID": "22", "type": "Convoy", "toTerrID": "24", "fromTerrID": "20", "viaConvoy": "", # "convoyPath": ["20", "69"]} elif order_type == 'C': if len(words) < 6: LOGGER.error('[Convoy] Unable to parse the convoy order "%s". Require at least 6 words', order) LOGGER.error(order) return # Building map map_object = Map(self.map_name) # Getting supported unit move_index = words.index('-') to_loc_name = words[move_index + 1] from_loc_name = words[move_index - 1] to_terr_id = CACHE[self.map_name]['loc_to_ix'].get(to_loc_name, None) from_terr_id = CACHE[self.map_name]['loc_to_ix'].get(from_loc_name, None) if to_loc_name is None: LOGGER.error('[Convoy] Received invalid to loc "%s" for map "%s".', to_terr_id, self.map_name) LOGGER.error(order) return if from_loc_name is None: LOGGER.error('[Convoy] Received invalid from loc "%s" for map "%s".', from_terr_id, self.map_name) LOGGER.error(order) return # Finding convoy path # Any plausible convoy path (i.e. where fleets are on water, even though they are not convoying) # is valid for the 'convoyPath' argument, only if it includes the current fleet issuing the convoy order convoy_path = find_convoy_path(from_loc_name, to_loc_name, map_object, game, including=loc_name) self.order_str = '%s %s C A %s - %s' % (short_unit_type, loc_name, from_loc_name, to_loc_name) self.order_dict = {'terrID': terr_id, 'unitType': unit_type, 'type': 'Convoy', 'toTerrID': to_terr_id, 'fromTerrID': from_terr_id, 'viaConvoy': ''} if convoy_path: self.order_dict['convoyPath'] = [CACHE[self.map_name]['loc_to_ix'][loc] for loc in convoy_path[:-1]] # --- Retreat --- # {"id": "152", "unitID": "18", "type": "Retreat", "toTerrID": "75", "fromTerrID": "", "viaConvoy": ""} elif order_type == 'R': if len(words) < 4: LOGGER.error('[Retreat] Unable to parse the move order "%s". Require at least 4 words', order) LOGGER.error(order) return # Getting destination to_loc_name = words[3] to_terr_id = CACHE[self.map_name]['loc_to_ix'].get(to_loc_name, None) if to_loc_name is None: return self.order_str = '%s %s R %s' % (short_unit_type, loc_name, to_loc_name) self.order_dict = {'terrID': terr_id, 'unitType': unit_type, 'type': 'Retreat', 'toTerrID': to_terr_id, 'fromTerrID': '', 'viaConvoy': ''} # --- Disband (R phase) --- # {"id": "152", "unitID": "18", "type": "Disband", "toTerrID": "", "fromTerrID": "", "viaConvoy": ""} elif order_type == 'D' and self.phase_type == 'R': # Note: For R phase, we disband with the coast self.order_str = '%s %s D' % (short_unit_type, loc_name) self.order_dict = {'terrID': terr_id, 'unitType': unit_type, 'type': 'Disband', 'toTerrID': '', 'fromTerrID': '', 'viaConvoy': ''} # --- Build Army --- # [{"id": "56", "unitID": null, "type": "Build Army", "toTerrID": "37", "fromTerrID": "", "viaConvoy": ""}] elif order_type == 'B' and short_unit_type == 'A': self.order_str = 'A %s B' % loc_name self.order_dict = {'terrID': terr_id, 'unitType': 'Army', 'type': 'Build Army', 'toTerrID': terr_id, 'fromTerrID': '', 'viaConvoy': ''} # -- Build Fleet --- # [{"id": "56", "unitID": null, "type": "Build Fleet", "toTerrID": "37", "fromTerrID": "", "viaConvoy": ""}] elif order_type == 'B' and short_unit_type == 'F': self.order_str = 'F %s B' % loc_name self.order_dict = {'terrID': terr_id, 'unitType': 'Fleet', 'type': 'Build Fleet', 'toTerrID': terr_id, 'fromTerrID': '', 'viaConvoy': ''} # Disband (A phase) # {"id": "152", "unitID": null, "type": "Destroy", "toTerrID": "18", "fromTerrID": "", "viaConvoy": ""} elif order_type == 'D': # For A phase, we disband without the coast loc_name = loc_name[:3] terr_id = CACHE[self.map_name]['loc_to_ix'][loc_name] self.order_str = '%s %s D' % (short_unit_type, loc_name) self.order_dict = {'terrID': terr_id, 'unitType': unit_type, 'type': 'Destroy', 'toTerrID': terr_id, 'fromTerrID': '', 'viaConvoy': ''}
def get_order_vocabulary(): """ Computes the list of all valid orders on the standard map :return: A sorted list of all valid orders on the standard map """ # pylint: disable=too-many-nested-blocks,too-many-branches categories = [ 'H', 'D', 'B', '-', 'R', 'SH', 'S-', '-1', 'S1', 'C1', # Move, Support, Convoy (using 1 fleet) '-2', 'S2', 'C2', # Move, Support, Convoy (using 2 fleets) '-3', 'S3', 'C3', # Move, Support, Convoy (using 3 fleets) '-4', 'S4', 'C4' ] # Move, Support, Convoy (using 4 fleets) orders = {category: set() for category in categories} map_object = Map() locs = sorted([loc.upper() for loc in map_object.locs]) # All holds, builds, and disbands orders for loc in locs: for unit_type in ['A', 'F']: if map_object.is_valid_unit('%s %s' % (unit_type, loc)): orders['H'].add('%s %s H' % (unit_type, loc)) orders['D'].add('%s %s D' % (unit_type, loc)) # Allowing builds in all SCs (even though only homes will likely be used) if loc[:3] in map_object.scs: orders['B'].add('%s %s B' % (unit_type, loc)) # Moves, Retreats, Support Holds for unit_loc in locs: for dest in [ loc.upper() for loc in map_object.abut_list(unit_loc, incl_no_coast=True) ]: for unit_type in ['A', 'F']: if not map_object.is_valid_unit('%s %s' % (unit_type, unit_loc)): continue if map_object.abuts(unit_type, unit_loc, '-', dest): orders['-'].add('%s %s - %s' % (unit_type, unit_loc, dest)) orders['R'].add('%s %s R %s' % (unit_type, unit_loc, dest)) # Making sure we can support destination if not (map_object.abuts(unit_type, unit_loc, 'S', dest) or map_object.abuts(unit_type, unit_loc, 'S', dest[:3])): continue # Support Hold for dest_unit_type in ['A', 'F']: for coast in ['', '/NC', '/SC', '/EC', '/WC']: if map_object.is_valid_unit( '%s %s%s' % (dest_unit_type, dest, coast)): orders['SH'].add('%s %s S %s %s%s' % (unit_type, unit_loc, dest_unit_type, dest, coast)) # Convoys, Move Via for nb_fleets in map_object.convoy_paths: # Skipping long-term convoys if nb_fleets > 4: continue for start, fleets, dests in map_object.convoy_paths[nb_fleets]: for end in dests: orders['-%d' % nb_fleets].add('A %s - %s VIA' % (start, end)) orders['-%d' % nb_fleets].add('A %s - %s VIA' % (end, start)) for fleet_loc in fleets: orders['C%d' % nb_fleets].add('F %s C A %s - %s' % (fleet_loc, start, end)) orders['C%d' % nb_fleets].add('F %s C A %s - %s' % (fleet_loc, end, start)) # Support Move (Non-Convoyed) for start_loc in locs: for dest_loc in [ loc.upper() for loc in map_object.abut_list(start_loc, incl_no_coast=True) ]: for support_loc in ( map_object.abut_list(dest_loc, incl_no_coast=True) + map_object.abut_list(dest_loc[:3], incl_no_coast=True)): support_loc = support_loc.upper() # A unit cannot support itself if support_loc[:3] == start_loc[:3]: continue # Making sure the src unit can move to dest # and the support unit can also support to dest for src_unit_type in ['A', 'F']: for support_unit_type in ['A', 'F']: if (map_object.abuts(src_unit_type, start_loc, '-', dest_loc) and map_object.abuts( support_unit_type, support_loc, 'S', dest_loc[:3]) and map_object.is_valid_unit( '%s %s' % (src_unit_type, start_loc)) and map_object.is_valid_unit( '%s %s' % (support_unit_type, support_loc))): orders['S-'].add( '%s %s S %s %s - %s' % (support_unit_type, support_loc, src_unit_type, start_loc, dest_loc[:3])) # Support Move (Convoyed) for nb_fleets in map_object.convoy_paths: # Skipping long-term convoys if nb_fleets > 4: continue for start_loc, fleets, ends in map_object.convoy_paths[nb_fleets]: for dest_loc in ends: for support_loc in map_object.abut_list(dest_loc, incl_no_coast=True): support_loc = support_loc.upper() # A unit cannot support itself if support_loc[:3] == start_loc[:3]: continue # A fleet cannot support if it convoys if support_loc in fleets: continue # Making sure the support unit can also support to dest # And that the support unit is not convoying for support_unit_type in ['A', 'F']: if (map_object.abuts(support_unit_type, support_loc, 'S', dest_loc) and map_object.is_valid_unit( '%s %s' % (support_unit_type, support_loc))): orders['S%d' % nb_fleets].add( '%s %s S A %s - %s' % (support_unit_type, support_loc, start_loc, dest_loc[:3])) # Building the list of final orders final_orders = [PAD_TOKEN, GO_TOKEN, EOS_TOKEN, DRAW_TOKEN] final_orders += [ '<%s>' % power_name for power_name in get_map_powers(map_object) ] final_orders += ['WAIVE'] # Sorting each category for category in categories: category_orders = [ order for order in orders[category] if order not in final_orders ] final_orders += list( sorted( category_orders, key=lambda value: ( value.split()[1], # Sorting by loc value))) # Then alphabetically return final_orders
def get_power_vocabulary(): """ Computes a sorted list of powers in the standard map :return: A list of the powers """ standard_map = Map() return sorted([power_name for power_name in standard_map.powers])
def build(): """ Building the hdf5 dataset """ if not os.path.exists(ZIP_DATASET_PATH): raise RuntimeError('Unable to find the zip dataset at %s' % ZIP_DATASET_PATH) # Extracting extract_dir = os.path.join(os.path.dirname(ZIP_DATASET_PATH), 'zip_dataset') if not os.path.exists(extract_dir): LOGGER.info('... Extracting files from zip dataset.') with zipfile.ZipFile(ZIP_DATASET_PATH, 'r') as zip_dataset: zip_dataset.extractall(extract_dir) # Additional information we also want to store map_object = Map() all_powers = get_map_powers(map_object) sc_to_win = len(map_object.scs) // 2 + 1 hash_table = {} # zobrist_hash: [{game_id}/{phase_name}] moves = {} # Moves frequency: {move: [nb_no_press, nb_press]} nb_phases = OrderedDict() # Nb of phases per game end_scs = {'press': {power_name: {nb_sc: [] for nb_sc in range(0, sc_to_win + 1)} for power_name in all_powers}, 'no_press': {power_name: {nb_sc: [] for nb_sc in range(0, sc_to_win + 1)} for power_name in all_powers}} # Building dataset_index = {} LOGGER.info('... Building HDF5 dataset.') with multiprocessing.Pool() as pool: with h5py.File(DATASET_PATH, 'w') as hdf5_dataset, open(PROTO_DATASET_PATH, 'wb') as proto_dataset: for json_file_path in glob.glob(extract_dir + '/*.jsonl'): LOGGER.info('... Processing: %s', json_file_path) category = json_file_path.split('/')[-1].split('.')[0] dataset_index[category] = set() # Processing file using pool with open(json_file_path, 'r') as json_file: lines = json_file.read().splitlines() for game_id, saved_game_zlib in tqdm(pool.imap_unordered(process_game, lines), total=len(lines)): if game_id is None: continue saved_game_proto = zlib_to_proto(saved_game_zlib, SavedGameProto) # Saving to disk hdf5_dataset[game_id] = np.void(saved_game_zlib) write_proto_to_file(proto_dataset, saved_game_proto, compressed=False) dataset_index[category].add(game_id) # Recording additional info get_end_scs_info(saved_game_proto, game_id, all_powers, sc_to_win, end_scs) get_moves_info(saved_game_proto, moves) nb_phases[game_id] = len(saved_game_proto.phases) # Recording hash of each phase for phase in saved_game_proto.phases: hash_table.setdefault(phase.state.zobrist_hash, []) hash_table[phase.state.zobrist_hash] += ['%s/%s' % (game_id, phase.name)] # Storing info to disk with open(DATASET_INDEX_PATH, 'wb') as file: pickle.dump(dataset_index, file, pickle.HIGHEST_PROTOCOL) with open(END_SCS_DATASET_PATH, 'wb') as file: pickle.dump(end_scs, file, pickle.HIGHEST_PROTOCOL) with open(HASH_DATASET_PATH, 'wb') as file: pickle.dump(hash_table, file, pickle.HIGHEST_PROTOCOL) with open(MOVES_COUNT_DATASET_PATH, 'wb') as file: pickle.dump(moves, file, pickle.HIGHEST_PROTOCOL) with open(PHASES_COUNT_DATASET_PATH, 'wb') as file: pickle.dump(nb_phases, file, pickle.HIGHEST_PROTOCOL) # Deleting extract_dir LOGGER.info('... Deleting extracted files.') if os.path.exists(extract_dir): shutil.rmtree(extract_dir, ignore_errors=True) LOGGER.info('... Done building HDF5 dataset.')
def generate_daide_game(players, progress_bar, daide_rules): """ Generate a game """ global _server max_number_of_year = 35 max_year = Map().first_year + max_number_of_year players_ordering = list(range(len(players))) shuffle(players_ordering) power_names = Map().powers power_names = [power_names[idx] for idx in players_ordering] clients = { power_name: player for power_name, player in zip(power_names, players) } nb_daide_players = len([ _ for _, (player, _) in clients.items() if isinstance(player, DaidePlayerPlaceHolder) ]) nb_regular_players = min(1, len(power_names) - nb_daide_players) server_game = ServerGame(n_controls=nb_daide_players + nb_regular_players, rules=daide_rules) server_game.server = _server _server.add_new_game(server_game) reg_power_name, reg_client = None, None for power_name, (player, channel) in clients.items(): if channel: game = yield channel.join_game(game_id=server_game.game_id, power_name=power_name) reg_power_name, reg_client = power_name, ClientWrapper( player, game) clients[power_name] = reg_client elif isinstance(player, Player): server_game.get_power(power_name).set_controlled(player.name) if nb_daide_players: server_port = PORTS_POOL.pop(0) OPEN_PORTS.append(server_port) _server.start_new_daide_server(server_game.game_id, port=server_port) yield gen.sleep(1) for power_name, (player, _) in clients.items(): if isinstance(player, DaidePlayerPlaceHolder): process = launch_daide_client(server_port) clients[power_name] = DaideWrapper(player, process) for attempt_idx in range(30): if server_game.count_controlled_powers() == len(power_names): break yield gen.sleep(10) LOGGER.info('Waiting for DAIDE to connect. - Attempt %d / %d', attempt_idx + 1, 30) else: LOGGER.error('DAIDE is not online after 5 minutes. Aborting.') raise RuntimeError() for power_name, (player, game) in clients.items(): if not game and isinstance(player, Player): server_game.get_power(power_name).set_controlled(strings.DUMMY) if server_game.game_can_start(): _server.start_game(server_game) local_powers = [ power_name for power_name, (player, game) in clients.items() if not game and isinstance(player, Player) ] elimination_orders = {"_0": 0} yield gen.sleep(get_unsync_wait()) try: watched_game = reg_client.channel_game if reg_client else server_game phase = PhaseSplitter(watched_game.get_current_phase()) while watched_game.status != strings.COMPLETED and phase.year < max_year: print('\n=== NEW PHASE ===\n') print(watched_game.get_current_phase()) if reg_client: yield reg_client.channel_game.wait() players_orders = yield [ player.get_orders(server_game, power_name) for power_name, (player, _) in clients.items() if power_name in local_powers ] for power_name, orders in zip(local_powers, players_orders): if phase.type == 'R': orders = [order.replace(' - ', ' R ') for order in orders] orders = [order for order in orders if order != 'WAIVE'] server_game.set_orders(power_name, orders, expand=False) while reg_client and not server_game.get_power( reg_power_name).order_is_set: orders = yield reg_client.player.get_orders( server_game, reg_power_name) print('Sending orders') yield reg_client.channel_game.set_orders(orders=orders) print('All orders sent') if reg_client: yield reg_client.channel_game.no_wait() for attempt_idx in range(120): if phase.input_str != watched_game.get_current_phase() or \ server_game.status != strings.ACTIVE: break if ( attempt_idx + 1 ) % 12 == 0 and phase.input_str != server_game.get_current_phase( ): # Watched game is unsynched watched_game = server_game break LOGGER.info( 'Waiting for the phase to be processed. - Attempt %d / %d', attempt_idx + 1, 120) yield gen.sleep(2.5) else: LOGGER.error('Phase is taking too long to process. Aborting.') raise RuntimeError() elimination_order = max(elimination_orders.values()) + 1 for power_name in power_names: if power_name not in elimination_orders and not server_game.get_power( power_name).units: elimination_orders[power_name] = elimination_order if power_name == reg_power_name: watched_game = server_game if server_game.status != strings.ACTIVE: break phase = PhaseSplitter(watched_game.get_current_phase()) except TimeoutError as timeout: print('Timeout: ', timeout) except Exception as exception: print('Exception: ', exception) finally: _server.stop_daide_server(server_game.game_id) yield gen.sleep(1) for power_name, (player, _) in clients.items(): if isinstance(player, DaidePlayerPlaceHolder): process = clients[power_name].process process.kill() if reg_client: reg_client.channel_game.leave() game = None saved_game = to_saved_game_format(server_game) if server_game.status == strings.COMPLETED or PhaseSplitter( server_game.get_current_phase()).year >= max_year: elimination_orders = [ elimination_orders.get(power_name, 0) for power_name in power_names ] nb_centers = [ len(server_game.get_power(power_name).centers) for power_name in power_names ] game = saved_game game['players_names'] = [player.name for player, _ in players] game['assigned_powers'] = power_names game['ranking'] = compute_ranking(power_names, nb_centers, elimination_orders) with open('game_{}.json'.format(saved_game['id']), 'w') as file: json.dump(saved_game, file) progress_bar.update() return game
def get_policy_data(saved_game_proto, power_names, top_victors): """ Computes the proto to save in tf.train.Example as a training example for the policy network :param saved_game_proto: A `.proto.game.SavedGame` object from the dataset. :param power_names: The list of powers for which we want the policy data :param top_victors: The list of powers that ended with more than 25% of the supply centers :return: A dictionary with key: the phase_ix with value: A dict with the power_name as key and a dict with the example fields as value """ # pylint: disable=too-many-branches nb_phases = len(saved_game_proto.phases) policy_data = {phase_ix: {} for phase_ix in range(nb_phases - 1)} game_id = saved_game_proto.id map_object = Map(saved_game_proto.map) # Determining if we have a draw nb_sc_to_win = len(map_object.scs) // 2 + 1 has_solo_winner = max([len(saved_game_proto.phases[-1].state.centers[power_name].value) for power_name in saved_game_proto.phases[-1].state.centers]) >= nb_sc_to_win survivors = [power_name for power_name in saved_game_proto.phases[-1].state.centers if saved_game_proto.phases[-1].state.centers[power_name].value] has_draw = not has_solo_winner and len(survivors) >= 2 # Processing all phases (except the last one current_year = 0 for phase_ix in range(nb_phases - 1): # Building a list of orders of previous phases previous_orders_states = [np.zeros((NB_NODES, NB_ORDERS_FEATURES), dtype=np.uint8)] * NB_PREV_ORDERS for phase_proto in saved_game_proto.phases[max(0, phase_ix - NB_PREV_ORDERS_HISTORY):phase_ix]: if phase_proto.name[-1] == 'M': previous_orders_states += [proto_to_prev_orders_state(phase_proto, map_object)] previous_orders_states = previous_orders_states[-NB_PREV_ORDERS:] prev_orders_state = np.array(previous_orders_states) # Parsing each requested power in the specified phase phase_proto = saved_game_proto.phases[phase_ix] phase_name = phase_proto.name state_proto = phase_proto.state phase_board_state = proto_to_board_state(state_proto, map_object) # Increasing year for every spring or when the game is completed if phase_proto.name == 'COMPLETED' or (phase_proto.name[0] == 'S' and phase_proto.name[-1] == 'M'): current_year += 1 for power_name in power_names: phase_issued_orders = get_issued_orders_for_powers(phase_proto, [power_name]) phase_possible_orders = get_possible_orders_for_powers(phase_proto, [power_name]) phase_draw_target = 1. if has_draw and phase_ix == (nb_phases - 2) and power_name in survivors else 0. # Data to use when not learning a policy blank_policy_data = {'board_state': phase_board_state, 'prev_orders_state': prev_orders_state, 'draw_target': phase_draw_target} # Power is not a top victor - We don't want to learn a policy from him if power_name not in top_victors: policy_data[phase_ix][power_name] = blank_policy_data continue # Finding the orderable locs orderable_locations = list(phase_issued_orders[power_name].keys()) # Skipping power for this phase if we are only issuing Hold for order_loc, order in phase_issued_orders[power_name].items(): order_tokens = get_order_tokens(order) if len(order_tokens) >= 2 and order_tokens[1] != 'H': break else: policy_data[phase_ix][power_name] = blank_policy_data continue # Removing orderable locs where orders are not possible (i.e. NO_CHECK games) for order_loc, order in phase_issued_orders[power_name].items(): if order not in phase_possible_orders[order_loc]: if 'NO_CHECK' not in saved_game_proto.rules: LOGGER.warning('%s not in all possible orders. Phase %s - Game %s.', order, phase_name, game_id) orderable_locations.remove(order_loc) # Determining if we are in an adjustment phase in_adjustment_phase = state_proto.name[-1] == 'A' nb_builds = state_proto.builds[power_name].count nb_homes = len(state_proto.builds[power_name].homes) # WxxxA - We can build units # WxxxA - We can disband units # Other phase if in_adjustment_phase and nb_builds >= 0: decoder_length = TOKENS_PER_ORDER * min(nb_builds, nb_homes) elif in_adjustment_phase and nb_builds < 0: decoder_length = TOKENS_PER_ORDER * abs(nb_builds) else: decoder_length = TOKENS_PER_ORDER * len(orderable_locations) # Not all units were disbanded - Skipping this power as we can't learn the orders properly if in_adjustment_phase and nb_builds < 0 and len(orderable_locations) < abs(nb_builds): policy_data[phase_ix][power_name] = blank_policy_data continue # Not enough orderable locations for this power, skipping if not orderable_locations or not decoder_length: policy_data[phase_ix][power_name] = blank_policy_data continue # Encoding decoder inputs - Padding each order to 6 tokens # The decoder length should be a multiple of 6, since each order is padded to 6 tokens decoder_inputs = [GO_ID] for loc in orderable_locations[:]: order = phase_issued_orders[power_name][loc] try: tokens = [token_to_ix(order_token) for order_token in get_order_tokens(order)] + [EOS_ID] tokens += [PAD_ID] * (TOKENS_PER_ORDER - len(tokens)) decoder_inputs += tokens except KeyError: LOGGER.warning('[data_generator] Order "%s" is not valid. Skipping location.', order) orderable_locations.remove(loc) # Adding WAIVE orders if in_adjustment_phase and nb_builds > 0: waive_tokens = [token_to_ix('WAIVE'), EOS_ID] + [PAD_ID] * (TOKENS_PER_ORDER - 2) decoder_inputs += waive_tokens * (min(nb_builds, nb_homes) - len(orderable_locations)) decoder_length = min(decoder_length, TOKENS_PER_ORDER * NB_SUPPLY_CENTERS) # Getting decoder mask coords = set() # Adjustment phase, we allow all builds / disbands in all positions if in_adjustment_phase: build_disband_locs = list(get_possible_orders_for_powers(phase_proto, [power_name]).keys()) phase_board_alignments = get_board_alignments(build_disband_locs, in_adjustment_phase=in_adjustment_phase, tokens_per_loc=TOKENS_PER_ORDER, decoder_length=decoder_length) # Building a list of all orders for all locations adj_orders = [] for loc in build_disband_locs: adj_orders += phase_possible_orders[loc] # Not learning builds for BUILD_ANY if nb_builds > 0 and 'BUILD_ANY' in state_proto.rules: adj_orders = [] # No orders found - Skipping if not adj_orders: policy_data[phase_ix][power_name] = blank_policy_data continue # Building a list of coordinates for the decoder mask matrix for loc_ix in range(decoder_length): coords = get_token_based_mask(adj_orders, offset=loc_ix * TOKENS_PER_ORDER, coords=coords) # Regular phase, we mask for each location else: phase_board_alignments = get_board_alignments(orderable_locations, in_adjustment_phase=in_adjustment_phase, tokens_per_loc=TOKENS_PER_ORDER, decoder_length=decoder_length) for loc_ix, loc in enumerate(orderable_locations): coords = get_token_based_mask(phase_possible_orders[loc] or [''], offset=loc_ix * TOKENS_PER_ORDER, coords=coords) # Saving results # No need to return temperature, current_power, current_season policy_data[phase_ix][power_name] = {'board_state': phase_board_state, 'board_alignments': phase_board_alignments, 'prev_orders_state': prev_orders_state, 'decoder_inputs': decoder_inputs, 'decoder_lengths': decoder_length, 'decoder_mask_indices': list(sorted(coords)), 'draw_target': phase_draw_target} # Returning return policy_data