Пример #1
0
    def _prepare_reanalyze_data(self, replay_buffer: ReplayBuffer, env_ids,
                                positions, n1, n2):
        """
        Get the n1 + n2 steps of experience indicated by ``positions`` and return
        as the first n1 as ``exp1`` and the next n2 steps as ``exp2``.
        """
        batch_size = env_ids.shape[0]
        n = n1 + n2
        flat_env_ids = env_ids.expand_as(positions).reshape(-1)
        flat_positions = positions.reshape(-1)
        exp = replay_buffer.get_field(None, flat_env_ids, flat_positions)

        if self._data_transformer_ctor is not None:
            if self._data_transformer is None:
                observation_spec = dist_utils.extract_spec(exp.observation)
                self._data_transformer = create_data_transformer(
                    self._data_transformer_ctor, observation_spec)

            # DataTransformer assumes the shape of exp is [B, T, ...]
            # It also needs exp.batch_info and exp.replay_buffer.
            exp = alf.nest.map_structure(lambda x: x.unsqueeze(1), exp)
            exp = exp._replace(batch_info=BatchInfo(flat_env_ids,
                                                    flat_positions),
                               replay_buffer=replay_buffer)
            exp = self._data_transformer.transform_experience(exp)
            exp = exp._replace(batch_info=(), replay_buffer=())
            exp = alf.nest.map_structure(lambda x: x.squeeze(1), exp)

        def _split1(x):
            shape = x.shape[1:]
            x = x.reshape(batch_size, n, *shape)
            return x[:, :n1, ...].reshape(batch_size * n1, *shape)

        def _split2(x):
            shape = x.shape[1:]
            x = x.reshape(batch_size, n, *shape)
            return x[:, n1:, ...].reshape(batch_size * n2, *shape)

        with alf.device(self._device):
            exp = convert_device(exp)
            exp1 = alf.nest.map_structure(_split1, exp)
            exp2 = alf.nest.map_structure(_split2, exp)

        return exp1, exp2
Пример #2
0
    def test_frame_stacker(self, stack_axis=0):
        data_spec = DataItem(step_type=alf.TensorSpec((), dtype=torch.int32),
                             observation=dict(scalar=alf.TensorSpec(()),
                                              vector=alf.TensorSpec((7, )),
                                              matrix=alf.TensorSpec((5, 6)),
                                              tensor=alf.TensorSpec(
                                                  (2, 3, 4))))
        replay_buffer = ReplayBuffer(data_spec=data_spec,
                                     num_environments=2,
                                     max_length=1024,
                                     num_earliest_frames_ignored=2)
        frame_stacker = FrameStacker(
            data_spec.observation,
            stack_size=3,
            stack_axis=stack_axis,
            fields=['scalar', 'vector', 'matrix', 'tensor'])

        new_spec = frame_stacker.transformed_observation_spec
        self.assertEqual(new_spec['scalar'].shape, (3, ))
        self.assertEqual(new_spec['vector'].shape, (21, ))
        if stack_axis == -1:
            self.assertEqual(new_spec['matrix'].shape, (5, 18))
            self.assertEqual(new_spec['tensor'].shape, (2, 3, 12))
        elif stack_axis == 0:
            self.assertEqual(new_spec['matrix'].shape, (15, 6))
            self.assertEqual(new_spec['tensor'].shape, (6, 3, 4))

        def _step_type(t, period):
            if t % period == 0:
                return StepType.FIRST
            if t % period == period - 1:
                return StepType.LAST
            return StepType.MID

        observation = alf.nest.map_structure(
            lambda spec: spec.randn((1000, 2)), data_spec.observation)
        state = common.zero_tensor_from_nested_spec(frame_stacker.state_spec,
                                                    2)

        def _get_stacked_data(t, b):
            if stack_axis == -1:
                return dict(scalar=observation['scalar'][t, b],
                            vector=observation['vector'][t, b].reshape(-1),
                            matrix=observation['matrix'][t, b].transpose(
                                0, 1).reshape(5, 18),
                            tensor=observation['tensor'][t, b].permute(
                                1, 2, 0, 3).reshape(2, 3, 12))
            elif stack_axis == 0:
                return dict(scalar=observation['scalar'][t, b],
                            vector=observation['vector'][t, b].reshape(-1),
                            matrix=observation['matrix'][t, b].reshape(15, 6),
                            tensor=observation['tensor'][t,
                                                         b].reshape(6, 3, 4))

        def _check_equal(stacked, expected, b):
            self.assertEqual(stacked['scalar'][b], expected['scalar'])
            self.assertEqual(stacked['vector'][b], expected['vector'])
            self.assertEqual(stacked['matrix'][b], expected['matrix'])
            self.assertEqual(stacked['tensor'][b], expected['tensor'])

        for t in range(1000):
            batch = DataItem(
                step_type=torch.tensor([_step_type(t, 17),
                                        _step_type(t, 22)]),
                observation=alf.nest.map_structure(lambda x: x[t],
                                                   observation))
            replay_buffer.add_batch(batch)
            timestep, state = frame_stacker.transform_timestep(batch, state)
            if t == 0:
                for b in (0, 1):
                    expected = _get_stacked_data([0, 0, 0], b)
                    _check_equal(timestep.observation, expected, b)
            if t == 1:
                for b in (0, 1):
                    expected = _get_stacked_data([0, 0, 1], b)
                    _check_equal(timestep.observation, expected, b)
            if t == 2:
                for b in (0, 1):
                    expected = _get_stacked_data([0, 1, 2], b)
                    _check_equal(timestep.observation, expected, b)
            if t == 16:
                for b in (0, 1):
                    expected = _get_stacked_data([14, 15, 16], b)
                    _check_equal(timestep.observation, expected, b)
            if t == 17:
                for b, t in ((0, [17, 17, 17]), (1, [15, 16, 17])):
                    expected = _get_stacked_data(t, b)
                    _check_equal(timestep.observation, expected, b)
            if t == 18:
                for b, t in ((0, [17, 17, 18]), (1, [16, 17, 18])):
                    expected = _get_stacked_data(t, b)
                    _check_equal(timestep.observation, expected, b)
            if t == 22:
                for b, t in ((0, [20, 21, 22]), (1, [22, 22, 22])):
                    expected = _get_stacked_data(t, b)
                    _check_equal(timestep.observation, expected, b)

        batch_info = BatchInfo(env_ids=torch.tensor([0, 1, 0, 1],
                                                    dtype=torch.int64),
                               positions=torch.tensor([0, 1, 18, 22],
                                                      dtype=torch.int64))

        # [4, 2, ...]
        experience = replay_buffer.get_field(
            '', batch_info.env_ids.unsqueeze(-1),
            batch_info.positions.unsqueeze(-1) + torch.arange(2))
        experience = experience._replace(batch_info=batch_info,
                                         replay_buffer=replay_buffer)
        experience = frame_stacker.transform_experience(experience)
        expected = _get_stacked_data([0, 0, 0], 0)
        _check_equal(experience.observation, expected, (0, 0))
        expected = _get_stacked_data([0, 0, 1], 0)
        _check_equal(experience.observation, expected, (0, 1))

        expected = _get_stacked_data([0, 0, 1], 1)
        _check_equal(experience.observation, expected, (1, 0))
        expected = _get_stacked_data([0, 1, 2], 1)
        _check_equal(experience.observation, expected, (1, 1))

        expected = _get_stacked_data([17, 17, 18], 0)
        _check_equal(experience.observation, expected, (2, 0))
        expected = _get_stacked_data([17, 18, 19], 0)
        _check_equal(experience.observation, expected, (2, 1))

        expected = _get_stacked_data([22, 22, 22], 1)
        _check_equal(experience.observation, expected, (3, 0))
        expected = _get_stacked_data([22, 22, 23], 1)
        _check_equal(experience.observation, expected, (3, 1))
Пример #3
0
    def _test_preprocess_experience(self, train_reward_function, td_steps,
                                    reanalyze_ratio, expected):
        """
        The following summarizes how the data is generated:

        .. code-block:: python

            # position:   01234567890123
            step_type0 = 'FMMMLFMMLFMMMM'
            step_type1 = 'FMMMMMLFMMMMLF'
            scale = 1. for current model
                    2. for target model
            observation = [position] * 3
            reward = position if train_reward_function and td_steps!=-1
                     else position * (step_type == LAST)
            value = 0.5 * position * scale
            action_probs = scale * [position, position+1, position] for env 0
                           scale * [position+1, position, position] for env 1
            action = 1 for env 0
                     0 for env 1

        """
        reanalyze_td_steps = 2

        num_unroll_steps = 4
        batch_size = 2
        obs_dim = 3

        observation_spec = alf.TensorSpec([obs_dim])
        action_spec = alf.BoundedTensorSpec((),
                                            minimum=0,
                                            maximum=1,
                                            dtype=torch.int32)
        reward_spec = alf.TensorSpec(())
        time_step_spec = ds.time_step_spec(observation_spec, action_spec,
                                           reward_spec)

        global _mcts_model_id
        _mcts_model_id = 0
        muzero = MuzeroAlgorithm(observation_spec,
                                 action_spec,
                                 model_ctor=_create_mcts_model,
                                 mcts_algorithm_ctor=MockMCTSAlgorithm,
                                 num_unroll_steps=num_unroll_steps,
                                 td_steps=td_steps,
                                 train_game_over_function=True,
                                 train_reward_function=train_reward_function,
                                 reanalyze_ratio=reanalyze_ratio,
                                 reanalyze_td_steps=reanalyze_td_steps,
                                 data_transformer_ctor=partial(FrameStacker,
                                                               stack_size=2))

        data_transformer = FrameStacker(observation_spec, stack_size=2)
        time_step = common.zero_tensor_from_nested_spec(
            time_step_spec, batch_size)
        dt_state = common.zero_tensor_from_nested_spec(
            data_transformer.state_spec, batch_size)
        state = muzero.get_initial_predict_state(batch_size)
        transformed_time_step, dt_state = data_transformer.transform_timestep(
            time_step, dt_state)
        alg_step = muzero.rollout_step(transformed_time_step, state)
        alg_step_spec = dist_utils.extract_spec(alg_step)

        experience = ds.make_experience(time_step, alg_step, state)
        experience_spec = ds.make_experience(time_step_spec, alg_step_spec,
                                             muzero.train_state_spec)
        replay_buffer = ReplayBuffer(data_spec=experience_spec,
                                     num_environments=batch_size,
                                     max_length=16,
                                     keep_episodic_info=True)

        #             01234567890123
        step_type0 = 'FMMMLFMMLFMMMM'
        step_type1 = 'FMMMMMLFMMMMLF'

        dt_state = common.zero_tensor_from_nested_spec(
            data_transformer.state_spec, batch_size)
        for i in range(len(step_type0)):
            step_type = [step_type0[i], step_type1[i]]
            step_type = [
                ds.StepType.MID if c == 'M' else
                (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST)
                for c in step_type
            ]
            step_type = torch.tensor(step_type, dtype=torch.int32)
            reward = reward = torch.full([batch_size], float(i))
            if not train_reward_function or td_steps == -1:
                reward = reward * (step_type == ds.StepType.LAST).to(
                    torch.float32)
            time_step = time_step._replace(
                discount=(step_type != ds.StepType.LAST).to(torch.float32),
                step_type=step_type,
                observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]],
                                         dtype=torch.float32),
                reward=reward,
                env_id=torch.arange(batch_size, dtype=torch.int32))
            transformed_time_step, dt_state = data_transformer.transform_timestep(
                time_step, dt_state)
            alg_step = muzero.rollout_step(transformed_time_step, state)
            experience = ds.make_experience(time_step, alg_step, state)
            replay_buffer.add_batch(experience)
            state = alg_step.state

        env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64)
        positions = torch.tensor(list(range(14)) + list(range(14)),
                                 dtype=torch.int64)
        experience = replay_buffer.get_field(None,
                                             env_ids.unsqueeze(-1).cpu(),
                                             positions.unsqueeze(-1).cpu())
        experience = experience._replace(replay_buffer=replay_buffer,
                                         batch_info=BatchInfo(
                                             env_ids=env_ids,
                                             positions=positions),
                                         rollout_info_field='rollout_info')
        processed_experience = muzero.preprocess_experience(experience)
        import pprint
        pprint.pprint(processed_experience.rollout_info)
        alf.nest.map_structure(lambda x, y: self.assertEqual(x, y),
                               processed_experience.rollout_info, expected)
Пример #4
0
    def test_preprocess_experience(self):
        """
        The following summarizes how the data is generated:

        .. code-block:: python

            # position:   01234567890123
            step_type0 = 'FMMMLFMMLFMMMM'
            step_type1 = 'FMMMMMLFMMMMLF'
            reward = position if train_reward_function and td_steps!=-1
                     else position * (step_type == LAST)
            action = t + 1 for env 0
                     t for env 1

        """
        num_unroll_steps = 4
        batch_size = 2
        obs_dim = 3
        observation_spec = alf.TensorSpec([obs_dim])
        action_spec = alf.BoundedTensorSpec((1, ),
                                            minimum=0,
                                            maximum=1,
                                            dtype=torch.float32)
        reward_spec = alf.TensorSpec(())
        time_step_spec = ds.time_step_spec(observation_spec, action_spec,
                                           reward_spec)

        repr_learner = PredictiveRepresentationLearner(
            observation_spec,
            action_spec,
            num_unroll_steps=num_unroll_steps,
            decoder_ctor=partial(SimpleDecoder,
                                 target_field='reward',
                                 decoder_net_ctor=partial(
                                     EncodingNetwork, fc_layer_params=(4, ))),
            encoding_net_ctor=LSTMEncodingNetwork,
            dynamics_net_ctor=LSTMEncodingNetwork)

        time_step = common.zero_tensor_from_nested_spec(
            time_step_spec, batch_size)
        state = repr_learner.get_initial_predict_state(batch_size)
        alg_step = repr_learner.rollout_step(time_step, state)
        alg_step = alg_step._replace(output=torch.tensor([[1.], [0.]]))
        alg_step_spec = dist_utils.extract_spec(alg_step)

        experience = ds.make_experience(time_step, alg_step, state)
        experience_spec = ds.make_experience(time_step_spec, alg_step_spec,
                                             repr_learner.train_state_spec)
        replay_buffer = ReplayBuffer(data_spec=experience_spec,
                                     num_environments=batch_size,
                                     max_length=16,
                                     keep_episodic_info=True)

        #             01234567890123
        step_type0 = 'FMMMLFMMLFMMMM'
        step_type1 = 'FMMMMMLFMMMMLF'

        for i in range(len(step_type0)):
            step_type = [step_type0[i], step_type1[i]]
            step_type = [
                ds.StepType.MID if c == 'M' else
                (ds.StepType.FIRST if c == 'F' else ds.StepType.LAST)
                for c in step_type
            ]
            step_type = torch.tensor(step_type, dtype=torch.int32)
            reward = reward = torch.full([batch_size], float(i))
            time_step = time_step._replace(
                discount=(step_type != ds.StepType.LAST).to(torch.float32),
                step_type=step_type,
                observation=torch.tensor([[i, i + 1, i], [i + 1, i, i]],
                                         dtype=torch.float32),
                reward=reward,
                env_id=torch.arange(batch_size, dtype=torch.int32))
            alg_step = repr_learner.rollout_step(time_step, state)
            alg_step = alg_step._replace(output=i + torch.tensor([[1.], [0.]]))
            experience = ds.make_experience(time_step, alg_step, state)
            replay_buffer.add_batch(experience)
            state = alg_step.state

        env_ids = torch.tensor([0] * 14 + [1] * 14, dtype=torch.int64)
        positions = torch.tensor(list(range(14)) + list(range(14)),
                                 dtype=torch.int64)
        experience = replay_buffer.get_field(None,
                                             env_ids.unsqueeze(-1).cpu(),
                                             positions.unsqueeze(-1).cpu())
        experience = experience._replace(replay_buffer=replay_buffer,
                                         batch_info=BatchInfo(
                                             env_ids=env_ids,
                                             positions=positions),
                                         rollout_info_field='rollout_info')
        processed_experience = repr_learner.preprocess_experience(experience)
        pprint.pprint(processed_experience.rollout_info)

        # yapf: disable
        expected = PredictiveRepresentationLearnerInfo(
            action=torch.tensor(
               [[[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  5.]],
                [[ 3.,  4.,  5.,  5.,  5.]],
                [[ 4.,  5.,  5.,  5.,  5.]],
                [[ 5.,  5.,  5.,  5.,  5.]],
                [[ 6.,  7.,  8.,  9.,  9.]],
                [[ 7.,  8.,  9.,  9.,  9.]],
                [[ 8.,  9.,  9.,  9.,  9.]],
                [[ 9.,  9.,  9.,  9.,  9.]],
                [[10., 11., 12., 13., 14.]],
                [[11., 12., 13., 14., 14.]],
                [[12., 13., 14., 14., 14.]],
                [[13., 14., 14., 14., 14.]],
                [[14., 14., 14., 14., 14.]],
                [[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  6.]],
                [[ 3.,  4.,  5.,  6.,  6.]],
                [[ 4.,  5.,  6.,  6.,  6.]],
                [[ 5.,  6.,  6.,  6.,  6.]],
                [[ 6.,  6.,  6.,  6.,  6.]],
                [[ 7.,  8.,  9., 10., 11.]],
                [[ 8.,  9., 10., 11., 12.]],
                [[ 9., 10., 11., 12., 12.]],
                [[10., 11., 12., 12., 12.]],
                [[11., 12., 12., 12., 12.]],
                [[12., 12., 12., 12., 12.]],
                [[13., 13., 13., 13., 13.]]]).unsqueeze(-1),
            mask=torch.tensor(
               [[[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True,  True]],
                [[ True,  True,  True,  True, False]],
                [[ True,  True,  True, False, False]],
                [[ True,  True, False, False, False]],
                [[ True, False, False, False, False]],
                [[ True, False, False, False, False]]]),
            target=torch.tensor(
               [[[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  4.]],
                [[ 2.,  3.,  4.,  4.,  4.]],
                [[ 3.,  4.,  4.,  4.,  4.]],
                [[ 4.,  4.,  4.,  4.,  4.]],
                [[ 5.,  6.,  7.,  8.,  8.]],
                [[ 6.,  7.,  8.,  8.,  8.]],
                [[ 7.,  8.,  8.,  8.,  8.]],
                [[ 8.,  8.,  8.,  8.,  8.]],
                [[ 9., 10., 11., 12., 13.]],
                [[10., 11., 12., 13., 13.]],
                [[11., 12., 13., 13., 13.]],
                [[12., 13., 13., 13., 13.]],
                [[13., 13., 13., 13., 13.]],
                [[ 0.,  1.,  2.,  3.,  4.]],
                [[ 1.,  2.,  3.,  4.,  5.]],
                [[ 2.,  3.,  4.,  5.,  6.]],
                [[ 3.,  4.,  5.,  6.,  6.]],
                [[ 4.,  5.,  6.,  6.,  6.]],
                [[ 5.,  6.,  6.,  6.,  6.]],
                [[ 6.,  6.,  6.,  6.,  6.]],
                [[ 7.,  8.,  9., 10., 11.]],
                [[ 8.,  9., 10., 11., 12.]],
                [[ 9., 10., 11., 12., 12.]],
                [[10., 11., 12., 12., 12.]],
                [[11., 12., 12., 12., 12.]],
                [[12., 12., 12., 12., 12.]],
                [[13., 13., 13., 13., 13.]]]))
        # yapf: enable

        alf.nest.map_structure(lambda x, y: self.assertEqual(x, y),
                               processed_experience.rollout_info, expected)
Пример #5
0
    def _reanalyze1(self, replay_buffer: ReplayBuffer, env_ids, positions,
                    mcts_state_field):
        """Reanalyze one batch.

        This means:
        1. Re-plan the policy using MCTS for n1 = 1 + num_unroll_steps to get fresh policy
        and value target.
        2. Caluclate the value for following n2 = reanalyze_td_steps so that we have value
        for a total of 1 + num_unroll_steps + reanalyze_td_steps.
        3. Use these values and rewards from replay buffer to caculate n2-step
        bootstraped value target for the first n1 steps.

        In order to do 1 and 2, we need to get the observations for n1 + n2 steps
        and processs them using data_transformer.
        """
        batch_size = env_ids.shape[0]
        n1 = self._num_unroll_steps + 1
        n2 = self._reanalyze_td_steps
        env_ids, positions = self._next_n_positions(
            replay_buffer, env_ids, positions, self._num_unroll_steps + n2)
        # [B, n1]
        positions1 = positions[:, :n1]
        game_overs = replay_buffer.get_field('discount', env_ids,
                                             positions1) == 0.

        steps_to_episode_end = replay_buffer.steps_to_episode_end(
            positions1, env_ids)
        bootstrap_n = steps_to_episode_end.clamp(max=n2)

        exp1, exp2 = self._prepare_reanalyze_data(replay_buffer, env_ids,
                                                  positions, n1, n2)

        bootstrap_position = positions1 + bootstrap_n
        discount = replay_buffer.get_field('discount', env_ids,
                                           bootstrap_position)
        sum_reward = self._sum_discounted_reward(replay_buffer, env_ids,
                                                 positions1,
                                                 bootstrap_position, n2)

        if not self._train_reward_function:
            rewards = self._get_reward(replay_buffer, env_ids,
                                       bootstrap_position)

        with alf.device(self._device):
            bootstrap_n = convert_device(bootstrap_n)
            discount = convert_device(discount)
            sum_reward = convert_device(sum_reward)
            game_overs = convert_device(game_overs)

            # 1. Reanalyze the first n1 steps to get both the updated value and policy
            self._mcts.set_model(self._target_model)
            mcts_step = self._mcts.predict_step(
                exp1, alf.nest.get_field(exp1, mcts_state_field))
            self._mcts.set_model(self._model)
            candidate_actions = ()
            if not _is_empty(mcts_step.info.candidate_actions):
                candidate_actions = mcts_step.info.candidate_actions
                candidate_actions = candidate_actions.reshape(
                    batch_size, n1, *candidate_actions.shape[1:])
            candidate_action_policy = mcts_step.info.candidate_action_policy
            candidate_action_policy = candidate_action_policy.reshape(
                batch_size, n1, *candidate_action_policy.shape[1:])
            values1 = mcts_step.info.value.reshape(batch_size, n1)

            # 2. Calulate the value of the next n2 steps so that n2-step return
            # can be computed.
            model_output = self._target_model.initial_inference(
                exp2.observation)
            values2 = model_output.value.reshape(batch_size, n2)

            # 3. Calculate n2-step return
            values = torch.cat([values1, values2], dim=1)
            # [B, n1]
            bootstrap_pos = torch.arange(n1).unsqueeze(0) + bootstrap_n
            values = values[torch.arange(batch_size).unsqueeze(-1),
                            bootstrap_pos]
            values = values * discount * (self._discount**bootstrap_n.to(
                torch.float32))
            values = values + sum_reward
            if not self._train_reward_function:
                # For this condition, we need to set the value at and after the
                # last step to be the last reward.
                values = torch.where(game_overs, convert_device(rewards),
                                     values)
            return candidate_actions, candidate_action_policy, values