Example #1
0
 def test_log_probability_(self):
     M = td.categorical.Categorical(
         torch.Tensor([[0.25, 0.75], [0.5, 0.5], [0.75, 0.25]]))
     actions = torch.Tensor([0]).repeat(3)
     expected = torch.Tensor([-1.38629436, -0.6931471, -0.287682])
     obtained = compute_log_probability(M, actions)
     np.testing.assert_array_almost_equal(expected, obtained)
Example #2
0
 def test_log_probability(self):
     m = td.categorical.Categorical(torch.Tensor([0.25, 0.75]))
     M = m.expand([2, 3])
     actions = torch.Tensor([1]).repeat(2, 3)
     expected = torch.Tensor([[-0.287682, -0.287682, -0.287682],
                              [-0.287682, -0.287682, -0.287682]])
     obtained = compute_log_probability(M, actions)
     np.testing.assert_array_almost_equal(expected, obtained)
Example #3
0
    def _train_step(self, time_step: TimeStep, state: SarsaState):
        not_first_step = time_step.step_type != StepType.FIRST
        prev_critics, critic_states = self._critic_networks(
            (state.prev_observation, time_step.prev_action), state.critics)

        critic_states = common.reset_state_if_necessary(
            state.critics, critic_states, not_first_step)

        action_distribution, action, actor_state, noise_state = self._get_action(
            self._actor_network, time_step, state)

        critics, _ = self._critic_networks((time_step.observation, action),
                                           critic_states)
        critic = critics.min(dim=1)[0]
        dqda = nest_utils.grad(action, critic.sum())

        def actor_loss_fn(dqda, action):
            if self._dqda_clipping:
                dqda = dqda.clamp(-self._dqda_clipping, self._dqda_clipping)
            loss = 0.5 * losses.element_wise_squared_loss(
                (dqda + action).detach(), action)
            loss = loss.sum(list(range(1, loss.ndim)))
            return loss

        actor_loss = nest_map(actor_loss_fn, dqda, action)
        actor_loss = math_ops.add_n(alf.nest.flatten(actor_loss))

        neg_entropy = ()
        if self._log_alpha is not None:
            neg_entropy = dist_utils.compute_log_probability(
                action_distribution, action)

        target_critics, target_critic_states = self._target_critic_networks(
            (time_step.observation, action), state.target_critics)

        info = SarsaInfo(action_distribution=action_distribution,
                         actor_loss=actor_loss,
                         critics=prev_critics,
                         neg_entropy=neg_entropy,
                         target_critics=target_critics.min(dim=1)[0])

        rl_state = SarsaState(noise=noise_state,
                              prev_observation=time_step.observation,
                              prev_step_type=time_step.step_type,
                              actor=actor_state,
                              critics=critic_states,
                              target_critics=target_critic_states)

        return AlgStep(action, rl_state, info)
Example #4
0
    def _predict_skill_loss(self, observation, prev_action, prev_skill, steps,
                            state):
        # steps -> {1,2,3}
        if self._skill_type == "action":
            subtrajectory = (state.first_observation, prev_action)
        elif self._skill_type == "action_difference":
            action_difference = prev_action - state.subtrajectory[:, 1, :]
            subtrajectory = (state.first_observation, action_difference)
        elif self._skill_type == "state_action":
            subtrajectory = (observation, prev_action)
        elif self._skill_type == "state":
            subtrajectory = observation
        elif self._skill_type == "state_difference":
            subtrajectory = observation - state.untrans_observation
        elif "concatenation" in self._skill_type:
            subtrajectory = alf.nest.map_structure(
                lambda traj: traj.reshape(observation.shape[0], -1),
                state.subtrajectory)
            if is_action_skill(self._skill_type):
                subtrajectory = (state.first_observation,
                                 subtrajectory.prev_action)
            else:
                subtrajectory = subtrajectory.observation

        if self._skill_encoder is not None:
            steps = self._num_steps_per_skill - steps
            if not isinstance(subtrajectory, tuple):
                subtrajectory = (subtrajectory, )
            subtrajectory = subtrajectory + (steps, )
            with torch.no_grad():
                prev_skill, _ = self._skill_encoder((prev_skill, steps))

        if isinstance(self._skill_discriminator, EncodingNetwork):
            pred_skill, _ = self._skill_discriminator(subtrajectory)
            loss = torch.sum(losses.element_wise_squared_loss(
                prev_skill, pred_skill),
                             dim=-1)
        else:
            pred_skill_dist, _ = self._skill_discriminator(subtrajectory)
            loss = -dist_utils.compute_log_probability(pred_skill_dist,
                                                       prev_skill)
        return loss
Example #5
0
    def train_step(self, exp: Experience, state: SacState):
        # We detach exp.observation here so that in the case that exp.observation
        # is calculated by some other trainable module, the training of that
        # module will not be affected by the gradient back-propagated from the
        # actor. However, the gradient from critic will still affect the training
        # of that module.
        (action_distribution, action, critics,
         action_state) = self._predict_action(common.detach(exp.observation),
                                              state=state.action)

        log_pi = nest.map_structure(lambda dist, a: dist.log_prob(a),
                                    action_distribution, action)

        if self._act_type == ActionType.Mixed:
            # For mixed type, add log_pi separately
            log_pi = type(self._action_spec)(
                (sum(nest.flatten(log_pi[0])), sum(nest.flatten(log_pi[1]))))
        else:
            log_pi = sum(nest.flatten(log_pi))

        if self._prior_actor is not None:
            prior_step = self._prior_actor.train_step(exp, ())
            log_prior = dist_utils.compute_log_probability(
                prior_step.output, action)
            log_pi = log_pi - log_prior

        actor_state, actor_loss = self._actor_train_step(
            exp, state.actor, action, critics, log_pi, action_distribution)
        critic_state, critic_info = self._critic_train_step(
            exp, state.critic, action, log_pi, action_distribution)
        alpha_loss = self._alpha_train_step(log_pi)

        state = SacState(action=action_state,
                         actor=actor_state,
                         critic=critic_state)
        info = SacInfo(action_distribution=action_distribution,
                       actor=actor_loss,
                       critic=critic_info,
                       alpha=alpha_loss)
        return AlgStep(action, state, info)
Example #6
0
 def _pg_loss(self, experience, train_info, advantages):
     action_log_prob = dist_utils.compute_log_probability(
         train_info.action_distribution, experience.action)
     return -advantages * action_log_prob
Example #7
0
def action_importance_ratio(action_distribution, collect_action_distribution,
                            action, clipping_mode, scope,
                            importance_ratio_clipping, log_prob_clipping,
                            check_numerics, debug_summaries):
    """ ratio for importance sampling, used in PPO loss and vtrace loss.

        Caller has to save tf.name_scope() and pass scope to this function.

        Args:
            action_distribution (nested tf.distribution): Distribution over
                actions under target policy.
            collect_action_distribution (nested tf.distribution): distribution
                over actions from behavior policy, used to sample actions for
                the rollout.
            action (nested tf.distribution): possibly batched action tuple
                taken during rollout.
            clipping_mode (str): mode for clipping the importance ratio.
                'double_sided': clips the range of importance ratio into
                    [1-importance_ratio_clipping, 1+importance_ratio_clipping],
                    which is used by PPOLoss.
                'capping': clips the range of importance ratio into
                    min(1+importance_ratio_clipping, importance_ratio),
                    which is used by VTraceLoss, where c_bar or rho_bar =
                    1+importance_ratio_clipping.
            scope (name scope manager): returned by tf.name_scope(), set
                outside.
            importance_ratio_clipping (float):  Epsilon in clipped, surrogate
                PPO objective. See the cited paper for more detail.
            log_prob_clipping (float): If >0, clipping log probs to the range
                (-log_prob_clipping, log_prob_clipping) to prevent inf / NaN
                values.
            check_numerics (bool):  If true, adds tf.debugging.check_numerics to
                help find NaN / Inf values. For debugging only.
            debug_summaries (bool): If true, output summary metrics to tf.

        Returns:
            importance_ratio (Tensor), importance_ratio_clipped (Tensor).
    """
    current_policy_distribution = action_distribution

    sample_action_log_probs = dist_utils.compute_log_probability(
        collect_action_distribution, action).detach()

    action_log_prob = dist_utils.compute_log_probability(
        current_policy_distribution, action)
    if log_prob_clipping > 0.0:
        action_log_prob = action_log_prob.clamp(-log_prob_clipping,
                                                log_prob_clipping)
    if check_numerics:
        assert torch.all(torch.isfinite(action_log_prob))

    # Prepare both clipped and unclipped importance ratios.
    importance_ratio = (action_log_prob - sample_action_log_probs).exp()
    if check_numerics:
        assert torch.all(torch.isfinite(importance_ratio))

    if clipping_mode == 'double_sided':
        importance_ratio_clipped = importance_ratio.clamp(
            1 - importance_ratio_clipping, 1 + importance_ratio_clipping)
    elif clipping_mode == 'capping':
        importance_ratio_clipped = torch.min(
            importance_ratio, torch.tensor(1 + importance_ratio_clipping))
    else:
        raise Exception('Unsupported clipping mode: ' + clipping_mode)

    if debug_summaries and alf.summary.should_record_summaries():
        with scope:
            if importance_ratio_clipping > 0.0:
                clip_fraction = (torch.abs(importance_ratio - 1.0) >
                                 importance_ratio_clipping).to(
                                     torch.float32).mean()
                alf.summary.scalar('clip_fraction', clip_fraction)

            alf.summary.histogram('action_log_prob', action_log_prob)
            alf.summary.histogram('action_log_prob_sample',
                                  sample_action_log_probs)
            alf.summary.histogram('importance_ratio', importance_ratio)
            alf.summary.scalar('importance_ratio_mean',
                               importance_ratio.mean())
            alf.summary.histogram('importance_ratio_clipped',
                                  importance_ratio_clipped)

    return importance_ratio, importance_ratio_clipped