def forward(self, obs, sample=True, return_all_probs=False, return_log_pi=False, regularize=False, return_entropy=False): out = super(DiscretePolicy, self).forward(obs) probs = F.softmax(out, dim=1) on_gpu = next(self.parameters()).is_cuda if sample: int_act, act = categorical_sample(probs, use_cuda=on_gpu) else: act = onehot_from_logits(probs) rets = [act] if return_log_pi or return_entropy: log_probs = F.log_softmax(out, dim=1) if return_all_probs: rets.append(probs) if return_log_pi: # return log probability of selected action rets.append(log_probs.gather(1, int_act)) if regularize: rets.append([(out**2).mean()]) if return_entropy: rets.append(-(log_probs * probs).sum(1).mean()) if len(rets) == 1: return rets[0] return rets
def step(self, obs, explore=False): """ Take a step forward in environment for a minibatch of observations equivalent to `act` or `compute_actions` Arguments: obs: (B,O) explore: Whether or not to add exploration noise Returns: action: dict of actions for this agent, (B,A) """ with torch.no_grad(): action, hidden_states = self.policy(obs, self.policy_hidden_states) self.policy_hidden_states = hidden_states # if mlp, still defafult None if self.discrete_action: for k in action: if explore: action[k] = gumbel_softmax(action[k], hard=True) else: action[k] = onehot_from_logits(action[k]) else: # continuous action idx = 0 noise = Variable(Tensor(self.exploration.noise()), requires_grad=False) for k in action: if explore: dim = action[k].shape[-1] action[k] += noise[idx:idx + dim] idx += dim action[k] = action[k].clamp(-1, 1) return action
def step(self, obs, explore=False): """ Take a step forward in environment for a minibatch of observations Inputs: obs (PyTorch Variable): Observations for this agent explore (boolean): Whether or not to add exploration noise Outputs: action (PyTorch Variable): Actions for this agent """ # print('----', obs) action = self.policy(obs) # print('>>>>>>>>', action) if self.discrete_action: # print('agents.py discrete_action yes') if explore: action = gumbel_softmax(action, hard=True) else: action = onehot_from_logits(action) else: # continuous action # print('agents.py continuous_action yes') if explore: # print('agents.py explore yes') # print('action before noise', action) action += Variable(Tensor(self.exploration.noise()), requires_grad=False) # print('action after noise', action) action = action.clamp(-1, 1) # print('action after clamp', action) # print('>>>>>>>>>>>>>>>>', action) return action
def _soft_act(x): # x: (B,A) if not self.discrete_action: return x if requires_grad: return gumbel_softmax(x, hard=True) else: return onehot_from_logits(x)
def _soft_act(self, x, requires_grad=True): """ soften action if discrete, x: (B,A) """ if not self.discrete_action: return x if requires_grad: return gumbel_softmax(x, hard=True) else: return onehot_from_logits(x)
def update(self, sample, agent_i, parallel=False, logger=None): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next observations, and episode end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads logger (SummaryWriter from Tensorboard-Pytorch): If passed in, important quantities will be logged """ obs, acs, rews, next_obs, dones = sample curr_agent = self.agents[agent_i] curr_agent.critic_optimizer.zero_grad() if self.alg_types[agent_i] == 'MADDPG': if self.discrete_action: # one-hot encode action all_trgt_acs = [ onehot_from_logits(pi(nobs)) for pi, nobs in zip(self.target_policies, next_obs) ] else: all_trgt_acs = [ pi(nobs) for pi, nobs in zip(self.target_policies, next_obs) ] trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1) else: # DDPG if self.discrete_action: trgt_vf_in = torch.cat( (next_obs[agent_i], onehot_from_logits( curr_agent.target_policy(next_obs[agent_i]))), dim=1) else: trgt_vf_in = torch.cat( (next_obs[agent_i], curr_agent.target_policy(next_obs[agent_i])), dim=1) target_value = (rews[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(trgt_vf_in) * (1 - dones[agent_i].view(-1, 1))) if self.alg_types[agent_i] == 'MADDPG': vf_in = torch.cat((*obs, *acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1) actual_value = curr_agent.critic(vf_in) vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() if parallel: average_gradients(curr_agent.critic) torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5) curr_agent.critic_optimizer.step() curr_agent.policy_optimizer.zero_grad() if self.discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) else: curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = curr_pol_out if self.alg_types[agent_i] == 'MADDPG': all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): if i == agent_i: all_pol_acs.append(curr_pol_vf_in) elif self.discrete_action: all_pol_acs.append(onehot_from_logits(pi(ob))) else: all_pol_acs.append(pi(ob)) vf_in = torch.cat((*obs, *all_pol_acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=1) pol_loss = -curr_agent.critic(vf_in).mean() pol_loss += (curr_pol_out**2).mean() * 1e-3 pol_loss.backward() if parallel: average_gradients(curr_agent.policy) torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5) curr_agent.policy_optimizer.step() if logger is not None: logger.add_scalars('agent%i/losses' % agent_i, { 'vf_loss': vf_loss, 'pol_loss': pol_loss }, self.niter)
def update_agent(self, sample, agent_i): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next_observations, and episode_end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads """ # Extract info and agent observations, actions, rewards, next_obs, dones = sample curr_agent = self.agents[agent_i] # UPDATES THE CRITIC --- # Resets gradient buffer curr_agent.critic_optimizer.zero_grad() curr_agent.f_e_optimizer.zero_grad() # Gets target state-action pair (next_obs, next_actions) if self.alg_types[agent_i] in ['SharedMADDPG', 'MADDPG']: if self.use_discrete_action: # one-hot encode action all_target_actions = [ onehot_from_logits(pi(f_e(nobs))) for pi, f_e, nobs in zip( self.target_policies, self.f_es, next_obs) ] else: all_target_actions = [ pi(f_e(nobs)) for pi, f_e, nobs in zip( self.target_policies, self.f_es, next_obs) ] if self.alg_types[agent_i] == 'SharedMADDPG': next_obs.insert(0, next_obs.pop(agent_i)) all_target_actions.insert(0, all_target_actions.pop(agent_i)) if self.critic_concat_all_obs: target_vf_in = torch.cat( (*next_obs, *all_target_actions), dim=1 ) # TODO: WARNING: should probably feed that in feature_extractor else: target_vf_in = torch.cat( (curr_agent.f_e(next_obs[agent_i]), *all_target_actions), dim=1) elif self.alg_types[agent_i] in ['SharedDDPG', 'DDPG']: next_fe = curr_agent.f_e(next_obs[agent_i]) if self.use_discrete_action: target_vf_in = torch.cat( (next_fe, onehot_from_logits(curr_agent.target_policy(next_fe))), dim=1) else: target_vf_in = torch.cat( (next_fe, curr_agent.target_policy(next_fe)), dim=1) else: raise NotImplemented # Computes target value target_value = (rewards[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(target_vf_in) * (1 - dones.view(-1, 1))) # Computes current state-action value if self.alg_types[agent_i] in ['SharedMADDPG', 'MADDPG']: if self.alg_types[agent_i] == 'SharedMADDPG': observations.insert(0, observations.pop(agent_i)) actions.insert(0, actions.pop(agent_i)) vf_in = torch.cat((*observations, *actions), dim=1) elif self.alg_types[agent_i] in ['SharedDDPG', 'DDPG']: vf_in = torch.cat((observations[agent_i], actions[agent_i]), dim=1) else: raise NotImplemented actual_value = curr_agent.critic(vf_in) # Backpropagates vf_loss = MSELoss(actual_value, target_value.detach()) # we have to retain the graph because we reuse the critic_obs for the following actor loss vf_loss.backward( retain_graph=True ) ##todo:make sure there is no leakage between the two losses # Clip gradients torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), self.grad_clip_value) # Apply critic update curr_agent.critic_optimizer.step() # UPDATES THE ACTOR --- # Resets gradient buffer curr_agent.policy_optimizer.zero_grad() # We put experience data back in the general point of view if self.alg_types[agent_i] == 'SharedMADDPG': observations.insert(agent_i, observations.pop(0)) if self.use_discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. curr_pol_out = curr_agent.policy( curr_agent.f_e(observations[agent_i])) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) else: curr_pol_out = curr_agent.policy( curr_agent.f_e(observations[agent_i])) curr_pol_vf_in = curr_pol_out # No Gumbel-softmax for continuous control # Gets state-action pair value given by the critic if self.alg_types[agent_i] in ['SharedMADDPG', 'MADDPG']: all_pol_acs = [] for i, pi, f_e, ob in zip(range(self.nagents), self.policies, self.f_es, observations): if i == agent_i: all_pol_acs.append(curr_pol_vf_in) elif self.use_discrete_action: all_pol_acs.append( onehot_from_logits(pi(f_e(ob))).detach()) else: all_pol_acs.append(pi(ob).detach()) if self.alg_types[agent_i] == 'SharedMADDPG': # the critic must take the point of vue of agent i observations.insert(0, observations.pop(agent_i)) all_pol_acs.insert(0, all_pol_acs.pop(agent_i)) vf_in = torch.cat((*observations, *all_pol_acs), dim=1) # Centralized critic for MADDPG agent elif self.alg_types[agent_i] in ['SharedDDPG', 'DDPG']: vf_in = torch.cat((observations[agent_i], curr_pol_vf_in), dim=1) else: raise NotImplemented # Computes the loss pol_loss = -torch.mean(curr_agent.critic(vf_in)) # Backpropagates pol_loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), self.grad_clip_value) # Apply actor update curr_agent.policy_optimizer.step() ## Backpropagates for feature extractor from both backprops torch.nn.utils.clip_grad_norm_(curr_agent.f_e.parameters(), self.grad_clip_value) curr_agent.f_e_optimizer.step() return pol_loss.data.cpu().numpy(), vf_loss.data.cpu().numpy()
def update_agent(self, sample, agent_i): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next_observations, and episode_end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads """ # Extract info and agent observations, actions, rewards, next_obs, dones = sample curr_agent = self.agents[agent_i] # UPDATES THE CRITIC --- # Resets gradient buffer curr_agent.critic_optimizer.zero_grad() # Gets target state-action pair (next_obs, next_actions) if self.alg_types[agent_i] in ['MADDPG', 'CoachMADDPG']: if self.use_discrete_action: # one-hot encode action all_target_actions = [onehot_from_logits(pi(nobs)) for pi, nobs in zip(self.target_policies, next_obs)] else: all_target_actions = [pi(nobs) for pi, nobs in zip(self.target_policies, next_obs)] target_vf_in = torch.cat((*next_obs, *all_target_actions), dim=1) elif self.alg_types[agent_i] == 'DDPG': if self.use_discrete_action: target_vf_in = torch.cat((next_obs[agent_i], onehot_from_logits(curr_agent.target_policy(next_obs[agent_i]))), dim=1) else: target_vf_in = torch.cat((next_obs[agent_i], curr_agent.target_policy(next_obs[agent_i])), dim=1) else: raise NotImplemented # Computes target value target_value = (rewards[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(target_vf_in) * (1 - dones.view(-1, 1))) # Computes current state-action value if self.alg_types[agent_i] in ['MADDPG', 'CoachMADDPG']: vf_in = torch.cat((*observations, *actions), dim=1) elif self.alg_types[agent_i] == 'DDPG': vf_in = torch.cat((observations[agent_i], actions[agent_i]), dim=1) else: raise NotImplemented actual_value = curr_agent.critic(vf_in) # Backpropagates vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), self.grad_clip_value) # Apply critic update curr_agent.critic_optimizer.step() # UPDATES THE ACTOR --- # Resets gradient buffer curr_agent.policy_optimizer.zero_grad() if self.use_discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. curr_pol_out = curr_agent.policy(observations[agent_i], return_embed_logits=False) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) else: curr_pol_out = curr_agent.policy(observations[agent_i], return_embed_logits=False) curr_pol_vf_in = curr_pol_out # No Gumbel-softmax for continuous control # Gets state-action pair value given by the critic if self.alg_types[agent_i] in ['MADDPG', 'CoachMADDPG']: all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, observations): if i == agent_i: all_pol_acs.append(curr_pol_vf_in) elif self.use_discrete_action: all_pol_acs.append(onehot_from_logits(pi(ob).detach())) else: all_pol_acs.append(pi(ob).detach()) vf_in = torch.cat((*observations, *all_pol_acs), dim=1) # Centralized critic for MADDPG agent elif self.alg_types[agent_i] == 'DDPG': vf_in = torch.cat((observations[agent_i], curr_pol_vf_in), dim=1) else: raise NotImplemented # Computes the loss J_PG = -torch.mean(curr_agent.critic(vf_in)) pol_loss = J_PG # Backpropagates pol_loss.backward() # Update actors # Clip gradients torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), self.grad_clip_value) # Apply actor update curr_agent.policy_optimizer.step() return pol_loss.data.cpu().numpy(), vf_loss.data.cpu().numpy()
def update_agent(self, sample, agent_i): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next_observations, and episode_end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads """ # Extract info and agent observations, actions, rewards, next_obs, dones = sample curr_agent = self.agents[agent_i] # UPDATES THE CRITIC --- # Resets gradient buffer curr_agent.critic_optimizer.zero_grad() curr_agent.f_e_optimizer.zero_grad() # Gets target state-action pair (next_obs, next_actions) if self.alg_types[agent_i] in ['TeamMADDPG', 'MADDPG']: if self.use_discrete_action: # one-hot encode action all_target_actions = [ onehot_from_logits( pi(f_e(nobs)).view(-1, self.action_spaces[0], self.nagents)[:, :, i]) for i, (pi, f_e, nobs) in enumerate( zip(self.target_policies, self.f_es, next_obs)) ] else: all_target_actions = [ pi(f_e(nobs)).view(-1, self.action_spaces[0], self.nagents)[:, :, i] for i, (pi, f_e, nobs) in enumerate( zip(self.target_policies, self.f_es, next_obs)) ] if self.critic_concat_all_obs: target_vf_in = torch.cat((*next_obs, *all_target_actions), dim=1) else: target_vf_in = torch.cat( (curr_agent.f_e(next_obs[agent_i]), *all_target_actions), dim=1) elif self.alg_types[agent_i] == 'DDPG': next_fe = curr_agent.f_e(next_obs[agent_i]) if self.use_discrete_action: target_vf_in = torch.cat( (next_fe, onehot_from_logits(curr_agent.target_policy(next_fe))), dim=1) else: target_vf_in = torch.cat( (next_fe, curr_agent.target_policy(next_fe)), dim=1) else: raise NotImplemented # Computes target value target_value = (rewards[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(target_vf_in) * (1 - dones.view(-1, 1))) # Computes current state-action value if self.alg_types[agent_i] in ['TeamMADDPG', 'MADDPG']: if self.critic_concat_all_obs: critic_obs = [ agent.f_e(obs) for obs, agent in zip(observations, self.agents) ] else: critic_obs = [curr_agent.f_e(observations[agent_i])] vf_in = torch.cat((*critic_obs, *actions), dim=1) elif self.alg_types[agent_i] == 'DDPG': critic_obs = curr_agent.f_e(observations[agent_i]) vf_in = torch.cat((critic_obs, actions[agent_i]), dim=1) else: raise NotImplemented actual_value = curr_agent.critic(vf_in) # Backpropagates vf_loss = MSELoss(actual_value, target_value.detach()) # we have to retain the graph because we reuse the critic_obs for the following actor loss vf_loss.backward( retain_graph=True ) ##todo:make sure there is no leakage between the two losses # Clip gradients torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), self.grad_clip_value) # Apply critic update curr_agent.critic_optimizer.step() # UPDATES THE ACTOR --- # Resets gradient buffer curr_agent.policy_optimizer.zero_grad() if self.use_discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. pol_out_all_heads = curr_agent.policy( curr_agent.f_e(observations[agent_i])).view( -1, self.action_spaces[0], self.nagents) curr_pol_out = pol_out_all_heads[:, :, agent_i] curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) else: pol_out_all_heads = curr_agent.policy( curr_agent.f_e(observations[agent_i])).view( -1, self.action_spaces[0], self.nagents) curr_pol_out = pol_out_all_heads[:, :, agent_i] curr_pol_vf_in = curr_pol_out # No Gumbel-softmax for continuous control # Gets state-action pair value given by the critic if self.alg_types[agent_i] in ['TeamMADDPG', 'MADDPG']: all_pol_logits = [] all_pol_acs = [] for i, pi, f_e, ob in zip(range(self.nagents), self.policies, self.f_es, observations): if i == agent_i: all_pol_logits.append(curr_pol_out) all_pol_acs.append(curr_pol_vf_in) elif self.use_discrete_action: logits = pi(f_e(ob)).view(-1, self.action_spaces[0], self.nagents)[:, :, i] all_pol_logits.append(logits) all_pol_acs.append(onehot_from_logits(logits).detach()) else: all_pol_logits.append(None) all_pol_acs.append( pi(f_e(ob)).detach().view(-1, self.action_spaces[0], self.nagents)[:, :, i]) vf_in = torch.cat((*critic_obs, *all_pol_acs), dim=1) # Centralized critic for MADDPG agent elif self.alg_types[agent_i] == 'DDPG': vf_in = torch.cat((critic_obs, curr_pol_vf_in), dim=1) else: raise NotImplemented # Computes the loss pol_loss = -torch.mean(curr_agent.critic(vf_in)) # Backpropagates pol_loss.backward(retain_graph=True if self.alg_types[agent_i] == "TeamMADDPG" else False) # Clip gradients torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), self.grad_clip_value) # Apply actor update curr_agent.policy_optimizer.step() # Team Spirit (TS) regularization if self.alg_types[agent_i] == "TeamMADDPG": if self.use_discrete_action: real_action_logits = torch.stack(all_pol_logits, dim=2) real_action_probs = F.softmax(real_action_logits, dim=1) real_action_log_probs = F.log_softmax(real_action_logits, dim=1) predicted_action_log_probs = F.log_softmax(pol_out_all_heads, dim=1) # KL-divergence KL(predicted_action_dist||real_action_dist) ts_loss = torch.mean( torch.sum( real_action_probs * (real_action_log_probs - predicted_action_log_probs), dim=1)) else: ts_loss = torch.mean( torch.sum((pol_out_all_heads - torch.stack(all_pol_acs, dim=2))**2, dim=1)) # Resets gradient buffer of all agents for agent in self.agents: agent.policy_optimizer.zero_grad() # Backpropagates through every agent of the team (including curr_agent) ts_loss.backward() for i, agent in enumerate(self.agents): if i == agent_i: coeff = self.lambdat_1 else: coeff = self.lambdat_2 for p in agent.policy.parameters(): p.grad *= coeff # Apply gradients torch.nn.utils.clip_grad_norm_(agent.policy.parameters(), self.grad_clip_value) agent.policy_optimizer.step() ts_loss = ts_loss.data.cpu().numpy() else: ts_loss = None ## Backpropagates for feature extractor from both backprops torch.nn.utils.clip_grad_norm_(curr_agent.f_e.parameters(), self.grad_clip_value) curr_agent.f_e_optimizer.step() return pol_loss.data.cpu().numpy(), vf_loss.data.cpu().numpy(), ts_loss
def update(self, sample, agent_i, parallel=False, logger=None): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next observations, and episode end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads logger (SummaryWriter from Tensorboard-Pytorch): If passed in, important quantities will be logged """ obs, acs, rews, next_obs, dones = sample curr_agent = self.agents[agent_i] curr_agent.critic_optimizer.zero_grad() if self.alg_types[agent_i] == 'MADDPG': if self.discrete_action: # one-hot encode action all_trgt_acs = [ onehot_from_logits(pi(nobs)) for pi, nobs in zip(self.target_policies, next_obs) ] # a'=mu'(o') Have all agents' else: all_trgt_acs = [ pi(nobs) for pi, nobs in zip(self.target_policies, next_obs) ] # ==========================Adding noise==================== if self.noisy_sharing == True: #noisy_all_trgt_acs = self.noisy_sharing_discrete(all_trgt_acs,agent_i) #all_trgt_acs = noisy_all_trgt_acs noisy_acs = self.noisy_sharing_discrete(acs, agent_i) acs = noisy_acs # print(self.noisy_SNR) # ==================End of Adding noise==================== trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1) # =========================Differential Obs======================== # ============== Dedicate for simple_speaker_listener ============= # The est_action is used to replace acs[1] if self.game_id == 'simple_speaker_listener' and self.est_ac == True: diff_pos = (next_obs[0] - obs[0])[:, -2:] tmp_p = torch.transpose(diff_pos.ge(torch.max(diff_pos) * 0.8), 0, 1) tmp_p[0] = tmp_p[0] * 1 tmp_p[1] = tmp_p[1] * 3 tmp_n = torch.transpose(diff_pos.le(torch.min(diff_pos) * 0.8), 0, 1) tmp_n[0] = tmp_n[0] * 2 tmp_n[1] = tmp_n[1] * 4 mask = torch.transpose(tmp_p, 0, 1) + torch.transpose( tmp_n, 0, 1) est_action = mask.sum(dim=1) est_action = torch.zeros(len(est_action), acs[1].shape[1]).scatter_( dim=1, index=est_action.view(-1, 1), value=1) acs[1] = est_action # =======================End of differential Obs ================== else: # DDPG if self.discrete_action: trgt_vf_in = torch.cat( (next_obs[agent_i], onehot_from_logits( curr_agent.target_policy(next_obs[agent_i]))), dim=1) else: # a'=mu(o') only have current agent's trgt_vf_in = torch.cat( (next_obs[agent_i], curr_agent.target_policy(next_obs[agent_i])), dim=1) target_value = (rews[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(trgt_vf_in) * (1 - dones[agent_i].view(-1, 1))) #y^j if self.alg_types[agent_i] == 'MADDPG': vf_in = torch.cat((*obs, *acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1) actual_value = curr_agent.critic(vf_in) vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() if parallel: average_gradients(curr_agent.critic) torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5) curr_agent.critic_optimizer.step() # ============== Here for policy network training ===================== curr_agent.policy_optimizer.zero_grad() if self.discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) else: curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = curr_pol_out if self.alg_types[agent_i] == 'MADDPG': all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): # Is it correct to train mu using all others' policies??? if i == agent_i: all_pol_acs.append(curr_pol_vf_in) elif self.discrete_action: all_pol_acs.append(onehot_from_logits(pi(ob))) else: all_pol_acs.append(pi(ob)) vf_in = torch.cat((*obs, *all_pol_acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=1) pol_loss = -curr_agent.critic(vf_in).mean() pol_loss += (curr_pol_out**2).mean() * 1e-3 pol_loss.backward() if parallel: average_gradients(curr_agent.policy) torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5) # Constraints on the grad. curr_agent.policy_optimizer.step() if logger is not None: logger.add_scalars('agent%i/losses' % agent_i, { 'vf_loss': vf_loss, 'pol_loss': pol_loss }, self.niter)
def evaluate(config): if config.seed_num is None: all_seeds = list((DirectoryManager.root / config.storage_name / f"experiment{config.experiment_num}").iterdir()) config.seed_num = all_seeds[0].stem.strip('seed') # Creates paths and directories seed_path = DirectoryManager.root / config.storage_name / f"experiment{config.experiment_num}" / f"seed{config.seed_num}" dir_manager = DirectoryManager.init_from_seed_path(seed_path) if config.incremental is not None: model_path = dir_manager.incrementals_dir / ( f'model_ep{config.incremental}.pt') elif config.last_model: last_models = [ path for path in dir_manager.seed_dir.iterdir() if path.suffix == ".pt" and not path.stem.endswith('best') ] assert len(last_models) == 1 model_path = last_models[0] else: best_models = [ path for path in dir_manager.seed_dir.iterdir() if path.suffix == ".pt" and path.stem.endswith('best') ] assert len(best_models) == 1 model_path = best_models[0] # Retrieves world_params if there were any (see make_world function in multiagent.scenarios) if (dir_manager.seed_dir / 'world_params.json').exists(): world_params = load_dict_from_json( str(dir_manager.seed_dir / 'world_params.json')) else: world_params = {} # Overwrites world_params if specified if config.shuffle_landmarks is not None: world_params['shuffle_landmarks'] = config.shuffle_landmarks if config.color_objects is not None: world_params['color_objects'] = config.color_objects if config.small_agents is not None: world_params['small_agents'] = config.small_agents if config.individual_reward is not None: world_params['individual_reward'] = config.individual_reward if config.use_dense_rewards is not None: world_params['use_dense_rewards'] = config.use_dense_rewards # Retrieves env_params (see multiagent.environment.MultiAgentEnv) if (dir_manager.seed_dir / 'env_params.json').exists(): env_params = load_dict_from_json( str(dir_manager.seed_dir / 'env_params.json')) else: env_params = {} env_params['use_max_speed'] = False # Initializes model and environment algorithm = init_from_save(model_path) env = make_env(scenario_name=env_params['env_name'], use_discrete_action=algorithm.use_discrete_action, use_max_speed=env_params['use_max_speed'], world_params=world_params) if config.render: env.render() if config.runner_prey: # makes sure the environment involves a prey assert config.env_name.endswith('tag') runner_policy = RunnerPolicy() for agent in env.world.agents: if agent.adversary: agent.action_callback = runner_policy.action if config.rusher_predators: # makes sure the environment involves predators assert config.env_name.endswith('tag') rusher_policy = RusherPolicy() for agent in env.world.agents: if not agent.adversary: agent.action_callback = rusher_policy.action if config.pendulum_agent is not None: # makes sure the agent to be controlled has a valid id assert config.pendulum_agent in list(range(len(env.world.agents))) pendulum_policy = DoublePendulumPolicy() env.world.agents[ config.pendulum_agent].action_callback = pendulum_policy.action if config.interactive_agent is not None: # makes sure the agent to be controlled has a valid id assert config.interactive_agent in list(range(len(env.world.agents))) interactive_policy = InteractivePolicy(env, viewer_id=0) env.world.agents[ config. interactive_agent].action_callback = interactive_policy.action algorithm.prep_rollouts(device='cpu') ifi = 1 / config.fps # inter-frame interval total_reward = [] all_episodes_agent_embeddings = [] all_episodes_coach_embeddings = [] all_trajs = [] overide_color = None color_agents = True if env_params['env_name'] == 'bounce': env.agents[0].size = 1. * env.agents[0].size env.world.overwrite = config.overwrite elif env_params['env_name'] == 'spread': color_agents = False elif env_params['env_name'] == 'compromise': env.agents[0].lightness = 0.9 env.world.landmarks[0].lightness = 0.9 env.agents[1].lightness = 0.5 env.world.landmarks[1].lightness = 0.5 # cmo = plt.cm.get_cmap('viridis') env.world.overwrite = config.overwrite # overide_color = [np.array(cmo(float(i) / float(2))[:3]) for i in range(2)] # set_seeds_env(2, env) # EPISODES LOOP for ep_i in range(config.n_episodes): # set_seeds(2) # set_seeds_env(2, env) agent_embeddings = [] coach_embeddings = [] traj = [] ep_recorder = EpisodeRecorder(stuff_to_record=['reward']) # Resets the environment obs = env.reset() if config.save_gifs: frames = None if config.render: env.render('human') if not algorithm.soft: # Resets exploration noise algorithm.scale_noise(config.noise_scale) algorithm.reset_noise() # STEPS LOOP for t_i in range(config.episode_length): calc_start = time.time() # rearrange observations to be per agent, and convert to torch Variable torch_obs = [ Variable(torch.Tensor(obs[i]).view(1, -1), requires_grad=False) for i in range(algorithm.nagents) ] # get actions as torch Variables torch_actions, torch_embed = algorithm.select_action( torch_obs, is_exploring=False if config.noise_scale is None else True, return_embed=True) torch_total_obs = torch.cat(torch_obs, dim=-1) coach_embed = onehot_from_logits( algorithm.coach.model(torch_total_obs)) coach_embeddings.append(coach_embed.data.numpy().squeeze()) # convert actions to numpy arrays actions = [ac.data.numpy().flatten() for ac in torch_actions] embeds = [emb.data.numpy().squeeze() for emb in torch_embed] agent_embeddings.append(embeds) # steps forward in the environment next_obs, rewards, dones, infos = env.step(actions) ep_recorder.add_step(None, None, rewards, None) traj.append((obs, actions, next_obs, rewards, dones)) obs = next_obs colors = list(cm.get_cmap('Set1').colors[:len(embeds[0])]) if overide_color is not None: colors[0] = overide_color[0] colors[2] = overide_color[1] if color_agents: for agent, emb in zip(env.agents, embeds): agent.color = colors[np.argmax(emb)] # record frames if config.save_gifs: frames = [] if frames is None else frames frames.append(env.render('rgb_array')[0]) if config.render or config.save_gifs: # Enforces the fps config calc_end = time.time() elapsed = calc_end - calc_start if elapsed < ifi: time.sleep(ifi - elapsed) env.render('human') if all(dones) and config.interrupt_episode: if config.render: time.sleep(2) break # print(ep_recorder.get_total_reward()) total_reward.append(ep_recorder.get_total_reward()) all_episodes_agent_embeddings.append(agent_embeddings) all_episodes_coach_embeddings.append(coach_embeddings) all_trajs.append(traj) # Saves gif of all the episodes if config.save_gifs: gif_path = dir_manager.storage_dir / 'gifs' gif_path.mkdir(exist_ok=True) gif_num = 0 while (gif_path / f"{env_params['env_name']}__experiment{config.experiment_num}_seed{config.seed_num}_{gif_num}.gif" ).exists(): gif_num += 1 imageio.mimsave(str( gif_path / f"{env_params['env_name']}__experiment{config.experiment_num}_seed{config.seed_num}_{gif_num}.gif" ), frames, duration=ifi) env.close() embeddings = { 'agents': all_episodes_agent_embeddings, 'coach': all_episodes_coach_embeddings } save_folder = dir_manager.experiment_dir if config.save_to_exp_folder else dir_manager.seed_dir embeddings_path = U.directory_tree.uniquify( save_folder / f"{config.file_name_to_save}.pkl") trajs_path = osp.splitext(embeddings_path)[0] + "_trajs.pkl" with open(embeddings_path, 'wb') as fp: pickle.dump(embeddings, fp) fp.close() with open(trajs_path, 'wb') as fp: pickle.dump(all_trajs, fp) fp.close() return total_reward, str(embeddings_path)
def forward(self, obs, sample=True, return_all_probs=False, return_log_pi=False, regularize=False, return_entropy=False): out = super(DiscretePolicy, self).forward(obs) # _, action_dim = out.size() # # dim(u_aaction)=5, dim(r_action) = 2, dim(audio_action = 3) # r_action_dim = 2 # audio_action_dim = 3 # u_action_dim = action_dim - (r_action_dim + audio_action_dim) # assert u_action_dim == 5, "policy dimensions" # # # probs_u = F.softmax(out[:,0:u_action_dim], dim=1) # on_gpu = next(self.parameters()).is_cuda # if sample: # int_act, act_u = categorical_sample(probs_u, use_cuda=on_gpu) # else: # act_u = onehot_from_logits(probs_u) # # # TODO: change rotation to discrete action, and output prob_r, also change the step in environment # # action_r = out[:, u_action_dim].view(-1, 1) # probs_r = F.softmax(out[:, u_action_dim:u_action_dim+r_action_dim], dim=1) # # on_gpu = next(self.parameters()).is_cuda # if sample: # _, act_r = categorical_sample(probs_r, use_cuda=on_gpu) # else: # act_r = onehot_from_logits(probs_r) # # probs_audio = F.softmax(out[:, u_action_dim+r_action_dim:], dim=1) # # on_gpu = next(self.parameters()).is_cuda # if sample: # _, act_audio = categorical_sample(probs_audio, use_cuda=on_gpu) # else: # act_audio = onehot_from_logits(probs_audio) # # return torch.cat([act_u, act_r, act_audio], dim=1) probs = F.softmax(out, dim=1) on_gpu = next(self.parameters()).is_cuda if sample: int_act, act = categorical_sample(probs, use_cuda=on_gpu) else: act = onehot_from_logits(probs) rets = [act] if return_log_pi or return_entropy: log_probs = F.log_softmax(out, dim=1) if return_all_probs: rets.append(probs) if return_log_pi: # return log probability of selected action rets.append(log_probs.gather(1, int_act)) if regularize: rets.append([(out**2).mean()]) if return_entropy: rets.append(-(log_probs * probs).sum(1).mean()) if len(rets) == 1: return rets[0] return rets
def update(self, sample, agent_i, parallel=False, grad_norm=0.5): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: EpisodeBatch, use sample[key_i] to get a specific array of obs, action, etc for agent i agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads """ obs, acs, rews, next_obs, dones = self.parse_sample(sample) # [(B,T,D)]*N bs, ts, _ = obs[0].shape self.init_hidden(bs) # use pre-defined init hiddens curr_agent = self.agents[agent_i] # NOTE: critic update curr_agent.critic_optimizer.zero_grad() # compute target actions if self.alg_types[agent_i] == 'MADDPG': all_trgt_acs = [] # [(B,T,A)]*N for i, (pi, nobs) in enumerate(zip(self.target_policies, next_obs)): # nobs: (B,T,O) act_i = self.compute_action(i, pi, nobs, bs=bs, ts=ts) all_trgt_acs.append(act_i) # [(B,T,A)] if self.discrete_action: # one-hot encode action all_trgt_acs = [onehot_from_logits( act_i.reshape(bs*ts,-1) ).reshape(bs,ts,-1) for act_i in all_trgt_acs] # critic input, [(B,T,O)_i, ..., (B,T,A)_i, ...] -> (B,T,O*N+A*N) trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=-1) else: # DDPG act_i = self.compute_action(agent_i, curr_agent.target_policy, next_obs[agent_i], bs=bs, ts=ts) if self.discrete_action: act_i = onehot_from_logits( act_i.reshape(bs*ts, -1) ).reshape(bs, ts, -1) # (B,T,O) + (B,T,A) -> (B,T,O+A) trgt_vf_in = torch.cat((next_obs[agent_i], act_i), dim=-1) # bellman targets target_q = self.compute_q_val(agent_i, curr_agent.target_critic, trgt_vf_in, bs=bs, ts=ts) # (B*T,1) target_value = (rews[agent_i].view(-1, 1) + self.gamma * target_q * (1 - dones[agent_i].view(-1, 1))) # (B*T,1) # Q func if self.alg_types[agent_i] == 'MADDPG': vf_in = torch.cat((*obs, *acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1) actual_value = self.compute_q_val(agent_i, curr_agent.critic, vf_in, bs=bs, ts=ts) # (B*T,1) # bellman errors vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() if parallel: average_gradients(curr_agent.critic) if grad_norm > 0: torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), grad_norm) curr_agent.critic_optimizer.step() # NOTE: policy update curr_agent.policy_optimizer.zero_grad() # current agent action (deterministic, softened) curr_pol_out = self.compute_action(i, curr_agent.policy, obs[agent_i], bs=bs, ts=ts) # (B,T,A) if self.discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. curr_pol_vf_in = gumbel_softmax( curr_pol_out.reshape(bs*ts, -1), hard=True ).reshape(bs, ts, -1) else: curr_pol_vf_in = curr_pol_out if self.alg_types[agent_i] == 'MADDPG': all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): if i == agent_i: # insert current agent act to q input all_pol_acs.append(curr_pol_vf_in) else: p_act_i = self.compute_action(i, pi, ob, bs=bs, ts=ts) # (B,T,A) if self.discrete_action: p_act_i = onehot_from_logits( p_act_i.reshape(bs*ts, -1) ).reshape(bs, ts, -1) all_pol_acs.append(p_act_i) p_vf_in = torch.cat((*obs, *all_pol_acs), dim=-1) # (B,T,O*N+A*N) else: # DDPG p_vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=-1) # (B,T,O+A) # value function to update current policy p_value = self.compute_q_val(agent_i, curr_agent.critic, p_vf_in, bs=bs, ts=ts) # (B*T,1) pol_loss = -p_value.mean() # p regularization, scale down output (gaussian mean,std or logits) # reference: https://github.com/openai/maddpg/blob/master/maddpg/trainer/maddpg.py pol_loss += ((curr_pol_out.reshape(bs*ts, -1))**2).mean() * 1e-3 pol_loss.backward() if parallel: average_gradients(curr_agent.policy) if grad_norm > 0: torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), grad_norm) curr_agent.policy_optimizer.step() # collect training statss results = { "agent_{}_critic_loss".format(agent_i): vf_loss, "agent_{}_policy_loss".format(agent_i): pol_loss } return results
def update(self, sample, agent_i): """ TODO: to make the sample into valid inputs :param sample: the batch of experiences :param agent_i: the agent to be updated """ obs, acs, rews, next_obs, dones = sample tensor_acs = torch.from_numpy(acs).float() tensor_obs = torch.from_numpy(obs).float() tensor_next_obs = torch.from_numpy(next_obs).float() tensor_rews = torch.from_numpy(rews).float() tensor_dones = torch.from_numpy(~dones).float() current_agent = self.agents[agent_i] current_agent.critic_optimizer.zero_grad() if self.alg_types[agent_i] == 'MADDPG': all_target_actors = self.target_actors() if self.discrete_action: all_target_actions = [ onehot_from_logits( pi(tensor_obs[:, self.observation_index[agent_index][0]: self.observation_index[agent_index][1]])) for pi, agent_index in zip(all_target_actors, range(self.num_agent)) ] else: all_target_actions = [ pi(tensor_obs[:, self.observation_index[agent_index][0]:self. observation_index[agent_index][1]]) for pi, agent_index in zip(all_target_actors, range(self.num_agent)) ] target_critic_input = torch.cat( (tensor_next_obs, torch.cat(all_target_actions, dim=1)), dim=1) else: if self.discrete_action: target_critic_input = torch.cat(( tensor_next_obs[:, self.observation_index[agent_i][0]:self. observation_index[agent_i][1]], onehot_from_logits( current_agent.target_actor( tensor_next_obs[:, self. observation_index[agent_i][0]:self. observation_index[agent_i][1]]))), dim=1) else: target_critic_input = torch.cat(( tensor_next_obs[:, self.observation_index[agent_i][0]:self. observation_index[agent_i][1]], current_agent.target_actor( tensor_next_obs[:, self.observation_index[agent_i][0]:self .observation_index[agent_i][1]])), dim=1) target_critic_value = current_agent.target_critic(target_critic_input) target_value = \ tensor_rews[:, agent_i].unsqueeze(1) + \ self.gamma * target_critic_value * tensor_dones[:, agent_i].unsqueeze(1) if self.alg_types[agent_i] == 'MADDPG': critic_input = torch.cat((tensor_obs, tensor_acs), dim=1) else: # DDPG critic_input = torch.cat( (tensor_obs[:, self.observation_index[agent_i][0]:self. observation_index[agent_i][1]], tensor_acs[:, self.action_index[agent_i][0]:self. action_index[agent_i][1]]), dim=1) actual_value = current_agent.critic(critic_input) critic_loss = MSELoss(actual_value, target_value.detach()) critic_loss.backward() torch.nn.utils.clip_grad_norm_(current_agent.critic.parameters(), 0.5) current_agent.critic_optimizer.step() current_agent.actor_optimizer.zero_grad() if self.discrete_action: current_action_out = current_agent.actor( tensor_obs[:, self.observation_index[agent_i][0]:self. observation_index[agent_i][1]]) current_action_input_critic = gumbel_softmax(current_action_out, hard=True) else: current_action_out = current_agent.actor( tensor_obs[:, self.observation_index[agent_i][0]:self. observation_index[agent_i][1]]) current_action_input_critic = current_action_out if self.alg_types[agent_i] == 'MADDPG': all_actor_action = [] all_target_actors = self.target_actors() for i, pi in zip(range(self.num_agent), all_target_actors): if i == agent_i: all_actor_action.append(current_action_input_critic) else: if self.discrete_action: all_actor_action.append( onehot_from_logits(all_target_actors[i]( tensor_obs[:, self.observation_index[i][0]:self. observation_index[i][1]]))) else: all_actor_action.append(all_target_actors[i]( tensor_obs[:, self.observation_index[i][0]:self. observation_index[i][1]])) critic_input = torch.cat( (tensor_obs, torch.cat(all_actor_action, dim=1)), dim=1) else: critic_input = torch.cat( (tensor_obs[:, self.observation_index[agent_i][0]:self. observation_index[agent_i][1]], current_action_input_critic), dim=1) actor_loss = -current_agent.critic(critic_input).mean() actor_loss += (current_action_out**2).mean() * 1e-3 actor_loss.backward() torch.nn.utils.clip_grad_norm_(current_agent.actor.parameters(), 0.5) current_agent.actor_optimizer.step()
def update(self, sample, agent_i, parallel=False, logger=None): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next observations, and episode end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads logger (SummaryWriter from Tensorboard-Pytorch): If passed in, important quantities will be logged """ # For RNN, the obs and next_obs both have histories obs, acs, rews, next_obs, dones = sample curr_agent = self.agents[agent_i] curr_agent.critic_optimizer.zero_grad() if self.alg_types[agent_i] == 'MADDPG': if self.discrete_action: # one-hot encode action # This is original one, 'pi': policy, 'nobs' n_observations #all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in # zip(self.target_policies, next_obs)] # Original till here #-------- Expanding out for debugging --------# all_trgt_acs = [] for pi, nobs in zip(self.target_policies, next_obs): temp = onehot_from_logits(pi(nobs)) #print(temp) all_trgt_acs.append(temp) # -------- End debug -------------------------# else: all_trgt_acs = [ pi(nobs) for pi, nobs in zip(self.target_policies, next_obs) ] # Get the most current observation from the history to calculate the target value t0_next_obs = [[], [], []] for a in range(self.nagents): t0_next_obs[a] = torch.tensor(np.zeros( (next_obs[0].shape[0], 18)), dtype=torch.float) # the next_obs[0].shape[0] gives the batch size # TODO: change it to be a parameter # Only keep the current obs for critic VF for n in range(self.nagents): # for each agents for b in range( next_obs[0].shape[0]): # for the number of batches t0_next_obs[n][b][:] = next_obs[n][b][0:18] # ORIGINAL was \/ #trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1) trgt_vf_in = torch.cat((*t0_next_obs, *all_trgt_acs), dim=1) # It is working till here. Only kept the current obs for critic VF else: # DDPG # DDPG only knows the particular agent's observation and policy # Whereas, MADDPG has access to all other agents' policies # TODO: grab only the current observation to send to the critic t0_next_obs = [[], [], []] for a in range(self.nagents): t0_next_obs[a] = torch.tensor(np.zeros( (next_obs[0].shape[0], 18)), dtype=torch.float) for n in range(self.nagents): # for each agents for b in range( next_obs[0].shape[0]): # for the number of batches t0_next_obs[n][b][:] = next_obs[n][b][0:18] # Originally it would be next_obs[agent_i] instead of t0_next_obs[agent_i] if self.discrete_action: trgt_vf_in = torch.cat( (t0_next_obs[agent_i], onehot_from_logits( curr_agent.target_policy(next_obs[agent_i]))), dim=1) else: trgt_vf_in = torch.cat( (next_obs[agent_i], curr_agent.target_policy(next_obs[agent_i])), dim=1) target_value = (rews[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(trgt_vf_in) * (1 - dones[agent_i].view(-1, 1))) ##### Just get the current observation (i.e., without history) ########## # Reason: Critic VF does not need history # Copied the same as in t0_next_obs, since BOTH obs and next_obs have HISTORIES. t0_obs = [[], [], []] for a in range(self.nagents): t0_obs[a] = torch.tensor(np.zeros((obs[0].shape[0], 18)), dtype=torch.float) for n in range(self.nagents): # for each agents for b in range(obs[0].shape[0]): # for the number of batches t0_obs[n][b][:] = obs[n][b][0:18] ################################################### if self.alg_types[agent_i] == 'MADDPG': vf_in = torch.cat((*t0_obs, *acs), dim=1) else: # DDPG #TODO: below, might have to change obs to t0_obs, when using DDPG vf_in = torch.cat((t0_obs[agent_i], acs[agent_i]), dim=1) actual_value = curr_agent.critic(vf_in) vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() if parallel: average_gradients(curr_agent.critic) torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5) curr_agent.critic_optimizer.step() curr_agent.policy_optimizer.zero_grad() if self.discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. ''' Now, we are back to forwarding policy, so we need to use obs with history ''' curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) # Seems to be working fine till here else: curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = curr_pol_out if self.alg_types[agent_i] == 'MADDPG': all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): if i == agent_i: all_pol_acs.append(curr_pol_vf_in) elif self.discrete_action: all_pol_acs.append(onehot_from_logits(pi(ob))) else: all_pol_acs.append(pi(ob)) # Originally: #vf_in = torch.cat((*obs, *all_pol_acs), dim=1) vf_in = torch.cat((*t0_obs, *all_pol_acs), dim=1) else: # DDPG vf_in = torch.cat((t0_obs[agent_i], curr_pol_vf_in), dim=1) # TODO: FIX THIS pol_loss = -curr_agent.critic(vf_in).mean() pol_loss += (curr_pol_out**2).mean() * 1e-3 pol_loss.backward() if parallel: average_gradients(curr_agent.policy) torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5) curr_agent.policy_optimizer.step() if logger is not None: logger.add_scalars('agent%i/losses' % agent_i, { 'vf_loss': vf_loss, 'pol_loss': pol_loss }, self.niter)