Ejemplo n.º 1
0
class SyncUniformExperienceReplayer(ExperienceReplayer):
    """
    For synchronous off-policy training.

    Example algorithms: DDPG, SAC
    """
    def __init__(self, experience_spec, batch_size):
        self._experience_spec = experience_spec
        self._buffer = ReplayBuffer(experience_spec, batch_size)
        self._data_iter = None

    @tf.function
    def observe(self, exp):
        """Store one batch of experience into replay buffer.

        Args:
            exp (Experience): input experience to be stored.

        For the sync driver, `exp` has the shape (`env_batch_size`, ...)
        with `num_envs`==1 and `unroll_length`==1.
        """
        outer_rank = get_outer_rank(exp, self._experience_spec)

        if outer_rank == 1:
            self._buffer.add_batch(exp, exp.env_id)
        elif outer_rank == 3:
            # The shape is [learn_queue_cap, unroll_length, env_batch_size, ...]
            for q in tf.range(tf.shape(exp.step_type)[0]):
                for t in tf.range(tf.shape(exp.step_type)[1]):
                    bat = tf.nest.map_structure(lambda x: x[q, t, ...], exp)
                    self._buffer.add_batch(bat, bat.env_id)
        else:
            raise ValueError("Unsupported outer rank %s of `exp`" % outer_rank)

    def replay(self, sample_batch_size, mini_batch_length):
        """Get a random batch.

        Args:
            sample_batch_size (int): number of sequences
            mini_batch_length (int): the length of each sequence
        Returns:
            Experience: experience batch in batch major (B, T, ...)
        """
        return self._buffer.get_batch(sample_batch_size, mini_batch_length)

    def replay_all(self):
        return self._buffer.gather_all()

    def clear(self):
        self._buffer.clear()

    @property
    def batch_size(self):
        return self._buffer.num_environments
Ejemplo n.º 2
0
    def test_replay_buffer(self, allow_multiprocess, with_replacement):
        replay_buffer = ReplayBuffer(data_spec=self.data_spec,
                                     num_environments=self.num_envs,
                                     max_length=self.max_length,
                                     allow_multiprocess=allow_multiprocess)

        batch1 = get_batch([0, 4, 7], self.dim, t=0, x=0.1)
        replay_buffer.add_batch(batch1, batch1.env_id)
        self.assertEqual(replay_buffer._current_size,
                         torch.tensor([1, 0, 0, 0, 1, 0, 0, 1]))
        self.assertEqual(replay_buffer._current_pos,
                         torch.tensor([1, 0, 0, 0, 1, 0, 0, 1]))
        self.assertRaises(AssertionError, replay_buffer.get_batch, 8, 1)

        batch2 = get_batch([1, 2, 3, 5, 6], self.dim, t=0, x=0.2)
        replay_buffer.add_batch(batch2, batch2.env_id)
        self.assertEqual(replay_buffer._current_size,
                         torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]))
        self.assertEqual(replay_buffer._current_pos,
                         torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]))

        batch = replay_buffer.gather_all()
        self.assertEqual(list(batch.t.shape), [8, 1])
        # test that RingBuffer detaches gradients of inputs
        self.assertFalse(batch.x.requires_grad)

        self.assertRaises(AssertionError, replay_buffer.get_batch, 8, 2)
        replay_buffer.get_batch(13, 1)[0]

        batch = replay_buffer.get_batch(8, 1)[0]
        # squeeze the time dimension
        batch = alf.nest.map_structure(lambda bat: bat.squeeze(1), batch)
        bat1 = alf.nest.map_structure(lambda bat: bat[batch1.env_id], batch)
        bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id], batch)
        self.assertEqual(bat1.env_id, batch1.env_id)
        self.assertEqual(bat1.x, batch1.x)
        self.assertEqual(bat1.t, batch1.t)
        self.assertEqual(bat2.env_id, batch2.env_id)
        self.assertEqual(bat2.x, batch2.x)
        self.assertEqual(bat2.t, batch2.t)

        for t in range(1, 10):
            batch3 = get_batch([0, 4, 7], self.dim, t=t, x=0.3)
            j = t + 1
            s = min(t + 1, self.max_length)
            replay_buffer.add_batch(batch3, batch3.env_id)
            self.assertEqual(replay_buffer._current_size,
                             torch.tensor([s, 1, 1, 1, s, 1, 1, s]))
            self.assertEqual(replay_buffer._current_pos,
                             torch.tensor([j, 1, 1, 1, j, 1, 1, j]))

        batch2 = get_batch([1, 2, 3, 5, 6], self.dim, t=1, x=0.2)
        replay_buffer.add_batch(batch2, batch2.env_id)
        batch = replay_buffer.get_batch(8, 1)[0]
        # squeeze the time dimension
        batch = alf.nest.map_structure(lambda bat: bat.squeeze(1), batch)
        bat3 = alf.nest.map_structure(lambda bat: bat[batch3.env_id], batch)
        bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id], batch)
        self.assertEqual(bat3.env_id, batch3.env_id)
        self.assertEqual(bat3.x, batch3.x)
        self.assertEqual(bat2.env_id, batch2.env_id)
        self.assertEqual(bat2.x, batch2.x)

        batch = replay_buffer.get_batch(8, 2)[0]
        t2 = []
        t3 = []
        for t in range(2):
            batch_t = alf.nest.map_structure(lambda b: b[:, t], batch)
            bat3 = alf.nest.map_structure(lambda bat: bat[batch3.env_id],
                                          batch_t)
            bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id],
                                          batch_t)
            t2.append(bat2.t)
            self.assertEqual(bat3.env_id, batch3.env_id)
            self.assertEqual(bat3.x, batch3.x)
            self.assertEqual(bat2.env_id, batch2.env_id)
            self.assertEqual(bat2.x, batch2.x)
            t3.append(bat3.t)

        # Test time consistency
        self.assertEqual(t2[0] + 1, t2[1])
        self.assertEqual(t3[0] + 1, t3[1])

        batch = replay_buffer.get_batch(128, 2)[0]
        self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1])
        self.assertEqual(list(batch.t.shape), [128, 2])

        batch = replay_buffer.get_batch(10, 2)[0]
        self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1])
        self.assertEqual(list(batch.t.shape), [10, 2])

        batch = replay_buffer.get_batch(4, 2)[0]
        self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1])
        self.assertEqual(list(batch.t.shape), [4, 2])

        # Test gather_all()
        # Exception because the size of all the environments are not same
        self.assertRaises(AssertionError, replay_buffer.gather_all)

        for t in range(2, 10):
            batch4 = get_batch([1, 2, 3, 5, 6], self.dim, t=t, x=0.4)
            replay_buffer.add_batch(batch4, batch4.env_id)
        batch = replay_buffer.gather_all()
        self.assertEqual(list(batch.t.shape), [8, 4])

        # Test clear()
        replay_buffer.clear()
        self.assertEqual(replay_buffer.total_size, 0)
Ejemplo n.º 3
0
class SyncExperienceReplayer(ExperienceReplayer):
    """
    For synchronous off-policy training.

    Example algorithms: DDPG, SAC
    """
    def __init__(self,
                 experience_spec,
                 batch_size,
                 max_length,
                 num_earliest_frames_ignored=0,
                 prioritized_sampling=False,
                 name="SyncExperienceReplayer"):
        """Create a ReplayBuffer.

        Args:
            data_experience_specspec (nested TensorSpec): spec describing a
                single item that can be stored in the replayer.
            batch_size (int): number of environments.
            max_length (int): The maximum number of items that can be stored
                for a single environment.
            num_earliest_frames_ignored (int): ignore the earlist so many frames
                when sample from the buffer. This is typically required when
                FrameStack is used.
            prioritized_sampling (bool): Use prioritized sampling if this is True.
        """
        super().__init__()
        self._experience_spec = experience_spec
        self._buffer = ReplayBuffer(
            experience_spec,
            batch_size,
            max_length=max_length,
            prioritized_sampling=prioritized_sampling,
            num_earliest_frames_ignored=num_earliest_frames_ignored,
            name=name)
        self._data_iter = None

    def observe(self, exp):
        """
        For the sync driver, `exp` has the shape (`env_batch_size`, ...)
        with `num_envs`==1 and `unroll_length`==1.
        """
        outer_rank = alf.nest.utils.get_outer_rank(exp, self._experience_spec)

        if outer_rank == 1:
            self._buffer.add_batch(exp, exp.env_id)
        elif outer_rank == 3:
            # The shape is [learn_queue_cap, unroll_length, env_batch_size, ...]
            for q in range(exp.step_type.shape[0]):
                for t in range(exp.step_type.shape[1]):
                    bat = alf.nest.map_structure(lambda x: x[q, t, ...], exp)
                    self._buffer.add_batch(bat, bat.env_id)
        else:
            raise ValueError("Unsupported outer rank %s of `exp`" % outer_rank)

    def replay(self, sample_batch_size, mini_batch_length):
        """Get a random batch.

        Args:
            sample_batch_size (int): number of sequences
            mini_batch_length (int): the length of each sequence
        Returns:
            tuple:
                - nested Tensors: The samples. Its shapes are [batch_size, batch_length, ...]
                - BatchInfo: Information about the batch. Its shapes are [batch_size].
                    - env_ids: environment id for each sequence
                    - positions: starting position in the replay buffer for each sequence.
                    - importance_weights: importance weight divided by the average of
                        all non-zero importance weights in the buffer.

        """
        return self._buffer.get_batch(sample_batch_size, mini_batch_length)

    def replay_all(self):
        return self._buffer.gather_all()

    def clear(self):
        self._buffer.clear()

    def update_priority(self, env_ids, positions, priorities):
        """Update the priorities for the given experiences.

        Args:
            env_ids (Tensor): 1-D int64 Tensor.
            positions (Tensor): 1-D int64 Tensor with same shape as ``env_ids``.
                This position should be obtained the BatchInfo returned by
                ``get_batch()``
        """
        self._buffer.update_priority(env_ids, positions, priorities)

    @property
    def batch_size(self):
        return self._buffer.num_environments

    @property
    def total_size(self):
        return self._buffer.total_size

    @property
    def replay_buffer(self):
        return self._buffer
Ejemplo n.º 4
0
    def test_replay_buffer(self):
        dim = 20
        max_length = 4
        num_envs = 8
        data_spec = DataItem(
            env_id=tf.TensorSpec(shape=(), dtype=tf.int32),
            x=tf.TensorSpec(shape=(dim, ), dtype=tf.float32),
            t=tf.TensorSpec(shape=(), dtype=tf.int32))

        replay_buffer = ReplayBuffer(
            data_spec=data_spec,
            num_environments=num_envs,
            max_length=max_length)

        def _get_batch(env_ids, t, x):
            batch_size = len(env_ids)
            x = (x * tf.expand_dims(tf.range(batch_size, dtype=tf.float32), 1)
                 * tf.expand_dims(tf.range(dim, dtype=tf.float32), 0))
            return DataItem(
                env_id=tf.constant(env_ids),
                x=x,
                t=t * tf.ones((batch_size, ), tf.int32))

        batch1 = _get_batch([0, 4, 7], t=0, x=0.1)
        replay_buffer.add_batch(batch1, batch1.env_id)
        self.assertArrayEqual(replay_buffer._current_size,
                              [1, 0, 0, 0, 1, 0, 0, 1])
        self.assertArrayEqual(replay_buffer._current_pos,
                              [1, 0, 0, 0, 1, 0, 0, 1])
        with self.assertRaises(tf.errors.InvalidArgumentError):
            replay_buffer.get_batch(8, 1)

        batch2 = _get_batch([1, 2, 3, 5, 6], t=0, x=0.2)
        replay_buffer.add_batch(batch2, batch2.env_id)
        self.assertArrayEqual(replay_buffer._current_size,
                              [1, 1, 1, 1, 1, 1, 1, 1])
        self.assertArrayEqual(replay_buffer._current_pos,
                              [1, 1, 1, 1, 1, 1, 1, 1])

        batch = replay_buffer.gather_all()
        self.assertEqual(batch.t.shape, [8, 1])

        with self.assertRaises(tf.errors.InvalidArgumentError):
            replay_buffer.get_batch(8, 2)
            replay_buffer.get_batch(13, 1)
        batch = replay_buffer.get_batch(8, 1)
        # squeeze the time dimension
        batch = tf.nest.map_structure(lambda bat: tf.squeeze(bat, axis=1),
                                      batch)
        bat1 = tf.nest.map_structure(
            lambda bat: tf.gather(bat, batch1.env_id, axis=0), batch)
        bat2 = tf.nest.map_structure(
            lambda bat: tf.gather(bat, batch2.env_id, axis=0), batch)
        self.assertArrayEqual(bat1.env_id, batch1.env_id)
        self.assertArrayEqual(bat1.x, batch1.x)
        self.assertArrayEqual(bat1.t, batch1.t)
        self.assertArrayEqual(bat2.env_id, batch2.env_id)
        self.assertArrayEqual(bat2.x, batch2.x)
        self.assertArrayEqual(bat2.t, batch2.t)

        for t in range(1, 10):
            batch3 = _get_batch([0, 4, 7], t=t, x=0.3)
            j = (t + 1) % max_length
            s = min(t + 1, max_length)
            replay_buffer.add_batch(batch3, batch3.env_id)
            self.assertArrayEqual(replay_buffer._current_size,
                                  [s, 1, 1, 1, s, 1, 1, s])
            self.assertArrayEqual(replay_buffer._current_pos,
                                  [j, 1, 1, 1, j, 1, 1, j])

        batch2 = _get_batch([1, 2, 3, 5, 6], t=1, x=0.2)
        replay_buffer.add_batch(batch2, batch2.env_id)
        batch = replay_buffer.get_batch(8, 1)
        # squeeze the time dimension
        batch = tf.nest.map_structure(lambda bat: tf.squeeze(bat, axis=1),
                                      batch)
        bat3 = tf.nest.map_structure(
            lambda bat: tf.gather(bat, batch3.env_id, axis=0), batch)
        bat2 = tf.nest.map_structure(
            lambda bat: tf.gather(bat, batch2.env_id, axis=0), batch)
        self.assertArrayEqual(bat3.env_id, batch3.env_id)
        self.assertArrayEqual(bat3.x, batch3.x)
        self.assertArrayEqual(bat2.env_id, batch2.env_id)
        self.assertArrayEqual(bat2.x, batch2.x)

        batch = replay_buffer.get_batch(8, 2)
        t2 = []
        t3 = []
        for t in range(2):
            batch_t = tf.nest.map_structure(lambda b: b[:, t], batch)
            bat3 = tf.nest.map_structure(
                lambda bat: tf.gather(bat, batch3.env_id, axis=0), batch_t)
            bat2 = tf.nest.map_structure(
                lambda bat: tf.gather(bat, batch2.env_id, axis=0), batch_t)
            t2.append(bat2.t)
            self.assertArrayEqual(bat3.env_id, batch3.env_id)
            self.assertArrayEqual(bat3.x, batch3.x)
            self.assertArrayEqual(bat2.env_id, batch2.env_id)
            self.assertArrayEqual(bat2.x, batch2.x)
            t3.append(bat3.t)

        # Test time consistency
        self.assertArrayEqual(t2[0] + 1, t2[1])
        self.assertArrayEqual(t3[0] + 1, t3[1])

        batch = replay_buffer.get_batch(128, 2)
        self.assertArrayEqual(batch.t[:, 0] + 1, batch.t[:, 1])
        self.assertEqual(batch.t.shape, [128, 2])

        batch = replay_buffer.get_batch(10, 2)
        self.assertArrayEqual(batch.t[:, 0] + 1, batch.t[:, 1])
        self.assertEqual(batch.t.shape, [10, 2])

        batch = replay_buffer.get_batch(4, 2)
        self.assertArrayEqual(batch.t[:, 0] + 1, batch.t[:, 1])
        self.assertEqual(batch.t.shape, [4, 2])

        # Test gather_all()
        with self.assertRaises(tf.errors.InvalidArgumentError):
            replay_buffer.gather_all()

        for t in range(2, 10):
            batch4 = _get_batch([1, 2, 3, 5, 6], t=t, x=0.4)
            replay_buffer.add_batch(batch4, batch4.env_id)
        batch = replay_buffer.gather_all()
        self.assertEqual(batch.t.shape, [8, 4])

        # Test cyclic gather_all():
        replay_buffer.clear()

        # regular slice
        for t in range(2, 4):
            batch = _get_batch([0, 1, 2, 3, 4, 5, 6, 7], t=t, x=0.4)
            replay_buffer.add_batch(batch, batch.env_id)
        batch = replay_buffer.gather_all()
        self.assertEqual(batch.t.shape, [8, 2])
        self.assertArrayEqual(batch.t, tf.constant([[2, 3]] * 8))

        # slice that includes everything in the buffer
        for t in range(4, 6):
            batch = _get_batch([0, 1, 2, 3, 4, 5, 6, 7], t=t, x=0.4)
            replay_buffer.add_batch(batch, batch.env_id)
        batch = replay_buffer.gather_all()
        self.assertEqual(batch.t.shape, [8, 4])
        self.assertArrayEqual(batch.t, tf.constant([[2, 3, 4, 5]] * 8))

        # slice that starts from the middle and includes everything
        for t in range(6, 8):
            batch = _get_batch([0, 1, 2, 3, 4, 5, 6, 7], t=t, x=0.4)
            replay_buffer.add_batch(batch, batch.env_id)
        batch = replay_buffer.gather_all()
        self.assertEqual(batch.t.shape, [8, 4])
        self.assertArrayEqual(batch.t, tf.constant([[4, 5, 6, 7]] * 8))

        # slice that starts from the first and includes everything
        for t in range(8, 10):
            batch = _get_batch([0, 1, 2, 3, 4, 5, 6, 7], t=t, x=0.4)
            replay_buffer.add_batch(batch, batch.env_id)
        batch = replay_buffer.gather_all()
        self.assertEqual(batch.t.shape, [8, 4])
        self.assertArrayEqual(batch.t, tf.constant([[6, 7, 8, 9]] * 8))