class TFPrioritizedReplayBuffer(TFReplayBufferAbstract): _curr_beta = None def __init__(self, collect_data_spec, alpha, beta, training_iterations_num): """ Store replay buffer params Params: collect_data_spec: spec of the data to be added to the buffer alpha: This param is used to determine how much emphasis is given to the priority beta: training_iterations_num: """ super().__init__(collect_data_spec) self._beta = beta self._alpha = alpha self._training_iterations_total_num = training_iterations_num self._metric_tracker = None def _init_replay_buffer(self, batch_size, traj_spec): buffer_config = { "batch_size": batch_size, "data_spec": traj_spec, "max_length": 1, "alpha": self._alpha } tf.compat.v2.summary.scalar(name="replay_buffer_size", data=batch_size) self._replay_buffer = TFReplayBuffer(**buffer_config) def add_batch(self, traj_dict): """ add a trajectory to the replay buffer Params traj (dict[dim]:numpy): a dict of tensors representing the trajectory to be added it to the replay buffer """ collect_spec_dict = self.collect_data_spec._asdict() traj_tf, traj_spec = build_tf_trajectory(traj_dict, collect_spec_dict) if not self._replay_buffer: batch_size = len(traj_dict["observation"]) self._init_replay_buffer(batch_size, traj_spec) self._replay_buffer.add_batch(traj_tf) def get_batch(self, batch_size): traj, metadata = self._replay_buffer.get_next(sample_batch_size=batch_size, beta=self._curr_beta) self._metric_tracker.add_batch_weights(metadata.probabilities) self._metric_tracker.add_batch_indices(metadata.ids) return traj, metadata def pre_process(self, curr_iter): if not self._metric_tracker: self._metric_tracker = TrainingMetricTracker(self._training_iterations_total_num) # compute the beta that will be used when computing the importance sampling weights self._curr_beta = self._replay_buffer.compute_beta(self._beta, curr_iter, self._training_iterations_total_num) # add important data to the metric tracker self._metric_tracker.latest_iteration = curr_iter self._metric_tracker.latest_beta = self._curr_beta def post_process(self, traj_meta, loss_info, curr_iter): indices = traj_meta.ids.numpy() # get the loss of every experience using during the training. it is stored in DQNLossInfo td_loss = loss_info[1].td_loss.numpy() # make sure the td loss array has the same size as the batch if td_loss.shape != indices.shape: raise Exception("Expected the shape of the loss '%s' to be the same as the shape of the " "indices '%s'" % (str(td_loss.shape), len(indices.shape))) # update the prioritized replay buffer self._replay_buffer.update_priorities(indices, td_loss) self._metric_tracker.log_partial_metrics() self._metric_tracker.latest_loss_info = loss_info if curr_iter == self._training_iterations_total_num: self._metric_tracker.log_summary_metrics()
def test_prioritized_replay_buffer_as_dataset(self): np.random.seed(123) buffer_batch_size = 10 alpha = 0.6 spec = specs.TensorSpec([], tf.int32, 'action') replay_buffer = TfPrioritizedReplayBuffer(spec, batch_size=buffer_batch_size, max_length=1, alpha=alpha) # make sure that the priority are set to 0 since the buffer is empty expected_priority = np.zeros((buffer_batch_size, ), dtype=np.float32) for i in range(buffer_batch_size): if i >= buffer_batch_size: break expected_priority[i] = 1.0 experience = [] experience_shape = (1, ) for k in range(buffer_batch_size): experience.append(np.full(experience_shape, k, dtype=np.int32)) tf_experience = tf.convert_to_tensor(experience) replay_buffer.add_batch(tf_experience) sample_batch_size = 10 beta = 0.4 sample_frequency = [0 for _ in range(10)] for i in range(15 * 3): ds = replay_buffer.as_dataset(sample_batch_size=sample_batch_size, beta=beta) itr = iter(ds) for j in range(int(100 / 3)): mini_batch, metadata = next(itr) indices_tf = metadata.ids # indices = self.evaluate(indices_tf) indices = indices_tf.numpy() if i % 100 == 0: self.validate_data(mini_batch, indices) for idx in indices: sample_frequency[idx] += 1 # set the loss of numbers larger 5 to be equal to their number # set the loss of numbers smaller or equal to 5 close to 0 priorities = [i if i > 5 else i / 10 for i in indices] replay_buffer.update_priorities(indices, priorities) for i in range(10): if i <= 5: # numbers smaller than 5 should be picked less that 1% of the time self.assertLessEqual(sample_frequency[i], 15000 * 5 / 100) else: # all numbers larger than 5 should be picked between 15% and 25% of the time self.assertGreaterEqual(sample_frequency[i], 15000 * 15 / 100) self.assertLessEqual(sample_frequency[i], 15000 * 30 / 100) # all numbers larger than 5 should be selected more times than the numbers which precedes them and # less time than the numbers that follows them self.assertGreaterEqual(sample_frequency[i], sample_frequency[i - 1]) if i < 9: self.assertLessEqual(sample_frequency[i], sample_frequency[i + 1]) indices = [i for i in range(10)] priorities = [1 for _ in range(10)] replay_buffer.update_priorities(indices, priorities) np.random.seed(12323423) # set the loss of numbers larger or equal 5 to be close to 0 # set the loss of numbers smaller to 5 to their number + 5 sample_frequency = [0 for _ in range(10)] for i in range(15 * 20): ds = replay_buffer.as_dataset(sample_batch_size=sample_batch_size, beta=beta) itr = iter(ds) for j in range(int(100 / 20)): mini_batch, metadata = next(itr) indices_tf = metadata.ids indices = indices_tf.numpy() if i % 100 == 0: self.validate_data(mini_batch, indices) for idx in indices: sample_frequency[idx] += 1 # set the loss of numbers larger 5 to be equal to their number # set the loss of numbers smaller or equal to 5 close to 0 priorities = [i / 10 if i >= 5 else i + 5 for i in indices] replay_buffer.update_priorities(indices, priorities) for i in range(10): if i >= 5: # numbers larger than 5 should be picked less that 1% of the time self.assertLessEqual(sample_frequency[i], 15000 * 5 / 100) else: # all numbers smaller or equal to 5 should be picked between 12% and 20% of the time self.assertGreaterEqual(sample_frequency[i], 15000 * 10 / 100) self.assertLessEqual(sample_frequency[i], 15000 * 25 / 100) # all numbers smaller or equal to 5 should be selected more times than the numbers which precedes # them and less time than the numbers that follows them self.assertGreaterEqual(sample_frequency[i], sample_frequency[i - 1]) if i < 4: self.assertLessEqual(sample_frequency[i], sample_frequency[i + 1])