Beispiel #1
0
 def create_default_memory(self):
     return prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer(
         SCREEN_SIZE,
         STACK_SIZE,
         REPLAY_CAPACITY,
         BATCH_SIZE,
         max_sample_attempts=10)  # For faster tests.
 def _build_replay_buffer(self):
     """Creates the prioritized replay buffer used by the agent."""
     return prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer(
         observation_shape=self.observation_shape,
         stack_size=self.stack_size,
         update_horizon=self.update_horizon,
         gamma=self.gamma,
         observation_dtype=self.observation_dtype)
Beispiel #3
0
 def testConstructorWithExtraStorageTypes(self):
   prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer(
       SCREEN_SIZE,
       STACK_SIZE,
       REPLAY_CAPACITY,
       BATCH_SIZE,
       extra_storage_types=[
           prioritized_replay_buffer.ReplayElement('extra1', [], np.float32),
           prioritized_replay_buffer.ReplayElement('extra2', [2], np.int8)
       ])
Beispiel #4
0
 def _build_replay_buffer(self):
     """Creates the replay buffer used by the agent."""
     if self._replay_scheme not in ['uniform', 'prioritized']:
         raise ValueError('Invalid replay scheme: {}'.format(
             self._replay_scheme))
     return prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer(
         observation_shape=self.observation_shape,
         stack_size=self.stack_size,
         update_horizon=self.update_horizon,
         gamma=self.gamma,
         observation_dtype=self.observation_dtype)
Beispiel #5
0
 def _build_replay_buffer(self):
   """Creates the replay buffer used by the agent."""
   if self._replay_scheme not in ['uniform', 'prioritized']:
     raise ValueError('Invalid replay scheme: {}'.format(self._replay_scheme))
   # Both replay schemes use the same data structure, but the 'uniform' scheme
   # sets all priorities to the same value (which yields uniform sampling).
   return prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer(
       observation_shape=self.observation_shape,
       stack_size=self.stack_size,
       update_horizon=self.update_horizon,
       gamma=self.gamma,
       observation_dtype=self.observation_dtype)
Beispiel #6
0
 def testSampleIndexBatch(self):
     memory = prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer(
         SCREEN_SIZE,
         STACK_SIZE,
         REPLAY_CAPACITY,
         BATCH_SIZE,
         max_sample_attempts=REPLAY_CAPACITY)
     # This will ensure we end up with cursor == 1.
     for _ in range(REPLAY_CAPACITY - STACK_SIZE + 2):
         self.add_blank(memory)
     self.assertEqual(memory.cursor(), 1)
     samples = memory.sample_index_batch(REPLAY_CAPACITY)
     # Because cursor == 1, the invalid range as set by circular_replay_buffer.py
     # will be # [0, 1, 2, 3], resulting in all samples being in
     # [STACK_SIZE, REPLAY_CAPACITY - 1].
     for sample in samples:
         self.assertGreaterEqual(sample, STACK_SIZE)
         self.assertLessEqual(sample, REPLAY_CAPACITY - 1)