def testIsTransitionValid(self): memory = circular_replay_buffer.ReplayBuffer( stack_size=STACK_SIZE, replay_capacity=10, batch_size=2 ) memory.add( observation=np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), action=0, reward=0, terminal=0, ) memory.add( observation=np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), action=0, reward=0, terminal=0, ) memory.add( observation=np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), action=0, reward=0, terminal=1, ) # These valids account for the automatically applied padding (3 blanks each # episode. # correct_valids = [0, 0, 0, 1, 1, 0, 0, 0, 0, 0] # The above comment is for the original Dopamine buffer, which doesn't # account for terminal frames within the update_horizon frames before # the cursor. In this case, the frame right before the cursor # is terminal, so even though it is within [c-update_horizon, c], # it should still be valid for sampling, as next state doesn't matter. correct_valids = [0, 0, 0, 1, 1, 1, 0, 0, 0, 0] # The cursor is: ^\ for i in range(10): self.assertEqual( correct_valids[i], memory.is_valid_transition(i), "Index %i should be %s" % (i, bool(correct_valids[i])), )
def testSamplingWithterminalInTrajectory(self): replay_capacity = 10 update_horizon = 3 memory = circular_replay_buffer.ReplayBuffer( stack_size=1, replay_capacity=replay_capacity, batch_size=2, update_horizon=update_horizon, gamma=1.0, ) for i in range(replay_capacity): memory.add( observation=np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE), action=i * 2, reward=i, terminal=1 if i == 3 else 0, ) indices = [2, 3, 4] batch = memory.sample_transition_batch(batch_size=len(indices), indices=torch.tensor(indices)) # In commone shape, state is 2-D unless stack_size > 1. expected_states = np.array( [np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE) for i in indices]) # The reward in the replay buffer will be (an asterisk marks the terminal # state): # [0 1 2 3* 4 5 6 7 8 9] # Since we're setting the update_horizon to 3, the accumulated trajectory # reward starting at each of the replay buffer positions will be: # [3 6 5 3 15 18 21 24] # Since indices = [2, 3, 4], our expected reward are [5, 3, 15]. expected_reward = np.array([[5], [3], [15]]) # Because update_horizon = 3, both indices 2 and 3 include terminal. expected_terminal = np.array([[1], [1], [0]]).astype(bool) npt.assert_array_equal(batch.state, expected_states) npt.assert_array_equal(batch.action, np.expand_dims(np.array(indices) * 2, axis=1)) npt.assert_array_equal(batch.reward, expected_reward) npt.assert_array_equal(batch.terminal, expected_terminal) npt.assert_array_equal(batch.indices, np.expand_dims(np.array(indices), 1))
def testGetRangeInvalidIndexOrder(self): replay_capacity = 10 memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=replay_capacity, batch_size=BATCH_SIZE, update_horizon=5, gamma=1.0, ) with self.assertRaisesRegex( AssertionError, "end_index must be larger than start_index"): memory.get_range([], 2, 1) with self.assertRaises(AssertionError): # Negative end_index. memory.get_range([], 1, -1) with self.assertRaises(AssertionError): # Start index beyond replay capacity. memory.get_range([], replay_capacity, replay_capacity + 1) with self.assertRaisesRegex(AssertionError, "Index 1 has not been added."): memory.get_range([], 1, 2)
def testPartialLoadFails(self): memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, ) self.assertNotEqual(memory._store["observation"], self._test_observation) self.assertNotEqual(memory._store["action"], self._test_action) self.assertNotEqual(memory._store["reward"], self._test_reward) self.assertNotEqual(memory._store["terminal"], self._test_terminal) self.assertNotEqual(memory.add_count, self._test_add_count) numpy_arrays = { "observation": self._test_observation, "action": self._test_action, "terminal": self._test_terminal, "add_count": self._test_add_count, "invalid_range": self._test_invalid_range, } for attr in numpy_arrays: filename = os.path.join(self._test_subdir, "{}_ckpt.3.gz".format(attr)) with open(filename, "wb") as f: with gzip.GzipFile(fileobj=f) as outfile: np.save(outfile, numpy_arrays[attr], allow_pickle=False) # We are are missing the reward file, so a NotFoundError will be raised. with self.assertRaises(FileNotFoundError): memory.load(self._test_subdir, "3") # Since we are missing the reward file, it should not have loaded any of # the other files. self.assertNotEqual(memory._store["observation"], self._test_observation) self.assertNotEqual(memory._store["action"], self._test_action) self.assertNotEqual(memory._store["reward"], self._test_reward) self.assertNotEqual(memory._store["terminal"], self._test_terminal) self.assertNotEqual(memory.add_count, self._test_add_count) self.assertNotEqual(memory.invalid_range, self._test_invalid_range)
def testLoad(self): memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, ) self.assertNotEqual(memory._store["observation"], self._test_observation) self.assertNotEqual(memory._store["action"], self._test_action) self.assertNotEqual(memory._store["reward"], self._test_reward) self.assertNotEqual(memory._store["terminal"], self._test_terminal) self.assertNotEqual(memory.add_count, self._test_add_count) self.assertNotEqual(memory.invalid_range, self._test_invalid_range) store_prefix = "$store$_" numpy_arrays = { store_prefix + "observation": self._test_observation, store_prefix + "action": self._test_action, store_prefix + "reward": self._test_reward, store_prefix + "terminal": self._test_terminal, "add_count": self._test_add_count, "invalid_range": self._test_invalid_range, } for attr in numpy_arrays: filename = os.path.join(self._test_subdir, "{}_ckpt.3.gz".format(attr)) with open(filename, "wb") as f: with gzip.GzipFile(fileobj=f) as outfile: np.save(outfile, numpy_arrays[attr], allow_pickle=False) memory.load(self._test_subdir, "3") npt.assert_allclose(memory._store["observation"], self._test_observation) npt.assert_allclose(memory._store["action"], self._test_action) npt.assert_allclose(memory._store["reward"], self._test_reward) npt.assert_allclose(memory._store["terminal"], self._test_terminal) self.assertEqual(memory.add_count, self._test_add_count) npt.assert_allclose(memory.invalid_range, self._test_invalid_range)
def testSampleTransitionBatchExtra(self): replay_capacity = 10 memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=1, replay_capacity=replay_capacity, batch_size=2, extra_storage_types=[ circular_replay_buffer.ReplayElement("extra1", [], np.float32), circular_replay_buffer.ReplayElement("extra2", [2], np.int8), ], ) num_adds = 50 # The number of transitions to add to the memory. for i in range(num_adds): memory.add( np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE), 0, 0, i % 4, i % 2, [i % 2, 0], ) # Every 4 transitions is terminal. # Test sampling with default batch size. for _i in range(1000): batch = memory.sample_transition_batch() self.assertEqual(batch[0].shape[0], 2) # Test changing batch sizes. for _i in range(1000): batch = memory.sample_transition_batch(BATCH_SIZE) self.assertEqual(batch[0].shape[0], BATCH_SIZE) # Verify we revert to default batch size. for _i in range(1000): batch = memory.sample_transition_batch() self.assertEqual(batch[0].shape[0], 2) # Verify we can specify what indices to sample. indices = [1, 2, 3, 5, 8] expected_states = np.array([ np.full(OBSERVATION_SHAPE + (1, ), i, dtype=OBS_DTYPE) for i in indices ]) expected_next_states = (expected_states + 1) % replay_capacity # Because the replay buffer is circular, we can exactly compute what the # states will be at the specified indices by doing a little mod math: expected_states += num_adds - replay_capacity expected_next_states += num_adds - replay_capacity # This is replicating the formula that was used above to determine what # transitions are terminal when adding observation (i % 4). expected_terminal = np.array( [min((x + num_adds - replay_capacity) % 4, 1) for x in indices]) expected_extra1 = np.array([(x + num_adds - replay_capacity) % 2 for x in indices]) expected_next_extra1 = np.array([ (x + 1 + num_adds - replay_capacity) % 2 for x in indices ]) expected_extra2 = np.stack( [ [(x + num_adds - replay_capacity) % 2 for x in indices], np.zeros((len(indices), )), ], axis=1, ) expected_next_extra2 = np.stack( [ [(x + 1 + num_adds - replay_capacity) % 2 for x in indices], np.zeros((len(indices), )), ], axis=1, ) batch = memory.sample_transition_batch(batch_size=len(indices), indices=np.array(indices)) ( states, action, reward, next_states, next_action, next_reward, terminal, indices_batch, extra1, next_extra1, extra2, next_extra2, ) = batch npt.assert_array_equal(states, expected_states) npt.assert_array_equal(action, np.zeros(len(indices))) npt.assert_array_equal(reward, np.zeros(len(indices))) npt.assert_array_equal(next_action, np.zeros(len(indices))) npt.assert_array_equal(next_reward, np.zeros(len(indices))) npt.assert_array_equal(next_states, expected_next_states) npt.assert_array_equal(terminal, expected_terminal) npt.assert_array_equal(indices_batch, indices) npt.assert_array_equal(extra1, expected_extra1) npt.assert_array_equal(next_extra1, expected_next_extra1) npt.assert_array_equal(extra2, expected_extra2) npt.assert_array_equal(next_extra2, expected_next_extra2)
def testConstructor(self): memory = circular_replay_buffer.ReplayBuffer( stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE ) self.assertEqual(memory.add_count, 0)