def testLowCapacity(self): with self.assertRaisesRegexp(ValueError, 'There is not enough capacity'): circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=10, replay_capacity=10, batch_size=BATCH_SIZE, update_horizon=1, gamma=1.0) with self.assertRaisesRegexp(ValueError, 'There is not enough capacity'): circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=5, replay_capacity=10, batch_size=BATCH_SIZE, update_horizon=10, gamma=1.0) # We should be able to create a buffer that contains just enough for a # transition. circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=5, replay_capacity=10, batch_size=BATCH_SIZE, update_horizon=5, gamma=1.0)
def testConstructor(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE) self.assertEqual(memory._observation_shape, OBSERVATION_SHAPE) # Test with non square observation shape memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=(4, 20), stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE) self.assertEqual(memory._observation_shape, (4, 20)) self.assertEqual(memory.add_count, 0)
def testGetRangeWithWraparound(self): # Test the get_range function when the indices wrap around the circular # buffer. In other words, start_index > end_index. memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=10, batch_size=BATCH_SIZE, update_horizon=5, gamma=1.0) for _ in range(10): memory.add(np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), 0, 2.0, 0) # The constructed `array` will be: # array([[ 1., 1., 1., 1., 1.], # [ 2., 2., 2., 2., 2.], # [ 3., 3., 3., 3., 3.], # [ 4., 4., 4., 4., 4.], # [ 5., 5., 5., 5., 5.], # [ 6., 6., 6., 6., 6.], # [ 7., 7., 7., 7., 7.], # [ 8., 8., 8., 8., 8.], # [ 9., 9., 9., 9., 9.], # [10., 10., 10., 10., 10.]]) array = np.arange(10).reshape(10, 1) + np.ones(5) sliced_array = memory.get_range(array, 8, 12) # We roll by two, since start_index == 8 and replay_capacity == 10, so the # resulting indices used will be [8, 9, 0, 1]. rolled_array = np.roll(array, 2, axis=0) self.assertAllEqual(sliced_array, rolled_array[:4])
def testGetStack(self): zero_stack = np.zeros(OBSERVATION_SHAPE + (4, ), dtype=OBS_DTYPE) memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=50, batch_size=BATCH_SIZE) for i in range(11): memory.add(np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE), 0, 0, 0) # ensure that the returned shapes are always correct for i in range(3, memory.cursor()): self.assertTrue( memory.get_observation_stack(i).shape, OBSERVATION_SHAPE + (4, )) # ensure that there is the necessary 0 padding stack = memory.get_observation_stack(3) self.assertTrue(np.array_equal(zero_stack, stack)) # ensure that after the padding the contents are properly stored stack = memory.get_observation_stack(6) for i in range(4): self.assertTrue( np.array_equal(np.full(OBSERVATION_SHAPE, i), stack[:, :, i]))
def _create_dummy_memory(**kwargs): return circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=(2, ), stack_size=1, replay_capacity=10, batch_size=2, **kwargs)
def testSave(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE) memory.observation = self._test_observation memory.action = self._test_action memory.reward = self._test_reward memory.terminal = self._test_terminal current_iteration = 5 stale_iteration = (current_iteration - circular_replay_buffer.CHECKPOINT_DURATION) memory.save(self._test_subdir, stale_iteration) for attr in memory.__dict__: if attr.startswith('_'): continue stale_filename = os.path.join( self._test_subdir, '{}_ckpt.{}.gz'.format(attr, stale_iteration)) self.assertTrue(tf.gfile.Exists(stale_filename)) memory.save(self._test_subdir, current_iteration) for attr in memory.__dict__: if attr.startswith('_'): continue filename = os.path.join( self._test_subdir, '{}_ckpt.{}.gz'.format(attr, current_iteration)) self.assertTrue(tf.gfile.Exists(filename)) # The stale version file should have been deleted. self.assertFalse(tf.gfile.Exists(stale_filename))
def testSaveNonNDArrayAttributes(self): """Tests checkpointing an attribute which is not a numpy array.""" memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE) # Add some non-numpy data: an int, a string, an object. memory.dummy_attribute_1 = 4753849 memory.dummy_attribute_2 = 'String data' memory.dummy_attribute_3 = CheckpointableClass() current_iteration = 5 stale_iteration = (current_iteration - circular_replay_buffer.CHECKPOINT_DURATION) memory.save(self._test_subdir, stale_iteration) for attr in memory.__dict__: if attr.startswith('_'): continue stale_filename = os.path.join( self._test_subdir, '{}_ckpt.{}.gz'.format(attr, stale_iteration)) self.assertTrue(tf.gfile.Exists(stale_filename)) memory.save(self._test_subdir, current_iteration) for attr in memory.__dict__: if attr.startswith('_'): continue filename = os.path.join( self._test_subdir, '{}_ckpt.{}.gz'.format(attr, current_iteration)) self.assertTrue(tf.gfile.Exists(filename)) # The stale version file should have been deleted. self.assertFalse(tf.gfile.Exists(stale_filename))
def testWithNontupleObservationShape(self): with self.assertRaises(AssertionError): _ = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=84, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE)
def testIsTransitionValid(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=10, batch_size=2) memory.add( np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE), 0, dtype=OBS_DTYPE), 0, 0, 0) memory.add( np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE), 0, dtype=OBS_DTYPE), 0, 0, 0) memory.add( np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE), 0, dtype=OBS_DTYPE), 0, 0, 1) # These valids account for the automatically applied padding (3 blanks each # episode. correct_valids = [0, 0, 0, 1, 1, 0, 0, 0, 0, 0] # The cursor is: ^\ for i in range(10): self.assertEqual( correct_valids[i], memory.is_valid_transition(i), 'Index %i should be %s' % (i, bool(correct_valids[i])))
def testSampleTransitionBatchExtra(self): replay_capacity = 10 memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=1, replay_capacity=replay_capacity, batch_size=2, extra_storage_types=[ circular_replay_buffer.ReplayElement('extra1', [], np.float32), circular_replay_buffer.ReplayElement('extra2', [2], np.int8) ]) num_adds = 50 # The number of transitions to add to the memory. for i in range(num_adds): memory.add( np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE), i, dtype=OBS_DTYPE), 0, 0, i % 4, 0, [0, 0]) # Every 4 transitions is terminal. # Test sampling with default batch size. for i in range(1000): batch = memory.sample_transition_batch() self.assertEqual(batch[0].shape[0], 2) # Test changing batch sizes. for i in range(1000): batch = memory.sample_transition_batch(BATCH_SIZE) self.assertEqual(batch[0].shape[0], BATCH_SIZE) # Verify we revert to default batch size. for i in range(1000): batch = memory.sample_transition_batch() self.assertEqual(batch[0].shape[0], 2) # Verify we can specify what indices to sample. indices = [1, 2, 3, 5, 8] expected_states = np.array([ np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE, 1), i, dtype=OBS_DTYPE) for i in indices ]) expected_next_states = (expected_states + 1) % replay_capacity # Because the replay buffer is circular, we can exactly compute what the # states will be at the specified indices by doing a little mod math: expected_states += num_adds - replay_capacity expected_next_states += num_adds - replay_capacity # This is replicating the formula that was used above to determine what # transitions are terminal when adding observation (i % 4). expected_terminal = np.array( [min((x + num_adds - replay_capacity) % 4, 1) for x in indices]) expected_extra2 = np.zeros([len(indices), 2]) batch = memory.sample_transition_batch(batch_size=len(indices), indices=indices) (states, action, reward, next_states, terminal, indices_batch, extra1, extra2) = batch self.assertAllEqual(states, expected_states) self.assertAllEqual(action, np.zeros(len(indices))) self.assertAllEqual(reward, np.zeros(len(indices))) self.assertAllEqual(next_states, expected_next_states) self.assertAllEqual(terminal, expected_terminal) self.assertAllEqual(indices_batch, indices) self.assertAllEqual(extra1, np.zeros(len(indices))) self.assertAllEqual(extra2, expected_extra2)
def _load_buffer(self, suffix): """Loads a OutOfGraphReplayBuffer replay buffer.""" try: # pytype: disable=attribute-error logging.info('Starting to load from ckpt %d from %s', int(suffix), self._data_dir) replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer( *self._args, **self._kwargs) replay_buffer.load(self._data_dir, suffix) # pylint: disable = protected-access replay_capacity = replay_buffer._replay_capacity logging.info('Capacity: %d', replay_buffer._replay_capacity) logging.info('Start index: %d', self._replay_start_index) for name, array in replay_buffer._store.items(): # This frees unused RAM if replay_capacity is smaller than 1M end_index = (self._replay_start_index + replay_capacity + replay_buffer._stack_size) replay_buffer._store[name] = array[ self._replay_start_index:end_index].copy() logging.info('%s: %s', name, array.shape) logging.info('Loaded replay buffer from ckpt %d from %s', int(suffix), self._data_dir) # pylint: enable=protected-access # pytype: enable=attribute-error return replay_buffer except tf.errors.NotFoundError: return None
def testLoad(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE) self.assertNotEqual(memory._store['observation'], self._test_observation) self.assertNotEqual(memory._store['action'], self._test_action) self.assertNotEqual(memory._store['reward'], self._test_reward) self.assertNotEqual(memory._store['terminal'], self._test_terminal) self.assertNotEqual(memory.add_count, self._test_add_count) self.assertNotEqual(memory.invalid_range, self._test_invalid_range) store_prefix = '$store$_' numpy_arrays = { store_prefix + 'observation': self._test_observation, store_prefix + 'action': self._test_action, store_prefix + 'reward': self._test_reward, store_prefix + 'terminal': self._test_terminal, 'add_count': self._test_add_count, 'invalid_range': self._test_invalid_range } for attr in numpy_arrays: filename = os.path.join(self._test_subdir, '{}_ckpt.3.gz'.format(attr)) with tf.gfile.Open(filename, 'w') as f: with gzip.GzipFile(fileobj=f) as outfile: np.save(outfile, numpy_arrays[attr], allow_pickle=False) memory.load(self._test_subdir, '3') self.assertAllClose(memory._store['observation'], self._test_observation) self.assertAllClose(memory._store['action'], self._test_action) self.assertAllClose(memory._store['reward'], self._test_reward) self.assertAllClose(memory._store['terminal'], self._test_terminal) self.assertEqual(memory.add_count, self._test_add_count) self.assertAllClose(memory.invalid_range, self._test_invalid_range)
def testPartialLoadFails(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE) self.assertNotEqual(memory._store['observation'], self._test_observation) self.assertNotEqual(memory._store['action'], self._test_action) self.assertNotEqual(memory._store['reward'], self._test_reward) self.assertNotEqual(memory._store['terminal'], self._test_terminal) self.assertNotEqual(memory.add_count, self._test_add_count) numpy_arrays = { 'observation': self._test_observation, 'action': self._test_action, 'terminal': self._test_terminal, 'add_count': self._test_add_count, 'invalid_range': self._test_invalid_range } for attr in numpy_arrays: filename = os.path.join(self._test_subdir, '{}_ckpt.3.gz'.format(attr)) with tf.gfile.Open(filename, 'w') as f: with gzip.GzipFile(fileobj=f) as outfile: np.save(outfile, numpy_arrays[attr], allow_pickle=False) # We are are missing the reward file, so a NotFoundError will be raised. with self.assertRaises(tf.errors.NotFoundError): memory.load(self._test_subdir, '3') # Since we are missing the reward file, it should not have loaded any of # the other files. self.assertNotEqual(memory._store['observation'], self._test_observation) self.assertNotEqual(memory._store['action'], self._test_action) self.assertNotEqual(memory._store['reward'], self._test_reward) self.assertNotEqual(memory._store['terminal'], self._test_terminal) self.assertNotEqual(memory.add_count, self._test_add_count) self.assertNotEqual(memory.invalid_range, self._test_invalid_range)
def testGetRangeNoWraparound(self): # Test the get_range function when the indices do not wrap around the # circular buffer. In other words, start_index < end_index. memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=10, batch_size=BATCH_SIZE, update_horizon=5, gamma=1.0) for _ in range(10): memory.add(np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), 0, 2.0, 0) # The constructed `array` will be: # array([[ 1., 1., 1., 1., 1.], # [ 2., 2., 2., 2., 2.], # [ 3., 3., 3., 3., 3.], # [ 4., 4., 4., 4., 4.], # [ 5., 5., 5., 5., 5.], # [ 6., 6., 6., 6., 6.], # [ 7., 7., 7., 7., 7.], # [ 8., 8., 8., 8., 8.], # [ 9., 9., 9., 9., 9.], # [10., 10., 10., 10., 10.]]) array = np.arange(10).reshape(10, 1) + np.ones(5) sliced_array = memory.get_range(array, 2, 5) self.assertAllEqual(sliced_array, array[2:5])
def _build_replay_buffer(self): """Creates the replay buffer used by the agent.""" return circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=self.observation_shape, stack_size=self.stack_size, update_horizon=self.update_horizon, gamma=self.gamma, observation_dtype=self.observation_dtype)
def testAdd(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE) self.assertEqual(memory.cursor(), 0) zeros = np.zeros(OBSERVATION_SHAPE) memory.add(zeros, 0, 0, 0) # Check if the cursor moved STACK_SIZE -1 padding adds + 1, (the one above). self.assertEqual(memory.cursor(), STACK_SIZE)
def testNodeNotAddedToMemory(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=1, replay_capacity=5, batch_size=BATCH_SIZE, use_contiguous_trajectories=True) self.assertEqual(memory.cursor(), 0) zeros = np.zeros(OBSERVATION_SHAPE) memory.add(zeros, 0, 0, 0) self.assertEqual(memory.cursor(), 0) self.assertEqual(memory.add_count, 0)
def _load_buffer(self, suffix): """Loads a OutOfGraphReplayBuffer replay buffer.""" try: # pytype: disable=attribute-error replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer( *self._args, **self._kwargs) replay_buffer.load(self._data_dir, suffix) tf.logging.info('Loaded replay buffer ckpt {} from {}'.format( suffix, self._data_dir)) # pytype: enable=attribute-error return replay_buffer except tf.errors.NotFoundError: return None
def testAddTerminalNodeToTrajectoryBuffer(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, use_contiguous_trajectories=True) self.assertEqual(memory.cursor(), 0) self.assertEqual(len(memory._trajectory), 0) zeros = np.zeros(OBSERVATION_SHAPE) memory.add(zeros, 0, 0, 1) self.assertEqual(memory.cursor(), STACK_SIZE) self.assertEqual(memory.add_count, STACK_SIZE) self.assertEqual(len(memory._trajectory), 0)
def testLoadFromNonexistentDirectory(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE) # We are trying to load from a non-existent directory, so a NotFoundError # will be raised. with self.assertRaises(tf.errors.NotFoundError): memory.load('/does/not/exist', '3') self.assertNotEqual(memory._store['observation'], self._test_observation) self.assertNotEqual(memory._store['action'], self._test_action) self.assertNotEqual(memory._store['reward'], self._test_reward) self.assertNotEqual(memory._store['terminal'], self._test_terminal) self.assertNotEqual(memory.add_count, self._test_add_count) self.assertNotEqual(memory.invalid_range, self._test_invalid_range)
def testCheckAddTypes(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, extra_storage_types=[ circular_replay_buffer.ReplayElement('extra1', [], np.float32), circular_replay_buffer.ReplayElement('extra2', [2], np.int8) ]) zeros = np.zeros(OBSERVATION_SHAPE) memory._check_add_types(zeros, 0, 0, 0, 0, [0, 0]) with self.assertRaisesRegexp(ValueError, 'Add expects'): memory._check_add_types(zeros, 0, 0, 0)
def testAddMultipleThreadsNodeNotAdded(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=1, replay_capacity=5, batch_size=BATCH_SIZE, use_contiguous_trajectories=True) self.assertEqual(memory.cursor(), 0) self.assertEqual(len(memory._trajectory), 0) zeros = np.zeros(OBSERVATION_SHAPE) # Add transition in main thread. memory.add(zeros, 0, 0, 0) # Add a terminal transition in separate thread. with test_utils.mock_thread('other-thread'): memory.add(zeros, 0, 0, 1) # Check that terminal transition is added by itself. self.assertEqual(memory.add_count, 1)
def _build_memory(self, capacity, batch_size): """Creates the replay buffer used by the generators. Args: capacity: int, maximum capacity of the memory unit. batch_size int, batch size of the batch produced during memory replay. Returns: A OutOfGraphReplayBuffer object. """ return circular_replay_buffer.OutOfGraphReplayBuffer( self.observation_shape, self.stack_size, capacity, batch_size, observation_dtype=self.observation_dtype.as_numpy_dtype, )
def testNSteprewardum(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=10, batch_size=BATCH_SIZE, update_horizon=5, gamma=1.0) for i in range(50): memory.add(np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE), 0, 2.0, 0) for i in range(100): batch = memory.sample_transition_batch() # Make sure the total reward is reward per step x update_horizon. self.assertEqual(batch[2][0], 10.0)
def _load_buffer(self, suffix): """Loads a OutOfGraphReplayBuffer replay buffer.""" try: if use_off_policy_replay_buffer: replay_buffer = off_policy_replay_buffer.OutOfGraphOffPolicyReplayBuffer( *self._args, subsample_seed=suffix, **self._kwargs) else: # pytype: disable=attribute-error replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer( *self._args, **self._kwargs) replay_buffer.load(self._data_dir, suffix) tf.logging.info('Loaded replay buffer ckpt {} from {}'.format( suffix, self._data_dir)) # pytype: enable=attribute-error return replay_buffer except tf.errors.NotFoundError: # raise return None
def testExtraAdd(self): memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, extra_storage_types=[ circular_replay_buffer.ReplayElement('extra1', [], np.float32), circular_replay_buffer.ReplayElement('extra2', [2], np.int8) ]) self.assertEqual(memory.cursor(), 0) zeros = np.zeros(OBSERVATION_SHAPE) memory.add(zeros, 0, 0, 0, 0, [0, 0]) with self.assertRaisesRegexp(ValueError, 'Add expects'): memory.add(zeros, 0, 0, 0) # Check if the cursor moved STACK_SIZE -1 zeros adds + 1, (the one above). self.assertEqual(memory.cursor(), STACK_SIZE)
def testSamplingWithterminalInTrajectory(self): replay_capacity = 10 update_horizon = 3 memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=1, replay_capacity=replay_capacity, batch_size=2, update_horizon=update_horizon, gamma=1.0) for i in range(replay_capacity): memory.add( np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE), i, dtype=OBS_DTYPE), i * 2, # action i, # reward 1 if i == 3 else 0) # terminal indices = [2, 3, 4] batch = memory.sample_transition_batch(batch_size=len(indices), indices=indices) states, action, reward, _, terminal, indices_batch = batch expected_states = np.array([ np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE, 1), i, dtype=OBS_DTYPE) for i in indices ]) # The reward in the replay buffer will be (an asterisk marks the terminal # state): # [0 1 2 3* 4 5 6 7 8 9] # Since we're setting the update_horizon to 3, the accumulated trajectory # reward starting at each of the replay buffer positions will be: # [3 6 5 3 15 18 21 24] # Since indices = [2, 3, 4], our expected reward are [5, 3, 15]. expected_reward = np.array([5, 3, 15]) # Because update_horizon = 3, both indices 2 and 3 include terminal. expected_terminal = np.array([1, 1, 0]) self.assertAllEqual(states, expected_states) self.assertAllEqual(action, np.array(indices) * 2) self.assertAllEqual(reward, expected_reward) self.assertAllEqual(terminal, expected_terminal) self.assertAllEqual(indices_batch, indices)
def testGetRangeInvalidIndexOrder(self): replay_capacity = 10 memory = circular_replay_buffer.OutOfGraphReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=replay_capacity, batch_size=BATCH_SIZE, update_horizon=5, gamma=1.0) with self.assertRaisesRegexp( AssertionError, 'end_index must be larger than start_index'): memory.get_range([], 2, 1) with self.assertRaises(AssertionError): # Negative end_index. memory.get_range([], 1, -1) with self.assertRaises(AssertionError): # Start index beyond replay capacity. memory.get_range([], replay_capacity, replay_capacity + 1) with self.assertRaisesRegexp(AssertionError, 'Index 1 has not been added.'): memory.get_range([], 1, 2)