Пример #1
0
    def _actor_train_step(self, exp: Experience, state, action, critics,
                          log_pi, action_distribution):
        neg_entropy = sum(nest.flatten(log_pi))

        if self._act_type == ActionType.Discrete:
            # Pure discrete case doesn't need to learn an actor network
            return (), LossInfo(extra=SacActorInfo(neg_entropy=neg_entropy))

        if self._act_type == ActionType.Continuous:
            critics, critics_state = self._compute_critics(
                self._critic_networks, exp.observation, action, state)
            if critics.ndim == 3:
                # Multidimensional reward: [B, num_criric_replicas, reward_dim]
                if self._reward_weights is None:
                    critics = critics.sum(dim=2)
                else:
                    critics = torch.tensordot(critics,
                                              self._reward_weights,
                                              dims=1)

            target_q_value = critics.min(dim=1)[0]
            continuous_log_pi = log_pi
            cont_alpha = torch.exp(self._log_alpha).detach()
        else:
            # use the critics computed during action prediction for Mixed type
            critics_state = ()
            discrete_act_dist = action_distribution[0]
            discrete_entropy = discrete_act_dist.entropy()
            # critics is already after min over replicas
            weighted_q_value = torch.sum(discrete_act_dist.probs * critics,
                                         dim=-1)
            discrete_alpha = torch.exp(self._log_alpha[0]).detach()
            target_q_value = weighted_q_value + discrete_alpha * discrete_entropy
            action, continuous_log_pi = action[1], log_pi[1]
            cont_alpha = torch.exp(self._log_alpha[1]).detach()

        dqda = nest_utils.grad(action, target_q_value.sum())

        def actor_loss_fn(dqda, action):
            if self._dqda_clipping:
                dqda = torch.clamp(dqda, -self._dqda_clipping,
                                   self._dqda_clipping)
            loss = 0.5 * losses.element_wise_squared_loss(
                (dqda + action).detach(), action)
            return loss.sum(list(range(1, loss.ndim)))

        actor_loss = nest.map_structure(actor_loss_fn, dqda, action)
        actor_loss = math_ops.add_n(nest.flatten(actor_loss))
        actor_info = LossInfo(loss=actor_loss + cont_alpha * continuous_log_pi,
                              extra=SacActorInfo(actor_loss=actor_loss,
                                                 neg_entropy=neg_entropy))
        return critics_state, actor_info
Пример #2
0
    def _train_step(self, time_step: TimeStep, state: SarsaState):
        not_first_step = time_step.step_type != StepType.FIRST
        prev_critics, critic_states = self._critic_networks(
            (state.prev_observation, time_step.prev_action), state.critics)

        critic_states = common.reset_state_if_necessary(
            state.critics, critic_states, not_first_step)

        action_distribution, action, actor_state, noise_state = self._get_action(
            self._actor_network, time_step, state)

        critics, _ = self._critic_networks((time_step.observation, action),
                                           critic_states)
        critic = critics.min(dim=1)[0]
        dqda = nest_utils.grad(action, critic.sum())

        def actor_loss_fn(dqda, action):
            if self._dqda_clipping:
                dqda = dqda.clamp(-self._dqda_clipping, self._dqda_clipping)
            loss = 0.5 * losses.element_wise_squared_loss(
                (dqda + action).detach(), action)
            loss = loss.sum(list(range(1, loss.ndim)))
            return loss

        actor_loss = nest_map(actor_loss_fn, dqda, action)
        actor_loss = math_ops.add_n(alf.nest.flatten(actor_loss))

        neg_entropy = ()
        if self._log_alpha is not None:
            neg_entropy = dist_utils.compute_log_probability(
                action_distribution, action)

        target_critics, target_critic_states = self._target_critic_networks(
            (time_step.observation, action), state.target_critics)

        info = SarsaInfo(action_distribution=action_distribution,
                         actor_loss=actor_loss,
                         critics=prev_critics,
                         neg_entropy=neg_entropy,
                         target_critics=target_critics.min(dim=1)[0])

        rl_state = SarsaState(noise=noise_state,
                              prev_observation=time_step.observation,
                              prev_step_type=time_step.step_type,
                              actor=actor_state,
                              critics=critic_states,
                              target_critics=target_critic_states)

        return AlgStep(action, rl_state, info)
Пример #3
0
    def _actor_train_step(self, exp: Experience, state: DdpgActorState):
        action, actor_state = self._actor_network(exp.observation,
                                                  state=state.actor)

        q_values, critic_states = self._critic_networks(
            (exp.observation, action), state=state.critics)
        if q_values.ndim == 3:
            # Multidimensional reward: [B, num_criric_replicas, reward_dim]
            if self._reward_weights is None:
                q_values = q_values.sum(dim=2)
            else:
                q_values = torch.tensordot(q_values,
                                           self._reward_weights,
                                           dims=1)

        if self._num_critic_replicas > 1:
            q_value = q_values.min(dim=1)[0]
        else:
            q_value = q_values.squeeze(dim=1)

        dqda = nest_utils.grad(action, q_value.sum())

        def actor_loss_fn(dqda, action):
            if self._dqda_clipping:
                dqda = torch.clamp(dqda, -self._dqda_clipping,
                                   self._dqda_clipping)
            loss = 0.5 * losses.element_wise_squared_loss(
                (dqda + action).detach(), action)
            if self._action_l2 > 0:
                assert action.requires_grad
                loss += self._action_l2 * (action**2)
            loss = loss.sum(list(range(1, loss.ndim)))
            return loss

        actor_loss = nest.map_structure(actor_loss_fn, dqda, action)
        state = DdpgActorState(actor=actor_state, critics=critic_states)
        info = LossInfo(loss=sum(nest.flatten(actor_loss)), extra=actor_loss)
        return AlgStep(output=action, state=state, info=info)