Exemple #1
0
class TFPrioritizedReplayBuffer(TFReplayBufferAbstract):
    _curr_beta = None

    def __init__(self, collect_data_spec, alpha, beta, training_iterations_num):
        """
        Store replay buffer params

        Params:
            collect_data_spec: spec of the data to be added to the buffer
            alpha: This param is used to determine how much emphasis is given to the priority
            beta:
            training_iterations_num:
        """
        super().__init__(collect_data_spec)
        self._beta = beta
        self._alpha = alpha
        self._training_iterations_total_num = training_iterations_num
        self._metric_tracker = None

    def _init_replay_buffer(self, batch_size, traj_spec):

        buffer_config = {
            "batch_size": batch_size,
            "data_spec": traj_spec,
            "max_length": 1,
            "alpha": self._alpha
        }
        tf.compat.v2.summary.scalar(name="replay_buffer_size", data=batch_size)
        self._replay_buffer = TFReplayBuffer(**buffer_config)

    def add_batch(self, traj_dict):
        """
        add a trajectory to the replay buffer

        Params
            traj (dict[dim]:numpy): a dict of tensors representing the trajectory to be added it to the replay buffer
        """
        collect_spec_dict = self.collect_data_spec._asdict()
        traj_tf, traj_spec = build_tf_trajectory(traj_dict, collect_spec_dict)

        if not self._replay_buffer:
            batch_size = len(traj_dict["observation"])
            self._init_replay_buffer(batch_size, traj_spec)

        self._replay_buffer.add_batch(traj_tf)

    def get_batch(self, batch_size):
        traj, metadata = self._replay_buffer.get_next(sample_batch_size=batch_size, beta=self._curr_beta)

        self._metric_tracker.add_batch_weights(metadata.probabilities)
        self._metric_tracker.add_batch_indices(metadata.ids)

        return traj, metadata

    def pre_process(self, curr_iter):

        if not self._metric_tracker:
            self._metric_tracker = TrainingMetricTracker(self._training_iterations_total_num)

        # compute the beta that will be used when computing the importance sampling weights
        self._curr_beta = self._replay_buffer.compute_beta(self._beta, curr_iter, self._training_iterations_total_num)
        # add important data to the metric tracker
        self._metric_tracker.latest_iteration = curr_iter
        self._metric_tracker.latest_beta = self._curr_beta

    def post_process(self, traj_meta, loss_info, curr_iter):
        indices = traj_meta.ids.numpy()
        # get the loss of every experience using during the training. it is stored in DQNLossInfo
        td_loss = loss_info[1].td_loss.numpy()

        # make sure the td loss array has the same size as the batch
        if td_loss.shape != indices.shape:
            raise Exception("Expected the shape of the loss '%s' to be the same as the shape of the "
                            "indices '%s'" % (str(td_loss.shape), len(indices.shape)))

        # update the prioritized replay buffer
        self._replay_buffer.update_priorities(indices, td_loss)

        self._metric_tracker.log_partial_metrics()
        self._metric_tracker.latest_loss_info = loss_info

        if curr_iter == self._training_iterations_total_num:
            self._metric_tracker.log_summary_metrics()
    def test_prioritized_replay_buffer_as_dataset(self):
        np.random.seed(123)

        buffer_batch_size = 10
        alpha = 0.6
        spec = specs.TensorSpec([], tf.int32, 'action')
        replay_buffer = TfPrioritizedReplayBuffer(spec,
                                                  batch_size=buffer_batch_size,
                                                  max_length=1,
                                                  alpha=alpha)

        # make sure that the priority are set to 0 since the buffer is empty
        expected_priority = np.zeros((buffer_batch_size, ), dtype=np.float32)
        for i in range(buffer_batch_size):
            if i >= buffer_batch_size:
                break
            expected_priority[i] = 1.0

        experience = []
        experience_shape = (1, )
        for k in range(buffer_batch_size):
            experience.append(np.full(experience_shape, k, dtype=np.int32))

        tf_experience = tf.convert_to_tensor(experience)
        replay_buffer.add_batch(tf_experience)

        sample_batch_size = 10
        beta = 0.4

        sample_frequency = [0 for _ in range(10)]
        for i in range(15 * 3):
            ds = replay_buffer.as_dataset(sample_batch_size=sample_batch_size,
                                          beta=beta)
            itr = iter(ds)
            for j in range(int(100 / 3)):
                mini_batch, metadata = next(itr)
                indices_tf = metadata.ids
                # indices = self.evaluate(indices_tf)
                indices = indices_tf.numpy()
                if i % 100 == 0:
                    self.validate_data(mini_batch, indices)

                for idx in indices:
                    sample_frequency[idx] += 1

                # set the loss of numbers larger 5 to be equal to their number
                # set the loss of numbers smaller or equal to 5 close to 0

                priorities = [i if i > 5 else i / 10 for i in indices]

                replay_buffer.update_priorities(indices, priorities)

        for i in range(10):
            if i <= 5:
                # numbers smaller than 5 should be picked less that 1% of the time
                self.assertLessEqual(sample_frequency[i], 15000 * 5 / 100)
            else:
                # all numbers larger than 5 should be picked between 15% and 25% of the time
                self.assertGreaterEqual(sample_frequency[i], 15000 * 15 / 100)
                self.assertLessEqual(sample_frequency[i], 15000 * 30 / 100)

                # all numbers larger than 5 should be selected more times than the numbers which precedes them and
                # less time than the numbers that follows them
                self.assertGreaterEqual(sample_frequency[i],
                                        sample_frequency[i - 1])
                if i < 9:
                    self.assertLessEqual(sample_frequency[i],
                                         sample_frequency[i + 1])

        indices = [i for i in range(10)]
        priorities = [1 for _ in range(10)]

        replay_buffer.update_priorities(indices, priorities)
        np.random.seed(12323423)
        # set the loss of numbers larger or equal 5 to be close to 0
        # set the loss of numbers smaller to 5 to their number + 5
        sample_frequency = [0 for _ in range(10)]
        for i in range(15 * 20):
            ds = replay_buffer.as_dataset(sample_batch_size=sample_batch_size,
                                          beta=beta)
            itr = iter(ds)
            for j in range(int(100 / 20)):
                mini_batch, metadata = next(itr)
                indices_tf = metadata.ids

                indices = indices_tf.numpy()
                if i % 100 == 0:
                    self.validate_data(mini_batch, indices)

                for idx in indices:
                    sample_frequency[idx] += 1

                # set the loss of numbers larger 5 to be equal to their number
                # set the loss of numbers smaller or equal to 5 close to 0

                priorities = [i / 10 if i >= 5 else i + 5 for i in indices]
                replay_buffer.update_priorities(indices, priorities)

        for i in range(10):
            if i >= 5:
                # numbers larger than 5 should be picked less that 1% of the time
                self.assertLessEqual(sample_frequency[i], 15000 * 5 / 100)
            else:
                # all numbers smaller or equal to 5 should be picked between 12% and 20% of the time
                self.assertGreaterEqual(sample_frequency[i], 15000 * 10 / 100)
                self.assertLessEqual(sample_frequency[i], 15000 * 25 / 100)

                # all numbers smaller or equal to 5 should be selected more times than the numbers which precedes
                # them and less time than the numbers that follows them
                self.assertGreaterEqual(sample_frequency[i],
                                        sample_frequency[i - 1])
                if i < 4:
                    self.assertLessEqual(sample_frequency[i],
                                         sample_frequency[i + 1])