def compute(self, rewards, next_obs): with torch.no_grad(): next_obs = [Variable(torch.Tensor(np.vstack(next_obs[:, i])), requires_grad=False) for i in range(rewards.shape[1])] acs_src = [] prob_src = [] for no, source in zip(next_obs, self.source): acs_src.append(gumbel_softmax(source(no), device=self.source_dev, hard=True)) prob_src.append(gumbel_softmax(source(no), device=self.source_dev, hard=False)) trans_in = torch.cat((*next_obs, *acs_src), dim=1) trans_out = self.transition(trans_in) prob_plan = [] start = 0 for i, (no, planning) in enumerate(zip(next_obs, self.planning)): length = no.shape[1] nno = trans_out[:, start:start + length] acs_ = [] for j, ac in enumerate(acs_src): if j != i: acs_.append(ac) acs_ = torch.cat(acs_, dim=1) plan_in = torch.cat((no, nno, acs_), dim=1) prob_plan.append(gumbel_softmax(planning(plan_in), device=self.plan_dev, hard=False)) prob_plan = torch.cat(prob_plan, dim=1) prob_src = torch.cat(prob_src, dim=1) acs_src = torch.cat(acs_src, dim=1) E = acs_src * prob_plan - acs_src * prob_src i_rews = E.mean() * torch.ones((1, rewards.shape[1])) return i_rews.numpy()
def update(self, sample, logger): obs, acs, rews, emps, next_obs, dones = sample self.transition_optimizer.zero_grad() trans_in = torch.cat((*obs, *acs), dim=1) next_obs_pred = self.transition(trans_in) trans_loss = MSELoss(next_obs_pred, torch.cat(next_obs, dim=1)) trans_loss.backward() self.transition_optimizer.step() self.source_optimizer.zero_grad() acs_src = [] prob_src = [] for no, source in zip(next_obs, self.source): acs_src.append( gumbel_softmax(source(no), device=self.source_dev, hard=True)) prob_src.append( gumbel_softmax(source(no), device=self.source_dev, hard=False)) with torch.no_grad(): trans_in = torch.cat((*next_obs, *acs_src), dim=1) trans_out = self.transition(trans_in) prob_plan = [] start = 0 for i, (no, planning) in enumerate(zip(next_obs, self.planning)): length = no.shape[1] nno = trans_out[:, start:start + length] acs_pi = gumbel_softmax(self.agents[i].policy(nno), device=self.source_dev, hard=True) plan_in = torch.cat((no, nno, acs_pi), dim=1) prob_plan.append( gumbel_softmax(planning(plan_in), device=self.plan_dev, hard=False)) start += length prob_plan = torch.cat(prob_plan, dim=1) prob_src = torch.cat(prob_src, dim=1) acs_src = torch.cat(acs_src, dim=1) E = acs_src * prob_plan - acs_src * prob_src i_rews = -E.mean() i_rews.backward() self.source_optimizer.step() if logger is not None: logger.add_scalars('empowerment/losses', { 'trans_loss': trans_loss.detach(), 'i_rews': i_rews.detach() }, self.niter) self.niter += 1
def step(self, obs, explore=False): """ Take a step forward in environment for a minibatch of observations equivalent to `act` or `compute_actions` Arguments: obs: (B,O) explore: Whether or not to add exploration noise Returns: action: dict of actions for this agent, (B,A) """ with torch.no_grad(): action, hidden_states = self.policy(obs, self.policy_hidden_states) self.policy_hidden_states = hidden_states # if mlp, still defafult None if self.discrete_action: for k in action: if explore: action[k] = gumbel_softmax(action[k], hard=True) else: action[k] = onehot_from_logits(action[k]) else: # continuous action idx = 0 noise = Variable(Tensor(self.exploration.noise()), requires_grad=False) for k in action: if explore: dim = action[k].shape[-1] action[k] += noise[idx:idx + dim] idx += dim action[k] = action[k].clamp(-1, 1) return action
def step(self, obs, explore=False): """ Take a step forward in environment for a minibatch of observations Inputs: obs (PyTorch Variable): Observations for this agent explore (boolean): Whether or not to add exploration noise Outputs: action (PyTorch Variable): Actions for this agent """ # print('----', obs) action = self.policy(obs) # print('>>>>>>>>', action) if self.discrete_action: # print('agents.py discrete_action yes') if explore: action = gumbel_softmax(action, hard=True) else: action = onehot_from_logits(action) else: # continuous action # print('agents.py continuous_action yes') if explore: # print('agents.py explore yes') # print('action before noise', action) action += Variable(Tensor(self.exploration.noise()), requires_grad=False) # print('action after noise', action) action = action.clamp(-1, 1) # print('action after clamp', action) # print('>>>>>>>>>>>>>>>>', action) return action
def _soft_act(x): # x: (B,A) if not self.discrete_action: return x if requires_grad: return gumbel_softmax(x, hard=True) else: return onehot_from_logits(x)
def _soft_act(self, x, requires_grad=True): """ soften action if discrete, x: (B,A) """ if not self.discrete_action: return x if requires_grad: return gumbel_softmax(x, hard=True) else: return onehot_from_logits(x)
def compute(self, next_obs): acs_src = [] prob_src = [] for no, source in zip(next_obs, self.source): acs_src.append( gumbel_softmax(source(no), device=self.device.get_device(), hard=True)) prob_src.append( gumbel_softmax(source(no), device=self.device.get_device(), hard=False)) trans_in = torch.cat((*next_obs, *acs_src), dim=1) trans_out = self.transition(trans_in) prob_plan = [] end_idx = [0] + np.cumsum([ne_ob.shape[1] for ne_ob in next_obs]).tolist() start_end = [(start, end) for start, end in zip(end_idx, end_idx[1:])] for i, (no, planning) in enumerate(zip(next_obs, self.planning)): nno = trans_out[:, start_end[i][0]:start_end[i][1]] acs_ = [] for j, ac in enumerate(acs_src): if j == i: continue nno_other = trans_out[:, start_end[j][0]:start_end[j][1]] acs_.append( gumbel_softmax(self.agents[j].policy(nno_other), device=self.device.get_device(), hard=True)) acs_ = torch.cat(acs_, dim=1) plan_in = torch.cat((no, nno, acs_), dim=1) prob_plan.append( gumbel_softmax(planning(plan_in), device=self.device.get_device(), hard=False)) prob_plan = torch.cat(prob_plan, dim=1) prob_src = torch.cat(prob_src, dim=1) acs_src = torch.cat(acs_src, dim=1) return acs_src * prob_plan - acs_src * prob_src
def compute(self, rewards, next_obs): with torch.no_grad(): next_obs = [ Variable(torch.Tensor(np.vstack(next_obs[:, i])), requires_grad=False) for i in range(rewards.shape[1]) ] acs_src = [] prob_src = [] for no in next_obs: acs_src.append( gumbel_softmax(self.source(no), device=self.device.get_device(), hard=True)) prob_src.append( gumbel_softmax(self.source(no), device=self.device.get_device(), hard=False)) trans_in = torch.cat((*next_obs, *acs_src), dim=1) trans_out = self.transition(trans_in) prob_plan = [] n_obs = len(next_obs[0][0]) for i, no in enumerate(next_obs): nno = trans_out[:, i * n_obs:(i + 1) * n_obs] plan_in = torch.cat((no, nno), dim=1) prob_plan.append( gumbel_softmax(self.planning(plan_in), device=self.device.get_device(), hard=False)) prob_plan = torch.cat(prob_plan, dim=1) prob_src = torch.cat(prob_src, dim=1) acs_src = torch.cat(acs_src, dim=1) E = acs_src * prob_plan - acs_src * prob_src i_rews = E.mean() * torch.ones((1, rewards.shape[1])) return i_rews.numpy()
def get_actions(self, obs, noise=True, batch=False, hard=False): acts = self.ac_update.get_action(obs, batch) if self.args.discrete_action: if noise: assert batch is False acts = gumbel_softmax(acts, hard).cpu().detach().numpy().squeeze() else: acts = onehot(acts) else: acts = acts.cpu().detach().numpy().squeeze() if noise: assert batch is False acts = self.noise.get_action(acts) return acts
def cast_embedding(self, emb, explore=True): # the terminology here is a bit misleading: explore==True is used for roll-outs (exploring the embedding # space with boltzmann exploration) and for selecting an un-biased backpropagable action (in contrast to a # back-propagable argmax that would be biased because always the mode of the distribution and not the mean (in # contrast to a gaussian that has mode=mean)) # explore==False is only used at evaluation # we could imagine having three cases: 1-epsilon greedy exploration for roll-outs (or tunable temperature), # 2- gumbel_softmax for backprop # 3- argmax for evaluation if explore: emb = gumbel_softmax(emb, hard=True) else: emb = differentiable_onehot_from_logits(emb) return emb
def update(self, sample, agent_i, parallel=False, logger=None): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next observations, and episode end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads logger (SummaryWriter from Tensorboard-Pytorch): If passed in, important quantities will be logged """ obs, acs, rews, next_obs, dones = sample curr_agent = self.agents[agent_i] curr_agent.critic_optimizer.zero_grad() if self.alg_types[agent_i] == 'MADDPG': if self.discrete_action: # one-hot encode action all_trgt_acs = [ onehot_from_logits(pi(nobs)) for pi, nobs in zip(self.target_policies, next_obs) ] # a'=mu'(o') Have all agents' else: all_trgt_acs = [ pi(nobs) for pi, nobs in zip(self.target_policies, next_obs) ] # ==========================Adding noise==================== if self.noisy_sharing == True: #noisy_all_trgt_acs = self.noisy_sharing_discrete(all_trgt_acs,agent_i) #all_trgt_acs = noisy_all_trgt_acs noisy_acs = self.noisy_sharing_discrete(acs, agent_i) acs = noisy_acs # print(self.noisy_SNR) # ==================End of Adding noise==================== trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1) # =========================Differential Obs======================== # ============== Dedicate for simple_speaker_listener ============= # The est_action is used to replace acs[1] if self.game_id == 'simple_speaker_listener' and self.est_ac == True: diff_pos = (next_obs[0] - obs[0])[:, -2:] tmp_p = torch.transpose(diff_pos.ge(torch.max(diff_pos) * 0.8), 0, 1) tmp_p[0] = tmp_p[0] * 1 tmp_p[1] = tmp_p[1] * 3 tmp_n = torch.transpose(diff_pos.le(torch.min(diff_pos) * 0.8), 0, 1) tmp_n[0] = tmp_n[0] * 2 tmp_n[1] = tmp_n[1] * 4 mask = torch.transpose(tmp_p, 0, 1) + torch.transpose( tmp_n, 0, 1) est_action = mask.sum(dim=1) est_action = torch.zeros(len(est_action), acs[1].shape[1]).scatter_( dim=1, index=est_action.view(-1, 1), value=1) acs[1] = est_action # =======================End of differential Obs ================== else: # DDPG if self.discrete_action: trgt_vf_in = torch.cat( (next_obs[agent_i], onehot_from_logits( curr_agent.target_policy(next_obs[agent_i]))), dim=1) else: # a'=mu(o') only have current agent's trgt_vf_in = torch.cat( (next_obs[agent_i], curr_agent.target_policy(next_obs[agent_i])), dim=1) target_value = (rews[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(trgt_vf_in) * (1 - dones[agent_i].view(-1, 1))) #y^j if self.alg_types[agent_i] == 'MADDPG': vf_in = torch.cat((*obs, *acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1) actual_value = curr_agent.critic(vf_in) vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() if parallel: average_gradients(curr_agent.critic) torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), 0.5) curr_agent.critic_optimizer.step() # ============== Here for policy network training ===================== curr_agent.policy_optimizer.zero_grad() if self.discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) else: curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = curr_pol_out if self.alg_types[agent_i] == 'MADDPG': all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): # Is it correct to train mu using all others' policies??? if i == agent_i: all_pol_acs.append(curr_pol_vf_in) elif self.discrete_action: all_pol_acs.append(onehot_from_logits(pi(ob))) else: all_pol_acs.append(pi(ob)) vf_in = torch.cat((*obs, *all_pol_acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=1) pol_loss = -curr_agent.critic(vf_in).mean() pol_loss += (curr_pol_out**2).mean() * 1e-3 pol_loss.backward() if parallel: average_gradients(curr_agent.policy) torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), 0.5) # Constraints on the grad. curr_agent.policy_optimizer.step() if logger is not None: logger.add_scalars('agent%i/losses' % agent_i, { 'vf_loss': vf_loss, 'pol_loss': pol_loss }, self.niter)
def update(self, sample, agent_i): """ TODO: to make the sample into valid inputs :param sample: the batch of experiences :param agent_i: the agent to be updated """ obs, acs, rews, next_obs, dones = sample tensor_acs = torch.from_numpy(acs).float() tensor_obs = torch.from_numpy(obs).float() tensor_next_obs = torch.from_numpy(next_obs).float() tensor_rews = torch.from_numpy(rews).float() tensor_dones = torch.from_numpy(~dones).float() current_agent = self.agents[agent_i] current_agent.critic_optimizer.zero_grad() if self.alg_types[agent_i] == 'MADDPG': all_target_actors = self.target_actors() if self.discrete_action: all_target_actions = [ onehot_from_logits( pi(tensor_obs[:, self.observation_index[agent_index][0]: self.observation_index[agent_index][1]])) for pi, agent_index in zip(all_target_actors, range(self.num_agent)) ] else: all_target_actions = [ pi(tensor_obs[:, self.observation_index[agent_index][0]:self. observation_index[agent_index][1]]) for pi, agent_index in zip(all_target_actors, range(self.num_agent)) ] target_critic_input = torch.cat( (tensor_next_obs, torch.cat(all_target_actions, dim=1)), dim=1) else: if self.discrete_action: target_critic_input = torch.cat(( tensor_next_obs[:, self.observation_index[agent_i][0]:self. observation_index[agent_i][1]], onehot_from_logits( current_agent.target_actor( tensor_next_obs[:, self. observation_index[agent_i][0]:self. observation_index[agent_i][1]]))), dim=1) else: target_critic_input = torch.cat(( tensor_next_obs[:, self.observation_index[agent_i][0]:self. observation_index[agent_i][1]], current_agent.target_actor( tensor_next_obs[:, self.observation_index[agent_i][0]:self .observation_index[agent_i][1]])), dim=1) target_critic_value = current_agent.target_critic(target_critic_input) target_value = \ tensor_rews[:, agent_i].unsqueeze(1) + \ self.gamma * target_critic_value * tensor_dones[:, agent_i].unsqueeze(1) if self.alg_types[agent_i] == 'MADDPG': critic_input = torch.cat((tensor_obs, tensor_acs), dim=1) else: # DDPG critic_input = torch.cat( (tensor_obs[:, self.observation_index[agent_i][0]:self. observation_index[agent_i][1]], tensor_acs[:, self.action_index[agent_i][0]:self. action_index[agent_i][1]]), dim=1) actual_value = current_agent.critic(critic_input) critic_loss = MSELoss(actual_value, target_value.detach()) critic_loss.backward() torch.nn.utils.clip_grad_norm_(current_agent.critic.parameters(), 0.5) current_agent.critic_optimizer.step() current_agent.actor_optimizer.zero_grad() if self.discrete_action: current_action_out = current_agent.actor( tensor_obs[:, self.observation_index[agent_i][0]:self. observation_index[agent_i][1]]) current_action_input_critic = gumbel_softmax(current_action_out, hard=True) else: current_action_out = current_agent.actor( tensor_obs[:, self.observation_index[agent_i][0]:self. observation_index[agent_i][1]]) current_action_input_critic = current_action_out if self.alg_types[agent_i] == 'MADDPG': all_actor_action = [] all_target_actors = self.target_actors() for i, pi in zip(range(self.num_agent), all_target_actors): if i == agent_i: all_actor_action.append(current_action_input_critic) else: if self.discrete_action: all_actor_action.append( onehot_from_logits(all_target_actors[i]( tensor_obs[:, self.observation_index[i][0]:self. observation_index[i][1]]))) else: all_actor_action.append(all_target_actors[i]( tensor_obs[:, self.observation_index[i][0]:self. observation_index[i][1]])) critic_input = torch.cat( (tensor_obs, torch.cat(all_actor_action, dim=1)), dim=1) else: critic_input = torch.cat( (tensor_obs[:, self.observation_index[agent_i][0]:self. observation_index[agent_i][1]], current_action_input_critic), dim=1) actor_loss = -current_agent.critic(critic_input).mean() actor_loss += (current_action_out**2).mean() * 1e-3 actor_loss.backward() torch.nn.utils.clip_grad_norm_(current_agent.actor.parameters(), 0.5) current_agent.actor_optimizer.step()
def update(self, sample, logger): obs, acs, rews, emps, next_obs, dones = sample self.transition_optimizer.zero_grad() trans_in = torch.cat((*obs, *acs), dim=1) next_obs_pred = self.transition(trans_in) trans_loss = MSELoss(next_obs_pred, torch.cat(next_obs, dim=1)) trans_loss.backward() self.transition_optimizer.step() self.planning_optimizer.zero_grad() acs_plan = [] for o, no in zip(obs, next_obs): plan_in = torch.cat((o, no), dim=1) acs_plan.append( gumbel_softmax(self.planning(plan_in), device=self.device.get_device(), hard=True)) acs_plan = torch.cat(acs_plan, dim=1) acs_torch = torch.cat(acs, dim=1) plan_loss = MSELoss(acs_plan, acs_torch) plan_loss.backward() self.planning_optimizer.step() self.source_optimizer.zero_grad() acs_src = [] prob_src = [] for no in next_obs: acs_src.append( gumbel_softmax(self.source(no), device=self.device.get_device(), hard=True)) prob_src.append( gumbel_softmax(self.source(no), device=self.device.get_device(), hard=False)) with torch.no_grad(): trans_in = torch.cat((*next_obs, *acs_src), dim=1) trans_out = self.transition(trans_in) prob_plan = [] n_obs = len(next_obs[0][0]) for i, no in enumerate(next_obs): nno = trans_out[:, i * n_obs:(i + 1) * n_obs] plan_in = torch.cat((no, nno), dim=1) prob_plan.append( gumbel_softmax(self.planning(plan_in), device=self.device.get_device(), hard=False)) prob_plan = torch.cat(prob_plan, dim=1) prob_src = torch.cat(prob_src, dim=1) acs_src = torch.cat(acs_src, dim=1) E = acs_src * prob_plan - acs_src * prob_src i_rews = -E.mean() i_rews.backward() self.source_optimizer.step() if logger is not None: logger.add_scalars( 'empowerment/losses', { 'trans_loss': trans_loss.detach(), 'plan_loss': plan_loss.detach(), 'i_rews': i_rews.detach() }, self.niter) self.niter += 1
def cast_embedding(self, emb, explore=True): if explore: emb = gumbel_softmax(emb, hard=True) else: emb = differentiable_onehot_from_logits(emb) return emb
def update(self, sample, agent_i, parallel=False, grad_norm=0.5): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: EpisodeBatch, use sample[key_i] to get a specific array of obs, action, etc for agent i agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads """ obs, acs, rews, next_obs, dones = self.parse_sample(sample) # [(B,T,D)]*N bs, ts, _ = obs[0].shape self.init_hidden(bs) # use pre-defined init hiddens curr_agent = self.agents[agent_i] # NOTE: critic update curr_agent.critic_optimizer.zero_grad() # compute target actions if self.alg_types[agent_i] == 'MADDPG': all_trgt_acs = [] # [(B,T,A)]*N for i, (pi, nobs) in enumerate(zip(self.target_policies, next_obs)): # nobs: (B,T,O) act_i = self.compute_action(i, pi, nobs, bs=bs, ts=ts) all_trgt_acs.append(act_i) # [(B,T,A)] if self.discrete_action: # one-hot encode action all_trgt_acs = [onehot_from_logits( act_i.reshape(bs*ts,-1) ).reshape(bs,ts,-1) for act_i in all_trgt_acs] # critic input, [(B,T,O)_i, ..., (B,T,A)_i, ...] -> (B,T,O*N+A*N) trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=-1) else: # DDPG act_i = self.compute_action(agent_i, curr_agent.target_policy, next_obs[agent_i], bs=bs, ts=ts) if self.discrete_action: act_i = onehot_from_logits( act_i.reshape(bs*ts, -1) ).reshape(bs, ts, -1) # (B,T,O) + (B,T,A) -> (B,T,O+A) trgt_vf_in = torch.cat((next_obs[agent_i], act_i), dim=-1) # bellman targets target_q = self.compute_q_val(agent_i, curr_agent.target_critic, trgt_vf_in, bs=bs, ts=ts) # (B*T,1) target_value = (rews[agent_i].view(-1, 1) + self.gamma * target_q * (1 - dones[agent_i].view(-1, 1))) # (B*T,1) # Q func if self.alg_types[agent_i] == 'MADDPG': vf_in = torch.cat((*obs, *acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1) actual_value = self.compute_q_val(agent_i, curr_agent.critic, vf_in, bs=bs, ts=ts) # (B*T,1) # bellman errors vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() if parallel: average_gradients(curr_agent.critic) if grad_norm > 0: torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), grad_norm) curr_agent.critic_optimizer.step() # NOTE: policy update curr_agent.policy_optimizer.zero_grad() # current agent action (deterministic, softened) curr_pol_out = self.compute_action(i, curr_agent.policy, obs[agent_i], bs=bs, ts=ts) # (B,T,A) if self.discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. curr_pol_vf_in = gumbel_softmax( curr_pol_out.reshape(bs*ts, -1), hard=True ).reshape(bs, ts, -1) else: curr_pol_vf_in = curr_pol_out if self.alg_types[agent_i] == 'MADDPG': all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): if i == agent_i: # insert current agent act to q input all_pol_acs.append(curr_pol_vf_in) else: p_act_i = self.compute_action(i, pi, ob, bs=bs, ts=ts) # (B,T,A) if self.discrete_action: p_act_i = onehot_from_logits( p_act_i.reshape(bs*ts, -1) ).reshape(bs, ts, -1) all_pol_acs.append(p_act_i) p_vf_in = torch.cat((*obs, *all_pol_acs), dim=-1) # (B,T,O*N+A*N) else: # DDPG p_vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=-1) # (B,T,O+A) # value function to update current policy p_value = self.compute_q_val(agent_i, curr_agent.critic, p_vf_in, bs=bs, ts=ts) # (B*T,1) pol_loss = -p_value.mean() # p regularization, scale down output (gaussian mean,std or logits) # reference: https://github.com/openai/maddpg/blob/master/maddpg/trainer/maddpg.py pol_loss += ((curr_pol_out.reshape(bs*ts, -1))**2).mean() * 1e-3 pol_loss.backward() if parallel: average_gradients(curr_agent.policy) if grad_norm > 0: torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), grad_norm) curr_agent.policy_optimizer.step() # collect training statss results = { "agent_{}_critic_loss".format(agent_i): vf_loss, "agent_{}_policy_loss".format(agent_i): pol_loss } return results
def update_coach(self, sample): if any(['Coach' in alg for alg in self.alg_types]): observations, actions, rewards, next_obs, dones = sample # Computes coach embedding coach_embed_logits = self.coach.model( torch.cat((*observations, ), dim=1)) coach_embed = gumbel_softmax(coach_embed_logits, hard=True) ## EMBEDDING MATCHING REGULARIZATION # Computes agents embeddings regularization J_E = 0 for i, pi, ob in zip(range(self.nagents), self.policies, observations): if "Coach" in self.alg_types[i]: _, agent_embed_logits = pi(ob, return_embed_logits=True) J_E += self.coach.get_regularization_loss( coach_embed_logits, agent_embed_logits) J_E = J_E / self.nagents ## POLICY GRADIENT WITH EMBEDDING REGULARIZATION # Gets actions of all agents when computed from the coach-embedding (coordinated actions) all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, observations): if "Coach" in self.alg_types[i]: if self.use_discrete_action: # we need this trick to be able to differentiate all_pol_acs.append( differentiable_onehot_from_logits( pi.partial_forward(ob, coach_embed))) else: all_pol_acs.append(pi.partial_forward(ob, coach_embed)) # Gets evaluations from all critics vf_in = torch.cat((*observations, *all_pol_acs), dim=1) all_critics_eval = [] for i, critic in enumerate(self.critics): if "Coach" in self.alg_types[i]: all_critics_eval.append(critic(vf_in)) J_PGE = -torch.mean(torch.stack(all_critics_eval).squeeze()) ## BACKPROP, we backprop in two steps because the agents and the coach do not have the same weighting i = 0 for loss, lam in zip([J_E, J_PGE], [self.lambdac_1, self.lambdac_2]): # Resets gradient buffers self.coach.optimizer.zero_grad() for agent in self.agents: agent.policy.zero_grad() loss.backward(retain_graph=i == 0) # Apply coach update to coach if i == 0: multiply_gradient(self.coach.model.parameters(), self.lambdac_3) torch.nn.utils.clip_grad_norm_(self.coach.model.parameters(), self.grad_clip_value) self.coach.optimizer.step() # Apply coach update to all agents for i, agent in enumerate(self.agents): if "Coach" in self.alg_types[i]: multiply_gradient(agent.policy.parameters(), lam * self.nagents) torch.nn.utils.clip_grad_norm_( agent.policy.parameters(), self.grad_clip_value) agent.policy_optimizer.step() i += 1 return J_E.data.cpu().numpy(), J_PGE.data.cpu().numpy()
def update_agent(self, sample, agent_i): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next_observations, and episode_end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads """ # Extract info and agent observations, actions, rewards, next_obs, dones = sample curr_agent = self.agents[agent_i] # UPDATES THE CRITIC --- # Resets gradient buffer curr_agent.critic_optimizer.zero_grad() curr_agent.f_e_optimizer.zero_grad() # Gets target state-action pair (next_obs, next_actions) if self.alg_types[agent_i] in ['TeamMADDPG', 'MADDPG']: if self.use_discrete_action: # one-hot encode action all_target_actions = [ onehot_from_logits( pi(f_e(nobs)).view(-1, self.action_spaces[0], self.nagents)[:, :, i]) for i, (pi, f_e, nobs) in enumerate( zip(self.target_policies, self.f_es, next_obs)) ] else: all_target_actions = [ pi(f_e(nobs)).view(-1, self.action_spaces[0], self.nagents)[:, :, i] for i, (pi, f_e, nobs) in enumerate( zip(self.target_policies, self.f_es, next_obs)) ] if self.critic_concat_all_obs: target_vf_in = torch.cat((*next_obs, *all_target_actions), dim=1) else: target_vf_in = torch.cat( (curr_agent.f_e(next_obs[agent_i]), *all_target_actions), dim=1) elif self.alg_types[agent_i] == 'DDPG': next_fe = curr_agent.f_e(next_obs[agent_i]) if self.use_discrete_action: target_vf_in = torch.cat( (next_fe, onehot_from_logits(curr_agent.target_policy(next_fe))), dim=1) else: target_vf_in = torch.cat( (next_fe, curr_agent.target_policy(next_fe)), dim=1) else: raise NotImplemented # Computes target value target_value = (rewards[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(target_vf_in) * (1 - dones.view(-1, 1))) # Computes current state-action value if self.alg_types[agent_i] in ['TeamMADDPG', 'MADDPG']: if self.critic_concat_all_obs: critic_obs = [ agent.f_e(obs) for obs, agent in zip(observations, self.agents) ] else: critic_obs = [curr_agent.f_e(observations[agent_i])] vf_in = torch.cat((*critic_obs, *actions), dim=1) elif self.alg_types[agent_i] == 'DDPG': critic_obs = curr_agent.f_e(observations[agent_i]) vf_in = torch.cat((critic_obs, actions[agent_i]), dim=1) else: raise NotImplemented actual_value = curr_agent.critic(vf_in) # Backpropagates vf_loss = MSELoss(actual_value, target_value.detach()) # we have to retain the graph because we reuse the critic_obs for the following actor loss vf_loss.backward( retain_graph=True ) ##todo:make sure there is no leakage between the two losses # Clip gradients torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), self.grad_clip_value) # Apply critic update curr_agent.critic_optimizer.step() # UPDATES THE ACTOR --- # Resets gradient buffer curr_agent.policy_optimizer.zero_grad() if self.use_discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. pol_out_all_heads = curr_agent.policy( curr_agent.f_e(observations[agent_i])).view( -1, self.action_spaces[0], self.nagents) curr_pol_out = pol_out_all_heads[:, :, agent_i] curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) else: pol_out_all_heads = curr_agent.policy( curr_agent.f_e(observations[agent_i])).view( -1, self.action_spaces[0], self.nagents) curr_pol_out = pol_out_all_heads[:, :, agent_i] curr_pol_vf_in = curr_pol_out # No Gumbel-softmax for continuous control # Gets state-action pair value given by the critic if self.alg_types[agent_i] in ['TeamMADDPG', 'MADDPG']: all_pol_logits = [] all_pol_acs = [] for i, pi, f_e, ob in zip(range(self.nagents), self.policies, self.f_es, observations): if i == agent_i: all_pol_logits.append(curr_pol_out) all_pol_acs.append(curr_pol_vf_in) elif self.use_discrete_action: logits = pi(f_e(ob)).view(-1, self.action_spaces[0], self.nagents)[:, :, i] all_pol_logits.append(logits) all_pol_acs.append(onehot_from_logits(logits).detach()) else: all_pol_logits.append(None) all_pol_acs.append( pi(f_e(ob)).detach().view(-1, self.action_spaces[0], self.nagents)[:, :, i]) vf_in = torch.cat((*critic_obs, *all_pol_acs), dim=1) # Centralized critic for MADDPG agent elif self.alg_types[agent_i] == 'DDPG': vf_in = torch.cat((critic_obs, curr_pol_vf_in), dim=1) else: raise NotImplemented # Computes the loss pol_loss = -torch.mean(curr_agent.critic(vf_in)) # Backpropagates pol_loss.backward(retain_graph=True if self.alg_types[agent_i] == "TeamMADDPG" else False) # Clip gradients torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), self.grad_clip_value) # Apply actor update curr_agent.policy_optimizer.step() # Team Spirit (TS) regularization if self.alg_types[agent_i] == "TeamMADDPG": if self.use_discrete_action: real_action_logits = torch.stack(all_pol_logits, dim=2) real_action_probs = F.softmax(real_action_logits, dim=1) real_action_log_probs = F.log_softmax(real_action_logits, dim=1) predicted_action_log_probs = F.log_softmax(pol_out_all_heads, dim=1) # KL-divergence KL(predicted_action_dist||real_action_dist) ts_loss = torch.mean( torch.sum( real_action_probs * (real_action_log_probs - predicted_action_log_probs), dim=1)) else: ts_loss = torch.mean( torch.sum((pol_out_all_heads - torch.stack(all_pol_acs, dim=2))**2, dim=1)) # Resets gradient buffer of all agents for agent in self.agents: agent.policy_optimizer.zero_grad() # Backpropagates through every agent of the team (including curr_agent) ts_loss.backward() for i, agent in enumerate(self.agents): if i == agent_i: coeff = self.lambdat_1 else: coeff = self.lambdat_2 for p in agent.policy.parameters(): p.grad *= coeff # Apply gradients torch.nn.utils.clip_grad_norm_(agent.policy.parameters(), self.grad_clip_value) agent.policy_optimizer.step() ts_loss = ts_loss.data.cpu().numpy() else: ts_loss = None ## Backpropagates for feature extractor from both backprops torch.nn.utils.clip_grad_norm_(curr_agent.f_e.parameters(), self.grad_clip_value) curr_agent.f_e_optimizer.step() return pol_loss.data.cpu().numpy(), vf_loss.data.cpu().numpy(), ts_loss
def update_agent(self, sample, agent_i): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next_observations, and episode_end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads """ # Extract info and agent observations, actions, rewards, next_obs, dones = sample curr_agent = self.agents[agent_i] # UPDATES THE CRITIC --- # Resets gradient buffer curr_agent.critic_optimizer.zero_grad() curr_agent.f_e_optimizer.zero_grad() # Gets target state-action pair (next_obs, next_actions) if self.alg_types[agent_i] in ['SharedMADDPG', 'MADDPG']: if self.use_discrete_action: # one-hot encode action all_target_actions = [ onehot_from_logits(pi(f_e(nobs))) for pi, f_e, nobs in zip( self.target_policies, self.f_es, next_obs) ] else: all_target_actions = [ pi(f_e(nobs)) for pi, f_e, nobs in zip( self.target_policies, self.f_es, next_obs) ] if self.alg_types[agent_i] == 'SharedMADDPG': next_obs.insert(0, next_obs.pop(agent_i)) all_target_actions.insert(0, all_target_actions.pop(agent_i)) if self.critic_concat_all_obs: target_vf_in = torch.cat( (*next_obs, *all_target_actions), dim=1 ) # TODO: WARNING: should probably feed that in feature_extractor else: target_vf_in = torch.cat( (curr_agent.f_e(next_obs[agent_i]), *all_target_actions), dim=1) elif self.alg_types[agent_i] in ['SharedDDPG', 'DDPG']: next_fe = curr_agent.f_e(next_obs[agent_i]) if self.use_discrete_action: target_vf_in = torch.cat( (next_fe, onehot_from_logits(curr_agent.target_policy(next_fe))), dim=1) else: target_vf_in = torch.cat( (next_fe, curr_agent.target_policy(next_fe)), dim=1) else: raise NotImplemented # Computes target value target_value = (rewards[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(target_vf_in) * (1 - dones.view(-1, 1))) # Computes current state-action value if self.alg_types[agent_i] in ['SharedMADDPG', 'MADDPG']: if self.alg_types[agent_i] == 'SharedMADDPG': observations.insert(0, observations.pop(agent_i)) actions.insert(0, actions.pop(agent_i)) vf_in = torch.cat((*observations, *actions), dim=1) elif self.alg_types[agent_i] in ['SharedDDPG', 'DDPG']: vf_in = torch.cat((observations[agent_i], actions[agent_i]), dim=1) else: raise NotImplemented actual_value = curr_agent.critic(vf_in) # Backpropagates vf_loss = MSELoss(actual_value, target_value.detach()) # we have to retain the graph because we reuse the critic_obs for the following actor loss vf_loss.backward( retain_graph=True ) ##todo:make sure there is no leakage between the two losses # Clip gradients torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), self.grad_clip_value) # Apply critic update curr_agent.critic_optimizer.step() # UPDATES THE ACTOR --- # Resets gradient buffer curr_agent.policy_optimizer.zero_grad() # We put experience data back in the general point of view if self.alg_types[agent_i] == 'SharedMADDPG': observations.insert(agent_i, observations.pop(0)) if self.use_discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. curr_pol_out = curr_agent.policy( curr_agent.f_e(observations[agent_i])) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) else: curr_pol_out = curr_agent.policy( curr_agent.f_e(observations[agent_i])) curr_pol_vf_in = curr_pol_out # No Gumbel-softmax for continuous control # Gets state-action pair value given by the critic if self.alg_types[agent_i] in ['SharedMADDPG', 'MADDPG']: all_pol_acs = [] for i, pi, f_e, ob in zip(range(self.nagents), self.policies, self.f_es, observations): if i == agent_i: all_pol_acs.append(curr_pol_vf_in) elif self.use_discrete_action: all_pol_acs.append( onehot_from_logits(pi(f_e(ob))).detach()) else: all_pol_acs.append(pi(ob).detach()) if self.alg_types[agent_i] == 'SharedMADDPG': # the critic must take the point of vue of agent i observations.insert(0, observations.pop(agent_i)) all_pol_acs.insert(0, all_pol_acs.pop(agent_i)) vf_in = torch.cat((*observations, *all_pol_acs), dim=1) # Centralized critic for MADDPG agent elif self.alg_types[agent_i] in ['SharedDDPG', 'DDPG']: vf_in = torch.cat((observations[agent_i], curr_pol_vf_in), dim=1) else: raise NotImplemented # Computes the loss pol_loss = -torch.mean(curr_agent.critic(vf_in)) # Backpropagates pol_loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), self.grad_clip_value) # Apply actor update curr_agent.policy_optimizer.step() ## Backpropagates for feature extractor from both backprops torch.nn.utils.clip_grad_norm_(curr_agent.f_e.parameters(), self.grad_clip_value) curr_agent.f_e_optimizer.step() return pol_loss.data.cpu().numpy(), vf_loss.data.cpu().numpy()
def update_agent(self, sample, agent_i): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next_observations, and episode_end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads """ # Extract info and agent observations, actions, rewards, next_obs, dones = sample curr_agent = self.agents[agent_i] # UPDATES THE CRITIC --- # Resets gradient buffer curr_agent.critic_optimizer.zero_grad() # Gets target state-action pair (next_obs, next_actions) if self.alg_types[agent_i] in ['MADDPG', 'CoachMADDPG']: if self.use_discrete_action: # one-hot encode action all_target_actions = [onehot_from_logits(pi(nobs)) for pi, nobs in zip(self.target_policies, next_obs)] else: all_target_actions = [pi(nobs) for pi, nobs in zip(self.target_policies, next_obs)] target_vf_in = torch.cat((*next_obs, *all_target_actions), dim=1) elif self.alg_types[agent_i] == 'DDPG': if self.use_discrete_action: target_vf_in = torch.cat((next_obs[agent_i], onehot_from_logits(curr_agent.target_policy(next_obs[agent_i]))), dim=1) else: target_vf_in = torch.cat((next_obs[agent_i], curr_agent.target_policy(next_obs[agent_i])), dim=1) else: raise NotImplemented # Computes target value target_value = (rewards[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(target_vf_in) * (1 - dones.view(-1, 1))) # Computes current state-action value if self.alg_types[agent_i] in ['MADDPG', 'CoachMADDPG']: vf_in = torch.cat((*observations, *actions), dim=1) elif self.alg_types[agent_i] == 'DDPG': vf_in = torch.cat((observations[agent_i], actions[agent_i]), dim=1) else: raise NotImplemented actual_value = curr_agent.critic(vf_in) # Backpropagates vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() # Clip gradients torch.nn.utils.clip_grad_norm_(curr_agent.critic.parameters(), self.grad_clip_value) # Apply critic update curr_agent.critic_optimizer.step() # UPDATES THE ACTOR --- # Resets gradient buffer curr_agent.policy_optimizer.zero_grad() if self.use_discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. curr_pol_out = curr_agent.policy(observations[agent_i], return_embed_logits=False) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) else: curr_pol_out = curr_agent.policy(observations[agent_i], return_embed_logits=False) curr_pol_vf_in = curr_pol_out # No Gumbel-softmax for continuous control # Gets state-action pair value given by the critic if self.alg_types[agent_i] in ['MADDPG', 'CoachMADDPG']: all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, observations): if i == agent_i: all_pol_acs.append(curr_pol_vf_in) elif self.use_discrete_action: all_pol_acs.append(onehot_from_logits(pi(ob).detach())) else: all_pol_acs.append(pi(ob).detach()) vf_in = torch.cat((*observations, *all_pol_acs), dim=1) # Centralized critic for MADDPG agent elif self.alg_types[agent_i] == 'DDPG': vf_in = torch.cat((observations[agent_i], curr_pol_vf_in), dim=1) else: raise NotImplemented # Computes the loss J_PG = -torch.mean(curr_agent.critic(vf_in)) pol_loss = J_PG # Backpropagates pol_loss.backward() # Update actors # Clip gradients torch.nn.utils.clip_grad_norm_(curr_agent.policy.parameters(), self.grad_clip_value) # Apply actor update curr_agent.policy_optimizer.step() return pol_loss.data.cpu().numpy(), vf_loss.data.cpu().numpy()
def compute(self, next_obs): # acs_pi_k = [] # prob_pi_k = [] # for no, pi in zip(next_obs, self.agents): # acs_pi_k.append(gumbel_softmax(pi.policy(no), device=self.device.get_device(), hard=True)) # prob_pi_k.append(F.softmax(pi.policy(no), dim=1).unsqueeze(1)) # for stacking later, dim = [B, 1, A] # # final_obs = self.transition(torch.cat((*next_obs, *acs_pi_k), dim=1)) # # end_idx = [0] + np.cumsum([ne_ob.shape[1] for ne_ob in next_obs]).tolist() # start_end = [(start, end) for start, end in zip(end_idx, end_idx[1:])] # # # P(action distribution of agent j | k takes action taken) # action_dist = [] # action_dist dim = [num agents, batch_size, num agents - 1, action_dim] # for k, no in enumerate(next_obs): # prob_pi_j = [] # for j, pi_j in enumerate(self.agents): # if j == k: continue # computing effect on other agents # final_obs_j = final_obs[:, start_end[j][0]:start_end[j][1]] # prob_pi_j.append(F.softmax(pi_j.policy(final_obs_j), dim=1).unsqueeze(1)) # prob_pi_j = torch.cat(prob_pi_j, dim=1) # action_dist.append(prob_pi_j) # # # [P(k takes action 0) * (action distribution of agent j | k takes action 0) + ...] # marginal_action_dists = [] # for k, (no, pi) in enumerate(zip(next_obs, self.agents)): # batch_size, action_dim = acs_pi_k[k].shape # all_acs_pi_k = torch.nn.functional.one_hot(torch.arange(action_dim)).float() # for one_hot_ac in all_acs_pi_k: # # replace inside the original acs_pi_k, k's action # acs_pi_k_modified = acs_pi_k # acs_pi_k_modified[k] = one_hot_ac.unsqueeze(0).repeat(batch_size, 1) # tilde_final_obs = self.transition(torch.cat((*next_obs, *acs_pi_k_modified), dim=1)) # # for j, pi_j in enumerate(self.agents): # if j == k: continue # computing effect on other agent # tilde_final_obs_j = tilde_final_obs[:, start_end[j][0]:start_end[j][1]] # mrgn_dist_acs_j = F.softmax(pi_j.policy(tilde_final_obs_j), dim=1).unsqueeze(1) # marginal_action_dists.append() acs_src = [] prob_src = [] for no, source in zip(next_obs, self.agents): acs_src.append(gumbel_softmax(source.policy(no), device=self.device.get_device(), hard=True)) prob_src.append(gumbel_softmax(source.policy(no), device=self.device.get_device(), hard=False)) trans_in = torch.cat((*next_obs, *acs_src), dim=1) trans_out = self.transition(trans_in) prob_plan = [] end_idx = [0] + np.cumsum([ne_ob.shape[1] for ne_ob in next_obs]).tolist() start_end = [(start, end) for start, end in zip(end_idx, end_idx[1:])] for i, (no, planning) in enumerate(zip(next_obs, self.planning)): nno = trans_out[:, start_end[i][0]:start_end[i][1]] acs_ = [] for j, ac in enumerate(acs_src): if j == i: continue nno_other = trans_out[:, start_end[j][0]:start_end[j][1]] acs_.append( gumbel_softmax(self.agents[j].policy(nno_other), device=self.device.get_device(), hard=True)) acs_ = torch.cat(acs_, dim=1) plan_in = torch.cat((no, nno, acs_), dim=1) prob_plan.append(gumbel_softmax(planning(plan_in), device=self.device.get_device(), hard=False)) prob_plan = torch.cat(prob_plan, dim=1) prob_src = torch.cat(prob_src, dim=1) # for returning the si for individual agents end_idx = [0] + np.cumsum([ne_ac.shape[1] for ne_ac in acs_src]).tolist() start_end = [(start, end) for start, end in zip(end_idx, end_idx[1:])] acs_src = torch.cat(acs_src, dim=1) si = acs_src * prob_plan - acs_src * prob_src result = torch.cat([si[:, start:end] for (start, end) in start_end], dim=0) return result
def update(self, sample, agent_i, parallel=False, logger=None): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next observations, and episode end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads logger (SummaryWriter from Tensorboard-Pytorch): If passed in, important quantities will be logged """ obs, acs, rews, next_obs, dones = sample curr_agent = self.agents[agent_i] curr_agent.critic_optimizer.zero_grad() if self.alg_types[agent_i] == 'MADDPG': if self.discrete_action: # one-hot encode action all_trgt_acs = [ onehot_from_logits(pi(nobs)) for pi, nobs in zip(self.target_policies, next_obs) ] else: all_trgt_acs = [ pi(nobs) for pi, nobs in zip(self.target_policies, next_obs) ] trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1) else: # DDPG if self.discrete_action: trgt_vf_in = torch.cat( (next_obs[agent_i], onehot_from_logits( curr_agent.target_policy(next_obs[agent_i]))), dim=1) else: trgt_vf_in = torch.cat( (next_obs[agent_i], curr_agent.target_policy(next_obs[agent_i])), dim=1) target_value = (rews[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(trgt_vf_in) * (1 - dones[agent_i].view(-1, 1))) if self.alg_types[agent_i] == 'MADDPG': vf_in = torch.cat((*obs, *acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], acs[agent_i]), dim=1) actual_value = curr_agent.critic(vf_in) vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() if parallel: average_gradients(curr_agent.critic) torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5) curr_agent.critic_optimizer.step() curr_agent.policy_optimizer.zero_grad() if self.discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) else: curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = curr_pol_out if self.alg_types[agent_i] == 'MADDPG': all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): if i == agent_i: all_pol_acs.append(curr_pol_vf_in) elif self.discrete_action: all_pol_acs.append(onehot_from_logits(pi(ob))) else: all_pol_acs.append(pi(ob)) vf_in = torch.cat((*obs, *all_pol_acs), dim=1) else: # DDPG vf_in = torch.cat((obs[agent_i], curr_pol_vf_in), dim=1) pol_loss = -curr_agent.critic(vf_in).mean() pol_loss += (curr_pol_out**2).mean() * 1e-3 pol_loss.backward() if parallel: average_gradients(curr_agent.policy) torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5) curr_agent.policy_optimizer.step() if logger is not None: logger.add_scalars('agent%i/losses' % agent_i, { 'vf_loss': vf_loss, 'pol_loss': pol_loss }, self.niter)
def update(self, sample, agent_i, parallel=False, logger=None): """ Update parameters of agent model based on sample from replay buffer Inputs: sample: tuple of (observations, actions, rewards, next observations, and episode end masks) sampled randomly from the replay buffer. Each is a list with entries corresponding to each agent agent_i (int): index of agent to update parallel (bool): If true, will average gradients across threads logger (SummaryWriter from Tensorboard-Pytorch): If passed in, important quantities will be logged """ # For RNN, the obs and next_obs both have histories obs, acs, rews, next_obs, dones = sample curr_agent = self.agents[agent_i] curr_agent.critic_optimizer.zero_grad() if self.alg_types[agent_i] == 'MADDPG': if self.discrete_action: # one-hot encode action # This is original one, 'pi': policy, 'nobs' n_observations #all_trgt_acs = [onehot_from_logits(pi(nobs)) for pi, nobs in # zip(self.target_policies, next_obs)] # Original till here #-------- Expanding out for debugging --------# all_trgt_acs = [] for pi, nobs in zip(self.target_policies, next_obs): temp = onehot_from_logits(pi(nobs)) #print(temp) all_trgt_acs.append(temp) # -------- End debug -------------------------# else: all_trgt_acs = [ pi(nobs) for pi, nobs in zip(self.target_policies, next_obs) ] # Get the most current observation from the history to calculate the target value t0_next_obs = [[], [], []] for a in range(self.nagents): t0_next_obs[a] = torch.tensor(np.zeros( (next_obs[0].shape[0], 18)), dtype=torch.float) # the next_obs[0].shape[0] gives the batch size # TODO: change it to be a parameter # Only keep the current obs for critic VF for n in range(self.nagents): # for each agents for b in range( next_obs[0].shape[0]): # for the number of batches t0_next_obs[n][b][:] = next_obs[n][b][0:18] # ORIGINAL was \/ #trgt_vf_in = torch.cat((*next_obs, *all_trgt_acs), dim=1) trgt_vf_in = torch.cat((*t0_next_obs, *all_trgt_acs), dim=1) # It is working till here. Only kept the current obs for critic VF else: # DDPG # DDPG only knows the particular agent's observation and policy # Whereas, MADDPG has access to all other agents' policies # TODO: grab only the current observation to send to the critic t0_next_obs = [[], [], []] for a in range(self.nagents): t0_next_obs[a] = torch.tensor(np.zeros( (next_obs[0].shape[0], 18)), dtype=torch.float) for n in range(self.nagents): # for each agents for b in range( next_obs[0].shape[0]): # for the number of batches t0_next_obs[n][b][:] = next_obs[n][b][0:18] # Originally it would be next_obs[agent_i] instead of t0_next_obs[agent_i] if self.discrete_action: trgt_vf_in = torch.cat( (t0_next_obs[agent_i], onehot_from_logits( curr_agent.target_policy(next_obs[agent_i]))), dim=1) else: trgt_vf_in = torch.cat( (next_obs[agent_i], curr_agent.target_policy(next_obs[agent_i])), dim=1) target_value = (rews[agent_i].view(-1, 1) + self.gamma * curr_agent.target_critic(trgt_vf_in) * (1 - dones[agent_i].view(-1, 1))) ##### Just get the current observation (i.e., without history) ########## # Reason: Critic VF does not need history # Copied the same as in t0_next_obs, since BOTH obs and next_obs have HISTORIES. t0_obs = [[], [], []] for a in range(self.nagents): t0_obs[a] = torch.tensor(np.zeros((obs[0].shape[0], 18)), dtype=torch.float) for n in range(self.nagents): # for each agents for b in range(obs[0].shape[0]): # for the number of batches t0_obs[n][b][:] = obs[n][b][0:18] ################################################### if self.alg_types[agent_i] == 'MADDPG': vf_in = torch.cat((*t0_obs, *acs), dim=1) else: # DDPG #TODO: below, might have to change obs to t0_obs, when using DDPG vf_in = torch.cat((t0_obs[agent_i], acs[agent_i]), dim=1) actual_value = curr_agent.critic(vf_in) vf_loss = MSELoss(actual_value, target_value.detach()) vf_loss.backward() if parallel: average_gradients(curr_agent.critic) torch.nn.utils.clip_grad_norm(curr_agent.critic.parameters(), 0.5) curr_agent.critic_optimizer.step() curr_agent.policy_optimizer.zero_grad() if self.discrete_action: # Forward pass as if onehot (hard=True) but backprop through a differentiable # Gumbel-Softmax sample. The MADDPG paper uses the Gumbel-Softmax trick to backprop # through discrete categorical samples, but I'm not sure if that is # correct since it removes the assumption of a deterministic policy for # DDPG. Regardless, discrete policies don't seem to learn properly without it. ''' Now, we are back to forwarding policy, so we need to use obs with history ''' curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = gumbel_softmax(curr_pol_out, hard=True) # Seems to be working fine till here else: curr_pol_out = curr_agent.policy(obs[agent_i]) curr_pol_vf_in = curr_pol_out if self.alg_types[agent_i] == 'MADDPG': all_pol_acs = [] for i, pi, ob in zip(range(self.nagents), self.policies, obs): if i == agent_i: all_pol_acs.append(curr_pol_vf_in) elif self.discrete_action: all_pol_acs.append(onehot_from_logits(pi(ob))) else: all_pol_acs.append(pi(ob)) # Originally: #vf_in = torch.cat((*obs, *all_pol_acs), dim=1) vf_in = torch.cat((*t0_obs, *all_pol_acs), dim=1) else: # DDPG vf_in = torch.cat((t0_obs[agent_i], curr_pol_vf_in), dim=1) # TODO: FIX THIS pol_loss = -curr_agent.critic(vf_in).mean() pol_loss += (curr_pol_out**2).mean() * 1e-3 pol_loss.backward() if parallel: average_gradients(curr_agent.policy) torch.nn.utils.clip_grad_norm(curr_agent.policy.parameters(), 0.5) curr_agent.policy_optimizer.step() if logger is not None: logger.add_scalars('agent%i/losses' % agent_i, { 'vf_loss': vf_loss, 'pol_loss': pol_loss }, self.niter)