Esempio n. 1
0
  def test_replay_buffer_max_capacity(self):
    replay_buffer = dqn.ReplayBuffer(replay_buffer_capacity=2)
    replay_buffer.add("entry1")
    replay_buffer.add("entry2")
    replay_buffer.add("entry3")
    self.assertEqual(len(replay_buffer), 2)

    self.assertIn("entry2", replay_buffer)
    self.assertIn("entry3", replay_buffer)
Esempio n. 2
0
  def test_replay_buffer_add(self):
    replay_buffer = dqn.ReplayBuffer(replay_buffer_capacity=10)
    self.assertEqual(len(replay_buffer), 0)
    replay_buffer.add("entry1")
    self.assertEqual(len(replay_buffer), 1)
    replay_buffer.add("entry2")
    self.assertEqual(len(replay_buffer), 2)

    self.assertIn("entry1", replay_buffer)
    self.assertIn("entry2", replay_buffer)
Esempio n. 3
0
  def test_replay_buffer_sample(self):
    replay_buffer = dqn.ReplayBuffer(replay_buffer_capacity=3)
    replay_buffer.add("entry1")
    replay_buffer.add("entry2")
    replay_buffer.add("entry3")

    samples = replay_buffer.sample(3)

    self.assertIn("entry1", samples)
    self.assertIn("entry2", samples)
    self.assertIn("entry3", samples)
Esempio n. 4
0
    def __init__(self,
                 game,
                 bot,
                 model,
                 replay_buffer_capacity=int(1e6),
                 action_selection_transition=30):
        """AlphaZero constructor.

    Args:
      game: a pyspiel.Game object
      bot: an MCTSBot object.
      model: A Model.
      replay_buffer_capacity: the size of the replay buffer in which the results
        of self-play games are stored.
      action_selection_transition: an integer representing the move number in a
        game of self-play when greedy action selection is used. Before this,
        actions are sampled from the MCTS policy.

    Raises:
      ValueError: if incorrect inputs are supplied.
    """

        game_info = game.get_type()
        if game.num_players() != 2:
            raise ValueError("Game must be a 2-player game")
        if game_info.chance_mode != pyspiel.GameType.ChanceMode.DETERMINISTIC:
            raise ValueError(
                "The game must be a Deterministic one, not {}".format(
                    game.chance_mode))
        if (game_info.information !=
                pyspiel.GameType.Information.PERFECT_INFORMATION):
            raise ValueError(
                "The game must be a perfect information one, not {}".format(
                    game.information))
        if game_info.dynamics != pyspiel.GameType.Dynamics.SEQUENTIAL:
            raise ValueError("The game must be turn-based, not {}".format(
                game.dynamics))
        if game_info.utility != pyspiel.GameType.Utility.ZERO_SUM:
            raise ValueError("The game must be 0-sum, not {}".format(
                game.utility))
        if game.num_players() != 2:
            raise ValueError("Game must have exactly 2 players.")

        self.game = game
        self.bot = bot
        self.model = model
        self.replay_buffer = dqn.ReplayBuffer(replay_buffer_capacity)
        self.action_selection_transition = action_selection_transition