def __init__(self, input_dim, n_nodes, node_dim): super(GraphVAE, self).__init__() # store parameters self.input_dim = input_dim self.n_nodes = n_nodes self.node_dim = node_dim # encoder: x -> h_x self.encoder = nn.Sequential(nn.Linear(input_dim, 512), nn.BatchNorm1d(512), nn.ELU(), nn.Linear(512, 512), nn.BatchNorm1d(512), nn.ELU(), nn.Linear(512, 256), nn.BatchNorm1d(256), nn.ELU(), nn.Linear(256, 128)) # bottom-up inference: predicts parameters of P(z_i | x) self.bottom_up = nn.ModuleList([ nn.Sequential( nn.Linear(128, 128), nn.BatchNorm1d(128), nn.ELU(), nn.Linear(128, node_dim), nn.Linear(node_dim, 2 * node_dim) # split into mu and logvar ) for _ in range(n_nodes - 1) ]) # ignore z_n # top-down inference: predicts parameters of P(z_i | Pa(z_i)) self.top_down = nn.ModuleList([ nn.Sequential( nn.Linear((n_nodes - i - 1) * node_dim, 128), # parents of z_i are z_{i+1} ... z_N nn.BatchNorm1d(128), nn.ELU(), nn.Linear(128, node_dim), nn.Linear(node_dim, 2 * node_dim) # split into mu and logvar ) for i in range(n_nodes - 1) ]) # ignore z_n # decoder: (z_1, z_2 ... z_n) -> parameters of P(x) self.decoder = nn.Sequential(nn.Linear(node_dim * n_nodes, 256), nn.BatchNorm1d(256), nn.ELU(), nn.Linear(256, 512), nn.BatchNorm1d(512), nn.ELU(), nn.Linear(512, 512), nn.BatchNorm1d(512), nn.ELU(), nn.Linear(512, input_dim)) # mean of Bernoulli variables c_{i,j} representing edges self.gating_params = nn.ParameterList([ nn.Parameter(torch.empty(n_nodes - i - 1, 1, 1).fill_(0.5), requires_grad=True) for i in range(n_nodes - 1) ]) # ignore z_n # distributions for sampling self.unit_normal = D.Normal(torch.zeros(self.node_dim), torch.ones(self.node_dim)) self.gumbel = D.Gumbel(0., 1.) # other parameters / distributions self.tau = 1.0
def _train(self, BATCH): if self.is_continuous: action_target = self.actor.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] if self.use_target_action_noise: action_target = self.target_noised_action( action_target) # [T, B, A] else: target_logits = self.actor.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] action_target = F.one_hot(target_pi, self.a_dim).float() # [T, B, A] q = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q_target = self.critic.t(BATCH.obs_, action_target, begin_mask=BATCH.begin_mask) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target, BATCH.begin_mask).detach() # [T, B, 1] td_error = dc_r - q # [T, B, 1] q_loss = (td_error.square() * BATCH.get('isw', 1.0)).mean() # 1 self.critic_oplr.optimize(q_loss) if self.is_continuous: mu = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] mu = _pi_diff + _pi # [T, B, A] q_actor = self.critic(BATCH.obs, mu, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -q_actor.mean() # 1 self.actor_oplr.optimize(actor_loss) return td_error, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': q_loss, 'Statistics/q_min': q.min(), 'Statistics/q_mean': q.mean(), 'Statistics/q_max': q.max() }
def sample(self, sample_shape=torch.Size()): if sample_shape is not None: sample_shape = torch.Size(sample_shape) # In comments, I use S as an indication of dimension(s) related to sample_shape # and B as an indication of dimension(s) related to batch_shape with torch.no_grad(): batch_shape, K = self.batch_shape, self._K # This will store the sequence of labels from k=0 to K+1 # [S, B, K+2] L = torch.zeros(sample_shape + batch_shape + (K+2,), device=self._scores.device).long() # [L, B, K+2, 3] eps = td.Gumbel( loc=torch.zeros(L.shape + (3,), device=self._scores.device), scale=torch.ones(L.shape + (3,), device=self._scores.device) ).sample() # [...,K+1,3,3] W = self._arc_weight # [...,K+2,3] V = self._state_value for k in torch.arange(K+1, device=self._scores.device): # weights of arcs leaving this coordinate # [B, 3, 3] W_k = W[...,k,:,:] # reshape to introduce sample_shape dimensions # [S, B, 3, 3] W_k = W_k.view((1,) * len(sample_shape) + W_k.shape).expand(sample_shape + (-1,)*len(W_k.shape)) # origin state for coordinate k # [S, B] L_k = L[...,k] # reshape to a 3-dimensional one-hot encoding of the label # [S, B, 3, 1] L_k = torch.nn.functional.one_hot(L_k, 3).unsqueeze(-1) # select the weights for destination (zeroing out the rest) # [S, B, 3, 3] logits_k = torch.where(L_k == 1, W_k, torch.zeros_like(W_k)) # sum 0s out and incorporate value of destination # [S, B, 3] logits_k = logits_k.sum(-2) + V[...,k+1,:] # Categorical sampling via Gumbel-Argmax # possibly more efficient than td.Categorical(logits=logits_k).sample().long() L[...,k+1] = torch.argmax(logits_k + eps[...,k+1,:], -1).long() assert (L[...,-1] == 1).all(), "Not every sample reached the final state" L = L[...,1:-1] # discard the initial (k=0) and final (k=K+1) states # map to boolean and then float (in torch discrete samples are float) return (L==2).float()
def _train(self, BATCH_DICT): """ TODO: Annotation """ summaries = defaultdict(dict) target_actions = {} for aid, mid in zip(self.agent_ids, self.model_ids): if self.is_continuouss[aid]: target_actions[aid] = self.actors[mid].t( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] else: target_logits = self.actors[mid].t( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] action_target = F.one_hot( target_pi, self.a_dims[aid]).float() # [T, B, A] target_actions[aid] = action_target # [T, B, A] target_actions = th.cat(list(target_actions.values()), -1) # [T, B, N*A] qs, q_targets = {}, {} for mid in self.model_ids: qs[mid] = self.critics[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat([BATCH_DICT[id].action for id in self.agent_ids], -1)) # [T, B, 1] q_targets[mid] = self.critics[mid].t( [BATCH_DICT[id].obs_ for id in self.agent_ids], target_actions) # [T, B, 1] q_loss = {} td_errors = 0. for aid, mid in zip(self.agent_ids, self.model_ids): dc_r = n_step_return( BATCH_DICT[aid].reward, self.gamma, BATCH_DICT[aid].done, q_targets[mid], BATCH_DICT['global'].begin_mask).detach() # [T, B, 1] td_error = dc_r - qs[mid] # [T, B, 1] td_errors += td_error q_loss[aid] = 0.5 * td_error.square().mean() # 1 summaries[aid].update({ 'Statistics/q_min': qs[mid].min(), 'Statistics/q_mean': qs[mid].mean(), 'Statistics/q_max': qs[mid].max() }) self.critic_oplr.optimize(sum(q_loss.values())) actor_loss = {} for aid, mid in zip(self.agent_ids, self.model_ids): if self.is_continuouss[aid]: mu = self.actors[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] else: logits = self.actors[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot( _pi.argmax(-1), self.a_dims[aid]).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] mu = _pi_diff + _pi # [T, B, A] all_actions = {id: BATCH_DICT[id].action for id in self.agent_ids} all_actions[aid] = mu q_actor = self.critics[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat(list(all_actions.values()), -1), begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, 1] actor_loss[aid] = -q_actor.mean() # 1 self.actor_oplr.optimize(sum(actor_loss.values())) for aid in self.agent_ids: summaries[aid].update({ 'LOSS/actor_loss': actor_loss[aid], 'LOSS/critic_loss': q_loss[aid] }) summaries['model'].update({ 'LOSS/actor_loss', sum(actor_loss.values()), 'LOSS/critic_loss', sum(q_loss.values()) }) return td_errors / self.n_agents_percopy, summaries
def _train(self, BATCH): if self.is_continuous: target_mu, target_log_std = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(target_mu, target_log_std.exp()), 1) target_pi = dist.sample() # [T, B, A] target_pi, target_log_pi = squash_action(target_pi, dist.log_prob( target_pi).unsqueeze(-1), is_independent=False) # [T, B, A] target_log_pi = tsallis_entropy_log_q(target_log_pi, self.entropic_index) # [T, B, 1] else: target_logits = self.actor(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze(-1) # [T, B, 1] target_pi = F.one_hot(target_pi, self.a_dim).float() # [T, B, A] q1 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q2 = self.critic2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q1_target = self.critic.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q2_target = self.critic2.t(BATCH.obs_, target_pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q_target = th.minimum(q1_target, q2_target) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, (q_target - self.alpha * target_log_pi), BATCH.begin_mask).detach() # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * q1_loss + 0.5 * q2_loss self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action(pi, dist.log_prob(pi).unsqueeze(-1), is_independent=False) # [T, B, A] log_pi = tsallis_entropy_log_q(log_pi, self.entropic_index) # [T, B, 1] entropy = dist.entropy().mean() # 1 else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax(-1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 q_s_pi = th.minimum(self.critic(BATCH.obs, pi, begin_mask=BATCH.begin_mask), self.critic2(BATCH.obs, pi, begin_mask=BATCH.begin_mask)) # [T, B, 1] actor_loss = -(q_s_pi - self.alpha * log_pi).mean() # 1 self.actor_oplr.optimize(actor_loss) summaries = { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/q1_loss': q1_loss, 'LOSS/q2_loss': q2_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/entropy': entropy, 'Statistics/q_min': th.minimum(q1, q2).min(), 'Statistics/q_mean': th.minimum(q1, q2).mean(), 'Statistics/q_max': th.maximum(q1, q2).max() } if self.auto_adaption: alpha_loss = -(self.alpha * (log_pi + self.target_entropy).detach()).mean() # 1 self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries
def _train(self, BATCH): for _ in range(self.delay_num): if self.is_continuous: action_target = self.target_noised_action( self.actor.t(BATCH.obs_, begin_mask=BATCH.begin_mask)) # [T, B, A] else: target_logits = self.actor.t( BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] action_target = F.one_hot(target_pi, self.a_dim).float() # [T, B, A] q1 = self.critic(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q2 = self.critic2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q_target = th.minimum( self.critic.t(BATCH.obs_, action_target, begin_mask=BATCH.begin_mask), self.critic2.t(BATCH.obs_, action_target, begin_mask=BATCH.begin_mask)) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, q_target, BATCH.begin_mask).detach() # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * (q1_loss + q2_loss) self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] mu = _pi_diff + _pi # [T, B, A] q1_actor = self.critic(BATCH.obs, mu, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -q1_actor.mean() # 1 self.actor_oplr.optimize(actor_loss) return (td_error1 + td_error2) / 2, { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': critic_loss, 'Statistics/q_min': th.minimum(q1, q2).min(), 'Statistics/q_mean': th.minimum(q1, q2).mean(), 'Statistics/q_max': th.maximum(q1, q2).max() }
def _train_continuous(self, BATCH): v = self.v_net(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, 1] v_target = self.v_net.t(BATCH.obs_, begin_mask=BATCH.begin_mask) # [T, B, 1] if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action( pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] q1 = self.q_net(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q2 = self.q_net2(BATCH.obs, BATCH.action, begin_mask=BATCH.begin_mask) # [T, B, 1] q1_pi = self.q_net(BATCH.obs, pi, begin_mask=BATCH.begin_mask) # [T, B, 1] q2_pi = self.q_net2(BATCH.obs, pi, begin_mask=BATCH.begin_mask) # [T, B, 1] dc_r = n_step_return(BATCH.reward, self.gamma, BATCH.done, v_target, BATCH.begin_mask).detach() # [T, B, 1] v_from_q_stop = (th.minimum(q1_pi, q2_pi) - self.alpha * log_pi).detach() # [T, B, 1] td_v = v - v_from_q_stop # [T, B, 1] td_error1 = q1 - dc_r # [T, B, 1] td_error2 = q2 - dc_r # [T, B, 1] q1_loss = (td_error1.square() * BATCH.get('isw', 1.0)).mean() # 1 q2_loss = (td_error2.square() * BATCH.get('isw', 1.0)).mean() # 1 v_loss_stop = (td_v.square() * BATCH.get('isw', 1.0)).mean() # 1 critic_loss = 0.5 * q1_loss + 0.5 * q2_loss + 0.5 * v_loss_stop self.critic_oplr.optimize(critic_loss) if self.is_continuous: mu, log_std = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action( pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1] entropy = dist.entropy().mean() # 1 else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] entropy = -(logp_all.exp() * logp_all).sum(-1).mean() # 1 q1_pi = self.q_net(BATCH.obs, pi, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -(q1_pi - self.alpha * log_pi).mean() # 1 self.actor_oplr.optimize(actor_loss) summaries = { 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/actor_loss': actor_loss, 'LOSS/q1_loss': q1_loss, 'LOSS/q2_loss': q2_loss, 'LOSS/v_loss': v_loss_stop, 'LOSS/critic_loss': critic_loss, 'Statistics/log_alpha': self.log_alpha, 'Statistics/alpha': self.alpha, 'Statistics/entropy': entropy, 'Statistics/q_min': th.minimum(q1, q2).min(), 'Statistics/q_mean': th.minimum(q1, q2).mean(), 'Statistics/q_max': th.maximum(q1, q2).max(), 'Statistics/v_mean': v.mean() } if self.auto_adaption: alpha_loss = -(self.alpha * (log_pi.detach() + self.target_entropy)).mean() self.alpha_oplr.optimize(alpha_loss) summaries.update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return (td_error1 + td_error2) / 2, summaries
def __init__(self, u=0, b=1, t=.1, dim=-1): self.gumbel = distributions.Gumbel(loc=u, scale=b) self.temperature = t self.dim = dim
def _train(self, BATCH_DICT): """ TODO: Annotation """ summaries = defaultdict(dict) target_actions = {} target_log_pis = 1. for aid, mid in zip(self.agent_ids, self.model_ids): if self.is_continuouss[aid]: target_mu, target_log_std = self.actors[mid]( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] dist = td.Independent( td.Normal(target_mu, target_log_std.exp()), 1) target_pi = dist.sample() # [T, B, A] target_pi, target_log_pi = squash_action( target_pi, dist.log_prob(target_pi).unsqueeze( -1)) # [T, B, A], [T, B, 1] else: target_logits = self.actors[mid]( BATCH_DICT[aid].obs_, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T, B] target_log_pi = target_cate_dist.log_prob(target_pi).unsqueeze( -1) # [T, B, 1] target_pi = F.one_hot(target_pi, self.a_dims[aid]).float() # [T, B, A] target_actions[aid] = target_pi target_log_pis *= target_log_pi target_log_pis += th.finfo().eps target_actions = th.cat(list(target_actions.values()), -1) # [T, B, N*A] qs1, qs2, q_targets1, q_targets2 = {}, {}, {}, {} for mid in self.model_ids: qs1[mid] = self.critics[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat([BATCH_DICT[id].action for id in self.agent_ids], -1)) # [T, B, 1] qs2[mid] = self.critics2[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat([BATCH_DICT[id].action for id in self.agent_ids], -1)) # [T, B, 1] q_targets1[mid] = self.critics[mid].t( [BATCH_DICT[id].obs_ for id in self.agent_ids], target_actions) # [T, B, 1] q_targets2[mid] = self.critics2[mid].t( [BATCH_DICT[id].obs_ for id in self.agent_ids], target_actions) # [T, B, 1] q_loss = {} td_errors = 0. for aid, mid in zip(self.agent_ids, self.model_ids): q_target = th.minimum(q_targets1[mid], q_targets2[mid]) # [T, B, 1] dc_r = n_step_return( BATCH_DICT[aid].reward, self.gamma, BATCH_DICT[aid].done, q_target - self.alpha * target_log_pis, BATCH_DICT['global'].begin_mask).detach() # [T, B, 1] td_error1 = qs1[mid] - dc_r # [T, B, 1] td_error2 = qs2[mid] - dc_r # [T, B, 1] td_errors += (td_error1 + td_error2) / 2 q1_loss = td_error1.square().mean() # 1 q2_loss = td_error2.square().mean() # 1 q_loss[aid] = 0.5 * q1_loss + 0.5 * q2_loss summaries[aid].update({ 'Statistics/q_min': qs1[mid].min(), 'Statistics/q_mean': qs1[mid].mean(), 'Statistics/q_max': qs1[mid].max() }) self.critic_oplr.optimize(sum(q_loss.values())) log_pi_actions = {} log_pis = {} sample_pis = {} for aid, mid in zip(self.agent_ids, self.model_ids): if self.is_continuouss[aid]: mu, log_std = self.actors[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] dist = td.Independent(td.Normal(mu, log_std.exp()), 1) pi = dist.rsample() # [T, B, A] pi, log_pi = squash_action( pi, dist.log_prob(pi).unsqueeze(-1)) # [T, B, A], [T, B, 1] pi_action = BATCH_DICT[aid].action.arctanh() _, log_pi_action = squash_action( pi_action, dist.log_prob(pi_action).unsqueeze( -1)) # [T, B, A], [T, B, 1] else: logits = self.actors[mid]( BATCH_DICT[aid].obs, begin_mask=BATCH_DICT['global'].begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot( _pi.argmax(-1), self.a_dims[aid]).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] pi = _pi_diff + _pi # [T, B, A] log_pi = (logp_all * pi).sum(-1, keepdim=True) # [T, B, 1] log_pi_action = (logp_all * BATCH_DICT[aid].action).sum( -1, keepdim=True) # [T, B, 1] log_pi_actions[aid] = log_pi_action log_pis[aid] = log_pi sample_pis[aid] = pi actor_loss = {} for aid, mid in zip(self.agent_ids, self.model_ids): all_actions = {id: BATCH_DICT[id].action for id in self.agent_ids} all_actions[aid] = sample_pis[aid] all_log_pis = {id: log_pi_actions[id] for id in self.agent_ids} all_log_pis[aid] = log_pis[aid] q_s_pi = th.minimum( self.critics[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat(list(all_actions.values()), -1), begin_mask=BATCH_DICT['global'].begin_mask), self.critics2[mid]( [BATCH_DICT[id].obs for id in self.agent_ids], th.cat(list(all_actions.values()), -1), begin_mask=BATCH_DICT['global'].begin_mask)) # [T, B, 1] _log_pis = 1. for _log_pi in all_log_pis.values(): _log_pis *= _log_pi _log_pis += th.finfo().eps actor_loss[aid] = -(q_s_pi - self.alpha * _log_pis).mean() # 1 self.actor_oplr.optimize(sum(actor_loss.values())) for aid in self.agent_ids: summaries[aid].update({ 'LOSS/actor_loss': actor_loss[aid], 'LOSS/critic_loss': q_loss[aid] }) summaries['model'].update({ 'LOSS/actor_loss': sum(actor_loss.values()), 'LOSS/critic_loss': sum(q_loss.values()) }) if self.auto_adaption: _log_pis = 1. _log_pis = 1. for _log_pi in log_pis.values(): _log_pis *= _log_pi _log_pis += th.finfo().eps alpha_loss = -( self.alpha * (_log_pis + self.target_entropy).detach()).mean() # 1 self.alpha_oplr.optimize(alpha_loss) summaries['model'].update({ 'LOSS/alpha_loss': alpha_loss, 'LEARNING_RATE/alpha_lr': self.alpha_oplr.lr }) return td_errors / self.n_agents_percopy, summaries
def _train(self, BATCH): obs = get_first_vector(BATCH.obs) # [T, B, S] obs_ = get_first_vector(BATCH.obs_) # [T, B, S] _timestep = obs.shape[0] _batchsize = obs.shape[1] predicted_obs_ = self._forward_dynamic_model(obs, BATCH.action) # [T, B, S] predicted_reward = self._reward_model(obs, BATCH.action) # [T, B, 1] predicted_done_dist = self._done_model(obs, BATCH.action) # [T, B, 1] _obs_loss = F.mse_loss(obs_, predicted_obs_) # todo _reward_loss = F.mse_loss(BATCH.reward, predicted_reward) _done_loss = -predicted_done_dist.log_prob(BATCH.done).mean() wm_loss = _obs_loss + _reward_loss + _done_loss self._wm_oplr.optimize(wm_loss) obs = th.reshape(obs, (_timestep * _batchsize, -1)) # [T*B, S] obs_ = th.reshape(obs_, (_timestep * _batchsize, -1)) # [T*B, S] actions = th.reshape(BATCH.action, (_timestep * _batchsize, -1)) # [T*B, A] rewards = th.reshape(BATCH.reward, (_timestep * _batchsize, -1)) # [T*B, 1] dones = th.reshape(BATCH.done, (_timestep * _batchsize, -1)) # [T*B, 1] rollout_rewards = [rewards] rollout_dones = [dones] r_obs_ = obs_ _r_obs = deepcopy(BATCH.obs_) r_done = (1. - dones) for _ in range(self._roll_out_horizon): r_obs = r_obs_ _r_obs.vector.vector_0 = r_obs if self.is_continuous: action_target = self.actor.t(_r_obs) # [T*B, A] if self.use_target_action_noise: r_action = self.target_noised_action( action_target) # [T*B, A] else: target_logits = self.actor.t(_r_obs) # [T*B, A] target_cate_dist = td.Categorical(logits=target_logits) target_pi = target_cate_dist.sample() # [T*B,] r_action = F.one_hot(target_pi, self.a_dim).float() # [T*B, A] r_obs_ = self._forward_dynamic_model(r_obs, r_action) # [T*B, S] r_reward = self._reward_model(r_obs, r_action) # [T*B, 1] r_done = r_done * (1. - self._done_model(r_obs, r_action).sample() ) # [T*B, 1] rollout_rewards.append(r_reward) # [H+1, T*B, 1] rollout_dones.append(r_done) # [H+1, T*B, 1] _r_obs.vector.vector_0 = obs q = self.critic(_r_obs, actions) # [T*B, 1] _r_obs.vector.vector_0 = r_obs_ q_target = self.critic.t(_r_obs, r_action) # [T*B, 1] dc_r = rewards for t in range(1, self._roll_out_horizon): dc_r += (self.gamma**t) * (rollout_rewards[t] * rollout_dones[t]) dc_r += (self.gamma**self._roll_out_horizon) * rollout_dones[ self._roll_out_horizon] * q_target # [T*B, 1] td_error = dc_r - q # [T*B, 1] q_loss = td_error.square().mean() # 1 self.critic_oplr.optimize(q_loss) # train actor if self.is_continuous: mu = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] else: logits = self.actor(BATCH.obs, begin_mask=BATCH.begin_mask) # [T, B, A] logp_all = logits.log_softmax(-1) # [T, B, A] gumbel_noise = td.Gumbel(0, 1).sample(logp_all.shape) # [T, B, A] _pi = ((logp_all + gumbel_noise) / self.discrete_tau).softmax( -1) # [T, B, A] _pi_true_one_hot = F.one_hot(_pi.argmax(-1), self.a_dim).float() # [T, B, A] _pi_diff = (_pi_true_one_hot - _pi).detach() # [T, B, A] mu = _pi_diff + _pi # [T, B, A] q_actor = self.critic(BATCH.obs, mu, begin_mask=BATCH.begin_mask) # [T, B, 1] actor_loss = -q_actor.mean() # 1 self.actor_oplr.optimize(actor_loss) return th.ones_like(BATCH.reward), { 'LEARNING_RATE/wm_lr': self._wm_oplr.lr, 'LEARNING_RATE/actor_lr': self.actor_oplr.lr, 'LEARNING_RATE/critic_lr': self.critic_oplr.lr, 'LOSS/wm_loss': wm_loss, 'LOSS/actor_loss': actor_loss, 'LOSS/critic_loss': q_loss, 'Statistics/q_min': q.min(), 'Statistics/q_mean': q.mean(), 'Statistics/q_max': q.max() }