示例#1
0
    def test_simple_action_encoder(self):
        action_spec = [
            alf.BoundedTensorSpec((3, )),
            alf.BoundedTensorSpec((), dtype=torch.int32, minimum=0, maximum=3)
        ]
        encoder = SimpleActionEncoder(action_spec)

        # test scalar
        x = [torch.tensor([0.5, 1.5, 2.5]), torch.tensor(3)]
        y = encoder(x)[0]
        self.assertEqual(y, torch.tensor([0.5, 1.5, 2.5, 0, 0, 0, 1]))

        # test batch
        x = [torch.tensor([[0.5, 1.5, 2.5], [1, 2, 3]]), torch.tensor([3, 2])]
        y = encoder(x)[0]
        self.assertEqual(
            y,
            torch.tensor([[0.5, 1.5, 2.5, 0, 0, 0, 1], [1, 2, 3, 0, 0, 1, 0]]))

        # test unsupported spec
        action_spec = [
            alf.BoundedTensorSpec((3, )),
            alf.BoundedTensorSpec((), dtype=torch.int32, minimum=1, maximum=3)
        ]

        self.assertRaises(AssertionError, SimpleActionEncoder, action_spec)
示例#2
0
def create_algorithm(env):
    config = TrainerConfig(root_dir="dummy", unroll_length=5)
    obs_spec = alf.TensorSpec((2, ), dtype='float32')
    action_spec = alf.BoundedTensorSpec(
        shape=(), dtype='int32', minimum=0, maximum=2)

    fc_layer_params = (10, 8, 6)

    actor_network = partial(
        ActorDistributionNetwork,
        fc_layer_params=fc_layer_params,
        discrete_projection_net_ctor=alf.networks.CategoricalProjectionNetwork)

    value_network = partial(ValueNetwork, fc_layer_params=(10, 8, 1))

    alg = ActorCriticAlgorithm(
        observation_spec=obs_spec,
        action_spec=action_spec,
        actor_network_ctor=actor_network,
        value_network_ctor=value_network,
        env=env,
        config=config,
        optimizer=alf.optimizers.Adam(lr=1e-2),
        debug_summaries=True,
        name="MyActorCritic")
    return alg
示例#3
0
 def _transform_spec(spec):
     assert isinstance(
         spec,
         alf.TensorSpec), (str(type(spec)) + "is not a TensorSpec")
     assert spec.dtype == torch.uint8, "Image must have dtype uint8!"
     return alf.BoundedTensorSpec(
         spec.shape, dtype=torch.float32, minimum=min, maximum=max)
示例#4
0
 def _make_stacked_spec(self, spec):
     assert isinstance(
         spec, alf.TensorSpec), (str(type(spec)) + "is not a TensorSpec")
     if spec.ndim > 0:
         stacked_shape = list(copy.copy(spec.shape))
         stacked_shape[self._stack_axis] = stacked_shape[
             self._stack_axis] * self._stack_size
         stacked_shape = tuple(stacked_shape)
     else:
         stacked_shape = (self._stack_size, )
     if not spec.is_bounded():
         return alf.TensorSpec(stacked_shape, spec.dtype)
     else:
         if spec.minimum.shape != ():
             assert spec.minimum.shape == spec.shape
             minimum = np.repeat(
                 spec.minimum,
                 repeats=self._stack_size,
                 axis=self._stack_axis)
         else:
             minimum = spec.minimum
         if spec.maximum.shape != ():
             assert spec.maximum.shape == spec.shape
             maximum = np.repeat(
                 spec.maximum,
                 repeats=self._stack_size,
                 axis=self._stack_axis)
         else:
             maximum = spec.maximum
         return alf.BoundedTensorSpec(
             stacked_shape,
             minimum=minimum,
             maximum=maximum,
             dtype=spec.dtype)
示例#5
0
文件: suite_go.py 项目: soychanq/alf
    def __init__(self,
                 batch_size,
                 height=19,
                 width=19,
                 winning_thresh=7.5,
                 allow_suicidal_move=False,
                 reward_shaping=False,
                 human_player=None):
        """
        Args:
            batch_size (int): the number of parallel boards
            height (int): height of each board
            width (int): width of each board
            winning_thresh (float): player 0 wins if area0 - area1 > winning_thresh,
                lose if area0 - area1 < winning_thresh, otherwise draw.
            allow_suicidal_move (bool): whether suicidal move is allowed.
            reward_shaping (bool): if True, instead of using +1,-1 as reward,
                use ``alf.math.softsign(area0 - area1 - winning_thresh)`` as reward
                to encourage capture more area.
            human_player (int|None): 0, 1 or None
        """
        self._batch_size = batch_size
        self._width = width
        self._height = height
        self._max_num_moves = 2 * height * width
        self._winning_thresh = float(winning_thresh)
        self._allow_suicical_move = allow_suicidal_move
        self._reward_shaping = reward_shaping
        self._human_player = human_player

        # width*height for pass
        # otherwise it is a move at (y=action // width, x=action % width)
        self._action_spec = alf.BoundedTensorSpec((),
                                                  minimum=0,
                                                  maximum=width * height,
                                                  dtype=torch.int64)
        self._observation_spec = OrderedDict(
            board=alf.TensorSpec((1, height, width), torch.int8),
            prev_action=self._action_spec,
            valid_action_mask=alf.TensorSpec([width * height + 1], torch.bool),
            steps=alf.TensorSpec((), torch.int32),
            to_play=alf.TensorSpec((), torch.int8))

        self._B = torch.arange(self._batch_size)
        self._env_ids = torch.arange(batch_size)
        self._pass_action = width * height
        self._board = GoBoard(batch_size, height, width, self._max_num_moves)
        self._previous_board = self._board.get_board()
        self._num_moves = torch.zeros((batch_size, ), dtype=torch.int32)
        self._game_over = torch.zeros((batch_size, ), dtype=torch.bool)
        self._prev_action = torch.full((batch_size, ),
                                       self._pass_action,
                                       dtype=torch.int64)
        self._surface = None
        if human_player is not None:
            logging.info("Use mouse click to place a stone")
            logging.info("Kayboard control:")
            logging.info("P     : pass")
            logging.info("SPACE : refresh display")
示例#6
0
 def __init__(self, batch_size, obs_shape=(2, )):
     super().__init__()
     self._batch_size = batch_size
     self._rewards = torch.tensor([0.5, 1.0, -1.])
     self._observation_spec = alf.TensorSpec(obs_shape, dtype='float32')
     self._action_spec = alf.BoundedTensorSpec(
         shape=(), dtype='int64', minimum=0, maximum=2)
     self.reset()
示例#7
0
 def test_multitask_wrapper(self):
     env = alf_wrappers.MultitaskWrapper.load(
         suite_gym.load, ['CartPole-v0', 'CartPole-v1'])
     self.assertEqual(env.num_tasks, 2)
     self.assertEqual(env.action_spec()['task_id'],
                      alf.BoundedTensorSpec((), maximum=1, dtype='int64'))
     self.assertEqual(env.action_spec()['action'],
                      env._envs[0].action_spec())
     time_step = env.reset()
     time_step = env.step(
         OrderedDict(task_id=1, action=time_step.prev_action['action']))
     self.assertEqual(time_step.prev_action['task_id'], 1)
示例#8
0
    def action_spec(self):
        """Get the action spec.

        The action is a 4-D vector of [throttle, steer, brake, reverse], where
        throttle is in [-1.0, 1.0] (negative value is same as zero), steer is in
        [-1.0, 1.0], brake is in [-1.0, 1.0] (negative value is same as zero),
        and reverse is interpreted as a boolean value with values greater than
        0.5 corrsponding to True.

        Returns:
            nested BoundedTensorSpec:
        """
        return alf.BoundedTensorSpec([4],
                                     minimum=[-1., -1., -1., 0.],
                                     maximum=[1., 1., 1., 1.])
示例#9
0
    def __init__(self, envs, task_names, env_id=None):
        """
        Args:
            envs (list[AlfEnvironment]): a list of environments. Each one
                represents a different task.
            task_names (list[str]): the names of each task.
            env_id (int): (optional) ID of the environment.
        """
        assert len(envs) > 0, "`envs should not be empty"
        assert len(set(task_names)) == len(task_names), (
            "task_names should "
            "not contain duplicated names: %s" % str(task_names))
        self._envs = envs
        self._observation_spec = envs[0].observation_spec()
        self._action_spec = envs[0].action_spec()
        self._reward_spec = envs[0].reward_spec()
        self._env_info_spec = envs[0].env_info_spec()
        self._task_names = task_names
        if env_id is None:
            env_id = 0
        self._env_id = np.int32(env_id)

        def _nested_eq(a, b):
            return all(
                alf.nest.flatten(
                    alf.nest.map_structure(lambda x, y: x == y, a, b)))

        for env in envs:
            assert _nested_eq(
                env.observation_spec(), self._observation_spec), (
                    "All environement should have same observation spec. "
                    "Got %s vs %s" %
                    (self._observation_spec, env.observation_spec()))
            assert _nested_eq(env.action_spec(), self._action_spec), (
                "All environement should have same action spec. "
                "Got %s vs %s" % (self._action_spec, env.action_spec()))
            assert _nested_eq(env.reward_spec(), self._reward_spec), (
                "All environement should have same reward spec. "
                "Got %s vs %s" % (self._reward_spec, env.reward_spec()))
            assert _nested_eq(env.env_info_spec(), self._env_info_spec), (
                "All environement should have same env_info spec. "
                "Got %s vs %s" % (self._env_info_spec, env.env_info_spec()))
            env.reset()

        self._current_env_id = np.int64(0)
        self._action_spec = OrderedDict(task_id=alf.BoundedTensorSpec(
            (), maximum=len(envs) - 1, dtype='int64'),
                                        action=self._action_spec)
示例#10
0
    def action_spec(self):
        """Get the action spec.

        The action is a 3-D vector of [speed, direction, reverse], where speed is in
        [-1.0, 1.0] with negative value meaning zero speed and 1.0 corresponding
        to maximally allowed speed as provided by the ``max_speed`` argument for
        ``__init__()``, and direction is the relative direction that the vehicle
        is facing, with 0 being front, -0.5 being left and 0.5 being right, and
        reverse is interpreted as a boolean value with values greater than 0.5
        corrsponding to True to indicate going backward.

        Returns:
            alf.BoundedTensorSpec
        """
        return alf.BoundedTensorSpec([3],
                                     minimum=[-1., -1., 0.],
                                     maximum=[1., 1., 1.])
示例#11
0
 def __init__(self, batch_size):
     self._batch_size = batch_size
     self._observation_spec = alf.TensorSpec((3, 3))
     self._action_spec = alf.BoundedTensorSpec((),
                                               minimum=0,
                                               maximum=8,
                                               dtype=torch.int64)
     self._line_x = torch.tensor(
         [[0, 0, 0], [1, 1, 1], [2, 2, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2],
          [0, 1, 2], [0, 1, 2]]).unsqueeze(0)
     self._line_y = torch.tensor(
         [[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 0, 0], [1, 1, 1], [2, 2, 2],
          [0, 1, 2], [2, 1, 0]]).unsqueeze(0)
     self._B = torch.arange(self._batch_size)
     self._empty_board = self._observation_spec.zeros()
     self._boards = self._observation_spec.zeros((self._batch_size, ))
     self._env_ids = torch.arange(batch_size)
     self._player_0 = torch.tensor(-1.)
     self._player_1 = torch.tensor(1.)
示例#12
0
    def test_mcts_algorithm(self):
        observation_spec = alf.TensorSpec((3, 3))
        action_spec = alf.BoundedTensorSpec((),
                                            dtype=torch.int64,
                                            minimum=0,
                                            maximum=8)
        model = TicTacToeModel()
        time_step = TimeStep(step_type=torch.tensor([StepType.MID]))

        # board situations and expected actions
        # yapf: disable
        cases = [
            ([[1, -1,  1],
              [1, -1, -1],
              [0,  0,  1]], 6),
            ([[0,  0,  0],
              [0, -1, -1],
              [0,  1,  0]], 3),
            ([[ 1, -1, -1],
              [-1, -1,  0],
              [ 0,  1,  1]], 6),
            ([[-1,  0,  1],
              [ 0, -1, -1],
              [ 0,  0,  1]], 3),
            ([[0, 0,  0],
              [0, 0,  0],
              [0, 0, -1]], 4),
            ([[0,  0, 0],
              [0, -1, 0],
              [0,  0, 0]], (0, 2, 6, 8)),
            ([[0,  0,  0],
              [0,  1, -1],
              [1, -1, -1]], 2),
        ]
        # yapf: enable

        def _create_mcts(observation_spec, action_spec, num_simulations):
            return MCTSAlgorithm(
                observation_spec,
                action_spec,
                discount=1.0,
                root_dirichlet_alpha=100.,
                root_exploration_fraction=0.25,
                num_simulations=num_simulations,
                pb_c_init=1.25,
                pb_c_base=19652,
                visit_softmax_temperature_fn=VisitSoftmaxTemperatureByMoves(
                    [(0, 1.0), (10, 0.0001)]),
                known_value_bounds=(-1, 1),
                is_two_player_game=True)

        # test case serially
        for observation, action in cases:
            observation = torch.tensor([observation], dtype=torch.float32)
            state = MCTSState(steps=(observation != 0).sum(dim=(1, 2)))
            # We use varing num_simulations instead of a fixed large number such
            # as 2000 to make the test faster.
            num_simulations = int((observation == 0).sum().cpu()) * 200
            mcts = _create_mcts(
                observation_spec, action_spec, num_simulations=num_simulations)
            mcts.set_model(model)
            alg_step = mcts.predict_step(
                time_step._replace(observation=observation), state)
            print(observation, alg_step.output, alg_step.info)
            if type(action) == tuple:
                self.assertTrue(alg_step.output[0] in action)
            else:
                self.assertEqual(alg_step.output[0], action)

        # test batch predict
        observation = torch.tensor([case[0] for case in cases],
                                   dtype=torch.float32)
        state = MCTSState(steps=(observation != 0).sum(dim=(1, 2)))
        mcts = _create_mcts(
            observation_spec, action_spec, num_simulations=2000)
        mcts.set_model(model)
        alg_step = mcts.predict_step(
            time_step._replace(
                step_type=torch.tensor([StepType.MID] * len(cases)),
                observation=observation), state)
        for i, (observation, action) in enumerate(cases):
            if type(action) == tuple:
                self.assertTrue(alg_step.output[i] in action)
            else:
                self.assertEqual(alg_step.output[i], action)
示例#13
0
    def _test_preprocess_experience(self, train_reward_function, td_steps,
                                    reanalyze_ratio, expected):
        """
        The following summarizes how the data is generated:

        .. code-block:: python

            # position:   01234567890123
            step_type0 = 'FMMMLFMMLFMMMM'
            step_type1 = 'FMMMMMLFMMMMLF'
            scale = 1. for current model
                    2. for target model
            observation = [position] * 3
            reward = position if train_reward_function and td_steps!=-1
                     else position * (step_type == LAST)
            value = 0.5 * position * scale
            action_probs = scale * [position, position+1, position] for env 0
                           scale * [position+1, position, position] for env 1
            action = 1 for env 0
                     0 for env 1

        """
        reanalyze_td_steps = 2

        num_unroll_steps = 4
        batch_size = 2
        obs_dim = 3

        observation_spec = alf.TensorSpec([obs_dim])
        action_spec = alf.BoundedTensorSpec((),
                                            minimum=0,
                                            maximum=1,
                                            dtype=torch.int32)
        reward_spec = alf.TensorSpec(())
        time_step_spec = ds.time_step_spec(observation_spec, action_spec,
                                           reward_spec)

        global _mcts_model_id
        _mcts_model_id = 0
        muzero = MuzeroAlgorithm(observation_spec,
                                 action_spec,
                                 model_ctor=_create_mcts_model,
                                 mcts_algorithm_ctor=MockMCTSAlgorithm,
                                 num_unroll_steps=num_unroll_steps,
                                 td_steps=td_steps,
                                 train_game_over_function=True,
                                 train_reward_function=train_reward_function,
                                 reanalyze_ratio=reanalyze_ratio,
                                 reanalyze_td_steps=reanalyze_td_steps,
                                 data_transformer_ctor=partial(FrameStacker,
                                                               stack_size=2))

        data_transformer = FrameStacker(observation_spec, stack_size=2)
        time_step = common.zero_tensor_from_nested_spec(
            time_step_spec, batch_size)
        dt_state = common.zero_tensor_from_nested_spec(
            data_transformer.state_spec, batch_size)
        state = muzero.get_initial_predict_state(batch_size)
        transformed_time_step, dt_state = data_transformer.transform_timestep(
            time_step, dt_state)
        alg_step = muzero.rollout_step(transformed_time_step, state)
        alg_step_spec = dist_utils.extract_spec(alg_step)

        experience = ds.make_experience(time_step, alg_step, state)
        experience_spec = ds.make_experience(time_step_spec, alg_step_spec,
                                             muzero.train_state_spec)
        replay_buffer = ReplayBuffer(data_spec=experience_spec,
                                     num_environments=batch_size,
                                     max_length=16,
                                     keep_episodic_info=True)

        #             01234567890123
        step_type0 = 'FMMMLFMMLFMMMM'
        step_type1 = 'FMMMMMLFMMMMLF'

        dt_state = common.zero_tensor_from_nested_spec(
            data_transformer.state_spec, batch_size)
        for i in range(len(step_type0)):
            step_type = [step_type0[i], step_type1[i]]
            step_type = [
                ds.StepType.MID if c == 'M' else
                (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST)
                for c in step_type
            ]
            step_type = torch.tensor(step_type, dtype=torch.int32)
            reward = reward = torch.full([batch_size], float(i))
            if not train_reward_function or td_steps == -1:
                reward = reward * (step_type == ds.StepType.LAST).to(
                    torch.float32)
            time_step = time_step._replace(
                discount=(step_type != ds.StepType.LAST).to(torch.float32),
                step_type=step_type,
                observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]],
                                         dtype=torch.float32),
                reward=reward,
                env_id=torch.arange(batch_size, dtype=torch.int32))
            transformed_time_step, dt_state = data_transformer.transform_timestep(
                time_step, dt_state)
            alg_step = muzero.rollout_step(transformed_time_step, state)
            experience = ds.make_experience(time_step, alg_step, state)
            replay_buffer.add_batch(experience)
            state = alg_step.state

        env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64)
        positions = torch.tensor(list(range(14)) + list(range(14)),
                                 dtype=torch.int64)
        experience = replay_buffer.get_field(None,
                                             env_ids.unsqueeze(-1).cpu(),
                                             positions.unsqueeze(-1).cpu())
        experience = experience._replace(replay_buffer=replay_buffer,
                                         batch_info=BatchInfo(
                                             env_ids=env_ids,
                                             positions=positions),
                                         rollout_info_field='rollout_info')
        processed_experience = muzero.preprocess_experience(experience)
        import pprint
        pprint.pprint(processed_experience.rollout_info)
        alf.nest.map_structure(lambda x, y: self.assertEqual(x, y),
                               processed_experience.rollout_info, expected)
示例#14
0
    def test_preprocess_experience(self):
        """
        The following summarizes how the data is generated:

        .. code-block:: python

            # position:   01234567890123
            step_type0 = 'FMMMLFMMLFMMMM'
            step_type1 = 'FMMMMMLFMMMMLF'
            reward = position if train_reward_function and td_steps!=-1
                     else position * (step_type == LAST)
            action = t + 1 for env 0
                     t for env 1

        """
        num_unroll_steps = 4
        batch_size = 2
        obs_dim = 3
        observation_spec = alf.TensorSpec([obs_dim])
        action_spec = alf.BoundedTensorSpec((1, ),
                                            minimum=0,
                                            maximum=1,
                                            dtype=torch.float32)
        reward_spec = alf.TensorSpec(())
        time_step_spec = ds.time_step_spec(observation_spec, action_spec,
                                           reward_spec)

        repr_learner = PredictiveRepresentationLearner(
            observation_spec,
            action_spec,
            num_unroll_steps=num_unroll_steps,
            decoder_ctor=partial(SimpleDecoder,
                                 target_field='reward',
                                 decoder_net_ctor=partial(
                                     EncodingNetwork, fc_layer_params=(4, ))),
            encoding_net_ctor=LSTMEncodingNetwork,
            dynamics_net_ctor=LSTMEncodingNetwork)

        time_step = common.zero_tensor_from_nested_spec(
            time_step_spec, batch_size)
        state = repr_learner.get_initial_predict_state(batch_size)
        alg_step = repr_learner.rollout_step(time_step, state)
        alg_step = alg_step._replace(output=torch.tensor([[1.], [0.]]))
        alg_step_spec = dist_utils.extract_spec(alg_step)

        experience = ds.make_experience(time_step, alg_step, state)
        experience_spec = ds.make_experience(time_step_spec, alg_step_spec,
                                             repr_learner.train_state_spec)
        replay_buffer = ReplayBuffer(data_spec=experience_spec,
                                     num_environments=batch_size,
                                     max_length=16,
                                     keep_episodic_info=True)

        #             01234567890123
        step_type0 = 'FMMMLFMMLFMMMM'
        step_type1 = 'FMMMMMLFMMMMLF'

        for i in range(len(step_type0)):
            step_type = [step_type0[i], step_type1[i]]
            step_type = [
                ds.StepType.MID if c == 'M' else
                (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST)
                for c in step_type
            ]
            step_type = torch.tensor(step_type, dtype=torch.int32)
            reward = reward = torch.full([batch_size], float(i))
            time_step = time_step._replace(
                discount=(step_type != ds.StepType.LAST).to(torch.float32),
                step_type=step_type,
                observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]],
                                         dtype=torch.float32),
                reward=reward,
                env_id=torch.arange(batch_size, dtype=torch.int32))
            alg_step = repr_learner.rollout_step(time_step, state)
            alg_step = alg_step._replace(output=i + torch.tensor([[1.], [0.]]))
            experience = ds.make_experience(time_step, alg_step, state)
            replay_buffer.add_batch(experience)
            state = alg_step.state

        env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64)
        positions = torch.tensor(list(range(14)) + list(range(14)),
                                 dtype=torch.int64)
        experience = replay_buffer.get_field(None,
                                             env_ids.unsqueeze(-1).cpu(),
                                             positions.unsqueeze(-1).cpu())
        experience = experience._replace(replay_buffer=replay_buffer,
                                         batch_info=BatchInfo(
                                             env_ids=env_ids,
                                             positions=positions),
                                         rollout_info_field='rollout_info')
        processed_experience = repr_learner.preprocess_experience(experience)
        pprint.pprint(processed_experience.rollout_info)

        # yapf: disable
        expected = PredictiveRepresentationLearnerInfo(
            action=torch.tensor(
               [[[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  5.]],
                [[ 3.,  4.,  5.,  5.,  5.]],
                [[ 4.,  5.,  5.,  5.,  5.]],
                [[ 5.,  5.,  5.,  5.,  5.]],
                [[ 6.,  7.,  8.,  9.,  9.]],
                [[ 7.,  8.,  9.,  9.,  9.]],
                [[ 8.,  9.,  9.,  9.,  9.]],
                [[ 9.,  9.,  9.,  9.,  9.]],
                [[10., 11., 12., 13., 14.]],
                [[11., 12., 13., 14., 14.]],
                [[12., 13., 14., 14., 14.]],
                [[13., 14., 14., 14., 14.]],
                [[14., 14., 14., 14., 14.]],
                [[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  6.]],
                [[ 3.,  4.,  5.,  6.,  6.]],
                [[ 4.,  5.,  6.,  6.,  6.]],
                [[ 5.,  6.,  6.,  6.,  6.]],
                [[ 6.,  6.,  6.,  6.,  6.]],
                [[ 7.,  8.,  9., 10., 11.]],
                [[ 8.,  9., 10., 11., 12.]],
                [[ 9., 10., 11., 12., 12.]],
                [[10., 11., 12., 12., 12.]],
                [[11., 12., 12., 12., 12.]],
                [[12., 12., 12., 12., 12.]],
                [[13., 13., 13., 13., 13.]]]).unsqueeze(-1),
            mask=torch.tensor(
               [[[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True, False, False, False, False]]]),
            target=torch.tensor(
               [[[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  4.]],
                [[ 2.,  3.,  4.,  4.,  4.]],
                [[ 3.,  4.,  4.,  4.,  4.]],
                [[ 4.,  4.,  4.,  4.,  4.]],
                [[ 5.,  6.,  7.,  8.,  8.]],
                [[ 6.,  7.,  8.,  8.,  8.]],
                [[ 7.,  8.,  8.,  8.,  8.]],
                [[ 8.,  8.,  8.,  8.,  8.]],
                [[ 9., 10., 11., 12., 13.]],
                [[10., 11., 12., 13., 13.]],
                [[11., 12., 13., 13., 13.]],
                [[12., 13., 13., 13., 13.]],
                [[13., 13., 13., 13., 13.]],
                [[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  6.]],
                [[ 3.,  4.,  5.,  6.,  6.]],
                [[ 4.,  5.,  6.,  6.,  6.]],
                [[ 5.,  6.,  6.,  6.,  6.]],
                [[ 6.,  6.,  6.,  6.,  6.]],
                [[ 7.,  8.,  9., 10., 11.]],
                [[ 8.,  9., 10., 11., 12.]],
                [[ 9., 10., 11., 12., 12.]],
                [[10., 11., 12., 12., 12.]],
                [[11., 12., 12., 12., 12.]],
                [[12., 12., 12., 12., 12.]],
                [[13., 13., 13., 13., 13.]]]))
        # yapf: enable

        alf.nest.map_structure(lambda x, y: self.assertEqual(x, y),
                               processed_experience.rollout_info, expected)