예제 #1
0
    def train_step(self, exp: Experience, state):
        new_state = AgentState()
        info = AgentInfo()

        skill_generator_info = exp.rollout_info.skill_generator

        subtrajectory = self._skill_generator.update_disc_subtrajectory(
            exp, state.skill_generator)
        skill_step = self._skill_generator.train_step(
            exp._replace(rollout_info=skill_generator_info),
            state.skill_generator)
        info = info._replace(skill_generator=skill_step.info)
        new_state = new_state._replace(skill_generator=skill_step.state)

        exp = transform_nest(exp, "observation", self._observation_transformer)

        observation = self._make_low_level_observation(
            subtrajectory, skill_step.output,
            skill_generator_info.switch_skill, skill_generator_info.steps,
            skill_step.state.discriminator.first_observation)

        rl_step = self._rl_algorithm.train_step(
            exp._replace(observation=observation,
                         rollout_info=exp.rollout_info.rl), state.rl)

        new_state = new_state._replace(rl=rl_step.state)
        info = info._replace(rl=rl_step.info)

        return AlgStep(output=rl_step.output, state=new_state, info=info)
예제 #2
0
 def preprocess_experience(self, exp: Experience):
     ac_info = exp.rollout_info.ac._replace(
         action_distribution=exp.rollout_info.action_distribution)
     new_exp = self._ac_algorithm.preprocess_experience(
         exp._replace(rollout_info=ac_info))
     return new_exp._replace(rollout_info=exp.rollout_info._replace(
         ac=new_exp.rollout_info))
    def preprocess_experience(self, experience: Experience):
        """Fill experience.rollout_info with PredictiveRepresentationLearnerInfo

        Note that the shape of experience is [B, T, ...].

        The target is a Tensor (or a nest of Tensors) when there is only one
        decoder. When there are multiple decorders, the target is a list,
        and each of its element is a Tensor (or a nest of Tensors), which is
        used as the target for the corresponding decoder.

        """
        assert experience.batch_info != ()
        batch_info: BatchInfo = experience.batch_info
        replay_buffer: ReplayBuffer = experience.replay_buffer
        mini_batch_length = experience.step_type.shape[1]

        with alf.device(replay_buffer.device):
            # [B, 1]
            positions = convert_device(batch_info.positions).unsqueeze(-1)
            # [B, 1]
            env_ids = convert_device(batch_info.env_ids).unsqueeze(-1)

            # [B, T]
            positions = positions + torch.arange(mini_batch_length)

            # [B, T]
            steps_to_episode_end = replay_buffer.steps_to_episode_end(
                positions, env_ids)
            # [B, T]
            episode_end_positions = positions + steps_to_episode_end

            # [B, T, unroll_steps+1]
            positions = positions.unsqueeze(-1) + torch.arange(
                self._num_unroll_steps + 1)
            # [B, 1, 1]
            env_ids = env_ids.unsqueeze(-1)
            # [B, T, 1]
            episode_end_positions = episode_end_positions.unsqueeze(-1)

            # [B, T, unroll_steps+1]
            mask = positions <= episode_end_positions

            # [B, T, unroll_steps+1]
            positions = torch.min(positions, episode_end_positions)

            # [B, T, unroll_steps+1, ...]
            target = replay_buffer.get_field(self._target_fields, env_ids,
                                             positions)

            # [B, T, unroll_steps+1]
            action = replay_buffer.get_field('action', env_ids, positions)

            rollout_info = PredictiveRepresentationLearnerInfo(action=action,
                                                               mask=mask,
                                                               target=target)

        rollout_info = convert_device(rollout_info)

        return experience._replace(rollout_info=rollout_info)
예제 #4
0
    def train_step(self, exp: Experience, state):
        new_state = AgentState()
        info = AgentInfo()
        observation = exp.observation

        if self._representation_learner is not None:
            repr_step = self._representation_learner.train_step(
                exp._replace(rollout_info=exp.rollout_info.repr), state.repr)
            new_state = new_state._replace(repr=repr_step.state)
            info = info._replace(repr=repr_step.info)
            observation = repr_step.output

        if self._goal_generator is not None:
            goal_step = self._goal_generator.train_step(
                exp._replace(
                    observation=observation,
                    rollout_info=exp.rollout_info.goal_generator),
                state.goal_generator)
            info = info._replace(goal_generator=goal_step.info)
            new_state = new_state._replace(goal_generator=goal_step.state)
            observation = [observation, goal_step.output]

        if self._irm is not None:
            irm_step = self._irm.train_step(
                exp._replace(observation=observation), state=state.irm)
            info = info._replace(irm=irm_step.info)
            new_state = new_state._replace(irm=irm_step.state)

        rl_step = self._rl_algorithm.train_step(
            exp._replace(
                observation=observation, rollout_info=exp.rollout_info.rl),
            state.rl)

        new_state = new_state._replace(rl=rl_step.state)
        info = info._replace(rl=rl_step.info)

        if self._entropy_target_algorithm:
            assert 'action_distribution' in rl_step.info._fields, (
                "PolicyStep from rl_algorithm.train_step() does not contain "
                "`action_distribution`, which is required by "
                "`enforce_entropy_target`")
            et_step = self._entropy_target_algorithm.train_step(
                rl_step.info.action_distribution, step_type=exp.step_type)
            info = info._replace(entropy_target=et_step.info)

        return AlgStep(output=rl_step.output, state=new_state, info=info)
예제 #5
0
 def preprocess_experience(self, exp: Experience):
     self.summarize_reward("training_reward/extrinsic", exp.reward)
     # relabel exp with intrinsic rewards by goal generator
     skill_rollout_info = exp.rollout_info.skill_generator
     new_exp = self._rl_algorithm.preprocess_experience(
         exp._replace(reward=skill_rollout_info.reward,
                      discount=exp.discount *
                      exp.rollout_info.skill_discount,
                      rollout_info=exp.rollout_info.rl))
     return new_exp._replace(rollout_info=exp.rollout_info._replace(
         rl=new_exp.rollout_info))
예제 #6
0
    def preprocess_experience(self, experience: Experience):
        """Fill experience.rollout_info with PredictiveRepresentationLearnerInfo

        Note that the shape of experience is [B, T, ...]
        """
        assert experience.batch_info != ()
        batch_info: BatchInfo = experience.batch_info
        replay_buffer: ReplayBuffer = experience.replay_buffer
        mini_batch_length = experience.step_type.shape[1]

        with alf.device(replay_buffer.device):
            # [B, 1]
            positions = convert_device(batch_info.positions).unsqueeze(-1)
            # [B, 1]
            env_ids = convert_device(batch_info.env_ids).unsqueeze(-1)

            # [B, T]
            positions = positions + torch.arange(mini_batch_length)

            # [B, T]
            steps_to_episode_end = replay_buffer.steps_to_episode_end(
                positions, env_ids)
            # [B, T]
            episode_end_positions = positions + steps_to_episode_end

            # [B, T, unroll_steps+1]
            positions = positions.unsqueeze(-1) + torch.arange(
                self._num_unroll_steps + 1)
            # [B, 1, 1]
            env_ids = env_ids.unsqueeze(-1)
            # [B, T, 1]
            episode_end_positions = episode_end_positions.unsqueeze(-1)

            # [B, T, unroll_steps+1]
            mask = positions <= episode_end_positions

            # [B, T, unroll_steps+1]
            positions = torch.min(positions, episode_end_positions)

            # [B, T, unroll_steps+1]
            target = replay_buffer.get_field(self._target_fields, env_ids,
                                             positions)

            # [B, T, unroll_steps+1]
            action = replay_buffer.get_field('action', env_ids, positions)

            rollout_info = PredictiveRepresentationLearnerInfo(action=action,
                                                               mask=mask,
                                                               target=target)

        rollout_info = convert_device(rollout_info)

        return experience._replace(rollout_info=rollout_info)
예제 #7
0
    def preprocess_experience(self, exp: Experience):
        """Add intrinsic rewards to extrinsic rewards if there is an intrinsic
        reward module. Also call ``preprocess_experience()`` of the rl
        algorithm.
        """
        reward = self.calc_training_reward(exp.reward, exp.rollout_info)
        exp = exp._replace(reward=reward)

        if self._representation_learner:
            new_exp = self._representation_learner.preprocess_experience(
                exp._replace(rollout_info=exp.rollout_info.repr,
                             rollout_info_field=exp.rollout_info_field +
                             '.repr'))
            exp = new_exp._replace(rollout_info=exp.rollout_info._replace(
                repr=new_exp.rollout_info))

        new_exp = self._rl_algorithm.preprocess_experience(
            exp._replace(rollout_info=exp.rollout_info.rl,
                         rollout_info_field=exp.rollout_info_field + '.rl'))
        return new_exp._replace(
            rollout_info=exp.rollout_info._replace(rl=new_exp.rollout_info),
            rollout_info_field=exp.rollout_info_field)
예제 #8
0
    def observe(self, exp: Experience):
        """An algorithm can override to manipulate experience.

        Args:
            exp (Experience): The shapes can be either [Q, T, B, ...] or
                [B, ...], where Q is `learn_queue_cap` in `AsyncOffPolicyDriver`,
                T is the sequence length, and B is the batch size of the batched
                environment.
        """
        if not self._use_rollout_state:
            exp = exp._replace(state=())
        exp = nest_utils.distributions_to_params(exp)
        for observer in self._exp_observers:
            observer(exp)
예제 #9
0
 def preprocess_experience(self, exp: Experience):
     """Compute advantages and put it into exp.rollout_info."""
     advantages = value_ops.generalized_advantage_estimation(
         rewards=exp.reward,
         values=exp.rollout_info.value,
         step_types=exp.step_type,
         discounts=exp.discount * self._loss._gamma,
         td_lambda=self._loss._lambda,
         time_major=False)
     advantages = torch.cat([
         advantages,
         torch.zeros(*advantages.shape[:-1], 1, dtype=advantages.dtype)
     ],
                            dim=-1)
     returns = exp.rollout_info.value + advantages
     return exp._replace(rollout_info=PPOInfo(
         exp.rollout_info.action_distribution, returns, advantages))
예제 #10
0
    def train_step(self, rl_exp: Experience, state: ActionRepeatState):
        """Train the underlying RL algorithm ``self._rl``. Because in
        ``self.rollout_step()`` the replay buffer only stores info related to
        ``self._rl``, here we can directly call ``self._rl.train_step()``.

        Args:
            rl_exp (Experience): experiences that have been transformed to be
                learned by ``self._rl``.
            state (ActionRepeatState):
        """
        repr_state = ()
        if self._repr_learner is not None:
            repr_step = self._repr_learner.train_step(rl_exp, state.repr)
            rl_exp = rl_exp._replace(observation=repr_step.output)
            repr_state = repr_step.state

        rl_step = self._rl.train_step(rl_exp, state.rl)
        new_state = ActionRepeatState(rl=rl_step.state, repr=repr_state)
        return rl_step._replace(state=new_state)
예제 #11
0
 def train_step(self, exp: Experience, state):
     ac_info = exp.rollout_info.ac._replace(
         action_distribution=exp.rollout_info.action_distribution)
     exp = exp._replace(rollout_info=ac_info)
     policy_step = self._ac_algorithm.train_step(exp, state)
     return self._make_policy_step(exp, state, policy_step)
예제 #12
0
    def unroll(self, unroll_length):
        r"""Unroll ``unroll_length`` steps using the current policy.

        Because the ``self._env`` is a batched environment. The total number of
        environment steps is ``self._env.batch_size * unroll_length``.

        Args:
            unroll_length (int): number of steps to unroll.
        Returns:
            Experience: The stacked experience with shape :math:`[T, B, \ldots]`
            for each of its members.
        """
        if self._current_time_step is None:
            self._current_time_step = common.get_initial_time_step(self._env)
        if self._current_policy_state is None:
            self._current_policy_state = self.get_initial_rollout_state(
                self._env.batch_size)
        if self._current_transform_state is None:
            self._current_transform_state = self.get_initial_transform_state(
                self._env.batch_size)

        time_step = self._current_time_step
        policy_state = self._current_policy_state
        trans_state = self._current_transform_state

        experience_list = []
        initial_state = self.get_initial_rollout_state(self._env.batch_size)

        env_step_time = 0.
        store_exp_time = 0.
        for _ in range(unroll_length):
            policy_state = common.reset_state_if_necessary(
                policy_state, initial_state, time_step.is_first())
            transformed_time_step, trans_state = self.transform_timestep(
                time_step, trans_state)
            # save the untransformed time step in case that sub-algorithms need
            # to store it in replay buffers
            transformed_time_step = transformed_time_step._replace(
                untransformed=time_step)
            policy_step = self.rollout_step(transformed_time_step,
                                            policy_state)
            # release the reference to ``time_step``
            transformed_time_step = transformed_time_step._replace(
                untransformed=())

            action = common.detach(policy_step.output)

            t0 = time.time()
            next_time_step = self._env.step(action)
            env_step_time += time.time() - t0

            self.observe_for_metrics(time_step.cpu())

            if self._exp_replayer_type == "one_time":
                exp = make_experience(transformed_time_step, policy_step,
                                      policy_state)
            else:
                exp = make_experience(time_step.cpu(), policy_step,
                                      policy_state)

            t0 = time.time()
            self.observe_for_replay(exp)
            store_exp_time += time.time() - t0

            exp_for_training = Experience(
                action=action,
                reward=transformed_time_step.reward,
                discount=transformed_time_step.discount,
                step_type=transformed_time_step.step_type,
                state=policy_state,
                prev_action=transformed_time_step.prev_action,
                observation=transformed_time_step.observation,
                rollout_info=dist_utils.distributions_to_params(
                    policy_step.info),
                env_id=transformed_time_step.env_id)

            experience_list.append(exp_for_training)
            time_step = next_time_step
            policy_state = policy_step.state

        alf.summary.scalar("time/unroll_env_step", env_step_time)
        alf.summary.scalar("time/unroll_store_exp", store_exp_time)
        experience = alf.nest.utils.stack_nests(experience_list)
        experience = experience._replace(
            rollout_info=dist_utils.params_to_distributions(
                experience.rollout_info, self._rollout_info_spec))

        self._current_time_step = time_step
        # Need to detach so that the graph from this unroll is disconnected from
        # the next unroll. Otherwise backward() will report error for on-policy
        # training after the next unroll.
        self._current_policy_state = common.detach(policy_state)
        self._current_transform_state = common.detach(trans_state)

        return experience
예제 #13
0
    def transform_experience(self, experience: Experience):
        if self._stack_size == 1:
            return experience

        assert experience.batch_info != ()
        batch_info: BatchInfo = experience.batch_info
        replay_buffer: ReplayBuffer = experience.replay_buffer

        with alf.device(replay_buffer.device):
            # [B]
            env_ids = convert_device(batch_info.env_ids)
            # [B]
            positions = convert_device(batch_info.positions)

            prev_positions = torch.arange(self._stack_size -
                                          1) - self._stack_size + 1

            # [B, stack_size - 1]
            prev_positions = positions.unsqueeze(
                -1) + prev_positions.unsqueeze(0)
            episode_begin_positions = replay_buffer.get_episode_begin_position(
                positions, env_ids)
            # [B, 1]
            episode_begin_positions = episode_begin_positions.unsqueeze(-1)
            # [B, stack_size - 1]
            prev_positions = torch.max(prev_positions, episode_begin_positions)
            # [B, 1]
            env_ids = env_ids.unsqueeze(-1)
            assert torch.all(
                prev_positions[:, 0] >= replay_buffer.get_earliest_position(
                    env_ids)
            ), ("Some previous posisions are no longer in the replay buffer")

        batch_size, mini_batch_length = experience.step_type.shape

        # [[0, 1, ..., stack_size-1],
        #  [1, 2, ..., stack_size],
        #  ...
        #  [mini_batch_length - 1, ...]]
        #
        # [mini_batch_length, stack_size]
        obs_index = (torch.arange(self._stack_size).unsqueeze(0) +
                     torch.arange(mini_batch_length).unsqueeze(1))
        B = torch.arange(batch_size)
        obs_index = (B.unsqueeze(-1).unsqueeze(-1), obs_index.unsqueeze(0))

        def _stack_frame(obs, i):
            prev_obs = replay_buffer.get_field(self._exp_fields[i], env_ids,
                                               prev_positions)
            prev_obs = convert_device(prev_obs)
            stacked_shape = alf.nest.get_field(
                self._transformed_observation_spec, self._fields[i]).shape
            # [batch_size, mini_batch_length + stack_size - 1, ...]
            stacked_obs = torch.cat((prev_obs, obs), dim=1)
            # [batch_size, mini_batch_length, stack_size, ...]
            stacked_obs = stacked_obs[obs_index]
            if self._stack_axis != 0 and obs.ndim > 3:
                stack_axis = self._stack_axis
                if stack_axis < 0:
                    stack_axis += stacked_obs.ndim
                else:
                    stack_axis += 3
                stacked_obs = stacked_obs.unsqueeze(stack_axis)
                stacked_obs = stacked_obs.transpose(2, stack_axis)
                stacked_obs = stacked_obs.squeeze(2)
            stacked_obs = stacked_obs.reshape(batch_size, mini_batch_length,
                                              *stacked_shape)
            return stacked_obs

        observation = experience.observation
        for i, field in enumerate(self._fields):
            observation = alf.nest.transform_nest(observation, field,
                                                  partial(_stack_frame, i=i))
        return experience._replace(observation=observation)
예제 #14
0
    def preprocess_experience(self, experience: Experience):
        """Fill experience.rollout_info with MuzeroInfo

        Note that the shape of experience is [B, T, ...]
        """
        assert experience.batch_info != ()
        batch_info: BatchInfo = experience.batch_info
        replay_buffer: ReplayBuffer = experience.replay_buffer
        info_path: str = experience.rollout_info_field
        mini_batch_length = experience.step_type.shape[1]
        assert mini_batch_length == 1, (
            "Only support TrainerConfig.mini_batch_length=1, got %s" %
            mini_batch_length)

        value_field = info_path + '.value'
        candidate_actions_field = info_path + '.candidate_actions'
        candidate_action_policy_field = (info_path +
                                         '.candidate_action_policy')

        with alf.device(replay_buffer.device):
            positions = convert_device(batch_info.positions)  # [B]
            env_ids = convert_device(batch_info.env_ids)  # [B]

            if self._reanalyze_ratio > 0:
                # Here we assume state and info have similar name scheme.
                mcts_state_field = 'state' + info_path[len('rollout_info'):]
                r = torch.rand(
                    experience.step_type.shape[0]) < self._reanalyze_ratio
                r_candidate_actions, r_candidate_action_policy, r_values = self._reanalyze(
                    replay_buffer, env_ids[r], positions[r], mcts_state_field)

            # [B]
            steps_to_episode_end = replay_buffer.steps_to_episode_end(
                positions, env_ids)
            # [B]
            episode_end_positions = positions + steps_to_episode_end

            # [B, unroll_steps+1]
            positions = positions.unsqueeze(-1) + torch.arange(
                self._num_unroll_steps + 1)
            # [B, 1]
            env_ids = batch_info.env_ids.unsqueeze(-1)
            # [B, 1]
            episode_end_positions = episode_end_positions.unsqueeze(-1)

            beyond_episode_end = positions > episode_end_positions
            positions = torch.min(positions, episode_end_positions)

            if self._td_steps >= 0:
                values = self._calc_bootstrap_return(replay_buffer, env_ids,
                                                     positions, value_field)
            else:
                values = self._calc_monte_carlo_return(replay_buffer, env_ids,
                                                       positions, value_field)

            candidate_actions = replay_buffer.get_field(
                candidate_actions_field, env_ids, positions)
            candidate_action_policy = replay_buffer.get_field(
                candidate_action_policy_field, env_ids, positions)

            if self._reanalyze_ratio > 0:
                if not _is_empty(candidate_actions):
                    candidate_actions[r] = r_candidate_actions
                candidate_action_policy[r] = r_candidate_action_policy
                values[r] = r_values

            game_overs = ()
            if self._train_game_over_function or self._train_reward_function:
                game_overs = positions == episode_end_positions
                discount = replay_buffer.get_field('discount', env_ids,
                                                   positions)
                # In the case of discount != 0, the game over may not always be correct
                # since the episode is truncated because of TimeLimit or incomplete
                # last episode in the replay buffer. There is no way to know for sure
                # the future game overs.
                game_overs = game_overs & (discount == 0.)

            rewards = ()
            if self._train_reward_function:
                rewards = self._get_reward(replay_buffer, env_ids, positions)
                rewards[beyond_episode_end] = 0.
                values[game_overs] = 0.

            if not self._train_game_over_function:
                game_overs = ()

            action = replay_buffer.get_field('action', env_ids,
                                             positions[:, :-1])

            rollout_info = MuzeroInfo(
                action=action,
                value=(),
                target=ModelTarget(reward=rewards,
                                   action=candidate_actions,
                                   action_policy=candidate_action_policy,
                                   value=values,
                                   game_over=game_overs))

        # make the shape to [B, T, ...], where T=1
        rollout_info = alf.nest.map_structure(lambda x: x.unsqueeze(1),
                                              rollout_info)
        rollout_info = convert_device(rollout_info)
        rollout_info = rollout_info._replace(
            value=experience.rollout_info.value)

        if self._reward_normalizer:
            experience = experience._replace(
                reward=rollout_info.target.reward[:, :, 0])
        return experience._replace(rollout_info=rollout_info)