def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_eval = (q * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q_target = n_step_return( BATCH.reward, self.gamma, BATCH.done, q_next.max(-1, keepdim=True)[0], BATCH.begin_mask, nstep=self._n_step_value).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 cql1_loss = (th.logsumexp(q, dim=-1, keepdim=True) - q).mean() # 1 loss = q_loss + self._cql_weight * cql1_loss self.oplr.optimize(loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/q_loss': q_loss, 'LOSS/cql1_loss': cql1_loss, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() }
def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask).mean( 0) # [H, T, B, A] => [T, B, A] q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask).mean( 0) # [H, T, B, A] => [T, B, A] # [T, B, A] * [T, B, A] => [T, B, 1] q_eval = (q * BATCH.action).sum(-1, keepdim=True) q_target = n_step_return( BATCH.reward, self.gamma, BATCH.done, # [T, B, A] => [T, B, 1] q_next.max(-1, keepdim=True)[0], BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 # mask_dist = td.Bernoulli(probs=self._probs) # TODO: # mask = mask_dist.sample([batch_size]).T # [H, B] self.oplr.optimize(q_loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': q_loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() }
def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] next_q = self.q_net(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_target = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_eval = (q * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] next_max_action = next_q.argmax(-1) # [T, B] next_max_action_one_hot = F.one_hot(next_max_action.squeeze(), self.a_dim).float() # [T, B, A] q_target_next_max = (q_target * next_max_action_one_hot).sum(-1, keepdim=True) # [T, B, 1] q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target_next_max, BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(q_loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': q_loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() }
def _train(self, BATCH): if self.is_continuous: action_target = self.actor.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] if self.use_target_action_noise: action_target = self.target_noised_action( action_target) # [T, B, A] else: target_logits = self.actor.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] action_target = F.one_hot(target_pi, self.a_dim).float() # [T, B, A] q = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q_target = self.critic.t(BATCH.obs_, action_target, begin_mask=BATCH.begin_mask) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target, BATCH.begin_mask).detach() # [T, B, 1] td_error = dc_r - q # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.critic_oplr.optimize(q_loss) if self.is_continuous: mu = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] mu = _pi_diff + _pi # [T, B, A] q_actor = self.critic(BATCH.obs, mu, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -q_actor.mean() # 1 self.actor_oplr.optimize(actor_loss) return td_error, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': q_loss, 'Statistics/q_min': q.min(), 'Statistics/q_mean': q.mean(), 'Statistics/q_max': q.max() }
def _train(self, BATCH): q = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] if self.is_continuous: next_mu, _ = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, *] max_q_next = self.critic( BATCH.obs_, next_mu, begin_mask=BATCH.begin_mask).detach() # [T, B, 1] else: logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, *] max_a = logits.argmax(-1) # [T, B] max_a_one_hot = F.one_hot(max_a, self.a_dim).float() # [T, B, N] max_q_next = self.critic(BATCH.obs_, max_a_one_hot).detach() # [T, B, 1] td_error = q - n_step_return(BATCH.reward, self.gamma, BATCH.done, max_q_next, BATCH.begin_mask).detach() # [T, B, 1] critic_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, *] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) log_prob = dist.log_prob(BATCH.action) # [T, B] entropy = dist.entropy().mean() # 1 else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, *] logp_all = logits.log_softmax(-1) # [T, B, *] log_prob = (logp_all * BATCH.action).sum(-1) # [T, B] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 ratio = (log_prob - BATCH.log_prob).exp().detach() # [T, B] actor_loss = -(ratio * log_prob * q.squeeze(-1).detach()).mean() # [T, B] => 1 self.actor_oplr.optimize(actor_loss) return td_error, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/q_max': q.max(), 'Statistics/q_min': q.min(), 'Statistics/q_mean': q.mean(), 'Statistics/ratio': ratio.mean(), 'Statistics/entropy': entropy }
def _train(self, BATCH): q1 = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q2 = self.critic2(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q1_eval = (q1 * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q2_eval = (q2 * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q1_log_probs = (q1 / (self.alpha + th.finfo().eps)).log_softmax(-1) # [T, B, A] q1_entropy = -(q1_log_probs.exp() * q1_log_probs).sum(-1, keepdim=True).mean() # 1 q1_target = self.critic.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q2_target = self.critic2.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q1_target_max = q1_target.max(-1, keepdim=True)[0] # [T, B, 1] q1_target_log_probs = (q1_target / (self.alpha + th.finfo().eps)).log_softmax(-1) # [T, B, A] q1_target_entropy = -(q1_target_log_probs.exp() * q1_target_log_probs).sum(-1, keepdim=True) # [T, B, 1] q2_target_max = q2_target.max(-1, keepdim=True)[0] # [T, B, 1] # q2_target_log_probs = q2_target.log_softmax(-1) # q2_target_log_max = q2_target_log_probs.max(1, keepdim=True)[0] q_target = th.minimum(q1_target_max, q2_target_max) + self.alpha * q1_target_entropy # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target, BATCH.begin_mask).detach() # [T, B, 1] td_error1 = q1_eval - dc_r # [T, B, 1] td_error2 = q2_eval - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 loss = 0.5 * (q1_loss + q2_loss) self.critic_oplr.optimize(loss) summaries = { 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/loss': loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/q1_entropy': q1_entropy, 'Statistics/q_min': th.minimum(q1, q2).mean(), 'Statistics/q_mean': q1.mean(), 'Statistics/q_max': th.maximum(q1, q2).mean() } if self.auto_adaption: alpha_loss = -(self.alpha * (self.target_entropy - q1_entropy).detach()).mean() self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries
def _train(self, BATCH): q_dist = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A, N] # [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N] q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2) q_eval = (q_dist * self._z).sum(-1) # [T, B, N] * [N,] => [T, B] target_q_dist = self.q_net.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A, N] # [T, B, A, N] * [1, N] => [T, B, A] target_q = (target_q_dist * self._z).sum(-1) a_ = target_q.argmax(-1) # [T, B] a_onehot = F.one_hot(a_, self.a_dim).float() # [T, B, A] # [T, B, A, N] * [T, B, A, 1] => [T, B, A, N] => [T, B, N] target_q_dist = (target_q_dist * a_onehot.unsqueeze(-1)).sum(-2) target = n_step_return( BATCH.reward.repeat(1, 1, self._atoms), self.gamma, BATCH.done.repeat(1, 1, self._atoms), target_q_dist, BATCH.begin_mask.repeat(1, 1, self._atoms)).detach() # [T, B, N] target = target.clamp(self._v_min, self._v_max) # [T, B, N] # An amazing trick for calculating the projection gracefully. # ref: https://github.com/ShangtongZhang/DeepRL target_dist = ( 1 - (target.unsqueeze(-1) - self._z.view(1, 1, -1, 1)).abs() / self._delta_z).clamp(0, 1) * target_q_dist.unsqueeze( -1) # [T, B, N, 1] target_dist = target_dist.sum(-1) # [T, B, N] _cross_entropy = -(target_dist * th.log(q_dist + th.finfo().eps)).sum( -1, keepdim=True) # [T, B, 1] loss = (_cross_entropy * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(loss) return _cross_entropy, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() }
def _train(self, BATCH): q_dist = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A, N] q_dist = (q_dist * BATCH.action.unsqueeze(-1)).sum(-2) # [T, B, A, N] => [T, B, N] target_q_dist = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A, N] target_q = target_q_dist.mean(-1) # [T, B, A, N] => [T, B, A] _a = target_q.argmax(-1) # [T, B] next_max_action = F.one_hot(_a, self.a_dim).float().unsqueeze(-1) # [T, B, A, 1] # [T, B, A, N] => [T, B, N] target_q_dist = (target_q_dist * next_max_action).sum(-2) target = n_step_return(BATCH.reward.repeat(1, 1, self.nums), self.gamma, BATCH.done.repeat(1, 1, self.nums), target_q_dist, BATCH.begin_mask.repeat(1, 1, self.nums)).detach() # [T, B, N] q_eval = q_dist.mean(-1, keepdim=True) # [T, B, 1] q_target = target.mean(-1, keepdim=True) # [T, B, 1] td_error = q_target - q_eval # [T, B, 1], used for PER target = target.unsqueeze(-2) # [T, B, 1, N] q_dist = q_dist.unsqueeze(-1) # [T, B, N, 1] # [T, B, 1, N] - [T, B, N, 1] => [T, B, N, N] quantile_error = target - q_dist huber = F.huber_loss(target, q_dist, reduction="none", delta=self.huber_delta) # [T, B, N, N] # [N,] - [T, B, N, N] => [T, B, N, N] huber_abs = (self.quantiles - quantile_error.detach().le(0.).float()).abs() loss = (huber_abs * huber).mean(-1) # [T, B, N, N] => [T, B, N] loss = loss.sum(-1, keepdim=True) # [T, B, N] => [T, B, 1] loss = (loss * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() }
def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] v_next = self._get_v(q_next) # [T, B, 1] q_eval = (q * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_next, BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(q_loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': q_loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() }
def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] beta_next = self.termination_net( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] qu_eval = (q * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] beta_s_ = (beta_next * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] q_s_ = (q_next * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] if self.double_q: q_ = self.q_net(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] max_a_idx = F.one_hot(q_.argmax(-1), self.options_num).float() # [T, B, P] q_s_max = (q_next * max_a_idx).sum(-1, keepdim=True) # [T, B, 1] else: q_s_max = q_next.max(-1, keepdim=True)[0] # [T, B, 1] u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max # [T, B, 1] qu_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, u_target, BATCH.begin_mask).detach() # [T, B, 1] td_error = qu_target - qu_eval # [T, B, 1] gradient : q q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.q_oplr.optimize(q_loss) q_s = qu_eval.detach() # [T, B, 1] pi = self.intra_option_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P, A] if self.use_baseline: adv = (qu_target - q_s).detach() # [T, B, 1] else: adv = qu_target.detach() # [T, B, 1] # [T, B, P] => [T, B, P, 1] options_onehot_expanded = BATCH.options.unsqueeze(-1) # [T, B, P, A] => [T, B, A] pi = (pi * options_onehot_expanded).sum(-2) if self.is_continuous: mu = pi.tanh() # [T, B, A] log_std = self.log_std[BATCH.options.argmax(-1)] # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) log_p = dist.log_prob(BATCH.action).unsqueeze(-1) # [T, B, 1] entropy = dist.entropy().unsqueeze(-1) # [T, B, 1] else: pi = pi / self.boltzmann_temperature # [T, B, A] log_pi = pi.log_softmax(-1) # [T, B, A] entropy = -(log_pi.exp() * log_pi).sum(-1, keepdim=True) # [T, B, 1] log_p = (log_pi * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] pi_loss = -(log_p * adv + self.ent_coff * entropy).mean() # 1 self.intra_option_oplr.optimize(pi_loss) beta = self.termination_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] beta_s = (beta * BATCH.last_options).sum(-1, keepdim=True) # [T, B, 1] interests = self.interest_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] # [T, B, P] or q.softmax(-1) pi_op = (interests * q.detach()).softmax(-1) interest_loss = -(beta_s.detach() * (pi_op * BATCH.options).sum(-1, keepdim=True) * q_s).mean() # 1 self.interest_oplr.optimize(interest_loss) v_s = (q * pi_op).sum(-1, keepdim=True) # [T, B, 1] beta_loss = beta_s * (q_s - v_s).detach() # [T, B, 1] if self.terminal_mask: beta_loss *= (1 - BATCH.done) # [T, B, 1] beta_loss = beta_loss.mean() # 1 self.termination_oplr.optimize(beta_loss) return td_error, { 'LEARNING_RATE/q_lr': self.q_oplr.lr, 'LEARNING_RATE/intra_option_lr': self.intra_option_oplr.lr, 'LEARNING_RATE/termination_lr': self.termination_oplr.lr, # 'Statistics/option': self.options[0], 'LOSS/q_loss': q_loss, 'LOSS/pi_loss': pi_loss, 'LOSS/beta_loss': beta_loss, 'LOSS/interest_loss': interest_loss, 'Statistics/q_option_max': q_s.max(), 'Statistics/q_option_min': q_s.min(), 'Statistics/q_option_mean': q_s.mean() }
def _train(self, BATCH_DICT): """ TODO: Annotation """ summaries = defaultdict(dict) target_actions = {} for aid, mid in zip(self.agent_ids, self.model_ids): if self.is_continuouss[aid]: target_actions[aid] = self.actors[mid].t( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] else: target_logits = self.actors[mid].t( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] action_target = F.one_hot( target_pi, self.a_dims[aid]).float() # [T, B, A] target_actions[aid] = action_target # [T, B, A] target_actions = th.cat(list(target_actions.values()), -1) # [T, B, N*A] qs, q_targets = {}, {} for mid in self.model_ids: qs[mid] = self.critics[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat([BATCH_DICT[id].action for id in self.agent_ids], -1)) # [T, B, 1] q_targets[mid] = self.critics[mid].t( [BATCH_DICT[id].obs_ for id in self.agent_ids], target_actions) # [T, B, 1] q_loss = {} td_errors = 0. for aid, mid in zip(self.agent_ids, self.model_ids): dc_r = n_step_return( BATCH_DICT[aid].reward, self.gamma, BATCH_DICT[aid].done, q_targets[mid], BATCH_DICT['global'].begin_mask).detach() # [T, B, 1] td_error = dc_r - qs[mid] # [T, B, 1] td_errors += td_error q_loss[aid] = 0.5 * td_error.square().mean() # 1 summaries[aid].update({ 'Statistics/q_min': qs[mid].min(), 'Statistics/q_mean': qs[mid].mean(), 'Statistics/q_max': qs[mid].max() }) self.critic_oplr.optimize(sum(q_loss.values())) actor_loss = {} for aid, mid in zip(self.agent_ids, self.model_ids): if self.is_continuouss[aid]: mu = self.actors[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] else: logits = self.actors[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot( _pi.argmax(-1), self.a_dims[aid]).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] mu = _pi_diff + _pi # [T, B, A] all_actions = {id: BATCH_DICT[id].action for id in self.agent_ids} all_actions[aid] = mu q_actor = self.critics[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat(list(all_actions.values()), -1), begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1] actor_loss[aid] = -q_actor.mean() # 1 self.actor_oplr.optimize(sum(actor_loss.values())) for aid in self.agent_ids: summaries[aid].update({ 'LOSS/actor_loss': actor_loss[aid], 'LOSS/critic_loss': q_loss[aid] }) summaries['model'].update({ 'LOSS/actor_loss', sum(actor_loss.values()), 'LOSS/critic_loss', sum(q_loss.values()) }) return td_errors / self.n_agents_percopy, summaries
def _train(self, BATCH_DICT): summaries = {} reward = BATCH_DICT[self.agent_ids[0]].reward # [T, B, 1] done = 0. q_evals = [] q_rnncs_s = [] q_actions = [] q_maxs = [] q_max_actions = [] q_target_next_choose_maxs = [] q_target_rnncs_s = [] q_target_actions = [] for aid, mid in zip(self.agent_ids, self.model_ids): done += BATCH_DICT[aid].done # [T, B, 1] q = self.q_nets[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] q_rnncs = self.q_nets[mid].get_rnncs() # [T, B, *] q_eval = (q * BATCH_DICT[aid].action).sum( -1, keepdim=True) # [T, B, 1] q_evals.append(q_eval) # N * [T, B, 1] q_rnncs_s.append(q_rnncs) # N * [T, B, *] q_actions.append(BATCH_DICT[aid].action) # N * [T, B, A] q_maxs.append(q.max(-1, keepdim=True)[0]) # [T, B, 1] q_max_actions.append( F.one_hot(q.argmax(-1), self.a_dims[aid]).float()) # [T, B, A] q_target = self.q_nets[mid].t( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] # [T, B, *] q_target_rnncs = self.q_nets[mid].target.get_rnncs() if self._use_double: next_q = self.q_nets[mid]( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] next_max_action = next_q.argmax(-1) # [T, B] next_max_action_one_hot = F.one_hot( next_max_action, self.a_dims[aid]).float() # [T, B, A] q_target_next_max = (q_target * next_max_action_one_hot).sum( -1, keepdim=True) # [T, B, 1] else: next_max_action = q_target.argmax(-1) # [T, B] next_max_action_one_hot = F.one_hot( next_max_action, self.a_dims[aid]).float() # [T, B, A] # [T, B, 1] q_target_next_max = q_target.max(-1, keepdim=True)[0] q_target_next_choose_maxs.append( q_target_next_max) # N * [T, B, 1] q_target_rnncs_s.append(q_target_rnncs) # N * [T, B, *] q_target_actions.append(next_max_action_one_hot) # N * [T, B, A] joint_qs, vs = self.mixer( BATCH_DICT['global'].obs, q_rnncs_s, q_actions, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1] target_joint_qs, target_vs = self.mixer.t( BATCH_DICT['global'].obs_, q_target_rnncs_s, q_target_actions, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1] q_target_tot = n_step_return( reward, self.gamma, (done > 0.).float(), target_joint_qs, BATCH_DICT['global'].begin_mask).detach() # [T, B, 1] td_error = q_target_tot - joint_qs # [T, B, 1] td_loss = td_error.square().mean() # 1 # opt loss max_joint_qs, _ = self.mixer( BATCH_DICT['global'].obs, q_rnncs_s, q_max_actions, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1] max_actions_qvals = sum(q_maxs) # [T, B, 1] opt_loss = (max_actions_qvals - max_joint_qs.detach() + vs).square().mean() # 1 # nopt loss nopt_error = sum(q_evals) - joint_qs.detach() + vs # [T, B, 1] nopt_error = nopt_error.clamp(max=0) # [T, B, 1] nopt_loss = nopt_error.square().mean() # 1 loss = td_loss + self.opt_loss * opt_loss + self.nopt_min_loss * nopt_loss self.oplr.optimize(loss) summaries['model'] = { 'LOSS/q_loss': td_loss, 'LOSS/loss': loss, 'Statistics/q_max': joint_qs.max(), 'Statistics/q_min': joint_qs.min(), 'Statistics/q_mean': joint_qs.mean() } return td_error, summaries
def _train(self, BATCH): q = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] q_next = self.q_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] beta_next = self.termination_net( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] qu_eval = (q * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] beta_s_ = (beta_next * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] q_s_ = (q_next * BATCH.options).sum(-1, keepdim=True) # [T, B, 1] # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L94 if self.double_q: q_ = self.q_net(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, P] # [T, B, P] => [T, B] => [T, B, P] max_a_idx = F.one_hot(q_.argmax(-1), self.options_num).float() q_s_max = (q_next * max_a_idx).sum(-1, keepdim=True) # [T, B, 1] else: q_s_max = q_next.max(-1, keepdim=True)[0] # [T, B, 1] u_target = (1 - beta_s_) * q_s_ + beta_s_ * q_s_max # [T, B, 1] qu_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, u_target, BATCH.begin_mask).detach() # [T, B, 1] td_error = qu_target - qu_eval # gradient : q [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # [T, B, 1] => 1 self.q_oplr.optimize(q_loss) q_s = qu_eval.detach() # [T, B, 1] # https://github.com/jeanharb/option_critic/blob/5d6c81a650a8f452bc8ad3250f1f211d317fde8c/neural_net.py#L130 if self.use_baseline: adv = (qu_target - q_s).detach() # [T, B, 1] else: adv = qu_target.detach() # [T, B, 1] # [T, B, P] => [T, B, P, 1] options_onehot_expanded = BATCH.options.unsqueeze(-1) pi = self.intra_option_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P, A] # [T, B, P, A] => [T, B, A] pi = (pi * options_onehot_expanded).sum(-2) if self.is_continuous: mu = pi.tanh() # [T, B, A] log_std = self.log_std[BATCH.options.argmax(-1)] # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) log_p = dist.log_prob(BATCH.action).unsqueeze(-1) # [T, B, 1] entropy = dist.entropy().unsqueeze(-1) # [T, B, 1] else: pi = pi / self.boltzmann_temperature # [T, B, A] log_pi = pi.log_softmax(-1) # [T, B, A] entropy = -(log_pi.exp() * log_pi).sum(-1, keepdim=True) # [T, B, 1] log_p = (BATCH.action * log_pi).sum(-1, keepdim=True) # [T, B, 1] pi_loss = -(log_p * adv + self.ent_coff * entropy).mean() # 1 beta = self.termination_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, P] beta_s = (beta * BATCH.last_options).sum(-1, keepdim=True) # [T, B, 1] if self.use_eps_greedy: v_s = q.max( -1, keepdim=True)[0] - self.termination_regularizer # [T, B, 1] else: v_s = (1 - beta_s) * q_s + beta_s * q.max( -1, keepdim=True)[0] # [T, B, 1] # v_s = q.mean(-1, keepdim=True) # [T, B, 1] beta_loss = beta_s * (q_s - v_s).detach() # [T, B, 1] # https://github.com/lweitkamp/option-critic-pytorch/blob/0c57da7686f8903ed2d8dded3fae832ee9defd1a/option_critic.py#L238 if self.terminal_mask: beta_loss *= (1 - BATCH.done) # [T, B, 1] beta_loss = beta_loss.mean() # 1 self.intra_option_oplr.optimize(pi_loss) self.termination_oplr.optimize(beta_loss) return td_error, { 'LEARNING_RATE/q_lr': self.q_oplr.lr, 'LEARNING_RATE/intra_option_lr': self.intra_option_oplr.lr, 'LEARNING_RATE/termination_lr': self.termination_oplr.lr, # 'Statistics/option': self.options[0], 'LOSS/q_loss': q_loss, 'LOSS/pi_loss': pi_loss, 'LOSS/beta_loss': beta_loss, 'Statistics/q_option_max': q_s.max(), 'Statistics/q_option_min': q_s.min(), 'Statistics/q_option_mean': q_s.mean() }
def _train(self, BATCH): if self.is_continuous: target_mu, target_log_std = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(target_mu, target_log_std.exp()), 1) target_pi = dist.sample() # [T, B, A] target_pi, target_log_pi = squash_action(target_pi, dist.log_prob( target_pi).unsqueeze(-1), is_independent=False) # [T, B, A] target_log_pi = tsallis_entropy_log_q(target_log_pi, self.entropic_index) # [T, B, 1] else: target_logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze(-1) # [T, B, 1] target_pi = F.one_hot(target_pi, self.a_dim).float() # [T, B, A] q1 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q2 = self.critic2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q1_target = self.critic.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q2_target = self.critic2.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q_target = th.minimum(q1_target, q2_target) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, (q_target - self.alpha * target_log_pi), BATCH.begin_mask).detach() # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * q1_loss + 0.5 * q2_loss self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action(pi, dist.log_prob(pi).unsqueeze(-1), is_independent=False) # [T, B, A] log_pi = tsallis_entropy_log_q(log_pi, self.entropic_index) # [T, B, 1] entropy = dist.entropy().mean() # 1 else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(-1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 q_s_pi = th.minimum(self.critic(BATCH.obs, pi, begin_mask=BATCH.begin_mask), self.critic2(BATCH.obs, pi, begin_mask=BATCH.begin_mask)) # [T, B, 1] actor_loss = -(q_s_pi - self.alpha * log_pi).mean() # 1 self.actor_oplr.optimize(actor_loss) summaries = { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/q1_loss': q1_loss, 'LOSS/q2_loss': q2_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/entropy': entropy, 'Statistics/q_min': th.minimum(q1, q2).min(), 'Statistics/q_mean': th.minimum(q1, q2).mean(), 'Statistics/q_max': th.maximum(q1, q2).max() } if self.auto_adaption: alpha_loss = -(self.alpha * (log_pi + self.target_entropy).detach()).mean() # 1 self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries
def _train(self, BATCH): time_step = BATCH.reward.shape[0] batch_size = BATCH.reward.shape[1] quantiles, quantiles_tiled = self._generate_quantiles( # [T*B, N, 1], [N*T*B, X] batch_size=time_step * batch_size, quantiles_num=self.online_quantiles) # [T*B, N, 1] => [T, B, N, 1] quantiles = quantiles.view(time_step, batch_size, -1, 1) quantiles_tiled = quantiles_tiled.view(time_step, -1, self.quantiles_idx) # [N*T*B, X] => [T, N*B, X] quantiles_value = self.q_net(BATCH.obs, quantiles_tiled, begin_mask=BATCH.begin_mask) # [T, N, B, A] # [T, N, B, A] => [N, T, B, A] * [T, B, A] => [N, T, B, 1] quantiles_value = (quantiles_value.swapaxes(0, 1) * BATCH.action).sum(-1, keepdim=True) q_eval = quantiles_value.mean(0) # [N, T, B, 1] => [T, B, 1] _, select_quantiles_tiled = self._generate_quantiles( # [N*T*B, X] batch_size=time_step * batch_size, quantiles_num=self.select_quantiles) select_quantiles_tiled = select_quantiles_tiled.view( time_step, -1, self.quantiles_idx) # [N*T*B, X] => [T, N*B, X] q_values = self.q_net( BATCH.obs_, select_quantiles_tiled, begin_mask=BATCH.begin_mask) # [T, N, B, A] q_values = q_values.mean(1) # [T, N, B, A] => [T, B, A] next_max_action = q_values.argmax(-1) # [T, B] next_max_action = F.one_hot( next_max_action, self.a_dim).float() # [T, B, A] _, target_quantiles_tiled = self._generate_quantiles( # [N'*T*B, X] batch_size=time_step * batch_size, quantiles_num=self.target_quantiles) target_quantiles_tiled = target_quantiles_tiled.view( time_step, -1, self.quantiles_idx) # [N'*T*B, X] => [T, N'*B, X] target_quantiles_value = self.q_net.t(BATCH.obs_, target_quantiles_tiled, begin_mask=BATCH.begin_mask) # [T, N', B, A] target_quantiles_value = target_quantiles_value.swapaxes(0, 1) # [T, N', B, A] => [N', T, B, A] target_quantiles_value = (target_quantiles_value * next_max_action).sum(-1, keepdim=True) # [N', T, B, 1] target_q = target_quantiles_value.mean(0) # [T, B, 1] q_target = n_step_return(BATCH.reward, # [T, B, 1] self.gamma, BATCH.done, # [T, B, 1] target_q, # [T, B, 1] BATCH.begin_mask).detach() # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] # [N', T, B, 1] => [N', T, B] target_quantiles_value = target_quantiles_value.squeeze(-1) target_quantiles_value = target_quantiles_value.permute( 1, 2, 0) # [N', T, B] => [T, B, N'] quantiles_value_target = n_step_return(BATCH.reward.repeat(1, 1, self.target_quantiles), self.gamma, BATCH.done.repeat(1, 1, self.target_quantiles), target_quantiles_value, BATCH.begin_mask.repeat(1, 1, self.target_quantiles)).detach() # [T, B, N'] # [T, B, N'] => [T, B, 1, N'] quantiles_value_target = quantiles_value_target.unsqueeze(-2) quantiles_value_online = quantiles_value.permute(1, 2, 0, 3) # [N, T, B, 1] => [T, B, N, 1] # [T, B, N, 1] - [T, B, 1, N'] => [T, B, N, N'] quantile_error = quantiles_value_online - quantiles_value_target huber = F.huber_loss(quantiles_value_online, quantiles_value_target, reduction="none", delta=self.huber_delta) # [T, B, N, N] # [T, B, N, 1] - [T, B, N, N'] => [T, B, N, N'] huber_abs = (quantiles - quantile_error.detach().le(0.).float()).abs() loss = (huber_abs * huber).mean(-1) # [T, B, N, N'] => [T, B, N] loss = loss.sum(-1, keepdim=True) # [T, B, N] => [T, B, 1] loss = (loss * BATCH.get('isw', 1.0)).mean() # 1 self.oplr.optimize(loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() }
def _train(self, BATCH): for _ in range(self.delay_num): if self.is_continuous: action_target = self.target_noised_action( self.actor.t(BATCH.obs_, begin_mask=BATCH.begin_mask)) # [T, B, A] else: target_logits = self.actor.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] action_target = F.one_hot(target_pi, self.a_dim).float() # [T, B, A] q1 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q2 = self.critic2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q_target = th.minimum( self.critic.t(BATCH.obs_, action_target, begin_mask=BATCH.begin_mask), self.critic2.t(BATCH.obs_, action_target, begin_mask=BATCH.begin_mask)) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target, BATCH.begin_mask).detach() # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * (q1_loss + q2_loss) self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] mu = _pi_diff + _pi # [T, B, A] q1_actor = self.critic(BATCH.obs, mu, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -q1_actor.mean() # 1 self.actor_oplr.optimize(actor_loss) return (td_error1 + td_error2) / 2, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/q_min': th.minimum(q1, q2).min(), 'Statistics/q_mean': th.minimum(q1, q2).mean(), 'Statistics/q_max': th.maximum(q1, q2).max() }
def _train(self, BATCH_DICT): summaries = {} reward = BATCH_DICT[self.agent_ids[0]].reward # [T, B, 1] done = 0. q_evals = [] q_actions = [] q_maxs = [] q_target_next_choose_maxs = [] q_target_actions = [] q_target_next_maxs = [] for aid, mid in zip(self.agent_ids, self.model_ids): done += BATCH_DICT[aid].done # [T, B, 1] q = self.q_nets[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] q_eval = (q * BATCH_DICT[aid].action).sum( -1, keepdim=True) # [T, B, 1] q_evals.append(q_eval) # N * [T, B, 1] q_actions.append(BATCH_DICT[aid].action) # N * [T, B, A] q_maxs.append(q.max(-1, keepdim=True)[0]) # [T, B, 1] q_target = self.q_nets[mid].t( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] # use double next_q = self.q_nets[mid]( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] next_max_action = next_q.argmax(-1) # [T, B] next_max_action_one_hot = F.one_hot( next_max_action, self.a_dims[aid]).float() # [T, B, A] q_target_next_max = (q_target * next_max_action_one_hot).sum( -1, keepdim=True) # [T, B, 1] q_target_next_choose_maxs.append( q_target_next_max) # N * [T, B, 1] q_target_actions.append(next_max_action_one_hot) # N * [T, B, A] q_target_next_maxs.append(q_target.max( -1, keepdim=True)[0]) # N * [T, B, 1] q_evals = th.stack(q_evals, -1) # [T, B, 1, N] q_maxs = th.stack(q_maxs, -1) # [T, B, 1, N] q_target_next_choose_maxs = th.stack(q_target_next_choose_maxs, -1) # [T, B, 1, N] q_target_next_maxs = th.stack(q_target_next_maxs, -1) # [T, B, 1, N] q_eval_tot = self.mixer( BATCH_DICT['global'].obs, q_evals, q_actions, q_maxs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1] q_target_next_max_tot = self.mixer.t( BATCH_DICT['global'].obs_, q_target_next_choose_maxs, q_target_actions, q_target_next_maxs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1] q_target_tot = n_step_return( reward, self.gamma, (done > 0.).float(), q_target_next_max_tot, BATCH_DICT['global'].begin_mask).detach() # [T, B, 1] td_error = q_target_tot - q_eval_tot # [T, B, 1] q_loss = td_error.square().mean() # 1 self.oplr.optimize(q_loss) summaries['model'] = { 'LOSS/q_loss': q_loss, 'Statistics/q_max': q_eval_tot.max(), 'Statistics/q_min': q_eval_tot.min(), 'Statistics/q_mean': q_eval_tot.mean() } return td_error, summaries
def _train(self, BATCH_DICT): """ TODO: Annotation """ summaries = defaultdict(dict) target_actions = {} target_log_pis = 1. for aid, mid in zip(self.agent_ids, self.model_ids): if self.is_continuouss[aid]: target_mu, target_log_std = self.actors[mid]( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] dist = td.Independent( td.Normal(target_mu, target_log_std.exp()), 1) target_pi = dist.sample() # [T, B, A] target_pi, target_log_pi = squash_action( target_pi, dist.log_prob(target_pi).unsqueeze( -1)) # [T, B, A], [T, B, 1] else: target_logits = self.actors[mid]( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze( -1) # [T, B, 1] target_pi = F.one_hot(target_pi, self.a_dims[aid]).float() # [T, B, A] target_actions[aid] = target_pi target_log_pis *= target_log_pi target_log_pis += th.finfo().eps target_actions = th.cat(list(target_actions.values()), -1) # [T, B, N*A] qs1, qs2, q_targets1, q_targets2 = {}, {}, {}, {} for mid in self.model_ids: qs1[mid] = self.critics[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat([BATCH_DICT[id].action for id in self.agent_ids], -1)) # [T, B, 1] qs2[mid] = self.critics2[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat([BATCH_DICT[id].action for id in self.agent_ids], -1)) # [T, B, 1] q_targets1[mid] = self.critics[mid].t( [BATCH_DICT[id].obs_ for id in self.agent_ids], target_actions) # [T, B, 1] q_targets2[mid] = self.critics2[mid].t( [BATCH_DICT[id].obs_ for id in self.agent_ids], target_actions) # [T, B, 1] q_loss = {} td_errors = 0. for aid, mid in zip(self.agent_ids, self.model_ids): q_target = th.minimum(q_targets1[mid], q_targets2[mid]) # [T, B, 1] dc_r = n_step_return( BATCH_DICT[aid].reward, self.gamma, BATCH_DICT[aid].done, q_target - self.alpha * target_log_pis, BATCH_DICT['global'].begin_mask).detach() # [T, B, 1] td_error1 = qs1[mid] - dc_r # [T, B, 1] td_error2 = qs2[mid] - dc_r # [T, B, 1] td_errors += (td_error1 + td_error2) / 2 q1_loss = td_error1.square().mean() # 1 q2_loss = td_error2.square().mean() # 1 q_loss[aid] = 0.5 * q1_loss + 0.5 * q2_loss summaries[aid].update({ 'Statistics/q_min': qs1[mid].min(), 'Statistics/q_mean': qs1[mid].mean(), 'Statistics/q_max': qs1[mid].max() }) self.critic_oplr.optimize(sum(q_loss.values())) log_pi_actions = {} log_pis = {} sample_pis = {} for aid, mid in zip(self.agent_ids, self.model_ids): if self.is_continuouss[aid]: mu, log_std = self.actors[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action( pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1] pi_action = BATCH_DICT[aid].action.arctanh() _, log_pi_action = squash_action( pi_action, dist.log_prob(pi_action).unsqueeze( -1)) # [T, B, A], [T, B, 1] else: logits = self.actors[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot( _pi.argmax(-1), self.a_dims[aid]).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] log_pi_action = (logp_all * BATCH_DICT[aid].action).sum( -1, keepdim=True) # [T, B, 1] log_pi_actions[aid] = log_pi_action log_pis[aid] = log_pi sample_pis[aid] = pi actor_loss = {} for aid, mid in zip(self.agent_ids, self.model_ids): all_actions = {id: BATCH_DICT[id].action for id in self.agent_ids} all_actions[aid] = sample_pis[aid] all_log_pis = {id: log_pi_actions[id] for id in self.agent_ids} all_log_pis[aid] = log_pis[aid] q_s_pi = th.minimum( self.critics[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat(list(all_actions.values()), -1), begin_mask=BATCH_DICT['global'].begin_mask), self.critics2[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat(list(all_actions.values()), -1), begin_mask=BATCH_DICT['global'].begin_mask)) # [T, B, 1] _log_pis = 1. for _log_pi in all_log_pis.values(): _log_pis *= _log_pi _log_pis += th.finfo().eps actor_loss[aid] = -(q_s_pi - self.alpha * _log_pis).mean() # 1 self.actor_oplr.optimize(sum(actor_loss.values())) for aid in self.agent_ids: summaries[aid].update({ 'LOSS/actor_loss': actor_loss[aid], 'LOSS/critic_loss': q_loss[aid] }) summaries['model'].update({ 'LOSS/actor_loss': sum(actor_loss.values()), 'LOSS/critic_loss': sum(q_loss.values()) }) if self.auto_adaption: _log_pis = 1. _log_pis = 1. for _log_pi in log_pis.values(): _log_pis *= _log_pi _log_pis += th.finfo().eps alpha_loss = -( self.alpha * (_log_pis + self.target_entropy).detach()).mean() # 1 self.alpha_oplr.optimize(alpha_loss) summaries['model'].update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return td_errors / self.n_agents_percopy, summaries
def _train_discrete(self, BATCH): v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] v_target = self.v_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, 1] q1_all = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q2_all = self.q_net2(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q1 = (q1_all * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q2 = (q2_all * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target, BATCH.begin_mask).detach() # [T, B, 1] td_v = v - (th.minimum((logp_all.exp() * q1_all).sum(-1, keepdim=True), (logp_all.exp() * q2_all).sum( -1, keepdim=True))).detach() # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop self.critic_oplr.optimize(critic_loss) q1_all = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q2_all = self.q_net2(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] entropy = -(logp_all.exp() * logp_all).sum(-1, keepdim=True) # [T, B, 1] q_all = th.minimum(q1_all, q2_all) # [T, B, A] actor_loss = -((q_all - self.alpha * logp_all) * logp_all.exp()).sum( -1) # [T, B, A] => [T, B] actor_loss = actor_loss.mean() # 1 self.actor_oplr.optimize(actor_loss) summaries = { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/q1_loss': q1_loss, 'LOSS/q2_loss': q2_loss, 'LOSS/v_loss': v_loss_stop, 'LOSS/critic_loss': critic_loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/entropy': entropy.mean(), 'Statistics/v_mean': v.mean() } if self.auto_adaption: corr = (self.target_entropy - entropy).detach() # [T, B, 1] # corr = ((logp_all - self.a_dim) * logp_all.exp()).sum(-1).detach() alpha_loss = -(self.alpha * corr) # [T, B, 1] alpha_loss = alpha_loss.mean() # 1 self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries
def _train_continuous(self, BATCH): v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] v_target = self.v_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, 1] if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action( pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] q1 = self.q_net(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q2 = self.q_net2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q1_pi = self.q_net(BATCH.obs, pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q2_pi = self.q_net2(BATCH.obs, pi, begin_mask=BATCH.begin_mask) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target, BATCH.begin_mask).detach() # [T, B, 1] v_from_q_stop = (th.minimum(q1_pi, q2_pi) - self.alpha * log_pi).detach() # [T, B, 1] td_v = v - v_from_q_stop # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action( pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1] entropy = dist.entropy().mean() # 1 else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 q1_pi = self.q_net(BATCH.obs, pi, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -(q1_pi - self.alpha * log_pi).mean() # 1 self.actor_oplr.optimize(actor_loss) summaries = { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/q1_loss': q1_loss, 'LOSS/q2_loss': q2_loss, 'LOSS/v_loss': v_loss_stop, 'LOSS/critic_loss': critic_loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/entropy': entropy, 'Statistics/q_min': th.minimum(q1, q2).min(), 'Statistics/q_mean': th.minimum(q1, q2).mean(), 'Statistics/q_max': th.maximum(q1, q2).max(), 'Statistics/v_mean': v.mean() } if self.auto_adaption: alpha_loss = -(self.alpha * (log_pi.detach() + self.target_entropy)).mean() self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries
def _train_discrete(self, BATCH): q1_all = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q2_all = self.critic2(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q1 = (q1_all * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] q2 = (q2_all * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] target_logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_log_probs = target_logits.log_softmax(-1) # [T, B, A] q1_target = self.critic.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q2_target = self.critic2.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] def v_target_function(x): return (target_log_probs.exp() * (x - self.alpha * target_log_probs)).sum( -1, keepdim=True) # [T, B, 1] v1_target = v_target_function(q1_target) # [T, B, 1] v2_target = v_target_function(q2_target) # [T, B, 1] v_target = th.minimum(v1_target, v2_target) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target, BATCH.begin_mask).detach() # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * q1_loss + 0.5 * q2_loss self.critic_oplr.optimize(critic_loss) q1_all = self.critic(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q2_all = self.critic2(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] entropy = -(logp_all.exp() * logp_all).sum(-1, keepdim=True) # [T, B, 1] q_all = th.minimum(q1_all, q2_all) # [T, B, A] actor_loss = -((q_all - self.alpha * logp_all) * logp_all.exp()).sum( -1) # [T, B, A] => [T, B] actor_loss = actor_loss.mean() # 1 # actor_loss = - (q_all + self.alpha * entropy).mean() self.actor_oplr.optimize(actor_loss) summaries = { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/q1_loss': q1_loss, 'LOSS/q2_loss': q2_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/entropy': entropy.mean() } if self.auto_adaption: corr = (self.target_entropy - entropy).detach() # [T, B, 1] # corr = ((logp_all - self.a_dim) * logp_all.exp()).sum(-1).detach() #[B, A] => [B,] # J(\alpha)=\pi_{t}\left(s_{t}\right)^{T}\left[-\alpha\left(\log \left(\pi_{t}\left(s_{t}\right)\right)+\bar{H}\right)\right] # \bar{H} is negative alpha_loss = -(self.alpha * corr) # [T, B, 1] alpha_loss = alpha_loss.mean() # 1 self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries
def _train(self, BATCH): if self.is_continuous: # Variational Auto-Encoder Training recon, mean, std = self.vae(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) recon_loss = F.mse_loss(recon, BATCH.action) KL_loss = -0.5 * (1 + th.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean() vae_loss = recon_loss + 0.5 * KL_loss self.vae_oplr.optimize(vae_loss) target_Qs = [] for _ in range(self._train_samples): # Compute value of perturbed actions sampled from the VAE _vae_actions = self.vae.decode(BATCH.obs_, begin_mask=BATCH.begin_mask) _actor_actions = self.actor.t(BATCH.obs_, _vae_actions, begin_mask=BATCH.begin_mask) target_Q1, target_Q2 = self.critic.t( BATCH.obs_, _actor_actions, begin_mask=BATCH.begin_mask) # Soft Clipped Double Q-learning target_Q = self._lmbda * th.min(target_Q1, target_Q2) + \ (1. - self._lmbda) * th.max(target_Q1, target_Q2) target_Qs.append(target_Q) target_Qs = th.stack(target_Qs, dim=0) # [N, T, B, 1] # Take max over each BATCH.action sampled from the VAE target_Q = target_Qs.max(dim=0)[0] # [T, B, 1] target_Q = n_step_return(BATCH.reward, self.gamma, BATCH.done, target_Q, BATCH.begin_mask).detach() # [T, B, 1] current_Q1, current_Q2 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) td_error = ((current_Q1 - target_Q) + (current_Q2 - target_Q)) / 2 critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( current_Q2, target_Q) self.critic_oplr.optimize(critic_loss) # Pertubation Model / Action Training sampled_actions = self.vae.decode(BATCH.obs, begin_mask=BATCH.begin_mask) perturbed_actions = self.actor(BATCH.obs, sampled_actions, begin_mask=BATCH.begin_mask) # Update through DPG q1, _ = self.critic(BATCH.obs, perturbed_actions, begin_mask=BATCH.begin_mask) actor_loss = -q1.mean() self.actor_oplr.optimize(actor_loss) return td_error, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LEARNING_RATE/vae_lr': self.vae_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'LOSS/vae_loss': vae_loss, 'Statistics/q_min': q1.min(), 'Statistics/q_mean': q1.mean(), 'Statistics/q_max': q1.max() } else: q_next, i_next = self.q_net( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_next = q_next - q_next.min(dim=-1, keepdim=True)[0] # [B, *] i_next = F.log_softmax(i_next, dim=-1) # [T, B, A] i_next = i_next.exp() # [T, B, A] i_next = (i_next / i_next.max(-1, keepdim=True)[0] > self._threshold).float() # [T, B, A] q_next = i_next * q_next # [T, B, A] next_max_action = q_next.argmax(-1) # [T, B] next_max_action_one_hot = F.one_hot( next_max_action.squeeze(), self.a_dim).float() # [T, B, A] q_target_next, _ = self.q_net.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] q_target_next_max = (q_target_next * next_max_action_one_hot).sum( -1, keepdim=True) # [T, B, 1] q_target = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target_next_max, BATCH.begin_mask).detach() # [T, B, 1] q, i = self.q_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] q_eval = (q * BATCH.action).sum(-1, keepdim=True) # [T, B, 1] td_error = q_target - q_eval # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 imt = F.log_softmax(i, dim=-1) # [T, B, A] imt = imt.reshape(-1, self.a_dim) # [T*B, A] action = BATCH.action.reshape(-1, self.a_dim) # [T*B, A] i_loss = F.nll_loss(imt, action.argmax(-1)) # 1 loss = q_loss + i_loss + 1e-2 * i.pow(2).mean() self.oplr.optimize(loss) return td_error, { 'LEARNING_RATE/lr': self.oplr.lr, 'LOSS/q_loss': q_loss, 'LOSS/i_loss': i_loss, 'LOSS/loss': loss, 'Statistics/q_max': q_eval.max(), 'Statistics/q_min': q_eval.min(), 'Statistics/q_mean': q_eval.mean() }