예제 #1
0
class SyncUniformExperienceReplayer(ExperienceReplayer):
    """
    For synchronous off-policy training.

    Example algorithms: DDPG, SAC
    """
    def __init__(self, experience_spec, batch_size):
        self._experience_spec = experience_spec
        self._buffer = ReplayBuffer(experience_spec, batch_size)
        self._data_iter = None

    @tf.function
    def observe(self, exp):
        """Store one batch of experience into replay buffer.

        Args:
            exp (Experience): input experience to be stored.

        For the sync driver, `exp` has the shape (`env_batch_size`, ...)
        with `num_envs`==1 and `unroll_length`==1.
        """
        outer_rank = get_outer_rank(exp, self._experience_spec)

        if outer_rank == 1:
            self._buffer.add_batch(exp, exp.env_id)
        elif outer_rank == 3:
            # The shape is [learn_queue_cap, unroll_length, env_batch_size, ...]
            for q in tf.range(tf.shape(exp.step_type)[0]):
                for t in tf.range(tf.shape(exp.step_type)[1]):
                    bat = tf.nest.map_structure(lambda x: x[q, t, ...], exp)
                    self._buffer.add_batch(bat, bat.env_id)
        else:
            raise ValueError("Unsupported outer rank %s of `exp`" % outer_rank)

    def replay(self, sample_batch_size, mini_batch_length):
        """Get a random batch.

        Args:
            sample_batch_size (int): number of sequences
            mini_batch_length (int): the length of each sequence
        Returns:
            Experience: experience batch in batch major (B, T, ...)
        """
        return self._buffer.get_batch(sample_batch_size, mini_batch_length)

    def replay_all(self):
        return self._buffer.gather_all()

    def clear(self):
        self._buffer.clear()

    @property
    def batch_size(self):
        return self._buffer.num_environments
예제 #2
0
    def test_recent_data_and_without_replacement(self):
        num_envs = 4
        max_length = 100
        replay_buffer = ReplayBuffer(data_spec=self.data_spec,
                                     num_environments=num_envs,
                                     max_length=max_length,
                                     with_replacement=False,
                                     recent_data_ratio=0.5,
                                     recent_data_steps=4)
        replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=0, x=0.))
        batch, info = replay_buffer.get_batch(4, 1)
        self.assertEqual(info.env_ids, torch.tensor([0, 1, 2, 3]))

        replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=1, x=1.0))
        batch, info = replay_buffer.get_batch(8, 1)
        self.assertEqual(info.env_ids, torch.tensor([0, 1, 2, 3] * 2))

        for t in range(2, 32):
            replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=t,
                                              x=t))
        batch, info = replay_buffer.get_batch(32, 1)
        self.assertEqual(info.env_ids[16:], torch.tensor([0, 1, 2, 3] * 4))
        # The first half is from recent data
        self.assertEqual(info.env_ids[:16], torch.tensor([0, 1, 2, 3] * 4))
        self.assertEqual(
            info.positions[:16],
            torch.tensor([28] * 4 + [29] * 4 + [30] * 4 + [31] * 4))
예제 #3
0
    def test_num_earliest_frames_ignored_priortized(self):
        replay_buffer = ReplayBuffer(data_spec=self.data_spec,
                                     num_environments=self.num_envs,
                                     max_length=self.max_length,
                                     num_earliest_frames_ignored=2,
                                     keep_episodic_info=False,
                                     prioritized_sampling=True)

        batch1 = get_batch([1], self.dim, x=0.25, t=0)
        replay_buffer.add_batch(batch1, batch1.env_id)
        # not enough data
        self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 1)

        batch2 = get_batch([1], self.dim, x=0.25, t=1)
        replay_buffer.add_batch(batch2, batch1.env_id)
        # not enough data
        self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 1)

        batch3 = get_batch([1], self.dim, x=0.25, t=2)
        replay_buffer.add_batch(batch3, batch1.env_id)
        for _ in range(10):
            batch, batch_info = replay_buffer.get_batch(1, 1)
            self.assertEqual(batch_info.env_ids,
                             torch.tensor([1], dtype=torch.int64))
            self.assertEqual(batch_info.importance_weights, 1.)
            self.assertEqual(batch_info.importance_weights, torch.tensor([1.]))
            self.assertEqual(batch.t, torch.tensor([[2]]))
예제 #4
0
    def test_num_earliest_frames_ignored_uniform(self):
        num_envs = 4
        max_length = 100
        replay_buffer = ReplayBuffer(data_spec=self.data_spec,
                                     num_environments=num_envs,
                                     max_length=max_length,
                                     keep_episodic_info=False,
                                     num_earliest_frames_ignored=2)

        replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=0, x=0.))
        # not enough data
        self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 1)

        replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=1, x=0.))
        # not enough data
        self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 1)

        replay_buffer.add_batch(get_batch([0, 1, 2, 3], self.dim, t=2, x=0.))
        for _ in range(10):
            batch, batch_info = replay_buffer.get_batch(1, 1)
            self.assertEqual(batch.t, torch.tensor([[2]]))
예제 #5
0
    def test_frame_stacker(self, stack_axis=0):
        data_spec = DataItem(step_type=alf.TensorSpec((), dtype=torch.int32),
                             observation=dict(scalar=alf.TensorSpec(()),
                                              vector=alf.TensorSpec((7, )),
                                              matrix=alf.TensorSpec((5, 6)),
                                              tensor=alf.TensorSpec(
                                                  (2, 3, 4))))
        replay_buffer = ReplayBuffer(data_spec=data_spec,
                                     num_environments=2,
                                     max_length=1024,
                                     num_earliest_frames_ignored=2)
        frame_stacker = FrameStacker(
            data_spec.observation,
            stack_size=3,
            stack_axis=stack_axis,
            fields=['scalar', 'vector', 'matrix', 'tensor'])

        new_spec = frame_stacker.transformed_observation_spec
        self.assertEqual(new_spec['scalar'].shape, (3, ))
        self.assertEqual(new_spec['vector'].shape, (21, ))
        if stack_axis == -1:
            self.assertEqual(new_spec['matrix'].shape, (5, 18))
            self.assertEqual(new_spec['tensor'].shape, (2, 3, 12))
        elif stack_axis == 0:
            self.assertEqual(new_spec['matrix'].shape, (15, 6))
            self.assertEqual(new_spec['tensor'].shape, (6, 3, 4))

        def _step_type(t, period):
            if t % period == 0:
                return StepType.FIRST
            if t % period == period - 1:
                return StepType.LAST
            return StepType.MID

        observation = alf.nest.map_structure(
            lambda spec: spec.randn((1000, 2)), data_spec.observation)
        state = common.zero_tensor_from_nested_spec(frame_stacker.state_spec,
                                                    2)

        def _get_stacked_data(t, b):
            if stack_axis == -1:
                return dict(scalar=observation['scalar'][t, b],
                            vector=observation['vector'][t, b].reshape(-1),
                            matrix=observation['matrix'][t, b].transpose(
                                0, 1).reshape(5, 18),
                            tensor=observation['tensor'][t, b].permute(
                                1, 2, 0, 3).reshape(2, 3, 12))
            elif stack_axis == 0:
                return dict(scalar=observation['scalar'][t, b],
                            vector=observation['vector'][t, b].reshape(-1),
                            matrix=observation['matrix'][t, b].reshape(15, 6),
                            tensor=observation['tensor'][t,
                                                         b].reshape(6, 3, 4))

        def _check_equal(stacked, expected, b):
            self.assertEqual(stacked['scalar'][b], expected['scalar'])
            self.assertEqual(stacked['vector'][b], expected['vector'])
            self.assertEqual(stacked['matrix'][b], expected['matrix'])
            self.assertEqual(stacked['tensor'][b], expected['tensor'])

        for t in range(1000):
            batch = DataItem(
                step_type=torch.tensor([_step_type(t, 17),
                                        _step_type(t, 22)]),
                observation=alf.nest.map_structure(lambda x: x[t],
                                                   observation))
            replay_buffer.add_batch(batch)
            timestep, state = frame_stacker.transform_timestep(batch, state)
            if t == 0:
                for b in (0, 1):
                    expected = _get_stacked_data([0, 0, 0], b)
                    _check_equal(timestep.observation, expected, b)
            if t == 1:
                for b in (0, 1):
                    expected = _get_stacked_data([0, 0, 1], b)
                    _check_equal(timestep.observation, expected, b)
            if t == 2:
                for b in (0, 1):
                    expected = _get_stacked_data([0, 1, 2], b)
                    _check_equal(timestep.observation, expected, b)
            if t == 16:
                for b in (0, 1):
                    expected = _get_stacked_data([14, 15, 16], b)
                    _check_equal(timestep.observation, expected, b)
            if t == 17:
                for b, t in ((0, [17, 17, 17]), (1, [15, 16, 17])):
                    expected = _get_stacked_data(t, b)
                    _check_equal(timestep.observation, expected, b)
            if t == 18:
                for b, t in ((0, [17, 17, 18]), (1, [16, 17, 18])):
                    expected = _get_stacked_data(t, b)
                    _check_equal(timestep.observation, expected, b)
            if t == 22:
                for b, t in ((0, [20, 21, 22]), (1, [22, 22, 22])):
                    expected = _get_stacked_data(t, b)
                    _check_equal(timestep.observation, expected, b)

        batch_info = BatchInfo(env_ids=torch.tensor([0, 1, 0, 1],
                                                    dtype=torch.int64),
                               positions=torch.tensor([0, 1, 18, 22],
                                                      dtype=torch.int64))

        # [4, 2, ...]
        experience = replay_buffer.get_field(
            '', batch_info.env_ids.unsqueeze(-1),
            batch_info.positions.unsqueeze(-1) + torch.arange(2))
        experience = experience._replace(batch_info=batch_info,
                                         replay_buffer=replay_buffer)
        experience = frame_stacker.transform_experience(experience)
        expected = _get_stacked_data([0, 0, 0], 0)
        _check_equal(experience.observation, expected, (0, 0))
        expected = _get_stacked_data([0, 0, 1], 0)
        _check_equal(experience.observation, expected, (0, 1))

        expected = _get_stacked_data([0, 0, 1], 1)
        _check_equal(experience.observation, expected, (1, 0))
        expected = _get_stacked_data([0, 1, 2], 1)
        _check_equal(experience.observation, expected, (1, 1))

        expected = _get_stacked_data([17, 17, 18], 0)
        _check_equal(experience.observation, expected, (2, 0))
        expected = _get_stacked_data([17, 18, 19], 0)
        _check_equal(experience.observation, expected, (2, 1))

        expected = _get_stacked_data([22, 22, 22], 1)
        _check_equal(experience.observation, expected, (3, 0))
        expected = _get_stacked_data([22, 22, 23], 1)
        _check_equal(experience.observation, expected, (3, 1))
예제 #6
0
class SyncExperienceReplayer(ExperienceReplayer):
    """
    For synchronous off-policy training.

    Example algorithms: DDPG, SAC
    """
    def __init__(self,
                 experience_spec,
                 batch_size,
                 max_length,
                 num_earliest_frames_ignored=0,
                 prioritized_sampling=False,
                 name="SyncExperienceReplayer"):
        """Create a ReplayBuffer.

        Args:
            data_experience_specspec (nested TensorSpec): spec describing a
                single item that can be stored in the replayer.
            batch_size (int): number of environments.
            max_length (int): The maximum number of items that can be stored
                for a single environment.
            num_earliest_frames_ignored (int): ignore the earlist so many frames
                when sample from the buffer. This is typically required when
                FrameStack is used.
            prioritized_sampling (bool): Use prioritized sampling if this is True.
        """
        super().__init__()
        self._experience_spec = experience_spec
        self._buffer = ReplayBuffer(
            experience_spec,
            batch_size,
            max_length=max_length,
            prioritized_sampling=prioritized_sampling,
            num_earliest_frames_ignored=num_earliest_frames_ignored,
            name=name)
        self._data_iter = None

    def observe(self, exp):
        """
        For the sync driver, `exp` has the shape (`env_batch_size`, ...)
        with `num_envs`==1 and `unroll_length`==1.
        """
        outer_rank = alf.nest.utils.get_outer_rank(exp, self._experience_spec)

        if outer_rank == 1:
            self._buffer.add_batch(exp, exp.env_id)
        elif outer_rank == 3:
            # The shape is [learn_queue_cap, unroll_length, env_batch_size, ...]
            for q in range(exp.step_type.shape[0]):
                for t in range(exp.step_type.shape[1]):
                    bat = alf.nest.map_structure(lambda x: x[q, t, ...], exp)
                    self._buffer.add_batch(bat, bat.env_id)
        else:
            raise ValueError("Unsupported outer rank %s of `exp`" % outer_rank)

    def replay(self, sample_batch_size, mini_batch_length):
        """Get a random batch.

        Args:
            sample_batch_size (int): number of sequences
            mini_batch_length (int): the length of each sequence
        Returns:
            tuple:
                - nested Tensors: The samples. Its shapes are [batch_size, batch_length, ...]
                - BatchInfo: Information about the batch. Its shapes are [batch_size].
                    - env_ids: environment id for each sequence
                    - positions: starting position in the replay buffer for each sequence.
                    - importance_weights: importance weight divided by the average of
                        all non-zero importance weights in the buffer.

        """
        return self._buffer.get_batch(sample_batch_size, mini_batch_length)

    def replay_all(self):
        return self._buffer.gather_all()

    def clear(self):
        self._buffer.clear()

    def update_priority(self, env_ids, positions, priorities):
        """Update the priorities for the given experiences.

        Args:
            env_ids (Tensor): 1-D int64 Tensor.
            positions (Tensor): 1-D int64 Tensor with same shape as ``env_ids``.
                This position should be obtained the BatchInfo returned by
                ``get_batch()``
        """
        self._buffer.update_priority(env_ids, positions, priorities)

    @property
    def batch_size(self):
        return self._buffer.num_environments

    @property
    def total_size(self):
        return self._buffer.total_size

    @property
    def replay_buffer(self):
        return self._buffer
예제 #7
0
    def test_replay_buffer(self):
        dim = 20
        max_length = 4
        num_envs = 8
        data_spec = DataItem(
            env_id=tf.TensorSpec(shape=(), dtype=tf.int32),
            x=tf.TensorSpec(shape=(dim, ), dtype=tf.float32),
            t=tf.TensorSpec(shape=(), dtype=tf.int32))

        replay_buffer = ReplayBuffer(
            data_spec=data_spec,
            num_environments=num_envs,
            max_length=max_length)

        def _get_batch(env_ids, t, x):
            batch_size = len(env_ids)
            x = (x * tf.expand_dims(tf.range(batch_size, dtype=tf.float32), 1)
                 * tf.expand_dims(tf.range(dim, dtype=tf.float32), 0))
            return DataItem(
                env_id=tf.constant(env_ids),
                x=x,
                t=t * tf.ones((batch_size, ), tf.int32))

        batch1 = _get_batch([0, 4, 7], t=0, x=0.1)
        replay_buffer.add_batch(batch1, batch1.env_id)
        self.assertArrayEqual(replay_buffer._current_size,
                              [1, 0, 0, 0, 1, 0, 0, 1])
        self.assertArrayEqual(replay_buffer._current_pos,
                              [1, 0, 0, 0, 1, 0, 0, 1])
        with self.assertRaises(tf.errors.InvalidArgumentError):
            replay_buffer.get_batch(8, 1)

        batch2 = _get_batch([1, 2, 3, 5, 6], t=0, x=0.2)
        replay_buffer.add_batch(batch2, batch2.env_id)
        self.assertArrayEqual(replay_buffer._current_size,
                              [1, 1, 1, 1, 1, 1, 1, 1])
        self.assertArrayEqual(replay_buffer._current_pos,
                              [1, 1, 1, 1, 1, 1, 1, 1])

        batch = replay_buffer.gather_all()
        self.assertEqual(batch.t.shape, [8, 1])

        with self.assertRaises(tf.errors.InvalidArgumentError):
            replay_buffer.get_batch(8, 2)
            replay_buffer.get_batch(13, 1)
        batch = replay_buffer.get_batch(8, 1)
        # squeeze the time dimension
        batch = tf.nest.map_structure(lambda bat: tf.squeeze(bat, axis=1),
                                      batch)
        bat1 = tf.nest.map_structure(
            lambda bat: tf.gather(bat, batch1.env_id, axis=0), batch)
        bat2 = tf.nest.map_structure(
            lambda bat: tf.gather(bat, batch2.env_id, axis=0), batch)
        self.assertArrayEqual(bat1.env_id, batch1.env_id)
        self.assertArrayEqual(bat1.x, batch1.x)
        self.assertArrayEqual(bat1.t, batch1.t)
        self.assertArrayEqual(bat2.env_id, batch2.env_id)
        self.assertArrayEqual(bat2.x, batch2.x)
        self.assertArrayEqual(bat2.t, batch2.t)

        for t in range(1, 10):
            batch3 = _get_batch([0, 4, 7], t=t, x=0.3)
            j = (t + 1) % max_length
            s = min(t + 1, max_length)
            replay_buffer.add_batch(batch3, batch3.env_id)
            self.assertArrayEqual(replay_buffer._current_size,
                                  [s, 1, 1, 1, s, 1, 1, s])
            self.assertArrayEqual(replay_buffer._current_pos,
                                  [j, 1, 1, 1, j, 1, 1, j])

        batch2 = _get_batch([1, 2, 3, 5, 6], t=1, x=0.2)
        replay_buffer.add_batch(batch2, batch2.env_id)
        batch = replay_buffer.get_batch(8, 1)
        # squeeze the time dimension
        batch = tf.nest.map_structure(lambda bat: tf.squeeze(bat, axis=1),
                                      batch)
        bat3 = tf.nest.map_structure(
            lambda bat: tf.gather(bat, batch3.env_id, axis=0), batch)
        bat2 = tf.nest.map_structure(
            lambda bat: tf.gather(bat, batch2.env_id, axis=0), batch)
        self.assertArrayEqual(bat3.env_id, batch3.env_id)
        self.assertArrayEqual(bat3.x, batch3.x)
        self.assertArrayEqual(bat2.env_id, batch2.env_id)
        self.assertArrayEqual(bat2.x, batch2.x)

        batch = replay_buffer.get_batch(8, 2)
        t2 = []
        t3 = []
        for t in range(2):
            batch_t = tf.nest.map_structure(lambda b: b[:, t], batch)
            bat3 = tf.nest.map_structure(
                lambda bat: tf.gather(bat, batch3.env_id, axis=0), batch_t)
            bat2 = tf.nest.map_structure(
                lambda bat: tf.gather(bat, batch2.env_id, axis=0), batch_t)
            t2.append(bat2.t)
            self.assertArrayEqual(bat3.env_id, batch3.env_id)
            self.assertArrayEqual(bat3.x, batch3.x)
            self.assertArrayEqual(bat2.env_id, batch2.env_id)
            self.assertArrayEqual(bat2.x, batch2.x)
            t3.append(bat3.t)

        # Test time consistency
        self.assertArrayEqual(t2[0] + 1, t2[1])
        self.assertArrayEqual(t3[0] + 1, t3[1])

        batch = replay_buffer.get_batch(128, 2)
        self.assertArrayEqual(batch.t[:, 0] + 1, batch.t[:, 1])
        self.assertEqual(batch.t.shape, [128, 2])

        batch = replay_buffer.get_batch(10, 2)
        self.assertArrayEqual(batch.t[:, 0] + 1, batch.t[:, 1])
        self.assertEqual(batch.t.shape, [10, 2])

        batch = replay_buffer.get_batch(4, 2)
        self.assertArrayEqual(batch.t[:, 0] + 1, batch.t[:, 1])
        self.assertEqual(batch.t.shape, [4, 2])

        # Test gather_all()
        with self.assertRaises(tf.errors.InvalidArgumentError):
            replay_buffer.gather_all()

        for t in range(2, 10):
            batch4 = _get_batch([1, 2, 3, 5, 6], t=t, x=0.4)
            replay_buffer.add_batch(batch4, batch4.env_id)
        batch = replay_buffer.gather_all()
        self.assertEqual(batch.t.shape, [8, 4])
예제 #8
0
    def _test_preprocess_experience(self, train_reward_function, td_steps,
                                    reanalyze_ratio, expected):
        """
        The following summarizes how the data is generated:

        .. code-block:: python

            # position:   01234567890123
            step_type0 = 'FMMMLFMMLFMMMM'
            step_type1 = 'FMMMMMLFMMMMLF'
            scale = 1. for current model
                    2. for target model
            observation = [position] * 3
            reward = position if train_reward_function and td_steps!=-1
                     else position * (step_type == LAST)
            value = 0.5 * position * scale
            action_probs = scale * [position, position+1, position] for env 0
                           scale * [position+1, position, position] for env 1
            action = 1 for env 0
                     0 for env 1

        """
        reanalyze_td_steps = 2

        num_unroll_steps = 4
        batch_size = 2
        obs_dim = 3

        observation_spec = alf.TensorSpec([obs_dim])
        action_spec = alf.BoundedTensorSpec((),
                                            minimum=0,
                                            maximum=1,
                                            dtype=torch.int32)
        reward_spec = alf.TensorSpec(())
        time_step_spec = ds.time_step_spec(observation_spec, action_spec,
                                           reward_spec)

        global _mcts_model_id
        _mcts_model_id = 0
        muzero = MuzeroAlgorithm(observation_spec,
                                 action_spec,
                                 model_ctor=_create_mcts_model,
                                 mcts_algorithm_ctor=MockMCTSAlgorithm,
                                 num_unroll_steps=num_unroll_steps,
                                 td_steps=td_steps,
                                 train_game_over_function=True,
                                 train_reward_function=train_reward_function,
                                 reanalyze_ratio=reanalyze_ratio,
                                 reanalyze_td_steps=reanalyze_td_steps,
                                 data_transformer_ctor=partial(FrameStacker,
                                                               stack_size=2))

        data_transformer = FrameStacker(observation_spec, stack_size=2)
        time_step = common.zero_tensor_from_nested_spec(
            time_step_spec, batch_size)
        dt_state = common.zero_tensor_from_nested_spec(
            data_transformer.state_spec, batch_size)
        state = muzero.get_initial_predict_state(batch_size)
        transformed_time_step, dt_state = data_transformer.transform_timestep(
            time_step, dt_state)
        alg_step = muzero.rollout_step(transformed_time_step, state)
        alg_step_spec = dist_utils.extract_spec(alg_step)

        experience = ds.make_experience(time_step, alg_step, state)
        experience_spec = ds.make_experience(time_step_spec, alg_step_spec,
                                             muzero.train_state_spec)
        replay_buffer = ReplayBuffer(data_spec=experience_spec,
                                     num_environments=batch_size,
                                     max_length=16,
                                     keep_episodic_info=True)

        #             01234567890123
        step_type0 = 'FMMMLFMMLFMMMM'
        step_type1 = 'FMMMMMLFMMMMLF'

        dt_state = common.zero_tensor_from_nested_spec(
            data_transformer.state_spec, batch_size)
        for i in range(len(step_type0)):
            step_type = [step_type0[i], step_type1[i]]
            step_type = [
                ds.StepType.MID if c == 'M' else
                (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST)
                for c in step_type
            ]
            step_type = torch.tensor(step_type, dtype=torch.int32)
            reward = reward = torch.full([batch_size], float(i))
            if not train_reward_function or td_steps == -1:
                reward = reward * (step_type == ds.StepType.LAST).to(
                    torch.float32)
            time_step = time_step._replace(
                discount=(step_type != ds.StepType.LAST).to(torch.float32),
                step_type=step_type,
                observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]],
                                         dtype=torch.float32),
                reward=reward,
                env_id=torch.arange(batch_size, dtype=torch.int32))
            transformed_time_step, dt_state = data_transformer.transform_timestep(
                time_step, dt_state)
            alg_step = muzero.rollout_step(transformed_time_step, state)
            experience = ds.make_experience(time_step, alg_step, state)
            replay_buffer.add_batch(experience)
            state = alg_step.state

        env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64)
        positions = torch.tensor(list(range(14)) + list(range(14)),
                                 dtype=torch.int64)
        experience = replay_buffer.get_field(None,
                                             env_ids.unsqueeze(-1).cpu(),
                                             positions.unsqueeze(-1).cpu())
        experience = experience._replace(replay_buffer=replay_buffer,
                                         batch_info=BatchInfo(
                                             env_ids=env_ids,
                                             positions=positions),
                                         rollout_info_field='rollout_info')
        processed_experience = muzero.preprocess_experience(experience)
        import pprint
        pprint.pprint(processed_experience.rollout_info)
        alf.nest.map_structure(lambda x, y: self.assertEqual(x, y),
                               processed_experience.rollout_info, expected)
예제 #9
0
    def test_prioritized_replay(self):
        replay_buffer = ReplayBuffer(data_spec=self.data_spec,
                                     num_environments=self.num_envs,
                                     max_length=self.max_length,
                                     prioritized_sampling=True)
        self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 1)

        batch1 = get_batch([1], self.dim, x=0.25, t=0)
        replay_buffer.add_batch(batch1, batch1.env_id)

        batch, batch_info = replay_buffer.get_batch(1, 1)
        self.assertEqual(batch_info.env_ids,
                         torch.tensor([1], dtype=torch.int64))
        self.assertEqual(batch_info.importance_weights, 1.)
        self.assertEqual(batch_info.importance_weights, torch.tensor([1.]))
        self.assertRaises(AssertionError, replay_buffer.get_batch, 1, 2)

        batch2 = get_batch([1], self.dim, x=0.5, t=1)
        replay_buffer.add_batch(batch1, batch1.env_id)

        batch, batch_info = replay_buffer.get_batch(4, 2)
        self.assertEqual(batch_info.env_ids,
                         torch.tensor([1], dtype=torch.int64))
        self.assertEqual(batch_info.importance_weights, torch.tensor([1.]))
        self.assertEqual(batch_info.importance_weights, torch.tensor([1.] * 4))

        batch, batch_info = replay_buffer.get_batch(1000, 1)
        n0 = (replay_buffer.circular(batch_info.positions) == 0).sum()
        n1 = (replay_buffer.circular(batch_info.positions) == 1).sum()
        self.assertEqual(n0, 500)
        self.assertEqual(n1, 500)
        replay_buffer.update_priority(env_ids=torch.tensor([1, 1],
                                                           dtype=torch.int64),
                                      positions=torch.tensor(
                                          [0, 1], dtype=torch.int64),
                                      priorities=torch.tensor([0.5, 1.5]))
        batch, batch_info = replay_buffer.get_batch(1000, 1)
        n0 = (replay_buffer.circular(batch_info.positions) == 0).sum()
        n1 = (replay_buffer.circular(batch_info.positions) == 1).sum()
        self.assertEqual(n0, 250)
        self.assertEqual(n1, 750)

        batch2 = get_batch([0, 2], self.dim, x=0.5, t=1)
        replay_buffer.add_batch(batch2, batch2.env_id)
        batch, batch_info = replay_buffer.get_batch(1000, 1)

        def _get(env_id, pos):
            flag = ((batch_info.env_ids == env_id) *
                    (batch_info.positions == replay_buffer._pad(pos, env_id)))
            w = batch_info.importance_weights[torch.nonzero(flag,
                                                            as_tuple=True)[0]]
            return flag.sum(), w

        n0, w0 = _get(0, 0)
        n1, w1 = _get(1, 0)
        n2, w2 = _get(1, 1)
        n3, w3 = _get(2, 0)
        self.assertEqual(n0, 300)
        self.assertEqual(n1, 100)
        self.assertEqual(n2, 300)
        self.assertEqual(n3, 300)
        self.assertTrue(torch.all(w0 == 1.2))
        self.assertTrue(torch.all(w1 == 0.4))
        self.assertTrue(torch.all(w2 == 1.2))
        self.assertTrue(torch.all(w3 == 1.2))

        replay_buffer.update_priority(env_ids=torch.tensor([1, 2],
                                                           dtype=torch.int64),
                                      positions=torch.tensor(
                                          [1, 0], dtype=torch.int64),
                                      priorities=torch.tensor([1.0, 1.0]))
        batch, batch_info = replay_buffer.get_batch(1000, 1)

        n0, w0 = _get(0, 0)
        n1, w1 = _get(1, 0)
        n2, w2 = _get(1, 1)
        n3, w3 = _get(2, 0)
        self.assertEqual(n0, 375)
        self.assertEqual(n1, 125)
        self.assertEqual(n2, 250)
        self.assertEqual(n3, 250)
        self.assertTrue(torch.all(w0 == 1.5))
        self.assertTrue(torch.all(w1 == 0.5))
        self.assertTrue(torch.all(w2 == 1.0))
        self.assertTrue(torch.all(w3 == 1.0))
예제 #10
0
    def test_replay_with_hindsight_relabel(self):
        self.max_length = 8
        torch.manual_seed(0)
        configs = [
            "hindsight_relabel_fn.her_proportion=0.8",
            'hindsight_relabel_fn.achieved_goal_field="o.a"',
            'hindsight_relabel_fn.desired_goal_field="o.g"',
            "ReplayBuffer.postprocess_exp_fn=@hindsight_relabel_fn",
        ]
        gin.parse_config_files_and_bindings("", configs)

        replay_buffer = ReplayBuffer(data_spec=self.data_spec,
                                     num_environments=2,
                                     max_length=self.max_length,
                                     keep_episodic_info=True,
                                     step_type_field="t",
                                     with_replacement=True)

        steps = [
            [
                ds.StepType.FIRST,  # will be overwritten
                ds.StepType.MID,  # idx == 1 in buffer
                ds.StepType.LAST,
                ds.StepType.FIRST,
                ds.StepType.MID,
                ds.StepType.MID,
                ds.StepType.LAST,
                ds.StepType.FIRST,
                ds.StepType.MID  # idx == 0
            ],
            [
                ds.StepType.FIRST,  # will be overwritten in RingBuffer
                ds.StepType.LAST,  # idx == 1 in RingBuffer
                ds.StepType.FIRST,
                ds.StepType.MID,
                ds.StepType.MID,
                ds.StepType.LAST,
                ds.StepType.FIRST,
                ds.StepType.MID,
                ds.StepType.MID  # idx == 0
            ]
        ]
        # insert data that will be overwritten later
        for b, t in list(itertools.product(range(2), range(8))):
            batch = get_batch([b], self.dim, t=steps[b][t], x=0.1 * t + b)
            replay_buffer.add_batch(batch, batch.env_id)
        # insert data
        for b, t in list(itertools.product(range(2), range(9))):
            batch = get_batch([b], self.dim, t=steps[b][t], x=0.1 * t + b)
            replay_buffer.add_batch(batch, batch.env_id)

        # Test padding
        idx = torch.tensor([[7, 0, 0, 6, 3, 3, 3, 0], [6, 0, 5, 2, 2, 2, 0,
                                                       6]])
        pos = replay_buffer._pad(idx, torch.tensor([[0] * 8, [1] * 8]))
        self.assertTrue(
            torch.equal(
                pos,
                torch.tensor([[15, 16, 16, 14, 11, 11, 11, 16],
                              [14, 16, 13, 10, 10, 10, 16, 14]])))

        # Verify _index is built correctly.
        # Note, the _index_pos 8 represents headless timesteps, which are
        # outdated and not the same as the result of padding: 16.
        pos = torch.tensor([[15, 8, 8, 14, 11, 11, 11, 16],
                            [14, 8, 13, 10, 10, 10, 16, 14]])

        self.assertTrue(torch.equal(replay_buffer._indexed_pos, pos))
        self.assertTrue(
            torch.equal(replay_buffer._headless_indexed_pos,
                        torch.tensor([10, 9])))

        # Save original exp for later testing.
        g_orig = replay_buffer._buffer.o["g"].clone()
        r_orig = replay_buffer._buffer.reward.clone()

        # HER selects indices [0, 2, 3, 4] to relabel, from all 5:
        # env_ids: [[0, 0], [1, 1], [0, 0], [1, 1], [0, 0]]
        # pos:     [[6, 7], [1, 2], [1, 2], [3, 4], [5, 6]] + 8
        # selected:    x               x       x       x
        # future:  [   7       2       2       4       6  ] + 8
        # g        [[.7,.7],[0, 0], [.2,.2],[1.4,1.4],[.6,.6]]  # 0.1 * t + b with default 0
        # reward:  [[-1,0], [-1,-1],[-1,0], [-1,0], [-1,0]]  # recomputed with default -1
        env_ids = torch.tensor([0, 0, 1, 0])
        dist = replay_buffer.steps_to_episode_end(
            replay_buffer._pad(torch.tensor([7, 2, 4, 6]), env_ids), env_ids)
        self.assertEqual(list(dist), [1, 0, 1, 0])

        # Test HER relabeled experiences
        res = replay_buffer.get_batch(5, 2)[0]

        self.assertEqual(list(res.o["g"].shape), [5, 2])

        # Test relabeling doesn't change original experience
        self.assertTrue(torch.allclose(r_orig, replay_buffer._buffer.reward))
        self.assertTrue(torch.allclose(g_orig, replay_buffer._buffer.o["g"]))

        # test relabeled goals
        g = torch.tensor([0.7, 0., .2, 1.4, .6]).unsqueeze(1).expand(5, 2)
        self.assertTrue(torch.allclose(res.o["g"], g))

        # test relabeled rewards
        r = torch.tensor([[-1., 0.], [-1., -1.], [-1., 0.], [-1., 0.],
                          [-1., 0.]])
        self.assertTrue(torch.allclose(res.reward, r))
예제 #11
0
    def test_replay_buffer(self, allow_multiprocess, with_replacement):
        replay_buffer = ReplayBuffer(data_spec=self.data_spec,
                                     num_environments=self.num_envs,
                                     max_length=self.max_length,
                                     allow_multiprocess=allow_multiprocess)

        batch1 = get_batch([0, 4, 7], self.dim, t=0, x=0.1)
        replay_buffer.add_batch(batch1, batch1.env_id)
        self.assertEqual(replay_buffer._current_size,
                         torch.tensor([1, 0, 0, 0, 1, 0, 0, 1]))
        self.assertEqual(replay_buffer._current_pos,
                         torch.tensor([1, 0, 0, 0, 1, 0, 0, 1]))
        self.assertRaises(AssertionError, replay_buffer.get_batch, 8, 1)

        batch2 = get_batch([1, 2, 3, 5, 6], self.dim, t=0, x=0.2)
        replay_buffer.add_batch(batch2, batch2.env_id)
        self.assertEqual(replay_buffer._current_size,
                         torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]))
        self.assertEqual(replay_buffer._current_pos,
                         torch.tensor([1, 1, 1, 1, 1, 1, 1, 1]))

        batch = replay_buffer.gather_all()
        self.assertEqual(list(batch.t.shape), [8, 1])
        # test that RingBuffer detaches gradients of inputs
        self.assertFalse(batch.x.requires_grad)

        self.assertRaises(AssertionError, replay_buffer.get_batch, 8, 2)
        replay_buffer.get_batch(13, 1)[0]

        batch = replay_buffer.get_batch(8, 1)[0]
        # squeeze the time dimension
        batch = alf.nest.map_structure(lambda bat: bat.squeeze(1), batch)
        bat1 = alf.nest.map_structure(lambda bat: bat[batch1.env_id], batch)
        bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id], batch)
        self.assertEqual(bat1.env_id, batch1.env_id)
        self.assertEqual(bat1.x, batch1.x)
        self.assertEqual(bat1.t, batch1.t)
        self.assertEqual(bat2.env_id, batch2.env_id)
        self.assertEqual(bat2.x, batch2.x)
        self.assertEqual(bat2.t, batch2.t)

        for t in range(1, 10):
            batch3 = get_batch([0, 4, 7], self.dim, t=t, x=0.3)
            j = t + 1
            s = min(t + 1, self.max_length)
            replay_buffer.add_batch(batch3, batch3.env_id)
            self.assertEqual(replay_buffer._current_size,
                             torch.tensor([s, 1, 1, 1, s, 1, 1, s]))
            self.assertEqual(replay_buffer._current_pos,
                             torch.tensor([j, 1, 1, 1, j, 1, 1, j]))

        batch2 = get_batch([1, 2, 3, 5, 6], self.dim, t=1, x=0.2)
        replay_buffer.add_batch(batch2, batch2.env_id)
        batch = replay_buffer.get_batch(8, 1)[0]
        # squeeze the time dimension
        batch = alf.nest.map_structure(lambda bat: bat.squeeze(1), batch)
        bat3 = alf.nest.map_structure(lambda bat: bat[batch3.env_id], batch)
        bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id], batch)
        self.assertEqual(bat3.env_id, batch3.env_id)
        self.assertEqual(bat3.x, batch3.x)
        self.assertEqual(bat2.env_id, batch2.env_id)
        self.assertEqual(bat2.x, batch2.x)

        batch = replay_buffer.get_batch(8, 2)[0]
        t2 = []
        t3 = []
        for t in range(2):
            batch_t = alf.nest.map_structure(lambda b: b[:, t], batch)
            bat3 = alf.nest.map_structure(lambda bat: bat[batch3.env_id],
                                          batch_t)
            bat2 = alf.nest.map_structure(lambda bat: bat[batch2.env_id],
                                          batch_t)
            t2.append(bat2.t)
            self.assertEqual(bat3.env_id, batch3.env_id)
            self.assertEqual(bat3.x, batch3.x)
            self.assertEqual(bat2.env_id, batch2.env_id)
            self.assertEqual(bat2.x, batch2.x)
            t3.append(bat3.t)

        # Test time consistency
        self.assertEqual(t2[0] + 1, t2[1])
        self.assertEqual(t3[0] + 1, t3[1])

        batch = replay_buffer.get_batch(128, 2)[0]
        self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1])
        self.assertEqual(list(batch.t.shape), [128, 2])

        batch = replay_buffer.get_batch(10, 2)[0]
        self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1])
        self.assertEqual(list(batch.t.shape), [10, 2])

        batch = replay_buffer.get_batch(4, 2)[0]
        self.assertEqual(batch.t[:, 0] + 1, batch.t[:, 1])
        self.assertEqual(list(batch.t.shape), [4, 2])

        # Test gather_all()
        # Exception because the size of all the environments are not same
        self.assertRaises(AssertionError, replay_buffer.gather_all)

        for t in range(2, 10):
            batch4 = get_batch([1, 2, 3, 5, 6], self.dim, t=t, x=0.4)
            replay_buffer.add_batch(batch4, batch4.env_id)
        batch = replay_buffer.gather_all()
        self.assertEqual(list(batch.t.shape), [8, 4])

        # Test clear()
        replay_buffer.clear()
        self.assertEqual(replay_buffer.total_size, 0)
예제 #12
0
    def test_compute_her_future_step_distance(self, end_prob):
        num_envs = 2
        max_length = 100
        torch.manual_seed(0)
        configs = [
            "hindsight_relabel_fn.her_proportion=0.8",
            'hindsight_relabel_fn.achieved_goal_field="o.a"',
            'hindsight_relabel_fn.desired_goal_field="o.g"',
            "ReplayBuffer.postprocess_exp_fn=@hindsight_relabel_fn",
        ]
        gin.parse_config_files_and_bindings("", configs)

        replay_buffer = ReplayBuffer(data_spec=self.data_spec,
                                     num_environments=num_envs,
                                     max_length=max_length,
                                     keep_episodic_info=True,
                                     step_type_field="t")
        # insert data
        max_steps = 1000
        # generate step_types with certain density of episode ends
        steps = self.generate_step_types(num_envs,
                                         max_steps,
                                         end_prob=end_prob)
        for t in range(max_steps):
            for b in range(num_envs):
                batch = get_batch([b],
                                  self.dim,
                                  t=steps[b * max_steps + t],
                                  x=1. / max_steps * t + b)
                replay_buffer.add_batch(batch, batch.env_id)
            if t > 1:
                sample_steps = min(t, max_length)
                env_ids = torch.tensor([0] * sample_steps + [1] * sample_steps)
                idx = torch.tensor(
                    list(range(sample_steps)) + list(range(sample_steps)))
                gd = self.steps_to_episode_end(replay_buffer, env_ids, idx)
                idx_orig = replay_buffer._indexed_pos.clone()
                idx_headless_orig = replay_buffer._headless_indexed_pos.clone()
                d = replay_buffer.steps_to_episode_end(
                    replay_buffer._pad(idx, env_ids), env_ids)
                # Test distance to end computation
                if not torch.equal(gd, d):
                    outs = [
                        "t: ", t, "\nenvids:\n", env_ids, "\nidx:\n", idx,
                        "\npos:\n",
                        replay_buffer._pad(idx, env_ids), "\nNot Equal: a:\n",
                        gd, "\nb:\n", d, "\nsteps:\n", replay_buffer._buffer.t,
                        "\nindexed_pos:\n", replay_buffer._indexed_pos,
                        "\nheadless_indexed_pos:\n",
                        replay_buffer._headless_indexed_pos
                    ]
                    outs = [str(out) for out in outs]
                    assert False, "".join(outs)

                # Save original exp for later testing.
                g_orig = replay_buffer._buffer.o["g"].clone()
                r_orig = replay_buffer._buffer.reward.clone()

                # HER relabel experience
                res = replay_buffer.get_batch(sample_steps, 2)[0]

                self.assertEqual(list(res.o["g"].shape), [sample_steps, 2])

                # Test relabeling doesn't change original experience
                self.assertTrue(
                    torch.allclose(r_orig, replay_buffer._buffer.reward))
                self.assertTrue(
                    torch.allclose(g_orig, replay_buffer._buffer.o["g"]))
                self.assertTrue(
                    torch.all(idx_orig == replay_buffer._indexed_pos))
                self.assertTrue(
                    torch.all(idx_headless_orig ==
                              replay_buffer._headless_indexed_pos))
예제 #13
0
    def test_preprocess_experience(self):
        """
        The following summarizes how the data is generated:

        .. code-block:: python

            # position:   01234567890123
            step_type0 = 'FMMMLFMMLFMMMM'
            step_type1 = 'FMMMMMLFMMMMLF'
            reward = position if train_reward_function and td_steps!=-1
                     else position * (step_type == LAST)
            action = t + 1 for env 0
                     t for env 1

        """
        num_unroll_steps = 4
        batch_size = 2
        obs_dim = 3
        observation_spec = alf.TensorSpec([obs_dim])
        action_spec = alf.BoundedTensorSpec((1, ),
                                            minimum=0,
                                            maximum=1,
                                            dtype=torch.float32)
        reward_spec = alf.TensorSpec(())
        time_step_spec = ds.time_step_spec(observation_spec, action_spec,
                                           reward_spec)

        repr_learner = PredictiveRepresentationLearner(
            observation_spec,
            action_spec,
            num_unroll_steps=num_unroll_steps,
            decoder_ctor=partial(SimpleDecoder,
                                 target_field='reward',
                                 decoder_net_ctor=partial(
                                     EncodingNetwork, fc_layer_params=(4, ))),
            encoding_net_ctor=LSTMEncodingNetwork,
            dynamics_net_ctor=LSTMEncodingNetwork)

        time_step = common.zero_tensor_from_nested_spec(
            time_step_spec, batch_size)
        state = repr_learner.get_initial_predict_state(batch_size)
        alg_step = repr_learner.rollout_step(time_step, state)
        alg_step = alg_step._replace(output=torch.tensor([[1.], [0.]]))
        alg_step_spec = dist_utils.extract_spec(alg_step)

        experience = ds.make_experience(time_step, alg_step, state)
        experience_spec = ds.make_experience(time_step_spec, alg_step_spec,
                                             repr_learner.train_state_spec)
        replay_buffer = ReplayBuffer(data_spec=experience_spec,
                                     num_environments=batch_size,
                                     max_length=16,
                                     keep_episodic_info=True)

        #             01234567890123
        step_type0 = 'FMMMLFMMLFMMMM'
        step_type1 = 'FMMMMMLFMMMMLF'

        for i in range(len(step_type0)):
            step_type = [step_type0[i], step_type1[i]]
            step_type = [
                ds.StepType.MID if c == 'M' else
                (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST)
                for c in step_type
            ]
            step_type = torch.tensor(step_type, dtype=torch.int32)
            reward = reward = torch.full([batch_size], float(i))
            time_step = time_step._replace(
                discount=(step_type != ds.StepType.LAST).to(torch.float32),
                step_type=step_type,
                observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]],
                                         dtype=torch.float32),
                reward=reward,
                env_id=torch.arange(batch_size, dtype=torch.int32))
            alg_step = repr_learner.rollout_step(time_step, state)
            alg_step = alg_step._replace(output=i + torch.tensor([[1.], [0.]]))
            experience = ds.make_experience(time_step, alg_step, state)
            replay_buffer.add_batch(experience)
            state = alg_step.state

        env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64)
        positions = torch.tensor(list(range(14)) + list(range(14)),
                                 dtype=torch.int64)
        experience = replay_buffer.get_field(None,
                                             env_ids.unsqueeze(-1).cpu(),
                                             positions.unsqueeze(-1).cpu())
        experience = experience._replace(replay_buffer=replay_buffer,
                                         batch_info=BatchInfo(
                                             env_ids=env_ids,
                                             positions=positions),
                                         rollout_info_field='rollout_info')
        processed_experience = repr_learner.preprocess_experience(experience)
        pprint.pprint(processed_experience.rollout_info)

        # yapf: disable
        expected = PredictiveRepresentationLearnerInfo(
            action=torch.tensor(
               [[[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  5.]],
                [[ 3.,  4.,  5.,  5.,  5.]],
                [[ 4.,  5.,  5.,  5.,  5.]],
                [[ 5.,  5.,  5.,  5.,  5.]],
                [[ 6.,  7.,  8.,  9.,  9.]],
                [[ 7.,  8.,  9.,  9.,  9.]],
                [[ 8.,  9.,  9.,  9.,  9.]],
                [[ 9.,  9.,  9.,  9.,  9.]],
                [[10., 11., 12., 13., 14.]],
                [[11., 12., 13., 14., 14.]],
                [[12., 13., 14., 14., 14.]],
                [[13., 14., 14., 14., 14.]],
                [[14., 14., 14., 14., 14.]],
                [[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  6.]],
                [[ 3.,  4.,  5.,  6.,  6.]],
                [[ 4.,  5.,  6.,  6.,  6.]],
                [[ 5.,  6.,  6.,  6.,  6.]],
                [[ 6.,  6.,  6.,  6.,  6.]],
                [[ 7.,  8.,  9., 10., 11.]],
                [[ 8.,  9., 10., 11., 12.]],
                [[ 9., 10., 11., 12., 12.]],
                [[10., 11., 12., 12., 12.]],
                [[11., 12., 12., 12., 12.]],
                [[12., 12., 12., 12., 12.]],
                [[13., 13., 13., 13., 13.]]]).unsqueeze(-1),
            mask=torch.tensor(
               [[[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True, False, False, False, False]]]),
            target=torch.tensor(
               [[[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  4.]],
                [[ 2.,  3.,  4.,  4.,  4.]],
                [[ 3.,  4.,  4.,  4.,  4.]],
                [[ 4.,  4.,  4.,  4.,  4.]],
                [[ 5.,  6.,  7.,  8.,  8.]],
                [[ 6.,  7.,  8.,  8.,  8.]],
                [[ 7.,  8.,  8.,  8.,  8.]],
                [[ 8.,  8.,  8.,  8.,  8.]],
                [[ 9., 10., 11., 12., 13.]],
                [[10., 11., 12., 13., 13.]],
                [[11., 12., 13., 13., 13.]],
                [[12., 13., 13., 13., 13.]],
                [[13., 13., 13., 13., 13.]],
                [[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  6.]],
                [[ 3.,  4.,  5.,  6.,  6.]],
                [[ 4.,  5.,  6.,  6.,  6.]],
                [[ 5.,  6.,  6.,  6.,  6.]],
                [[ 6.,  6.,  6.,  6.,  6.]],
                [[ 7.,  8.,  9., 10., 11.]],
                [[ 8.,  9., 10., 11., 12.]],
                [[ 9., 10., 11., 12., 12.]],
                [[10., 11., 12., 12., 12.]],
                [[11., 12., 12., 12., 12.]],
                [[12., 12., 12., 12., 12.]],
                [[13., 13., 13., 13., 13.]]]))
        # yapf: enable

        alf.nest.map_structure(lambda x, y: self.assertEqual(x, y),
                               processed_experience.rollout_info, expected)