Example #1
0
 def test_invalid_aoh_construction_raises_error(self):
     with self.assertRaises(RuntimeError):
         pyspiel.ActionObservationHistory(0, ["not tuple"])
     with self.assertRaises(RuntimeError):
         pyspiel.ActionObservationHistory(0, [("tuple", "too", "long")])
     with self.assertRaises(RuntimeError):
         pyspiel.ActionObservationHistory(0, [("not int", "obs")])
     with self.assertRaises(RuntimeError):
         pyspiel.ActionObservationHistory(0, [(1, 1)])  # obs not string
Example #2
0
    def test_observation_consistency(self):
        if self.game_type.dynamics != pyspiel.GameType.Dynamics.SEQUENTIAL:
            return  # TODO(author13): support also other types of games.

        # TODO(author13): Following games need to be fixed -- they currently
        # do not pass the observation consistency test.
        broken_games = [
            "first_sealed_auction", "hearts", "kuhn_poker", "leduc_poker",
            "lewis_signaling", "liars_dice", "pentago", "phantom_ttt",
            "tiny_bridge_2p", "tiny_bridge_4p", "tiny_hanabi",
            "universal_poker"
        ]
        if self.game_name in broken_games:
            return

        callbacks = dict()
        callbacks["move_number"] = lambda _, state: str(state.move_number())
        # pylint: disable=g-long-lambda
        if self.game_type.provides_information_state_string:
            callbacks["information_state_string"] = lambda player, state: \
              state.information_state_string(player)
        if self.game_type.provides_observation_string:
            callbacks["action_observation_history"] = lambda player, state: \
              str(pyspiel.ActionObservationHistory(player, state))
        if self.game_type.provides_factored_observation_string:
            callbacks["public_observation_history"] = lambda player, state: \
              str(pyspiel.PublicObservationHistory(state))

        # TODO(author13): Add testing of tensor variants.

        relations = []

        def add_relation(x, relation, y):
            if x in callbacks and y in callbacks:
                relations.append((x, relation, y))

        # All observations should be subsets of move_number.
        add_relation("information_state_string", Relation.SUBSET_OR_EQUALS,
                     "move_number")
        add_relation("public_observation_history", Relation.SUBSET_OR_EQUALS,
                     "move_number")
        add_relation("action_observation_history", Relation.SUBSET_OR_EQUALS,
                     "move_number")

        # Other relations:
        add_relation("information_state_string", Relation.EQUALS,
                     "action_observation_history")
        add_relation("information_state_string", Relation.SUBSET_OR_EQUALS,
                     "public_observation_history")
        add_relation("action_observation_history", Relation.SUBSET_OR_EQUALS,
                     "public_observation_history")

        self._check_partition_consistency(callbacks, relations)
Example #3
0
        def advance_and_test(state):
            nonlocal seq_idx
            action = action_sequence[seq_idx]

            parent_state = state.clone()
            state.apply_action(action)

            # Both PublicObservationHistory and ActionObservationHistory need to pass
            # the same correspondence / prefix / extension tests, for both
            # construction from state (and player for AOH) or just passing state (and
            # player) without history construction. So this tests them together.
            # These are passed as constructors (lefts) and targets (rights).
            constructors_with_targets = [
                # (p)layer, (s)tate
                (pyspiel.ActionObservationHistory, [
                    lambda p, s: [pyspiel.ActionObservationHistory(p, s)],
                    lambda p, s: [p, s]
                ]),
                # Since PublicObservationHistory does not need a player it is ignored.
                (lambda _, s: pyspiel.PublicObservationHistory(s), [
                    lambda _, s: [pyspiel.PublicObservationHistory(s)],
                    lambda _, s: [s]
                ])
            ]

            for (left, rights) in constructors_with_targets:
                for right in rights:
                    for player in range(2):
                        # Shortcuts for conciseness: the most important things
                        # relevant for the tests have long names.
                        # In other words, when reading the code just skip over
                        # the short vars to gain understanding.
                        l, r, p = left, right, player

                        self.assertTrue(
                            l(p, state).corresponds_to(*r(p, state)))
                        self.assertFalse(
                            l(p, parent_state).corresponds_to(*r(p, state)))
                        if state.is_terminal():
                            self.assertTrue(
                                l(p, terminal).corresponds_to(*r(p, state)))
                        else:
                            self.assertFalse(
                                l(p, terminal).corresponds_to(*r(p, state)))
                        if state.is_initial_state():
                            self.assertTrue(
                                l(p, root_state).corresponds_to(*r(p, state)))
                        else:
                            self.assertFalse(
                                l(p, root_state).corresponds_to(*r(p, state)))

                        self.assertTrue(
                            l(p, parent_state).is_prefix_of(*r(p, state)))
                        self.assertFalse(
                            l(p, state).is_prefix_of(*r(p, parent_state)))
                        self.assertTrue(
                            l(p, root_state).is_prefix_of(*r(p, state)))
                        if state.is_terminal():
                            self.assertTrue(
                                l(p, terminal).is_prefix_of(*r(p, state)))
                        else:
                            self.assertFalse(
                                l(p, terminal).is_prefix_of(*r(p, state)))

                        self.assertFalse(
                            l(p, parent_state).is_extension_of(*r(p, state)))
                        self.assertTrue(
                            l(p, state).is_extension_of(*r(p, parent_state)))
                        self.assertTrue(
                            l(p, terminal).is_extension_of(*r(p, state)))
                        if state.is_initial_state():
                            self.assertTrue(
                                l(p, root_state).is_extension_of(*r(p, state)))
                        else:
                            self.assertFalse(
                                l(p, root_state).is_extension_of(*r(p, state)))

            seq_idx += 1
            return state
Example #4
0
    def test_kuhn_rollout(self):
        game = pyspiel.load_game("kuhn_poker")

        # Test on this specific sequence.
        action_sequence = [2, 1, 0, 1, 1]

        root_state = game.new_initial_state()
        terminal = state_from_sequence(game, action_sequence)
        self.assertTrue(terminal.is_terminal())

        # Check prefixes, extensions and correspondences of both
        # public-observation histories / action-observation histories
        # while rolling out.
        seq_idx = 0

        def advance_and_test(state):
            nonlocal seq_idx
            action = action_sequence[seq_idx]

            parent_state = state.clone()
            state.apply_action(action)

            # Both PublicObservationHistory and ActionObservationHistory need to pass
            # the same correspondence / prefix / extension tests, for both
            # construction from state (and player for AOH) or just passing state (and
            # player) without history construction. So this tests them together.
            # These are passed as constructors (lefts) and targets (rights).
            constructors_with_targets = [
                # (p)layer, (s)tate
                (pyspiel.ActionObservationHistory, [
                    lambda p, s: [pyspiel.ActionObservationHistory(p, s)],
                    lambda p, s: [p, s]
                ]),
                # Since PublicObservationHistory does not need a player it is ignored.
                (lambda _, s: pyspiel.PublicObservationHistory(s), [
                    lambda _, s: [pyspiel.PublicObservationHistory(s)],
                    lambda _, s: [s]
                ])
            ]

            for (left, rights) in constructors_with_targets:
                for right in rights:
                    for player in range(2):
                        # Shortcuts for conciseness: the most important things
                        # relevant for the tests have long names.
                        # In other words, when reading the code just skip over
                        # the short vars to gain understanding.
                        l, r, p = left, right, player

                        self.assertTrue(
                            l(p, state).corresponds_to(*r(p, state)))
                        self.assertFalse(
                            l(p, parent_state).corresponds_to(*r(p, state)))
                        if state.is_terminal():
                            self.assertTrue(
                                l(p, terminal).corresponds_to(*r(p, state)))
                        else:
                            self.assertFalse(
                                l(p, terminal).corresponds_to(*r(p, state)))
                        if state.is_initial_state():
                            self.assertTrue(
                                l(p, root_state).corresponds_to(*r(p, state)))
                        else:
                            self.assertFalse(
                                l(p, root_state).corresponds_to(*r(p, state)))

                        self.assertTrue(
                            l(p, parent_state).is_prefix_of(*r(p, state)))
                        self.assertFalse(
                            l(p, state).is_prefix_of(*r(p, parent_state)))
                        self.assertTrue(
                            l(p, root_state).is_prefix_of(*r(p, state)))
                        if state.is_terminal():
                            self.assertTrue(
                                l(p, terminal).is_prefix_of(*r(p, state)))
                        else:
                            self.assertFalse(
                                l(p, terminal).is_prefix_of(*r(p, state)))

                        self.assertFalse(
                            l(p, parent_state).is_extension_of(*r(p, state)))
                        self.assertTrue(
                            l(p, state).is_extension_of(*r(p, parent_state)))
                        self.assertTrue(
                            l(p, terminal).is_extension_of(*r(p, state)))
                        if state.is_initial_state():
                            self.assertTrue(
                                l(p, root_state).is_extension_of(*r(p, state)))
                        else:
                            self.assertFalse(
                                l(p, root_state).is_extension_of(*r(p, state)))

            seq_idx += 1
            return state

        state = game.new_initial_state()
        self.assertTrue(state.is_chance_node())
        self.assertEqual(
            pyspiel.PublicObservationHistory(state),
            pyspiel.PublicObservationHistory(
                [pyspiel.PublicObservation.START_GAME]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(0, state),
            pyspiel.ActionObservationHistory(
                0, [(None, pyspiel.PrivateObservation.NOTHING)]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(1, state),
            pyspiel.ActionObservationHistory(
                1, [(None, pyspiel.PrivateObservation.NOTHING)]))

        advance_and_test(state)
        self.assertTrue(state.is_chance_node())
        self.assertEqual(
            pyspiel.PublicObservationHistory(state),
            pyspiel.PublicObservationHistory(
                [pyspiel.PublicObservation.START_GAME, "Deal to player 0"]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(0, state),
            pyspiel.ActionObservationHistory(
                0, [(None, pyspiel.PrivateObservation.NOTHING),
                    (None, "211")]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(1, state),
            pyspiel.ActionObservationHistory(
                1, [(None, pyspiel.PrivateObservation.NOTHING),
                    (None, pyspiel.PrivateObservation.NOTHING)]))

        advance_and_test(state)
        self.assertTrue(state.is_player_node())
        self.assertEqual(
            pyspiel.PublicObservationHistory(state),
            pyspiel.PublicObservationHistory([
                pyspiel.PublicObservation.START_GAME, "Deal to player 0",
                "Deal to player 1"
            ]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(0, state),
            pyspiel.ActionObservationHistory(
                0, [(None, pyspiel.PrivateObservation.NOTHING), (None, "211"),
                    (None, "211")]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(1, state),
            pyspiel.ActionObservationHistory(
                1, [(None, pyspiel.PrivateObservation.NOTHING),
                    (None, pyspiel.PrivateObservation.NOTHING),
                    (None, "111")]))

        advance_and_test(state)
        self.assertTrue(state.is_player_node())
        self.assertEqual(
            pyspiel.PublicObservationHistory(state),
            pyspiel.PublicObservationHistory([
                pyspiel.PublicObservation.START_GAME, "Deal to player 0",
                "Deal to player 1", "Pass"
            ]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(0, state),
            pyspiel.ActionObservationHistory(
                0, [(None, pyspiel.PrivateObservation.NOTHING), (None, "211"),
                    (None, "211"), (0, "211")]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(1, state),
            pyspiel.ActionObservationHistory(
                1, [(None, pyspiel.PrivateObservation.NOTHING),
                    (None, pyspiel.PrivateObservation.NOTHING), (None, "111"),
                    (None, "111")]))

        advance_and_test(state)
        self.assertTrue(state.is_player_node())
        self.assertEqual(
            pyspiel.PublicObservationHistory(state),
            pyspiel.PublicObservationHistory([
                pyspiel.PublicObservation.START_GAME, "Deal to player 0",
                "Deal to player 1", "Pass", "Bet"
            ]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(0, state),
            pyspiel.ActionObservationHistory(
                0, [(None, pyspiel.PrivateObservation.NOTHING), (None, "211"),
                    (None, "211"), (0, "211"), (None, "212")]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(1, state),
            pyspiel.ActionObservationHistory(
                1, [(None, pyspiel.PrivateObservation.NOTHING),
                    (None, pyspiel.PrivateObservation.NOTHING), (None, "111"),
                    (None, "111"), (1, "112")]))

        advance_and_test(state)
        self.assertTrue(state.is_terminal())
        self.assertEqual(
            pyspiel.PublicObservationHistory(state),
            pyspiel.PublicObservationHistory([
                pyspiel.PublicObservation.START_GAME, "Deal to player 0",
                "Deal to player 1", "Pass", "Bet", "Bet"
            ]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(0, state),
            pyspiel.ActionObservationHistory(
                0, [(None, pyspiel.PrivateObservation.NOTHING), (None, "211"),
                    (None, "211"), (0, "211"), (None, "212"), (1, "222")]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(1, state),
            pyspiel.ActionObservationHistory(
                1, [(None, pyspiel.PrivateObservation.NOTHING),
                    (None, pyspiel.PrivateObservation.NOTHING), (None, "111"),
                    (None, "111"), (1, "112"), (None, "122")]))
    def test_kuhn_rollout(self):
        game = pyspiel.load_game("kuhn_poker")

        state = game.new_initial_state()
        self.assertTrue(state.is_chance_node())
        self.assertEqual(
            pyspiel.PublicObservationHistory(state),
            pyspiel.PublicObservationHistory(
                [pyspiel.PublicObservation.START_GAME]))
        self.assertEqual(pyspiel.ActionObservationHistory(0, state),
                         pyspiel.ActionObservationHistory(0, [""]))
        self.assertEqual(pyspiel.ActionObservationHistory(1, state),
                         pyspiel.ActionObservationHistory(1, [""]))

        state.apply_action(2)
        self.assertTrue(state.is_chance_node())
        self.assertEqual(
            pyspiel.PublicObservationHistory(state),
            pyspiel.PublicObservationHistory(
                [pyspiel.PublicObservation.START_GAME, "Deal to player 0"]))
        self.assertEqual(pyspiel.ActionObservationHistory(0, state),
                         pyspiel.ActionObservationHistory(0, ["", "211"]))
        self.assertEqual(pyspiel.ActionObservationHistory(1, state),
                         pyspiel.ActionObservationHistory(1, ["", ""]))

        state.apply_action(1)
        self.assertTrue(state.is_player_node())
        self.assertEqual(
            pyspiel.PublicObservationHistory(state),
            pyspiel.PublicObservationHistory([
                pyspiel.PublicObservation.START_GAME, "Deal to player 0",
                "Deal to player 1"
            ]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(0, state),
            pyspiel.ActionObservationHistory(0, ["", "211", "211"]))
        self.assertEqual(pyspiel.ActionObservationHistory(1, state),
                         pyspiel.ActionObservationHistory(1, ["", "", "111"]))

        state.apply_action(0)
        self.assertTrue(state.is_player_node())
        self.assertEqual(
            pyspiel.PublicObservationHistory(state),
            pyspiel.PublicObservationHistory([
                pyspiel.PublicObservation.START_GAME, "Deal to player 0",
                "Deal to player 1", "Pass"
            ]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(0, state),
            pyspiel.ActionObservationHistory(0, ["", "211", "211", 0, "211"]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(1, state),
            pyspiel.ActionObservationHistory(1, ["", "", "111", "111"]))

        state.apply_action(1)
        self.assertTrue(state.is_player_node())
        self.assertEqual(
            pyspiel.PublicObservationHistory(state),
            pyspiel.PublicObservationHistory([
                pyspiel.PublicObservation.START_GAME, "Deal to player 0",
                "Deal to player 1", "Pass", "Bet"
            ]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(0, state),
            pyspiel.ActionObservationHistory(
                0, ["", "211", "211", 0, "211", "212"]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(1, state),
            pyspiel.ActionObservationHistory(1,
                                             ["", "", "111", "111", 1, "112"]))

        state.apply_action(1)
        self.assertTrue(state.is_terminal())
        self.assertEqual(
            pyspiel.PublicObservationHistory(state),
            pyspiel.PublicObservationHistory([
                pyspiel.PublicObservation.START_GAME, "Deal to player 0",
                "Deal to player 1", "Pass", "Bet", "Bet"
            ]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(0, state),
            pyspiel.ActionObservationHistory(
                0, ["", "211", "211", 0, "211", "212", 1, "222"]))
        self.assertEqual(
            pyspiel.ActionObservationHistory(1, state),
            pyspiel.ActionObservationHistory(
                1, ["", "", "111", "111", 1, "112", "122"]))