예제 #1
0
    def testIsTransitionValid(self):
        memory = circular_replay_buffer.ReplayBuffer(
            stack_size=STACK_SIZE, replay_capacity=10, batch_size=2
        )

        memory.add(
            observation=np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE),
            action=0,
            reward=0,
            terminal=0,
        )
        memory.add(
            observation=np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE),
            action=0,
            reward=0,
            terminal=0,
        )
        memory.add(
            observation=np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE),
            action=0,
            reward=0,
            terminal=1,
        )

        # These valids account for the automatically applied padding (3 blanks each
        # episode.
        # correct_valids = [0, 0, 0, 1, 1, 0, 0, 0, 0, 0]
        # The above comment is for the original Dopamine buffer, which doesn't
        # account for terminal frames within the update_horizon frames before
        # the cursor. In this case, the frame right before the cursor
        # is terminal, so even though it is within [c-update_horizon, c],
        # it should still be valid for sampling, as next state doesn't matter.
        correct_valids = [0, 0, 0, 1, 1, 1, 0, 0, 0, 0]
        # The cursor is:                    ^\
        for i in range(10):
            self.assertEqual(
                correct_valids[i],
                memory.is_valid_transition(i),
                "Index %i should be %s" % (i, bool(correct_valids[i])),
            )
 def testSamplingWithterminalInTrajectory(self):
     replay_capacity = 10
     update_horizon = 3
     memory = circular_replay_buffer.ReplayBuffer(
         stack_size=1,
         replay_capacity=replay_capacity,
         batch_size=2,
         update_horizon=update_horizon,
         gamma=1.0,
     )
     for i in range(replay_capacity):
         memory.add(
             observation=np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE),
             action=i * 2,
             reward=i,
             terminal=1 if i == 3 else 0,
         )
     indices = [2, 3, 4]
     batch = memory.sample_transition_batch(batch_size=len(indices),
                                            indices=torch.tensor(indices))
     # In commone shape, state is 2-D unless stack_size > 1.
     expected_states = np.array(
         [np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE) for i in indices])
     # The reward in the replay buffer will be (an asterisk marks the terminal
     # state):
     #   [0 1 2 3* 4 5 6 7 8 9]
     # Since we're setting the update_horizon to 3, the accumulated trajectory
     # reward starting at each of the replay buffer positions will be:
     #   [3 6 5 3 15 18 21 24]
     # Since indices = [2, 3, 4], our expected reward are [5, 3, 15].
     expected_reward = np.array([[5], [3], [15]])
     # Because update_horizon = 3, both indices 2 and 3 include terminal.
     expected_terminal = np.array([[1], [1], [0]]).astype(bool)
     npt.assert_array_equal(batch.state, expected_states)
     npt.assert_array_equal(batch.action,
                            np.expand_dims(np.array(indices) * 2, axis=1))
     npt.assert_array_equal(batch.reward, expected_reward)
     npt.assert_array_equal(batch.terminal, expected_terminal)
     npt.assert_array_equal(batch.indices,
                            np.expand_dims(np.array(indices), 1))
 def testGetRangeInvalidIndexOrder(self):
     replay_capacity = 10
     memory = circular_replay_buffer.ReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=STACK_SIZE,
         replay_capacity=replay_capacity,
         batch_size=BATCH_SIZE,
         update_horizon=5,
         gamma=1.0,
     )
     with self.assertRaisesRegex(
             AssertionError, "end_index must be larger than start_index"):
         memory.get_range([], 2, 1)
     with self.assertRaises(AssertionError):
         # Negative end_index.
         memory.get_range([], 1, -1)
     with self.assertRaises(AssertionError):
         # Start index beyond replay capacity.
         memory.get_range([], replay_capacity, replay_capacity + 1)
     with self.assertRaisesRegex(AssertionError,
                                 "Index 1 has not been added."):
         memory.get_range([], 1, 2)
 def testPartialLoadFails(self):
     memory = circular_replay_buffer.ReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=STACK_SIZE,
         replay_capacity=5,
         batch_size=BATCH_SIZE,
     )
     self.assertNotEqual(memory._store["observation"],
                         self._test_observation)
     self.assertNotEqual(memory._store["action"], self._test_action)
     self.assertNotEqual(memory._store["reward"], self._test_reward)
     self.assertNotEqual(memory._store["terminal"], self._test_terminal)
     self.assertNotEqual(memory.add_count, self._test_add_count)
     numpy_arrays = {
         "observation": self._test_observation,
         "action": self._test_action,
         "terminal": self._test_terminal,
         "add_count": self._test_add_count,
         "invalid_range": self._test_invalid_range,
     }
     for attr in numpy_arrays:
         filename = os.path.join(self._test_subdir,
                                 "{}_ckpt.3.gz".format(attr))
         with open(filename, "wb") as f:
             with gzip.GzipFile(fileobj=f) as outfile:
                 np.save(outfile, numpy_arrays[attr], allow_pickle=False)
     # We are are missing the reward file, so a NotFoundError will be raised.
     with self.assertRaises(FileNotFoundError):
         memory.load(self._test_subdir, "3")
     # Since we are missing the reward file, it should not have loaded any of
     # the other files.
     self.assertNotEqual(memory._store["observation"],
                         self._test_observation)
     self.assertNotEqual(memory._store["action"], self._test_action)
     self.assertNotEqual(memory._store["reward"], self._test_reward)
     self.assertNotEqual(memory._store["terminal"], self._test_terminal)
     self.assertNotEqual(memory.add_count, self._test_add_count)
     self.assertNotEqual(memory.invalid_range, self._test_invalid_range)
 def testLoad(self):
     memory = circular_replay_buffer.ReplayBuffer(
         observation_shape=OBSERVATION_SHAPE,
         stack_size=STACK_SIZE,
         replay_capacity=5,
         batch_size=BATCH_SIZE,
     )
     self.assertNotEqual(memory._store["observation"],
                         self._test_observation)
     self.assertNotEqual(memory._store["action"], self._test_action)
     self.assertNotEqual(memory._store["reward"], self._test_reward)
     self.assertNotEqual(memory._store["terminal"], self._test_terminal)
     self.assertNotEqual(memory.add_count, self._test_add_count)
     self.assertNotEqual(memory.invalid_range, self._test_invalid_range)
     store_prefix = "$store$_"
     numpy_arrays = {
         store_prefix + "observation": self._test_observation,
         store_prefix + "action": self._test_action,
         store_prefix + "reward": self._test_reward,
         store_prefix + "terminal": self._test_terminal,
         "add_count": self._test_add_count,
         "invalid_range": self._test_invalid_range,
     }
     for attr in numpy_arrays:
         filename = os.path.join(self._test_subdir,
                                 "{}_ckpt.3.gz".format(attr))
         with open(filename, "wb") as f:
             with gzip.GzipFile(fileobj=f) as outfile:
                 np.save(outfile, numpy_arrays[attr], allow_pickle=False)
     memory.load(self._test_subdir, "3")
     npt.assert_allclose(memory._store["observation"],
                         self._test_observation)
     npt.assert_allclose(memory._store["action"], self._test_action)
     npt.assert_allclose(memory._store["reward"], self._test_reward)
     npt.assert_allclose(memory._store["terminal"], self._test_terminal)
     self.assertEqual(memory.add_count, self._test_add_count)
     npt.assert_allclose(memory.invalid_range, self._test_invalid_range)
    def testSampleTransitionBatchExtra(self):
        replay_capacity = 10
        memory = circular_replay_buffer.ReplayBuffer(
            observation_shape=OBSERVATION_SHAPE,
            stack_size=1,
            replay_capacity=replay_capacity,
            batch_size=2,
            extra_storage_types=[
                circular_replay_buffer.ReplayElement("extra1", [], np.float32),
                circular_replay_buffer.ReplayElement("extra2", [2], np.int8),
            ],
        )
        num_adds = 50  # The number of transitions to add to the memory.
        for i in range(num_adds):
            memory.add(
                np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE),
                0,
                0,
                i % 4,
                i % 2,
                [i % 2, 0],
            )  # Every 4 transitions is terminal.
        # Test sampling with default batch size.
        for _i in range(1000):
            batch = memory.sample_transition_batch()
            self.assertEqual(batch[0].shape[0], 2)
        # Test changing batch sizes.
        for _i in range(1000):
            batch = memory.sample_transition_batch(BATCH_SIZE)
            self.assertEqual(batch[0].shape[0], BATCH_SIZE)
        # Verify we revert to default batch size.
        for _i in range(1000):
            batch = memory.sample_transition_batch()
            self.assertEqual(batch[0].shape[0], 2)

        # Verify we can specify what indices to sample.
        indices = [1, 2, 3, 5, 8]
        expected_states = np.array([
            np.full(OBSERVATION_SHAPE + (1, ), i, dtype=OBS_DTYPE)
            for i in indices
        ])
        expected_next_states = (expected_states + 1) % replay_capacity
        # Because the replay buffer is circular, we can exactly compute what the
        # states will be at the specified indices by doing a little mod math:
        expected_states += num_adds - replay_capacity
        expected_next_states += num_adds - replay_capacity
        # This is replicating the formula that was used above to determine what
        # transitions are terminal when adding observation (i % 4).
        expected_terminal = np.array(
            [min((x + num_adds - replay_capacity) % 4, 1) for x in indices])
        expected_extra1 = np.array([(x + num_adds - replay_capacity) % 2
                                    for x in indices])
        expected_next_extra1 = np.array([
            (x + 1 + num_adds - replay_capacity) % 2 for x in indices
        ])
        expected_extra2 = np.stack(
            [
                [(x + num_adds - replay_capacity) % 2 for x in indices],
                np.zeros((len(indices), )),
            ],
            axis=1,
        )
        expected_next_extra2 = np.stack(
            [
                [(x + 1 + num_adds - replay_capacity) % 2 for x in indices],
                np.zeros((len(indices), )),
            ],
            axis=1,
        )
        batch = memory.sample_transition_batch(batch_size=len(indices),
                                               indices=np.array(indices))
        (
            states,
            action,
            reward,
            next_states,
            next_action,
            next_reward,
            terminal,
            indices_batch,
            extra1,
            next_extra1,
            extra2,
            next_extra2,
        ) = batch
        npt.assert_array_equal(states, expected_states)
        npt.assert_array_equal(action, np.zeros(len(indices)))
        npt.assert_array_equal(reward, np.zeros(len(indices)))
        npt.assert_array_equal(next_action, np.zeros(len(indices)))
        npt.assert_array_equal(next_reward, np.zeros(len(indices)))
        npt.assert_array_equal(next_states, expected_next_states)
        npt.assert_array_equal(terminal, expected_terminal)
        npt.assert_array_equal(indices_batch, indices)
        npt.assert_array_equal(extra1, expected_extra1)
        npt.assert_array_equal(next_extra1, expected_next_extra1)
        npt.assert_array_equal(extra2, expected_extra2)
        npt.assert_array_equal(next_extra2, expected_next_extra2)
예제 #7
0
 def testConstructor(self):
     memory = circular_replay_buffer.ReplayBuffer(
         stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE
     )
     self.assertEqual(memory.add_count, 0)