def create_default_memory(self): return prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer( SCREEN_SIZE, STACK_SIZE, REPLAY_CAPACITY, BATCH_SIZE, max_sample_attempts=10) # For faster tests.
def _build_replay_buffer(self): """Creates the prioritized replay buffer used by the agent.""" return prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer( observation_shape=self.observation_shape, stack_size=self.stack_size, update_horizon=self.update_horizon, gamma=self.gamma, observation_dtype=self.observation_dtype)
def testConstructorWithExtraStorageTypes(self): prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer( SCREEN_SIZE, STACK_SIZE, REPLAY_CAPACITY, BATCH_SIZE, extra_storage_types=[ prioritized_replay_buffer.ReplayElement('extra1', [], np.float32), prioritized_replay_buffer.ReplayElement('extra2', [2], np.int8) ])
def _build_replay_buffer(self): """Creates the replay buffer used by the agent.""" if self._replay_scheme not in ['uniform', 'prioritized']: raise ValueError('Invalid replay scheme: {}'.format( self._replay_scheme)) return prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer( observation_shape=self.observation_shape, stack_size=self.stack_size, update_horizon=self.update_horizon, gamma=self.gamma, observation_dtype=self.observation_dtype)
def _build_replay_buffer(self): """Creates the replay buffer used by the agent.""" if self._replay_scheme not in ['uniform', 'prioritized']: raise ValueError('Invalid replay scheme: {}'.format(self._replay_scheme)) # Both replay schemes use the same data structure, but the 'uniform' scheme # sets all priorities to the same value (which yields uniform sampling). return prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer( observation_shape=self.observation_shape, stack_size=self.stack_size, update_horizon=self.update_horizon, gamma=self.gamma, observation_dtype=self.observation_dtype)
def testSampleIndexBatch(self): memory = prioritized_replay_buffer.OutOfGraphPrioritizedReplayBuffer( SCREEN_SIZE, STACK_SIZE, REPLAY_CAPACITY, BATCH_SIZE, max_sample_attempts=REPLAY_CAPACITY) # This will ensure we end up with cursor == 1. for _ in range(REPLAY_CAPACITY - STACK_SIZE + 2): self.add_blank(memory) self.assertEqual(memory.cursor(), 1) samples = memory.sample_index_batch(REPLAY_CAPACITY) # Because cursor == 1, the invalid range as set by circular_replay_buffer.py # will be # [0, 1, 2, 3], resulting in all samples being in # [STACK_SIZE, REPLAY_CAPACITY - 1]. for sample in samples: self.assertGreaterEqual(sample, STACK_SIZE) self.assertLessEqual(sample, REPLAY_CAPACITY - 1)