Exemplo n.º 1
0
    def test_conversions(self):
        dists = {
            't':
            torch.tensor([[1., 2., 4.], [3., 3., 1.]]),
            'd':
            dist_utils.DiagMultivariateNormal(
                torch.tensor([[1., 2.], [2., 2.]]),
                torch.tensor([[2., 3.], [1., 1.]]))
        }
        params = dist_utils.distributions_to_params(dists)
        dists_spec = dist_utils.extract_spec(dists, from_dim=1)
        self.assertEqual(dists_spec['t'],
                         alf.TensorSpec(shape=(3, ), dtype=torch.float32))
        self.assertEqual(type(dists_spec['d']), dist_utils.DistributionSpec)
        self.assertEqual(len(params), 2)
        self.assertEqual(dists['t'], params['t'])
        self.assertEqual(dists['d'].base_dist.mean, params['d']['loc'])
        self.assertEqual(dists['d'].base_dist.stddev, params['d']['scale'])

        dists1 = dist_utils.params_to_distributions(params, dists_spec)
        self.assertEqual(len(dists1), 2)
        self.assertEqual(dists1['t'], dists['t'])
        self.assertEqual(type(dists1['d']), type(dists['d']))

        params_spec = dist_utils.to_distribution_param_spec(dists_spec)
        alf.nest.assert_same_structure(params_spec, params)
        params1_spec = dist_utils.extract_spec(params)
        self.assertEqual(params_spec, params1_spec)
    def train_step(self, exp: TimeStep, state):
        # [B, num_unroll_steps + 1]
        info = exp.rollout_info
        targets = common.as_list(info.target)
        batch_size = exp.step_type.shape[0]
        latent, state = self._encoding_net(exp.observation, state)

        sim_latent = self._multi_step_latent_rollout(latent,
                                                     self._num_unroll_steps,
                                                     info.action, state)

        loss = 0
        for i, decoder in enumerate(self._decoders):
            # [num_unroll_steps + 1)*B, ...]
            train_info = decoder.train_step(sim_latent).info
            train_info_spec = dist_utils.extract_spec(train_info)
            train_info = dist_utils.distributions_to_params(train_info)
            train_info = alf.nest.map_structure(
                lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x.
                                    shape[1:]), train_info)
            # [num_unroll_steps + 1, B, ...]
            train_info = dist_utils.params_to_distributions(
                train_info, train_info_spec)
            target = alf.nest.map_structure(lambda x: x.transpose(0, 1),
                                            targets[i])
            loss_info = decoder.calc_loss(target, train_info, info.mask.t())
            loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0),
                                               loss_info)
            loss += loss_info.loss

        loss_info = LossInfo(loss=loss, extra=loss)

        return AlgStep(output=latent, state=state, info=loss_info)
Exemplo n.º 3
0
    def train_step(self, exp: Experience, state):
        def _hook(grad, name):
            alf.summary.scalar("MCTS_state_grad_norm/" + name, grad.norm())

        model_output = self._model.initial_inference(exp.observation)
        if alf.summary.should_record_summaries():
            model_output.state.register_hook(partial(_hook, name="s0"))
        model_output_spec = dist_utils.extract_spec(model_output)
        model_outputs = [dist_utils.distributions_to_params(model_output)]
        info = exp.rollout_info

        for i in range(self._num_unroll_steps):
            model_output = self._model.recurrent_inference(
                model_output.state, info.action[:, i, ...])
            if alf.summary.should_record_summaries():
                model_output.state.register_hook(
                    partial(_hook, name="s" + str(i + 1)))
            model_output = model_output._replace(state=scale_gradient(
                model_output.state, self._recurrent_gradient_scaling_factor))
            model_outputs.append(
                dist_utils.distributions_to_params(model_output))

        model_outputs = alf.nest.utils.stack_nests(model_outputs, dim=1)
        model_outputs = dist_utils.params_to_distributions(
            model_outputs, model_output_spec)
        return AlgStep(info=self._model.calc_loss(model_outputs, info.target))
Exemplo n.º 4
0
 def _rollout_step(self, time_step: TimeStep, state):
     """A wrapper around user-defined ``rollout_step``. For every rl algorithm,
     this wrapper ensures that the rollout info spec will be computed.
     """
     policy_step = self._original_rollout_step(time_step, state)
     if self._rollout_info_spec is None:
         self._rollout_info_spec = dist_utils.extract_spec(policy_step.info)
     return policy_step
Exemplo n.º 5
0
 def output_spec(self):
     """Return the spec of the network's encoding output. By default, we use
     `_test_forward` to automatically compute the output and get its spec.
     For efficiency, subclasses can overwrite this function if the output spec
     can be obtained easily in other ways.
     """
     if self._output_spec is None:
         training = self.training
         self.eval()
         self._output_spec = extract_spec(self._test_forward()[0],
                                          from_dim=1)
         self.train(training)
     return self._output_spec
Exemplo n.º 6
0
    def _prepare_reanalyze_data(self, replay_buffer: ReplayBuffer, env_ids,
                                positions, n1, n2):
        """
        Get the n1 + n2 steps of experience indicated by ``positions`` and return
        as the first n1 as ``exp1`` and the next n2 steps as ``exp2``.
        """
        batch_size = env_ids.shape[0]
        n = n1 + n2
        flat_env_ids = env_ids.expand_as(positions).reshape(-1)
        flat_positions = positions.reshape(-1)
        exp = replay_buffer.get_field(None, flat_env_ids, flat_positions)

        if self._data_transformer_ctor is not None:
            if self._data_transformer is None:
                observation_spec = dist_utils.extract_spec(exp.observation)
                self._data_transformer = create_data_transformer(
                    self._data_transformer_ctor, observation_spec)

            # DataTransformer assumes the shape of exp is [B, T, ...]
            # It also needs exp.batch_info and exp.replay_buffer.
            exp = alf.nest.map_structure(lambda x: x.unsqueeze(1), exp)
            exp = exp._replace(batch_info=BatchInfo(flat_env_ids,
                                                    flat_positions),
                               replay_buffer=replay_buffer)
            exp = self._data_transformer.transform_experience(exp)
            exp = exp._replace(batch_info=(), replay_buffer=())
            exp = alf.nest.map_structure(lambda x: x.squeeze(1), exp)

        def _split1(x):
            shape = x.shape[1:]
            x = x.reshape(batch_size, n, *shape)
            return x[:, :n1, ...].reshape(batch_size * n1, *shape)

        def _split2(x):
            shape = x.shape[1:]
            x = x.reshape(batch_size, n, *shape)
            return x[:, n1:, ...].reshape(batch_size * n2, *shape)

        with alf.device(self._device):
            exp = convert_device(exp)
            exp1 = alf.nest.map_structure(_split1, exp)
            exp2 = alf.nest.map_structure(_split2, exp)

        return exp1, exp2
Exemplo n.º 7
0
 def _make_policy_step(self, time_step, state, policy_step):
     assert (
         alf.nest.is_namedtuple(policy_step.info)
         and "action_distribution" in policy_step.info._fields), (
             "PolicyStep.info from ac_algorithm.rollout_step() or "
             "ac_algorithm.train_step() should be a namedtuple containing "
             "`action_distribution` in order to use TracAlgorithm.")
     action_distribution = policy_step.info.action_distribution
     if self._action_distribution_spec is None:
         self._action_distribution_spec = dist_utils.extract_spec(
             action_distribution)
     ac_info = policy_step.info._replace(action_distribution=())
     # EntropyTargetAlgorithm need info.action_distribution
     return policy_step._replace(
         info=TracInfo(action_distribution=action_distribution,
                       observation=time_step.observation,
                       prev_action=time_step.prev_action,
                       state=self._ac_algorithm.
                       convert_train_state_to_predict_state(state),
                       ac=ac_info))
Exemplo n.º 8
0
    def train_step(self, exp: TimeStep, state):
        # [B, num_unroll_steps + 1]
        info = exp.rollout_info
        batch_size = exp.step_type.shape[0]
        latent, state = self._encoding_net(exp.observation, state)

        sim_latents = [latent]

        if self._num_unroll_steps > 0:

            if self._latent_to_dstate_fc is not None:
                dstate = self._latent_to_dstate_fc(latent)
                dstate = dstate.split(self._dynamics_state_dims, dim=1)
                dstate = alf.nest.pack_sequence_as(
                    self._dynamics_net.state_spec, dstate)
            else:
                dstate = state

        for i in range(self._num_unroll_steps):
            sim_latent, dstate = self._dynamics_net(info.action[:, i, ...],
                                                    dstate)
            sim_latents.append(sim_latent)

        sim_latent = torch.cat(sim_latents, dim=0)

        # [num_unroll_steps + 1)*B, ...]
        train_info = self._decoder.train_step(sim_latent).info
        train_info_spec = dist_utils.extract_spec(train_info)
        train_info = dist_utils.distributions_to_params(train_info)
        train_info = alf.nest.map_structure(
            lambda x: x.reshape(self._num_unroll_steps + 1, batch_size, *x.
                                shape[1:]), train_info)
        # [num_unroll_steps + 1, B, ...]
        train_info = dist_utils.params_to_distributions(
            train_info, train_info_spec)
        target = alf.nest.map_structure(lambda x: x.transpose(0, 1),
                                        info.target)
        loss_info = self._decoder.calc_loss(target, train_info, info.mask.t())
        loss_info = alf.nest.map_structure(lambda x: x.mean(dim=0), loss_info)

        return AlgStep(output=latent, state=state, info=loss_info)
Exemplo n.º 9
0
    def _test_preprocess_experience(self, train_reward_function, td_steps,
                                    reanalyze_ratio, expected):
        """
        The following summarizes how the data is generated:

        .. code-block:: python

            # position:   01234567890123
            step_type0 = 'FMMMLFMMLFMMMM'
            step_type1 = 'FMMMMMLFMMMMLF'
            scale = 1. for current model
                    2. for target model
            observation = [position] * 3
            reward = position if train_reward_function and td_steps!=-1
                     else position * (step_type == LAST)
            value = 0.5 * position * scale
            action_probs = scale * [position, position+1, position] for env 0
                           scale * [position+1, position, position] for env 1
            action = 1 for env 0
                     0 for env 1

        """
        reanalyze_td_steps = 2

        num_unroll_steps = 4
        batch_size = 2
        obs_dim = 3

        observation_spec = alf.TensorSpec([obs_dim])
        action_spec = alf.BoundedTensorSpec((),
                                            minimum=0,
                                            maximum=1,
                                            dtype=torch.int32)
        reward_spec = alf.TensorSpec(())
        time_step_spec = ds.time_step_spec(observation_spec, action_spec,
                                           reward_spec)

        global _mcts_model_id
        _mcts_model_id = 0
        muzero = MuzeroAlgorithm(observation_spec,
                                 action_spec,
                                 model_ctor=_create_mcts_model,
                                 mcts_algorithm_ctor=MockMCTSAlgorithm,
                                 num_unroll_steps=num_unroll_steps,
                                 td_steps=td_steps,
                                 train_game_over_function=True,
                                 train_reward_function=train_reward_function,
                                 reanalyze_ratio=reanalyze_ratio,
                                 reanalyze_td_steps=reanalyze_td_steps,
                                 data_transformer_ctor=partial(FrameStacker,
                                                               stack_size=2))

        data_transformer = FrameStacker(observation_spec, stack_size=2)
        time_step = common.zero_tensor_from_nested_spec(
            time_step_spec, batch_size)
        dt_state = common.zero_tensor_from_nested_spec(
            data_transformer.state_spec, batch_size)
        state = muzero.get_initial_predict_state(batch_size)
        transformed_time_step, dt_state = data_transformer.transform_timestep(
            time_step, dt_state)
        alg_step = muzero.rollout_step(transformed_time_step, state)
        alg_step_spec = dist_utils.extract_spec(alg_step)

        experience = ds.make_experience(time_step, alg_step, state)
        experience_spec = ds.make_experience(time_step_spec, alg_step_spec,
                                             muzero.train_state_spec)
        replay_buffer = ReplayBuffer(data_spec=experience_spec,
                                     num_environments=batch_size,
                                     max_length=16,
                                     keep_episodic_info=True)

        #             01234567890123
        step_type0 = 'FMMMLFMMLFMMMM'
        step_type1 = 'FMMMMMLFMMMMLF'

        dt_state = common.zero_tensor_from_nested_spec(
            data_transformer.state_spec, batch_size)
        for i in range(len(step_type0)):
            step_type = [step_type0[i], step_type1[i]]
            step_type = [
                ds.StepType.MID if c == 'M' else
                (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST)
                for c in step_type
            ]
            step_type = torch.tensor(step_type, dtype=torch.int32)
            reward = reward = torch.full([batch_size], float(i))
            if not train_reward_function or td_steps == -1:
                reward = reward * (step_type == ds.StepType.LAST).to(
                    torch.float32)
            time_step = time_step._replace(
                discount=(step_type != ds.StepType.LAST).to(torch.float32),
                step_type=step_type,
                observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]],
                                         dtype=torch.float32),
                reward=reward,
                env_id=torch.arange(batch_size, dtype=torch.int32))
            transformed_time_step, dt_state = data_transformer.transform_timestep(
                time_step, dt_state)
            alg_step = muzero.rollout_step(transformed_time_step, state)
            experience = ds.make_experience(time_step, alg_step, state)
            replay_buffer.add_batch(experience)
            state = alg_step.state

        env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64)
        positions = torch.tensor(list(range(14)) + list(range(14)),
                                 dtype=torch.int64)
        experience = replay_buffer.get_field(None,
                                             env_ids.unsqueeze(-1).cpu(),
                                             positions.unsqueeze(-1).cpu())
        experience = experience._replace(replay_buffer=replay_buffer,
                                         batch_info=BatchInfo(
                                             env_ids=env_ids,
                                             positions=positions),
                                         rollout_info_field='rollout_info')
        processed_experience = muzero.preprocess_experience(experience)
        import pprint
        pprint.pprint(processed_experience.rollout_info)
        alf.nest.map_structure(lambda x, y: self.assertEqual(x, y),
                               processed_experience.rollout_info, expected)
Exemplo n.º 10
0
    def __init__(self, num_expansions, model_output, known_value_bounds):
        batch_size, branch_factor = model_output.action_probs.shape
        action_spec = dist_utils.extract_spec(model_output.actions, from_dim=2)
        state_spec = dist_utils.extract_spec(model_output.state, from_dim=1)
        if known_value_bounds:
            self.fixed_bounds = True
            self.minimum, self.maximum = known_value_bounds
        else:
            self.fixed_bounds = False
            self.minimum, self.maximum = MAXIMUM_FLOAT_VALUE, -MAXIMUM_FLOAT_VALUE
        self.minimum = torch.full((batch_size, ),
                                  self.minimum,
                                  dtype=torch.float32)
        self.maximum = torch.full((batch_size, ),
                                  self.maximum,
                                  dtype=torch.float32)
        if known_value_bounds:
            self.normalize_scale = 1 / (self.maximum - self.minimum + 1e-30)
            self.normalize_base = self.minimum
        else:
            self.normalize_scale = torch.ones((batch_size, ))
            self.normalize_base = torch.zeros((batch_size, ))

        self.B = torch.arange(batch_size)
        self.root_indices = torch.zeros((batch_size, ), dtype=torch.int64)
        self.branch_factor = branch_factor

        parent_shape = (batch_size, num_expansions)
        children_shape = (batch_size, num_expansions, branch_factor)

        self.visit_count = torch.zeros(parent_shape, dtype=torch.int32)

        # the player who will take action from the current state
        self.to_play = torch.zeros(parent_shape, dtype=torch.int64)
        self.prior = torch.zeros(children_shape)

        # value for player 0
        self.value_sum = torch.zeros(parent_shape)

        # 0 for not expanded, value in range [0, num_expansions)
        self.children_index = torch.zeros(children_shape, dtype=torch.int64)
        self.model_state = common.zero_tensor_from_nested_spec(
            state_spec, parent_shape)

        # reward for player 0
        self.reward = None
        if isinstance(model_output.reward, torch.Tensor):
            self.reward = torch.zeros(parent_shape)

        self.action = None
        if isinstance(model_output.actions, torch.Tensor):
            # candidate actions for this state
            self.action = torch.zeros(
                children_shape + action_spec.shape, dtype=action_spec.dtype)

        self.game_over = None
        if isinstance(model_output.game_over, torch.Tensor):
            self.game_over = torch.zeros(parent_shape, dtype=torch.bool)

        # value in range [0, branch_factor)
        self.best_child_index = torch.zeros(parent_shape, dtype=torch.int64)
        self.ucb_score = torch.zeros(children_shape)
Exemplo n.º 11
0
    def test_preprocess_experience(self):
        """
        The following summarizes how the data is generated:

        .. code-block:: python

            # position:   01234567890123
            step_type0 = 'FMMMLFMMLFMMMM'
            step_type1 = 'FMMMMMLFMMMMLF'
            reward = position if train_reward_function and td_steps!=-1
                     else position * (step_type == LAST)
            action = t + 1 for env 0
                     t for env 1

        """
        num_unroll_steps = 4
        batch_size = 2
        obs_dim = 3
        observation_spec = alf.TensorSpec([obs_dim])
        action_spec = alf.BoundedTensorSpec((1, ),
                                            minimum=0,
                                            maximum=1,
                                            dtype=torch.float32)
        reward_spec = alf.TensorSpec(())
        time_step_spec = ds.time_step_spec(observation_spec, action_spec,
                                           reward_spec)

        repr_learner = PredictiveRepresentationLearner(
            observation_spec,
            action_spec,
            num_unroll_steps=num_unroll_steps,
            decoder_ctor=partial(SimpleDecoder,
                                 target_field='reward',
                                 decoder_net_ctor=partial(
                                     EncodingNetwork, fc_layer_params=(4, ))),
            encoding_net_ctor=LSTMEncodingNetwork,
            dynamics_net_ctor=LSTMEncodingNetwork)

        time_step = common.zero_tensor_from_nested_spec(
            time_step_spec, batch_size)
        state = repr_learner.get_initial_predict_state(batch_size)
        alg_step = repr_learner.rollout_step(time_step, state)
        alg_step = alg_step._replace(output=torch.tensor([[1.], [0.]]))
        alg_step_spec = dist_utils.extract_spec(alg_step)

        experience = ds.make_experience(time_step, alg_step, state)
        experience_spec = ds.make_experience(time_step_spec, alg_step_spec,
                                             repr_learner.train_state_spec)
        replay_buffer = ReplayBuffer(data_spec=experience_spec,
                                     num_environments=batch_size,
                                     max_length=16,
                                     keep_episodic_info=True)

        #             01234567890123
        step_type0 = 'FMMMLFMMLFMMMM'
        step_type1 = 'FMMMMMLFMMMMLF'

        for i in range(len(step_type0)):
            step_type = [step_type0[i], step_type1[i]]
            step_type = [
                ds.StepType.MID if c == 'M' else
                (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST)
                for c in step_type
            ]
            step_type = torch.tensor(step_type, dtype=torch.int32)
            reward = reward = torch.full([batch_size], float(i))
            time_step = time_step._replace(
                discount=(step_type != ds.StepType.LAST).to(torch.float32),
                step_type=step_type,
                observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]],
                                         dtype=torch.float32),
                reward=reward,
                env_id=torch.arange(batch_size, dtype=torch.int32))
            alg_step = repr_learner.rollout_step(time_step, state)
            alg_step = alg_step._replace(output=i + torch.tensor([[1.], [0.]]))
            experience = ds.make_experience(time_step, alg_step, state)
            replay_buffer.add_batch(experience)
            state = alg_step.state

        env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64)
        positions = torch.tensor(list(range(14)) + list(range(14)),
                                 dtype=torch.int64)
        experience = replay_buffer.get_field(None,
                                             env_ids.unsqueeze(-1).cpu(),
                                             positions.unsqueeze(-1).cpu())
        experience = experience._replace(replay_buffer=replay_buffer,
                                         batch_info=BatchInfo(
                                             env_ids=env_ids,
                                             positions=positions),
                                         rollout_info_field='rollout_info')
        processed_experience = repr_learner.preprocess_experience(experience)
        pprint.pprint(processed_experience.rollout_info)

        # yapf: disable
        expected = PredictiveRepresentationLearnerInfo(
            action=torch.tensor(
               [[[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  5.]],
                [[ 3.,  4.,  5.,  5.,  5.]],
                [[ 4.,  5.,  5.,  5.,  5.]],
                [[ 5.,  5.,  5.,  5.,  5.]],
                [[ 6.,  7.,  8.,  9.,  9.]],
                [[ 7.,  8.,  9.,  9.,  9.]],
                [[ 8.,  9.,  9.,  9.,  9.]],
                [[ 9.,  9.,  9.,  9.,  9.]],
                [[10., 11., 12., 13., 14.]],
                [[11., 12., 13., 14., 14.]],
                [[12., 13., 14., 14., 14.]],
                [[13., 14., 14., 14., 14.]],
                [[14., 14., 14., 14., 14.]],
                [[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  6.]],
                [[ 3.,  4.,  5.,  6.,  6.]],
                [[ 4.,  5.,  6.,  6.,  6.]],
                [[ 5.,  6.,  6.,  6.,  6.]],
                [[ 6.,  6.,  6.,  6.,  6.]],
                [[ 7.,  8.,  9., 10., 11.]],
                [[ 8.,  9., 10., 11., 12.]],
                [[ 9., 10., 11., 12., 12.]],
                [[10., 11., 12., 12., 12.]],
                [[11., 12., 12., 12., 12.]],
                [[12., 12., 12., 12., 12.]],
                [[13., 13., 13., 13., 13.]]]).unsqueeze(-1),
            mask=torch.tensor(
               [[[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True, False, False, False, False]]]),
            target=torch.tensor(
               [[[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  4.]],
                [[ 2.,  3.,  4.,  4.,  4.]],
                [[ 3.,  4.,  4.,  4.,  4.]],
                [[ 4.,  4.,  4.,  4.,  4.]],
                [[ 5.,  6.,  7.,  8.,  8.]],
                [[ 6.,  7.,  8.,  8.,  8.]],
                [[ 7.,  8.,  8.,  8.,  8.]],
                [[ 8.,  8.,  8.,  8.,  8.]],
                [[ 9., 10., 11., 12., 13.]],
                [[10., 11., 12., 13., 13.]],
                [[11., 12., 13., 13., 13.]],
                [[12., 13., 13., 13., 13.]],
                [[13., 13., 13., 13., 13.]],
                [[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  6.]],
                [[ 3.,  4.,  5.,  6.,  6.]],
                [[ 4.,  5.,  6.,  6.,  6.]],
                [[ 5.,  6.,  6.,  6.,  6.]],
                [[ 6.,  6.,  6.,  6.,  6.]],
                [[ 7.,  8.,  9., 10., 11.]],
                [[ 8.,  9., 10., 11., 12.]],
                [[ 9., 10., 11., 12., 12.]],
                [[10., 11., 12., 12., 12.]],
                [[11., 12., 12., 12., 12.]],
                [[12., 12., 12., 12., 12.]],
                [[13., 13., 13., 13., 13.]]]))
        # yapf: enable

        alf.nest.map_structure(lambda x, y: self.assertEqual(x, y),
                               processed_experience.rollout_info, expected)