def testLowCapacity(self):
        with self.assertRaisesRegexp(ValueError,
                                     'There is not enough capacity'):
            circular_replay_buffer.OutOfGraphReplayBuffer(
                observation_shape=OBSERVATION_SHAPE,
                stack_size=10,
                replay_capacity=10,
                batch_size=BATCH_SIZE,
                update_horizon=1,
                gamma=1.0)

        with self.assertRaisesRegexp(ValueError,
                                     'There is not enough capacity'):
            circular_replay_buffer.OutOfGraphReplayBuffer(
                observation_shape=OBSERVATION_SHAPE,
                stack_size=5,
                replay_capacity=10,
                batch_size=BATCH_SIZE,
                update_horizon=10,
                gamma=1.0)

        # We should be able to create a buffer that contains just enough for a
        # transition.
        circular_replay_buffer.OutOfGraphReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=5,
            replay_capacity=10,
            batch_size=BATCH_SIZE,
            update_horizon=5,
            gamma=1.0)
 def testConstructor(self):
     memory = circular_replay_buffer.OutOfGraphReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=STACK_SIZE,
         replay_capacity=5,
         batch_size=BATCH_SIZE)
     self.assertEqual(memory._observation_shape, OBSERVATION_SHAPE)
     # Test with non square observation shape
     memory = circular_replay_buffer.OutOfGraphReplayBuffer(
         observation_shape=(4, 20),
         stack_size=STACK_SIZE,
         replay_capacity=5,
         batch_size=BATCH_SIZE)
     self.assertEqual(memory._observation_shape, (4, 20))
     self.assertEqual(memory.add_count, 0)
 def testGetRangeWithWraparound(self):
     # Test the get_range function when the indices wrap around the circular
     # buffer. In other words, start_index > end_index.
     memory = circular_replay_buffer.OutOfGraphReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=STACK_SIZE,
         replay_capacity=10,
         batch_size=BATCH_SIZE,
         update_horizon=5,
         gamma=1.0)
     for _ in range(10):
         memory.add(np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), 0, 2.0,
                    0)
     # The constructed `array` will be:
     # array([[ 1.,  1.,  1.,  1.,  1.],
     #        [ 2.,  2.,  2.,  2.,  2.],
     #        [ 3.,  3.,  3.,  3.,  3.],
     #        [ 4.,  4.,  4.,  4.,  4.],
     #        [ 5.,  5.,  5.,  5.,  5.],
     #        [ 6.,  6.,  6.,  6.,  6.],
     #        [ 7.,  7.,  7.,  7.,  7.],
     #        [ 8.,  8.,  8.,  8.,  8.],
     #        [ 9.,  9.,  9.,  9.,  9.],
     #        [10., 10., 10., 10., 10.]])
     array = np.arange(10).reshape(10, 1) + np.ones(5)
     sliced_array = memory.get_range(array, 8, 12)
     # We roll by two, since start_index == 8 and replay_capacity == 10, so the
     # resulting indices used will be [8, 9, 0, 1].
     rolled_array = np.roll(array, 2, axis=0)
     self.assertAllEqual(sliced_array, rolled_array[:4])
    def testGetStack(self):
        zero_stack = np.zeros(OBSERVATION_SHAPE + (4, ), dtype=OBS_DTYPE)

        memory = circular_replay_buffer.OutOfGraphReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=STACK_SIZE,
            replay_capacity=50,
            batch_size=BATCH_SIZE)
        for i in range(11):
            memory.add(np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE), 0, 0, 0)

        # ensure that the returned shapes are always correct
        for i in range(3, memory.cursor()):
            self.assertTrue(
                memory.get_observation_stack(i).shape,
                OBSERVATION_SHAPE + (4, ))

        # ensure that there is the necessary 0 padding
        stack = memory.get_observation_stack(3)
        self.assertTrue(np.array_equal(zero_stack, stack))

        # ensure that after the padding the contents are properly stored
        stack = memory.get_observation_stack(6)
        for i in range(4):
            self.assertTrue(
                np.array_equal(np.full(OBSERVATION_SHAPE, i), stack[:, :, i]))
def _create_dummy_memory(**kwargs):
    return circular_replay_buffer.OutOfGraphReplayBuffer(
        observation_shape=(2, ),
        stack_size=1,
        replay_capacity=10,
        batch_size=2,
        **kwargs)
    def testSave(self):
        memory = circular_replay_buffer.OutOfGraphReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=STACK_SIZE,
            replay_capacity=5,
            batch_size=BATCH_SIZE)
        memory.observation = self._test_observation
        memory.action = self._test_action
        memory.reward = self._test_reward
        memory.terminal = self._test_terminal
        current_iteration = 5
        stale_iteration = (current_iteration -
                           circular_replay_buffer.CHECKPOINT_DURATION)
        memory.save(self._test_subdir, stale_iteration)
        for attr in memory.__dict__:
            if attr.startswith('_'):
                continue
            stale_filename = os.path.join(
                self._test_subdir,
                '{}_ckpt.{}.gz'.format(attr, stale_iteration))
            self.assertTrue(tf.gfile.Exists(stale_filename))

        memory.save(self._test_subdir, current_iteration)
        for attr in memory.__dict__:
            if attr.startswith('_'):
                continue
            filename = os.path.join(
                self._test_subdir,
                '{}_ckpt.{}.gz'.format(attr, current_iteration))
            self.assertTrue(tf.gfile.Exists(filename))
            # The stale version file should have been deleted.
            self.assertFalse(tf.gfile.Exists(stale_filename))
    def testSaveNonNDArrayAttributes(self):
        """Tests checkpointing an attribute which is not a numpy array."""
        memory = circular_replay_buffer.OutOfGraphReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=STACK_SIZE,
            replay_capacity=5,
            batch_size=BATCH_SIZE)

        # Add some non-numpy data: an int, a string, an object.
        memory.dummy_attribute_1 = 4753849
        memory.dummy_attribute_2 = 'String data'
        memory.dummy_attribute_3 = CheckpointableClass()

        current_iteration = 5
        stale_iteration = (current_iteration -
                           circular_replay_buffer.CHECKPOINT_DURATION)
        memory.save(self._test_subdir, stale_iteration)
        for attr in memory.__dict__:
            if attr.startswith('_'):
                continue
            stale_filename = os.path.join(
                self._test_subdir,
                '{}_ckpt.{}.gz'.format(attr, stale_iteration))
            self.assertTrue(tf.gfile.Exists(stale_filename))

        memory.save(self._test_subdir, current_iteration)
        for attr in memory.__dict__:
            if attr.startswith('_'):
                continue
            filename = os.path.join(
                self._test_subdir,
                '{}_ckpt.{}.gz'.format(attr, current_iteration))
            self.assertTrue(tf.gfile.Exists(filename))
            # The stale version file should have been deleted.
            self.assertFalse(tf.gfile.Exists(stale_filename))
 def testWithNontupleObservationShape(self):
     with self.assertRaises(AssertionError):
         _ = circular_replay_buffer.OutOfGraphReplayBuffer(
             observation_shape=84,
             stack_size=STACK_SIZE,
             replay_capacity=5,
             batch_size=BATCH_SIZE)
    def testIsTransitionValid(self):
        memory = circular_replay_buffer.OutOfGraphReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=STACK_SIZE,
            replay_capacity=10,
            batch_size=2)

        memory.add(
            np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE), 0,
                    dtype=OBS_DTYPE), 0, 0, 0)
        memory.add(
            np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE), 0,
                    dtype=OBS_DTYPE), 0, 0, 0)
        memory.add(
            np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE), 0,
                    dtype=OBS_DTYPE), 0, 0, 1)

        # These valids account for the automatically applied padding (3 blanks each
        # episode.
        correct_valids = [0, 0, 0, 1, 1, 0, 0, 0, 0, 0]
        # The cursor is:                    ^\
        for i in range(10):
            self.assertEqual(
                correct_valids[i], memory.is_valid_transition(i),
                'Index %i should be %s' % (i, bool(correct_valids[i])))
    def testSampleTransitionBatchExtra(self):
        replay_capacity = 10
        memory = circular_replay_buffer.OutOfGraphReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=1,
            replay_capacity=replay_capacity,
            batch_size=2,
            extra_storage_types=[
                circular_replay_buffer.ReplayElement('extra1', [], np.float32),
                circular_replay_buffer.ReplayElement('extra2', [2], np.int8)
            ])
        num_adds = 50  # The number of transitions to add to the memory.
        for i in range(num_adds):
            memory.add(
                np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE),
                        i,
                        dtype=OBS_DTYPE), 0, 0, i % 4, 0,
                [0, 0])  # Every 4 transitions is terminal.
        # Test sampling with default batch size.
        for i in range(1000):
            batch = memory.sample_transition_batch()
            self.assertEqual(batch[0].shape[0], 2)
        # Test changing batch sizes.
        for i in range(1000):
            batch = memory.sample_transition_batch(BATCH_SIZE)
            self.assertEqual(batch[0].shape[0], BATCH_SIZE)
        # Verify we revert to default batch size.
        for i in range(1000):
            batch = memory.sample_transition_batch()
            self.assertEqual(batch[0].shape[0], 2)

        # Verify we can specify what indices to sample.
        indices = [1, 2, 3, 5, 8]
        expected_states = np.array([
            np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE, 1),
                    i,
                    dtype=OBS_DTYPE) for i in indices
        ])
        expected_next_states = (expected_states + 1) % replay_capacity
        # Because the replay buffer is circular, we can exactly compute what the
        # states will be at the specified indices by doing a little mod math:
        expected_states += num_adds - replay_capacity
        expected_next_states += num_adds - replay_capacity
        # This is replicating the formula that was used above to determine what
        # transitions are terminal when adding observation (i % 4).
        expected_terminal = np.array(
            [min((x + num_adds - replay_capacity) % 4, 1) for x in indices])
        expected_extra2 = np.zeros([len(indices), 2])
        batch = memory.sample_transition_batch(batch_size=len(indices),
                                               indices=indices)
        (states, action, reward, next_states, terminal, indices_batch, extra1,
         extra2) = batch
        self.assertAllEqual(states, expected_states)
        self.assertAllEqual(action, np.zeros(len(indices)))
        self.assertAllEqual(reward, np.zeros(len(indices)))
        self.assertAllEqual(next_states, expected_next_states)
        self.assertAllEqual(terminal, expected_terminal)
        self.assertAllEqual(indices_batch, indices)
        self.assertAllEqual(extra1, np.zeros(len(indices)))
        self.assertAllEqual(extra2, expected_extra2)
    def _load_buffer(self, suffix):
        """Loads a OutOfGraphReplayBuffer replay buffer."""
        try:
            # pytype: disable=attribute-error
            logging.info('Starting to load from ckpt %d from %s', int(suffix),
                         self._data_dir)

            replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(
                *self._args, **self._kwargs)
            replay_buffer.load(self._data_dir, suffix)
            # pylint: disable = protected-access
            replay_capacity = replay_buffer._replay_capacity
            logging.info('Capacity: %d', replay_buffer._replay_capacity)
            logging.info('Start index: %d', self._replay_start_index)
            for name, array in replay_buffer._store.items():
                # This frees unused RAM if replay_capacity is smaller than 1M
                end_index = (self._replay_start_index + replay_capacity +
                             replay_buffer._stack_size)
                replay_buffer._store[name] = array[
                    self._replay_start_index:end_index].copy()
                logging.info('%s: %s', name, array.shape)
            logging.info('Loaded replay buffer from ckpt %d from %s',
                         int(suffix), self._data_dir)
            # pylint: enable=protected-access
            # pytype: enable=attribute-error
            return replay_buffer
        except tf.errors.NotFoundError:
            return None
Example #12
0
 def testLoad(self):
   memory = circular_replay_buffer.OutOfGraphReplayBuffer(
       observation_shape=OBSERVATION_SHAPE,
       stack_size=STACK_SIZE,
       replay_capacity=5,
       batch_size=BATCH_SIZE)
   self.assertNotEqual(memory._store['observation'], self._test_observation)
   self.assertNotEqual(memory._store['action'], self._test_action)
   self.assertNotEqual(memory._store['reward'], self._test_reward)
   self.assertNotEqual(memory._store['terminal'], self._test_terminal)
   self.assertNotEqual(memory.add_count, self._test_add_count)
   self.assertNotEqual(memory.invalid_range, self._test_invalid_range)
   store_prefix = '$store$_'
   numpy_arrays = {
       store_prefix + 'observation': self._test_observation,
       store_prefix + 'action': self._test_action,
       store_prefix + 'reward': self._test_reward,
       store_prefix + 'terminal': self._test_terminal,
       'add_count': self._test_add_count,
       'invalid_range': self._test_invalid_range
   }
   for attr in numpy_arrays:
     filename = os.path.join(self._test_subdir, '{}_ckpt.3.gz'.format(attr))
     with tf.gfile.Open(filename, 'w') as f:
       with gzip.GzipFile(fileobj=f) as outfile:
         np.save(outfile, numpy_arrays[attr], allow_pickle=False)
   memory.load(self._test_subdir, '3')
   self.assertAllClose(memory._store['observation'], self._test_observation)
   self.assertAllClose(memory._store['action'], self._test_action)
   self.assertAllClose(memory._store['reward'], self._test_reward)
   self.assertAllClose(memory._store['terminal'], self._test_terminal)
   self.assertEqual(memory.add_count, self._test_add_count)
   self.assertAllClose(memory.invalid_range, self._test_invalid_range)
Example #13
0
 def testPartialLoadFails(self):
   memory = circular_replay_buffer.OutOfGraphReplayBuffer(
       observation_shape=OBSERVATION_SHAPE,
       stack_size=STACK_SIZE,
       replay_capacity=5,
       batch_size=BATCH_SIZE)
   self.assertNotEqual(memory._store['observation'], self._test_observation)
   self.assertNotEqual(memory._store['action'], self._test_action)
   self.assertNotEqual(memory._store['reward'], self._test_reward)
   self.assertNotEqual(memory._store['terminal'], self._test_terminal)
   self.assertNotEqual(memory.add_count, self._test_add_count)
   numpy_arrays = {
       'observation': self._test_observation,
       'action': self._test_action,
       'terminal': self._test_terminal,
       'add_count': self._test_add_count,
       'invalid_range': self._test_invalid_range
   }
   for attr in numpy_arrays:
     filename = os.path.join(self._test_subdir, '{}_ckpt.3.gz'.format(attr))
     with tf.gfile.Open(filename, 'w') as f:
       with gzip.GzipFile(fileobj=f) as outfile:
         np.save(outfile, numpy_arrays[attr], allow_pickle=False)
   # We are are missing the reward file, so a NotFoundError will be raised.
   with self.assertRaises(tf.errors.NotFoundError):
     memory.load(self._test_subdir, '3')
   # Since we are missing the reward file, it should not have loaded any of
   # the other files.
   self.assertNotEqual(memory._store['observation'], self._test_observation)
   self.assertNotEqual(memory._store['action'], self._test_action)
   self.assertNotEqual(memory._store['reward'], self._test_reward)
   self.assertNotEqual(memory._store['terminal'], self._test_terminal)
   self.assertNotEqual(memory.add_count, self._test_add_count)
   self.assertNotEqual(memory.invalid_range, self._test_invalid_range)
 def testGetRangeNoWraparound(self):
     # Test the get_range function when the indices do not wrap around the
     # circular buffer. In other words, start_index < end_index.
     memory = circular_replay_buffer.OutOfGraphReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=STACK_SIZE,
         replay_capacity=10,
         batch_size=BATCH_SIZE,
         update_horizon=5,
         gamma=1.0)
     for _ in range(10):
         memory.add(np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), 0, 2.0,
                    0)
     # The constructed `array` will be:
     # array([[ 1.,  1.,  1.,  1.,  1.],
     #        [ 2.,  2.,  2.,  2.,  2.],
     #        [ 3.,  3.,  3.,  3.,  3.],
     #        [ 4.,  4.,  4.,  4.,  4.],
     #        [ 5.,  5.,  5.,  5.,  5.],
     #        [ 6.,  6.,  6.,  6.,  6.],
     #        [ 7.,  7.,  7.,  7.,  7.],
     #        [ 8.,  8.,  8.,  8.,  8.],
     #        [ 9.,  9.,  9.,  9.,  9.],
     #        [10., 10., 10., 10., 10.]])
     array = np.arange(10).reshape(10, 1) + np.ones(5)
     sliced_array = memory.get_range(array, 2, 5)
     self.assertAllEqual(sliced_array, array[2:5])
Example #15
0
 def _build_replay_buffer(self):
   """Creates the replay buffer used by the agent."""
   return circular_replay_buffer.OutOfGraphReplayBuffer(
       observation_shape=self.observation_shape,
       stack_size=self.stack_size,
       update_horizon=self.update_horizon,
       gamma=self.gamma,
       observation_dtype=self.observation_dtype)
 def testAdd(self):
     memory = circular_replay_buffer.OutOfGraphReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=STACK_SIZE,
         replay_capacity=5,
         batch_size=BATCH_SIZE)
     self.assertEqual(memory.cursor(), 0)
     zeros = np.zeros(OBSERVATION_SHAPE)
     memory.add(zeros, 0, 0, 0)
     # Check if the cursor moved STACK_SIZE -1 padding adds + 1, (the one above).
     self.assertEqual(memory.cursor(), STACK_SIZE)
 def testNodeNotAddedToMemory(self):
     memory = circular_replay_buffer.OutOfGraphReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=1,
         replay_capacity=5,
         batch_size=BATCH_SIZE,
         use_contiguous_trajectories=True)
     self.assertEqual(memory.cursor(), 0)
     zeros = np.zeros(OBSERVATION_SHAPE)
     memory.add(zeros, 0, 0, 0)
     self.assertEqual(memory.cursor(), 0)
     self.assertEqual(memory.add_count, 0)
 def _load_buffer(self, suffix):
     """Loads a OutOfGraphReplayBuffer replay buffer."""
     try:
         # pytype: disable=attribute-error
         replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(
             *self._args, **self._kwargs)
         replay_buffer.load(self._data_dir, suffix)
         tf.logging.info('Loaded replay buffer ckpt {} from {}'.format(
             suffix, self._data_dir))
         # pytype: enable=attribute-error
         return replay_buffer
     except tf.errors.NotFoundError:
         return None
 def testAddTerminalNodeToTrajectoryBuffer(self):
     memory = circular_replay_buffer.OutOfGraphReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=STACK_SIZE,
         replay_capacity=5,
         batch_size=BATCH_SIZE,
         use_contiguous_trajectories=True)
     self.assertEqual(memory.cursor(), 0)
     self.assertEqual(len(memory._trajectory), 0)
     zeros = np.zeros(OBSERVATION_SHAPE)
     memory.add(zeros, 0, 0, 1)
     self.assertEqual(memory.cursor(), STACK_SIZE)
     self.assertEqual(memory.add_count, STACK_SIZE)
     self.assertEqual(len(memory._trajectory), 0)
Example #20
0
 def testLoadFromNonexistentDirectory(self):
   memory = circular_replay_buffer.OutOfGraphReplayBuffer(
       observation_shape=OBSERVATION_SHAPE,
       stack_size=STACK_SIZE,
       replay_capacity=5,
       batch_size=BATCH_SIZE)
   # We are trying to load from a non-existent directory, so a NotFoundError
   # will be raised.
   with self.assertRaises(tf.errors.NotFoundError):
     memory.load('/does/not/exist', '3')
   self.assertNotEqual(memory._store['observation'], self._test_observation)
   self.assertNotEqual(memory._store['action'], self._test_action)
   self.assertNotEqual(memory._store['reward'], self._test_reward)
   self.assertNotEqual(memory._store['terminal'], self._test_terminal)
   self.assertNotEqual(memory.add_count, self._test_add_count)
   self.assertNotEqual(memory.invalid_range, self._test_invalid_range)
    def testCheckAddTypes(self):
        memory = circular_replay_buffer.OutOfGraphReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=STACK_SIZE,
            replay_capacity=5,
            batch_size=BATCH_SIZE,
            extra_storage_types=[
                circular_replay_buffer.ReplayElement('extra1', [], np.float32),
                circular_replay_buffer.ReplayElement('extra2', [2], np.int8)
            ])
        zeros = np.zeros(OBSERVATION_SHAPE)

        memory._check_add_types(zeros, 0, 0, 0, 0, [0, 0])

        with self.assertRaisesRegexp(ValueError, 'Add expects'):
            memory._check_add_types(zeros, 0, 0, 0)
 def testAddMultipleThreadsNodeNotAdded(self):
     memory = circular_replay_buffer.OutOfGraphReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=1,
         replay_capacity=5,
         batch_size=BATCH_SIZE,
         use_contiguous_trajectories=True)
     self.assertEqual(memory.cursor(), 0)
     self.assertEqual(len(memory._trajectory), 0)
     zeros = np.zeros(OBSERVATION_SHAPE)
     # Add transition in main thread.
     memory.add(zeros, 0, 0, 0)
     # Add a terminal transition in separate thread.
     with test_utils.mock_thread('other-thread'):
         memory.add(zeros, 0, 0, 1)
     # Check that terminal transition is added by itself.
     self.assertEqual(memory.add_count, 1)
Example #23
0
  def _build_memory(self, capacity, batch_size):
    """Creates the replay buffer used by the generators.

    Args:
      capacity: int, maximum capacity of the memory unit.
      batch_size int, batch size of the batch produced during memory replay.

    Returns:
      A OutOfGraphReplayBuffer object.
    """
    return circular_replay_buffer.OutOfGraphReplayBuffer(
      self.observation_shape,
      self.stack_size,
      capacity,
      batch_size,
      observation_dtype=self.observation_dtype.as_numpy_dtype,
    )
    def testNSteprewardum(self):
        memory = circular_replay_buffer.OutOfGraphReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=STACK_SIZE,
            replay_capacity=10,
            batch_size=BATCH_SIZE,
            update_horizon=5,
            gamma=1.0)

        for i in range(50):
            memory.add(np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE), 0, 2.0,
                       0)

        for i in range(100):
            batch = memory.sample_transition_batch()
            # Make sure the total reward is reward per step x update_horizon.
            self.assertEqual(batch[2][0], 10.0)
Example #25
0
 def _load_buffer(self, suffix):
     """Loads a OutOfGraphReplayBuffer replay buffer."""
     try:
         if use_off_policy_replay_buffer:
             replay_buffer = off_policy_replay_buffer.OutOfGraphOffPolicyReplayBuffer(
                 *self._args, subsample_seed=suffix, **self._kwargs)
         else:
             # pytype: disable=attribute-error
             replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(
                 *self._args, **self._kwargs)
         replay_buffer.load(self._data_dir, suffix)
         tf.logging.info('Loaded replay buffer ckpt {} from {}'.format(
             suffix, self._data_dir))
         # pytype: enable=attribute-error
         return replay_buffer
     except tf.errors.NotFoundError:
         # raise
         return None
    def testExtraAdd(self):
        memory = circular_replay_buffer.OutOfGraphReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=STACK_SIZE,
            replay_capacity=5,
            batch_size=BATCH_SIZE,
            extra_storage_types=[
                circular_replay_buffer.ReplayElement('extra1', [], np.float32),
                circular_replay_buffer.ReplayElement('extra2', [2], np.int8)
            ])
        self.assertEqual(memory.cursor(), 0)
        zeros = np.zeros(OBSERVATION_SHAPE)
        memory.add(zeros, 0, 0, 0, 0, [0, 0])

        with self.assertRaisesRegexp(ValueError, 'Add expects'):
            memory.add(zeros, 0, 0, 0)
        # Check if the cursor moved STACK_SIZE -1 zeros adds + 1, (the one above).
        self.assertEqual(memory.cursor(), STACK_SIZE)
 def testSamplingWithterminalInTrajectory(self):
     replay_capacity = 10
     update_horizon = 3
     memory = circular_replay_buffer.OutOfGraphReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=1,
         replay_capacity=replay_capacity,
         batch_size=2,
         update_horizon=update_horizon,
         gamma=1.0)
     for i in range(replay_capacity):
         memory.add(
             np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE),
                     i,
                     dtype=OBS_DTYPE),
             i * 2,  # action
             i,  # reward
             1 if i == 3 else 0)  # terminal
     indices = [2, 3, 4]
     batch = memory.sample_transition_batch(batch_size=len(indices),
                                            indices=indices)
     states, action, reward, _, terminal, indices_batch = batch
     expected_states = np.array([
         np.full((OBSERVATION_SHAPE, OBSERVATION_SHAPE, 1),
                 i,
                 dtype=OBS_DTYPE) for i in indices
     ])
     # The reward in the replay buffer will be (an asterisk marks the terminal
     # state):
     #   [0 1 2 3* 4 5 6 7 8 9]
     # Since we're setting the update_horizon to 3, the accumulated trajectory
     # reward starting at each of the replay buffer positions will be:
     #   [3 6 5 3 15 18 21 24]
     # Since indices = [2, 3, 4], our expected reward are [5, 3, 15].
     expected_reward = np.array([5, 3, 15])
     # Because update_horizon = 3, both indices 2 and 3 include terminal.
     expected_terminal = np.array([1, 1, 0])
     self.assertAllEqual(states, expected_states)
     self.assertAllEqual(action, np.array(indices) * 2)
     self.assertAllEqual(reward, expected_reward)
     self.assertAllEqual(terminal, expected_terminal)
     self.assertAllEqual(indices_batch, indices)
 def testGetRangeInvalidIndexOrder(self):
     replay_capacity = 10
     memory = circular_replay_buffer.OutOfGraphReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=STACK_SIZE,
         replay_capacity=replay_capacity,
         batch_size=BATCH_SIZE,
         update_horizon=5,
         gamma=1.0)
     with self.assertRaisesRegexp(
             AssertionError, 'end_index must be larger than start_index'):
         memory.get_range([], 2, 1)
     with self.assertRaises(AssertionError):
         # Negative end_index.
         memory.get_range([], 1, -1)
     with self.assertRaises(AssertionError):
         # Start index beyond replay capacity.
         memory.get_range([], replay_capacity, replay_capacity + 1)
     with self.assertRaisesRegexp(AssertionError,
                                  'Index 1 has not been added.'):
         memory.get_range([], 1, 2)