def __init__( self, observation_shape, stack_size, replay_capacity, batch_size, update_horizon=1, gamma=0.99, max_sample_attempts=1000, extra_storage_types=None, observation_dtype=np.uint8, terminal_dtype=np.uint8, action_shape=(), action_dtype=np.int32, reward_shape=(), reward_dtype=np.float32, ): """Initializes PrioritizedReplayBuffer. Args: observation_shape: tuple of ints. stack_size: int, number of frames to use in state stack. replay_capacity: int, number of transitions to keep in memory. batch_size: int. update_horizon: int, length of update ('n' in n-step update). gamma: int, the discount factor. max_sample_attempts: int, the maximum number of attempts allowed to get a sample. extra_storage_types: list of ReplayElements defining the type of the extra contents that will be stored and returned by sample_transition_batch. observation_dtype: np.dtype, type of the observations. Defaults to np.uint8 for Atari 2600. terminal_dtype: np.dtype, type of the terminals. Defaults to np.uint8 for Atari 2600. action_shape: tuple of ints, the shape for the action vector. Empty tuple means the action is a scalar. action_dtype: np.dtype, type of elements in the action. reward_shape: tuple of ints, the shape of the reward vector. Empty tuple means the reward is a scalar. reward_dtype: np.dtype, type of elements in the reward. """ super(PrioritizedReplayBuffer, self).__init__( observation_shape=observation_shape, stack_size=stack_size, replay_capacity=replay_capacity, batch_size=batch_size, update_horizon=update_horizon, gamma=gamma, max_sample_attempts=max_sample_attempts, extra_storage_types=extra_storage_types, observation_dtype=observation_dtype, terminal_dtype=terminal_dtype, action_shape=action_shape, action_dtype=action_dtype, reward_shape=reward_shape, reward_dtype=reward_dtype, ) self.sum_tree = sum_tree.SumTree(replay_capacity)
def testSetValueSmallCapacity(self): tree = sum_tree.SumTree(capacity=1) tree.set(0, 1.5) self.assertEqual(tree.get(0), 1.5)
def testNegativeCapacity(self): with self.assertRaises( ValueError, msg="Sum tree capacity should be positive. Got: -1" ): sum_tree.SumTree(capacity=-1)
def testSmallCapacityConstructor(self): tree = sum_tree.SumTree(capacity=1) self.assertEqual(len(tree.nodes), 1) tree = sum_tree.SumTree(capacity=2) self.assertEqual(len(tree.nodes), 2)
def setUp(self): self._tree = sum_tree.SumTree(capacity=100)