def _actor_train_step(self, exp: Experience, state, action, critics, log_pi, action_distribution): neg_entropy = sum(nest.flatten(log_pi)) if self._act_type == ActionType.Discrete: # Pure discrete case doesn't need to learn an actor network return (), LossInfo(extra=SacActorInfo(neg_entropy=neg_entropy)) if self._act_type == ActionType.Continuous: critics, critics_state = self._compute_critics( self._critic_networks, exp.observation, action, state) if critics.ndim == 3: # Multidimensional reward: [B, num_criric_replicas, reward_dim] if self._reward_weights is None: critics = critics.sum(dim=2) else: critics = torch.tensordot(critics, self._reward_weights, dims=1) target_q_value = critics.min(dim=1)[0] continuous_log_pi = log_pi cont_alpha = torch.exp(self._log_alpha).detach() else: # use the critics computed during action prediction for Mixed type critics_state = () discrete_act_dist = action_distribution[0] discrete_entropy = discrete_act_dist.entropy() # critics is already after min over replicas weighted_q_value = torch.sum(discrete_act_dist.probs * critics, dim=-1) discrete_alpha = torch.exp(self._log_alpha[0]).detach() target_q_value = weighted_q_value + discrete_alpha * discrete_entropy action, continuous_log_pi = action[1], log_pi[1] cont_alpha = torch.exp(self._log_alpha[1]).detach() dqda = nest_utils.grad(action, target_q_value.sum()) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = torch.clamp(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( (dqda + action).detach(), action) return loss.sum(list(range(1, loss.ndim))) actor_loss = nest.map_structure(actor_loss_fn, dqda, action) actor_loss = math_ops.add_n(nest.flatten(actor_loss)) actor_info = LossInfo(loss=actor_loss + cont_alpha * continuous_log_pi, extra=SacActorInfo(actor_loss=actor_loss, neg_entropy=neg_entropy)) return critics_state, actor_info
def _train_step(self, time_step: TimeStep, state: SarsaState): not_first_step = time_step.step_type != StepType.FIRST prev_critics, critic_states = self._critic_networks( (state.prev_observation, time_step.prev_action), state.critics) critic_states = common.reset_state_if_necessary( state.critics, critic_states, not_first_step) action_distribution, action, actor_state, noise_state = self._get_action( self._actor_network, time_step, state) critics, _ = self._critic_networks((time_step.observation, action), critic_states) critic = critics.min(dim=1)[0] dqda = nest_utils.grad(action, critic.sum()) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = dqda.clamp(-self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( (dqda + action).detach(), action) loss = loss.sum(list(range(1, loss.ndim))) return loss actor_loss = nest_map(actor_loss_fn, dqda, action) actor_loss = math_ops.add_n(alf.nest.flatten(actor_loss)) neg_entropy = () if self._log_alpha is not None: neg_entropy = dist_utils.compute_log_probability( action_distribution, action) target_critics, target_critic_states = self._target_critic_networks( (time_step.observation, action), state.target_critics) info = SarsaInfo(action_distribution=action_distribution, actor_loss=actor_loss, critics=prev_critics, neg_entropy=neg_entropy, target_critics=target_critics.min(dim=1)[0]) rl_state = SarsaState(noise=noise_state, prev_observation=time_step.observation, prev_step_type=time_step.step_type, actor=actor_state, critics=critic_states, target_critics=target_critic_states) return AlgStep(action, rl_state, info)
def _actor_train_step(self, exp: Experience, state: DdpgActorState): action, actor_state = self._actor_network(exp.observation, state=state.actor) q_values, critic_states = self._critic_networks( (exp.observation, action), state=state.critics) if q_values.ndim == 3: # Multidimensional reward: [B, num_criric_replicas, reward_dim] if self._reward_weights is None: q_values = q_values.sum(dim=2) else: q_values = torch.tensordot(q_values, self._reward_weights, dims=1) if self._num_critic_replicas > 1: q_value = q_values.min(dim=1)[0] else: q_value = q_values.squeeze(dim=1) dqda = nest_utils.grad(action, q_value.sum()) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = torch.clamp(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( (dqda + action).detach(), action) if self._action_l2 > 0: assert action.requires_grad loss += self._action_l2 * (action**2) loss = loss.sum(list(range(1, loss.ndim))) return loss actor_loss = nest.map_structure(actor_loss_fn, dqda, action) state = DdpgActorState(actor=actor_state, critics=critic_states) info = LossInfo(loss=sum(nest.flatten(actor_loss)), extra=actor_loss) return AlgStep(output=action, state=state, info=info)