def test_replay_buffer_max_capacity(self): replay_buffer = dqn.ReplayBuffer(replay_buffer_capacity=2) replay_buffer.add("entry1") replay_buffer.add("entry2") replay_buffer.add("entry3") self.assertEqual(len(replay_buffer), 2) self.assertIn("entry2", replay_buffer) self.assertIn("entry3", replay_buffer)
def test_replay_buffer_add(self): replay_buffer = dqn.ReplayBuffer(replay_buffer_capacity=10) self.assertEqual(len(replay_buffer), 0) replay_buffer.add("entry1") self.assertEqual(len(replay_buffer), 1) replay_buffer.add("entry2") self.assertEqual(len(replay_buffer), 2) self.assertIn("entry1", replay_buffer) self.assertIn("entry2", replay_buffer)
def test_replay_buffer_sample(self): replay_buffer = dqn.ReplayBuffer(replay_buffer_capacity=3) replay_buffer.add("entry1") replay_buffer.add("entry2") replay_buffer.add("entry3") samples = replay_buffer.sample(3) self.assertIn("entry1", samples) self.assertIn("entry2", samples) self.assertIn("entry3", samples)
def __init__(self, game, bot, model, replay_buffer_capacity=int(1e6), action_selection_transition=30): """AlphaZero constructor. Args: game: a pyspiel.Game object bot: an MCTSBot object. model: A Model. replay_buffer_capacity: the size of the replay buffer in which the results of self-play games are stored. action_selection_transition: an integer representing the move number in a game of self-play when greedy action selection is used. Before this, actions are sampled from the MCTS policy. Raises: ValueError: if incorrect inputs are supplied. """ game_info = game.get_type() if game.num_players() != 2: raise ValueError("Game must be a 2-player game") if game_info.chance_mode != pyspiel.GameType.ChanceMode.DETERMINISTIC: raise ValueError( "The game must be a Deterministic one, not {}".format( game.chance_mode)) if (game_info.information != pyspiel.GameType.Information.PERFECT_INFORMATION): raise ValueError( "The game must be a perfect information one, not {}".format( game.information)) if game_info.dynamics != pyspiel.GameType.Dynamics.SEQUENTIAL: raise ValueError("The game must be turn-based, not {}".format( game.dynamics)) if game_info.utility != pyspiel.GameType.Utility.ZERO_SUM: raise ValueError("The game must be 0-sum, not {}".format( game.utility)) if game.num_players() != 2: raise ValueError("Game must have exactly 2 players.") self.game = game self.bot = bot self.model = model self.replay_buffer = dqn.ReplayBuffer(replay_buffer_capacity) self.action_selection_transition = action_selection_transition