def test_iteration(self): lens = [10, 5, 7] rb = replay_buffer.ReplayBuffer(lens[0] + lens[1]) trajs = [self.get_random_trajectory(max_time_step=l) for l in lens] # buffer will have traj0 only. rb.store(trajs[0]) idx = rb.get_unrolled_indices() start_end_pairs = [(idx[p[0]], idx[p[1] - 1]) for p in rb.iterate_over_paths(idx)] self.assertEqual([(0, 9)], start_end_pairs) # buffer will have traj0 and traj1. rb.store(trajs[1]) idx = rb.get_unrolled_indices() start_end_pairs = [(idx[p[0]], idx[p[1] - 1]) for p in rb.iterate_over_paths(idx)] self.assertEqual([(0, 9), (10, 14)], start_end_pairs) # buffer will have traj1 and traj2, traj0 is booted out. rb.store(trajs[2]) idx = rb.get_unrolled_indices() start_end_pairs = [(idx[p[0]], idx[p[1] - 1]) for p in rb.iterate_over_paths(idx)] self.assertEqual([(10, 14), (0, 6)], start_end_pairs)
def test_valid_indices(self): lens = [10, 5, 7] rb = replay_buffer.ReplayBuffer(lens[0] + lens[1]) trajs = [self.get_random_trajectory(max_time_step=l) for l in lens] for traj in trajs: rb.store(traj) # Now the buffer looks like [traj3 <gap> traj2] self.assertLess(rb.buffer_head, rb.buffer_tail) idx, valid_mask, valid_idx = rb.get_valid_indices() np.testing.assert_array_equal( idx, np.array([10, 11, 12, 13, 14, 0, 1, 2, 3, 4, 5, 6])) np.testing.assert_array_equal( valid_idx, np.array([[10, 0], [11, 1], [12, 2], [13, 3], [0, 5], [1, 6], [2, 7], [3, 8], [4, 9], [5, 10]], )) np.testing.assert_array_equal( valid_mask, np.array([ True, True, True, True, False, True, True, True, True, True, True, False ]))
def __init__(self, train_env: env_problem.EnvProblem, eval_env: env_problem.EnvProblem, td_lambda=0.95, gamma=0.99, replay_buffer_sample_size=50000, num_samples_to_collect=2048, temperature=0.05, weight_clip=20, actor_batch_size=256, critic_batch_size=256, actor_optimization_steps=1000, critic_optimization_steps=500, actor_momentum=0.9, critic_momentum=0.9, actor_learning_rate=5e-5, critic_learning_rate=1e-4, actor_loss_weight=1.0, entropy_bonus=0.01, **kwargs): super(AwrTrainer, self).__init__(train_env, eval_env, **kwargs) self._td_lambda = td_lambda self._gamma = gamma self._replay_buffer_sample_size = replay_buffer_sample_size self._num_samples_to_collect = num_samples_to_collect self._temperature = temperature self._weight_clip = weight_clip self._actor_batch_size = actor_batch_size self._critic_batch_size = critic_batch_size self._actor_optimization_steps = actor_optimization_steps self._critic_optimization_steps = critic_optimization_steps self._actor_momentum = actor_momentum self._critic_momentum = critic_momentum self._actor_learning_rate = actor_learning_rate self._critic_learning_rate = critic_learning_rate self._actor_loss_weight = actor_loss_weight self._entropy_bonus = entropy_bonus # Unified loss. self._optimization_batch_size = critic_batch_size self._optimization_steps = critic_optimization_steps self._momentum = critic_momentum self._learning_rate = critic_learning_rate self._replay_buffer = replay_buffer.ReplayBuffer( buffer_size=replay_buffer_sample_size) # self._action_space and _observation_space were set in the base class. self._action_shape = self._action_space.shape self._action_dtype = self._action_space.dtype self._observation_shape = self._observation_space.shape self._observation_dtype = self._observation_space.dtype # TODO(afrozm): Offload all these to `trainer_lib.Trainer`. self._total_opt_step = 0 # TODO(afrozm): Ensure that this is updated. self._n_observations_seen = 0 self._opt_sw = None
def test_replay_buffer_to_padded_observations(self): traj_lengths = [10, 15, 17] obs_shape = (3, 4) t_final = 20 # lowest multiple of 10 that is sufficient. trajs = [ self.get_random_trajectory(max_time_step=l, obs_shape=obs_shape) for l in traj_lengths ] rb = replay_buffer.ReplayBuffer(2 * sum(traj_lengths)) for traj in trajs: rb.store(traj) padded_obs, mask = awr_utils.replay_buffer_to_padded_observations( rb, None, None) self.assertEqual((len(traj_lengths), t_final) + obs_shape, padded_obs.shape) # pylint: disable=line-too-long self.assertTrue( all((mask == onp.array([[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0. ], [ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0. ], [ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0. ]])).flatten())) # pylint: enable=line-too-long t_final = 6 * 3 # 18 is enough to cover everything. padded_obs, mask = awr_utils.replay_buffer_to_padded_observations( rb, None, 6) self.assertEqual((len(traj_lengths), t_final) + obs_shape, padded_obs.shape) # pylint: disable=line-too-long self.assertTrue( all((mask == onp.array([[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0. ], [ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0. ], [ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0. ]])).flatten()))
def test_replay_buffer_to_padded_rewards(self): traj_lengths = [10, 15, 17] obs_shape = (3, 4) t_final = 20 # lowest multiple of 10 that is sufficient. trajs = [ self.get_random_trajectory(max_time_step=l, obs_shape=obs_shape) for l in traj_lengths ] rb = replay_buffer.ReplayBuffer(2 * sum(traj_lengths)) for traj in trajs: rb.store(traj) idx = rb.get_unrolled_indices() padded_rewards, mask = awr_utils.replay_buffer_to_padded_rewards( rb, idx, t_final - 1) # pylint: disable=line-too-long self.assertTrue( all((padded_rewards == onp.array([[ 1., 2., 3., 4., 5., 6., 7., 8., 9., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0. ], [ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 0., 0., 0., 0., 0. ], [ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 0., 0., 0. ]])).flatten())) self.assertTrue( all((mask == onp.array([[ 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0. ], [ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0. ], [ 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0. ]])).flatten()))
def test_add_three_trajectories(self): n1 = 10 t1 = self.get_random_trajectory(max_time_step=n1) n2 = 5 t2 = self.get_random_trajectory(max_time_step=n2) # Make a buffer of just sufficient size to hold these two. rb = replay_buffer.ReplayBuffer(n1 + n2) # import pdb; pdb.set_trace() start_index_t1 = rb.store(t1) # Stored at the beginning. self.assertEqual(0, start_index_t1) # One path stored in total. self.assertEqual(1, rb.num_paths) # Total number of states stored, ever. self.assertEqual(n1, rb.total_count) # The current number of states stored. self.assertEqual(n1, rb.get_current_size()) start_index_t2 = rb.store(t2) # Stored right afterwards self.assertEqual(n1, start_index_t2) self.assertEqual(2, rb.num_paths) self.assertEqual(n1 + n2, rb.total_count) self.assertEqual(n1 + n2, rb.get_current_size()) # We now make a path that is smaller than n1. # Since there is no more space in the buffer, t1 will be ejected. # t2 will remain there. n3 = 6 assert n3 < n1 t3 = self.get_random_trajectory(max_time_step=n3) start_index_t3 = rb.store(t3) self.assertEqual(0, start_index_t3) self.assertEqual(2, rb.num_paths) self.assertEqual(n1 + n2 + n3, rb.total_count) self.assertEqual(n2 + n3, rb.get_current_size()) # So the first n3 rb.buffers[replay_buffer.ReplayBuffer.PATH_START_KEY] will # be 0, the next n1 - n3 will be -1, and the rest will be start_index_t2. path_start_array = ([0] * n3) + ([-1] * (n1 - n3)) + ([start_index_t2] * n2) np.testing.assert_array_equal( path_start_array, rb.buffers[replay_buffer.ReplayBuffer.PATH_START_KEY]) # The unrolled indices will be first t2s indices, then t3s. unrolled_indices = [start_index_t2 + x for x in range(n2) ] + [start_index_t3 + x for x in range(n3)] np.testing.assert_array_equal(unrolled_indices, rb.get_unrolled_indices()) invalid_indices = [start_index_t3 + n3 - 1, start_index_t2 + n2 - 1] # Let's sample a really large sample. n = 1000 sample_valid_indices = rb.sample(n, filter_end=True) sample_all_indices = rb.sample(n, filter_end=False) self.assertNotIn(invalid_indices[0], sample_valid_indices) self.assertNotIn(invalid_indices[1], sample_valid_indices) # Holds w.h.p. self.assertIn(invalid_indices[0], sample_all_indices) self.assertIn(invalid_indices[1], sample_all_indices)