예제 #1
0
    def __init__(self, powers_centers, map_name, **kwargs):
        """ Builds the response
            :param powers_centers: A dict of {power_name: centers} objects
            :param map_name: The name of the map
        """
        super(SupplyCenterResponse, self).__init__(**kwargs)
        remaining_scs = Map(map_name).scs[:]
        all_powers_bytes = []

        # Parsing each power
        for power_name in sorted(powers_centers):
            centers = sorted(powers_centers[power_name])
            power_clause = parse_string(Power, power_name)
            power_bytes = bytes(power_clause)

            for center in centers:
                sc_clause = parse_string(Province, center)
                power_bytes += bytes(sc_clause)
                remaining_scs.remove(center)

            all_powers_bytes += [power_bytes]

        # Parsing unowned center
        uno_token = tokens.UNO
        power_bytes = bytes(uno_token)

        for center in remaining_scs:
            sc_clause = parse_string(Province, center)
            power_bytes += bytes(sc_clause)

        all_powers_bytes += [power_bytes]

        # Storing full response
        self._bytes = bytes(tokens.SCO) \
                      + b''.join([add_parentheses(power_bytes) for power_bytes in all_powers_bytes])
예제 #2
0
def get_adjacency_matrix(map_name='standard'):
    """ Computes the adjacency matrix for map
        :param map_name: The name of the map
        :return: A (nb_nodes) x (nb_nodes) matrix
    """
    if map_name in ADJACENCY_MATRIX:
        return ADJACENCY_MATRIX[map_name]

    # Finding list of all locations
    current_map = Map(map_name)
    locs = get_sorted_locs(current_map)
    adjacencies = np.zeros((len(locs), len(locs)), dtype=np.bool)

    # Building adjacencies between locs
    # Coasts are adjacent to their parent location (without coasts)
    for i, loc_1 in enumerate(locs):
        for j, loc_2 in enumerate(locs):
            if current_map.abuts('A', loc_1, '-', loc_2) or current_map.abuts(
                    'F', loc_1, '-', loc_2):
                adjacencies[i, j] = 1
            if loc_1 != loc_2 and (loc_1[:3] == loc_2 or loc_1 == loc_2[:3]):
                adjacencies[i, j] = 1

    # Storing in cache and returning
    ADJACENCY_MATRIX[map_name] = adjacencies
    return adjacencies
예제 #3
0
def get_stats(games):
    """ Computes stats """
    nb_won, nb_most, nb_survived, nb_defeated = 0, 0, 0, 0
    nb_power_assignations = {power_name: 0 for power_name in Map().powers}
    players_names = next(iter(games))['players_names']
    assigned_powers, rankings = [], []

    for game in games:
        if not game:
            continue

        game_assigned_powers = game['assigned_powers']
        nb_centers = {
            power_name: len(game['phases'][-1]['state']['centers'][power_name])
            for power_name in game_assigned_powers
        }
        if nb_centers[game_assigned_powers[0]] >= 18:
            nb_won += 1
        elif nb_centers[game_assigned_powers[0]] == max(nb_centers.values()):
            nb_most += 1
        elif nb_centers[game_assigned_powers[0]] > 0:
            nb_survived += 1
        else:
            nb_defeated += 1

        nb_power_assignations[game_assigned_powers[0]] += 1
        assigned_powers.append(game_assigned_powers)
        rankings.append(game['ranking'])

    return players_names, GamesStats(nb_won, nb_most, nb_survived, nb_defeated, \
                                     nb_power_assignations, assigned_powers, rankings)
예제 #4
0
    def get_reward(self, prev_state_proto, state_proto, power_name,
                   is_terminal_state, done_reason):
        """ Computes the reward for a given power
            :param prev_state_proto: The `.proto.State` representation of the last state of the game (before .process)
            :param state_proto: The `.proto.State` representation of the state of the game (after .process)
            :param power_name: The name of the power for which to calculate the reward (e.g. 'FRANCE').
            :param is_terminal_state: Boolean flag to indicate we are at a terminal state.
            :param done_reason: An instance of DoneReason indicating why the terminal state was reached.
            :return: The current reward (float) for power_name.
            :type done_reason: diplomacy_research.models.gym.environment.DoneReason | None
        """
        assert done_reason is None or isinstance(
            done_reason,
            DoneReason), 'done_reason must be a DoneReason object.'
        if power_name not in state_proto.centers or power_name not in prev_state_proto.centers:
            if power_name not in ALL_POWERS:
                LOGGER.error('Unknown power %s. Expected powers are: %s',
                             power_name, ALL_POWERS)
            return 0.

        map_object = Map(state_proto.map)
        nb_centers_req_for_win = len(map_object.scs) // 2 + 1.
        current_centers = set(state_proto.centers[power_name].value)
        prev_centers = set(prev_state_proto.centers[power_name].value)
        sc_diff = len(current_centers) - len(prev_centers)

        if done_reason == DoneReason.THRASHED and current_centers:
            return -1.

        return sc_diff / nb_centers_req_for_win
예제 #5
0
def proto_to_heat_map(state_proto):
    """ Computes a heat map representation of the state of the game

        Size: 81 x 7    (NB_LOCS x NB_POWERS)
        > +1 is added for each locs where power has a unit and for each location that unit can reach
        > +0.5 is added for each loc where another power has a unit and for each loc that unit can reach

        :param state_proto: A `.proto.game.State` proto of the state of the game.
        :return: A heat map representation (NB_LOCS, NB_POWERS) of the state of the game
    """
    map_object = None

    # Building heat map
    heat_maps = []
    for power_name in state_proto.units:
        for unit in state_proto.units[power_name].value:
            cache_key = 'heat_map/{}/{}'.format(power_name, unit)
            cache_value = CACHE.get(cache_key, None)

            # Computing heat map for unit and storing in cache
            if cache_value is None:
                map_object = map_object or Map(state_proto.map)
                cache_value = _build_unit_heat_map(map_object, power_name,
                                                   unit)
                CACHE[cache_key] = cache_value

            # Adding the unit heap map
            heat_maps += [cache_value]

    # Summing the heat maps and returning
    return np.sum(heat_maps, axis=0)
예제 #6
0
def root_data_generator(saved_game_proto, is_validation_set):
    """ Converts a dataset game to protocol buffer format
        :param saved_game_proto: A `.proto.game.SavedGame` object from the dataset.
        :param is_validation_set: Boolean that indicates if we are generating the validation set (otw. training set)
        :return: A dictionary with phase_ix as key and a dictionary {power_name: (msg_len, proto)} as value
    """
    # Finding top victors and supply centers at end of game
    map_object = Map(saved_game_proto.map)
    top_victors = get_top_victors(saved_game_proto, map_object)
    all_powers = get_map_powers(map_object)
    nb_phases = len(saved_game_proto.phases)
    proto_results = {phase_ix: {} for phase_ix in range(nb_phases)}

    # Getting policy data for the phase_ix
    # (All powers for training - Top victors for evaluation)
    policy_data = get_policy_data(
        saved_game_proto,
        power_names=all_powers,
        top_victors=top_victors if is_validation_set else all_powers)
    value_data = get_value_data(saved_game_proto, all_powers)

    # Building results
    for phase_ix in range(nb_phases - 1):
        for power_name in all_powers:
            if is_validation_set and power_name not in top_victors:
                continue
            phase_policy = policy_data[phase_ix][power_name]
            phase_value = value_data[phase_ix][power_name]

            request_id = DatasetBuilder.get_request_id(saved_game_proto,
                                                       phase_ix, power_name,
                                                       is_validation_set)
            data = {
                'request_id':
                request_id,
                'player_seed':
                0,
                'decoder_inputs': [GO_ID],
                'noise':
                0.,
                'temperature':
                0.,
                'dropout_rate':
                0.,
                'current_power':
                POWER_VOCABULARY_KEY_TO_IX[power_name],
                'current_season':
                get_current_season(saved_game_proto.phases[phase_ix].state)
            }
            data.update(phase_policy)
            data.update(phase_value)

            # Saving results
            proto_result = BaseDatasetBuilder.build_example(
                data, BaseDatasetBuilder.get_proto_fields())
            proto_results[phase_ix][power_name] = (0, proto_result)

    # Returning data for buffer
    return proto_results
예제 #7
0
def add_cached_states_to_saved_game(saved_game):
    """ Adds a cached representation of board_state and prev_orders_state to the saved game """
    if saved_game['map'].startswith('standard'):
        map_object = Map(saved_game['map'])
        for phase in saved_game['phases']:
            phase['state']['board_state'] = dict_to_flatten_board_state(phase['state'], map_object)
            if phase['name'][-1] == 'M':
                phase['prev_orders_state'] = dict_to_flatten_prev_orders_state(phase, map_object)
    return saved_game
예제 #8
0
def get_vocabulary():
    """ Returns the list of words in the dictionary
        :return: The list of words in the dictionary
    """
    map_object = Map()
    locs = sorted([loc.upper() for loc in map_object.locs])

    vocab = [PAD_TOKEN, GO_TOKEN, EOS_TOKEN, DRAW_TOKEN]  # Utility tokens
    vocab += [
        '<%s>' % power_name for power_name in get_map_powers(map_object)
    ]  # Power names
    vocab += ['B', 'C', 'D', 'H', 'S', 'VIA',
              'WAIVE']  # Order Tokens (excl '-', 'R')
    vocab += ['- %s' % loc for loc in locs]  # Locations with '-'
    vocab += [
        'A %s' % loc for loc in locs if map_object.is_valid_unit('A %s' % loc)
    ]  # Army Units
    vocab += [
        'F %s' % loc for loc in locs if map_object.is_valid_unit('F %s' % loc)
    ]  # Fleet Units
    return vocab
예제 #9
0
def show_opening_moves_data(proto_dataset_path):
    """ Displays a list of opening moves for each power on the standard map
        :param proto_dataset_path: The path to the proto dataset
        :return: Nothing
    """
    if not os.path.exists(proto_dataset_path):
        raise RuntimeError('Unable to find Diplomacy dataset at {}'.format(proto_dataset_path))

    # Openings dict
    map_object = Map('standard')
    openings = {power_name: {} for power_name in get_map_powers(map_object)}

    # Loading the phases count dataset to get the number of games
    total = None
    if os.path.exists(PHASES_COUNT_DATASET_PATH):
        with open(PHASES_COUNT_DATASET_PATH, 'rb') as file:
            total = len(pickle.load(file))
    progress_bar = tqdm(total=total)

    # Loading dataset and building database
    LOGGER.info('... Building an opening move database.')
    with open(PROTO_DATASET_PATH, 'rb') as proto_dataset:

        # Reading games
        while True:
            saved_game_proto = read_next_proto(SavedGameProto, proto_dataset, compressed=False)
            if saved_game_proto is None:
                break
            progress_bar.update(1)

            # Only keeping games with the standard map (or its variations)
            if not saved_game_proto.map.startswith('standard'):
                continue

            initial_phase = saved_game_proto.phases[0]
            for power_name in initial_phase.orders:
                orders = initial_phase.orders[power_name].value
                orders = tuple(sorted(orders, key=lambda order: order.split()[1]))      # Sorted by location
                if orders not in openings[power_name]:
                    openings[power_name][orders] = 0
                openings[power_name][orders] += 1

    # Printing results
    for power_name in get_map_powers(map_object):
        print('=' * 80)
        print(power_name)
        print()
        for opening, count in sorted(openings[power_name].items(), key=lambda item: item[1], reverse=True):
            print(opening, 'Count:', count)

    # Closing
    progress_bar.close()
예제 #10
0
    def get_reward(self, prev_state_proto, state_proto, power_name,
                   is_terminal_state, done_reason):
        """ Computes the reward for a given power
            :param prev_state_proto: The `.proto.State` representation of the last state of the game (before .process)
            :param state_proto: The `.proto.State` representation of the state of the game (after .process)
            :param power_name: The name of the power for which to calculate the reward (e.g. 'FRANCE').
            :param is_terminal_state: Boolean flag to indicate we are at a terminal state.
            :param done_reason: An instance of DoneReason indicating why the terminal state was reached.
            :return: The current reward (float) for power_name.
            :type done_reason: diplomacy_research.models.gym.environment.DoneReason | None
        """
        assert done_reason is None or isinstance(
            done_reason,
            DoneReason), 'done_reason must be a DoneReason object.'
        if power_name not in state_proto.centers:
            if power_name not in ALL_POWERS:
                LOGGER.error('Unknown power %s. Expected powers are: %s',
                             power_name, ALL_POWERS)
            return 0.
        if not is_terminal_state:
            return 0.
        if done_reason == DoneReason.THRASHED:
            return 0.

        map_object = Map(state_proto.map)
        nb_centers_req_for_win = len(map_object.scs) // 2 + 1.
        victors = [
            power for power in state_proto.centers
            if len(state_proto.centers[power].value) >= nb_centers_req_for_win
        ]

        if victors:
            nb_scs = {
                power: len(state_proto.centers[power].value)
                for power in state_proto.centers
            }
            nb_excess_sc = nb_scs[victors[0]] - nb_centers_req_for_win
            nb_controlled_sc = sum(nb_scs.values()) - nb_excess_sc
            if power_name in victors:
                split_factor = nb_centers_req_for_win / nb_controlled_sc
            else:
                split_factor = nb_scs[power_name] / nb_controlled_sc
        else:
            survivors = [
                power for power in state_proto.centers
                if state_proto.centers[power].value
            ]
            split_factor = 1. / len(
                survivors) if power_name in survivors else 0.
        return self.pot_size * split_factor
예제 #11
0
def get_alignments_index(map_name='standard'):
    """ Computes a list of nodes index for each possible location
        e.g. if the sorted list of locs is ['BRE', 'MAR', 'PAR'] would return {'BRE': [0], 'MAR': [1], 'PAR': [2]}
    """
    current_map = Map(map_name)
    sorted_locs = get_sorted_locs(current_map)
    alignments_index = {}

    # Computing the index of each loc
    for loc in sorted_locs:
        if loc[:3] in alignments_index:
            continue
        alignments_index[loc[:3]] = [
            index for index, sorted_loc in enumerate(sorted_locs)
            if loc[:3] == sorted_loc[:3]
        ]
    return alignments_index
예제 #12
0
    def __init__(self, map_name, **kwargs):
        """ Builds the response
            :param map_name: The name of the map
        """
        super(MapDefinitionResponse, self).__init__(**kwargs)
        game_map = Map(map_name)

        # (Powers): (power power ...)
        # (Provinces): ((supply_centers) (non_supply_centres))
        # (Adjacencies): ((prov_adjacencies) (prov_adjacencies) ...)
        powers_clause = self._build_powers_clause(game_map)
        provinces_clause = self._build_provinces_clause(game_map)
        adjacencies_clause = self._build_adjacencies_clause(game_map)

        self._bytes = bytes(tokens.MDF) \
                      + powers_clause \
                      + provinces_clause \
                      + adjacencies_clause
예제 #13
0
    def get_reward(self, prev_state_proto, state_proto, power_name,
                   is_terminal_state, done_reason):
        """ Computes the reward for a given power
            :param prev_state_proto: The `.proto.State` representation of the last state of the game (before .process)
            :param state_proto: The `.proto.State` representation of the state of the game (after .process)
            :param power_name: The name of the power for which to calculate the reward (e.g. 'FRANCE').
            :param is_terminal_state: Boolean flag to indicate we are at a terminal state.
            :param done_reason: An instance of DoneReason indicating why the terminal state was reached.
            :return: The current reward (float) for power_name.
            :type done_reason: diplomacy_research.models.gym.environment.DoneReason | None
        """
        assert done_reason is None or isinstance(
            done_reason,
            DoneReason), 'done_reason must be a DoneReason object.'
        if power_name not in state_proto.centers or power_name not in prev_state_proto.centers:
            if power_name not in ALL_POWERS:
                LOGGER.error('Unknown power %s. Expected powers are: %s',
                             power_name, ALL_POWERS)
            return 0.

        map_object = Map(state_proto.map)
        nb_centers_req_for_win = len(map_object.scs) // 2 + 1.
        current_centers = set(state_proto.centers[power_name].value)
        prev_centers = set(prev_state_proto.centers[power_name].value)
        all_scs = map_object.scs

        if done_reason == DoneReason.THRASHED and current_centers:
            return -1. * nb_centers_req_for_win

        # Adjusting supply centers for the current phase
        # Dislodged units don't count for adjustment
        for unit_power in state_proto.units:
            if unit_power == power_name:
                for unit in state_proto.units[unit_power].value:
                    if '*' in unit:
                        continue
                    unit_loc = unit[2:5]
                    if unit_loc in all_scs and unit_loc not in current_centers:
                        current_centers.add(unit_loc)
            else:
                for unit in state_proto.units[unit_power].value:
                    if '*' in unit:
                        continue
                    unit_loc = unit[2:5]
                    if unit_loc in all_scs and unit_loc in current_centers:
                        current_centers.remove(unit_loc)

        # Adjusting supply centers for the previous phase
        # Dislodged units don't count for adjustment
        for unit_power in prev_state_proto.units:
            if unit_power == power_name:
                for unit in prev_state_proto.units[unit_power].value:
                    if '*' in unit:
                        continue
                    unit_loc = unit[2:5]
                    if unit_loc in all_scs and unit_loc not in prev_centers:
                        prev_centers.add(unit_loc)
            else:
                for unit in prev_state_proto.units[unit_power].value:
                    if '*' in unit:
                        continue
                    unit_loc = unit[2:5]
                    if unit_loc in all_scs and unit_loc in prev_centers:
                        prev_centers.remove(unit_loc)

        # Computing difference
        gained_centers = current_centers - prev_centers
        lost_centers = prev_centers - current_centers

        # Computing reward
        return float(len(gained_centers) - len(lost_centers))
예제 #14
0
    def get_feedable_item(locs, state_proto, power_name, phase_history_proto, possible_orders_proto, **kwargs):
        """ Computes and return a feedable item (to be fed into the feedable queue)
            :param locs: A list of locations for which we want orders
            :param state_proto: A `.proto.game.State` representation of the state of the game.
            :param power_name: The power name for which we want the orders and the state values
            :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases.
            :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc.
            :param kwargs: Additional optional kwargs:
                - player_seed: The seed to apply to the player to compute a deterministic mask.
                - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon)
                - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy)
                - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder.
            :return: A feedable item, with feature names as key and numpy arrays as values
        """
        # pylint: disable=too-many-branches
        # Converting to state space
        map_object = Map(state_proto.map)
        board_state = proto_to_board_state(state_proto, map_object)

        # Building the decoder length
        # For adjustment phase, we restrict the number of builds/disbands to what is allowed by the game engine
        in_adjustment_phase = state_proto.name[-1] == 'A'
        nb_builds = state_proto.builds[power_name].count
        nb_homes = len(state_proto.builds[power_name].homes)

        # If we are in adjustment phase, making sure the locs are the orderable locs (and not the policy locs)
        if in_adjustment_phase:
            orderable_locs, _ = get_orderable_locs_for_powers(state_proto, [power_name])
            if sorted(locs) != sorted(orderable_locs):
                if locs:
                    LOGGER.warning('Adj. phase requires orderable locs. Got %s. Expected %s.', locs, orderable_locs)
                locs = orderable_locs

        # WxxxA - We can build units
        # WxxxA - We can disband units
        # Other phase
        if in_adjustment_phase and nb_builds >= 0:
            decoder_length = min(nb_builds, nb_homes)
        elif in_adjustment_phase and nb_builds < 0:
            decoder_length = abs(nb_builds)
        else:
            decoder_length = len(locs)

        # Computing the candidates for the policy
        if possible_orders_proto:

            # Adjustment Phase - Use all possible orders for each location.
            if in_adjustment_phase:

                # Building a list of all orders for all locations
                adj_orders = []
                for loc in locs:
                    adj_orders += possible_orders_proto[loc].value

                # Computing the candidates
                candidates = [get_order_based_mask(adj_orders)] * decoder_length

            # Regular phase - Compute candidates for each location
            else:
                candidates = []
                for loc in locs:
                    candidates += [get_order_based_mask(possible_orders_proto[loc].value)]

        # We don't have possible orders, so we cannot compute candidates
        # This might be normal if we are only getting the state value or the next message to send
        else:
            candidates = []
            for _ in range(decoder_length):
                candidates.append([])

        # Prev orders state
        prev_orders_state = []
        for phase_proto in reversed(phase_history_proto):
            if len(prev_orders_state) == NB_PREV_ORDERS:
                break
            if phase_proto.name[-1] == 'M':
                prev_orders_state = [proto_to_prev_orders_state(phase_proto, map_object)] + prev_orders_state
        for _ in range(NB_PREV_ORDERS - len(prev_orders_state)):
            prev_orders_state = [np.zeros((NB_NODES, NB_ORDERS_FEATURES), dtype=np.uint8)] + prev_orders_state
        prev_orders_state = np.array(prev_orders_state)

        # Building (order) decoder inputs [GO_ID]
        decoder_inputs = [GO_ID]

        # kwargs
        player_seed = kwargs.get('player_seed', 0)
        noise = kwargs.get('noise', 0.)
        temperature = kwargs.get('temperature', 0.)
        dropout_rate = kwargs.get('dropout_rate', 0.)

        # Building feedable data
        item = {
            'player_seed': player_seed,
            'board_state': board_state,
            'board_alignments': get_board_alignments(locs,
                                                     in_adjustment_phase=in_adjustment_phase,
                                                     tokens_per_loc=1,
                                                     decoder_length=decoder_length),
            'prev_orders_state': prev_orders_state,
            'decoder_inputs': decoder_inputs,
            'decoder_lengths': decoder_length,
            'candidates': candidates,
            'noise': noise,
            'temperature': temperature,
            'dropout_rate': dropout_rate,
            'current_power': POWER_VOCABULARY_KEY_TO_IX[power_name],
            'current_season': get_current_season(state_proto)
        }

        # Return
        return item
예제 #15
0
def get_policy_data(saved_game_proto, power_names, top_victors):
    """ Computes the proto to save in tf.train.Example as a training example for the policy network
        :param saved_game_proto: A `.proto.game.SavedGame` object from the dataset.
        :param power_names: The list of powers for which we want the policy data
        :param top_victors: The list of powers that ended with more than 25% of the supply centers
        :return: A dictionary with key: the phase_ix
                              with value: A dict with the power_name as key and a dict with the example fields as value
    """
    nb_phases = len(saved_game_proto.phases)
    policy_data = {phase_ix: {} for phase_ix in range(nb_phases - 1)}
    game_id = saved_game_proto.id
    map_object = Map(saved_game_proto.map)

    # Determining if we have a draw
    nb_sc_to_win = len(map_object.scs) // 2 + 1
    has_solo_winner = max([len(saved_game_proto.phases[-1].state.centers[power_name].value)
                           for power_name in saved_game_proto.phases[-1].state.centers]) >= nb_sc_to_win
    survivors = [power_name for power_name in saved_game_proto.phases[-1].state.centers
                 if saved_game_proto.phases[-1].state.centers[power_name].value]
    has_draw = not has_solo_winner and len(survivors) >= 2

    # Processing all phases (except the last one)
    current_year = 0
    for phase_ix in range(nb_phases - 1):

        # Building a list of orders of previous phases
        previous_orders_states = [np.zeros((NB_NODES, NB_ORDERS_FEATURES), dtype=np.uint8)] * NB_PREV_ORDERS
        for phase_proto in saved_game_proto.phases[max(0, phase_ix - NB_PREV_ORDERS_HISTORY):phase_ix]:
            if phase_proto.name[-1] == 'M':
                previous_orders_states += [proto_to_prev_orders_state(phase_proto, map_object)]
        previous_orders_states = previous_orders_states[-NB_PREV_ORDERS:]
        prev_orders_state = np.array(previous_orders_states)

        # Parsing each requested power in the specified phase
        phase_proto = saved_game_proto.phases[phase_ix]
        phase_name = phase_proto.name
        state_proto = phase_proto.state
        phase_board_state = proto_to_board_state(state_proto, map_object)

        # Increasing year for every spring or when the game is completed
        if phase_proto.name == 'COMPLETED' or (phase_proto.name[0] == 'S' and phase_proto.name[-1] == 'M'):
            current_year += 1

        for power_name in power_names:
            phase_issued_orders = get_issued_orders_for_powers(phase_proto, [power_name])
            phase_possible_orders = get_possible_orders_for_powers(phase_proto, [power_name])
            phase_draw_target = 1. if has_draw and phase_ix == (nb_phases - 2) and power_name in survivors else 0.

            # Data to use when not learning a policy
            blank_policy_data = {'board_state': phase_board_state,
                                 'prev_orders_state': prev_orders_state,
                                 'draw_target': phase_draw_target}

            # Power is not a top victor - We don't want to learn a policy from him
            if power_name not in top_victors:
                policy_data[phase_ix][power_name] = blank_policy_data
                continue

            # Finding the orderable locs
            orderable_locations = list(phase_issued_orders[power_name].keys())

            # Skipping power for this phase if we are only issuing Hold
            for order_loc, order in phase_issued_orders[power_name].items():
                order_tokens = get_order_tokens(order)
                if len(order_tokens) >= 2 and order_tokens[1] != 'H':
                    break
            else:
                policy_data[phase_ix][power_name] = blank_policy_data
                continue

            # Removing orderable locs where orders are not possible (i.e. NO_CHECK games)
            for order_loc, order in phase_issued_orders[power_name].items():
                if order not in phase_possible_orders[order_loc] and order_loc in orderable_locations:
                    if 'NO_CHECK' not in saved_game_proto.rules:
                        LOGGER.warning('%s not in all possible orders. Phase %s - Game %s.', order, phase_name, game_id)
                    orderable_locations.remove(order_loc)

                # Remove orderable locs where the order is either invalid or not frequent
                if order_to_ix(order) is None and order_loc in orderable_locations:
                    orderable_locations.remove(order_loc)

            # Determining if we are in an adjustment phase
            in_adjustment_phase = state_proto.name[-1] == 'A'
            nb_builds = state_proto.builds[power_name].count
            nb_homes = len(state_proto.builds[power_name].homes)

            # WxxxA - We can build units
            # WxxxA - We can disband units
            # Other phase
            if in_adjustment_phase and nb_builds >= 0:
                decoder_length = min(nb_builds, nb_homes)
            elif in_adjustment_phase and nb_builds < 0:
                decoder_length = abs(nb_builds)
            else:
                decoder_length = len(orderable_locations)

            # Not all units were disbanded - Skipping this power as we can't learn the orders properly
            if in_adjustment_phase and nb_builds < 0 and len(orderable_locations) < abs(nb_builds):
                policy_data[phase_ix][power_name] = blank_policy_data
                continue

            # Not enough orderable locations for this power, skipping
            if not orderable_locations or not decoder_length:
                policy_data[phase_ix][power_name] = blank_policy_data
                continue

            # decoder_inputs [GO, order1, order2, order3]
            decoder_inputs = [GO_ID]
            decoder_inputs += [order_to_ix(phase_issued_orders[power_name][loc]) for loc in orderable_locations]
            if in_adjustment_phase and nb_builds > 0:
                decoder_inputs += [order_to_ix('WAIVE')] * (min(nb_builds, nb_homes) - len(orderable_locations))
            decoder_length = min(decoder_length, NB_SUPPLY_CENTERS)

            # Adjustment Phase - Use all possible orders for each location.
            if in_adjustment_phase:
                build_disband_locs = list(get_possible_orders_for_powers(phase_proto, [power_name]).keys())
                phase_board_alignments = get_board_alignments(build_disband_locs,
                                                              in_adjustment_phase=in_adjustment_phase,
                                                              tokens_per_loc=1,
                                                              decoder_length=decoder_length)

                # Building a list of all orders for all locations
                adj_orders = []
                for loc in build_disband_locs:
                    adj_orders += phase_possible_orders[loc]

                # Not learning builds for BUILD_ANY
                if nb_builds > 0 and 'BUILD_ANY' in state_proto.rules:
                    adj_orders = []

                # No orders found - Skipping
                if not adj_orders:
                    policy_data[phase_ix][power_name] = blank_policy_data
                    continue

                # Computing the candidates
                candidates = [get_order_based_mask(adj_orders)] * decoder_length

            # Regular phase - Compute candidates for each location
            else:
                phase_board_alignments = get_board_alignments(orderable_locations,
                                                              in_adjustment_phase=in_adjustment_phase,
                                                              tokens_per_loc=1,
                                                              decoder_length=decoder_length)
                candidates = []
                for loc in orderable_locations:
                    candidates += [get_order_based_mask(phase_possible_orders[loc])]

            # Saving results
            # No need to return temperature, current_power, current_season
            policy_data[phase_ix][power_name] = {'board_state': phase_board_state,
                                                 'board_alignments': phase_board_alignments,
                                                 'prev_orders_state': prev_orders_state,
                                                 'decoder_inputs': decoder_inputs,
                                                 'decoder_lengths': decoder_length,
                                                 'candidates': candidates,
                                                 'draw_target': phase_draw_target}
    # Returning
    return policy_data
예제 #16
0
def generate_trajectory(players,
                        reward_fn,
                        advantage_fn,
                        env_constructor=None,
                        hparams=None,
                        power_assignments=None,
                        set_player_seed=None,
                        initial_state_bytes=None,
                        update_interval=0,
                        update_queue=None,
                        output_format='proto'):
    """ Generates a single trajectory (Saved Gamed Proto) for RL (self-play) with the power assigments
        :param players: A list of instantiated players
        :param reward_fn: The reward function to use to calculate rewards
        :param advantage_fn: An instance of `.models.self_play.advantages`
        :param env_constructor: A callable to get the OpenAI gym environment (args: players)
        :param hparams: A dictionary of hyper parameters with their values
        :param power_assignments: Optional. The power name we want to play as. (e.g. 'FRANCE') or a list of powers.
        :param set_player_seed: Boolean that indicates that we want to set the player seed on reset().
        :param initial_state_bytes: A `game.State` proto (in bytes format) representing the initial state of the game.
        :param update_interval: Optional. If set, a partial saved game is put in the update_queue this every seconds.
        :param update_queue: Optional. If update interval is set, partial games will be put in this queue
        :param output_format: The output format. One of 'proto', 'bytes', 'zlib'
        :return: A SavedGameProto representing the game played (with policy details and power assignments)
                 Depending on format, the output might be converted to a byte array, or a compressed byte array.
        :type players: List[diplomacy_research.players.player.Player]
        :type reward_fn: diplomacy_research.models.self_play.reward_functions.AbstractRewardFunction
        :type advantage_fn: diplomacy_research.models.self_play.advantages.base_advantage.BaseAdvantage
        :type update_queue: multiprocessing.Queue
    """
    # pylint: disable=too-many-arguments
    assert output_format in ['proto', 'bytes', 'zlib'
                             ], 'Format should be "proto", "bytes", "zlib"'
    assert len(players) == NB_POWERS

    # Making sure we use the SavedGame wrapper to record the game
    if env_constructor:
        env = env_constructor(players)
    else:
        env = default_env_constructor(players, hparams, power_assignments,
                                      set_player_seed, initial_state_bytes)
    wrapped_env = env
    while not isinstance(wrapped_env, DiplomacyEnv):
        if isinstance(wrapped_env, SaveGame):
            break
        wrapped_env = wrapped_env.env
    else:
        env = SaveGame(env)

    # Detecting if we have a Auto-Draw wrapper
    has_auto_draw = False
    wrapped_env = env
    while not isinstance(wrapped_env, DiplomacyEnv):
        if isinstance(wrapped_env, AutoDraw):
            has_auto_draw = True
            break
        wrapped_env = wrapped_env.env

    # Resetting env
    env.reset()

    # Timing vars for partial updates
    time_last_update = time.time()
    year_last_update = 0
    start_phase_ix = 0
    current_phase_ix = 0
    nb_transitions = 0

    # Cache Variables
    powers = sorted(
        [power_name for power_name in get_map_powers(env.game.map)])
    assigned_powers = env.get_all_powers_name()
    stored_board_state = OrderedDict()  # {phase_name: board_state}
    stored_prev_orders_state = OrderedDict()  # {phase_name: prev_orders_state}
    stored_possible_orders = OrderedDict()  # {phase_name: possible_orders}

    power_variables = {
        power_name: {
            'orders': [],
            'policy_details': [],
            'state_values': [],
            'rewards': [],
            'returns': [],
            'last_state_value': 0.
        }
        for power_name in powers
    }

    new_state_proto = None
    phase_history_proto = []
    map_object = Map(name=env.game.map.name)

    # Generating
    while not env.is_done:
        state_proto = new_state_proto if new_state_proto is not None else extract_state_proto(
            env.game)
        possible_orders_proto = extract_possible_orders_proto(env.game)

        # Computing board_state
        board_state = proto_to_board_state(state_proto,
                                           map_object).flatten().tolist()
        state_proto.board_state.extend(board_state)

        # Storing possible orders for this phase
        current_phase = env.game.get_current_phase()
        stored_board_state[current_phase] = board_state
        stored_possible_orders[current_phase] = possible_orders_proto

        # Getting orders, policy details, and state value
        tasks = [(player, state_proto, pow_name,
                  phase_history_proto[-NB_PREV_ORDERS_HISTORY:],
                  possible_orders_proto)
                 for player, pow_name in zip(env.players, assigned_powers)]
        step_args = yield [get_step_args(*args) for args in tasks]

        # Stepping through env, storing power variables
        for power_name, (orders, policy_details,
                         state_value) in zip(assigned_powers, step_args):
            if orders:
                env.step((power_name, orders))
                nb_transitions += 1
            if has_auto_draw and policy_details is not None:
                env.set_draw_prob(power_name, policy_details['draw_prob'])

        # Processing
        env.process()
        current_phase_ix += 1

        # Retrieving draw action and saving power variables
        for power_name, (orders, policy_details,
                         state_value) in zip(assigned_powers, step_args):
            if has_auto_draw and policy_details is not None:
                policy_details['draw_action'] = env.get_draw_actions(
                )[power_name]
            power_variables[power_name]['orders'] += [orders]
            power_variables[power_name]['policy_details'] += [policy_details]
            power_variables[power_name]['state_values'] += [state_value]

        # Getting new state
        new_state_proto = extract_state_proto(env.game)

        # Storing reward for this transition
        done_reason = DoneReason(env.done_reason) if env.done_reason else None
        for power_name in powers:
            power_variables[power_name]['rewards'] += [
                reward_fn.get_reward(prev_state_proto=state_proto,
                                     state_proto=new_state_proto,
                                     power_name=power_name,
                                     is_terminal_state=done_reason is not None,
                                     done_reason=done_reason)
            ]

        # Computing prev_orders_state for the previous state
        last_phase_proto = extract_phase_history_proto(
            env.game, nb_previous_phases=1)[-1]
        if last_phase_proto.name[-1] == 'M':
            prev_orders_state = proto_to_prev_orders_state(
                last_phase_proto, map_object).flatten().tolist()
            stored_prev_orders_state[last_phase_proto.name] = prev_orders_state
            last_phase_proto.prev_orders_state.extend(prev_orders_state)
            phase_history_proto += [last_phase_proto]

        # Sending partial game if:
        # 1) We have update_interval > 0 with an update queue, and
        # 2a) The game is completed, or 2b) the update time has elapsted and at least 5 years as passed
        has_update_interval = update_interval > 0 and update_queue is not None
        game_is_completed = env.is_done
        min_time_has_passed = time.time() - time_last_update > update_interval
        current_year = 9999 if env.game.get_current_phase(
        ) == 'COMPLETED' else int(env.game.get_current_phase()[1:5])
        min_years_have_passed = current_year - year_last_update >= 5

        if (has_update_interval
                and (game_is_completed or
                     (min_time_has_passed and min_years_have_passed))):

            # Game is completed - last state value is 0
            if game_is_completed:
                for power_name in powers:
                    power_variables[power_name]['last_state_value'] = 0.

            # Otherwise - Querying the model for the value of the last state
            else:
                tasks = [
                    (player, new_state_proto, pow_name,
                     phase_history_proto[-NB_PREV_ORDERS_HISTORY:],
                     possible_orders_proto)
                    for player, pow_name in zip(env.players, assigned_powers)
                ]
                last_state_values = yield [
                    get_state_value(*args) for args in tasks
                ]

                for power_name, last_state_value in zip(
                        assigned_powers, last_state_values):
                    power_variables[power_name][
                        'last_state_value'] = last_state_value

            # Getting partial game and sending it on the update_queue
            saved_game_proto = get_saved_game_proto(
                env=env,
                players=players,
                stored_board_state=stored_board_state,
                stored_prev_orders_state=stored_prev_orders_state,
                stored_possible_orders=stored_possible_orders,
                power_variables=power_variables,
                start_phase_ix=start_phase_ix,
                reward_fn=reward_fn,
                advantage_fn=advantage_fn,
                is_partial_game=True)
            update_queue.put_nowait(
                (False, nb_transitions, proto_to_bytes(saved_game_proto)))

            # Updating stats
            start_phase_ix = current_phase_ix
            nb_transitions = 0
            if not env.is_done:
                year_last_update = int(env.game.get_current_phase()[1:5])

    # Since the environment is done (Completed game) - We can leave the last_state_value at 0.
    for power_name in powers:
        power_variables[power_name]['last_state_value'] = 0.

    # Getting completed game
    saved_game_proto = get_saved_game_proto(
        env=env,
        players=players,
        stored_board_state=stored_board_state,
        stored_prev_orders_state=stored_prev_orders_state,
        stored_possible_orders=stored_possible_orders,
        power_variables=power_variables,
        start_phase_ix=0,
        reward_fn=reward_fn,
        advantage_fn=advantage_fn,
        is_partial_game=False)

    # Converting to correct format
    output = {
        'proto': lambda proto: proto,
        'zlib': proto_to_zlib,
        'bytes': proto_to_bytes
    }[output_format](saved_game_proto)

    # Returning
    return output
예제 #17
0
    def _build_from_string(self, order, game=None):
        """ Builds this object from a string

            :type order: str
            :type game: diplomacy.Game
        """
        # pylint: disable=too-many-return-statements,too-many-branches,too-many-statements
        # Converting move to retreat during retreat phase
        if self.phase_type == 'R':
            order = order.replace(' - ', ' R ')

        # Splitting into parts
        words = order.split()

        # --- Wait / Waive ---
        # [{"id": "56", "unitID": null, "type": "Wait", "toTerrID": "", "fromTerrID": "", "viaConvoy": ""}]
        if len(words) == 1 and words[0] == 'WAIVE':
            self.order_str = 'WAIVE'
            self.order_dict = {'terrID': None,
                               'unitType': '',
                               'type': 'Wait',
                               'toTerrID': '',
                               'fromTerrID': '',
                               'viaConvoy': ''}
            return

        # Validating
        if len(words) < 3:
            LOGGER.error('Unable to parse the order "%s". Require at least 3 words', order)
            return

        short_unit_type, loc_name, order_type = words[:3]
        if short_unit_type not in 'AF':
            LOGGER.error('Unable to parse the order "%s". Valid unit types are "A" and "F".', order)
            return
        if order_type not in 'H-SCRBD':
            LOGGER.error('Unable to parse the order "%s". Valid order types are H-SCRBD', order)
            return
        if loc_name not in CACHE[self.map_name]['loc_to_ix']:
            LOGGER.error('Received invalid loc "%s" for map "%s".', loc_name, self.map_name)
            return

        # Extracting territories
        unit_type = {'A': 'Army', 'F': 'Fleet'}[short_unit_type]
        terr_id = CACHE[self.map_name]['loc_to_ix'][loc_name]

        # --- Hold ---
        # {"id": "76", "unitID": "19", "type": "Hold", "toTerrID": "", "fromTerrID": "", "viaConvoy": ""}
        if order_type == 'H':
            self.order_str = '%s %s H' % (short_unit_type, loc_name)
            self.order_dict = {'terrID': terr_id,
                               'unitType': unit_type,
                               'type': 'Hold',
                               'toTerrID': '',
                               'fromTerrID': '',
                               'viaConvoy': ''}

        # --- Move ---
        # {"id": "73", "unitID": "16", "type": "Move", "toTerrID": "25", "fromTerrID": "", "viaConvoy": "Yes",
        # "convoyPath": ["22", "69"]},
        # {"id": "74", "unitID": "17", "type": "Move", "toTerrID": "69", "fromTerrID": "", "viaConvoy": "No"}
        elif order_type == '-':
            if len(words) < 4:
                LOGGER.error('[Move] Unable to parse the move order "%s". Require at least 4 words', order)
                LOGGER.error(order)
                return

            # Building map
            map_object = Map(self.map_name)
            convoy_path = []

            # Getting destination
            to_loc_name = words[3]
            to_terr_id = CACHE[self.map_name]['loc_to_ix'].get(to_loc_name, None)

            # Deciding if this move is doable by convoy or not
            if unit_type != 'Army':
                via_flag = ''
            else:
                # Any plausible convoy path (i.e. where fleets are on water, even though they are not convoying)
                # is valid for the 'convoyPath' argument
                reachable_by_land = map_object.abuts('A', loc_name, '-', to_loc_name)
                via_convoy = bool(words[-1] == 'VIA') or not reachable_by_land
                via_flag = ' VIA' if via_convoy else ''
                convoy_path = find_convoy_path(loc_name, to_loc_name, map_object, game)

            if to_loc_name is None:
                LOGGER.error('[Move] Received invalid to loc "%s" for map "%s".', to_terr_id, self.map_name)
                LOGGER.error(order)
                return

            self.order_str = '%s %s - %s%s' % (short_unit_type, loc_name, to_loc_name, via_flag)
            self.order_dict = {'terrID': terr_id,
                               'unitType': unit_type,
                               'type': 'Move',
                               'toTerrID': to_terr_id,
                               'fromTerrID': '',
                               'viaConvoy': 'Yes' if via_flag else 'No'}
            if convoy_path:
                self.order_dict['convoyPath'] = [CACHE[self.map_name]['loc_to_ix'][loc] for loc in convoy_path[:-1]]

        # --- Support hold ---
        # {"id": "73", "unitID": "16", "type": "Support hold", "toTerrID": "24", "fromTerrID": "", "viaConvoy": ""}
        elif order_type == 'S' and '-' not in words:
            if len(words) < 5:
                LOGGER.error('[Support H] Unable to parse the support hold order "%s". Require at least 5 words', order)
                LOGGER.error(order)
                return

            # Getting supported unit
            to_loc_name = words[4][:3]
            to_terr_id = CACHE[self.map_name]['loc_to_ix'].get(to_loc_name, None)

            if to_loc_name is None:
                LOGGER.error('[Support H] Received invalid to loc "%s" for map "%s".', to_terr_id, self.map_name)
                LOGGER.error(order)
                return

            self.order_str = '%s %s S %s' % (short_unit_type, loc_name, to_loc_name)
            self.order_dict = {'terrID': terr_id,
                               'unitType': unit_type,
                               'type': 'Support hold',
                               'toTerrID': to_terr_id,
                               'fromTerrID': '',
                               'viaConvoy': ''}

        # --- Support move ---
        # {"id": "73", "unitID": "16", "type": "Support move", "toTerrID": "24", "fromTerrID": "69", "viaConvoy": ""}
        elif order_type == 'S':
            if len(words) < 6:
                LOGGER.error('Unable to parse the support move order "%s". Require at least 6 words', order)
                return

            # Building map
            map_object = Map(self.map_name)
            convoy_path = []

            # Getting supported unit
            move_index = words.index('-')
            to_loc_name = words[move_index + 1][:3]                         # Removing coast from dest
            from_loc_name = words[move_index - 1]
            to_terr_id = CACHE[self.map_name]['loc_to_ix'].get(to_loc_name, None)
            from_terr_id = CACHE[self.map_name]['loc_to_ix'].get(from_loc_name, None)

            if to_loc_name is None:
                LOGGER.error('[Support M] Received invalid to loc "%s" for map "%s".', to_terr_id, self.map_name)
                LOGGER.error(order)
                return
            if from_loc_name is None:
                LOGGER.error('[Support M] Received invalid from loc "%s" for map "%s".', from_terr_id, self.map_name)
                LOGGER.error(order)
                return

            # Deciding if we are support a move by convoy or not
            # Any plausible convoy path (i.e. where fleets are on water, even though they are not convoying)
            # is valid for the 'convoyPath' argument, only if it does not include the fleet issuing the support
            if words[move_index - 2] != 'F' and map_object.area_type(from_loc_name) == 'COAST':
                convoy_path = find_convoy_path(from_loc_name, to_loc_name, map_object, game, excluding=loc_name)

            self.order_str = '%s %s S %s - %s' % (short_unit_type, loc_name, from_loc_name, to_loc_name)
            self.order_dict = {'terrID': terr_id,
                               'unitType': unit_type,
                               'type': 'Support move',
                               'toTerrID': to_terr_id,
                               'fromTerrID': from_terr_id,
                               'viaConvoy': ''}
            if convoy_path:
                self.order_dict['convoyPath'] = [CACHE[self.map_name]['loc_to_ix'][loc] for loc in convoy_path[:-1]]

        # --- Convoy ---
        # {"id": "79", "unitID": "22", "type": "Convoy", "toTerrID": "24", "fromTerrID": "20", "viaConvoy": "",
        # "convoyPath": ["20", "69"]}
        elif order_type == 'C':
            if len(words) < 6:
                LOGGER.error('[Convoy] Unable to parse the convoy order "%s". Require at least 6 words', order)
                LOGGER.error(order)
                return

            # Building map
            map_object = Map(self.map_name)

            # Getting supported unit
            move_index = words.index('-')
            to_loc_name = words[move_index + 1]
            from_loc_name = words[move_index - 1]
            to_terr_id = CACHE[self.map_name]['loc_to_ix'].get(to_loc_name, None)
            from_terr_id = CACHE[self.map_name]['loc_to_ix'].get(from_loc_name, None)

            if to_loc_name is None:
                LOGGER.error('[Convoy] Received invalid to loc "%s" for map "%s".', to_terr_id, self.map_name)
                LOGGER.error(order)
                return
            if from_loc_name is None:
                LOGGER.error('[Convoy] Received invalid from loc "%s" for map "%s".', from_terr_id, self.map_name)
                LOGGER.error(order)
                return

            # Finding convoy path
            # Any plausible convoy path (i.e. where fleets are on water, even though they are not convoying)
            # is valid for the 'convoyPath' argument, only if it includes the current fleet issuing the convoy order
            convoy_path = find_convoy_path(from_loc_name, to_loc_name, map_object, game, including=loc_name)

            self.order_str = '%s %s C A %s - %s' % (short_unit_type, loc_name, from_loc_name, to_loc_name)
            self.order_dict = {'terrID': terr_id,
                               'unitType': unit_type,
                               'type': 'Convoy',
                               'toTerrID': to_terr_id,
                               'fromTerrID': from_terr_id,
                               'viaConvoy': ''}
            if convoy_path:
                self.order_dict['convoyPath'] = [CACHE[self.map_name]['loc_to_ix'][loc] for loc in convoy_path[:-1]]

        # --- Retreat ---
        # {"id": "152", "unitID": "18", "type": "Retreat", "toTerrID": "75", "fromTerrID": "", "viaConvoy": ""}
        elif order_type == 'R':
            if len(words) < 4:
                LOGGER.error('[Retreat] Unable to parse the move order "%s". Require at least 4 words', order)
                LOGGER.error(order)
                return

            # Getting destination
            to_loc_name = words[3]
            to_terr_id = CACHE[self.map_name]['loc_to_ix'].get(to_loc_name, None)

            if to_loc_name is None:
                return

            self.order_str = '%s %s R %s' % (short_unit_type, loc_name, to_loc_name)
            self.order_dict = {'terrID': terr_id,
                               'unitType': unit_type,
                               'type': 'Retreat',
                               'toTerrID': to_terr_id,
                               'fromTerrID': '',
                               'viaConvoy': ''}

        # --- Disband (R phase) ---
        # {"id": "152", "unitID": "18", "type": "Disband", "toTerrID": "", "fromTerrID": "", "viaConvoy": ""}
        elif order_type == 'D' and self.phase_type == 'R':
            # Note: For R phase, we disband with the coast
            self.order_str = '%s %s D' % (short_unit_type, loc_name)
            self.order_dict = {'terrID': terr_id,
                               'unitType': unit_type,
                               'type': 'Disband',
                               'toTerrID': '',
                               'fromTerrID': '',
                               'viaConvoy': ''}

        # --- Build Army ---
        # [{"id": "56", "unitID": null, "type": "Build Army", "toTerrID": "37", "fromTerrID": "", "viaConvoy": ""}]
        elif order_type == 'B' and short_unit_type == 'A':
            self.order_str = 'A %s B' % loc_name
            self.order_dict = {'terrID': terr_id,
                               'unitType': 'Army',
                               'type': 'Build Army',
                               'toTerrID': terr_id,
                               'fromTerrID': '',
                               'viaConvoy': ''}

        # -- Build Fleet ---
        # [{"id": "56", "unitID": null, "type": "Build Fleet", "toTerrID": "37", "fromTerrID": "", "viaConvoy": ""}]
        elif order_type == 'B' and short_unit_type == 'F':
            self.order_str = 'F %s B' % loc_name
            self.order_dict = {'terrID': terr_id,
                               'unitType': 'Fleet',
                               'type': 'Build Fleet',
                               'toTerrID': terr_id,
                               'fromTerrID': '',
                               'viaConvoy': ''}

        # Disband (A phase)
        # {"id": "152", "unitID": null, "type": "Destroy", "toTerrID": "18", "fromTerrID": "", "viaConvoy": ""}
        elif order_type == 'D':
            # For A phase, we disband without the coast
            loc_name = loc_name[:3]
            terr_id = CACHE[self.map_name]['loc_to_ix'][loc_name]
            self.order_str = '%s %s D' % (short_unit_type, loc_name)
            self.order_dict = {'terrID': terr_id,
                               'unitType': unit_type,
                               'type': 'Destroy',
                               'toTerrID': terr_id,
                               'fromTerrID': '',
                               'viaConvoy': ''}
예제 #18
0
def get_order_vocabulary():
    """ Computes the list of all valid orders on the standard map
        :return: A sorted list of all valid orders on the standard map
    """
    # pylint: disable=too-many-nested-blocks,too-many-branches
    categories = [
        'H',
        'D',
        'B',
        '-',
        'R',
        'SH',
        'S-',
        '-1',
        'S1',
        'C1',  # Move, Support, Convoy (using 1 fleet)
        '-2',
        'S2',
        'C2',  # Move, Support, Convoy (using 2 fleets)
        '-3',
        'S3',
        'C3',  # Move, Support, Convoy (using 3 fleets)
        '-4',
        'S4',
        'C4'
    ]  # Move, Support, Convoy (using 4 fleets)
    orders = {category: set() for category in categories}
    map_object = Map()
    locs = sorted([loc.upper() for loc in map_object.locs])

    # All holds, builds, and disbands orders
    for loc in locs:
        for unit_type in ['A', 'F']:
            if map_object.is_valid_unit('%s %s' % (unit_type, loc)):
                orders['H'].add('%s %s H' % (unit_type, loc))
                orders['D'].add('%s %s D' % (unit_type, loc))

                # Allowing builds in all SCs (even though only homes will likely be used)
                if loc[:3] in map_object.scs:
                    orders['B'].add('%s %s B' % (unit_type, loc))

    # Moves, Retreats, Support Holds
    for unit_loc in locs:
        for dest in [
                loc.upper()
                for loc in map_object.abut_list(unit_loc, incl_no_coast=True)
        ]:
            for unit_type in ['A', 'F']:
                if not map_object.is_valid_unit('%s %s' %
                                                (unit_type, unit_loc)):
                    continue

                if map_object.abuts(unit_type, unit_loc, '-', dest):
                    orders['-'].add('%s %s - %s' % (unit_type, unit_loc, dest))
                    orders['R'].add('%s %s R %s' % (unit_type, unit_loc, dest))

                # Making sure we can support destination
                if not (map_object.abuts(unit_type, unit_loc, 'S', dest) or
                        map_object.abuts(unit_type, unit_loc, 'S', dest[:3])):
                    continue

                # Support Hold
                for dest_unit_type in ['A', 'F']:
                    for coast in ['', '/NC', '/SC', '/EC', '/WC']:
                        if map_object.is_valid_unit(
                                '%s %s%s' % (dest_unit_type, dest, coast)):
                            orders['SH'].add('%s %s S %s %s%s' %
                                             (unit_type, unit_loc,
                                              dest_unit_type, dest, coast))

    # Convoys, Move Via
    for nb_fleets in map_object.convoy_paths:

        # Skipping long-term convoys
        if nb_fleets > 4:
            continue

        for start, fleets, dests in map_object.convoy_paths[nb_fleets]:
            for end in dests:
                orders['-%d' % nb_fleets].add('A %s - %s VIA' % (start, end))
                orders['-%d' % nb_fleets].add('A %s - %s VIA' % (end, start))
                for fleet_loc in fleets:
                    orders['C%d' % nb_fleets].add('F %s C A %s - %s' %
                                                  (fleet_loc, start, end))
                    orders['C%d' % nb_fleets].add('F %s C A %s - %s' %
                                                  (fleet_loc, end, start))

    # Support Move (Non-Convoyed)
    for start_loc in locs:
        for dest_loc in [
                loc.upper()
                for loc in map_object.abut_list(start_loc, incl_no_coast=True)
        ]:
            for support_loc in (
                    map_object.abut_list(dest_loc, incl_no_coast=True) +
                    map_object.abut_list(dest_loc[:3], incl_no_coast=True)):
                support_loc = support_loc.upper()

                # A unit cannot support itself
                if support_loc[:3] == start_loc[:3]:
                    continue

                # Making sure the src unit can move to dest
                # and the support unit can also support to dest
                for src_unit_type in ['A', 'F']:
                    for support_unit_type in ['A', 'F']:
                        if (map_object.abuts(src_unit_type, start_loc, '-',
                                             dest_loc)
                                and map_object.abuts(
                                    support_unit_type, support_loc, 'S',
                                    dest_loc[:3]) and map_object.is_valid_unit(
                                        '%s %s' % (src_unit_type, start_loc))
                                and map_object.is_valid_unit(
                                    '%s %s' %
                                    (support_unit_type, support_loc))):
                            orders['S-'].add(
                                '%s %s S %s %s - %s' %
                                (support_unit_type, support_loc, src_unit_type,
                                 start_loc, dest_loc[:3]))

    # Support Move (Convoyed)
    for nb_fleets in map_object.convoy_paths:

        # Skipping long-term convoys
        if nb_fleets > 4:
            continue

        for start_loc, fleets, ends in map_object.convoy_paths[nb_fleets]:
            for dest_loc in ends:
                for support_loc in map_object.abut_list(dest_loc,
                                                        incl_no_coast=True):
                    support_loc = support_loc.upper()

                    # A unit cannot support itself
                    if support_loc[:3] == start_loc[:3]:
                        continue

                    # A fleet cannot support if it convoys
                    if support_loc in fleets:
                        continue

                    # Making sure the support unit can also support to dest
                    # And that the support unit is not convoying
                    for support_unit_type in ['A', 'F']:
                        if (map_object.abuts(support_unit_type, support_loc,
                                             'S', dest_loc)
                                and map_object.is_valid_unit(
                                    '%s %s' %
                                    (support_unit_type, support_loc))):
                            orders['S%d' % nb_fleets].add(
                                '%s %s S A %s - %s' %
                                (support_unit_type, support_loc, start_loc,
                                 dest_loc[:3]))

    # Building the list of final orders
    final_orders = [PAD_TOKEN, GO_TOKEN, EOS_TOKEN, DRAW_TOKEN]
    final_orders += [
        '<%s>' % power_name for power_name in get_map_powers(map_object)
    ]
    final_orders += ['WAIVE']

    # Sorting each category
    for category in categories:
        category_orders = [
            order for order in orders[category] if order not in final_orders
        ]
        final_orders += list(
            sorted(
                category_orders,
                key=lambda value: (
                    value.split()[1],  # Sorting by loc
                    value)))  # Then alphabetically
    return final_orders
예제 #19
0
def get_power_vocabulary():
    """ Computes a sorted list of powers in the standard map
        :return: A list of the powers
    """
    standard_map = Map()
    return sorted([power_name for power_name in standard_map.powers])
예제 #20
0
def build():
    """ Building the hdf5 dataset """
    if not os.path.exists(ZIP_DATASET_PATH):
        raise RuntimeError('Unable to find the zip dataset at %s' % ZIP_DATASET_PATH)

    # Extracting
    extract_dir = os.path.join(os.path.dirname(ZIP_DATASET_PATH), 'zip_dataset')
    if not os.path.exists(extract_dir):
        LOGGER.info('... Extracting files from zip dataset.')
        with zipfile.ZipFile(ZIP_DATASET_PATH, 'r') as zip_dataset:
            zip_dataset.extractall(extract_dir)

    # Additional information we also want to store
    map_object = Map()
    all_powers = get_map_powers(map_object)
    sc_to_win = len(map_object.scs) // 2 + 1

    hash_table = {}                                         # zobrist_hash: [{game_id}/{phase_name}]
    moves = {}                                              # Moves frequency: {move: [nb_no_press, nb_press]}
    nb_phases = OrderedDict()                               # Nb of phases per game
    end_scs = {'press': {power_name: {nb_sc: [] for nb_sc in range(0, sc_to_win + 1)} for power_name in all_powers},
               'no_press': {power_name: {nb_sc: [] for nb_sc in range(0, sc_to_win + 1)} for power_name in all_powers}}

    # Building
    dataset_index = {}
    LOGGER.info('... Building HDF5 dataset.')
    with multiprocessing.Pool() as pool:
        with h5py.File(DATASET_PATH, 'w') as hdf5_dataset, open(PROTO_DATASET_PATH, 'wb') as proto_dataset:

            for json_file_path in glob.glob(extract_dir + '/*.jsonl'):
                LOGGER.info('... Processing: %s', json_file_path)
                category = json_file_path.split('/')[-1].split('.')[0]
                dataset_index[category] = set()

                # Processing file using pool
                with open(json_file_path, 'r') as json_file:
                    lines = json_file.read().splitlines()
                    for game_id, saved_game_zlib in tqdm(pool.imap_unordered(process_game, lines), total=len(lines)):
                        if game_id is None:
                            continue
                        saved_game_proto = zlib_to_proto(saved_game_zlib, SavedGameProto)

                        # Saving to disk
                        hdf5_dataset[game_id] = np.void(saved_game_zlib)
                        write_proto_to_file(proto_dataset, saved_game_proto, compressed=False)
                        dataset_index[category].add(game_id)

                        # Recording additional info
                        get_end_scs_info(saved_game_proto, game_id, all_powers, sc_to_win, end_scs)
                        get_moves_info(saved_game_proto, moves)
                        nb_phases[game_id] = len(saved_game_proto.phases)

                        # Recording hash of each phase
                        for phase in saved_game_proto.phases:
                            hash_table.setdefault(phase.state.zobrist_hash, [])
                            hash_table[phase.state.zobrist_hash] += ['%s/%s' % (game_id, phase.name)]

    # Storing info to disk
    with open(DATASET_INDEX_PATH, 'wb') as file:
        pickle.dump(dataset_index, file, pickle.HIGHEST_PROTOCOL)
    with open(END_SCS_DATASET_PATH, 'wb') as file:
        pickle.dump(end_scs, file, pickle.HIGHEST_PROTOCOL)
    with open(HASH_DATASET_PATH, 'wb') as file:
        pickle.dump(hash_table, file, pickle.HIGHEST_PROTOCOL)
    with open(MOVES_COUNT_DATASET_PATH, 'wb') as file:
        pickle.dump(moves, file, pickle.HIGHEST_PROTOCOL)
    with open(PHASES_COUNT_DATASET_PATH, 'wb') as file:
        pickle.dump(nb_phases, file, pickle.HIGHEST_PROTOCOL)

    # Deleting extract_dir
    LOGGER.info('... Deleting extracted files.')
    if os.path.exists(extract_dir):
        shutil.rmtree(extract_dir, ignore_errors=True)
    LOGGER.info('... Done building HDF5 dataset.')
예제 #21
0
def generate_daide_game(players, progress_bar, daide_rules):
    """ Generate a game """
    global _server

    max_number_of_year = 35
    max_year = Map().first_year + max_number_of_year

    players_ordering = list(range(len(players)))
    shuffle(players_ordering)
    power_names = Map().powers
    power_names = [power_names[idx] for idx in players_ordering]
    clients = {
        power_name: player
        for power_name, player in zip(power_names, players)
    }
    nb_daide_players = len([
        _ for _, (player, _) in clients.items()
        if isinstance(player, DaidePlayerPlaceHolder)
    ])
    nb_regular_players = min(1, len(power_names) - nb_daide_players)

    server_game = ServerGame(n_controls=nb_daide_players + nb_regular_players,
                             rules=daide_rules)
    server_game.server = _server

    _server.add_new_game(server_game)

    reg_power_name, reg_client = None, None

    for power_name, (player, channel) in clients.items():
        if channel:
            game = yield channel.join_game(game_id=server_game.game_id,
                                           power_name=power_name)
            reg_power_name, reg_client = power_name, ClientWrapper(
                player, game)
            clients[power_name] = reg_client
        elif isinstance(player, Player):
            server_game.get_power(power_name).set_controlled(player.name)

    if nb_daide_players:
        server_port = PORTS_POOL.pop(0)
        OPEN_PORTS.append(server_port)
        _server.start_new_daide_server(server_game.game_id, port=server_port)
        yield gen.sleep(1)

        for power_name, (player, _) in clients.items():
            if isinstance(player, DaidePlayerPlaceHolder):
                process = launch_daide_client(server_port)
                clients[power_name] = DaideWrapper(player, process)

    for attempt_idx in range(30):
        if server_game.count_controlled_powers() == len(power_names):
            break
        yield gen.sleep(10)
        LOGGER.info('Waiting for DAIDE to connect. - Attempt %d / %d',
                    attempt_idx + 1, 30)
    else:
        LOGGER.error('DAIDE is not online after 5 minutes. Aborting.')
        raise RuntimeError()

    for power_name, (player, game) in clients.items():
        if not game and isinstance(player, Player):
            server_game.get_power(power_name).set_controlled(strings.DUMMY)

    if server_game.game_can_start():
        _server.start_game(server_game)

    local_powers = [
        power_name for power_name, (player, game) in clients.items()
        if not game and isinstance(player, Player)
    ]

    elimination_orders = {"_0": 0}

    yield gen.sleep(get_unsync_wait())

    try:
        watched_game = reg_client.channel_game if reg_client else server_game
        phase = PhaseSplitter(watched_game.get_current_phase())
        while watched_game.status != strings.COMPLETED and phase.year < max_year:
            print('\n=== NEW PHASE ===\n')
            print(watched_game.get_current_phase())

            if reg_client:
                yield reg_client.channel_game.wait()

            players_orders = yield [
                player.get_orders(server_game, power_name)
                for power_name, (player, _) in clients.items()
                if power_name in local_powers
            ]

            for power_name, orders in zip(local_powers, players_orders):
                if phase.type == 'R':
                    orders = [order.replace(' - ', ' R ') for order in orders]
                orders = [order for order in orders if order != 'WAIVE']
                server_game.set_orders(power_name, orders, expand=False)

            while reg_client and not server_game.get_power(
                    reg_power_name).order_is_set:
                orders = yield reg_client.player.get_orders(
                    server_game, reg_power_name)
                print('Sending orders')
                yield reg_client.channel_game.set_orders(orders=orders)

            print('All orders sent')

            if reg_client:
                yield reg_client.channel_game.no_wait()

            for attempt_idx in range(120):
                if phase.input_str != watched_game.get_current_phase() or \
                        server_game.status != strings.ACTIVE:
                    break
                if (
                        attempt_idx + 1
                ) % 12 == 0 and phase.input_str != server_game.get_current_phase(
                ):
                    # Watched game is unsynched
                    watched_game = server_game
                    break
                LOGGER.info(
                    'Waiting for the phase to be processed. - Attempt %d / %d',
                    attempt_idx + 1, 120)
                yield gen.sleep(2.5)
            else:
                LOGGER.error('Phase is taking too long to process. Aborting.')
                raise RuntimeError()

            elimination_order = max(elimination_orders.values()) + 1
            for power_name in power_names:
                if power_name not in elimination_orders and not server_game.get_power(
                        power_name).units:
                    elimination_orders[power_name] = elimination_order
                    if power_name == reg_power_name:
                        watched_game = server_game

            if server_game.status != strings.ACTIVE:
                break

            phase = PhaseSplitter(watched_game.get_current_phase())

    except TimeoutError as timeout:
        print('Timeout: ', timeout)
    except Exception as exception:
        print('Exception: ', exception)
    finally:
        _server.stop_daide_server(server_game.game_id)
        yield gen.sleep(1)
        for power_name, (player, _) in clients.items():
            if isinstance(player, DaidePlayerPlaceHolder):
                process = clients[power_name].process
                process.kill()
        if reg_client:
            reg_client.channel_game.leave()

    game = None
    saved_game = to_saved_game_format(server_game)

    if server_game.status == strings.COMPLETED or PhaseSplitter(
            server_game.get_current_phase()).year >= max_year:
        elimination_orders = [
            elimination_orders.get(power_name, 0) for power_name in power_names
        ]
        nb_centers = [
            len(server_game.get_power(power_name).centers)
            for power_name in power_names
        ]

        game = saved_game
        game['players_names'] = [player.name for player, _ in players]
        game['assigned_powers'] = power_names
        game['ranking'] = compute_ranking(power_names, nb_centers,
                                          elimination_orders)

    with open('game_{}.json'.format(saved_game['id']), 'w') as file:
        json.dump(saved_game, file)

    progress_bar.update()

    return game
예제 #22
0
def get_policy_data(saved_game_proto, power_names, top_victors):
    """ Computes the proto to save in tf.train.Example as a training example for the policy network
        :param saved_game_proto: A `.proto.game.SavedGame` object from the dataset.
        :param power_names: The list of powers for which we want the policy data
        :param top_victors: The list of powers that ended with more than 25% of the supply centers
        :return: A dictionary with key: the phase_ix
                              with value: A dict with the power_name as key and a dict with the example fields as value
    """
    # pylint: disable=too-many-branches
    nb_phases = len(saved_game_proto.phases)
    policy_data = {phase_ix: {} for phase_ix in range(nb_phases - 1)}
    game_id = saved_game_proto.id
    map_object = Map(saved_game_proto.map)

    # Determining if we have a draw
    nb_sc_to_win = len(map_object.scs) // 2 + 1
    has_solo_winner = max([len(saved_game_proto.phases[-1].state.centers[power_name].value)
                           for power_name in saved_game_proto.phases[-1].state.centers]) >= nb_sc_to_win
    survivors = [power_name for power_name in saved_game_proto.phases[-1].state.centers
                 if saved_game_proto.phases[-1].state.centers[power_name].value]
    has_draw = not has_solo_winner and len(survivors) >= 2

    # Processing all phases (except the last one
    current_year = 0
    for phase_ix in range(nb_phases - 1):

        # Building a list of orders of previous phases
        previous_orders_states = [np.zeros((NB_NODES, NB_ORDERS_FEATURES), dtype=np.uint8)] * NB_PREV_ORDERS
        for phase_proto in saved_game_proto.phases[max(0, phase_ix - NB_PREV_ORDERS_HISTORY):phase_ix]:
            if phase_proto.name[-1] == 'M':
                previous_orders_states += [proto_to_prev_orders_state(phase_proto, map_object)]
        previous_orders_states = previous_orders_states[-NB_PREV_ORDERS:]
        prev_orders_state = np.array(previous_orders_states)

        # Parsing each requested power in the specified phase
        phase_proto = saved_game_proto.phases[phase_ix]
        phase_name = phase_proto.name
        state_proto = phase_proto.state
        phase_board_state = proto_to_board_state(state_proto, map_object)

        # Increasing year for every spring or when the game is completed
        if phase_proto.name == 'COMPLETED' or (phase_proto.name[0] == 'S' and phase_proto.name[-1] == 'M'):
            current_year += 1

        for power_name in power_names:
            phase_issued_orders = get_issued_orders_for_powers(phase_proto, [power_name])
            phase_possible_orders = get_possible_orders_for_powers(phase_proto, [power_name])
            phase_draw_target = 1. if has_draw and phase_ix == (nb_phases - 2) and power_name in survivors else 0.

            # Data to use when not learning a policy
            blank_policy_data = {'board_state': phase_board_state,
                                 'prev_orders_state': prev_orders_state,
                                 'draw_target': phase_draw_target}

            # Power is not a top victor - We don't want to learn a policy from him
            if power_name not in top_victors:
                policy_data[phase_ix][power_name] = blank_policy_data
                continue

            # Finding the orderable locs
            orderable_locations = list(phase_issued_orders[power_name].keys())

            # Skipping power for this phase if we are only issuing Hold
            for order_loc, order in phase_issued_orders[power_name].items():
                order_tokens = get_order_tokens(order)
                if len(order_tokens) >= 2 and order_tokens[1] != 'H':
                    break
            else:
                policy_data[phase_ix][power_name] = blank_policy_data
                continue

            # Removing orderable locs where orders are not possible (i.e. NO_CHECK games)
            for order_loc, order in phase_issued_orders[power_name].items():
                if order not in phase_possible_orders[order_loc]:
                    if 'NO_CHECK' not in saved_game_proto.rules:
                        LOGGER.warning('%s not in all possible orders. Phase %s - Game %s.', order, phase_name, game_id)
                    orderable_locations.remove(order_loc)

            # Determining if we are in an adjustment phase
            in_adjustment_phase = state_proto.name[-1] == 'A'
            nb_builds = state_proto.builds[power_name].count
            nb_homes = len(state_proto.builds[power_name].homes)

            # WxxxA - We can build units
            # WxxxA - We can disband units
            # Other phase
            if in_adjustment_phase and nb_builds >= 0:
                decoder_length = TOKENS_PER_ORDER * min(nb_builds, nb_homes)
            elif in_adjustment_phase and nb_builds < 0:
                decoder_length = TOKENS_PER_ORDER * abs(nb_builds)
            else:
                decoder_length = TOKENS_PER_ORDER * len(orderable_locations)

            # Not all units were disbanded - Skipping this power as we can't learn the orders properly
            if in_adjustment_phase and nb_builds < 0 and len(orderable_locations) < abs(nb_builds):
                policy_data[phase_ix][power_name] = blank_policy_data
                continue

            # Not enough orderable locations for this power, skipping
            if not orderable_locations or not decoder_length:
                policy_data[phase_ix][power_name] = blank_policy_data
                continue

            # Encoding decoder inputs - Padding each order to 6 tokens
            # The decoder length should be a multiple of 6, since each order is padded to 6 tokens
            decoder_inputs = [GO_ID]
            for loc in orderable_locations[:]:
                order = phase_issued_orders[power_name][loc]
                try:
                    tokens = [token_to_ix(order_token) for order_token in get_order_tokens(order)] + [EOS_ID]
                    tokens += [PAD_ID] * (TOKENS_PER_ORDER - len(tokens))
                    decoder_inputs += tokens
                except KeyError:
                    LOGGER.warning('[data_generator] Order "%s" is not valid. Skipping location.', order)
                    orderable_locations.remove(loc)

            # Adding WAIVE orders
            if in_adjustment_phase and nb_builds > 0:
                waive_tokens = [token_to_ix('WAIVE'), EOS_ID] + [PAD_ID] * (TOKENS_PER_ORDER - 2)
                decoder_inputs += waive_tokens * (min(nb_builds, nb_homes) - len(orderable_locations))
            decoder_length = min(decoder_length, TOKENS_PER_ORDER * NB_SUPPLY_CENTERS)

            # Getting decoder mask
            coords = set()

            # Adjustment phase, we allow all builds / disbands in all positions
            if in_adjustment_phase:
                build_disband_locs = list(get_possible_orders_for_powers(phase_proto, [power_name]).keys())
                phase_board_alignments = get_board_alignments(build_disband_locs,
                                                              in_adjustment_phase=in_adjustment_phase,
                                                              tokens_per_loc=TOKENS_PER_ORDER,
                                                              decoder_length=decoder_length)

                # Building a list of all orders for all locations
                adj_orders = []
                for loc in build_disband_locs:
                    adj_orders += phase_possible_orders[loc]

                # Not learning builds for BUILD_ANY
                if nb_builds > 0 and 'BUILD_ANY' in state_proto.rules:
                    adj_orders = []

                # No orders found - Skipping
                if not adj_orders:
                    policy_data[phase_ix][power_name] = blank_policy_data
                    continue

                # Building a list of coordinates for the decoder mask matrix
                for loc_ix in range(decoder_length):
                    coords = get_token_based_mask(adj_orders, offset=loc_ix * TOKENS_PER_ORDER, coords=coords)

            # Regular phase, we mask for each location
            else:
                phase_board_alignments = get_board_alignments(orderable_locations,
                                                              in_adjustment_phase=in_adjustment_phase,
                                                              tokens_per_loc=TOKENS_PER_ORDER,
                                                              decoder_length=decoder_length)
                for loc_ix, loc in enumerate(orderable_locations):
                    coords = get_token_based_mask(phase_possible_orders[loc] or [''],
                                                  offset=loc_ix * TOKENS_PER_ORDER,
                                                  coords=coords)

            # Saving results
            # No need to return temperature, current_power, current_season
            policy_data[phase_ix][power_name] = {'board_state': phase_board_state,
                                                 'board_alignments': phase_board_alignments,
                                                 'prev_orders_state': prev_orders_state,
                                                 'decoder_inputs': decoder_inputs,
                                                 'decoder_lengths': decoder_length,
                                                 'decoder_mask_indices': list(sorted(coords)),
                                                 'draw_target': phase_draw_target}

    # Returning
    return policy_data