def test_sample_single_episode(self): num_episodes = 1 sequence_length = 100 batch_size = 10 num_steps = 5 env = test_envs.EpisodeCountingEnv(steps_per_episode=sequence_length) # Insert only one episode in the RB. self._insert_random_data( env, num_steps=num_episodes * sequence_length, sequence_length=sequence_length) replay = reverb_replay_buffer.ReverbReplayBuffer( self._data_spec, self._table_name, sequence_length=sequence_length, local_server=self._server) dataset = replay.as_dataset(batch_size, num_steps=num_steps) n_samples = 0 for sample, _ in dataset.take(10): n_samples += 1 episode, step = sample.observation # The episode should always be 0. episode_id = tf.constant(0, dtype=episode.dtype, shape=episode.shape) # All elements in the same batch should belong to the same episode. self.assertAllEqual(episode_id, episode) for n in range(num_steps): # All elements in the batch should have consecutive steps. self.assertAllEqual(step[:, 0] + n, step[:, n]) # Ensure we can actually sampled 10 times. self.assertEqual(10, n_samples)
def setUp(self): super(ReverbReplayBufferTest, self).setUp() # Prepare the environment (and the corresponding specs). self._env = test_envs.EpisodeCountingEnv(steps_per_episode=3) tensor_time_step_spec = tf.nest.map_structure(tensor_spec.from_spec, self._env.time_step_spec()) tensor_action_spec = tensor_spec.from_spec(self._env.action_spec()) self._data_spec = trajectory.Trajectory( step_type=tensor_time_step_spec.step_type, observation=tensor_time_step_spec.observation, action=tensor_action_spec, policy_info=(), next_step_type=tensor_time_step_spec.step_type, reward=tensor_time_step_spec.reward, discount=tensor_time_step_spec.discount, ) table_spec = tf.nest.map_structure( lambda s: tf.TensorSpec(dtype=s.dtype, shape=(None,) + s.shape), self._data_spec) self._array_data_spec = tensor_spec.to_nest_array_spec(self._data_spec) # Initialize and start a Reverb server (and set up a client to it). self._table_name = 'test_table' uniform_table = reverb.Table( self._table_name, max_size=100, sampler=reverb.selectors.Uniform(), remover=reverb.selectors.Fifo(), rate_limiter=reverb.rate_limiters.MinSize(1), signature=table_spec, ) self._server = reverb.Server([uniform_table]) self._py_client = reverb.Client('localhost:{}'.format(self._server.port))
def test_sequential_ordering(self, num_steps): sequence_length = 10 batch_size = 5 env = test_envs.EpisodeCountingEnv(steps_per_episode=sequence_length) self._insert_random_data( env, num_steps=batch_size * sequence_length, sequence_length=sequence_length) replay = reverb_replay_buffer.ReverbReplayBuffer( self._data_spec, self._table_name, sequence_length=sequence_length, local_server=self._server) dataset = replay.as_dataset(batch_size, num_steps=num_steps) num_steps = num_steps or sequence_length for sample, _ in dataset.take(10): episode, step = sample.observation self.assertEqual((batch_size, num_steps), episode.shape) self.assertEqual((batch_size, num_steps), step.shape) for n in range(num_steps): # All elements in the same batch should belong to the same episode. self.assertAllEqual(episode[:, 0], episode[:, n]) # All elements in the batch should have consecutive steps. self.assertAllEqual(step[:, 0] + n, step[:, n])
def test_batched_episodes_dataset(self, sequence_length): # Observe batch_size * sequence_length steps to have at least 3 episodes batch_size = 3 env = test_envs.EpisodeCountingEnv(steps_per_episode=sequence_length) self._insert_random_data( env, num_steps=batch_size * sequence_length, sequence_length=sequence_length) replay = reverb_replay_buffer.ReverbReplayBuffer( self._data_spec, self._table_name, sequence_length=None, local_server=self._server) dataset = replay.as_dataset(batch_size) for sample, _ in dataset.take(5): episode, step = sample.observation self.assertEqual((batch_size, sequence_length), episode.shape) self.assertEqual((batch_size, sequence_length), step.shape) for n in range(sequence_length): # All elements in the same batch should belong to the same episode. self.assertAllEqual(episode[:, 0], episode[:, n]) # All elements in the same batch should have consecutive steps. self.assertAllEqual(step[:, 0] + n, step[:, n])
def test_variable_length_episodes_dataset(self): # Add one episode of each length. for sequence_length in range(1, 10): env = test_envs.EpisodeCountingEnv(steps_per_episode=sequence_length) self._insert_random_data( env, num_steps=sequence_length, sequence_length=sequence_length) replay = reverb_replay_buffer.ReverbReplayBuffer( self._data_spec, self._table_name, sequence_length=None, local_server=self._server) # Make sure observations are off by 1 given we are counting transitions in # the env observations. dataset = replay.as_dataset(sample_batch_size=1) for sample, _ in dataset.take(5): episode, step = sample.observation self.assertIn(episode.shape[1], range(1, 10)) self.assertIn(step.shape[1], range(1, 10)) length = episode.shape[1] # All Episode id are 0. self.assertAllEqual([[0] * length], episode) # Steps id is sequential up its length. self.assertAllEqual([list(range(length))], step)
def test_sequential(self): num_episodes = 3 steps_per_episode = 4 env = test_envs.EpisodeCountingEnv(steps_per_episode=steps_per_episode) for episode in range(num_episodes): step = 0 time_step = env.reset() self.assertAllEqual((episode, step), time_step.observation) while not time_step.is_last(): time_step = env.step(0) step += 1 self.assertAllEqual((episode, step), time_step.observation) self.assertAllEqual((episode, steps_per_episode), time_step.observation)
def test_validate_specs(self): env = test_envs.EpisodeCountingEnv(steps_per_episode=15) env_utils.validate_py_environment(env, episodes=10)