Example #1
0
def build_game_from_state_proto(state_proto):
    """ Builds a game object from a state_proto """
    game = Game(map_name=state_proto.map, rules=state_proto.rules)
    game.set_current_phase(state_proto.name)

    # Setting units
    game.clear_units()
    for power_name in state_proto.units:
        game.set_units(power_name, list(state_proto.units[power_name].value))

    # Setting centers
    game.clear_centers()
    for power_name in state_proto.centers:
        game.set_centers(power_name,
                         list(state_proto.centers[power_name].value))

    # Returning
    return game
Example #2
0
def state_dict_to_game_and_power(state_dict, country_id, max_phases=None):
    """ Converts a game state from the dictionary format to an actual diplomacy.Game object with the related power.
        :param state_dict: The game state in dictionary format from webdiplomacy.net
        :param country_id: The country id we want to convert.
        :param max_phases: Optional. If set, improve speed by only keeping the last 'x' phases to regenerate the game.
        :return: A tuple of
            1) None, None       - on error or if the conversion is not possible, or game is invalid / not-started / done
            2) game, power_name - on successful conversion
    """
    if state_dict is None:
        return None, None

    req_fields = ('gameID', 'variantID', 'turn', 'phase', 'gameOver', 'phases', 'standoffs', 'occupiedFrom')
    if [1 for field in req_fields if field not in state_dict]:
        LOGGER.error('The required fields for state dict are %s. Cannot translate %s', req_fields, state_dict)
        return None, None

    # Extracting information
    game_id = str(state_dict['gameID'])
    map_id = int(state_dict['variantID'])
    standoffs = state_dict['standoffs']
    occupied_from = state_dict['occupiedFrom']

    # Parsing all phases
    state_dict_phases = state_dict.get('phases', [])
    if max_phases is not None and isinstance(max_phases, int):
        state_dict_phases = state_dict_phases[-1 * max_phases:]
    all_phases = [process_phase_dict(phase_dict, map_id=map_id) for phase_dict in state_dict_phases]

    # Building game - Replaying the last phases
    game = Game(game_id=game_id, map_name=CACHE['ix_to_map'][map_id])

    for phase_to_replay in all_phases[:-1]:
        game.set_current_phase(phase_to_replay['name'])

        # Units
        game.clear_units()
        for power_name, power_units in phase_to_replay['units'].items():
            if power_name == 'GLOBAL':
                continue
            game.set_units(power_name, power_units)

        # Centers
        game.clear_centers()
        for power_name, power_centers in phase_to_replay['centers'].items():
            if power_name == 'GLOBAL':
                continue
            game.set_centers(power_name, power_centers)

        # Orders
        for power_name, power_orders in phase_to_replay['orders'].items():
            if power_name == 'GLOBAL':
                continue
            game.set_orders(power_name, power_orders)

        # Processing
        game.process()

    # Setting the current phase
    current_phase = all_phases[-1]
    game.set_current_phase(current_phase['name'])

    # Units
    game.clear_units()
    for power_name, power_units in current_phase['units'].items():
        if power_name == 'GLOBAL':
            continue
        game.set_units(power_name, power_units)

    # Centers
    game.clear_centers()
    for power_name, power_centers in current_phase['centers'].items():
        if power_name == 'GLOBAL':
            continue
        game.set_centers(power_name, power_centers)

    # Setting retreat locs
    if current_phase['name'][-1] == 'R':
        invalid_retreat_locs = set()
        attack_source = {}

        # Loc is occupied
        for power in game.powers.values():
            for unit in power.units:
                invalid_retreat_locs.add(unit[2:5])

        # Loc was in standoff
        if standoffs:
            for loc_dict in standoffs:
                _, loc = center_dict_to_str(loc_dict, map_id=map_id)
                invalid_retreat_locs.add(loc[:3])

        # Loc was attacked from
        if occupied_from:
            for loc_id, occupied_from_id in occupied_from.items():
                loc_name = CACHE[map_id]['ix_to_loc'][int(loc_id)][:3]
                from_loc_name = CACHE[map_id]['ix_to_loc'][int(occupied_from_id)][:3]
                attack_source[loc_name] = from_loc_name

        # Removing invalid retreat locs
        for power in game.powers.values():
            for retreat_unit in power.retreats:
                power.retreats[retreat_unit] = [loc for loc in power.retreats[retreat_unit]
                                                if loc[:3] not in invalid_retreat_locs
                                                and loc[:3] != attack_source.get(retreat_unit[2:5], '')]

    # Returning
    power_name = CACHE[map_id]['ix_to_power'][country_id]
    return game, power_name
def test_survivor_win_reward():
    """ Test survivor win reward function """
    game = Game()
    pot_size = 20
    rew_fn = SurvivorWinReward(pot_size=pot_size)
    prev_state_proto = extract_state_proto(game)
    state_proto = extract_state_proto(game)
    assert rew_fn.name == 'survivor_win_reward'
    get_reward = lambda power_name, is_terminal, done_reason: round(
        rew_fn.get_reward(prev_state_proto,
                          state_proto,
                          power_name,
                          is_terminal_state=is_terminal,
                          done_reason=done_reason), 8)

    # --- Not in terminal state
    assert get_reward('AUSTRIA', False, None) == 0.
    assert get_reward('ENGLAND', False, None) == 0.
    assert get_reward('FRANCE', False, None) == 0.
    assert get_reward('GERMANY', False, None) == 0.
    assert get_reward('ITALY', False, None) == 0.
    assert get_reward('RUSSIA', False, None) == 0.
    assert get_reward('TURKEY', False, None) == 0.

    # --- In terminal state
    assert get_reward('AUSTRIA', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)
    assert get_reward('ENGLAND', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)
    assert get_reward('FRANCE', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)
    assert get_reward('GERMANY', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)
    assert get_reward('ITALY', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)
    assert get_reward('RUSSIA', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)
    assert get_reward('TURKEY', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == 0.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == 0.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == 0.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == 0.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == 0.

    # --- Clearing supply centers
    prev_state_proto = extract_state_proto(game)
    for power in game.powers.values():
        if power.name != 'FRANCE' and power.name != 'RUSSIA':
            power.clear_units()
            power.centers = []
    state_proto = extract_state_proto(game)
    get_reward = lambda power_name, is_terminal, done_reason: round(
        rew_fn.get_reward(prev_state_proto,
                          state_proto,
                          power_name,
                          is_terminal_state=is_terminal,
                          done_reason=done_reason), 8)

    # --- In terminal state
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ENGLAND', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('FRANCE', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 2., 8)
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('RUSSIA', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 2., 8)
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 0.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == 0.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == 0.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == 0.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == 0.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == 0.

    # Move centers in other countries to FRANCE except ENGLAND
    # Winner: FRANCE
    # Survivor: FRANCE, ENGLAND
    game = Game()
    prev_state_proto = extract_state_proto(game)
    game.clear_centers()
    game.set_centers('FRANCE', [
        'BUD', 'TRI', 'VIE', 'BRE', 'MAR', 'PAR', 'BER', 'KIE', 'MUN', 'NAP',
        'ROM', 'VEN', 'MOS', 'SEV', 'STP', 'WAR', 'ANK', 'CON', 'SMY'
    ])
    game.set_centers('ENGLAND', ['EDI', 'LON', 'LVP'])
    state_proto = extract_state_proto(game)
    get_reward = lambda power_name, is_terminal, done_reason: round(
        rew_fn.get_reward(prev_state_proto,
                          state_proto,
                          power_name,
                          is_terminal_state=is_terminal,
                          done_reason=done_reason), 8)

    # France has 19 SC, 18 to win, 1 excess
    # Nb of controlled SC is 19 + 3 - 1 excess = 21
    # Reward for FRANCE is 18 / 21 * pot
    # Reward for ENGLAND is 3 / 21 * pot

    # --- In terminal state -- Victory
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ENGLAND', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 3. / 21, 8)
    assert get_reward('FRANCE', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 18. / 21, 8)
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('RUSSIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 0.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == 0.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == 0.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == 0.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == 0.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == 0.
Example #4
0
def create_mask(board_state, phase, locs, board_dict):
    '''
    Given a board_state, produces a mask that only includes valid orders from loc,
    based on their positions in get_order_vocabulary in state_space. Assumes playing on standard map

    Args:
    board_state - 81 x dilbo vector representing current board state as described in Figure 2
    phase - string indicating phase of game (e.g. 'S1901M')
    loc - list of strings representing locations (e.g. ['PAR', ... ])

    Returns:
    List of Masks for zeroing out invalid orders, length is the number of orders total
    '''

    # create instance of Game object based on board_state
    game = Game(map_name='standard')

    game.set_current_phase(phase)
    game.clear_units()

    power_units = {}
    power_centers = {}

    for loc_idx in range(len(board_state)):
        loc_name = ORDERING[loc_idx]
        loc_vec = board_state[loc_idx,:]

        if "unit_type" in board_dict[loc_name].keys():
            unit_type = board_dict[loc_name]["unit_type"]
        else:
            unit_type = None

        if "unit_power" in board_dict[loc_name].keys():
            unit_power = board_dict[loc_name]["unit_power"]
        else:
            unit_power = None

        if "buildable" in board_dict[loc_name].keys():
            buildable = board_dict[loc_name]["buildable"]
        else:
            buildable = None

        if "removable" in board_dict[loc_name].keys():
            removable = board_dict[loc_name]["removable"]
        else:
            removable = None

        if "d_unit_type" in board_dict[loc_name].keys():
            dislodged_unit_type = board_dict[loc_name]["d_unit_type"]
        else:
            dislodged_unit_type = None

        if "d_unit_power" in board_dict[loc_name].keys():
            dislodged_unit_power = board_dict[loc_name]["d_unit_power"]
        else:
            dislodged_unit_power = None

        if "area_type" in board_dict[loc_name].keys():
            area_type = board_dict[loc_name]["area_type"]
        else:
            area_type = None

        if "supply_center_owner" in board_dict[loc_name].keys():
            supply_center_owner = board_dict[loc_name]["supply_center_owner"]
        else:
            supply_center_owner = None

        # # extract one hot vectors from encoding
        # unit_type_one_hot = loc_vec[0:3]
        # unit_power_one_hot = loc_vec[3:11]
        # buildable = loc_vec[11]
        # removable = loc_vec[12]
        # dislodged_unit_type_one_hot = loc_vec[13:16]
        # dislodged_unit_power_one_hot = loc_vec[16:24]
        # area_type_one_hot = loc_vec[24:27]
        # supply_center_owner_one_hot = loc_vec[27:35]
        #
        # # convert one hot vectors into indices, and index into unit types and powers
        # unit_type = INV_UNIT_TYPE[np.argmax(unit_type_one_hot)]
        # unit_power = INV_UNIT_POWER[np.argmax(unit_power_one_hot)]
        # dislodged_unit_type = INV_UNIT_TYPE[np.argmax(dislodged_unit_type_one_hot)]
        # dislodged_unit_power = INV_UNIT_POWER[np.argmax(dislodged_unit_power_one_hot)]
        # supply_center_owner = INV_UNIT_POWER[np.argmax(supply_center_owner_one_hot)]

        # add the unit and/or dislodged unit in this locatino to power_units dict
        # likewise for supply center (if it exists). See set_units() documentation for how units are formatted
        if unit_type != None:
            if unit_power not in power_units:
                power_units[unit_power] = []
            power_units[unit_power].append('{} {}'.format(unit_type, loc_name))
        if dislodged_unit_type != None:
            if dislodged_unit_type not in power_units:
                power_units[dislodged_unit_power] = []
            power_units[dislodged_unit_power].append('*{} {}'.format(dislodged_unit_type, loc_name))
        if supply_center_owner != None:
            if supply_center_owner not in power_centers:
                power_centers[supply_center_owner] = []
            power_centers[supply_center_owner].append(loc_name)

    # Setting units
    game.clear_units()
    for power_name in list(power_units.keys()):
        game.set_units(power_name, power_units[power_name])

    # Setting centers
    game.clear_centers()
    for power_name in list(power_centers.keys()):
        game.set_centers(power_name, power_centers[power_name])

    possible_orders = game.get_all_possible_orders()

    masks = np.full((len(locs), ORDER_VOCABULARY_SIZE),-(10**15))

    for i in range(len(locs)):
        loc = locs[i]
        # Needs to be negative infinity to mask softmax function.
        # Negative infinity gives nan in loss, so using very large negative number instead
        for order in possible_orders[loc]:
            ix = state_space.order_to_ix(order)
            masks[i,ix] = 0
    return np.array(masks)
def test_sum_of_squares_reward():
    """ Test sum of squares reward function """
    game = Game()
    pot_size = 20
    rew_fn = SumOfSquares(pot_size=pot_size)
    prev_state_proto = extract_state_proto(game)
    state_proto = extract_state_proto(game)
    assert rew_fn.name == 'sum_of_squares_reward'
    get_reward = lambda power_name, is_terminal, done_reason: round(
        rew_fn.get_reward(prev_state_proto,
                          state_proto,
                          power_name,
                          is_terminal_state=is_terminal,
                          done_reason=done_reason), 8)

    # --- Not in terminal state
    assert get_reward('AUSTRIA', False, None) == 0.
    assert get_reward('ENGLAND', False, None) == 0.
    assert get_reward('FRANCE', False, None) == 0.
    assert get_reward('GERMANY', False, None) == 0.
    assert get_reward('ITALY', False, None) == 0.
    assert get_reward('RUSSIA', False, None) == 0.
    assert get_reward('TURKEY', False, None) == 0.

    # --- In terminal state
    assert get_reward('AUSTRIA', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 70., 8)
    assert get_reward('ENGLAND', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 70., 8)
    assert get_reward('FRANCE', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 70., 8)
    assert get_reward('GERMANY', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 70., 8)
    assert get_reward('ITALY', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 70., 8)
    assert get_reward('RUSSIA', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 16 / 70., 8)
    assert get_reward('TURKEY', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 70., 8)

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == 0.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == 0.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == 0.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == 0.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == 0.

    # --- Clearing supply centers
    prev_state_proto = extract_state_proto(game)
    for power in game.powers.values():
        if power.name != 'FRANCE' and power.name != 'RUSSIA':
            power.clear_units()
            power.centers = []
    state_proto = extract_state_proto(game)
    get_reward = lambda power_name, is_terminal, done_reason: round(
        rew_fn.get_reward(prev_state_proto,
                          state_proto,
                          power_name,
                          is_terminal_state=is_terminal,
                          done_reason=done_reason), 8)

    # --- In terminal state
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ENGLAND', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('FRANCE', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 25., 8)
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('RUSSIA', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 16 / 25., 8)
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 0.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == 0.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == 0.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == 0.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == 0.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == 0.

    # Move centers in other countries to FRANCE except ENGLAND
    # Winner: FRANCE
    # Survivor: FRANCE, ENGLAND
    game = Game()
    prev_state_proto = extract_state_proto(game)
    game.clear_centers()
    game.set_centers('FRANCE', [
        'BUD', 'TRI', 'VIE', 'BRE', 'MAR', 'PAR', 'BER', 'KIE', 'MUN', 'NAP',
        'ROM', 'VEN', 'MOS', 'SEV', 'STP', 'WAR', 'ANK', 'CON', 'SMY'
    ])
    game.set_centers('ENGLAND', ['EDI', 'LON', 'LVP'])
    state_proto = extract_state_proto(game)
    get_reward = lambda power_name, is_terminal, done_reason: round(
        rew_fn.get_reward(prev_state_proto,
                          state_proto,
                          power_name,
                          is_terminal_state=is_terminal,
                          done_reason=done_reason), 8)

    # --- In terminal state -- Victory
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ENGLAND', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('FRANCE', True,
                      DoneReason.GAME_ENGINE) == round(pot_size, 8)
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('RUSSIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 0.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == 0.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == 0.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == 0.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == 0.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == 0.