コード例 #1
0
    def __init__(
            self,
            observation_shape,
            stack_size,
            replay_capacity,
            batch_size,
            update_horizon=1,
            gamma=0.99,
            max_sample_attempts=1000,
            extra_storage_types=None,
            observation_dtype=np.uint8,
            terminal_dtype=np.uint8,
            action_shape=(),
            action_dtype=np.int32,
            reward_shape=(),
            reward_dtype=np.float32,
    ):
        """Initializes PrioritizedReplayBuffer.
        Args:
          observation_shape: tuple of ints.
          stack_size: int, number of frames to use in state stack.
          replay_capacity: int, number of transitions to keep in memory.
          batch_size: int.
          update_horizon: int, length of update ('n' in n-step update).
          gamma: int, the discount factor.
          max_sample_attempts: int, the maximum number of attempts allowed to
            get a sample.
          extra_storage_types: list of ReplayElements defining the type of the extra
            contents that will be stored and returned by sample_transition_batch.
          observation_dtype: np.dtype, type of the observations. Defaults to
            np.uint8 for Atari 2600.
          terminal_dtype: np.dtype, type of the terminals. Defaults to np.uint8 for
            Atari 2600.
          action_shape: tuple of ints, the shape for the action vector. Empty tuple
            means the action is a scalar.
          action_dtype: np.dtype, type of elements in the action.
          reward_shape: tuple of ints, the shape of the reward vector. Empty tuple
            means the reward is a scalar.
          reward_dtype: np.dtype, type of elements in the reward.
        """
        super(PrioritizedReplayBuffer, self).__init__(
            observation_shape=observation_shape,
            stack_size=stack_size,
            replay_capacity=replay_capacity,
            batch_size=batch_size,
            update_horizon=update_horizon,
            gamma=gamma,
            max_sample_attempts=max_sample_attempts,
            extra_storage_types=extra_storage_types,
            observation_dtype=observation_dtype,
            terminal_dtype=terminal_dtype,
            action_shape=action_shape,
            action_dtype=action_dtype,
            reward_shape=reward_shape,
            reward_dtype=reward_dtype,
        )

        self.sum_tree = sum_tree.SumTree(replay_capacity)
コード例 #2
0
 def testSetValueSmallCapacity(self):
     tree = sum_tree.SumTree(capacity=1)
     tree.set(0, 1.5)
     self.assertEqual(tree.get(0), 1.5)
コード例 #3
0
 def testNegativeCapacity(self):
     with self.assertRaises(
         ValueError, msg="Sum tree capacity should be positive. Got: -1"
     ):
         sum_tree.SumTree(capacity=-1)
コード例 #4
0
 def testSmallCapacityConstructor(self):
     tree = sum_tree.SumTree(capacity=1)
     self.assertEqual(len(tree.nodes), 1)
     tree = sum_tree.SumTree(capacity=2)
     self.assertEqual(len(tree.nodes), 2)
コード例 #5
0
 def setUp(self):
     self._tree = sum_tree.SumTree(capacity=100)