def _calc_critic_loss(self, experience, train_info: SacInfo): critic_info = train_info.critic critic_losses = [] for i, l in enumerate(self._critic_losses): critic_losses.append( l(experience=experience, value=critic_info.critics[:, :, i, ...], target_value=critic_info.target_critic).loss) critic_loss = math_ops.add_n(critic_losses) if (experience.batch_info != () and experience.batch_info.importance_weights != ()): valid_masks = (experience.step_type != StepType.LAST).to( torch.float32) valid_n = torch.clamp(valid_masks.sum(dim=0), min=1.0) priority = ((critic_loss * valid_masks).sum(dim=0) / valid_n).sqrt() else: priority = () return LossInfo(loss=critic_loss, priority=priority, extra=critic_loss / float(self._num_critic_replicas))
def calc_loss(self, experience, train_info: DdpgInfo): critic_losses = [None] * self._num_critic_replicas for i in range(self._num_critic_replicas): critic_losses[i] = self._critic_losses[i]( experience=experience, value=train_info.critic.q_values[:, :, i, ...], target_value=train_info.critic.target_q_values).loss critic_loss = math_ops.add_n(critic_losses) if (experience.batch_info != () and experience.batch_info.importance_weights != ()): valid_masks = (experience.step_type != StepType.LAST).to( torch.float32) valid_n = torch.clamp(valid_masks.sum(dim=0), min=1.0) priority = ((critic_loss * valid_masks).sum(dim=0) / valid_n).sqrt() else: priority = () actor_loss = train_info.actor_loss return LossInfo(loss=critic_loss + actor_loss.loss, priority=priority, extra=DdpgLossInfo(critic=critic_loss, actor=actor_loss.extra))
def _benchmark(pnet, name): t0 = time.time() outputs = [] for _ in range(1000): embedding = input_spec.randn(outer_dims=(batch_size, )) output, _ = pnet(embedding) outputs.append(output) o = math_ops.add_n(outputs).sum() logging.info("%s time=%s %s" % (name, time.time() - t0, float(o))) self.assertEqual(output.shape, (batch_size, replicas, 1)) self.assertEqual(pnet.output_spec.shape, (replicas, 1))
def _actor_train_step(self, exp: Experience, state, action, critics, log_pi, action_distribution): neg_entropy = sum(nest.flatten(log_pi)) if self._act_type == ActionType.Discrete: # Pure discrete case doesn't need to learn an actor network return (), LossInfo(extra=SacActorInfo(neg_entropy=neg_entropy)) if self._act_type == ActionType.Continuous: critics, critics_state = self._compute_critics( self._critic_networks, exp.observation, action, state) if critics.ndim == 3: # Multidimensional reward: [B, num_criric_replicas, reward_dim] if self._reward_weights is None: critics = critics.sum(dim=2) else: critics = torch.tensordot(critics, self._reward_weights, dims=1) target_q_value = critics.min(dim=1)[0] continuous_log_pi = log_pi cont_alpha = torch.exp(self._log_alpha).detach() else: # use the critics computed during action prediction for Mixed type critics_state = () discrete_act_dist = action_distribution[0] discrete_entropy = discrete_act_dist.entropy() # critics is already after min over replicas weighted_q_value = torch.sum(discrete_act_dist.probs * critics, dim=-1) discrete_alpha = torch.exp(self._log_alpha[0]).detach() target_q_value = weighted_q_value + discrete_alpha * discrete_entropy action, continuous_log_pi = action[1], log_pi[1] cont_alpha = torch.exp(self._log_alpha[1]).detach() dqda = nest_utils.grad(action, target_q_value.sum()) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = torch.clamp(dqda, -self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( (dqda + action).detach(), action) return loss.sum(list(range(1, loss.ndim))) actor_loss = nest.map_structure(actor_loss_fn, dqda, action) actor_loss = math_ops.add_n(nest.flatten(actor_loss)) actor_info = LossInfo(loss=actor_loss + cont_alpha * continuous_log_pi, extra=SacActorInfo(actor_loss=actor_loss, neg_entropy=neg_entropy)) return critics_state, actor_info
def decode_step(self, latent_vector, observations): """Calculate decoding loss.""" decoders = flatten(self._decoders) observations = flatten(observations) decoder_losses = [ decoder.train_step((latent_vector, obs)).info for decoder, obs in zip(decoders, observations) ] loss = math_ops.add_n( [decoder_loss.loss for decoder_loss in decoder_losses]) decoder_losses = alf.nest.pack_sequence_as(self._decoders, decoder_losses) return LossInfo(loss=loss, extra=decoder_losses)
def _train_step(self, time_step: TimeStep, state: SarsaState): not_first_step = time_step.step_type != StepType.FIRST prev_critics, critic_states = self._critic_networks( (state.prev_observation, time_step.prev_action), state.critics) critic_states = common.reset_state_if_necessary( state.critics, critic_states, not_first_step) action_distribution, action, actor_state, noise_state = self._get_action( self._actor_network, time_step, state) critics, _ = self._critic_networks((time_step.observation, action), critic_states) critic = critics.min(dim=1)[0] dqda = nest_utils.grad(action, critic.sum()) def actor_loss_fn(dqda, action): if self._dqda_clipping: dqda = dqda.clamp(-self._dqda_clipping, self._dqda_clipping) loss = 0.5 * losses.element_wise_squared_loss( (dqda + action).detach(), action) loss = loss.sum(list(range(1, loss.ndim))) return loss actor_loss = nest_map(actor_loss_fn, dqda, action) actor_loss = math_ops.add_n(alf.nest.flatten(actor_loss)) neg_entropy = () if self._log_alpha is not None: neg_entropy = dist_utils.compute_log_probability( action_distribution, action) target_critics, target_critic_states = self._target_critic_networks( (time_step.observation, action), state.target_critics) info = SarsaInfo(action_distribution=action_distribution, actor_loss=actor_loss, critics=prev_critics, neg_entropy=neg_entropy, target_critics=target_critics.min(dim=1)[0]) rl_state = SarsaState(noise=noise_state, prev_observation=time_step.observation, prev_step_type=time_step.step_type, actor=actor_state, critics=critic_states, target_critics=target_critic_states) return AlgStep(action, rl_state, info)
def _calc_critic_loss(self, experience, train_info: MdqInfo): critic_info = train_info.critic # [t, B, n] critic_free_form = critic_info.critic_free_form # [t, B, n, action_dim] critic_adv_form = critic_info.critic_adv_form target_critic_free_form = critic_info.target_critic_free_form distill_target = critic_info.distill_target num_critic_replicas = critic_free_form.shape[2] alpha = torch.exp(self._log_alpha).detach() kl_wrt_prior = critic_info.kl_wrt_prior # [t, B, n, action_dim] -> [t, B] # note that currently the kl_wrt_prior is independent of ensembles, # we therefore slice over ensemble by taking the first element; # for the aciton dimension, the first element is the full KL kl_wrt_prior = kl_wrt_prior[..., 0, 0] # [t, B, n] -> [t, B] target_critic, min_target_ind = torch.min( target_critic_free_form, dim=2) # [t, B, n] -> [t, B] distill_target, _ = torch.min(distill_target, dim=2) target_critic_corrected = target_critic - alpha * kl_wrt_prior critic_losses = [] for j in range(num_critic_replicas): critic_losses.append(self._critic_losses[j]( experience=experience, value=critic_free_form[:, :, j], target_value=target_critic_corrected).loss) critic_loss = math_ops.add_n(critic_losses) distill_loss = ( critic_adv_form[..., -1] - distill_target.unsqueeze(2).detach())**2 # mean over replica distill_loss = distill_loss.mean(dim=2) return LossInfo( loss=critic_loss, extra=critic_loss / len(critic_losses)), distill_loss
def calc_loss(self, experience, info: SarsaInfo): loss = info.actor_loss if self._log_alpha is not None: alpha = self._log_alpha.exp().detach() alpha_loss = self._log_alpha * (-info.neg_entropy - self._target_entropy).detach() loss = loss + alpha * info.neg_entropy + alpha_loss else: alpha_loss = () # For sarsa, info.critics is actually the critics for the previous step. # And info.target_critics is the critics for the current step. So we # need to rearrange ``experience``` to match the requirement for # `OneStepTDLoss`. step_type0 = experience.step_type[0] step_type0 = torch.where(step_type0 == StepType.LAST, torch.tensor(StepType.MID), step_type0) step_type0 = torch.where(step_type0 == StepType.FIRST, torch.tensor(StepType.LAST), step_type0) reward = experience.reward if self._use_entropy_reward: reward -= (self._log_alpha.exp() * info.neg_entropy).detach() shifted_experience = experience._replace( discount=tensor_utils.tensor_prepend_zero(experience.discount), reward=tensor_utils.tensor_prepend_zero(reward), step_type=tensor_utils.tensor_prepend(experience.step_type, step_type0)) critic_losses = [] for i in range(self._num_critic_replicas): critic = tensor_utils.tensor_extend_zero(info.critics[..., i]) target_critic = tensor_utils.tensor_prepend_zero( info.target_critics) loss_info = self._critic_losses[i](shifted_experience, critic, target_critic) critic_losses.append(nest_map(lambda l: l[:-1], loss_info.loss)) critic_loss = math_ops.add_n(critic_losses) not_first_step = (experience.step_type != StepType.FIRST).to( torch.float32) critic_loss = critic_loss * not_first_step if (experience.batch_info != () and experience.batch_info.importance_weights != ()): valid_n = torch.clamp(not_first_step.sum(dim=0), min=1.0) priority = (critic_loss.sum(dim=0) / valid_n).sqrt() else: priority = () # put critic_loss to scalar_loss because loss will be masked by # ~is_last at train_complete(). The critic_loss here should be # masked by ~is_first instead, which is done above scalar_loss = critic_loss.mean() if self._debug_summaries and alf.summary.should_record_summaries(): with alf.summary.scope(self._name): if self._log_alpha is not None: alf.summary.scalar("alpha", alpha) return LossInfo(loss=loss, scalar_loss=scalar_loss, priority=priority, extra=SarsaLossInfo(actor=info.actor_loss, critic=critic_loss, alpha=alpha_loss, neg_entropy=info.neg_entropy))