예제 #1
0
def root_data_generator(saved_game_proto, is_validation_set):
    """ Converts a dataset game to protocol buffer format
        :param saved_game_proto: A `.proto.game.SavedGame` object from the dataset.
        :param is_validation_set: Boolean that indicates if we are generating the validation set (otw. training set)
        :return: A dictionary with phase_ix as key and a dictionary {power_name: (msg_len, proto)} as value
    """
    # Finding top victors and supply centers at end of game
    map_object = Map(saved_game_proto.map)
    top_victors = get_top_victors(saved_game_proto, map_object)
    all_powers = get_map_powers(map_object)
    nb_phases = len(saved_game_proto.phases)
    proto_results = {phase_ix: {} for phase_ix in range(nb_phases)}

    # Getting policy data for the phase_ix
    # (All powers for training - Top victors for evaluation)
    policy_data = get_policy_data(
        saved_game_proto,
        power_names=all_powers,
        top_victors=top_victors if is_validation_set else all_powers)
    value_data = get_value_data(saved_game_proto, all_powers)

    # Building results
    for phase_ix in range(nb_phases - 1):
        for power_name in all_powers:
            if is_validation_set and power_name not in top_victors:
                continue
            phase_policy = policy_data[phase_ix][power_name]
            phase_value = value_data[phase_ix][power_name]

            request_id = DatasetBuilder.get_request_id(saved_game_proto,
                                                       phase_ix, power_name,
                                                       is_validation_set)
            data = {
                'request_id':
                request_id,
                'player_seed':
                0,
                'decoder_inputs': [GO_ID],
                'noise':
                0.,
                'temperature':
                0.,
                'dropout_rate':
                0.,
                'current_power':
                POWER_VOCABULARY_KEY_TO_IX[power_name],
                'current_season':
                get_current_season(saved_game_proto.phases[phase_ix].state)
            }
            data.update(phase_policy)
            data.update(phase_value)

            # Saving results
            proto_result = BaseDatasetBuilder.build_example(
                data, BaseDatasetBuilder.get_proto_fields())
            proto_results[phase_ix][power_name] = (0, proto_result)

    # Returning data for buffer
    return proto_results
예제 #2
0
    def get_orders(self, game, power_names):
        """
        See diplomacy_research.players.player.Player.get_orders
        :param game: Game object
        :param power_names: A list of power names we are playing, or alternatively a single power name.
        :return: One of the following:
                1) If power_name is a string and with_draw == False (or is not set):
                    - A list of orders the power should play
                2) If power_name is a list and with_draw == False (or is not set):
                    - A list of list, which contains orders for each power
                3) If power_name is a string and with_draw == True:
                    - A tuple of 1) the list of orders for the power, 2) a boolean to accept a draw or not
                4) If power_name is a list and with_draw == True:
                    - A list of tuples, each tuple having the list of orders and the draw boolean
        """
        # num_dummies/tiling is hacky way to get around TF Strided Slice error
        # that occurs when only passing in one state (e.g. batch size of 1)
        num_dummies = 2
        order_history = extract_phase_history_proto(game, 3)
        if len(order_history) == 0:
            prev_orders_state = tf.zeros((1, 81, 40), dtype=tf.float32)
        else:
            # print(order_history)
            # Getting last movement phase
            for i in range(len(order_history) - 1, -1, -1):
                if order_history[i].name[-1] == "M":
                    prev_movement_phase = order_history[i]
                    break
                else:
                    continue
            prev_orders_state = proto_to_prev_orders_state(
                prev_movement_phase, game.map).flatten().tolist()
            prev_orders_state = tf.reshape(prev_orders_state, (1, 81, 40))
        prev_orders__state_with_dummies = tf.tile(prev_orders_state,
                                                  [num_dummies, 1, 1])
        board_state = dict_to_flatten_board_state(game.get_state(), game.map)
        board_state = tf.reshape(board_state, (1, 81, 35))
        board_state_with_dummies = tf.tile(board_state, [num_dummies, 1, 1])
        season = get_current_season(extract_state_proto(game))
        state = game.get_state()
        year = state["name"]
        board_dict = parse_rl_state(state)
        orders = []
        order_probs = []
        for power in power_names:
            print(power, year)
            power_season = tf.concat([UNIT_POWER[power], INT_SEASON[season]],
                                     axis=0)
            power_season = tf.expand_dims(power_season, axis=0)
            power_season_with_dummies = tf.tile(power_season, [num_dummies, 1])
            probs, position_list = self.call(
                board_state_with_dummies, prev_orders__state_with_dummies,
                power_season_with_dummies, [year for _ in range(num_dummies)],
                [board_dict for _ in range(num_dummies)], power)

            prob_no_dummies = tf.squeeze(probs, axis=1)[:, 0, :]
            order_ix = tf.argmax(prob_no_dummies, axis=1)
            orders_list = [
                INVERSE_ORDER_DICT[index] for index in order_ix.numpy()
            ]
            orders_probs_list = [
                prob_no_dummies[i][index]
                for i, index in enumerate(order_ix.numpy())
            ]
            orders.append(orders_list)
            order_probs.append(orders_probs_list)
        return orders, order_probs
예제 #3
0
    def get_feedable_item(locs, state_proto, power_name, phase_history_proto, possible_orders_proto, **kwargs):
        """ Computes and return a feedable item (to be fed into the feedable queue)
            :param locs: A list of locations for which we want orders
            :param state_proto: A `.proto.game.State` representation of the state of the game.
            :param power_name: The power name for which we want the orders and the state values
            :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases.
            :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc.
            :param kwargs: Additional optional kwargs:
                - player_seed: The seed to apply to the player to compute a deterministic mask.
                - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon)
                - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy)
                - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder.
            :return: A feedable item, with feature names as key and numpy arrays as values
        """
        # pylint: disable=too-many-branches
        # Converting to state space
        map_object = Map(state_proto.map)
        board_state = proto_to_board_state(state_proto, map_object)

        # Building the decoder length
        # For adjustment phase, we restrict the number of builds/disbands to what is allowed by the game engine
        in_adjustment_phase = state_proto.name[-1] == 'A'
        nb_builds = state_proto.builds[power_name].count
        nb_homes = len(state_proto.builds[power_name].homes)

        # If we are in adjustment phase, making sure the locs are the orderable locs (and not the policy locs)
        if in_adjustment_phase:
            orderable_locs, _ = get_orderable_locs_for_powers(state_proto, [power_name])
            if sorted(locs) != sorted(orderable_locs):
                if locs:
                    LOGGER.warning('Adj. phase requires orderable locs. Got %s. Expected %s.', locs, orderable_locs)
                locs = orderable_locs

        # WxxxA - We can build units
        # WxxxA - We can disband units
        # Other phase
        if in_adjustment_phase and nb_builds >= 0:
            decoder_length = min(nb_builds, nb_homes)
        elif in_adjustment_phase and nb_builds < 0:
            decoder_length = abs(nb_builds)
        else:
            decoder_length = len(locs)

        # Computing the candidates for the policy
        if possible_orders_proto:

            # Adjustment Phase - Use all possible orders for each location.
            if in_adjustment_phase:

                # Building a list of all orders for all locations
                adj_orders = []
                for loc in locs:
                    adj_orders += possible_orders_proto[loc].value

                # Computing the candidates
                candidates = [get_order_based_mask(adj_orders)] * decoder_length

            # Regular phase - Compute candidates for each location
            else:
                candidates = []
                for loc in locs:
                    candidates += [get_order_based_mask(possible_orders_proto[loc].value)]

        # We don't have possible orders, so we cannot compute candidates
        # This might be normal if we are only getting the state value or the next message to send
        else:
            candidates = []
            for _ in range(decoder_length):
                candidates.append([])

        # Prev orders state
        prev_orders_state = []
        for phase_proto in reversed(phase_history_proto):
            if len(prev_orders_state) == NB_PREV_ORDERS:
                break
            if phase_proto.name[-1] == 'M':
                prev_orders_state = [proto_to_prev_orders_state(phase_proto, map_object)] + prev_orders_state
        for _ in range(NB_PREV_ORDERS - len(prev_orders_state)):
            prev_orders_state = [np.zeros((NB_NODES, NB_ORDERS_FEATURES), dtype=np.uint8)] + prev_orders_state
        prev_orders_state = np.array(prev_orders_state)

        # Building (order) decoder inputs [GO_ID]
        decoder_inputs = [GO_ID]

        # kwargs
        player_seed = kwargs.get('player_seed', 0)
        noise = kwargs.get('noise', 0.)
        temperature = kwargs.get('temperature', 0.)
        dropout_rate = kwargs.get('dropout_rate', 0.)

        # Building feedable data
        item = {
            'player_seed': player_seed,
            'board_state': board_state,
            'board_alignments': get_board_alignments(locs,
                                                     in_adjustment_phase=in_adjustment_phase,
                                                     tokens_per_loc=1,
                                                     decoder_length=decoder_length),
            'prev_orders_state': prev_orders_state,
            'decoder_inputs': decoder_inputs,
            'decoder_lengths': decoder_length,
            'candidates': candidates,
            'noise': noise,
            'temperature': temperature,
            'dropout_rate': dropout_rate,
            'current_power': POWER_VOCABULARY_KEY_TO_IX[power_name],
            'current_season': get_current_season(state_proto)
        }

        # Return
        return item