def update_moa(self, sample, agent_i, parallel=False, grad_norm=0.5): """ Update parameters of moa networks based on lastest sample from replay buffer Arguments: sample: [(B,D)]*N, obs, next_obs, action can be [dict (B,D)]*N agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads """ # [(B,1,D)]*N or [dict (B,1,D)]*N interm_sample = self.add_virtual_dim(sample) # place current agent subsample to first in sample batch obs, acs, rews, next_obs, dones = [ switch_list(s, agent_i) for s in interm_sample ] bs, ts, _ = obs[0].shape curr_agent = self.agents[agent_i] curr_agent.init_moa_hidden(bs) # use pre-defined init hiddens results = {} # perform update on each moa agent for agent_j in range(1, self.nagents): # current agent's j-th moa pi_j = curr_agent.moa_policies[agent_j] curr_agent.moa_optimizers[agent_j].zero_grad() log_prob_j, entropy_j = curr_agent.evaluate_moa_action( agent_j, self.wrap_action(acs[agent_j]), obs[agent_j] ) # (B,T,1) log_prob_loss = -log_prob_j.reshape(bs*ts, -1).mean() entropy_loss = -entropy_j.reshape(bs*ts, -1).mean() moa_loss_j = log_prob_loss + self.moa_entropy_coeff * entropy_loss moa_loss_j.backward() if parallel: average_gradients(pi_j) if grad_norm > 0: torch.nn.utils.clip_grad_norm(pi_j.parameters(), grad_norm) curr_agent.moa_optimizers[agent_j].step() # loggings (might be overwhelming) for k, v in zip( ["log_prob_loss", "entropy_loss"], [log_prob_loss, entropy_loss] ): key = "agent_{}/moa_{}/{}".format(agent_i, agent_j, k) value = v.data.cpu().numpy() results[key] = value self.agent_losses[key].append(value) return results
def update(self, sample, agent_i, parallel=False, logger=None): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next observations, and episode end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads logger (SummaryWriter from Tensorboard-Pytorch): If passed in, important quantities will be logged """ obs, acs, rews, next_obs, dones = sample curr_agent = self.agents[agent_i] curr_agent.critic_optimizer.zero_grad() if self.alg_types[agent_i] == 'MADDPG': if self.discrete_action: # one-hot encode action all_trgt_acs = [ onehot_from_logits(pi(nobs)) for pi, nobs in zip(self.target_policies, next_obs) ] else: all_trgt_acs = [ pi(nobs) for pi, nobs in zip(self.target_policies, next_obs) ] trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1) else: # DDPG if self.discrete_action: trgt_vf_in = torch.cat( (next_obs[agent_i], onehot_from_logits( curr_agent.target_policy(next_obs[agent_i]))), dim=1) else: trgt_vf_in = torch.cat( (next_obs[agent_i], curr_agent.target_policy(next_obs[agent_i])), dim=1) target_value = (rews[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(trgt_vf_in) * (1 - dones[agent_i].view(-1, 1))) if self.alg_types[agent_i] == 'MADDPG': vf_in = torch.cat((*obs, *acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1) actual_value = curr_agent.critic(vf_in) vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() if parallel: average_gradients(curr_agent.critic) torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5) curr_agent.critic_optimizer.step() curr_agent.policy_optimizer.zero_grad() if self.discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) else: curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = curr_pol_out if self.alg_types[agent_i] == 'MADDPG': all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): if i == agent_i: all_pol_acs.append(curr_pol_vf_in) elif self.discrete_action: all_pol_acs.append(onehot_from_logits(pi(ob))) else: all_pol_acs.append(pi(ob)) vf_in = torch.cat((*obs, *all_pol_acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=1) pol_loss = -curr_agent.critic(vf_in).mean() pol_loss += (curr_pol_out**2).mean() * 1e-3 pol_loss.backward() if parallel: average_gradients(curr_agent.policy) torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5) curr_agent.policy_optimizer.step() if logger is not None: logger.add_scalars('agent%i/losses' % agent_i, { 'vf_loss': vf_loss, 'pol_loss': pol_loss }, self.niter)
def update(self, sample, agent_i, parallel=False, logger=None): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next observations, and episode end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads logger (SummaryWriter from Tensorboard-Pytorch): If passed in, important quantities will be logged """ obs, acs, rews, next_obs, dones = sample curr_agent = self.agents[agent_i] curr_agent.critic_optimizer.zero_grad() if self.alg_types[agent_i] == 'MADDPG': if self.discrete_action: # one-hot encode action all_trgt_acs = [ onehot_from_logits(pi(nobs)) for pi, nobs in zip(self.target_policies, next_obs) ] # a'=mu'(o') Have all agents' else: all_trgt_acs = [ pi(nobs) for pi, nobs in zip(self.target_policies, next_obs) ] # ==========================Adding noise==================== if self.noisy_sharing == True: #noisy_all_trgt_acs = self.noisy_sharing_discrete(all_trgt_acs,agent_i) #all_trgt_acs = noisy_all_trgt_acs noisy_acs = self.noisy_sharing_discrete(acs, agent_i) acs = noisy_acs # print(self.noisy_SNR) # ==================End of Adding noise==================== trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1) # =========================Differential Obs======================== # ============== Dedicate for simple_speaker_listener ============= # The est_action is used to replace acs[1] if self.game_id == 'simple_speaker_listener' and self.est_ac == True: diff_pos = (next_obs[0] - obs[0])[:, -2:] tmp_p = torch.transpose(diff_pos.ge(torch.max(diff_pos) * 0.8), 0, 1) tmp_p[0] = tmp_p[0] * 1 tmp_p[1] = tmp_p[1] * 3 tmp_n = torch.transpose(diff_pos.le(torch.min(diff_pos) * 0.8), 0, 1) tmp_n[0] = tmp_n[0] * 2 tmp_n[1] = tmp_n[1] * 4 mask = torch.transpose(tmp_p, 0, 1) + torch.transpose( tmp_n, 0, 1) est_action = mask.sum(dim=1) est_action = torch.zeros(len(est_action), acs[1].shape[1]).scatter_( dim=1, index=est_action.view(-1, 1), value=1) acs[1] = est_action # =======================End of differential Obs ================== else: # DDPG if self.discrete_action: trgt_vf_in = torch.cat( (next_obs[agent_i], onehot_from_logits( curr_agent.target_policy(next_obs[agent_i]))), dim=1) else: # a'=mu(o') only have current agent's trgt_vf_in = torch.cat( (next_obs[agent_i], curr_agent.target_policy(next_obs[agent_i])), dim=1) target_value = (rews[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(trgt_vf_in) * (1 - dones[agent_i].view(-1, 1))) #y^j if self.alg_types[agent_i] == 'MADDPG': vf_in = torch.cat((*obs, *acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1) actual_value = curr_agent.critic(vf_in) vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() if parallel: average_gradients(curr_agent.critic) torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5) curr_agent.critic_optimizer.step() # ============== Here for policy network training ===================== curr_agent.policy_optimizer.zero_grad() if self.discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) else: curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = curr_pol_out if self.alg_types[agent_i] == 'MADDPG': all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): # Is it correct to train mu using all others' policies??? if i == agent_i: all_pol_acs.append(curr_pol_vf_in) elif self.discrete_action: all_pol_acs.append(onehot_from_logits(pi(ob))) else: all_pol_acs.append(pi(ob)) vf_in = torch.cat((*obs, *all_pol_acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=1) pol_loss = -curr_agent.critic(vf_in).mean() pol_loss += (curr_pol_out**2).mean() * 1e-3 pol_loss.backward() if parallel: average_gradients(curr_agent.policy) torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5) # Constraints on the grad. curr_agent.policy_optimizer.step() if logger is not None: logger.add_scalars('agent%i/losses' % agent_i, { 'vf_loss': vf_loss, 'pol_loss': pol_loss }, self.niter)
def update(self, sample, agent_i, parallel=False, grad_norm=0.5, norm_rewards=False): """ Update parameters of agent model based on sample from replay buffer Arguments: sample: [(B,D)]*N, obs, next_obs, action can be [dict (B,D)]*N agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads """ def switch_idx(idx, curr_agent_idx): return idx if idx > curr_agent_idx else idx + 1 # [(B,1,D)]*N or [dict (B,1,D)]*N obs, acs, rews, next_obs, dones = self.add_virtual_dim(sample) # preprocess rewards to reduce variance if norm_rewards: rews = selef.normalize_rewards(rews) bs, ts, _ = obs[agent_i].shape curr_agent = self.agents[agent_i] # NOTE: Critic update curr_agent.critic_optimizer.zero_grad() # compute target actions if self.alg_types[agent_i] == 'MADDPG': all_trgt_acs = [] # [dict (B,1,A)]*N for i, nobs in enumerate(next_obs): # (B,1,O) if self.model_of_agents: if i == agent_i: # use current agent target act_i = curr_agent.compute_action( nobs, target=True, requires_grad=False) else: # use moa agent target agent_j = switch_idx(i, agent_i) act_i = curr_agent.compute_moa_action( agent_j, nobs, target=True, requires_grad=False, return_logits=True) else: # use each agents' target directly act_i = self.agents[i].compute_action( nobs, target=True, requires_grad=False) all_trgt_acs.append(act_i) # [(B,1,O)_i, ..., (B,1,A)_i, ...] -> (B,1,O*N+A*N) trgt_vf_in = torch.cat([ *self.flatten_obs(next_obs, ma=True), *self.flatten_act(all_trgt_acs, ma=True) ], dim=-1) else: # DDPG act_i = curr_agent.compute_action(next_obs[agent_i], target=True, requires_grad=False) # (B,1,O) + (B,1,A) -> (B,1,O+A) trgt_vf_in = torch.cat([ self.flatten_obs(next_obs[agent_i]), self.flatten_act(act_i) ], dim=-1) # bellman targets # (B*T,1) -> (B*1,1) -> (B,1) target_q = curr_agent.compute_value(trgt_vf_in, target=True) target_value = (rews[agent_i].view(-1, 1) + self.gamma * target_q * (1 - dones[agent_i].view(-1, 1))) # Q func if self.alg_types[agent_i] == 'MADDPG': vf_in = torch.cat([ *self.flatten_obs(obs, ma=True), *self.flatten_act(acs, ma=True) ], dim=-1) else: # DDPG vf_in = torch.cat([ self.flatten_obs(obs[agent_i]), self.flatten_act(acs[agent_i]) ], dim=-1) actual_value = curr_agent.compute_value(vf_in, target=False) # (B*T,1) # bellman errors vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() if parallel: average_gradients(curr_agent.critic) if grad_norm > 0: torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), grad_norm) curr_agent.critic_optimizer.step() # NOTE: Policy update curr_agent.policy_optimizer.zero_grad() # current agent action (deterministic, softened), dcit (B,T,A) curr_pol_out = curr_agent.compute_action(obs[agent_i], target=False, requires_grad=True) if self.alg_types[agent_i] == 'MADDPG': all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): if i == agent_i: # insert current agent act to q input all_pol_acs.append(self.flatten_act(curr_pol_out)) # all_pol_acs.append(curr_pol_out) else: # p_act_i = self.agents[i].compute_action(ob, target=False, requires_grad=False) p_act_i = self.flatten_act(acs[i]) all_pol_acs.append(p_act_i) # (B,T,O*N+A*N)s p_vf_in = torch.cat([ *self.flatten_obs(obs, ma=True), *self.flatten_act(all_pol_acs, ma=True) ], dim=-1) else: # DDPG # (B,T,O+A) p_vf_in = torch.cat([ self.flatten_obs(obs[agent_i]), self.flatten_act(curr_pol_out) ], dim=-1) # value function to update current policy p_value = curr_agent.compute_value(p_vf_in, target=False) # (B*T,1) pol_loss = -p_value.mean() # p regularization, scale down output (gaussian mean,std or logits) # reference: https://github.com/openai/maddpg/blob/master/maddpg/trainer/maddpg.py pol_reg_loss = torch.tensor(0.0) for k, v in curr_pol_out.items(): pol_reg_loss += ((v.reshape(bs*ts, -1))**2).mean() * 1e-3 pol_loss_total = pol_loss + pol_reg_loss pol_loss_total.backward() if parallel: average_gradients(curr_agent.policy) if grad_norm > 0: torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), grad_norm) curr_agent.policy_optimizer.step() # NOTE: collect training statss results = {} for k, v in zip( ["critic_loss", "policy_loss", "policy_reg_loss"], [vf_loss, pol_loss, pol_reg_loss] ): key = "agent_{}/{}".format(agent_i, k) value = v.data.cpu().numpy() results[key] = value self.agent_losses[key].append(value) return results
def update(self, sample, agent_i, parallel=False, grad_norm=0.5): """ Update parameters of agent model based on sample from replay buffer Arguments: sample: [(B,T,D)]*N, obs, next_obs, action, logprobs can be [dict (B,T,D)]*N agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads """ obs, acs, rews, next_obs, dones, logprobs = sample # each [(B,T,D)]*N bs, ts, _ = obs[agent_i].shape self.init_hidden(bs) # use pre-defined init hiddens curr_agent = self.agents[agent_i] # entropy temperature param alpha = curr_agent.log_alpha.get_alpha().detach() # NOTE: Critic update curr_agent.critic_optimizer.zero_grad() # compute target actions if self.alg_types[agent_i] == 'MASAC': all_trgt_acs = [] # [dict (B,T,A)]*N all_trgt_logprobs = [] # [dict (B,T,1)]*N for i, nobs in enumerate(next_obs): # (B,T,O) with torch.no_grad(): act_i, log_prob_i, _ = self.agents[ i].compute_action_logprob(nobs) all_trgt_acs.append(act_i) all_trgt_logprobs.append(log_prob_i) # [(B,T,O)_i, ..., (B,T,A)_i, ...] -> (B,T,O*N+A*N) trgt_vf_in = torch.cat([ *self.flatten_obs(next_obs, ma=True), *self.flatten_act(all_trgt_acs, ma=True) ], dim=-1) # log prob of target action, [(B,T,1)]*N target_a_logprob = self.contract_logprob(all_trgt_logprobs, ma=True) # [(B,T,1)]*N -> (B,T,N) -> (B,T,1) target_a_logprob = torch.sum(torch.cat(target_a_logprob, -1), -1) else: # SAC with torch.no_grad(): act_i, log_prob_i, _ = curr_agent.compute_action_logprob( next_obs[agent_i]) # (B,T,O) + (B,T,A) -> (B,T,O+A) trgt_vf_in = torch.cat( [self.flatten_obs(next_obs[agent_i]), self.flatten_act(act_i)], dim=-1) # log prob of target action, (B,T,1) target_a_logprob = self.contract_logprob(log_prob_i) # bellman targets target_q1, target_q2 = curr_agent.compute_value(trgt_vf_in, target=True) # (B*T,1) target_a_logprob = target_a_logprob.reshape(-1, 1).detach() # (B*T,1) target_q = torch.min(target_q1, target_q2) - alpha * target_a_logprob target_value = (rews[agent_i].view(-1, 1) + self.gamma * target_q * (1.0 - dones[agent_i].view(-1, 1))) # (B*T,1) # Q func if self.alg_types[agent_i] == 'MASAC': vf_in = torch.cat([ *self.flatten_obs(obs, ma=True), *self.flatten_act(acs, ma=True) ], dim=-1) else: # DDPG vf_in = torch.cat([ self.flatten_obs(obs[agent_i]), self.flatten_act(acs[agent_i]) ], dim=-1) q1, q2 = curr_agent.compute_value(vf_in, target=False) # (B*T,1) # bellman errors vf_loss1 = MSELoss(q1, target_value.detach()) vf_loss1.backward() vf_loss2 = MSELoss(q2, target_value.detach()) vf_loss2.backward() if parallel: average_gradients(curr_agent.critic1) average_gradients(curr_agent.critic2) if grad_norm > 0: torch.nn.utils.clip_grad_norm_(curr_agent.critic1.parameters(), grad_norm) torch.nn.utils.clip_grad_norm_(curr_agent.critic2.parameters(), grad_norm) curr_agent.critic1_optimizer.step() curr_agent.critic2_optimizer.step() # NOTE: Policy update curr_agent.policy_optimizer.zero_grad() # current agent action (deterministic, softened), dcit (B,T,A) curr_pol_out, curr_log_prob, _ = curr_agent.compute_action_logprob( obs[agent_i]) a_log_prob = self.contract_logprob(log_prob_d) if self.alg_types[agent_i] == 'MASAC': all_pol_acs = [] all_pol_logprobs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): if i == agent_i: # insert current agent act to q input all_pol_acs.append(self.flatten_act(curr_pol_out)) # current agent log prob (backprop-able) all_pol_logprobs.append(curr_log_prob) else: # TODO: need other agents' log probs as well # p_act_i = self.agents[i].compute_action(ob, target=False, requires_grad=False) p_act_i = self.flatten_act(acs[i]) all_pol_acs.append(p_act_i) # other agents' log probs (during sampling) all_pol_logprobs.append(logprobs[i]) # (B,T,O*N+A*N)s p_vf_in = torch.cat([ *self.flatten_obs(obs, ma=True), *self.flatten_act(all_pol_acs, ma=True) ], dim=-1) # [dict (B,T,1)]*N -> [(B,T,1)]*N -> (B,T,1) a_log_prob = self.contract_logprob(all_pol_logprobs, ma=True) a_log_prob = torch.sum(torch.cat(a_log_prob, -1), -1) else: # DDPG # (B,T,O+A) p_vf_in = torch.cat([ self.flatten_obs(obs[agent_i]), self.flatten_act(curr_pol_out) ], dim=-1) # dict (B,T,1) -> (B,T,1) a_log_prob = self.contract_logprob(curr_log_prob) # KL loss between alpha log prob & target policy value function p_value1, p_value2 = curr_agent.compute_value(p_vf_in, target=False) # (B*T,1) p_value_target = torch.min(p_value1, p_value2) pol_loss = alpha * a_log_prob - p_value_target # NOTE: this is optional (not in SAC) # p regularization, scale down output (gaussian mean,std or logits) # reference: https://github.com/openai/maddpg/blob/master/maddpg/trainer/maddpg.py pol_reg_loss = torch.tensor(0.0) for k, v in curr_pol_out.items(): pol_reg_loss += ((v.reshape(bs * ts, -1))**2).mean() * 1e-3 pol_loss_total = pol_loss + pol_reg_loss pol_loss_total.backward() if parallel: average_gradients(curr_agent.policy) if grad_norm > 0: torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), grad_norm) curr_agent.policy_optimizer.step() # NOTE: Alpha (entropy) update alpha_loss = -curr_agent.log_alpha() * (a_log_prob.detach() + self.target_entropy) alpha_loss.backward() if parallel: average_gradients(curr_agent.policy) if grad_norm > 0: torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), grad_norm) curr_agent.alpha_optimizer.step() # NOTE: collect training statss results = {} for k, v in zip([ "critic1_loss", "critic2_loss", "policy_loss", "policy_reg_loss", "alpha_loss" ], [vf_loss1, vf_loss2, pol_loss, pol_reg_loss, alpha_loss]): key = "agent_{}/{}".format(agent_i, k) value = v.data.cpu().numpy() results[key] = value self.agent_losses[key].append(value) return results
def update(self, sample, agent_i, parallel=False, grad_norm=0.5, contract_keys=None): """ Update parameters of agent model based on sample from replay buffer Arguments: sample: [(B,T,D)]*N, obs, next_obs, action can be [dict (B,T,D)]*N agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads """ # each is [(B,T,D)]*N obs, acs, rews, next_obs, dones, old_logits, advantages, vf_preds = sample bs, ts, _ = obs[agent_i].shape self.init_hidden(bs) # use pre-defined init hiddens curr_agent = self.agents[agent_i] # NOTE: Critic update curr_agent.critic_optimizer.zero_grad() # value func if self.alg_types[agent_i] == 'CCPPO': # [(B,T,O)_i, ...] -> (B,T,O*N) vf_in = torch.cat([ *self.flatten_obs(obs, ma=True), ], dim=-1) else: # PPO vf_in = self.flatten_obs(obs[agent_i]) # (B,T,O) actual_value = curr_agent.compute_value(vf_in) # (B,T,1) # bellman errors (PPO clipped style) vf_loss1 = (actual_value - vf_preds) ** 2 vf_clipped = vf_preds + (actual_value - vf_preds).clamp( -self.vf_clip_param, self.vf_clip_param) vf_loss2 = (vf_clipped - vf_preds) ** 2 vf_loss = torch.max(vf_loss1, vf_loss2).mean() critic_loss = self.vf_loss_coeff * vf_loss critic_loss.backward() if parallel: average_gradients(curr_agent.critic) if grad_norm > 0: torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), grad_norm) curr_agent.critic_optimizer.step() # NOTE: Policy update curr_agent.policy_optimizer.zero_grad() # ppo policy update # NOTE: need wrap coz `evaluation_action` takes in dict, since policy output dict) act_eval_out = curr_agent.evalaute_action( self.wrap_action(old_logits[agent_i]), self.wrap_action(acs[agent_i]), obs[agent_i], contract_keys=contract_keys ) # all (B,T,1) curr_log_prob, old_log_prob, entropy, kl = act_eval_out logp_ratio = torch.exp(curr_log_probs - old_log_probs) policy_loss = -torch.min( advantages * logp_ratio, advantages * logp_ratio.clamp(1-self.clip_param, 1+self.clip_param) ) # (B,T,1) policy_loss = policy_loss.mean() # kl loss on current & previous policy outputs kl_loss = kl.mean() # update kl coefficient per update (with mean/expected kl) curr_agent.kl_coeff.update_kl(kl_loss) # entropy loss on current policy outputs entropy_loss = entropy.mean() actor_loss = policy_loss actor_loss += curr_agent.kl_coeff() * kl_loss actor_loss += self.entropy_coeff * entropy_loss actor_loss.backward() if parallel: average_gradients(curr_agent.policy) if grad_norm > 0: torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), grad_norm) curr_agent.policy_optimizer.step() # NOTE: collect training statss results = {} key_list = [ "critic_loss", "policy_loss", "kl_loss", "entropy_loss", "explained_variance" ] val_list = [ vf_loss, policy_loss, kl_loss, entropy_loss, explained_variance(vf_preds, actual_value) ] for k, v in zip(key_list, val_list): key = "agent_{}/{}".format(agent_i, k) value = v.data.cpu().numpy() results[key] = value self.agent_losses[key].append(value) return results
def update(self, sample, agent_i, parallel=False, grad_norm=0.5): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: EpisodeBatch, use sample[key_i] to get a specific array of obs, action, etc for agent i agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads """ obs, acs, rews, next_obs, dones = self.parse_sample(sample) # [(B,T,D)]*N bs, ts, _ = obs[0].shape self.init_hidden(bs) # use pre-defined init hiddens curr_agent = self.agents[agent_i] # NOTE: critic update curr_agent.critic_optimizer.zero_grad() # compute target actions if self.alg_types[agent_i] == 'MADDPG': all_trgt_acs = [] # [(B,T,A)]*N for i, (pi, nobs) in enumerate(zip(self.target_policies, next_obs)): # nobs: (B,T,O) act_i = self.compute_action(i, pi, nobs, bs=bs, ts=ts) all_trgt_acs.append(act_i) # [(B,T,A)] if self.discrete_action: # one-hot encode action all_trgt_acs = [onehot_from_logits( act_i.reshape(bs*ts,-1) ).reshape(bs,ts,-1) for act_i in all_trgt_acs] # critic input, [(B,T,O)_i, ..., (B,T,A)_i, ...] -> (B,T,O*N+A*N) trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=-1) else: # DDPG act_i = self.compute_action(agent_i, curr_agent.target_policy, next_obs[agent_i], bs=bs, ts=ts) if self.discrete_action: act_i = onehot_from_logits( act_i.reshape(bs*ts, -1) ).reshape(bs, ts, -1) # (B,T,O) + (B,T,A) -> (B,T,O+A) trgt_vf_in = torch.cat((next_obs[agent_i], act_i), dim=-1) # bellman targets target_q = self.compute_q_val(agent_i, curr_agent.target_critic, trgt_vf_in, bs=bs, ts=ts) # (B*T,1) target_value = (rews[agent_i].view(-1, 1) + self.gamma * target_q * (1 - dones[agent_i].view(-1, 1))) # (B*T,1) # Q func if self.alg_types[agent_i] == 'MADDPG': vf_in = torch.cat((*obs, *acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1) actual_value = self.compute_q_val(agent_i, curr_agent.critic, vf_in, bs=bs, ts=ts) # (B*T,1) # bellman errors vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() if parallel: average_gradients(curr_agent.critic) if grad_norm > 0: torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), grad_norm) curr_agent.critic_optimizer.step() # NOTE: policy update curr_agent.policy_optimizer.zero_grad() # current agent action (deterministic, softened) curr_pol_out = self.compute_action(i, curr_agent.policy, obs[agent_i], bs=bs, ts=ts) # (B,T,A) if self.discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. curr_pol_vf_in = gumbel_softmax( curr_pol_out.reshape(bs*ts, -1), hard=True ).reshape(bs, ts, -1) else: curr_pol_vf_in = curr_pol_out if self.alg_types[agent_i] == 'MADDPG': all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): if i == agent_i: # insert current agent act to q input all_pol_acs.append(curr_pol_vf_in) else: p_act_i = self.compute_action(i, pi, ob, bs=bs, ts=ts) # (B,T,A) if self.discrete_action: p_act_i = onehot_from_logits( p_act_i.reshape(bs*ts, -1) ).reshape(bs, ts, -1) all_pol_acs.append(p_act_i) p_vf_in = torch.cat((*obs, *all_pol_acs), dim=-1) # (B,T,O*N+A*N) else: # DDPG p_vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=-1) # (B,T,O+A) # value function to update current policy p_value = self.compute_q_val(agent_i, curr_agent.critic, p_vf_in, bs=bs, ts=ts) # (B*T,1) pol_loss = -p_value.mean() # p regularization, scale down output (gaussian mean,std or logits) # reference: https://github.com/openai/maddpg/blob/master/maddpg/trainer/maddpg.py pol_loss += ((curr_pol_out.reshape(bs*ts, -1))**2).mean() * 1e-3 pol_loss.backward() if parallel: average_gradients(curr_agent.policy) if grad_norm > 0: torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), grad_norm) curr_agent.policy_optimizer.step() # collect training statss results = { "agent_{}_critic_loss".format(agent_i): vf_loss, "agent_{}_policy_loss".format(agent_i): pol_loss } return results
def update(self, sample, agent_i, parallel=False, logger=None): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next observations, and episode end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads logger (SummaryWriter from Tensorboard-Pytorch): If passed in, important quantities will be logged """ # For RNN, the obs and next_obs both have histories obs, acs, rews, next_obs, dones = sample curr_agent = self.agents[agent_i] curr_agent.critic_optimizer.zero_grad() if self.alg_types[agent_i] == 'MADDPG': if self.discrete_action: # one-hot encode action # This is original one, 'pi': policy, 'nobs' n_observations #all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in # zip(self.target_policies, next_obs)] # Original till here #-------- Expanding out for debugging --------# all_trgt_acs = [] for pi, nobs in zip(self.target_policies, next_obs): temp = onehot_from_logits(pi(nobs)) #print(temp) all_trgt_acs.append(temp) # -------- End debug -------------------------# else: all_trgt_acs = [ pi(nobs) for pi, nobs in zip(self.target_policies, next_obs) ] # Get the most current observation from the history to calculate the target value t0_next_obs = [[], [], []] for a in range(self.nagents): t0_next_obs[a] = torch.tensor(np.zeros( (next_obs[0].shape[0], 18)), dtype=torch.float) # the next_obs[0].shape[0] gives the batch size # TODO: change it to be a parameter # Only keep the current obs for critic VF for n in range(self.nagents): # for each agents for b in range( next_obs[0].shape[0]): # for the number of batches t0_next_obs[n][b][:] = next_obs[n][b][0:18] # ORIGINAL was \/ #trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1) trgt_vf_in = torch.cat((*t0_next_obs, *all_trgt_acs), dim=1) # It is working till here. Only kept the current obs for critic VF else: # DDPG # DDPG only knows the particular agent's observation and policy # Whereas, MADDPG has access to all other agents' policies # TODO: grab only the current observation to send to the critic t0_next_obs = [[], [], []] for a in range(self.nagents): t0_next_obs[a] = torch.tensor(np.zeros( (next_obs[0].shape[0], 18)), dtype=torch.float) for n in range(self.nagents): # for each agents for b in range( next_obs[0].shape[0]): # for the number of batches t0_next_obs[n][b][:] = next_obs[n][b][0:18] # Originally it would be next_obs[agent_i] instead of t0_next_obs[agent_i] if self.discrete_action: trgt_vf_in = torch.cat( (t0_next_obs[agent_i], onehot_from_logits( curr_agent.target_policy(next_obs[agent_i]))), dim=1) else: trgt_vf_in = torch.cat( (next_obs[agent_i], curr_agent.target_policy(next_obs[agent_i])), dim=1) target_value = (rews[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(trgt_vf_in) * (1 - dones[agent_i].view(-1, 1))) ##### Just get the current observation (i.e., without history) ########## # Reason: Critic VF does not need history # Copied the same as in t0_next_obs, since BOTH obs and next_obs have HISTORIES. t0_obs = [[], [], []] for a in range(self.nagents): t0_obs[a] = torch.tensor(np.zeros((obs[0].shape[0], 18)), dtype=torch.float) for n in range(self.nagents): # for each agents for b in range(obs[0].shape[0]): # for the number of batches t0_obs[n][b][:] = obs[n][b][0:18] ################################################### if self.alg_types[agent_i] == 'MADDPG': vf_in = torch.cat((*t0_obs, *acs), dim=1) else: # DDPG #TODO: below, might have to change obs to t0_obs, when using DDPG vf_in = torch.cat((t0_obs[agent_i], acs[agent_i]), dim=1) actual_value = curr_agent.critic(vf_in) vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() if parallel: average_gradients(curr_agent.critic) torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5) curr_agent.critic_optimizer.step() curr_agent.policy_optimizer.zero_grad() if self.discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. ''' Now, we are back to forwarding policy, so we need to use obs with history ''' curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) # Seems to be working fine till here else: curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = curr_pol_out if self.alg_types[agent_i] == 'MADDPG': all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): if i == agent_i: all_pol_acs.append(curr_pol_vf_in) elif self.discrete_action: all_pol_acs.append(onehot_from_logits(pi(ob))) else: all_pol_acs.append(pi(ob)) # Originally: #vf_in = torch.cat((*obs, *all_pol_acs), dim=1) vf_in = torch.cat((*t0_obs, *all_pol_acs), dim=1) else: # DDPG vf_in = torch.cat((t0_obs[agent_i], curr_pol_vf_in), dim=1) # TODO: FIX THIS pol_loss = -curr_agent.critic(vf_in).mean() pol_loss += (curr_pol_out**2).mean() * 1e-3 pol_loss.backward() if parallel: average_gradients(curr_agent.policy) torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5) curr_agent.policy_optimizer.step() if logger is not None: logger.add_scalars('agent%i/losses' % agent_i, { 'vf_loss': vf_loss, 'pol_loss': pol_loss }, self.niter)