Exemplo n.º 1
0
    def test_get_state_value(self):
        """ Checks if the .get_state_value method works """
        game = Game()
        state_proto = extract_state_proto(game)
        phase_history_proto = extract_phase_history_proto(game)
        kwargs = {
            'player_seed': 0,
            'noise': 0.,
            'temperature': 0.,
            'dropout_rate': 0.
        }

        # Testing with and without prefetching
        for use_prefetching in (False, True):

            if not use_prefetching:
                results = yield self.adapter.get_state_value(
                    state_proto, 'FRANCE', phase_history_proto, **kwargs)
            else:
                fetches = yield self.adapter.get_state_value(
                    state_proto,
                    'FRANCE',
                    phase_history_proto,
                    prefetch=True,
                    **kwargs)
                fetches = yield process_fetches_dict(self.queue_dataset,
                                                     fetches)
                results = yield self.adapter.get_state_value(
                    state_proto,
                    'FRANCE',
                    phase_history_proto,
                    fetches=fetches,
                    **kwargs)
            assert results != 0.
Exemplo n.º 2
0
def test_board_state():
    """ Tests the proto_to_state_space  """
    game = Game()
    game_map = game.map
    state_proto = state_space.extract_state_proto(game)
    new_game = state_space.build_game_from_state_proto(state_proto)

    # Retrieving board_state
    state_proto_2 = state_space.extract_state_proto(new_game)
    board_state_1 = state_space.proto_to_board_state(state_proto, game_map)
    board_state_2 = state_space.proto_to_board_state(state_proto_2, game_map)

    # Checking
    assert np.allclose(board_state_1, board_state_2)
    assert board_state_1.shape == (state_space.NB_NODES, state_space.NB_FEATURES)
    assert game.get_hash() == new_game.get_hash()
Exemplo n.º 3
0
    def get_opening_orders(self):
        """ Returns a dictionary of power_name: [orders] for each power
            The orders represent the opening orders that would have been submitted by the player
        """
        game = Game()
        state_proto = extract_state_proto(game)
        phase_history_proto = extract_phase_history_proto(game)
        possible_orders_proto = extract_possible_orders_proto(game)

        # Retrieving all orders
        # Using default player_seed, noise, temperature, and dropout_rate.
        # power_orders is a list of tuples (orders, policy_details)
        power_orders = yield [
            self.policy_adapter.get_orders(self.get_orderable_locations(
                state_proto, power_name),
                                           state_proto,
                                           power_name,
                                           phase_history_proto,
                                           possible_orders_proto,
                                           retry_on_failure=False)
            for power_name in game.powers
        ]
        return {
            power_name: orders[0]
            for power_name, orders in zip(game.powers.keys(), power_orders)
        }
def test_int_norm_centers_reward():
    """ Tests for InterimNormNbCentersReward """
    game = Game()
    rew_fn = IntNormNbCentersReward()

    # Removing one center from FRANCE and adding it to GERMANY
    prev_state_proto = extract_state_proto(game)
    for power in game.powers.values():
        if power.name == 'FRANCE':
            power.centers.remove('PAR')
        if power.name == 'GERMANY':
            power.centers.append('PAR')
    state_proto = extract_state_proto(game)
    assert rew_fn.name == 'int_norm_nb_centers_reward'
    get_reward = lambda power_name, is_terminal, done_reason: rew_fn.get_reward(
        prev_state_proto,
        state_proto,
        power_name,
        is_terminal_state=is_terminal,
        done_reason=done_reason)

    # --- Not in terminal state
    assert get_reward('AUSTRIA', False, None) == 0.
    assert get_reward('ENGLAND', False, None) == 0.
    assert get_reward('FRANCE', False, None) == -1. / 18
    assert get_reward('GERMANY', False, None) == 1. / 18
    assert get_reward('ITALY', False, None) == 0.
    assert get_reward('RUSSIA', False, None) == 0.
    assert get_reward('TURKEY', False, None) == 0.

    # --- In terminal state
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ENGLAND', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('FRANCE', True, DoneReason.GAME_ENGINE) == -1. / 18
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 1. / 18
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('RUSSIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 0.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == -1.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == -1.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == -1.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == -1.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == -1.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == -1.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == -1.
Exemplo n.º 5
0
def test_get_nb_centers():
    """ Testing if the number of supply centers is correct """
    game = Game()
    player = FakePlayer()
    state_proto = extract_state_proto(game)

    # Checking every power
    power_names = [power_name for power_name in game.powers]
    for power_name in power_names:
        assert player.get_nb_centers(state_proto, power_name) == len(game.get_power(power_name).centers)
Exemplo n.º 6
0
def test_get_orderable_locations():
    """ Testing if the number of orderable locations is correct """
    game = Game()
    player = FakePlayer()
    state_proto = extract_state_proto(game)

    # Checking every power
    power_names = [power_name for power_name in game.powers]
    for power_name in power_names:
        expected_locs = [unit.replace('*', '')[2:5] for unit in state_proto.units[power_name].value]
        expected_locs += state_proto.builds[power_name].homes
        assert sorted(player.get_orderable_locations(state_proto, power_name)) == sorted(expected_locs)
def test_norm_centers_reward():
    """ Tests for NormNbCentersReward """
    game = Game()
    rew_fn = NormNbCentersReward()
    prev_state_proto = extract_state_proto(game)
    state_proto = extract_state_proto(game)
    assert rew_fn.name == 'norm_nb_centers_reward'
    get_reward = lambda power_name, is_terminal, done_reason: rew_fn.get_reward(
        prev_state_proto,
        state_proto,
        power_name,
        is_terminal_state=is_terminal,
        done_reason=done_reason)

    # --- Not in terminal state
    assert get_reward('AUSTRIA', False, None) == 0.
    assert get_reward('ENGLAND', False, None) == 0.
    assert get_reward('FRANCE', False, None) == 0.
    assert get_reward('GERMANY', False, None) == 0.
    assert get_reward('ITALY', False, None) == 0.
    assert get_reward('RUSSIA', False, None) == 0.
    assert get_reward('TURKEY', False, None) == 0.

    # --- In terminal state
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 3. / 18
    assert get_reward('ENGLAND', True, DoneReason.GAME_ENGINE) == 3. / 18
    assert get_reward('FRANCE', True, DoneReason.GAME_ENGINE) == 3. / 18
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 3. / 18
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == 3. / 18
    assert get_reward('RUSSIA', True, DoneReason.GAME_ENGINE) == 4. / 18
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 3. / 18

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == 0.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == 0.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == 0.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == 0.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == 0.
Exemplo n.º 8
0
    def get_orders(self,
                   game,
                   power_names,
                   *,
                   retry_on_failure=True,
                   **kwargs):
        """ Gets the orders the power(s) should play.
            :param game: The game object
            :param power_names: A list of power names we are playing, or alternatively a single power name.
            :param retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
            :param kwargs: Additional optional kwargs:
                - player_seed: If set. Override the player_seed to use for the model based player.
                - noise: If set. Override the noise to use for the model based player.
                - temperature: If set. Override the temperature to use for the model based player.
                - dropout_rate: If set. Override the dropout_rate to use for the model based player.
                - with_draw: If set, also returns whether to accept a draw or not
            :return: One of the following:
                1) If power_name is a string and with_draw == False (or is not set):
                    - A list of orders the power should play
                2) If power_name is a list and with_draw == False (or is not set):
                    - A list of list, which contains orders for each power
                3) If power_name is a string and with_draw == True:
                    - A tuple of 1) the list of orders for the power, 2) a boolean to accept a draw or not
                4) If power_name is a list and with_draw == True:
                    - A list of tuples, each tuple having the list of orders and the draw boolean
            :type game: diplomacy.Game
        """
        state_proto = extract_state_proto(game)
        phase_history_proto = extract_phase_history_proto(game)
        possible_orders_proto = extract_possible_orders_proto(game)

        # Determining if we have a single or multiple powers
        if not isinstance(power_names, list):
            is_single_power = True
            power_names = [power_names]
        else:
            is_single_power = False

        # Getting orders (and optional draw)
        orders_with_maybe_draw = yield [
            self.get_orders_with_proto(state_proto,
                                       power_name,
                                       phase_history_proto,
                                       possible_orders_proto,
                                       retry_on_failure=retry_on_failure,
                                       **kwargs) for power_name in power_names
        ]

        # Returning a single instance, or a list
        orders_with_maybe_draw = orders_with_maybe_draw[
            0] if is_single_power else orders_with_maybe_draw
        return orders_with_maybe_draw
Exemplo n.º 9
0
 def test_get_feedable_item(self):
     """ Checks if the .get_feedable_item method works """
     game = Game()
     state_proto = extract_state_proto(game)
     phase_history_proto = extract_phase_history_proto(game)
     possible_orders_proto = extract_possible_orders_proto(game)
     locs = ['PAR', 'MAR', 'BUR']
     kwargs = {
         'player_seed': 0,
         'noise': 0.,
         'temperature': 0.,
         'dropout_rate': 0.
     }
     assert self.dataset_builder.get_feedable_item(locs, state_proto,
                                                   'FRANCE',
                                                   phase_history_proto,
                                                   possible_orders_proto,
                                                   **kwargs)
Exemplo n.º 10
0
    def get_policy_details(self,
                           game,
                           power_names,
                           *,
                           retry_on_failure=True,
                           **kwargs):
        """ Gets the details of the current policy
            :param game: The game object
            :param power_names: A list of power names we are playing, or alternatively a single power name.
            :param retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
            :param kwargs: Additional optional kwargs:
                - player_seed: If set. Override the player_seed to use for the model based player.
                - noise: If set. Override the noise to use for the model based player.
                - temperature: If set. Override the temperature to use for the model based player.
                - dropout_rate: If set. Override the dropout_rate to use for the model based player.
            :return: 1) If power_names is a string, the policy details
                        ==> {'locs', 'tokens', 'log_probs', 'draw_action', 'draw_prob'}
                     2) If power_names is a list, a list of policy details, one for each power.
            :type game: diplomacy.Game
        """
        state_proto = extract_state_proto(game)
        phase_history_proto = extract_phase_history_proto(game)
        possible_orders_proto = extract_possible_orders_proto(game)

        # Determining if we have a single or multiple powers
        if not isinstance(power_names, list):
            is_single_power = True
            power_names = [power_names]
        else:
            is_single_power = False

        # Getting policy details
        policy_details = yield [
            self.get_policy_details_with_proto(
                state_proto,
                power_name,
                phase_history_proto,
                possible_orders_proto,
                retry_on_failure=retry_on_failure,
                **kwargs) for power_name in power_names
        ]
        policy_details = policy_details[
            0] if is_single_power else policy_details
        return policy_details
Exemplo n.º 11
0
    def get_beam_orders(self,
                        game,
                        power_names,
                        *,
                        retry_on_failure=True,
                        **kwargs):
        """ Finds all the beams with their probabilities returned by the diverse beam search for the selected power(s)
            Beams are ordered by score (highest first).
            :param game: The game object
            :param power_names: A list of power names we are playing, or alternatively a single power name.
            :param retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
            :param kwargs: Additional optional kwargs:
                - player_seed: The seed to apply to the player to compute a deterministic mask.
                - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon)
                - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy)
                - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder.
            :return: 1) If power_names is a string, a tuple of beam orders, and of beam probabilities
                     2) If power_names is a list, a list of list which contains beam orders and beam probabilities
            :type game: diplomacy.Game
        """
        state_proto = extract_state_proto(game)
        phase_history_proto = extract_phase_history_proto(game)
        possible_orders_proto = extract_possible_orders_proto(game)

        # Determining if we have a single or multiple powers
        if not isinstance(power_names, list):
            is_single_power = True
            power_names = [power_names]
        else:
            is_single_power = False

        # Getting beam orders
        beam_orders_probs = yield [
            self.get_beam_orders_with_proto(state_proto,
                                            power_name,
                                            phase_history_proto,
                                            possible_orders_proto,
                                            retry_on_failure=retry_on_failure,
                                            **kwargs)
            for power_name in power_names
        ]
        beam_orders_probs = beam_orders_probs[
            0] if is_single_power else beam_orders_probs
        return beam_orders_probs
Exemplo n.º 12
0
    def get_state_value(self,
                        game,
                        power_names,
                        *,
                        retry_on_failure=True,
                        **kwargs):
        """ Calculates the player's value of the state of the game for the given power(s)
            :param game: A game object
            :param power_names: A list of power names for which we want the value, or alternatively a single power name.
            :param retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
            :param kwargs: Additional optional kwargs:
                - player_seed: If set. Override the player_seed to use for the model based player.
                - noise: If set. Override the noise to use for the model based player.
                - temperature: If set. Override the temperature to use for the model based player.
                - dropout_rate: If set. Override the dropout_rate to use for the model based player.
            :return: 1) If power_names is a string, a single float representing the value of the state for the power
                     2) If power_names is a list, a list of floats representing the value for each power.
            :type game: diplomacy.Game
        """
        state_proto = extract_state_proto(game)
        phase_history_proto = extract_phase_history_proto(game)
        possible_order_proto = extract_possible_orders_proto(game)

        # Determining if we have a single or multiple powers
        if not isinstance(power_names, list):
            is_single_power = True
            power_names = [power_names]
        else:
            is_single_power = False

        # Getting state value
        state_value = yield [
            self.get_state_value_with_proto(state_proto,
                                            power_name,
                                            phase_history_proto,
                                            possible_order_proto,
                                            retry_on_failure=retry_on_failure,
                                            **kwargs)
            for power_name in power_names
        ]
        state_value = state_value[0] if is_single_power else state_value
        return state_value
Exemplo n.º 13
0
    def test_get_draw_prob(self):
        """ Checks if the .get_draw_prob method works """
        game = Game()
        state_proto = extract_state_proto(game)
        phase_history_proto = extract_phase_history_proto(game)
        possible_orders_proto = extract_possible_orders_proto(game)
        locs = ['PAR', 'MAR', 'BUR']
        kwargs = {
            'player_seed': 0,
            'noise': 0.,
            'temperature': 1.,
            'dropout_rate': 0.
        }

        # Temperature == 1.
        # With and without prefetching
        for use_prefetching in (False, True):
            if not use_prefetching:
                _, policy_details = yield self.adapter.get_orders(
                    locs, state_proto, 'FRANCE', phase_history_proto,
                    possible_orders_proto, **kwargs)
            else:
                fetches = yield self.adapter.get_orders(locs,
                                                        state_proto,
                                                        'FRANCE',
                                                        phase_history_proto,
                                                        possible_orders_proto,
                                                        prefetch=True,
                                                        **kwargs)
                fetches = yield process_fetches_dict(self.queue_dataset,
                                                     fetches)
                _, policy_details = yield self.adapter.get_orders(
                    locs,
                    state_proto,
                    'FRANCE',
                    phase_history_proto,
                    possible_orders_proto,
                    fetches=fetches,
                    **kwargs)

            assert policy_details['draw_action'] in (True, False)
            assert 0. < policy_details['draw_prob'] < 1.
Exemplo n.º 14
0
    def get_opening_orders(self):
        """ Returns a dictionary of power_name: [orders] for each power
            The orders represent the opening orders that would have been submitted by the player
        """
        game = Game()
        state_proto = extract_state_proto(game)
        phase_history_proto = extract_phase_history_proto(game)
        possible_orders_proto = extract_possible_orders_proto(game)

        # Retrieving all orders
        # Not using kwargs - Using default player_seed, noise, temperature, and dropout_rate.
        power_orders = yield [
            self.get_orders_with_proto(state_proto,
                                       power_name,
                                       phase_history_proto,
                                       possible_orders_proto,
                                       retry_on_failure=False)
            for power_name in game.powers
        ]
        return {
            power_name: orders
            for power_name, orders in zip(game.powers.keys(), power_orders)
        }
Exemplo n.º 15
0
def test_sum_of_squares_reward():
    """ Test sum of squares reward function """
    game = Game()
    pot_size = 20
    rew_fn = SumOfSquares(pot_size=pot_size)
    prev_state_proto = extract_state_proto(game)
    state_proto = extract_state_proto(game)
    assert rew_fn.name == 'sum_of_squares_reward'
    get_reward = lambda power_name, is_terminal, done_reason: round(
        rew_fn.get_reward(prev_state_proto,
                          state_proto,
                          power_name,
                          is_terminal_state=is_terminal,
                          done_reason=done_reason), 8)

    # --- Not in terminal state
    assert get_reward('AUSTRIA', False, None) == 0.
    assert get_reward('ENGLAND', False, None) == 0.
    assert get_reward('FRANCE', False, None) == 0.
    assert get_reward('GERMANY', False, None) == 0.
    assert get_reward('ITALY', False, None) == 0.
    assert get_reward('RUSSIA', False, None) == 0.
    assert get_reward('TURKEY', False, None) == 0.

    # --- In terminal state
    assert get_reward('AUSTRIA', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 70., 8)
    assert get_reward('ENGLAND', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 70., 8)
    assert get_reward('FRANCE', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 70., 8)
    assert get_reward('GERMANY', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 70., 8)
    assert get_reward('ITALY', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 70., 8)
    assert get_reward('RUSSIA', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 16 / 70., 8)
    assert get_reward('TURKEY', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 70., 8)

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == 0.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == 0.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == 0.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == 0.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == 0.

    # --- Clearing supply centers
    prev_state_proto = extract_state_proto(game)
    for power in game.powers.values():
        if power.name != 'FRANCE' and power.name != 'RUSSIA':
            power.clear_units()
            power.centers = []
    state_proto = extract_state_proto(game)
    get_reward = lambda power_name, is_terminal, done_reason: round(
        rew_fn.get_reward(prev_state_proto,
                          state_proto,
                          power_name,
                          is_terminal_state=is_terminal,
                          done_reason=done_reason), 8)

    # --- In terminal state
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ENGLAND', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('FRANCE', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 9 / 25., 8)
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('RUSSIA', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 16 / 25., 8)
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 0.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == 0.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == 0.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == 0.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == 0.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == 0.

    # Move centers in other countries to FRANCE except ENGLAND
    # Winner: FRANCE
    # Survivor: FRANCE, ENGLAND
    game = Game()
    prev_state_proto = extract_state_proto(game)
    game.clear_centers()
    game.set_centers('FRANCE', [
        'BUD', 'TRI', 'VIE', 'BRE', 'MAR', 'PAR', 'BER', 'KIE', 'MUN', 'NAP',
        'ROM', 'VEN', 'MOS', 'SEV', 'STP', 'WAR', 'ANK', 'CON', 'SMY'
    ])
    game.set_centers('ENGLAND', ['EDI', 'LON', 'LVP'])
    state_proto = extract_state_proto(game)
    get_reward = lambda power_name, is_terminal, done_reason: round(
        rew_fn.get_reward(prev_state_proto,
                          state_proto,
                          power_name,
                          is_terminal_state=is_terminal,
                          done_reason=done_reason), 8)

    # --- In terminal state -- Victory
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ENGLAND', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('FRANCE', True,
                      DoneReason.GAME_ENGINE) == round(pot_size, 8)
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('RUSSIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 0.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == 0.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == 0.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == 0.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == 0.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == 0.
Exemplo n.º 16
0
def test_plus_minus_one_reward():
    """ Tests for PlusOneMinusOneReward """
    game = Game()
    rew_fn = PlusOneMinusOneReward()
    prev_state_proto = extract_state_proto(game)
    state_proto = extract_state_proto(game)
    assert rew_fn.name == 'plus_one_minus_one_reward'
    get_reward = lambda power_name, is_terminal, done_reason: rew_fn.get_reward(
        prev_state_proto,
        state_proto,
        power_name,
        is_terminal_state=is_terminal,
        done_reason=done_reason)

    # --- Not in terminal state
    assert get_reward('AUSTRIA', False, None) == 0.
    assert get_reward('ENGLAND', False, None) == 0.
    assert get_reward('FRANCE', False, None) == 0.
    assert get_reward('GERMANY', False, None) == 0.
    assert get_reward('ITALY', False, None) == 0.
    assert get_reward('RUSSIA', False, None) == 0.
    assert get_reward('TURKEY', False, None) == 0.

    # --- In terminal state
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 1.
    assert get_reward('ENGLAND', True, DoneReason.GAME_ENGINE) == 1.
    assert get_reward('FRANCE', True, DoneReason.GAME_ENGINE) == 1.
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 1.
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == 1.
    assert get_reward('RUSSIA', True, DoneReason.GAME_ENGINE) == 1.
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 1.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == -1.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == -1.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == -1.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == -1.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == -1.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == -1.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == -1.

    # --- Clearing supply centers
    prev_state_proto = extract_state_proto(game)
    for power in game.powers.values():
        if power.name != 'FRANCE':
            power.clear_units()
            power.centers = []
    state_proto = extract_state_proto(game)
    get_reward = lambda power_name, is_terminal, done_reason: rew_fn.get_reward(
        prev_state_proto,
        state_proto,
        power_name,
        is_terminal_state=is_terminal,
        done_reason=done_reason)

    # --- In terminal state
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == -1.
    assert get_reward('ENGLAND', True, DoneReason.GAME_ENGINE) == -1.
    assert get_reward('FRANCE', True, DoneReason.GAME_ENGINE) == 1.
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == -1.
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == -1.
    assert get_reward('RUSSIA', True, DoneReason.GAME_ENGINE) == -1.
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == -1.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == -1.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == -1.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == -1.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == -1.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == -1.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == -1.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == -1.
Exemplo n.º 17
0
def test_custom_int_unit_reward():
    """ Tests for CustomInterimUnitReward """
    game = Game()
    rew_fn = CustomIntUnitReward()

    # Issuing orders
    prev_state_proto = extract_state_proto(game)
    game.set_orders('FRANCE', ['A MAR - SPA', 'A PAR - PIC'])
    game.set_orders('AUSTRIA', ['A VIE - TYR'])
    game.process()
    state_proto = extract_state_proto(game)
    assert game.get_current_phase() == 'F1901M'
    get_reward = lambda power_name, is_terminal, done_reason: rew_fn.get_reward(
        prev_state_proto,
        state_proto,
        power_name,
        is_terminal_state=is_terminal,
        done_reason=done_reason)

    # +1 for FRANCE for conquering SPA

    # --- Not in terminal state
    assert get_reward('AUSTRIA', False, None) == 0.
    assert get_reward('ENGLAND', False, None) == 0.
    assert get_reward('FRANCE', False, None) == 1.
    assert get_reward('GERMANY', False, None) == 0.
    assert get_reward('ITALY', False, None) == 0.
    assert get_reward('RUSSIA', False, None) == 0.
    assert get_reward('TURKEY', False, None) == 0.

    # --- In terminal state
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ENGLAND', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('FRANCE', True, DoneReason.GAME_ENGINE) == 1.
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('RUSSIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 0.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == -18.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == -18.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == -18.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == -18.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == -18.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == -18.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == -18.

    # Issuing orders
    prev_state_proto = state_proto
    game.set_orders('FRANCE', ['A PIC - BEL', 'A SPA - POR'])
    game.set_orders('AUSTRIA', ['F TRI - VEN', 'A TYR S F TRI - VEN'])
    game.process()
    state_proto = extract_state_proto(game)
    get_reward = lambda power_name, is_terminal, done_reason: rew_fn.get_reward(
        prev_state_proto,
        state_proto,
        power_name,
        is_terminal_state=is_terminal,
        done_reason=done_reason)

    # +1 for FRANCE for conquering POR
    # -1 for FRANCE for losing SPA
    # +1 for FRANCE for conquering BEL
    # +1 for AUSTRIA for conquering VEN
    # -1 for ITALY for losing VEN

    # --- Not in terminal state
    assert get_reward('AUSTRIA', False, None) == 1.
    assert get_reward('ENGLAND', False, None) == 0.
    assert get_reward('FRANCE', False, None) == 1.
    assert get_reward('GERMANY', False, None) == 0.
    assert get_reward('ITALY', False, None) == -1.
    assert get_reward('RUSSIA', False, None) == 0.
    assert get_reward('TURKEY', False, None) == 0.

    # --- In terminal state
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 1.
    assert get_reward('ENGLAND', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('FRANCE', True, DoneReason.GAME_ENGINE) == 1.
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == -1.
    assert get_reward('RUSSIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 0.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == -18.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == -18.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == -18.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == -18.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == -18.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == -18.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == -18.

    # Issuing orders
    prev_state_proto = state_proto
    game.set_orders('FRANCE', ['A PIC - BEL', 'A SPA - POR'])
    game.set_orders('AUSTRIA', ['F TRI - VEN', 'A TYR S F TRI - VEN'])
    game.process()
    state_proto = extract_state_proto(game)
    get_reward = lambda power_name, is_terminal, done_reason: rew_fn.get_reward(
        prev_state_proto,
        state_proto,
        power_name,
        is_terminal_state=is_terminal,
        done_reason=done_reason)

    # +0 - No new SCs

    # --- Not in terminal state
    assert get_reward('AUSTRIA', False, None) == 0.
    assert get_reward('ENGLAND', False, None) == 0.
    assert get_reward('FRANCE', False, None) == 0.
    assert get_reward('GERMANY', False, None) == 0.
    assert get_reward('ITALY', False, None) == 0.
    assert get_reward('RUSSIA', False, None) == 0.
    assert get_reward('TURKEY', False, None) == 0.

    # --- In terminal state
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ENGLAND', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('FRANCE', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('RUSSIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 0.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == -18.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == -18.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == -18.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == -18.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == -18.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == -18.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == -18.
Exemplo n.º 18
0
def generate_trajectory(players,
                        reward_fn,
                        advantage_fn,
                        env_constructor=None,
                        hparams=None,
                        power_assignments=None,
                        set_player_seed=None,
                        initial_state_bytes=None,
                        update_interval=0,
                        update_queue=None,
                        output_format='proto'):
    """ Generates a single trajectory (Saved Gamed Proto) for RL (self-play) with the power assigments
        :param players: A list of instantiated players
        :param reward_fn: The reward function to use to calculate rewards
        :param advantage_fn: An instance of `.models.self_play.advantages`
        :param env_constructor: A callable to get the OpenAI gym environment (args: players)
        :param hparams: A dictionary of hyper parameters with their values
        :param power_assignments: Optional. The power name we want to play as. (e.g. 'FRANCE') or a list of powers.
        :param set_player_seed: Boolean that indicates that we want to set the player seed on reset().
        :param initial_state_bytes: A `game.State` proto (in bytes format) representing the initial state of the game.
        :param update_interval: Optional. If set, a partial saved game is put in the update_queue this every seconds.
        :param update_queue: Optional. If update interval is set, partial games will be put in this queue
        :param output_format: The output format. One of 'proto', 'bytes', 'zlib'
        :return: A SavedGameProto representing the game played (with policy details and power assignments)
                 Depending on format, the output might be converted to a byte array, or a compressed byte array.
        :type players: List[diplomacy_research.players.player.Player]
        :type reward_fn: diplomacy_research.models.self_play.reward_functions.AbstractRewardFunction
        :type advantage_fn: diplomacy_research.models.self_play.advantages.base_advantage.BaseAdvantage
        :type update_queue: multiprocessing.Queue
    """
    # pylint: disable=too-many-arguments
    assert output_format in ['proto', 'bytes', 'zlib'
                             ], 'Format should be "proto", "bytes", "zlib"'
    assert len(players) == NB_POWERS

    # Making sure we use the SavedGame wrapper to record the game
    if env_constructor:
        env = env_constructor(players)
    else:
        env = default_env_constructor(players, hparams, power_assignments,
                                      set_player_seed, initial_state_bytes)
    wrapped_env = env
    while not isinstance(wrapped_env, DiplomacyEnv):
        if isinstance(wrapped_env, SaveGame):
            break
        wrapped_env = wrapped_env.env
    else:
        env = SaveGame(env)

    # Detecting if we have a Auto-Draw wrapper
    has_auto_draw = False
    wrapped_env = env
    while not isinstance(wrapped_env, DiplomacyEnv):
        if isinstance(wrapped_env, AutoDraw):
            has_auto_draw = True
            break
        wrapped_env = wrapped_env.env

    # Resetting env
    env.reset()

    # Timing vars for partial updates
    time_last_update = time.time()
    year_last_update = 0
    start_phase_ix = 0
    current_phase_ix = 0
    nb_transitions = 0

    # Cache Variables
    powers = sorted(
        [power_name for power_name in get_map_powers(env.game.map)])
    assigned_powers = env.get_all_powers_name()
    stored_board_state = OrderedDict()  # {phase_name: board_state}
    stored_prev_orders_state = OrderedDict()  # {phase_name: prev_orders_state}
    stored_possible_orders = OrderedDict()  # {phase_name: possible_orders}

    power_variables = {
        power_name: {
            'orders': [],
            'policy_details': [],
            'state_values': [],
            'rewards': [],
            'returns': [],
            'last_state_value': 0.
        }
        for power_name in powers
    }

    new_state_proto = None
    phase_history_proto = []
    map_object = Map(name=env.game.map.name)

    # Generating
    while not env.is_done:
        state_proto = new_state_proto if new_state_proto is not None else extract_state_proto(
            env.game)
        possible_orders_proto = extract_possible_orders_proto(env.game)

        # Computing board_state
        board_state = proto_to_board_state(state_proto,
                                           map_object).flatten().tolist()
        state_proto.board_state.extend(board_state)

        # Storing possible orders for this phase
        current_phase = env.game.get_current_phase()
        stored_board_state[current_phase] = board_state
        stored_possible_orders[current_phase] = possible_orders_proto

        # Getting orders, policy details, and state value
        tasks = [(player, state_proto, pow_name,
                  phase_history_proto[-NB_PREV_ORDERS_HISTORY:],
                  possible_orders_proto)
                 for player, pow_name in zip(env.players, assigned_powers)]
        step_args = yield [get_step_args(*args) for args in tasks]

        # Stepping through env, storing power variables
        for power_name, (orders, policy_details,
                         state_value) in zip(assigned_powers, step_args):
            if orders:
                env.step((power_name, orders))
                nb_transitions += 1
            if has_auto_draw and policy_details is not None:
                env.set_draw_prob(power_name, policy_details['draw_prob'])

        # Processing
        env.process()
        current_phase_ix += 1

        # Retrieving draw action and saving power variables
        for power_name, (orders, policy_details,
                         state_value) in zip(assigned_powers, step_args):
            if has_auto_draw and policy_details is not None:
                policy_details['draw_action'] = env.get_draw_actions(
                )[power_name]
            power_variables[power_name]['orders'] += [orders]
            power_variables[power_name]['policy_details'] += [policy_details]
            power_variables[power_name]['state_values'] += [state_value]

        # Getting new state
        new_state_proto = extract_state_proto(env.game)

        # Storing reward for this transition
        done_reason = DoneReason(env.done_reason) if env.done_reason else None
        for power_name in powers:
            power_variables[power_name]['rewards'] += [
                reward_fn.get_reward(prev_state_proto=state_proto,
                                     state_proto=new_state_proto,
                                     power_name=power_name,
                                     is_terminal_state=done_reason is not None,
                                     done_reason=done_reason)
            ]

        # Computing prev_orders_state for the previous state
        last_phase_proto = extract_phase_history_proto(
            env.game, nb_previous_phases=1)[-1]
        if last_phase_proto.name[-1] == 'M':
            prev_orders_state = proto_to_prev_orders_state(
                last_phase_proto, map_object).flatten().tolist()
            stored_prev_orders_state[last_phase_proto.name] = prev_orders_state
            last_phase_proto.prev_orders_state.extend(prev_orders_state)
            phase_history_proto += [last_phase_proto]

        # Sending partial game if:
        # 1) We have update_interval > 0 with an update queue, and
        # 2a) The game is completed, or 2b) the update time has elapsted and at least 5 years as passed
        has_update_interval = update_interval > 0 and update_queue is not None
        game_is_completed = env.is_done
        min_time_has_passed = time.time() - time_last_update > update_interval
        current_year = 9999 if env.game.get_current_phase(
        ) == 'COMPLETED' else int(env.game.get_current_phase()[1:5])
        min_years_have_passed = current_year - year_last_update >= 5

        if (has_update_interval
                and (game_is_completed or
                     (min_time_has_passed and min_years_have_passed))):

            # Game is completed - last state value is 0
            if game_is_completed:
                for power_name in powers:
                    power_variables[power_name]['last_state_value'] = 0.

            # Otherwise - Querying the model for the value of the last state
            else:
                tasks = [
                    (player, new_state_proto, pow_name,
                     phase_history_proto[-NB_PREV_ORDERS_HISTORY:],
                     possible_orders_proto)
                    for player, pow_name in zip(env.players, assigned_powers)
                ]
                last_state_values = yield [
                    get_state_value(*args) for args in tasks
                ]

                for power_name, last_state_value in zip(
                        assigned_powers, last_state_values):
                    power_variables[power_name][
                        'last_state_value'] = last_state_value

            # Getting partial game and sending it on the update_queue
            saved_game_proto = get_saved_game_proto(
                env=env,
                players=players,
                stored_board_state=stored_board_state,
                stored_prev_orders_state=stored_prev_orders_state,
                stored_possible_orders=stored_possible_orders,
                power_variables=power_variables,
                start_phase_ix=start_phase_ix,
                reward_fn=reward_fn,
                advantage_fn=advantage_fn,
                is_partial_game=True)
            update_queue.put_nowait(
                (False, nb_transitions, proto_to_bytes(saved_game_proto)))

            # Updating stats
            start_phase_ix = current_phase_ix
            nb_transitions = 0
            if not env.is_done:
                year_last_update = int(env.game.get_current_phase()[1:5])

    # Since the environment is done (Completed game) - We can leave the last_state_value at 0.
    for power_name in powers:
        power_variables[power_name]['last_state_value'] = 0.

    # Getting completed game
    saved_game_proto = get_saved_game_proto(
        env=env,
        players=players,
        stored_board_state=stored_board_state,
        stored_prev_orders_state=stored_prev_orders_state,
        stored_possible_orders=stored_possible_orders,
        power_variables=power_variables,
        start_phase_ix=0,
        reward_fn=reward_fn,
        advantage_fn=advantage_fn,
        is_partial_game=False)

    # Converting to correct format
    output = {
        'proto': lambda proto: proto,
        'zlib': proto_to_zlib,
        'bytes': proto_to_bytes
    }[output_format](saved_game_proto)

    # Returning
    return output
Exemplo n.º 19
0
    def get_orders(self, game, power_names):
        """
        See diplomacy_research.players.player.Player.get_orders
        :param game: Game object
        :param power_names: A list of power names we are playing, or alternatively a single power name.
        :return: One of the following:
                1) If power_name is a string and with_draw == False (or is not set):
                    - A list of orders the power should play
                2) If power_name is a list and with_draw == False (or is not set):
                    - A list of list, which contains orders for each power
                3) If power_name is a string and with_draw == True:
                    - A tuple of 1) the list of orders for the power, 2) a boolean to accept a draw or not
                4) If power_name is a list and with_draw == True:
                    - A list of tuples, each tuple having the list of orders and the draw boolean
        """
        # num_dummies/tiling is hacky way to get around TF Strided Slice error
        # that occurs when only passing in one state (e.g. batch size of 1)
        num_dummies = 2
        order_history = extract_phase_history_proto(game, 3)
        if len(order_history) == 0:
            prev_orders_state = tf.zeros((1, 81, 40), dtype=tf.float32)
        else:
            # print(order_history)
            # Getting last movement phase
            for i in range(len(order_history) - 1, -1, -1):
                if order_history[i].name[-1] == "M":
                    prev_movement_phase = order_history[i]
                    break
                else:
                    continue
            prev_orders_state = proto_to_prev_orders_state(
                prev_movement_phase, game.map).flatten().tolist()
            prev_orders_state = tf.reshape(prev_orders_state, (1, 81, 40))
        prev_orders__state_with_dummies = tf.tile(prev_orders_state,
                                                  [num_dummies, 1, 1])
        board_state = dict_to_flatten_board_state(game.get_state(), game.map)
        board_state = tf.reshape(board_state, (1, 81, 35))
        board_state_with_dummies = tf.tile(board_state, [num_dummies, 1, 1])
        season = get_current_season(extract_state_proto(game))
        state = game.get_state()
        year = state["name"]
        board_dict = parse_rl_state(state)
        orders = []
        order_probs = []
        for power in power_names:
            print(power, year)
            power_season = tf.concat([UNIT_POWER[power], INT_SEASON[season]],
                                     axis=0)
            power_season = tf.expand_dims(power_season, axis=0)
            power_season_with_dummies = tf.tile(power_season, [num_dummies, 1])
            probs, position_list = self.call(
                board_state_with_dummies, prev_orders__state_with_dummies,
                power_season_with_dummies, [year for _ in range(num_dummies)],
                [board_dict for _ in range(num_dummies)], power)

            prob_no_dummies = tf.squeeze(probs, axis=1)[:, 0, :]
            order_ix = tf.argmax(prob_no_dummies, axis=1)
            orders_list = [
                INVERSE_ORDER_DICT[index] for index in order_ix.numpy()
            ]
            orders_probs_list = [
                prob_no_dummies[i][index]
                for i, index in enumerate(order_ix.numpy())
            ]
            orders.append(orders_list)
            order_probs.append(orders_probs_list)
        return orders, order_probs
Exemplo n.º 20
0
def test_survivor_win_reward():
    """ Test survivor win reward function """
    game = Game()
    pot_size = 20
    rew_fn = SurvivorWinReward(pot_size=pot_size)
    prev_state_proto = extract_state_proto(game)
    state_proto = extract_state_proto(game)
    assert rew_fn.name == 'survivor_win_reward'
    get_reward = lambda power_name, is_terminal, done_reason: round(
        rew_fn.get_reward(prev_state_proto,
                          state_proto,
                          power_name,
                          is_terminal_state=is_terminal,
                          done_reason=done_reason), 8)

    # --- Not in terminal state
    assert get_reward('AUSTRIA', False, None) == 0.
    assert get_reward('ENGLAND', False, None) == 0.
    assert get_reward('FRANCE', False, None) == 0.
    assert get_reward('GERMANY', False, None) == 0.
    assert get_reward('ITALY', False, None) == 0.
    assert get_reward('RUSSIA', False, None) == 0.
    assert get_reward('TURKEY', False, None) == 0.

    # --- In terminal state
    assert get_reward('AUSTRIA', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)
    assert get_reward('ENGLAND', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)
    assert get_reward('FRANCE', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)
    assert get_reward('GERMANY', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)
    assert get_reward('ITALY', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)
    assert get_reward('RUSSIA', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)
    assert get_reward('TURKEY', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 7., 8)

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == 0.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == 0.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == 0.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == 0.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == 0.

    # --- Clearing supply centers
    prev_state_proto = extract_state_proto(game)
    for power in game.powers.values():
        if power.name != 'FRANCE' and power.name != 'RUSSIA':
            power.clear_units()
            power.centers = []
    state_proto = extract_state_proto(game)
    get_reward = lambda power_name, is_terminal, done_reason: round(
        rew_fn.get_reward(prev_state_proto,
                          state_proto,
                          power_name,
                          is_terminal_state=is_terminal,
                          done_reason=done_reason), 8)

    # --- In terminal state
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ENGLAND', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('FRANCE', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 2., 8)
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('RUSSIA', True,
                      DoneReason.GAME_ENGINE) == round(pot_size / 2., 8)
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 0.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == 0.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == 0.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == 0.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == 0.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == 0.

    # Move centers in other countries to FRANCE except ENGLAND
    # Winner: FRANCE
    # Survivor: FRANCE, ENGLAND
    game = Game()
    prev_state_proto = extract_state_proto(game)
    game.clear_centers()
    game.set_centers('FRANCE', [
        'BUD', 'TRI', 'VIE', 'BRE', 'MAR', 'PAR', 'BER', 'KIE', 'MUN', 'NAP',
        'ROM', 'VEN', 'MOS', 'SEV', 'STP', 'WAR', 'ANK', 'CON', 'SMY'
    ])
    game.set_centers('ENGLAND', ['EDI', 'LON', 'LVP'])
    state_proto = extract_state_proto(game)
    get_reward = lambda power_name, is_terminal, done_reason: round(
        rew_fn.get_reward(prev_state_proto,
                          state_proto,
                          power_name,
                          is_terminal_state=is_terminal,
                          done_reason=done_reason), 8)

    # France has 19 SC, 18 to win, 1 excess
    # Nb of controlled SC is 19 + 3 - 1 excess = 21
    # Reward for FRANCE is 18 / 21 * pot
    # Reward for ENGLAND is 3 / 21 * pot

    # --- In terminal state -- Victory
    assert get_reward('AUSTRIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ENGLAND', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 3. / 21, 8)
    assert get_reward('FRANCE', True,
                      DoneReason.GAME_ENGINE) == round(pot_size * 18. / 21, 8)
    assert get_reward('GERMANY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('ITALY', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('RUSSIA', True, DoneReason.GAME_ENGINE) == 0.
    assert get_reward('TURKEY', True, DoneReason.GAME_ENGINE) == 0.

    # --- Thrashing
    assert get_reward('AUSTRIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('ENGLAND', True, DoneReason.THRASHED) == 0.
    assert get_reward('FRANCE', True, DoneReason.THRASHED) == 0.
    assert get_reward('GERMANY', True, DoneReason.THRASHED) == 0.
    assert get_reward('ITALY', True, DoneReason.THRASHED) == 0.
    assert get_reward('RUSSIA', True, DoneReason.THRASHED) == 0.
    assert get_reward('TURKEY', True, DoneReason.THRASHED) == 0.