def testLowCapacity(self): with self.assertRaisesRegexp(ValueError, "There is not enough capacity"): circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=10, replay_capacity=10, batch_size=BATCH_SIZE, update_horizon=1, gamma=1.0, ) with self.assertRaisesRegexp(ValueError, "There is not enough capacity"): circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=5, replay_capacity=10, batch_size=BATCH_SIZE, update_horizon=10, gamma=1.0, ) # We should be able to create a buffer that contains just enough for a # transition. circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=5, replay_capacity=10, batch_size=BATCH_SIZE, update_horizon=5, gamma=1.0, )
def testConstructor(self): memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, ) self.assertEqual(memory._observation_shape, OBSERVATION_SHAPE) # Test with non square observation shape memory = circular_replay_buffer.ReplayBuffer( observation_shape=(4, 20), stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, ) self.assertEqual(memory._observation_shape, (4, 20)) self.assertEqual(memory.add_count, 0) # Test with terminal datatype of np.int32 memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, terminal_dtype=np.int32, replay_capacity=5, batch_size=BATCH_SIZE, ) self.assertEqual(memory._terminal_dtype, np.int32)
def testIsTransitionValid(self): memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=10, batch_size=2, ) memory.add(np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), 0, 0, 0) memory.add(np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), 0, 0, 0) memory.add(np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), 0, 0, 1) # These valids account for the automatically applied padding (3 blanks each # episode. # correct_valids = [0, 0, 0, 1, 1, 0, 0, 0, 0, 0] # The above comment is for the original Dopamine buffer, which doesn't # account for terminal frames within the update_horizon frames before # the cursor. In this case, the frame right before the cursor # is terminal, so even though it is within [c-update_horizon, c], # it should still be valid for sampling, as next state doesn't matter. correct_valids = [0, 0, 0, 1, 1, 1, 0, 0, 0, 0] # The cursor is: ^\ for i in range(10): self.assertEqual( correct_valids[i], memory.is_valid_transition(i), "Index %i should be %s" % (i, bool(correct_valids[i])), )
def testGetStack(self): zero_stack = np.zeros(OBSERVATION_SHAPE + (4,), dtype=OBS_DTYPE) memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=50, batch_size=BATCH_SIZE, ) for i in range(11): memory.add(np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE), 0, 0, 0) # ensure that the returned shapes are always correct for i in range(3, memory.cursor()): self.assertTrue( memory.get_observation_stack(i).shape, OBSERVATION_SHAPE + (4,) ) # ensure that there is the necessary 0 padding stack = memory.get_observation_stack(3) self.assertTrue(np.array_equal(zero_stack, stack)) # ensure that after the padding the contents are properly stored stack = memory.get_observation_stack(6) for i in range(4): self.assertTrue( np.array_equal(np.full(OBSERVATION_SHAPE, i), stack[:, :, i]) )
def testSaveNonNDArrayAttributes(self): """Tests checkpointing an attribute which is not a numpy array.""" memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, ) # Add some non-numpy data: an int, a string, an object. memory.dummy_attribute_1 = 4753849 memory.dummy_attribute_2 = "String data" memory.dummy_attribute_3 = CheckpointableClass() current_iteration = 5 stale_iteration = current_iteration - circular_replay_buffer.CHECKPOINT_DURATION memory.save(self._test_subdir, stale_iteration) for attr in memory.__dict__: if attr.startswith("_"): continue stale_filename = os.path.join( self._test_subdir, "{}_ckpt.{}.gz".format(attr, stale_iteration) ) self.assertTrue(os.path.exists(stale_filename)) memory.save(self._test_subdir, current_iteration) for attr in memory.__dict__: if attr.startswith("_"): continue filename = os.path.join( self._test_subdir, "{}_ckpt.{}.gz".format(attr, current_iteration) ) self.assertTrue(os.path.exists(filename)) # The stale version file should have been deleted. self.assertFalse(os.path.exists(stale_filename))
def testGetRangeWithWraparound(self): # Test the get_range function when the indices wrap around the circular # buffer. In other words, start_index > end_index. memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=10, batch_size=BATCH_SIZE, update_horizon=5, gamma=1.0, ) for _ in range(10): memory.add(np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), 0, 2.0, 0) # The constructed `array` will be: # array([[ 1., 1., 1., 1., 1.], # [ 2., 2., 2., 2., 2.], # [ 3., 3., 3., 3., 3.], # [ 4., 4., 4., 4., 4.], # [ 5., 5., 5., 5., 5.], # [ 6., 6., 6., 6., 6.], # [ 7., 7., 7., 7., 7.], # [ 8., 8., 8., 8., 8.], # [ 9., 9., 9., 9., 9.], # [10., 10., 10., 10., 10.]]) array = np.arange(10).reshape(10, 1) + np.ones(5) sliced_array = memory.get_range(array, 8, 12) # We roll by two, since start_index == 8 and replay_capacity == 10, so the # resulting indices used will be [8, 9, 0, 1]. rolled_array = np.roll(array, 2, axis=0) npt.assert_array_equal(sliced_array, rolled_array[:4])
def testGetRangeNoWraparound(self): # Test the get_range function when the indices do not wrap around the # circular buffer. In other words, start_index < end_index. memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=10, batch_size=BATCH_SIZE, update_horizon=5, gamma=1.0, ) for _ in range(10): memory.add(np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), 0, 2.0, 0) # The constructed `array` will be: # array([[ 1., 1., 1., 1., 1.], # [ 2., 2., 2., 2., 2.], # [ 3., 3., 3., 3., 3.], # [ 4., 4., 4., 4., 4.], # [ 5., 5., 5., 5., 5.], # [ 6., 6., 6., 6., 6.], # [ 7., 7., 7., 7., 7.], # [ 8., 8., 8., 8., 8.], # [ 9., 9., 9., 9., 9.], # [10., 10., 10., 10., 10.]]) array = np.arange(10).reshape(10, 1) + np.ones(5) sliced_array = memory.get_range(array, 2, 5) npt.assert_array_equal(sliced_array, array[2:5])
def testPartialLoadFails(self): memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, ) self.assertNotEqual(memory._store["observation"], self._test_observation) self.assertNotEqual(memory._store["action"], self._test_action) self.assertNotEqual(memory._store["reward"], self._test_reward) self.assertNotEqual(memory._store["terminal"], self._test_terminal) self.assertNotEqual(memory.add_count, self._test_add_count) numpy_arrays = { "observation": self._test_observation, "action": self._test_action, "terminal": self._test_terminal, "add_count": self._test_add_count, "invalid_range": self._test_invalid_range, } for attr in numpy_arrays: filename = os.path.join(self._test_subdir, "{}_ckpt.3.gz".format(attr)) with open(filename, "wb") as f: with gzip.GzipFile(fileobj=f) as outfile: np.save(outfile, numpy_arrays[attr], allow_pickle=False) # We are are missing the reward file, so a NotFoundError will be raised. with self.assertRaises(FileNotFoundError): memory.load(self._test_subdir, "3") # Since we are missing the reward file, it should not have loaded any of # the other files. self.assertNotEqual(memory._store["observation"], self._test_observation) self.assertNotEqual(memory._store["action"], self._test_action) self.assertNotEqual(memory._store["reward"], self._test_reward) self.assertNotEqual(memory._store["terminal"], self._test_terminal) self.assertNotEqual(memory.add_count, self._test_add_count) self.assertNotEqual(memory.invalid_range, self._test_invalid_range)
def testLoad(self): memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, ) self.assertNotEqual(memory._store["observation"], self._test_observation) self.assertNotEqual(memory._store["action"], self._test_action) self.assertNotEqual(memory._store["reward"], self._test_reward) self.assertNotEqual(memory._store["terminal"], self._test_terminal) self.assertNotEqual(memory.add_count, self._test_add_count) self.assertNotEqual(memory.invalid_range, self._test_invalid_range) store_prefix = "$store$_" numpy_arrays = { store_prefix + "observation": self._test_observation, store_prefix + "action": self._test_action, store_prefix + "reward": self._test_reward, store_prefix + "terminal": self._test_terminal, "add_count": self._test_add_count, "invalid_range": self._test_invalid_range, } for attr in numpy_arrays: filename = os.path.join(self._test_subdir, "{}_ckpt.3.gz".format(attr)) with open(filename, "wb") as f: with gzip.GzipFile(fileobj=f) as outfile: np.save(outfile, numpy_arrays[attr], allow_pickle=False) memory.load(self._test_subdir, "3") npt.assert_allclose(memory._store["observation"], self._test_observation) npt.assert_allclose(memory._store["action"], self._test_action) npt.assert_allclose(memory._store["reward"], self._test_reward) npt.assert_allclose(memory._store["terminal"], self._test_terminal) self.assertEqual(memory.add_count, self._test_add_count) npt.assert_allclose(memory.invalid_range, self._test_invalid_range)
def testSave(self): memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, ) memory.observation = self._test_observation memory.action = self._test_action memory.reward = self._test_reward memory.terminal = self._test_terminal current_iteration = 5 stale_iteration = current_iteration - circular_replay_buffer.CHECKPOINT_DURATION memory.save(self._test_subdir, stale_iteration) for attr in memory.__dict__: if attr.startswith("_"): continue stale_filename = os.path.join( self._test_subdir, "{}_ckpt.{}.gz".format(attr, stale_iteration) ) self.assertTrue(os.path.exists(stale_filename)) memory.save(self._test_subdir, current_iteration) for attr in memory.__dict__: if attr.startswith("_"): continue filename = os.path.join( self._test_subdir, "{}_ckpt.{}.gz".format(attr, current_iteration) ) self.assertTrue(os.path.exists(filename)) # The stale version file should have been deleted. self.assertFalse(os.path.exists(stale_filename))
def testSampleTransitionBatch(self): replay_capacity = 10 memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=1, replay_capacity=replay_capacity, batch_size=2, ) num_adds = 50 # The number of transitions to add to the memory. for i in range(num_adds): memory.add( np.full(OBSERVATION_SHAPE, i, OBS_DTYPE), 0, 0, i % 4 ) # Every 4 transitions is terminal. # Test sampling with default batch size. for _i in range(1000): batch = memory.sample_transition_batch() self.assertEqual(batch[0].shape[0], 2) # Test changing batch sizes. for _i in range(1000): batch = memory.sample_transition_batch(BATCH_SIZE) self.assertEqual(batch[0].shape[0], BATCH_SIZE) # Verify we revert to default batch size. for _i in range(1000): batch = memory.sample_transition_batch() self.assertEqual(batch[0].shape[0], 2) # Verify we can specify what indices to sample. indices = [1, 2, 3, 5, 8] expected_states = np.array( [np.full(OBSERVATION_SHAPE + (1,), i, dtype=OBS_DTYPE) for i in indices] ) expected_next_states = (expected_states + 1) % replay_capacity # Because the replay buffer is circular, we can exactly compute what the # states will be at the specified indices by doing a little mod math: expected_states += num_adds - replay_capacity expected_next_states += num_adds - replay_capacity # This is replicating the formula that was used above to determine what # transitions are terminal when adding observation (i % 4). expected_terminal = np.array( [min((x + num_adds - replay_capacity) % 4, 1) for x in indices] ) batch = memory.sample_transition_batch(batch_size=len(indices), indices=indices) ( states, action, reward, next_states, next_action, next_reward, terminal, indices_batch, ) = batch npt.assert_array_equal(states, expected_states) npt.assert_array_equal(action, np.zeros(len(indices))) npt.assert_array_equal(reward, np.zeros(len(indices))) npt.assert_array_equal(next_action, np.zeros(len(indices))) npt.assert_array_equal(next_reward, np.zeros(len(indices))) npt.assert_array_equal(next_states, expected_next_states) npt.assert_array_equal(terminal, expected_terminal) npt.assert_array_equal(indices_batch, indices)
def testWithNontupleObservationShape(self): with self.assertRaises(AssertionError): _ = circular_replay_buffer.ReplayBuffer( observation_shape=84, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, )
def testAdd(self): memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, ) self.assertEqual(memory.cursor(), 0) zeros = np.zeros(OBSERVATION_SHAPE) memory.add(zeros, 0, 0, 0) # Check if the cursor moved STACK_SIZE -1 padding adds + 1, (the one above). self.assertEqual(memory.cursor(), STACK_SIZE)
def testCheckAddTypes(self): memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, extra_storage_types=[ circular_replay_buffer.ReplayElement("extra1", [], np.float32), circular_replay_buffer.ReplayElement("extra2", [2], np.int8), ], ) zeros = np.zeros(OBSERVATION_SHAPE) memory._check_add_types(zeros, 0, 0, 0, 0, [0, 0]) with self.assertRaisesRegexp(ValueError, "Add expects"): memory._check_add_types(zeros, 0, 0, 0)
def testLoadFromNonexistentDirectory(self): memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, ) # We are trying to load from a non-existent directory, so a NotFoundError # will be raised. with self.assertRaises(FileNotFoundError): memory.load("/does/not/exist", "3") self.assertNotEqual(memory._store["observation"], self._test_observation) self.assertNotEqual(memory._store["action"], self._test_action) self.assertNotEqual(memory._store["reward"], self._test_reward) self.assertNotEqual(memory._store["terminal"], self._test_terminal) self.assertNotEqual(memory.add_count, self._test_add_count) self.assertNotEqual(memory.invalid_range, self._test_invalid_range)
def testNSteprewardum(self): memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=10, batch_size=BATCH_SIZE, update_horizon=5, gamma=1.0, ) for i in range(50): memory.add(np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE), 0, 2.0, 0) for _i in range(100): batch = memory.sample_transition_batch() # Make sure the total reward is reward per step x update_horizon. self.assertEqual(batch[2][0], 10.0)
def testExtraAdd(self): memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=5, batch_size=BATCH_SIZE, extra_storage_types=[ circular_replay_buffer.ReplayElement("extra1", [], np.float32), circular_replay_buffer.ReplayElement("extra2", [2], np.int8), ], ) self.assertEqual(memory.cursor(), 0) zeros = np.zeros(OBSERVATION_SHAPE) memory.add(zeros, 0, 0, 0, 0, [0, 0]) with self.assertRaisesRegexp(ValueError, "Add expects"): memory.add(zeros, 0, 0, 0) # Check if the cursor moved STACK_SIZE -1 zeros adds + 1, (the one above). self.assertEqual(memory.cursor(), STACK_SIZE)
def testSamplingWithterminalInTrajectory(self): replay_capacity = 10 update_horizon = 3 memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=1, replay_capacity=replay_capacity, batch_size=2, update_horizon=update_horizon, gamma=1.0, ) for i in range(replay_capacity): memory.add( np.full(OBSERVATION_SHAPE, i, dtype=OBS_DTYPE), i * 2, # action i, # reward 1 if i == 3 else 0, ) # terminal indices = [2, 3, 4] batch = memory.sample_transition_batch( batch_size=len(indices), indices=np.array(indices) ) states, action, reward, _, _, _, terminal, indices_batch = batch expected_states = np.array( [np.full(OBSERVATION_SHAPE + (1,), i, dtype=OBS_DTYPE) for i in indices] ) # The reward in the replay buffer will be (an asterisk marks the terminal # state): # [0 1 2 3* 4 5 6 7 8 9] # Since we're setting the update_horizon to 3, the accumulated trajectory # reward starting at each of the replay buffer positions will be: # [3 6 5 3 15 18 21 24] # Since indices = [2, 3, 4], our expected reward are [5, 3, 15]. expected_reward = np.array([5, 3, 15]) # Because update_horizon = 3, both indices 2 and 3 include terminal. expected_terminal = np.array([1, 1, 0]) npt.assert_array_equal(states, expected_states) npt.assert_array_equal(action, np.array(indices) * 2) npt.assert_array_equal(reward, expected_reward) npt.assert_array_equal(terminal, expected_terminal) npt.assert_array_equal(indices_batch, indices)
def testGetRangeInvalidIndexOrder(self): replay_capacity = 10 memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=replay_capacity, batch_size=BATCH_SIZE, update_horizon=5, gamma=1.0, ) with self.assertRaisesRegexp( AssertionError, "end_index must be larger than start_index" ): memory.get_range([], 2, 1) with self.assertRaises(AssertionError): # Negative end_index. memory.get_range([], 1, -1) with self.assertRaises(AssertionError): # Start index beyond replay capacity. memory.get_range([], replay_capacity, replay_capacity + 1) with self.assertRaisesRegexp(AssertionError, "Index 1 has not been added."): memory.get_range([], 1, 2)
def testIsTransitionValid(self): memory = circular_replay_buffer.ReplayBuffer( observation_shape=OBSERVATION_SHAPE, stack_size=STACK_SIZE, replay_capacity=10, batch_size=2, ) memory.add(np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), 0, 0, 0) memory.add(np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), 0, 0, 0) memory.add(np.full(OBSERVATION_SHAPE, 0, dtype=OBS_DTYPE), 0, 0, 1) # These valids account for the automatically applied padding (3 blanks each # episode. correct_valids = [0, 0, 0, 1, 1, 0, 0, 0, 0, 0] # The cursor is: ^\ for i in range(10): self.assertEqual( correct_valids[i], memory.is_valid_transition(i), "Index %i should be %s" % (i, bool(correct_valids[i])), )