Exemple #1
0
    def train_step(self, exp: Experience, state, trainable=True):
        """This function trains the discriminator or generates intrinsic rewards.

        If ``trainable=True``, then it only generates and returns the pred loss;
        otherwise it only generates rewards with no grad.
        """
        # Discriminator training from its own replay buffer  or
        # Discriminator computing intrinsic rewards for training lower_rl
        untrans_observation, prev_skill, switch_skill, steps = exp.observation

        observation = self._observation_transformer(untrans_observation)
        loss = self._predict_skill_loss(observation, exp.prev_action,
                                        prev_skill, steps, state)

        first_observation = self._update_state_if_necessary(
            switch_skill, observation, state.first_observation)
        subtrajectory = self._clear_subtrajectory_if_necessary(
            state.subtrajectory, switch_skill)
        new_state = DiscriminatorState(first_observation=first_observation,
                                       untrans_observation=untrans_observation,
                                       subtrajectory=subtrajectory)

        valid_masks = (exp.step_type != StepType.FIRST)
        if self._sparse_reward:
            # Only give intrinsic rewards at the last step of the skill
            valid_masks &= switch_skill
        loss *= valid_masks.to(torch.float32)

        if trainable:
            info = LossInfo(loss=loss, extra=dict(discriminator_loss=loss))
            return AlgStep(state=new_state, info=info)
        else:
            intrinsic_reward = -loss.detach() / self._skill_dim
            return AlgStep(state=common.detach(new_state),
                           info=intrinsic_reward)
Exemple #2
0
 def _buffer_sampler(self, x, y):
     batch_size = get_nest_batch_size(y)
     if self._y_buffer.current_size >= batch_size:
         y1 = self._y_buffer.get_batch(batch_size)
         self._y_buffer.add_batch(y)
     else:
         self._y_buffer.add_batch(y)
         y1 = self._y_buffer.get_batch(batch_size)
     return x, common.detach(y1)
    def rollout_step(self, time_step: TimeStep, state: ActorCriticState):
        """Rollout for one step."""
        value, value_state = self._value_network(time_step.observation,
                                                 state=state.value)

        # We detach exp.observation here so that in the case that exp.observation
        # is calculated by some other trainable module, the training of that
        # module will not be affected by the gradient back-propagated from the
        # actor. However, the gradient from critic will still affect the training
        # of that module.
        action_distribution, actor_state = self._actor_network(
            common.detach(time_step.observation), state=state.actor)

        action = dist_utils.sample_action_distribution(action_distribution)
        return AlgStep(output=action,
                       state=ActorCriticState(actor=actor_state,
                                              value=value_state),
                       info=ActorCriticInfo(
                           value=value,
                           action_distribution=action_distribution))
Exemple #4
0
    def train_step(self, exp: Experience, state: SacState):
        # We detach exp.observation here so that in the case that exp.observation
        # is calculated by some other trainable module, the training of that
        # module will not be affected by the gradient back-propagated from the
        # actor. However, the gradient from critic will still affect the training
        # of that module.
        (action_distribution, action, critics,
         action_state) = self._predict_action(common.detach(exp.observation),
                                              state=state.action)

        log_pi = nest.map_structure(lambda dist, a: dist.log_prob(a),
                                    action_distribution, action)

        if self._act_type == ActionType.Mixed:
            # For mixed type, add log_pi separately
            log_pi = type(self._action_spec)(
                (sum(nest.flatten(log_pi[0])), sum(nest.flatten(log_pi[1]))))
        else:
            log_pi = sum(nest.flatten(log_pi))

        if self._prior_actor is not None:
            prior_step = self._prior_actor.train_step(exp, ())
            log_prior = dist_utils.compute_log_probability(
                prior_step.output, action)
            log_pi = log_pi - log_prior

        actor_state, actor_loss = self._actor_train_step(
            exp, state.actor, action, critics, log_pi, action_distribution)
        critic_state, critic_info = self._critic_train_step(
            exp, state.critic, action, log_pi, action_distribution)
        alpha_loss = self._alpha_train_step(log_pi)

        state = SacState(action=action_state,
                         actor=actor_state,
                         critic=critic_state)
        info = SacInfo(action_distribution=action_distribution,
                       actor=actor_loss,
                       critic=critic_info,
                       alpha=alpha_loss)
        return AlgStep(action, state, info)
Exemple #5
0
    def unroll(self, unroll_length):
        r"""Unroll ``unroll_length`` steps using the current policy.

        Because the ``self._env`` is a batched environment. The total number of
        environment steps is ``self._env.batch_size * unroll_length``.

        Args:
            unroll_length (int): number of steps to unroll.
        Returns:
            Experience: The stacked experience with shape :math:`[T, B, \ldots]`
            for each of its members.
        """
        if self._current_time_step is None:
            self._current_time_step = common.get_initial_time_step(self._env)
        if self._current_policy_state is None:
            self._current_policy_state = self.get_initial_rollout_state(
                self._env.batch_size)
        if self._current_transform_state is None:
            self._current_transform_state = self.get_initial_transform_state(
                self._env.batch_size)

        time_step = self._current_time_step
        policy_state = self._current_policy_state
        trans_state = self._current_transform_state

        experience_list = []
        initial_state = self.get_initial_rollout_state(self._env.batch_size)

        env_step_time = 0.
        store_exp_time = 0.
        for _ in range(unroll_length):
            policy_state = common.reset_state_if_necessary(
                policy_state, initial_state, time_step.is_first())
            transformed_time_step, trans_state = self.transform_timestep(
                time_step, trans_state)
            # save the untransformed time step in case that sub-algorithms need
            # to store it in replay buffers
            transformed_time_step = transformed_time_step._replace(
                untransformed=time_step)
            policy_step = self.rollout_step(transformed_time_step,
                                            policy_state)
            # release the reference to ``time_step``
            transformed_time_step = transformed_time_step._replace(
                untransformed=())

            action = common.detach(policy_step.output)

            t0 = time.time()
            next_time_step = self._env.step(action)
            env_step_time += time.time() - t0

            self.observe_for_metrics(time_step.cpu())

            if self._exp_replayer_type == "one_time":
                exp = make_experience(transformed_time_step, policy_step,
                                      policy_state)
            else:
                exp = make_experience(time_step.cpu(), policy_step,
                                      policy_state)

            t0 = time.time()
            self.observe_for_replay(exp)
            store_exp_time += time.time() - t0

            exp_for_training = Experience(
                action=action,
                reward=transformed_time_step.reward,
                discount=transformed_time_step.discount,
                step_type=transformed_time_step.step_type,
                state=policy_state,
                prev_action=transformed_time_step.prev_action,
                observation=transformed_time_step.observation,
                rollout_info=dist_utils.distributions_to_params(
                    policy_step.info),
                env_id=transformed_time_step.env_id)

            experience_list.append(exp_for_training)
            time_step = next_time_step
            policy_state = policy_step.state

        alf.summary.scalar("time/unroll_env_step", env_step_time)
        alf.summary.scalar("time/unroll_store_exp", store_exp_time)
        experience = alf.nest.utils.stack_nests(experience_list)
        experience = experience._replace(
            rollout_info=dist_utils.params_to_distributions(
                experience.rollout_info, self._rollout_info_spec))

        self._current_time_step = time_step
        # Need to detach so that the graph from this unroll is disconnected from
        # the next unroll. Otherwise backward() will report error for on-policy
        # training after the next unroll.
        self._current_policy_state = common.detach(policy_state)
        self._current_transform_state = common.detach(trans_state)

        return experience