def main(data_path, label_path, nb_epoch, save_path, start_path=None, batch_size=2, lr=1e-3, plot_history=True): cudnn.benchmark = True train = NYUDepth(data_path, data_path, transforms=transforms.ToTensor()) hourglass = HourGlass() hourglass.cuda() optimizer = RMSprop(hourglass.parameters(), lr) if start_path: experiment = torch.load(start_path) hourglass.load_state_dict(experiment['model_state']) optimizer.load_state_dict(experiment['optimizer_state']) criterion = RelativeDepthLoss() history = fit(hourglass, train, criterion, optimizer, batch_size, nb_epoch) save_checkpoint(hourglass.state_dict(), optimizer.state_dict(), save_path) if plot_history: plt.plot(history['loss'], label='loss') plt.xlabel('epoch') plt.ylabel('relative depth loss') plt.legend() plt.show()
def main(train_data_path, train_label_path, val_data_path, val_label_path, nb_epoch, save_path, device, start_path, batch_size, lr): cudnn.benchmark = True train_data = NYUDepth(train_data_path, train_label_path, transforms=transforms.ToTensor()) val_data = NYUDepth(val_data_path, val_label_path, transforms=transforms.ToTensor()) hourglass = HourGlass() hourglass = hourglass.cuda() optimizer = RMSprop(hourglass.parameters(), lr, weight_decay=1e-5) # scheduler = MultiStepLR(optimizer, milestones=[10, 20], gamma=0.1) scheduler = None if start_path: experiment = torch.load(start_path) hourglass.load_state_dict(experiment['model_state']) optimizer.load_state_dict(experiment['optimizer_state']) criterion = RelativeDepthLoss() # save path t_now = datetime.datetime.now() t = t_now.strftime("%Y-%m-%d-%H-%M-%S") save_path = os.path.join(save_path, t) if not os.path.isdir(save_path): os.mkdir(save_path) history = fit(hourglass, train_data, val_data, criterion, optimizer, scheduler, save_path, device, batch_size=batch_size, nb_epoch=nb_epoch) # save final model save_checkpoint(hourglass.state_dict(), optimizer.state_dict(), os.path.join(save_path, "test_result.pth"))
class QLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.mac = mac self.logger = logger self.params = list(mac.parameters()) self.last_target_update_episode = 0 self.mixer = None if args.mixer == "qtran_base": self.mixer = QTranBase(args) elif args.mixer == "qtran_alt": raise Exception("Not implemented here!") self.params += list(self.mixer.parameters()) self.target_mixer = copy.deepcopy(self.mixer) self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC self.target_mac = copy.deepcopy(mac) self.log_stats_t = -self.args.learner_log_interval - 1 def train(self, batch: EpisodeBatch, t_env: int, episode_num: int, show_demo=False, save_data=None): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values mac_out = [] mac_hidden_states = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_hidden_states.append(self.mac.hidden_states) mac_out = torch.stack(mac_out, dim=1) # Concat over time mac_hidden_states = torch.stack(mac_hidden_states, dim=1) mac_hidden_states = mac_hidden_states.reshape(batch.batch_size, self.args.n_agents, batch.max_seq_length, -1).transpose(1, 2) #btav # Pick the Q-Values for the actions taken by each agent chosen_action_qvals = torch.gather(mac_out[:, :-1], dim=3, index=actions).squeeze( 3) # Remove the last dim x_mac_out = mac_out.clone().detach() x_mac_out[avail_actions == 0] = -9999999 max_action_qvals, max_action_index = x_mac_out[:, :-1].max(dim=3) max_action_index = max_action_index.detach().unsqueeze(3) is_max_action = (max_action_index == actions).int().float() if show_demo: q_i_data = chosen_action_qvals.detach().cpu().numpy() q_data = (max_action_qvals - chosen_action_qvals).detach().cpu().numpy() # Calculate the Q-Values necessary for the target target_mac_out = [] target_mac_hidden_states = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = self.target_mac.forward(batch, t=t) target_mac_out.append(target_agent_outs) target_mac_hidden_states.append(self.target_mac.hidden_states) # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = torch.stack(target_mac_out[:], dim=1) # Concat across time target_mac_hidden_states = torch.stack(target_mac_hidden_states, dim=1) target_mac_hidden_states = target_mac_hidden_states.reshape( batch.batch_size, self.args.n_agents, batch.max_seq_length, -1).transpose(1, 2) #btav # Mask out unavailable actions target_mac_out[avail_actions[:, :] == 0] = -9999999 # From OG deepmarl mac_out_maxs = mac_out.clone() mac_out_maxs[avail_actions == 0] = -9999999 # Best joint action computed by target agents target_max_actions = target_mac_out.max(dim=3, keepdim=True)[1] # Best joint-action computed by regular agents max_actions_qvals, max_actions_current = mac_out_maxs[:, :].max( dim=3, keepdim=True) if self.args.mixer == "qtran_base": # -- TD Loss -- # Joint-action Q-Value estimates joint_qs, vs = self.mixer(batch[:, :-1], mac_hidden_states[:, :-1]) # Need to argmax across the target agents' actions to compute target joint-action Q-Values if self.args.double_q: max_actions_current_ = torch.zeros( size=(batch.batch_size, batch.max_seq_length, self.args.n_agents, self.args.n_actions), device=batch.device) max_actions_current_onehot = max_actions_current_.scatter( 3, max_actions_current[:, :], 1) max_actions_onehot = max_actions_current_onehot else: max_actions = torch.zeros( size=(batch.batch_size, batch.max_seq_length, self.args.n_agents, self.args.n_actions), device=batch.device) max_actions_onehot = max_actions.scatter( 3, target_max_actions[:, :], 1) target_joint_qs, target_vs = self.target_mixer( batch[:, 1:], hidden_states=target_mac_hidden_states[:, 1:], actions=max_actions_onehot[:, 1:]) # Td loss targets td_targets = rewards.reshape(-1, 1) + self.args.gamma * ( 1 - terminated.reshape(-1, 1)) * target_joint_qs td_error = (joint_qs - td_targets.detach()) masked_td_error = td_error * mask.reshape(-1, 1) td_loss = (masked_td_error**2).sum() / mask.sum() # -- TD Loss -- # -- Opt Loss -- # Argmax across the current agents' actions if not self.args.double_q: # Already computed if we're doing double Q-Learning max_actions_current_ = torch.zeros( size=(batch.batch_size, batch.max_seq_length, self.args.n_agents, self.args.n_actions), device=batch.device) max_actions_current_onehot = max_actions_current_.scatter( 3, max_actions_current[:, :], 1) max_joint_qs, _ = self.mixer( batch[:, :-1], mac_hidden_states[:, :-1], actions=max_actions_current_onehot[:, :-1] ) # Don't use the target network and target agent max actions as per author's email # max_actions_qvals = torch.gather(mac_out[:, :-1], dim=3, index=max_actions_current[:,:-1]) opt_error = max_actions_qvals[:, :-1].sum(dim=2).reshape( -1, 1) - max_joint_qs.detach() + vs masked_opt_error = opt_error * mask.reshape(-1, 1) opt_loss = (masked_opt_error**2).sum() / mask.sum() # -- Opt Loss -- # -- Nopt Loss -- # target_joint_qs, _ = self.target_mixer(batch[:, :-1]) nopt_values = chosen_action_qvals.sum(dim=2).reshape( -1, 1) - joint_qs.detach( ) + vs # Don't use target networks here either nopt_error = nopt_values.clamp(max=0) masked_nopt_error = nopt_error * mask.reshape(-1, 1) nopt_loss = (masked_nopt_error**2).sum() / mask.sum() # -- Nopt loss -- elif self.args.mixer == "qtran_alt": raise Exception("Not supported yet.") if show_demo: tot_q_data = joint_qs.detach().cpu().numpy() tot_target = td_targets.detach().cpu().numpy() bs = q_data.shape[0] tot_q_data = tot_q_data.reshape(bs, -1) tot_target = tot_target.reshape(bs, -1) print('action_pair_%d_%d' % (save_data[0], save_data[1]), np.squeeze(q_data[:, 0]), np.squeeze(q_i_data[:, 0]), np.squeeze(tot_q_data[:, 0]), np.squeeze(tot_target[:, 0])) self.logger.log_stat( 'action_pair_%d_%d' % (save_data[0], save_data[1]), np.squeeze(tot_q_data[:, 0]), t_env) return loss = td_loss + self.args.opt_loss * opt_loss + self.args.nopt_min_loss * nopt_loss masked_hit_prob = torch.mean(is_max_action, dim=2) * mask hit_prob = masked_hit_prob.sum() / mask.sum() # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.optimiser.step() if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("hit_prob", hit_prob.item(), t_env) self.logger.log_stat("td_loss", td_loss.item(), t_env) self.logger.log_stat("opt_loss", opt_loss.item(), t_env) self.logger.log_stat("nopt_loss", nopt_loss.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) if self.args.mixer == "qtran_base": mask_elems = mask.sum().item() self.logger.log_stat( "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat( "td_targets", ((masked_td_error).sum().item() / mask_elems), t_env) self.logger.log_stat("td_chosen_qs", (joint_qs.sum().item() / mask_elems), t_env) self.logger.log_stat("v_mean", (vs.sum().item() / mask_elems), t_env) self.logger.log_stat( "agent_indiv_qs", ((chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents)), t_env) self.log_stats_t = t_env def _update_targets(self): self.target_mac.load_state(self.mac) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() if self.mixer is not None: self.mixer.cuda() self.target_mixer.cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: torch.save(self.mixer.state_dict(), "{}/mixer.torch".format(path)) torch.save(self.optimiser.state_dict(), "{}/opt.torch".format(path)) def load_models(self, path): self.mac.load_models(path) # Not quite right but I don't want to save target networks self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict( torch.load("{}/mixer.torch".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict( torch.load("{}/opt.torch".format(path), map_location=lambda storage, loc: storage))
class OffPGLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.n_agents = args.n_agents self.n_actions = args.n_actions self.mac = mac self.logger = logger self.last_target_update_step = 0 self.critic_training_steps = 0 self.log_stats_t = -self.args.learner_log_interval - 1 self.critic = OffPGCritic(scheme, args) self.mixer = QMixer(args) self.target_critic = copy.deepcopy(self.critic) self.target_mixer = copy.deepcopy(self.mixer) self.agent_params = list(mac.parameters()) self.critic_params = list(self.critic.parameters()) self.mixer_params = list(self.mixer.parameters()) self.params = self.agent_params + self.critic_params self.c_params = self.critic_params + self.mixer_params self.agent_optimiser = RMSprop(params=self.agent_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) self.critic_optimiser = RMSprop(params=self.critic_params, lr=args.critic_lr, alpha=args.optim_alpha, eps=args.optim_eps) self.mixer_optimiser = RMSprop(params=self.mixer_params, lr=args.critic_lr, alpha=args.optim_alpha, eps=args.optim_eps) def train(self, batch: EpisodeBatch, t_env: int, log): # Get the relevant quantities bs = batch.batch_size max_t = batch.max_seq_length actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() avail_actions = batch["avail_actions"][:, :-1] mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) mask = mask.repeat(1, 1, self.n_agents).view(-1) states = batch["state"][:, :-1] #build q inputs = self.critic._build_inputs(batch, bs, max_t) q_vals = self.critic.forward(inputs).detach()[:, :-1] mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length - 1): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Mask out unavailable actions, renormalise (as in action selection) mac_out[avail_actions == 0] = 0 mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True) mac_out[avail_actions == 0] = 0 # Calculated baseline q_taken = th.gather(q_vals, dim=3, index=actions).squeeze(3) pi = mac_out.view(-1, self.n_actions) baseline = th.sum(mac_out * q_vals, dim=-1).view(-1).detach() # Calculate policy grad with mask pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1) pi_taken[mask == 0] = 1.0 log_pi_taken = th.log(pi_taken) coe = self.mixer.k(states).view(-1) advantages = (q_taken.view(-1) - baseline).detach() coma_loss = -( (coe * advantages * log_pi_taken) * mask).sum() / mask.sum() # Optimise agents self.agent_optimiser.zero_grad() coma_loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip) self.agent_optimiser.step() #compute parameters sum for debugging p_sum = 0. for p in self.agent_params: p_sum += p.data.abs().sum().item() / 100.0 if t_env - self.log_stats_t >= self.args.learner_log_interval: ts_logged = len(log["critic_loss"]) for key in [ "critic_loss", "critic_grad_norm", "td_error_abs", "q_taken_mean", "target_mean", "q_max_mean", "q_min_mean", "q_max_var", "q_min_var" ]: self.logger.log_stat(key, sum(log[key]) / ts_logged, t_env) self.logger.log_stat("q_max_first", log["q_max_first"], t_env) self.logger.log_stat("q_min_first", log["q_min_first"], t_env) #self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("coma_loss", coma_loss.item(), t_env) self.logger.log_stat("agent_grad_norm", grad_norm, t_env) self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env) self.log_stats_t = t_env def train_critic(self, on_batch, best_batch=None, log=None): bs = on_batch.batch_size max_t = on_batch.max_seq_length rewards = on_batch["reward"][:, :-1] actions = on_batch["actions"][:, :] terminated = on_batch["terminated"][:, :-1].float() mask = on_batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = on_batch["avail_actions"][:] states = on_batch["state"] #build_target_q target_inputs = self.target_critic._build_inputs(on_batch, bs, max_t) target_q_vals = self.target_critic.forward(target_inputs).detach() targets_taken = self.target_mixer( th.gather(target_q_vals, dim=3, index=actions).squeeze(3), states) target_q = build_td_lambda_targets(rewards, terminated, mask, targets_taken, self.n_agents, self.args.gamma, self.args.td_lambda).detach() inputs = self.critic._build_inputs(on_batch, bs, max_t) mac_out = [] self.mac.init_hidden(bs) for i in range(max_t): agent_outs = self.mac.forward(on_batch, t=i) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1).detach() # Mask out unavailable actions, renormalise (as in action selection) mac_out[avail_actions == 0] = 0 mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True) mac_out[avail_actions == 0] = 0 if best_batch is not None: best_target_q, best_inputs, best_mask, best_actions, best_mac_out = self.train_critic_best( best_batch) log["best_reward"] = th.mean( best_batch["reward"][:, :-1].squeeze(2).sum(-1), dim=0) target_q = th.cat((target_q, best_target_q), dim=0) inputs = th.cat((inputs, best_inputs), dim=0) mask = th.cat((mask, best_mask), dim=0) actions = th.cat((actions, best_actions), dim=0) states = th.cat((states, best_batch["state"]), dim=0) mac_out = th.cat((mac_out, best_mac_out), dim=0) #train critic mac_out = mac_out.detach() for t in range(max_t - 1): mask_t = mask[:, t:t + 1] if mask_t.sum() < 0.5: continue k = self.mixer.k(states[:, t:t + 1]).unsqueeze(3) #b = self.mixer.b(states[:, t:t+1]) q_vals = self.critic.forward(inputs[:, t:t + 1]) q_ori = q_vals q_vals = th.gather(q_vals, 3, index=actions[:, t:t + 1]).squeeze(3) q_vals = self.mixer.forward(q_vals, states[:, t:t + 1]) target_q_t = target_q[:, t:t + 1].detach() q_err = (q_vals - target_q_t) * mask_t critic_loss = (q_err**2).sum() / mask_t.sum() #Here introduce the loss for Qi v_vals = th.sum(q_ori * mac_out[:, t:t + 1], dim=3, keepdim=True) ad_vals = q_ori - v_vals goal = th.sum(k * v_vals, dim=2, keepdim=True) + k * ad_vals goal_err = (goal - q_ori) * mask_t goal_loss = 0.1 * (goal_err** 2).sum() / mask_t.sum() / self.args.n_actions #critic_loss += goal_loss self.critic_optimiser.zero_grad() self.mixer_optimiser.zero_grad() critic_loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.c_params, self.args.grad_norm_clip) self.critic_optimiser.step() self.mixer_optimiser.step() self.critic_training_steps += 1 log["critic_loss"].append(critic_loss.item()) log["critic_grad_norm"].append(grad_norm) mask_elems = mask_t.sum().item() log["td_error_abs"].append((q_err.abs().sum().item() / mask_elems)) log["target_mean"].append( (target_q_t * mask_t).sum().item() / mask_elems) log["q_taken_mean"].append( (q_vals * mask_t).sum().item() / mask_elems) log["q_max_mean"].append( (th.mean(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems) log["q_min_mean"].append( (th.mean(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems) log["q_max_var"].append( (th.var(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems) log["q_min_var"].append( (th.var(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems) if (t == 0): log["q_max_first"] = ( th.mean(q_ori.max(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems log["q_min_first"] = ( th.mean(q_ori.min(dim=3)[0], dim=2, keepdim=True) * mask_t).sum().item() / mask_elems #update target network if (self.critic_training_steps - self.last_target_update_step ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_step = self.critic_training_steps def train_critic_best(self, batch): bs = batch.batch_size max_t = batch.max_seq_length rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"][:] states = batch["state"] # pr for all actions of the episode mac_out = [] self.mac.init_hidden(bs) for i in range(max_t): agent_outs = self.mac.forward(batch, t=i) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1).detach() # Mask out unavailable actions, renormalise (as in action selection) mac_out[avail_actions == 0] = 0 mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True) mac_out[avail_actions == 0] = 0 critic_mac = th.gather(mac_out, 3, actions).squeeze(3).prod(dim=2, keepdim=True) #target_q take target_inputs = self.target_critic._build_inputs(batch, bs, max_t) target_q_vals = self.target_critic.forward(target_inputs).detach() targets_taken = self.target_mixer( th.gather(target_q_vals, dim=3, index=actions).squeeze(3), states) #expected q exp_q = self.build_exp_q(target_q_vals, mac_out, states).detach() # td-error targets_taken[:, -1] = targets_taken[:, -1] * (1 - th.sum(terminated, dim=1)) exp_q[:, -1] = exp_q[:, -1] * (1 - th.sum(terminated, dim=1)) targets_taken[:, :-1] = targets_taken[:, :-1] * mask exp_q[:, :-1] = exp_q[:, :-1] * mask td_q = (rewards + self.args.gamma * exp_q[:, 1:] - targets_taken[:, :-1]) * mask #compute target target_q = build_target_q(td_q, targets_taken[:, :-1], critic_mac, mask, self.args.gamma, self.args.tb_lambda, self.args.step).detach() inputs = self.critic._build_inputs(batch, bs, max_t) return target_q, inputs, mask, actions, mac_out def build_exp_q(self, target_q_vals, mac_out, states): target_exp_q_vals = th.sum(target_q_vals * mac_out, dim=3) target_exp_q_vals = self.target_mixer.forward(target_exp_q_vals, states) return target_exp_q_vals def _update_targets(self): self.target_critic.load_state_dict(self.critic.state_dict()) self.target_mixer.load_state_dict(self.mixer.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.critic.cuda() self.mixer.cuda() self.target_critic.cuda() self.target_mixer.cuda() def save_models(self, path): self.mac.save_models(path) th.save(self.critic.state_dict(), "{}/critic.th".format(path)) th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path)) th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path)) th.save(self.mixer_optimiser.state_dict(), "{}/mixer_opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) self.critic.load_state_dict( th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage)) self.mixer.load_state_dict( th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) # Not quite right but I don't want to save target networks # self.target_critic.load_state_dict(self.critic.agent.state_dict()) self.target_mixer.load_state_dict(self.mixer.state_dict()) self.agent_optimiser.load_state_dict( th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage)) self.critic_optimiser.load_state_dict( th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage)) self.mixer_optimiser.load_state_dict( th.load("{}/mixer_opt.th".format(path), map_location=lambda storage, loc: storage))
class RODELearner: def __init__(self, mac, scheme, logger, args): self.args = args self.mac = mac self.logger = logger self.n_agents = args.n_agents self.params = list(mac.parameters()) self.last_target_update_episode = 0 self.mixer = None if args.mixer is not None: if args.mixer == "vdn": self.mixer = VDNMixer() elif args.mixer == "qmix": self.mixer = QMixer(args) else: raise ValueError("Mixer {} not recognised.".format(args.mixer)) self.params += list(self.mixer.parameters()) self.target_mixer = copy.deepcopy(self.mixer) self.role_mixer = None if args.role_mixer is not None: if args.role_mixer == "vdn": self.role_mixer = VDNMixer() elif args.role_mixer == "qmix": self.role_mixer = QMixer(args) else: raise ValueError("Role Mixer {} not recognised.".format( args.role_mixer)) self.params += list(self.role_mixer.parameters()) self.target_role_mixer = copy.deepcopy(self.role_mixer) self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC self.target_mac = copy.deepcopy(mac) self.log_stats_t = -self.args.learner_log_interval - 1 self.role_interval = args.role_interval self.device = self.args.device self.role_action_spaces_updated = True # action encoder self.action_encoder_params = list(self.mac.action_encoder_params()) self.action_encoder_optimiser = RMSprop( params=self.action_encoder_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # role_avail_actions = batch["role_avail_actions"] roles_shape_o = batch["roles"][:, :-1].shape role_at = int(np.ceil(roles_shape_o[1] / self.role_interval)) role_t = role_at * self.role_interval roles_shape = list(roles_shape_o) roles_shape[1] = role_t roles = th.zeros(roles_shape).to(self.device) roles[:, :roles_shape_o[1]] = batch["roles"][:, :-1] roles = roles.view(batch.batch_size, role_at, self.role_interval, self.n_agents, -1)[:, :, 0] # Calculate estimated Q-Values mac_out = [] role_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs, role_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) if t % self.role_interval == 0 and t < batch.max_seq_length - 1: role_out.append(role_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time role_out = th.stack(role_out, dim=1) # Concat over time # Pick the Q-Values for the actions taken by each agent chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze( 3) # Remove the last dim chosen_role_qvals = th.gather(role_out, dim=3, index=roles.long()).squeeze(3) # Calculate the Q-Values necessary for the target target_mac_out = [] target_role_out = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs, target_role_outs = self.target_mac.forward( batch, t=t) target_mac_out.append(target_agent_outs) if t % self.role_interval == 0 and t < batch.max_seq_length - 1: target_role_out.append(target_role_outs) target_role_out.append( th.zeros(batch.batch_size, self.n_agents, self.mac.n_roles).to(self.device)) # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time target_role_out = th.stack(target_role_out[1:], dim=1) # Mask out unavailable actions target_mac_out[avail_actions[:, 1:] == 0] = -9999999 # target_mac_out[role_avail_actions[:, 1:] == 0] = -9999999 # Max over target Q-Values if self.args.double_q: # Get actions that maximise live Q (for double q-learning) mac_out_detach = mac_out.clone().detach() mac_out_detach[avail_actions == 0] = -9999999 # mac_out_detach[role_avail_actions == 0] = -9999999 cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1] target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) role_out_detach = role_out.clone().detach() role_out_detach = th.cat( [role_out_detach[:, 1:], role_out_detach[:, 0:1]], dim=1) cur_max_roles = role_out_detach.max(dim=3, keepdim=True)[1] target_role_max_qvals = th.gather(target_role_out, 3, cur_max_roles).squeeze(3) else: target_max_qvals = target_mac_out.max(dim=3)[0] target_role_max_qvals = target_role_out.max(dim=3)[0] # Mix if self.mixer is not None: chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:]) if self.role_mixer is not None: state_shape_o = batch["state"][:, :-1].shape state_shape = list(state_shape_o) state_shape[1] = role_t role_states = th.zeros(state_shape).to(self.device) role_states[:, :state_shape_o[1]] = batch["state"][:, :-1].detach( ).clone() role_states = role_states.view(batch.batch_size, role_at, self.role_interval, -1)[:, :, 0] chosen_role_qvals = self.role_mixer(chosen_role_qvals, role_states) role_states = th.cat([role_states[:, 1:], role_states[:, 0:1]], dim=1) target_role_max_qvals = self.target_role_mixer( target_role_max_qvals, role_states) # Calculate 1-step Q-Learning targets targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals rewards_shape = list(rewards.shape) rewards_shape[1] = role_t role_rewards = th.zeros(rewards_shape).to(self.device) role_rewards[:, :rewards.shape[1]] = rewards.detach().clone() role_rewards = role_rewards.view(batch.batch_size, role_at, self.role_interval).sum(dim=-1, keepdim=True) # role_terminated terminated_shape_o = terminated.shape terminated_shape = list(terminated_shape_o) terminated_shape[1] = role_t role_terminated = th.zeros(terminated_shape).to(self.device) role_terminated[:, :terminated_shape_o[1]] = terminated.detach().clone( ) role_terminated = role_terminated.view( batch.batch_size, role_at, self.role_interval).sum(dim=-1, keepdim=True) # role_terminated role_targets = role_rewards + self.args.gamma * ( 1 - role_terminated) * target_role_max_qvals # Td-error td_error = (chosen_action_qvals - targets.detach()) role_td_error = (chosen_role_qvals - role_targets.detach()) mask = mask.expand_as(td_error) mask_shape = list(mask.shape) mask_shape[1] = role_t role_mask = th.zeros(mask_shape).to(self.device) role_mask[:, :mask.shape[1]] = mask.detach().clone() role_mask = role_mask.view(batch.batch_size, role_at, self.role_interval, -1)[:, :, 0] # 0-out the targets that came from padded data masked_td_error = td_error * mask masked_role_td_error = role_td_error * role_mask # Normal L2 loss, take mean over actual data loss = (masked_td_error**2).sum() / mask.sum() role_loss = (masked_role_td_error**2).sum() / role_mask.sum() loss += role_loss # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.optimiser.step() pred_obs_loss = None pred_r_loss = None pred_grad_norm = None if self.role_action_spaces_updated: # train action encoder no_pred = [] r_pred = [] for t in range(batch.max_seq_length): no_preds, r_preds = self.mac.action_repr_forward(batch, t=t) no_pred.append(no_preds) r_pred.append(r_preds) no_pred = th.stack(no_pred, dim=1)[:, :-1] # Concat over time r_pred = th.stack(r_pred, dim=1)[:, :-1] no = batch["obs"][:, 1:].detach().clone() repeated_rewards = batch["reward"][:, :-1].detach().clone( ).unsqueeze(2).repeat(1, 1, self.n_agents, 1) pred_obs_loss = th.sqrt(((no_pred - no)**2).sum(dim=-1)).mean() pred_r_loss = ((r_pred - repeated_rewards)**2).mean() pred_loss = pred_obs_loss + 10 * pred_r_loss self.action_encoder_optimiser.zero_grad() pred_loss.backward() pred_grad_norm = th.nn.utils.clip_grad_norm_( self.action_encoder_params, self.args.grad_norm_clip) self.action_encoder_optimiser.step() if t_env > self.args.role_action_spaces_update_start: self.mac.update_role_action_spaces() if 'noar' in self.args.mac: self.target_mac.role_selector.update_roles( self.mac.n_roles) self.role_action_spaces_updated = False self._update_targets() self.last_target_update_episode = episode_num if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", (loss - role_loss).item(), t_env) self.logger.log_stat("role_loss", role_loss.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) if pred_obs_loss is not None: self.logger.log_stat("pred_obs_loss", pred_obs_loss.item(), t_env) self.logger.log_stat("pred_r_loss", pred_r_loss.item(), t_env) self.logger.log_stat("action_encoder_grad_norm", pred_grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat( "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("role_q_taken_mean", (chosen_role_qvals * role_mask).sum().item() / (role_mask.sum().item() * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.log_stats_t = t_env def _update_targets(self): self.target_mac.load_state(self.mac) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) if self.role_mixer is not None: self.target_role_mixer.load_state_dict( self.role_mixer.state_dict()) self.target_mac.role_action_spaces_updated = self.role_action_spaces_updated self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() if self.mixer is not None: self.mixer.cuda() self.target_mixer.cuda() if self.role_mixer is not None: self.role_mixer.cuda() self.target_role_mixer.cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) if self.role_mixer is not None: th.save(self.role_mixer.state_dict(), "{}/role_mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) th.save(self.action_encoder_optimiser.state_dict(), "{}/action_repr_opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) # Not quite right but I don't want to save target networks self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict( th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) if self.role_mixer is not None: self.role_mixer.load_state_dict( th.load("{}/role_mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict( th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage)) self.action_encoder_optimiser.load_state_dict( th.load("{}/action_repr_opt.th".format(path), map_location=lambda storage, loc: storage))
class MAACLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.n_agents = args.n_agents self.n_actions = args.n_actions self.mac = mac self.logger = logger self.last_target_update_step = 0 self.critic_training_steps = 0 self.log_stats_t = -self.args.learner_log_interval - 1 self.critic = MAACCritic(scheme, args) self.target_critic = copy.deepcopy(self.critic) self.policies = partial(self.mac.forward, target=False) self.target_policies = partial(self.mac.forward, target=True) self.agent_params = list(mac.parameters()) self.critic_params = list(self.critic.parameters()) self.params = self.agent_params + self.critic_params self.gamma = args.gamma self.tau = args.tau self.reward_scale = args.reward_scale self.soft = args.soft self.agent_optimisers = [] for i in range(self.n_agents): agent_optimiser = RMSprop(params=self.agent_params[i], lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) self.agent_optimisers.append(agent_optimiser) self.critic_optimiser = RMSprop(params=self.critic_params, lr=args.critic_lr, alpha=args.optim_alpha, eps=args.optim_eps) def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities bs = batch.batch_size max_t = batch.max_seq_length terminated = batch["terminated"].float() mask = batch["filled"].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) rewards = batch["reward"][:, :-1].unsqueeze(0).expand( self.n_agents, -1, -1, -1).reshape( self.n_agents, batch.batch_size * (batch.max_seq_length-1), -1) terminated = terminated[:, :-1].unsqueeze(0).expand( self.n_agents, -1, -1, -1).reshape( self.n_agents, batch.batch_size * (batch.max_seq_length-1), -1) mask = mask[:, :-1].unsqueeze(0).expand( self.n_agents, -1, -1, -1).reshape( self.n_agents, batch.batch_size * (batch.max_seq_length-1), -1) # avail_actions = batch["avail_actions"][:, :-1] critic_train_stats = self._train_critic(batch, rewards, terminated, mask, bs, max_t) self.mac.init_hidden(batch.batch_size, target=False) samp_acs = [None] * self.n_agents all_probs = [None] * self.n_agents all_log_pis = [None] * self.n_agents all_pol_regs = [None] * self.n_agents all_pol_ents = [None] * self.n_agents # resample the current state's action use the behavior policy # note that we need to remove the resample action which is t = batch.max_seq_length for t in range(max_t): all_agent_rets = self.policies(batch, t=t, return_extras=True, return_all_probs=True, return_log_pi=True, regularize=True, return_entropy=True) for i in range(self.n_agents): curr_ac_t, probs_t, log_pi_t, pol_regs_t, ent_t = all_agent_rets[i] if t > 0: samp_acs[i] = th.cat((samp_acs[i], curr_ac_t), 0) all_probs[i] = th.cat((all_probs[i], probs_t), 0) all_log_pis[i] = th.cat((all_log_pis[i], log_pi_t), 0) # remember that we need to remove the last timestep if t < max_t - 1: all_pol_regs[i].append(pol_regs_t[0]) all_pol_ents[i].append(ent_t) else: samp_acs[i] = curr_ac_t all_probs[i] = probs_t all_log_pis[i] = log_pi_t all_pol_regs[i] = pol_regs_t all_pol_ents[i] = [ent_t] for i in range(self.n_agents): all_pol_regs[i] = [th.stack(all_pol_regs[i]).mean()] mean_policy_entropy = stats.mean( [th.stack(pol_ents).mean().item() for pol_ents in all_pol_ents]) # construct resample batch, i.e. replace the actions (and actions_onehot) # with the above resample actions resample_batch = copy.deepcopy(batch) # reshape the next action to match the shape in next batch reshaped_resample_acs = None for i in range(self.n_agents): _reshaped_resample_ac = samp_acs[i].reshape(bs, max_t, -1).unsqueeze(2) if i > 0: reshaped_resample_acs = th.cat( (reshaped_resample_acs, _reshaped_resample_ac), 2) else: reshaped_resample_acs = _reshaped_resample_ac # construct the resample action onehot according to resample action # the shape of resample action onehot also need to match the shape in resample batch reshaped_resample_acs_onehot = resample_batch.data.transition_data[ 'actions_onehot'].clone().fill_(0) reshaped_resample_acs_onehot.scatter_(3, reshaped_resample_acs, 1) reshaped_dict = {'actions': reshaped_resample_acs, 'actions_onehot': reshaped_resample_acs_onehot} for key in resample_batch.scheme.keys(): if key in ('actions', 'actions_onehot'): resample_batch.data.transition_data[key] = reshaped_dict[key][:, 1:] else: resample_batch.data.transition_data[key] = resample_batch[key][:, 1:] resample_batch.max_seq_length -= 1 # construct pre batch all_probs and all_log_pis, i.e. remove the first timestep for i in range(self.n_agents): all_probs[i] = all_probs[i].reshape(bs, max_t, -1)[:, 1:] all_probs[i] = all_probs[i].reshape(bs * (max_t - 1), -1) all_log_pis[i] = all_log_pis[i].reshape(bs, max_t, -1)[:, 1:] all_log_pis[i] = all_log_pis[i].reshape(bs * (max_t - 1), -1) grad_norms = [] pol_losses = [] advantages = [] mask_elems = mask.sum().item() critic_rets = self.critic(resample_batch, return_all_q=True) for a_i, probs, log_pi, pol_regs, (q, all_q) in zip( range(self.n_agents), all_probs, all_log_pis, all_pol_regs, critic_rets): curr_agent = self.mac.agents[a_i] v = (all_q * probs).sum(dim=1, keepdim=True) pol_target = q - v advantages.append((pol_target * mask).sum().item() / mask_elems) if self.soft: pol_loss = ((log_pi * ( log_pi / self.reward_scale - pol_target).detach()) * mask).sum() / mask.sum() else: pol_loss = ((log_pi * (-pol_target).detach()) * mask).sum() / mask.sum() for reg in pol_regs: pol_loss += 1e-3 * reg # policy regularization # don't want critic to accumulate gradients from policy loss disable_gradients(self.critic) pol_loss.backward(retain_graph=True) enable_gradients(self.critic) pol_losses.append(pol_loss.item()) grad_norm = th.nn.utils.clip_grad_norm_( curr_agent.parameters(), 0.5) grad_norms.append(grad_norm) self.agent_optimisers[a_i].step() self.agent_optimisers[a_i].zero_grad() if (self.critic_training_steps - self.last_target_update_step) /\ self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_step = self.critic_training_steps if t_env - self.log_stats_t >= self.args.learner_log_interval: ts_logged = len(critic_train_stats["critic_loss"]) for key in ["critic_loss", "critic_grad_norm", \ "td_error_abs", "q_taken_mean", "target_mean"]: self.logger.log_stat(key, sum(critic_train_stats[key])/ts_logged, t_env) self.logger.log_stat("advantage_mean", stats.mean(advantages), t_env) self.logger.log_stat("maac_loss", stats.mean(pol_losses), t_env) self.logger.log_stat("agent_grad_norm", stats.mean(grad_norms), t_env) # self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("policy_entropy", mean_policy_entropy, t_env) self.log_stats_t = t_env def _train_critic(self, batch, rewards, terminated, mask, bs, max_t): """ Update central critic for all agents """ self.mac.init_hidden(bs, target=True) next_acs = [None] * self.n_agents next_log_pis = [None] * self.n_agents # get the next state's action use the target policy # note that we need to remove the next action which is t = 0 for t in range(max_t): all_agent_rets = self.target_policies( batch, t=t, return_extras=True, return_log_pi=True) for i in range(self.n_agents): curr_next_ac_t, curr_next_log_pi_t = all_agent_rets[i] if t > 0: next_acs[i] = th.cat((next_acs[i], curr_next_ac_t), 0) next_log_pis[i] = th.cat((next_log_pis[i], curr_next_log_pi_t), 0) else: next_acs[i] = curr_next_ac_t next_log_pis[i] = curr_next_log_pi_t # construct next batch, i.e. remove the first timestep in old batch, # and replace the actions (and actions_onehot) with the above next actions next_batch = copy.deepcopy(batch) # reshape the next action to match the shape in next batch reshaped_next_acs = None for i in range(self.n_agents): _reshaped_next_ac = next_acs[i].reshape(bs, max_t, -1).unsqueeze(2) if i > 0: reshaped_next_acs = th.cat((reshaped_next_acs, _reshaped_next_ac), 2) else: reshaped_next_acs = _reshaped_next_ac # construct the next action onehot according to next action # the shape of next action onehot also need to match the shape in next batch reshaped_next_acs_onehot = next_batch.data.transition_data[ 'actions_onehot'].clone().fill_(0) reshaped_next_acs_onehot.scatter_(3, reshaped_next_acs, 1) reshaped_dict = {'actions': reshaped_next_acs, 'actions_onehot': reshaped_next_acs_onehot} for key in next_batch.scheme.keys(): if key in ('actions', 'actions_onehot'): next_batch.data.transition_data[key] = reshaped_dict[key][:, 1:] else: next_batch.data.transition_data[key] = next_batch[key][:, 1:] next_batch.max_seq_length -= 1 # construct pre batch, i.e. remove the last timestep in old batch, pre_batch = copy.deepcopy(batch) for key in pre_batch.scheme.keys(): pre_batch.data.transition_data[key] = pre_batch[key][:, 1:] pre_batch.max_seq_length -= 1 # calculate next_qs next_qs = self.target_critic(next_batch) # calculate current_qs critic_rets = self.critic(pre_batch, regularize=True) # construct next batch next_log_pis, i.e. remove the first timestep for i in range(self.n_agents): next_log_pis[i] = next_log_pis[i].reshape(bs, max_t, -1)[:, 1:] next_log_pis[i] = next_log_pis[i].reshape(bs * (max_t - 1), -1) # construct mask, i.e. remove the first timestep running_log = { "critic_loss": [], "critic_grad_norm": [], "td_error_abs": [], "target_mean": [], "q_taken_mean": [], } q_loss = 0 td_error = 0 abs_td_errors = [] q_takens = [] targets = [] mask_elems = mask.sum().item() for a_i, nq, log_pi, (pq, regs) in zip( range(self.n_agents), next_qs, next_log_pis, critic_rets): target_q = (rewards[a_i] + self.gamma * nq * (1 - terminated[a_i])) td_error += pq - target_q.detach() abs_td_errors.append((( pq - target_q.detach()).abs() * mask).sum().item() / mask_elems) q_takens.append((pq * mask).sum().item() / mask_elems) targets.append((target_q * mask).sum().item() / mask_elems) if self.soft: target_q -= log_pi / self.reward_scale masked_td_error = td_error * mask q_loss = (masked_td_error ** 2).sum() / mask.sum() for reg in regs: q_loss += reg # regularizing attention q_loss.backward() self.critic.scale_shared_grads() grad_norm = th.nn.utils.clip_grad_norm_( self.critic.parameters(), 10 * self.n_agents) self.critic_optimiser.step() self.critic_training_steps += (max_t - 1) self.critic_optimiser.zero_grad() running_log["critic_loss"].append(q_loss.item()) running_log["critic_grad_norm"].append(grad_norm) running_log["td_error_abs"].append(stats.mean(abs_td_errors)) running_log["q_taken_mean"].append(stats.mean(q_takens)) running_log["target_mean"].append(stats.mean(targets)) return running_log def _update_targets(self): soft_update(self.target_critic, self.critic, self.tau) for i in range(self.n_agents): soft_update(self.mac.target_agents[i], self.mac.agents[i], self.tau) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.critic.cuda() self.target_critic.cuda() def save_models(self, path): self.mac.save_models(path) th.save(self.critic.state_dict(), "{}/critic.th".format(path)) for i in range(self.n_agents): th.save(self.agent_optimisers[i].state_dict(), "{}/agent_{}_opt.th".format(path, i)) th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) self.critic.load_state_dict(th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage)) # Not quite right but I don't want to save target networks self.target_critic.load_state_dict(self.critic.state_dict()) for i in range(self.n_agents): self.agent_optimisers[i].load_state_dict(th.load("{}/agent_{}_opt.th".format(path, i), map_location=lambda storage, loc: storage)) self.critic_optimiser.load_state_dict(th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage))
class PairComaLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.n_agents = args.n_agents self.n_actions = args.n_actions self.mac = mac self.logger = logger self.last_target_update_step = 0 self.critic_training_steps = 0 self.log_stats_t = -self.args.learner_log_interval - 1 self.critic = PairComaCritic(scheme, args) self.target_critic = copy.deepcopy(self.critic) self.agent_params = list(mac.parameters()) self.critic_params = list(self.critic.parameters()) self.params = self.agent_params + self.critic_params self.agent_optimiser = RMSprop(params=self.agent_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) self.critic_optimiser = RMSprop(params=self.critic_params, lr=args.critic_lr, alpha=args.optim_alpha, eps=args.optim_eps) def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities bs = batch.batch_size max_t = batch.max_seq_length rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"][:, :-1] critic_mask = mask.clone() mask = mask.repeat(1, 1, self.n_agents).view(-1) q_vals, critic_train_stats = self._train_critic( batch, rewards, terminated, actions, critic_mask, bs) actions = actions[:, :-1] mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length - 1): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Mask out unavailable actions, renormalise (as in action selection) mac_out[avail_actions == 0] = 0 mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True) mac_out[avail_actions == 0] = 0 identity = th.eye( self.n_agents, device=batch.device).unsqueeze(0).unsqueeze(0).unsqueeze(4).expand( bs, q_vals.shape[1], -1, -1, self.n_actions) q_vals = (1 - identity) * q_vals pi = mac_out.view(-1, self.n_actions) # q_vals = q_vals.reshape(-1, self.n_actions) # # baseline = (pi * q_vals).sum(-1).detach() # # # Calculate policy grad with mask # q_taken = th.gather(q_vals, dim=1, index=actions.reshape(-1, 1)).squeeze(1) adv_temp = ( mac_out.unsqueeze(3).expand(-1, -1, -1, self.n_agents, -1) * q_vals).sum(4).sum(3) actions_for_adv = actions.unsqueeze(3).expand(-1, -1, -1, self.n_agents, -1) advantages = ( th.gather(q_vals, dim=4, index=actions_for_adv).squeeze(4).sum(3) - adv_temp).view(-1).detach() pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1) pi_taken[mask == 0] = 1.0 log_pi_taken = th.log(pi_taken) # advantages = (q_taken - baseline).detach() coma_loss = -((advantages * log_pi_taken) * mask).sum() / mask.sum() # Optimise agents self.agent_optimiser.zero_grad() coma_loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip) self.agent_optimiser.step() if (self.critic_training_steps - self.last_target_update_step ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_step = self.critic_training_steps if t_env - self.log_stats_t >= self.args.learner_log_interval: ts_logged = len(critic_train_stats["critic_loss"]) for key in [ "critic_loss", "critic_grad_norm", "td_error_abs", "q_taken_mean", "target_mean" ]: self.logger.log_stat(key, sum(critic_train_stats[key]) / ts_logged, t_env) self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("coma_loss", coma_loss.item(), t_env) self.logger.log_stat("agent_grad_norm", grad_norm, t_env) self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env) self.log_stats_t = t_env def get_q_taken(self, batch, q, t=None): bs = batch.batch_size max_t = batch.max_seq_length if t is None else 1 ts = slice(None) if t is None else slice(t, t + 1) identity = th.eye( self.n_agents, device=batch.device).unsqueeze(0).unsqueeze(0).unsqueeze(4).expand( bs, max_t, -1, -1, self.n_actions) q = (1 - identity) * q actions = batch["actions"][:, ts].unsqueeze(3).expand( -1, -1, -1, self.n_agents, -1) q_taken = th.gather(q, dim=4, index=actions).squeeze(4) q_tot = q_taken.sum(3).sum(2).unsqueeze(2).expand( -1, -1, self.n_agents) q_part = q_taken.sum(2) q_part_2 = q_taken.sum(3) q_result = q_tot - q_part + q_part_2 return q_tot def _train_critic(self, batch, rewards, terminated, actions, mask, bs): # Optimise critic target_q_vals = self.target_critic(batch) targets_taken = self.get_q_taken(batch, target_q_vals) # Calculate td-lambda targets targets = build_td_lambda_targets(rewards, terminated, mask, targets_taken, self.n_agents, self.args.gamma, self.args.td_lambda) q_vals = th.zeros_like(target_q_vals)[:, :-1] running_log = { "critic_loss": [], "critic_grad_norm": [], "td_error_abs": [], "target_mean": [], "q_taken_mean": [], } for t in reversed(range(rewards.size(1))): mask_t = mask[:, t].expand(-1, self.n_agents) if mask_t.sum() == 0: continue q_t = self.critic(batch, t) q_vals[:, t] = q_t.view(bs, self.n_agents, self.n_agents, self.n_actions) q_taken = self.get_q_taken(batch, q_t, t).squeeze(1) targets_t = targets[:, t] td_error = (q_taken - targets_t.detach()) # 0-out the targets that came from padded data masked_td_error = td_error * mask_t # Normal L2 loss, take mean over actual data loss = (masked_td_error**2).sum() / mask_t.sum() self.critic_optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.critic_params, self.args.grad_norm_clip) self.critic_optimiser.step() self.critic_training_steps += 1 running_log["critic_loss"].append(loss.item()) running_log["critic_grad_norm"].append(grad_norm) mask_elems = mask_t.sum().item() running_log["td_error_abs"].append( (masked_td_error.abs().sum().item() / mask_elems)) running_log["q_taken_mean"].append( (q_taken * mask_t).sum().item() / mask_elems) running_log["target_mean"].append( (targets_t * mask_t).sum().item() / mask_elems) return q_vals, running_log def _update_targets(self): self.target_critic.load_state_dict(self.critic.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.critic.cuda() self.target_critic.cuda() def save_models(self, path): self.mac.save_models(path) th.save(self.critic.state_dict(), "{}/critic.th".format(path)) th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path)) th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) self.critic.load_state_dict( th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage)) # Not quite right but I don't want to save target networks self.target_critic.load_state_dict(self.critic.state_dict()) self.agent_optimiser.load_state_dict( th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage)) self.critic_optimiser.load_state_dict( th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage))
class CateQLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.mac = mac self.logger = logger self.params = list(mac.parameters()) self.last_target_update_episode = 0 self.mixer = None if args.mixer is not None: if args.mixer == "vdn": self.mixer = VDNMixer() elif args.mixer == "qmix": self.mixer = QMixer(args) else: raise ValueError("Mixer {} not recognised.".format(args.mixer)) self.params += list(self.mixer.parameters()) self.target_mixer = copy.deepcopy(self.mixer) self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC self.target_mac = copy.deepcopy(mac) self.log_stats_t = -self.args.learner_log_interval - 1 self.s_mu = th.zeros(1) self.s_sigma = th.ones(1) def get_comm_beta(self, t_env): comm_beta = self.args.comm_beta if self.args.is_comm_beta_decay and t_env > self.args.comm_beta_start_decay: comm_beta += 1. * (self.args.comm_beta_target - self.args.comm_beta) / \ (self.args.comm_beta_end_decay - self.args.comm_beta_start_decay) * \ (t_env - self.args.comm_beta_start_decay) return comm_beta def get_comm_entropy_beta(self, t_env): comm_entropy_beta = self.args.comm_entropy_beta if self.args.is_comm_entropy_beta_decay and t_env > self.args.comm_entropy_beta_start_decay: comm_entropy_beta += 1. * (self.args.comm_entropy_beta_target - self.args.comm_entropy_beta) / \ (self.args.comm_entropy_beta_end_decay - self.args.comm_entropy_beta_start_decay) * \ (t_env - self.args.comm_entropy_beta_start_decay) return comm_entropy_beta def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values # shape = (bs, self.n_agents, -1) mac_out = [] mu_out = [] sigma_out = [] logits_out = [] m_sample_out = [] g_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): if self.args.comm and self.args.use_IB: agent_outs, (mu, sigma), logits, m_sample = self.mac.forward(batch, t=t) mu_out.append(mu) sigma_out.append(sigma) logits_out.append(logits) m_sample_out.append(m_sample) else: agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time if self.args.use_IB: mu_out = th.stack(mu_out, dim=1)[:, :-1] # Concat over time sigma_out = th.stack(sigma_out, dim=1)[:, :-1] # Concat over time logits_out = th.stack(logits_out, dim=1)[:, :-1] m_sample_out = th.stack(m_sample_out, dim=1)[:, :-1] # Pick the Q-Values for the actions taken by each agent chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze( 3) # Remove the last dim # I believe that code up to here is right... # Q values are right, the main issue is to calculate loss for message... # Calculate the Q-Values necessary for the target target_mac_out = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): if self.args.comm and self.args.use_IB: target_agent_outs, (target_mu, target_sigma), target_logits, target_m_sample = \ self.target_mac.forward(batch, t=t) else: target_agent_outs = self.target_mac.forward(batch, t=t) target_mac_out.append(target_agent_outs) # label label_target_max_out = th.stack(target_mac_out[:-1], dim=1) label_target_max_out[avail_actions[:, :-1] == 0] = -9999999 label_target_actions = label_target_max_out.max(dim=3, keepdim=True)[1] # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time # Mask out unavailable actions target_mac_out[avail_actions[:, 1:] == 0] = -9999999 # Max over target Q-Values if self.args.double_q: # Get actions that maximise live Q (for double q-learning) mac_out[avail_actions == 0] = -9999999 cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1] target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) else: target_max_qvals = target_mac_out.max(dim=3)[0] # Mix if self.mixer is not None: chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:]) # Calculate 1-step Q-Learning targets targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals # Td-error td_error = (chosen_action_qvals - targets.detach()) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data loss = (masked_td_error**2).sum() / mask.sum() if self.args.only_downstream or not self.args.use_IB: expressiveness_loss = th.Tensor([0.]) compactness_loss = th.Tensor([0.]) entropy_loss = th.Tensor([0.]) comm_loss = th.Tensor([0.]) comm_beta = th.Tensor([0.]) comm_entropy_beta = th.Tensor([0.]) else: # ### Optimize message # Message are controlled only by expressiveness and compactness loss. # Compute cross entropy with target q values of the same time step expressiveness_loss = 0 label_prob = th.gather(logits_out, 3, label_target_actions).squeeze(3) expressiveness_loss += ( -th.log(label_prob + 1e-6)).sum() / mask.sum() # Compute KL divergence compactness_loss = D.kl_divergence(D.Normal(mu_out, sigma_out), D.Normal(self.s_mu, self.s_sigma)).sum() / \ mask.sum() # Entropy loss entropy_loss = -D.Normal(self.s_mu, self.s_sigma).log_prob( m_sample_out).sum() / mask.sum() # Gate loss gate_loss = 0 # Total loss comm_beta = self.get_comm_beta(t_env) comm_entropy_beta = self.get_comm_entropy_beta(t_env) comm_loss = expressiveness_loss + comm_beta * compactness_loss + comm_entropy_beta * entropy_loss comm_loss *= self.args.c_beta loss += comm_loss comm_beta = th.Tensor([comm_beta]) comm_entropy_beta = th.Tensor([comm_entropy_beta]) # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.optimiser.step() # Update target if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("comm_loss", comm_loss.item(), t_env) self.logger.log_stat("exp_loss", expressiveness_loss.item(), t_env) self.logger.log_stat("comp_loss", compactness_loss.item(), t_env) self.logger.log_stat("comm_beta", comm_beta.item(), t_env) self.logger.log_stat("entropy_loss", entropy_loss.item(), t_env) self.logger.log_stat("comm_beta", comm_beta.item(), t_env) self.logger.log_stat("comm_entropy_beta", comm_entropy_beta.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat( "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.log_stats_t = t_env def _update_targets(self): self.target_mac.load_state(self.mac) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) # self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() if self.mixer is not None: self.mixer.cuda() self.target_mixer.cuda() self.s_mu = self.s_mu.cuda() self.s_sigma = self.s_sigma.cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) # Not quite right but I don't want to save target networks self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict( th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict( th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
class SHAQLearner: def __init__(self, mac, scheme, logger, args): self.args = args if args.name == "shaq": from modules.mixers.shaq import SHAQMixer else: raise Exception("Please give the correct mixer name!") self.mac = mac self.logger = logger self.params = list(mac.parameters()) self.last_target_update_episode = 0 self.last_mixer_update_episode = 0 self.last_sample_coalition_episode = 0 self.mixer = None if args.mixer is not None: self.mixer = SHAQMixer(args) self.params_mixer = list(self.mixer.parameters()) self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) self.optimiser_mixer = RMSprop(params=self.params_mixer, lr=args.alpha_lr, alpha=args.optim_alpha, eps=args.optim_eps) # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC self.target_mac = copy.deepcopy(mac) self.log_stats_t = -self.args.learner_log_interval - 1 def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] one_hot_actions = th.nn.functional.one_hot(actions, num_classes=self.args.n_actions) terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Pick the Q-Values for the actions taken by each agent chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim # generate a filter for selecting the agents with the max-action _mac_out_detach = mac_out.clone().detach() _mac_out_detach[avail_actions == 0] = -9999999 _cur_max_actions = _mac_out_detach[:, :-1].max(dim=3, keepdim=True)[1].squeeze(3) max_filter = (actions.detach().squeeze(3)==_cur_max_actions).float() # Calculate the Q-Values necessary for the target target_mac_out = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = self.target_mac.forward(batch, t=t) target_mac_out.append(target_agent_outs) # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time # Mask out unavailable actions target_mac_out[avail_actions[:, 1:] == 0] = -9999999 # From OG deepmarl # Max over target Q-Values if self.args.double_q: # Get actions that maximise live Q (for double q-learning) mac_out_detach = mac_out.clone().detach() mac_out_detach[avail_actions == 0] = -9999999 cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1] target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) else: target_max_qvals = target_mac_out.max(dim=3)[0] # Mix if self.mixer is not None: chosen_action_qvals, w_est = self.mixer(batch["state"][:, :-1], one_hot_actions, chosen_action_qvals, max_filter, target=False, manual_alpha_estimates=self.args.manual_alpha_estimates) target_max_qvals = self.mixer(batch["state"][:, 1:], one_hot_actions, target_max_qvals, max_filter, target=True, manual_alpha_estimates=self.args.manual_alpha_estimates) N = getattr(self.args, "n_step", 1) if N == 1: # Calculate 1-step Q-Learning targets targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals else: # N step Q-Learning targets n_rewards = th.zeros_like(rewards) gamma_tensor = th.tensor([self.args.gamma**i for i in range(N)], dtype=th.float, device=n_rewards.device) steps = mask.flip(1).cumsum(dim=1).flip(1).clamp_max(N).long() for i in range(batch.max_seq_length - 1): n_rewards[:,i,0] = ((rewards * mask)[:,i:i+N,0] * gamma_tensor[:(batch.max_seq_length - 1 - i)]).sum(dim=1) indices = th.linspace(0, batch.max_seq_length-2, steps=batch.max_seq_length-1, device=steps.device).unsqueeze(1).long() n_targets_terminated = th.gather(target_max_qvals*(1-terminated),dim=1,index=steps.long()+indices-1) targets = n_rewards + th.pow(self.args.gamma, steps.float()) * n_targets_terminated # Td-error td_error = (chosen_action_qvals - targets.detach()) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data loss = (masked_td_error ** 2).sum() / mask.sum() # Optimise self.optimiser.zero_grad() self.optimiser_mixer.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) grad_norm_mixer = th.nn.utils.clip_grad_norm_(self.params_mixer, self.args.grad_norm_clip) self.optimiser.step() self.optimiser_mixer.step() # Periodically update target Q-values if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num # Logging if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) self.logger.log_stat("grad_norm_mixer", grad_norm_mixer, t_env) mask_elems = mask.sum().item() self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) agent_utils = (th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) * mask).sum().item() / (mask_elems * self.args.n_agents) self.logger.log_stat("agent_utils", agent_utils, t_env) self.logger.log_stat("w_est", ( w_est * (1 - max_filter) * mask.expand_as(w_est) ).sum().item() / ( ( (1 - max_filter) * mask.expand_as(w_est) ).sum().item() ), t_env) self.log_stats_t = t_env def _update_targets(self): self.target_mac.load_state(self.mac) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() if self.mixer is not None: self.mixer.cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) # Not quite right but I don't want to save target networks self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
class SACQLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.mac = mac self.logger = logger self.mac_params = list(mac.parameters()) self.params = list(self.mac.parameters()) self.last_target_update_episode = 0 self.mixer = None assert args.mixer is not None if args.mixer is not None: if args.mixer == "vdn": self.mixer = VDNMixer() elif args.mixer == "qmix": self.mixer = QMixer(args) else: raise ValueError("Mixer {} not recognised.".format(args.mixer)) self.mixer_params = list(self.mixer.parameters()) self.params += list(self.mixer.parameters()) self.target_mixer = copy.deepcopy(self.mixer) # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC self.target_mac = copy.deepcopy(mac) # Central Q # TODO: Clean this mess up! self.central_mac = None assert self.args.central_mixer == "ff" self.central_mixer = QMixerCentralFF(args) assert args.central_mac == "basic_central_mac" self.central_mac = mac_REGISTRY[args.central_mac]( scheme, args ) # Groups aren't used in the CentralBasicController. Little hacky self.target_central_mac = copy.deepcopy(self.central_mac) self.params += list(self.central_mac.parameters()) self.params += list(self.central_mixer.parameters()) self.target_central_mixer = copy.deepcopy(self.central_mixer) self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) self.log_stats_t = -self.args.learner_log_interval - 1 def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Current policies mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Mask out unavailable actions, renormalise (as in action selection) mac_out[avail_actions == 0] = 0 mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True) mac_out[avail_actions == 0] = 0 mac_out[(mac_out.sum(dim=-1, keepdim=True) == 0).expand_as( mac_out )] = 1 # Set any all 0 probability vectors to all 1s. They will be masked out later, but still need to be sampled. # Target policies target_mac_out = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = self.target_mac.forward(batch, t=t) target_mac_out.append(target_agent_outs) target_mac_out = th.stack(target_mac_out, dim=1) # Concat across time # Mask out unavailable actions, renormalise (as in action selection) target_mac_out[avail_actions == 0] = 0 target_mac_out = target_mac_out / target_mac_out.sum(dim=-1, keepdim=True) target_mac_out[avail_actions == 0] = 0 target_mac_out[( target_mac_out.sum(dim=-1, keepdim=True) == 0 ).expand_as( target_mac_out )] = 1 # Set any all 0 probability vectors to all 1s. They will be masked out later, but still need to be sampled. # Sample actions sampled_actions = Categorical(mac_out).sample().long() sampled_target_actions = Categorical(target_mac_out).sample().long() # Central MAC stuff central_mac_out = [] self.central_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.central_mac.forward(batch, t=t) central_mac_out.append(agent_outs) central_mac_out = th.stack(central_mac_out, dim=1) # Concat over time # Actions chosen from replay buffer central_chosen_action_qvals_agents = th.gather( central_mac_out[:, :-1], dim=3, index=actions.unsqueeze(4).repeat( 1, 1, 1, 1, self.args.central_action_embed)).squeeze( 3) # Remove the last dim central_target_mac_out = [] self.target_central_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = self.target_central_mac.forward(batch, t=t) central_target_mac_out.append(target_agent_outs) central_target_mac_out = th.stack(central_target_mac_out[:], dim=1) # Concat across time central_target_action_qvals_agents = th.gather(central_target_mac_out[:,:], 3, \ sampled_target_actions[:,:].unsqueeze(3).unsqueeze(4)\ .repeat(1,1,1,1,self.args.central_action_embed)).squeeze(3) # --- critic_bootstrap_qvals = self.target_central_mixer( central_target_action_qvals_agents[:, 1:], batch["state"][:, 1:]) target_chosen_action_probs = th.gather( target_mac_out, dim=3, index=sampled_target_actions.unsqueeze(3)).squeeze(dim=3) target_policy_logs = th.log(target_chosen_action_probs).sum( dim=2, keepdim=True) # Sum across agents # Calculate 1-step Q-Learning targets targets = rewards + self.args.gamma * (1 - terminated) * \ (critic_bootstrap_qvals - self.args.entropy_temp * target_policy_logs[:,1:]) # Training Critic central_chosen_action_qvals = self.central_mixer( central_chosen_action_qvals_agents, batch["state"][:, :-1]) central_td_error = (central_chosen_action_qvals - targets.detach()) central_mask = mask.expand_as(central_td_error) central_masked_td_error = central_td_error * central_mask central_loss = (central_masked_td_error**2).sum() / mask.sum() # Actor Loss central_sampled_action_qvals_agents = th.gather(central_mac_out[:, :-1], 3, \ sampled_actions[:, :-1].unsqueeze(3).unsqueeze(4) \ .repeat(1, 1, 1, 1, self.args.central_action_embed)).squeeze(3) central_sampled_action_qvals = self.central_mixer( central_sampled_action_qvals_agents, batch["state"][:, :-1]).repeat(1, 1, self.args.n_agents) sampled_action_probs = th.gather( mac_out, dim=3, index=sampled_actions.unsqueeze(3)).squeeze(3) policy_logs = th.log(sampled_action_probs)[:, :-1] actor_mask = mask.expand_as(policy_logs) actor_loss = ( (policy_logs * (self.args.entropy_temp * (policy_logs + 1) - central_sampled_action_qvals).detach()) * actor_mask).sum() / actor_mask.sum() loss = self.args.actor_loss * actor_loss + self.args.central_loss * central_loss # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.grad_norm = grad_norm self.optimiser.step() if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("actor_loss", actor_loss.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("central_loss", central_loss.item(), t_env) ps = mac_out[:, :-1] * avail_actions[:, :-1] log_ps = th.log(mac_out[:, :-1] + 0.00001) * avail_actions[:, :-1] actor_entropy = -(( (ps * log_ps).sum(dim=3) * mask).sum() / mask.sum()) self.logger.log_stat("actor_entropy", actor_entropy.item(), t_env) self.log_stats_t = t_env def _update_targets(self): self.target_mac.load_state(self.mac) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) if self.central_mac is not None: self.target_central_mac.load_state(self.central_mac) self.target_central_mixer.load_state_dict( self.central_mixer.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() if self.mixer is not None: self.mixer.cuda() self.target_mixer.cuda() if self.central_mac is not None: self.central_mac.cuda() self.target_central_mac.cuda() self.central_mixer.cuda() self.target_central_mixer.cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) # Not quite right but I don't want to save target networks self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict( th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict( th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
class QLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.mac = mac self.logger = logger self.params = list(mac.parameters()) self.last_target_update_episode = 0 self.mixer = None if args.mixer is not None: if args.mixer == "vdn": self.mixer = VDNMixer() elif args.mixer == "qmix": self.mixer = QMixer(args) elif args.mixer == "graphmix": self.mixer = GraphMixer(args) else: raise ValueError("Mixer {} not recognised.".format(args.mixer)) self.params += list(self.mixer.parameters()) self.target_mixer = copy.deepcopy(self.mixer) self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) self.target_mac = copy.deepcopy(mac) self.log_stats_t = -self.args.learner_log_interval - 1 def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # 관련된 데이터를 가져온다. rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Agent 개별의 Q값을 산출함 mac_out = [] hidden_states = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) hidden_states.append(self.mac.hidden_states.view(batch.batch_size, self.args.n_agents, -1)) mac_out = th.stack(mac_out, dim=1) # 시간순에 따라 Concat hidden_states = th.stack(hidden_states, dim=1) # Agent가 선택한 행동의 Q값을 뽑는다. chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim # Agent Target Network 의 개별 Q값을 산출함 target_mac_out = [] target_hidden_states = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = self.target_mac.forward(batch, t=t) target_mac_out.append(target_agent_outs) target_hidden_states.append(self.target_mac.hidden_states.view(batch.batch_size, self.args.n_agents, -1)) # target network는 next_state 기준 이므로 t=1부터 저장 target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time target_hidden_states = th.stack(target_hidden_states[1:], dim=1) # Mask out unavailable actions target_mac_out[avail_actions[:, 1:] == 0] = -9999999 if self.args.double_q: # Get actions that maximise live Q (for double q-learning) # Agent의 Q값 저장 mac_out_detach = mac_out.clone().detach() mac_out_detach[avail_actions == 0] = -9999999 # Agent의 next_state 기준의 최대 Q값의 Action 저장 cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1] # 이 action 에 대한 Target Network Q값을 저장한다. # 실제 target_mac_out의 최댓값과 다를 수 있음 # 실제로 Agent가 선택한 행동의 대한 Q값을 가져오는 셈 target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) else: target_max_qvals = target_mac_out.max(dim=3)[0] if self.args.mixer == 'graphmix': # Mix chosen_action_qvals_peragent = chosen_action_qvals.clone() target_max_qvals_peragent = target_max_qvals.detach() Q_total, local_rewards, alive_agents_mask = self.mixer(chosen_action_qvals, batch["state"][:, :-1], agent_obs=batch["obs"][:, :-1], team_rewards=rewards, hidden_states=hidden_states[:, :-1] ) target_Q_total = self.target_mixer(target_max_qvals, batch["state"][:, 1:], agent_obs=batch["obs"][:, 1:], hidden_states=target_hidden_states )[0] ## Global loss # Calculate 1-step Q-Learning targets targets = rewards + self.args.gamma * (1 - terminated) * target_Q_total # Td-error td_error = (Q_total - targets.detach()) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data global_loss = (masked_td_error ** 2).sum() / mask.sum() ## Local losses # Calculate 1-step Q-Learning targets local_targets = local_rewards + self.args.gamma * (1 - terminated).repeat(1, 1, self.args.n_agents) \ * target_max_qvals_peragent # Td-error local_td_error = (chosen_action_qvals_peragent - local_targets) local_mask = mask.repeat(1, 1, self.args.n_agents) * alive_agents_mask.float() # 0-out the targets that came from padded data local_masked_td_error = local_td_error * local_mask # Normal L2 loss, take mean over actual data local_loss = (local_masked_td_error ** 2).sum() / mask.sum() # total loss lambda_local = self.args.lambda_local loss = global_loss + lambda_local * local_loss else: # Mix if self.mixer is not None: # Agent가 선택한 Q값과 state 정보를 넘겨 Q_total, target_Q_total 산출 Q_total = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) target_Q_total = self.target_mixer(target_max_qvals, batch["state"][:, 1:]) else: Q_total = chosen_action_qvals target_Q_total = target_max_qvals # 1 step Q 러닝 수행 targets = rewards + self.args.gamma * (1 - terminated) * target_Q_total # Td-error td_error = (Q_total - targets.detach()) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data loss = (masked_td_error ** 2).sum() / mask.sum() # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.optimiser.step() if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) if self.args.mixer == 'graphmix': self.logger.log_stat("global_loss", global_loss.item(), t_env) self.logger.log_stat("local_loss", local_loss.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (Q_total * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.log_stats_t = t_env def _update_targets(self): self.target_mac.load_state(self.mac) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() if self.mixer is not None: self.mixer.cuda() self.target_mixer.cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
class QDPPQLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.mac = mac self.logger = logger self.params = list(mac.parameters()) self.last_target_update_episode = 0 self.mixer = None if args.mixer is not None: if args.mixer == "qdpp": self.mixer = QDPPMixer(args) else: raise ValueError("Mixer {} not recognised.".format(args.mixer)) self.params += list(self.mixer.parameters()) self.target_mixer = copy.deepcopy(self.mixer) self.mac.mixer = self.mixer if getattr(args, "all_obs", None) is not None: shape = self.args.all_obs.shape if self.args.device == 'cuda': self.logger_batch = { 'obs': th.from_numpy( self.args.all_obs.reshape(shape[0], 1, shape[1], shape[2])).float().cuda(), 'avail_actions': th.ones(shape[0], 1, shape[1], self.args.n_actions).float().cuda() } else: self.logger_batch = { 'obs': th.from_numpy( self.args.all_obs.reshape(shape[0], 1, shape[1], shape[2])).float(), 'avail_actions': th.ones(shape[0], 1, shape[1], self.args.n_actions).float() } self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps, weight_decay=args.weight_decay) # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC self.target_mac = copy.deepcopy(mac) self.log_stats_t = -self.args.learner_log_interval - 1 def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values time_stamp = time.time() mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Pick the Q-Values for the actions taken by each agent chosen_action_ind_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze( 3) # Remove the last dim chosen_action_ind_qvals = th.clamp(chosen_action_ind_qvals, self.args.q_min, self.args.q_max) # Calculate the Q-Values necessary for the target target_mac_out = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = self.target_mac.forward(batch, t=t) target_mac_out.append(target_agent_outs) # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time target_mac_out = th.clamp(target_mac_out, self.args.q_min, self.args.q_max) # Mask out unavailable actions target_mac_out[avail_actions[:, 1:] == 0] = -9999999 # Max over target Q-Values temperature = self.mac.schedule.eval(t_env) / 2. if self.args.double_q: # Get actions that maximise live Q (for double q-learning) mac_out_detach = mac_out.clone().detach() mac_out_detach[avail_actions == 0] = -9999999 cur_max_actions = project_sample(batch["state"].int(), mac_out_detach, self.mixer, temperature=temperature, avail_actions=avail_actions, greedy=True)[:, 1:] target_max_ind_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) else: target_states = batch["state"][:, 1:].int() cur_max_actions = project_sample(target_states, target_mac_out, self.target_mixer, temperature=temperature, avail_actions=avail_actions, greedy=True) target_max_ind_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) # Mix if self.mixer is not None: chosen_action_qvals = self.mixer(chosen_action_ind_qvals, batch["state"][:, :-1], actions, t_env) # target_max_qvals = self.target_mixer(target_max_ind_qvals, batch["state"][:, 1:], cur_max_actions, t_env) # chosen_action_qvals = th.clamp(chosen_action_qvals, self.args.q_min, self.args.q_max) target_max_qvals = th.clamp(target_max_qvals, self.args.q_min, self.args.q_max) targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals # Td-error td_error = (chosen_action_qvals - targets.detach()) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data loss = (masked_td_error**2).sum() / mask.sum() if (not getattr(self.args, 'continuous_state', None) or not self.args.continuous_state) \ and self.args.beta_balance: _, sv, _ = th.svd(self.mixer.B.weight) sv_p = [] partition_l = self.args.state_num * self.args.n_actions beta = np.sqrt(self.args.n_agents) for i in range(self.args.n_agents): B_i = self.mixer.B.weight[i * partition_l:(i + 1) * partition_l] _, sv_i, _ = th.svd(B_i) sv_p.append(sv_i) sv_p = th.stack(sv_p, dim=0) sv = sv.repeat(self.args.n_agents).view(self.args.n_agents, -1) raw_beta_balance = (sv / beta - sv_p) beta_balance_mask = raw_beta_balance > 0. beta_balance = (raw_beta_balance * beta_balance_mask.float()).sum() loss += self.args.beta_balance_rate * beta_balance # Optimize self.optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) # grad_norm = 0 # TEST self.optimiser.step() if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) self.logger.log_stat("temperature", temperature, t_env) mask_elems = mask.sum().item() self.logger.log_stat( "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.log_stats_t = t_env if self.mixer is not None: BBT_grad = self.mixer.BBT.grad q_grad = self.mixer.q.grad det_grad_raw = ( BBT_grad / (q_grad.reshape(-1, 1, 1) + (1 - mask).reshape(-1, 1, 1) ) # avoid nan in masked places ).reshape(-1, self.args.n_agents * self.args.n_agents) * mask.reshape(-1, 1) det_grad_norm = det_grad_raw.norm(dim=-1).sum() / th.sum( mask == 1) BBT_grad_norm = (BBT_grad.reshape(-1, self.args.n_agents * self.args.n_agents) * mask.reshape(-1, 1))\ .norm(dim=-1).sum() / th.sum(mask == 1) q_grad_norm = ( (q_grad * mask.reshape(-1, )).abs()).sum() / th.sum(mask == 1) det_norm = ( (self.mixer.det_BBT * mask.reshape(-1, )).abs()).sum() / th.sum(mask == 1) self.logger.log_stat("det_grad_norm", float(det_grad_norm), t_env) self.logger.log_stat("det_norm", float(det_norm), t_env) self.logger.log_stat("BBT_grad_norm", float(BBT_grad_norm), t_env) self.logger.log_stat("q_grad_norm", float(q_grad_norm), t_env) self.logger.log_stat( "logdet", float((self.mixer.q - self.mixer.q_sum).mean()), t_env) self.logger.log_stat("q_sum", float(self.mixer.q_sum.mean()), t_env) if getattr(self.args, "all_obs", None) is not None and self.args.log_all_obs: qvals = self.mac.forward(self.logger_batch, t=0) if self.args.device == 'cuda': self.logger.log_stat('qvals', qvals.detach().cpu().numpy(), t_env) else: self.logger.log_stat('qvals', qvals.detach().numpy(), t_env) if self.args.device == 'cuda': self.logger.log_stat('B', self.mixer.B.weight.data.cpu().numpy(), t_env) else: bb = jointq_ind_q = chosen_action_qvals - chosen_action_ind_qvals.sum( dim=-1).unsqueeze(2) target_jointq_ind_q = target_max_qvals - target_max_ind_qvals.sum( dim=-1).unsqueeze(2) chosen_action_ind_qvals_sum = chosen_action_ind_qvals.sum( dim=-1).unsqueeze(2) self.logger.log_stat( 'chosen_action_ind_qvals_sum', chosen_action_ind_qvals_sum.mean().detach().numpy().astype( np.float64).item(), t_env) bb_q = bb / chosen_action_ind_qvals_sum self.logger.log_stat( 'bb_q_mean', bb_q.mean().detach().numpy().astype(np.float64).item(), t_env) self.logger.log_stat( 'bb_q_max', bb_q.max().detach().numpy().astype(np.float64).item(), t_env) self.logger.log_stat( 'bb_q_min', bb_q.max().detach().numpy().astype(np.float64).item(), t_env) self.logger.log_stat( 'jointq-ind_q_mean', jointq_ind_q.mean().detach().numpy().astype( np.float64).item(), t_env) self.logger.log_stat( 'jointq-ind_q_max', jointq_ind_q.max().detach().numpy().astype( np.float64).item(), t_env) self.logger.log_stat( 'target_jointq-ind_q_max', target_jointq_ind_q.max().detach().numpy().astype( np.float64).item(), t_env) self.logger.log_stat( 'target_jointq-ind_q_mean', target_jointq_ind_q.mean().detach().numpy().astype( np.float64).item(), t_env) self.logger.log_stat( 'jointq-ind_q_min', jointq_ind_q.min().detach().numpy().astype( np.float64).item(), t_env) self.logger.log_stat( 'target_jointq-ind_q_min', target_jointq_ind_q.min().detach().numpy().astype( np.float64).item(), t_env) self.logger.log_stat('target_q_taken_mean', (target_max_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat( 'target_ind_qvals_max', target_jointq_ind_q.detach().max().numpy().astype( np.float64).item(), t_env) self.logger.log_stat( 'target_ind_qvals_min', target_jointq_ind_q.detach().min().numpy().astype( np.float64).item(), t_env) self.logger.log_stat( 'target_ind_qvals_mean', target_jointq_ind_q.detach().mean().numpy().astype( np.float64).item(), t_env) self.logger.log_stat( 'chosen_action_ind_qvals', chosen_action_ind_qvals.detach().mean().numpy().astype( np.float64).item(), t_env) self.logger.log_stat('B', self.mixer.B.weight.data.numpy(), t_env) if self.args.all_obs is not None: self.logger.log_stat('states', self.args.all_obs, t_env) def _update_targets(self): self.target_mac.load_state(self.mac) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() if self.mixer is not None: self.mixer.cuda() self.target_mixer.cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict( th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict( th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
class OffPGLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.n_agents = args.n_agents self.n_actions = args.n_actions self.mac = mac self.logger = logger self.last_target_update_step = 0 self.critic_training_steps = 0 self.log_stats_t = -self.args.learner_log_interval - 1 self.critic = OffPGCritic(scheme, args) self.double_critic = OffPGCritic(scheme, args) self.target_critic = copy.deepcopy(self.critic) self.double_target_critic = copy.deepcopy(self.double_critic) self.agent_params = list(mac.parameters()) self.critic_params = list(self.critic.parameters()) self.double_critic_params = list(self.double_critic.parameters()) self.params = self.agent_params + self.critic_params + self.double_critic_params self.agent_optimiser = RMSprop(params=self.agent_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) self.critic_optimiser = RMSprop(params=self.critic_params, lr=args.critic_lr, alpha=args.optim_alpha, eps=args.optim_eps) self.double_critic_optimiser = RMSprop( params=self.double_critic_params, lr=args.critic_lr, alpha=args.optim_alpha, eps=args.optim_eps) def train(self, batch: EpisodeBatch, t_env: int, log): # Get the relevant quantities bs = batch.batch_size max_t = batch.max_seq_length actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() avail_actions = batch["avail_actions"][:, :-1] mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) mask = mask.repeat(1, 1, self.n_agents).view(-1) #build q inputs = self.critic._build_inputs(batch, bs, max_t) q_vals = self.critic.forward(inputs).detach()[:, :-1] mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length - 1): agent_outs, _ = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Mask out unavailable actions, renormalise (as in action selection) mac_out[avail_actions == 0] = 0 mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True) mac_out[avail_actions == 0] = 0 # Calculated baseline q_vals = q_vals.reshape(-1, self.n_actions) pi = mac_out.view(-1, self.n_actions) baseline = (pi * q_vals).sum(-1).detach() # Calculate policy grad with mask q_taken = th.gather(q_vals, dim=1, index=actions.reshape(-1, 1)).squeeze(1) pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1) pi_taken[mask == 0] = 1.0 log_pi_taken = th.log(pi_taken) advantages = (q_taken - baseline).detach() coma_loss = -((advantages * log_pi_taken) * mask).sum() / mask.sum() # Optimise agents self.agent_optimiser.zero_grad() coma_loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip) self.agent_optimiser.step() #compute parameters sum for debugging p_sum = 0. for p in self.agent_params: p_sum += p.data.abs().sum().item() / 100.0 if t_env - self.log_stats_t >= self.args.learner_log_interval: ts_logged = len(log["critic_loss"]) for key in [ "critic_loss", "critic_grad_norm", "td_error_abs", "q_taken_mean", "target_mean" ]: self.logger.log_stat(key, sum(log[key]) / ts_logged, t_env) self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("coma_loss", coma_loss.item(), t_env) self.logger.log_stat("agent_grad_norm", grad_norm, t_env) self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("alpha", self.mac.agent.comm_fact, t_env) self.logger.log_stat("agent_parameter", p_sum, t_env) self.log_stats_t = t_env def train_critic(self, on_batch, best_batch=None, log=None): bs = on_batch.batch_size max_t = on_batch.max_seq_length rewards = on_batch["reward"][:, :-1] actions = on_batch["actions"][:, :] terminated = on_batch["terminated"][:, :-1].float() mask = on_batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) #build_target_q for critic and double_critic target_inputs = self.target_critic._build_inputs(on_batch, bs, max_t) target_q_vals = self.target_critic.forward(target_inputs).detach() double_target_q_vals = self.double_target_critic.forward( target_inputs).detach() targets_taken = th.mean(th.gather(target_q_vals, dim=3, index=actions).squeeze(3), dim=2, keepdim=True) double_targets_taken = th.mean(th.gather(double_target_q_vals, dim=3, index=actions).squeeze(3), dim=2, keepdim=True) target_q = build_td_lambda_targets( rewards, terminated, mask, targets_taken, self.n_agents, self.args.gamma, self.args.td_lambda).repeat(1, 1, self.n_agents) double_target_q = build_td_lambda_targets( rewards, terminated, mask, double_targets_taken, self.n_agents, self.args.gamma, self.args.td_lambda).repeat(1, 1, self.n_agents) inputs = self.critic._build_inputs(on_batch, bs, max_t) if best_batch is not None: best_target_q, best_double_target_q, best_inputs, best_mask, best_actions = self.train_critic_best( best_batch) target_q = th.cat((target_q, best_target_q), dim=0) double_target_q = th.cat((double_target_q, best_double_target_q), dim=0) inputs = th.cat((inputs, best_inputs), dim=0) mask = th.cat((mask, best_mask), dim=0) actions = th.cat((actions, best_actions), dim=0) mask = mask.repeat(1, 1, self.n_agents) #train critic for t in range(max_t - 1): mask_t = mask[:, t:t + 1] if mask_t.sum() < 0.5: continue q_vals = self.critic.forward(inputs[:, t:t + 1]) double_q_vals = self.double_critic.forward(inputs[:, t:t + 1]) q_vals = th.gather(q_vals, 3, index=actions[:, t:t + 1]).squeeze(3) double_q_vals = th.gather(double_q_vals, 3, index=actions[:, t:t + 1]).squeeze(3) target_q_t = target_q[:, t:t + 1] double_target_q_t = double_target_q[:, t:t + 1] # gradient for both critics q_err = (q_vals - double_target_q_t) * mask_t critic_loss = (q_err**2).sum() / mask_t.sum() self.critic_optimiser.zero_grad() critic_loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.critic_params, self.args.grad_norm_clip) self.critic_optimiser.step() log["critic_loss"].append(critic_loss.item()) log["critic_grad_norm"].append(grad_norm) mask_elems = mask_t.sum().item() log["td_error_abs"].append((q_err.abs().sum().item() / mask_elems)) log["target_mean"].append( (target_q_t * mask_t).sum().item() / mask_elems) log["q_taken_mean"].append( (q_vals * mask_t).sum().item() / mask_elems) double_q_err = (double_q_vals - target_q_t) * mask_t double_critic_loss = (double_q_err**2).sum() / mask_t.sum() self.double_critic_optimiser.zero_grad() double_critic_loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.double_critic_params, self.args.grad_norm_clip) self.double_critic_optimiser.step() log["critic_loss"].append(double_critic_loss.item()) log["critic_grad_norm"].append(grad_norm) mask_elems = mask_t.sum().item() log["td_error_abs"].append( (double_q_err.abs().sum().item() / mask_elems)) log["target_mean"].append( (double_target_q_t * mask_t).sum().item() / mask_elems) log["q_taken_mean"].append( (double_q_vals * mask_t).sum().item() / mask_elems) self.critic_training_steps += 1 #update target network if (self.critic_training_steps - self.last_target_update_step ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_step = self.critic_training_steps def train_critic_best(self, batch): bs = batch.batch_size max_t = batch.max_seq_length rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"][:] # pr for all actions of the episode mac_out = [] self.mac.init_hidden(bs) for i in range(max_t): agent_outs, _ = self.mac.forward(batch, t=i) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1).detach() # Mask out unavailable actions, renormalise (as in action selection) mac_out[avail_actions == 0] = 0 mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True) mac_out[avail_actions == 0] = 0 critic_mac = th.gather(mac_out, 3, actions).squeeze(3).prod(dim=2, keepdim=True) #target_q take target_inputs = self.target_critic._build_inputs(batch, bs, max_t) target_q_vals = self.target_critic.forward(target_inputs).detach() targets_taken = th.mean(th.gather(target_q_vals, dim=3, index=actions).squeeze(3), dim=2, keepdim=True) double_target_q_vals = self.double_target_critic.forward( target_inputs).detach() double_targets_taken = th.mean(th.gather(double_target_q_vals, dim=3, index=actions).squeeze(3), dim=2, keepdim=True) #expected q exp_q, double_exp_q = self.build_exp_q(batch, mac_out, bs, max_t) # td-error targets_taken[:, -1] = targets_taken[:, -1] * (1 - th.sum(terminated, dim=1)) exp_q[:, -1] = exp_q[:, -1] * (1 - th.sum(terminated, dim=1)) targets_taken[:, :-1] = targets_taken[:, :-1] * mask exp_q[:, :-1] = exp_q[:, :-1] * mask td_q = (rewards + self.args.gamma * exp_q[:, 1:] - targets_taken[:, :-1]) * mask double_targets_taken[:, -1] = double_targets_taken[:, -1] * ( 1 - th.sum(terminated, dim=1)) double_exp_q[:, -1] = double_exp_q[:, -1] * (1 - th.sum(terminated, dim=1)) double_targets_taken[:, :-1] = double_targets_taken[:, :-1] * mask double_exp_q[:, :-1] = double_exp_q[:, :-1] * mask double_td_q = (rewards + self.args.gamma * double_exp_q[:, 1:] - double_targets_taken[:, :-1]) * mask target_q = build_target_q(td_q, targets_taken[:, :-1], critic_mac, mask, self.args.gamma, self.args.td_lambda, self.args.step).detach().repeat( 1, 1, self.n_agents) double_target_q = build_target_q( double_td_q, double_targets_taken[:, :-1], critic_mac, mask, self.args.gamma, self.args.td_lambda, self.args.step).detach().repeat(1, 1, self.n_agents) return target_q, double_target_q, target_inputs, mask, actions def build_exp_q(self, batch, mac_out, bs, max_t): # inputs for target net inputs = [] # state, obs, action inputs.append(batch["state"][:].unsqueeze(2).unsqueeze(2).repeat( 1, 1, self.args.n_sum, self.n_agents, 1)) inputs.append(batch["obs"][:].unsqueeze(2).repeat( 1, 1, self.args.n_sum, 1, 1)) # Sample n_sum number of possible actions and use importance sampling ac_sampler = Categorical( mac_out.unsqueeze(2).repeat(1, 1, self.args.n_sum, 1, 1) + 1e-10) actions = ac_sampler.sample().long().unsqueeze(4) action_one_hot = mac_out.new_zeros(bs, max_t, self.args.n_sum, self.n_agents, self.n_actions) action_one_hot = action_one_hot.scatter_(-1, actions, 1.0).view( bs, max_t, self.args.n_sum, 1, -1).repeat(1, 1, 1, self.n_agents, 1) agent_mask = (1 - th.eye(self.n_agents, device=batch.device)) agent_mask = agent_mask.view(-1, 1).repeat(1, self.n_actions).view( self.n_agents, -1) inputs.append(action_one_hot * (agent_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0))) # obs last action l_actions = batch["actions_onehot"][:].view(bs, max_t, 1, -1) if self.args.obs_last_action: last_action = [] last_action.append(l_actions[:, 0:1]) last_action.append(l_actions[:, :-1]) last_action = th.cat([x for x in last_action], dim=1) inputs.append( last_action.unsqueeze(2).repeat(1, 1, self.args.n_sum, self.n_agents, 1)) #agent id inputs.append( th.eye(self.n_agents, device=batch.device).unsqueeze(0).unsqueeze( 0).unsqueeze(0).expand(bs, max_t, self.args.n_sum, -1, -1)) inputs = th.cat([x for x in inputs], dim=-1) #E(V(s)) target_exp_q_vals = self.target_critic.forward(inputs).detach() target_exp_q_vals = th.gather(target_exp_q_vals, 4, actions).squeeze(-1).mean(dim=3) double_target_exp_q_vals = self.double_target_critic.forward( inputs).detach() double_target_exp_q_vals = th.gather(double_target_exp_q_vals, 4, actions).squeeze(-1).mean(dim=3) action_mac = mac_out.unsqueeze(2).repeat(1, 1, self.args.n_sum, 1, 1) action_mac = th.gather(action_mac, 4, actions).squeeze(-1) action_mac = th.prod(action_mac, 3) target_exp_q_vals = th.sum( target_exp_q_vals * action_mac, dim=2, keepdim=True) / (th.sum(action_mac, dim=2, keepdim=True) + 1e-10) double_target_exp_q_vals = th.sum( double_target_exp_q_vals * action_mac, dim=2, keepdim=True) / (th.sum(action_mac, dim=2, keepdim=True) + 1e-10) # target_exp_q_vals = th.sum(target_exp_q_vals, dim=2, keepdim=True) / self.args.n_sum return target_exp_q_vals, double_target_exp_q_vals def _update_targets(self): self.target_critic.load_state_dict(self.critic.state_dict()) self.double_critic.load_state_dict(self.double_critic.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.critic.cuda() self.target_critic.cuda() self.double_critic.cuda() self.double_target_critic.cuda() def save_models(self, path): self.mac.save_models(path) th.save(self.critic.state_dict(), "{}/critic.th".format(path)) th.save(self.double_critic.state_dict(), "{}/double_critic.th".format(path)) th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path)) th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path)) th.save(self.double_critic_optimiser.state_dict(), "{}/double_critic_opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) self.critic.load_state_dict( th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage)) self.double.load_state_dict( th.load("{}/double_critic.th".format(path), map_location=lambda storage, loc: storage)) # Not quite right but I don't want to save target networks self.target_critic.load_state_dict(self.critic.state_dict()) self.double_target_critic.load_state_dict( self.double_critic.state_dict()) self.agent_optimiser.load_state_dict( th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage)) self.critic_optimiser.load_state_dict( th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage)) self.double_critic_optimiser.load_state_dict( th.load("{}/double_critic_opt.th".format(path), map_location=lambda storage, loc: storage))
class QLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.mac = mac self.logger = logger self.params = list(mac.parameters()) self.last_target_update_episode = 0 self.mixer = None assert args.mixer == "qtran_alt" self.mixer = QTranAlt(args) self.params += list(self.mixer.parameters()) self.target_mixer = copy.deepcopy(self.mixer) self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC self.target_mac = copy.deepcopy(mac) self.log_stats_t = -self.args.learner_log_interval - 1 def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Pick the Q-Values for the actions taken by each agent chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze( 3) # Remove the last dim # Calculate the Q-Values necessary for the target target_mac_out = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = self.target_mac.forward(batch, t=t) target_mac_out.append(target_agent_outs) # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time # Mask out unavailable actions target_mac_out[avail_actions[:, 1:] == 0] = -9999999 # From OG deepmarl mac_out_maxs = mac_out.clone().detach() mac_out_maxs[avail_actions == 0] = -9999999 # Best joint action computed by target agents cur_max_actions = target_mac_out.max(dim=3, keepdim=True)[1] # Best joint-action computed by regular agents max_actions_qvals, max_actions_current = mac_out_maxs[:, :-1].max( dim=3, keepdim=True) counter_qs, vs = self.mixer(batch[:, :-1]) # Need to argmax across the target agents' actions # Convert cur_max_actions to one hot max_actions = th.zeros(size=(batch.batch_size, batch.max_seq_length - 1, self.args.n_agents, self.args.n_actions), device=batch.device) max_actions_onehot = max_actions.scatter(3, cur_max_actions, 1) max_actions_onehot_repeat = max_actions_onehot.repeat( 1, 1, self.args.n_agents, 1) agent_mask = (1 - th.eye(self.args.n_agents, device=batch.device)) agent_mask = agent_mask.view(-1, 1).repeat( 1, self.args.n_actions) #.view(self.n_agents, -1) masked_actions = max_actions_onehot_repeat * agent_mask.unsqueeze( 0).unsqueeze(0) masked_actions = masked_actions.view( -1, self.args.n_agents * self.args.n_actions) target_counter_qs, target_vs = self.target_mixer( batch[:, 1:], masked_actions) # Td loss td_target_qs = target_counter_qs.gather(1, cur_max_actions.view(-1, 1)) td_chosen_qs = counter_qs.gather(1, actions.contiguous().view(-1, 1)) td_targets = rewards.repeat(1, 1, self.args.n_agents).view( -1, 1) + self.args.gamma * (1 - terminated.repeat( 1, 1, self.args.n_agents).view(-1, 1)) * td_target_qs td_error = (td_chosen_qs - td_targets.detach()) td_mask = mask.repeat(1, 1, self.args.n_agents).view(-1, 1) masked_td_error = td_error * td_mask td_loss = (masked_td_error**2).sum() / td_mask.sum() # Opt loss # Computing the targets opt_max_actions = th.zeros( size=(batch.batch_size, batch.max_seq_length - 1, self.args.n_agents, self.args.n_actions), device=batch.device) opt_max_actions_onehot = opt_max_actions.scatter( 3, max_actions_current, 1) opt_max_actions_onehot_repeat = opt_max_actions_onehot.repeat( 1, 1, self.args.n_agents, 1) agent_mask = (1 - th.eye(self.args.n_agents, device=batch.device)) agent_mask = agent_mask.view(-1, 1).repeat(1, self.args.n_actions) opt_masked_actions = opt_max_actions_onehot_repeat * agent_mask.unsqueeze( 0).unsqueeze(0) opt_masked_actions = opt_masked_actions.view( -1, self.args.n_agents * self.args.n_actions) opt_target_qs, opt_vs = self.mixer(batch[:, :-1], opt_masked_actions) opt_error = max_actions_qvals.squeeze(3).sum( dim=2, keepdim=True).repeat(1, 1, self.args.n_agents).view( -1, 1) - opt_target_qs.gather(1, max_actions_current.view( -1, 1)).detach() + opt_vs opt_loss = ((opt_error * td_mask)**2).sum() / td_mask.sum() # NOpt loss qsums = chosen_action_qvals.clone().unsqueeze(2).repeat( 1, 1, self.args.n_agents, 1).view(-1, self.args.n_agents) ids_to_zero = th.tensor([i for i in range(self.args.n_agents)], device=batch.device).repeat( batch.batch_size * (batch.max_seq_length - 1)) qsums.scatter(1, ids_to_zero.unsqueeze(1), 0) nopt_error = mac_out[:, :-1].contiguous().view( -1, self.args.n_actions) + qsums.sum( dim=1, keepdim=True) - counter_qs.detach() + vs min_nopt_error = th.min(nopt_error, dim=1, keepdim=True)[0] nopt_loss = ((min_nopt_error * td_mask)**2).sum() / td_mask.sum() loss = td_loss + self.args.opt_loss * opt_loss + self.args.nopt_min_loss * nopt_loss # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.optimiser.step() if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat( "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (td_targets * td_mask).sum().item() / td_mask.sum().item(), t_env) self.log_stats_t = t_env def _update_targets(self): self.target_mac.load_state(self.mac) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() if self.mixer is not None: self.mixer.cuda() self.target_mixer.cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) # Not quite right but I don't want to save target networks self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict( th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict( th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
class COMALearner: def __init__(self, param_set, env_info): self.n_action = env_info["n_actions"] self.n_agent = env_info["n_agents"] self.obs_shape = env_info['obs_shape'] self.gamma = param_set['gamma'] self.td_lambda = param_set['td_lambda'] self.learning_rate = param_set['learning_rate'] self.alpha = param_set['alpha'] self.eps = param_set['eps'] self.grad_norm_clip = param_set['grad_norm_clip'] self.obs_last_action = True self.obs_agent_id = True output_shape = self.n_action input_shape = self._get_Q_input_shape() self.Q = RNNAgent(input_shape, output_shape) self.critic = COMACritic(env_info) self.target_critic = copy.deepcopy(self.critic) self.agent_params = list(self.Q.parameters()) self.critic_params = list(self.critic.parameters()) self.params = self.agent_params + self.critic_params self.agent_optimiser = RMSprop(params=self.agent_params, lr=self.learning_rate, alpha=self.alpha, eps=self.eps) self.critic_optimiser = RMSprop(params=self.critic_params, lr=self.learning_rate, alpha=self.alpha, eps=self.eps) self.critic_training_steps = 0 self.target_update_interval = param_set['target_update_interval'] self.last_target_update_step = 0 def _get_Q_input_shape(self): input_shape = self.obs_shape if self.obs_last_action: input_shape += self.n_action if self.obs_agent_id: input_shape += self.n_agent return input_shape def _train_critic(self, batch): reward = batch["reward"] action = batch["action"] done = batch["done"] batch_size = batch['batch_size'] mask = batch['mask'] # Optimise critic target_q_vals = self.target_critic(batch)[:, : -1] #batch:t:agent:action # print('target_q_vals.shape:', target_q_vals.shape, action.unsqueeze(3).shape) targets_taken = th.gather(target_q_vals, dim=3, index=action.unsqueeze(3)).squeeze(3) # Calculate td-lambda targets targets = build_td_lambda_targets(reward, done, mask, targets_taken, self.n_agent, self.gamma, self.td_lambda) q_vals = th.zeros_like(target_q_vals) # print('reward:', reward.shape) # batch:agent:t = mask.shpe # print("targets:", targets.shape) # batch:t:agent # print('q_vals:', q_vals.shape) #batch:t:agent:action for t in reversed(range(reward.size(2))): mask_t = mask[:, :, t] # batch:agent: if mask_t.sum() == 0: continue q_t = self.critic(batch, t) # batch:1:agent:action q_vals[:, t] = q_t.view(batch_size, self.n_agent, self.n_action) # print('q_t.shape ', q_t.shape) # print("action.shape ", (action[:, t]).unsqueeze(1).unsqueeze(3).shape) q_taken = th.gather( q_t, dim=3, index=( action[:, t]).unsqueeze(1).unsqueeze(3)).squeeze(3).squeeze( 1) # batch:agent targets_t = targets[:, t] # batch:agent td_error = (q_taken - targets_t.detach()) # 0-out the targets that came from padded data masked_td_error = td_error * mask_t # Normal L2 loss, take mean over actual data loss = (masked_td_error**2).sum() / mask_t.sum() self.critic_optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.critic_params, self.grad_norm_clip) self.critic_optimiser.step() self.critic_training_steps += 1 return q_vals def _update_critic_targets(self): self.target_critic.load_state_dict(self.critic.state_dict()) def init_Q_hidden(self, batch_size): self.hidden_states = self.Q.init_hidden().unsqueeze(0).expand( batch_size, self.n_agent, -1) # bav def _build_input_forQ(self, batch, t): batch_size = batch['batch_size'] inputs = [] inputs.append(batch["observation"][:, t]) # b1av if self.obs_last_action: if t == 0: inputs.append( th.zeros((batch_size, self.n_agent, self.n_action))) else: inputs.append(batch["action_onehot"][:, t - 1]) if self.obs_agent_id: inputs.append( th.eye(self.n_agent).unsqueeze(0).expand(batch_size, -1, -1)) inputs = th.cat( [x.reshape(batch_size * self.n_agent, -1) for x in inputs], dim=1) return inputs def Q_forward(self, batch, t): inputs = self._build_input_forQ(batch, t) outs, self.hidden_states = self.Q(inputs, self.hidden_states) return outs def train(self, batch, t_env): reward = batch["reward"] action = batch["action"] done = batch["done"] avail_action = batch["available_action"] batch_size = batch['batch_size'] mask = batch['mask'] action = action.transpose(1, 2).reshape(batch_size * self.n_agent, -1).unsqueeze(2) avail_action = avail_action[:, 1:].transpose(1, 2).reshape( batch_size * self.n_agent, -1, self.n_action) done = done.reshape(batch_size * self.n_agent, -1) mask = mask.reshape(batch_size * self.n_agent, -1) reward = reward.reshape(batch_size * self.n_agent, -1) q_val = self._train_critic(batch) self.init_Q_hidden(batch_size) out_batch = [] for t in range(batch['len'] - 1): outs = self.Q_forward(batch, t) out_batch.append(outs) out_batch = th.stack(out_batch, dim=1) out_batch[avail_action == 0] = 0 out_batch = out_batch / out_batch.sum(dim=-1, keepdim=True) out_batch[avail_action == 0] = 0 # chosen_action_qvals = th.gather(out_batch[:, :-1], dim=2, index=action).squeeze(2) # print('q_val:',q_val.shape) # batch:t:agent:action # print('out_batch',out_batch.shape) # batch * agent: t: action # print('action:', action.shape) # #batch * agent :t:1 # print('mask', mask.shape) # #batch * agent :t # Calculated baseline q_val = q_val.transpose(1, 2).reshape(-1, self.n_action) pi = out_batch.view(-1, self.n_action) baseline = (pi * q_val).sum(-1).detach() # Calculate policy grad with mask q_taken = th.gather(q_val, dim=1, index=action.reshape(-1, 1)).squeeze(1) pi_taken = th.gather(pi, dim=1, index=action.reshape(-1, 1)).squeeze(1) mask = mask.view(-1) pi_taken[mask == 0] = 1.0 log_pi_taken = th.log(pi_taken) advantages = (q_taken - baseline).detach() coma_loss = -((advantages * log_pi_taken) * mask).sum() / mask.sum() # Optimise agents self.agent_optimiser.zero_grad() coma_loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.grad_norm_clip) self.agent_optimiser.step() if (self.critic_training_steps - self.last_target_update_step ) / self.target_update_interval >= 1.0: print('Critic updated') self._update_critic_targets() self.last_target_update_step = self.critic_training_steps def approximate_Q(self, batch): self.init_Q_hidden(batch['batch_size']) for t in range(batch['len']): outs = self.Q_forward(batch, t) return outs def save_model(self, path): th.save(self.Q.state_dict(), path + 'Q' + '.pt') th.save(self.critic.state_dict(), "{}critic.th".format(path)) th.save(self.agent_optimiser.state_dict(), "{}agent_opt.th".format(path)) th.save(self.critic_optimiser.state_dict(), "{}critic_opt.th".format(path)) def load_model(self, path): file = path + 'load_Q.pt' file_c = path + 'load_critic.th' file_aopt = path + 'load_agent_opt.th' file_copt = path + 'load_critic_opt.th' if not os.path.isfile(file): file = path + 'best_Q.pt' file_c = path + 'best_critic.th' file_aopt = path + 'best_agent_opt.th' file_copt = path + 'best_critic_opt.th' if not os.path.isfile(file): file = path + 'Q.pt' file_c = path + 'critic.th' file_aopt = path + 'agent_opt.th' file_copt = path + 'critic_opt.th' if not os.path.isfile(file): print("here have not such model") return self.Q.load_state_dict(th.load(file)) self.critic.load_state_dict(th.load(file_c)) self.target_critic.load_state_dict(self.critic.state_dict()) self.agent_optimiser.load_state_dict(th.load(file_aopt)) self.critic_optimiser.load_state_dict(th.load(file_copt)) print('sucess load the model in ', file) return
class LIIRLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.n_agents = args.n_agents self.n_actions = args.n_actions self.mac = mac self.logger = logger self.last_target_update_step = 0 self.critic_training_steps = 0 self.log_stats_t = -self.args.learner_log_interval - 1 self.critic = LIIRCritic(scheme, args) self.target_critic = copy.deepcopy(self.critic) self.policy_new = copy.deepcopy(self.mac) self.policy_old = copy.deepcopy(self.mac) if self.args.use_cuda: # following two lines should be used when use GPU self.policy_old.agent = self.policy_old.agent.to("cuda") self.policy_new.agent = self.policy_new.agent.to("cuda") else: # following lines should be used when use CPU, self.policy_old.agent = self.policy_old.agent.to("cpu") self.policy_new.agent = self.policy_new.agent.to("cpu") self.agent_params = list(mac.parameters()) self.critic_params = list(self.critic.fc1.parameters()) + list( self.critic.fc2.parameters()) + list( self.critic.fc3_v_mix.parameters()) self.intrinsic_params = list(self.critic.fc3_r_in.parameters()) + list( self.critic.fc4.parameters()) # to do self.params = self.agent_params + self.critic_params + self.intrinsic_params self.agent_optimiser = RMSprop(params=self.agent_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) self.critic_optimiser = RMSprop(params=self.critic_params, lr=args.critic_lr, alpha=args.optim_alpha, eps=args.optim_eps) self.intrinsic_optimiser = RMSprop( params=self.intrinsic_params, lr=args.critic_lr, alpha=args.optim_alpha, eps=args.optim_eps) # should distinguish them self.update = 0 self.count = 0 def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities bs = batch.batch_size max_t = batch.max_seq_length rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"][:, :-1] critic_mask = mask.clone() mask_long = mask.repeat(1, 1, self.n_agents).view(-1, 1) mask = mask.view(-1, 1) avail_actions1 = avail_actions.reshape(-1, self.n_agents, self.n_actions) # [maskxx,:] mask_alive = 1.0 - avail_actions1[:, :, 0] mask_alive = mask_alive.float() q_vals, critic_train_stats, target_mix, target_ex, v_ex, r_in = self._train_critic( batch, rewards, terminated, actions, avail_actions, critic_mask, bs, max_t) actions = actions[:, :-1] mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length - 1): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Mask out unavailable actions, renormalise (as in action selection) mac_out[avail_actions == 0] = 0 mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True) mac_out[avail_actions == 0] = 0 # Calculated baseline q_vals = q_vals.reshape(-1, 1) pi = mac_out.view(-1, self.n_actions) # Calculate policy grad with mask pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1) pi_taken[mask_long.squeeze(-1) == 0] = 1.0 log_pi_taken = th.log(pi_taken) advantages = (target_mix.reshape(-1, 1) - q_vals).detach() advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) log_pi_taken = log_pi_taken.reshape(-1, self.n_agents) log_pi_taken = log_pi_taken * mask_alive log_pi_taken = log_pi_taken.reshape(-1, 1) liir_loss = -( (advantages * log_pi_taken) * mask_long).sum() / mask_long.sum() # Optimise agents self.agent_optimiser.zero_grad() liir_loss.backward() grad_norm_policy = th.nn.utils.clip_grad_norm_( self.agent_params, self.args.grad_norm_clip) self.agent_optimiser.step() # _________Intrinsic loss optimizer -------------------- # ____value loss v_ex_loss = (((v_ex - target_ex.detach())**2).view(-1, 1) * mask).sum() / mask.sum() # _____pg1____ mac_out_old = [] self.policy_old.init_hidden(batch.batch_size) for t in range(batch.max_seq_length - 1): agent_outs_tmp = self.policy_old.forward(batch, t=t, test_mode=True) mac_out_old.append(agent_outs_tmp) mac_out_old = th.stack(mac_out_old, dim=1) # Concat over time # Mask out unavailable actions, renormalise (as in action selection) mac_out_old[avail_actions == 0] = 0 mac_out_old = mac_out_old / mac_out.sum(dim=-1, keepdim=True) mac_out_old[avail_actions == 0] = 0 pi_old = mac_out_old.view(-1, self.n_actions) # Calculate policy grad with mask pi_taken_old = th.gather(pi_old, dim=1, index=actions.reshape(-1, 1)).squeeze(1) pi_taken_old[mask_long.squeeze(-1) == 0] = 1.0 log_pi_taken_old = th.log(pi_taken_old) log_pi_taken_old = log_pi_taken_old.reshape(-1, self.n_agents) log_pi_taken_old = log_pi_taken_old * mask_alive # ______pg2___new pi theta self._update_policy() # update policy_new to new params mac_out_new = [] self.policy_new.init_hidden(batch.batch_size) for t in range(batch.max_seq_length - 1): agent_outs_tmp = self.policy_new.forward(batch, t=t, test_mode=True) mac_out_new.append(agent_outs_tmp) mac_out_new = th.stack(mac_out_new, dim=1) # Concat over time # Mask out unavailable actions, renormalise (as in action selection) mac_out_new[avail_actions == 0] = 0 mac_out_new = mac_out_new / mac_out.sum(dim=-1, keepdim=True) mac_out_new[avail_actions == 0] = 0 pi_new = mac_out_new.view(-1, self.n_actions) # Calculate policy grad with mask pi_taken_new = th.gather(pi_new, dim=1, index=actions.reshape(-1, 1)).squeeze(1) pi_taken_new[mask_long.squeeze(-1) == 0] = 1.0 log_pi_taken_new = th.log(pi_taken_new) log_pi_taken_new = log_pi_taken_new.reshape(-1, self.n_agents) log_pi_taken_new = log_pi_taken_new * mask_alive neglogpac_new = -log_pi_taken_new.sum(-1) pi2 = log_pi_taken.reshape(-1, self.n_agents).sum(-1).clone() ratio_new = th.exp(-pi2 - neglogpac_new) adv_ex = (target_ex - v_ex.detach()).detach() adv_ex = (adv_ex - adv_ex.mean()) / (adv_ex.std() + 1e-8) # _______ gadient for pg 1 and 2--- mask_tnagt = critic_mask.repeat(1, 1, self.n_agents) pg_loss1 = (log_pi_taken_old.view(-1, 1) * mask_long).sum() / mask_long.sum() pg_loss2 = ((adv_ex.view(-1) * ratio_new) * mask.squeeze(-1)).sum() / mask.sum() self.policy_old.agent.zero_grad() pg_loss1_grad = th.autograd.grad(pg_loss1, self.policy_old.parameters()) self.policy_new.agent.zero_grad() pg_loss2_grad = th.autograd.grad(pg_loss2, self.policy_new.parameters()) grad_total = 0 for grad1, grad2 in zip(pg_loss1_grad, pg_loss2_grad): grad_total += (grad1 * grad2).sum() target_mix = target_mix.reshape(-1, max_t - 1, self.n_agents) pg_ex_loss = ((grad_total.detach() * target_mix) * mask_tnagt).sum() / mask_tnagt.sum() intrinsic_loss = pg_ex_loss + vf_coef * v_ex_loss self.intrinsic_optimiser.zero_grad() intrinsic_loss.backward() self.intrinsic_optimiser.step() self._update_policy_piold() # ______config tensorboard if (self.critic_training_steps - self.last_target_update_step ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_step = self.critic_training_steps if t_env - self.log_stats_t >= self.args.learner_log_interval: ts_logged = len(critic_train_stats["critic_loss"]) for key in [ "critic_loss", "critic_grad_norm", "td_error_abs", "value_mean", "target_mean" ]: self.logger.log_stat(key, sum(critic_train_stats[key]) / ts_logged, t_env) self.logger.log_stat("advantage_mean", (advantages * mask_long).sum().item() / mask_long.sum().item(), t_env) self.logger.log_stat("liir_loss", liir_loss.item(), t_env) self.logger.log_stat("agent_grad_norm", grad_norm_policy, t_env) self.logger.log_stat( "pi_max", (pi.max(dim=1)[0] * mask_long.squeeze(-1)).sum().item() / mask_long.sum().item(), t_env) reward1 = rewards.reshape(-1, 1) self.logger.log_stat('rewards_mean', (reward1 * mask).sum().item() / mask.sum().item(), t_env) self.log_stats_t = t_env def _train_critic(self, batch, rewards, terminated, actions, avail_actions, mask, bs, max_t): # Optimise critic r_in, target_vals, target_val_ex = self.target_critic(batch) r_in, _, target_val_ex_opt = self.critic(batch) r_in_taken = th.gather(r_in, dim=3, index=actions) r_in = r_in_taken.squeeze(-1) target_vals = target_vals.squeeze(-1) targets_mix, targets_ex = build_td_lambda_targets_v2( rewards, terminated, mask, target_vals, self.n_agents, self.args.gamma, self.args.td_lambda, r_in, target_val_ex) vals_mix = th.zeros_like(target_vals)[:, :-1] vals_ex = target_val_ex_opt[:, :-1] running_log = { "critic_loss": [], "critic_grad_norm": [], "td_error_abs": [], "target_mean": [], "value_mean": [], } for t in reversed(range(rewards.size(1))): mask_t = mask[:, t].expand(-1, self.n_agents) if mask_t.sum() == 0: continue _, q_t, _ = self.critic(batch, t) # 8,1,3,1, vals_mix[:, t] = q_t.view(bs, self.n_agents) targets_t = targets_mix[:, t] td_error = (q_t.view(bs, self.n_agents) - targets_t.detach()) # 0-out the targets that came from padded data masked_td_error = td_error * mask_t # Normal L2 loss, take mean over actual data loss = (masked_td_error**2).sum() / mask_t.sum() self.critic_optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.critic_params, self.args.grad_norm_clip) self.critic_optimiser.step() self.critic_training_steps += 1 running_log["critic_loss"].append(loss.item()) running_log["critic_grad_norm"].append(grad_norm) mask_elems = mask_t.sum().item() running_log["td_error_abs"].append( (masked_td_error.abs().sum().item() / mask_elems)) running_log["value_mean"].append( (q_t.view(bs, self.n_agents) * mask_t).sum().item() / mask_elems) running_log["target_mean"].append( (targets_t * mask_t).sum().item() / mask_elems) return vals_mix, running_log, targets_mix, targets_ex, vals_ex, r_in def _update_targets(self): self.target_critic.load_state_dict(self.critic.state_dict()) self.logger.console_logger.info("Updated target network") def _update_policy(self): self.policy_new.load_state(self.mac) def _update_policy_piold(self): self.policy_old.load_state(self.mac) def cuda(self): self.mac.cuda() self.critic.cuda() self.target_critic.cuda() def save_models(self, path): self.mac.save_models(path) th.save(self.critic.state_dict(), "{}/critic.th".format(path)) th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path)) th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) self.critic.load_state_dict( th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage)) self.target_critic.load_state_dict(self.critic.state_dict()) self.agent_optimiser.load_state_dict( th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage)) self.critic_optimiser.load_state_dict( th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage))
class PolicyGradientLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.n_agents = args.n_agents self.n_actions = args.n_actions self.mac = mac self.logger = logger self.last_target_update_step = 0 #self.critic_training_steps = 0 self.log_stats_t = -self.args.learner_log_interval - 1 #if args.critic_fact is not None: # self.critic = FactoredCentralVCritic(scheme, args) #else: # self.critic = CentralVCritic(scheme, args) #self.target_critic = copy.deepcopy(self.critic) self.agent_params = list(mac.parameters()) #self.critic_params = list(self.critic.parameters()) self.params = self.agent_params #+ self.critic_params self.agent_optimiser = RMSprop(params=self.agent_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) #self.critic_optimiser = RMSprop(params=self.critic_params, lr=args.critic_lr, alpha=args.optim_alpha, eps=args.optim_eps) def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"][:, :-1] mask = mask.repeat(1, 1, self.n_agents) #critic_mask = mask.clone() #get pilogits mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length - 1): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Mask out unavailable actions, renormalise (as in action selection) mac_out[avail_actions == 0] = 0 mac_out = mac_out/mac_out.sum(dim=-1, keepdim=True) mac_out[avail_actions == 0] = 0 pi = mac_out pi_taken = th.gather(pi, dim=3, index=actions).squeeze(3) pi_taken[mask == 0] = 1.0 log_pi_taken = th.log(pi_taken) #get V-values from Central V critic #q_sa, v_vals, critic_train_stats = self._train_critic(batch, rewards, terminated, critic_mask) #baseline = v_vals q_sa = self.returns(rewards, mask) #use no baseline---just vanilla policy gradient with RNNs #advantages = (q_sa - baseline).detach().squeeze() advantages = q_sa.detach().squeeze() entropy = -(pi * th.log(pi) * mask[:, :, :, None]).sum() / (mask.sum() * pi.shape[-1]) centralV_loss = - ((advantages * log_pi_taken) * mask).sum() / mask.sum() - self.args.entropy_alpha*entropy # Optimise agents self.agent_optimiser.zero_grad() centralV_loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip) self.agent_optimiser.step() #if (self.critic_training_steps - self.last_target_update_step) / self.args.target_update_interval >= 1.0: # self._update_targets() # self.last_target_update_step = self.critic_training_steps if t_env - self.log_stats_t >= self.args.learner_log_interval: #ts_logged = len(critic_train_stats["critic_loss"]) #for key in ["critic_loss", "critic_grad_norm", "td_error_abs", "q_taken_mean", "target_mean"]: # self.logger.log_stat(key, sum(critic_train_stats[key])/ts_logged, t_env) self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("centralV_loss", centralV_loss.item(), t_env) self.logger.log_stat("agent_grad_norm", grad_norm, t_env) self.logger.log_stat("pi_entropy", entropy.item(), t_env) self.logger.log_stat("pi_max", (pi.max(dim=-1)[0] * mask).sum().item() / mask.sum().item(), t_env) self.log_stats_t = t_env def returns(self, rewards, mask): nstep_values = th.zeros_like(mask) for t_start in range(rewards.size(1)): nstep_return_t = th.zeros_like(mask[:, 0]) for step in range(rewards.size(1)): t = t_start + step if t >= rewards.size(1) - 1: break #elif step == nsteps: # nstep_return_t += self.args.gamma ** (step) * values[:, t] * mask[:, t] #elif t == rewards.size(1) - 1: # #nstep_return_t += self.args.gamma ** (step) * values[:, t] * mask[:, t] else: nstep_return_t += self.args.gamma ** (step) * rewards[:, t] * mask[:, t] nstep_values[:, t_start, :] = nstep_return_t return nstep_values def cuda(self): self.mac.cuda() self.critic.cuda() self.target_critic.cuda() def save_models(self, path): self.mac.save_models(path) #th.save(self.critic.state_dict(), "{}/critic.th".format(path)) th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path)) #th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) #self.critic.load_state_dict(th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage)) # Not quite right but I don't want to save target networks #self.target_critic.load_state_dict(self.critic.state_dict()) self.agent_optimiser.load_state_dict(th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage))
class PGLearner_v2: def __init__(self, mac, scheme, logger, args): self.args = args self.n_agents = args.n_agents self.n_actions = args.n_actions self.mac = mac self.logger = logger self.last_target_update_step = 0 self.critic_training_steps = 0 self.log_stats_t = -self.args.learner_log_interval - 1 self.target_mac = copy.deepcopy(mac) self.params = list(self.mac.parameters()) if args.mixer is not None: if args.mixer == "vdn": self.mixer = VDNMixer() elif args.mixer == "qmix": self.mixer = QMixer(args) else: raise ValueError("Mixer {} not recognised.".format(args.mixer)) self.params += list(self.mixer.parameters()) if self.args.optim == 'adam': self.optimiser = Adam(params=self.params, lr=args.lr) else: self.optimiser = RMSprop(params=self.params, lr=args.lr) def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities bs = batch.batch_size max_t = batch.max_seq_length rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"][:, :] critic_mask = mask.clone() mask = mask.repeat(1, 1, self.n_agents).view(-1) advantages, td_error, targets_taken, log_pi_taken, entropy = self._calculate_advs(batch, rewards, terminated, actions, avail_actions, critic_mask, bs, max_t) pg_loss = - ((advantages.detach() * log_pi_taken) * mask).sum() / mask.sum() vf_loss = ((td_error ** 2) * mask).sum() / mask.sum() entropy[mask == 0] = 0 entropy_loss = (entropy * mask).sum() / mask.sum() coma_loss = pg_loss + self.args.vf_coef * vf_loss if self.args.ent_coef: coma_loss -= self.args.ent_coef * entropy_loss # Optimise agents self.optimiser.zero_grad() coma_loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.optimiser.step() if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("critic_loss", ((td_error ** 2) * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("td_error_abs", (td_error.abs() * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("q_taken_mean", (targets_taken * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("target_mean", ((targets_taken + advantages) * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("pg_loss", - ((advantages.detach() * log_pi_taken) * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("coma_loss", coma_loss.item(), t_env) self.logger.log_stat("entropy_loss", entropy_loss.item(), t_env) self.logger.log_stat("agent_grad_norm", grad_norm, t_env) # self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env) self.log_stats_t = t_env def _calculate_advs(self, batch, rewards, terminated, actions, avail_actions, mask, bs, max_t): mac_out = [] q_outs = [] # Roll out experiences self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_out, q_out = self.mac.forward(batch, t=t) mac_out.append(agent_out) q_outs.append(q_out) mac_out = th.stack(mac_out, dim=1) # Concat over time q_outs = th.stack(q_outs, dim=1) # Concat over time # Mask out unavailable actions, renormalise (as in action selection) # mac_out[avail_actions == 0] = 0 # mac_out = mac_out/(mac_out.sum(dim=-1, keepdim=True) + 1e-5) # Calculated baseline pi = mac_out[:, :-1] #[bs, t, n_agents, n_actions] pi_taken = th.gather(pi, dim=-1, index=actions[:, :-1]).squeeze(-1) #[bs, t, n_agents] action_mask = mask.repeat(1, 1, self.n_agents) pi_taken[action_mask == 0] = 1.0 log_pi_taken = th.log(pi_taken).reshape(-1) # Calculate entropy entropy = categorical_entropy(pi).reshape(-1) #[bs, t, n_agents, 1] # Calculate q targets targets_taken = q_outs.squeeze(-1) #[bs, t, n_agents] if self.args.mixer: targets_taken = self.mixer(targets_taken, batch["state"][:, :]) #[bs, t, 1] # Calculate td-lambda targets targets = build_td_lambda_targets(rewards, terminated, mask, targets_taken, self.n_agents, self.args.gamma, self.args.td_lambda) advantages = targets - targets_taken[:, :-1] advantages = advantages.unsqueeze(2).repeat(1, 1, self.n_agents, 1).reshape(-1) td_error = targets_taken[:, :-1] - targets.detach() td_error = td_error.unsqueeze(2).repeat(1, 1, self.n_agents, 1).reshape(-1) return advantages, td_error, targets_taken[:, :-1].unsqueeze(2).repeat(1, 1, self.n_agents, 1).reshape(-1), log_pi_taken, entropy def cuda(self): self.mac.cuda() if self.args.mixer: self.mixer.cuda() def save_models(self, path): self.mac.save_models(path) if self.args.mixer: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) if self.args.mixer: self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
class Solver(object): def __init__(self, args): self.use_cuda = args.cuda and torch.cuda.is_available() self.device = torch.device("cuda:0" if self.use_cuda else "cpu") """ /begin of non-generic part that needs to be modified for each new model """ # model name is used for checkpointing (and here for setting self.net) self.model = args.model self.image_size = args.image_size self.n_latent = args.n_latent self.img_channels = args.img_channels # beta for beta VAE if args.beta_is_normalized: self.beta_norm = args.beta self.beta = beta_from_normalized_beta( self.beta_norm, N=self.image_size * self.image_size * self.img_channels, M=args.n_latent) else: self.beta = args.beta self.beta_norm = normalized_beta_from_beta( self.beta, N=self.image_size * self.image_size * self.img_channels, M=args.n_latent) self.gamma = args.gamma dataloaderparams = { 'batch_size': args.batch_size, 'shuffle': args.shuffle, 'num_workers': args.num_workers } if args.dataset.lower() == 'dsprites_circle': self.train_loader = torch.utils.data.DataLoader( dSpriteBackgroundDatasetTime(transform=transforms.Resize( (self.image_size, self.image_size)), shapetype='circle'), **dataloaderparams) elif args.dataset.lower() == 'dsprites': self.train_loader = torch.utils.data.DataLoader( dSpriteBackgroundDatasetTime(transform=transforms.Resize( (self.image_size, self.image_size)), shapetype='dsprite'), **dataloaderparams) if args.model.lower() == "dynamicvae32": net = dynamicVAE32 self.modeltype = 'dynamicVAE' else: raise Exception('model "%s" unknown' % args.model) self.net = net(n_latent=self.n_latent, img_channels=self.img_channels).to(self.device) self.reconstruction_loss = reconstruction_loss self.kl_divergence = kl_divergence self.prediction_loss = prediction_loss self.loss = loss_function self.lr = args.lr self.optim = RMSprop(self.net.parameters(), lr=self.lr) """ /end of non-generic part that needs to be modified for each new model """ self.max_iter = args.max_iter self.global_iter = 0 # prepare checkpointing if not os.path.isdir(args.ckpt_dir): os.mkdir(args.ckpt_dir) self.ckpt_dir = args.ckpt_dir self.load_last_checkpoint = args.load_last_checkpoint self.ckpt_name = '{}_nlatent={}_betanorm={}_gamma={}_{}_last'.format( self.model.lower(), self.n_latent, self.beta_norm, self.gamma, args.dataset.lower()) self.save_step = args.save_step if self.load_last_checkpoint is not None: self.load_checkpoint(self.ckpt_name) self.display_step = args.display_step # will store training-related information self.trainstats_gather_step = args.trainstats_gather_step self.trainstats_dir = args.trainstats_dir if not os.path.isdir(self.trainstats_dir): os.mkdir(self.trainstats_dir) self.trainstats_fname = '{}_nlatent={}_betanorm={}_gamma={}_{}'.format( self.model.lower(), self.n_latent, self.beta_norm, self.gamma, args.dataset.lower()) self.gather = DataGather( filename=os.path.join(self.trainstats_dir, self.trainstats_fname)) def train(self, plotmode=False): pbar = tqdm(total=self.max_iter) pbar.update(self.global_iter) out = False running_loss_terminal_display = 0.0 # running loss for the trainstats (gathered and pickeled) running_loss_trainstats = 0.0 # running loss for the trainstats (gathered and pickeled) """ /begin of non-generic part (might need to be adapted for different models / data)""" running_recon_loss_trainstats = 0.0 running_pred_loss_trainstats = 0.0 running_total_kld = 0.0 running_dim_wise_kld = 0.0 running_mean_kld = 0.0 plot_total_loss = [] plot_kld = [] plot_recon_loss = [] plot_pred_loss = [] while not out: for [ samples, latents ] in self.train_loader: # not sure how long the train_loader spits out data (possibly infinite?) self.global_iter += 1 if not plotmode: pbar.update(1) self.net.zero_grad() # get current batch and push to device img_batch, _ = samples.to(self.device), latents.to(self.device) # in VAE, input = output/target if self.modeltype == 'dynamicVAE': input_batch = img_batch output_batch = img_batch predicted_batch, mu, logvar, mu_pred = self.net(input_batch) recon_loss = self.reconstruction_loss(x=output_batch, x_recon=predicted_batch) total_kld, dimension_wise_kld, mean_kld = self.kl_divergence( mu, logvar) pred_loss = self.prediction_loss(mu, mu_pred) actLoss = self.loss(recon_loss=recon_loss, total_kld=total_kld, pred_loss=pred_loss, beta=self.beta, gamma=self.gamma) actLoss.backward() self.optim.step() running_loss_terminal_display += actLoss.item() running_loss_trainstats += actLoss.item() running_recon_loss_trainstats += recon_loss.item() running_pred_loss_trainstats += pred_loss.item() running_total_kld += total_kld.item() running_dim_wise_kld += dimension_wise_kld.cpu().detach( ).numpy() running_mean_kld += mean_kld.item() # update gather object with training information if self.global_iter % self.trainstats_gather_step == 0: running_loss_trainstats = running_loss_trainstats / self.trainstats_gather_step running_recon_loss_trainstats = running_recon_loss_trainstats / self.trainstats_gather_step running_pred_loss_trainstats = running_pred_loss_trainstats / self.trainstats_gather_step running_total_kld = running_total_kld / self.trainstats_gather_step running_dim_wise_kld = running_dim_wise_kld / self.trainstats_gather_step running_mean_kld = running_mean_kld / self.trainstats_gather_step self.gather.insert( iter=self.global_iter, total_loss=running_loss_trainstats, target=output_batch[0].detach().cpu().numpy(), reconstructed=predicted_batch[0].detach().cpu().numpy( ), recon_loss=running_recon_loss_trainstats, pred_loss=running_pred_loss_trainstats, total_kld=running_total_kld, dim_wise_kld=running_dim_wise_kld, mean_kld=running_mean_kld, ) running_loss_trainstats = 0.0 running_recon_loss_trainstats = 0.0 running_pred_loss_trainstats = 0.0 running_total_kld = 0.0 running_dim_wise_kld = 0.0 running_mean_kld = 0.0 if plotmode: # plot mini-batches plot_kld.append(total_kld.detach().cpu().numpy()) plot_total_loss.append(actLoss.item()) plot_recon_loss.append(recon_loss.item()) plot_pred_loss.append(pred_loss.item()) # PLOT! clear_output(wait=True) fig = plt.figure(figsize=(10, 8)) plt.subplot(4, 4, 1) plt.plot(plot_total_loss) plt.xlabel('minibatches') plt.title('Total loss') plt.subplot(4, 4, 2) plt.plot(plot_kld) plt.xlabel('minibatches') plt.title('Total KL-divergence') plt.subplot(4, 4, 3) plt.plot(plot_recon_loss) plt.xlabel('minibatches') plt.title('Reconstruction training loss') plt.subplot(4, 4, 4) plt.plot(plot_pred_loss) plt.xlabel('minibatches') plt.title('Prediction training loss') # import ipdb; ipdb.set_trace() plt.subplot(4, 4, 5) plt.imshow(input_batch[0][2][0].detach().cpu().numpy()) plt.set_cmap('gray') plt.subplot(4, 4, 6) plt.imshow(input_batch[0][4][0].detach().cpu().numpy()) plt.set_cmap('gray') plt.subplot(4, 4, 7) plt.imshow(input_batch[0][6][0].detach().cpu().numpy()) plt.set_cmap('gray') plt.subplot(4, 4, 8) plt.imshow(input_batch[0][8][0].detach().cpu().numpy()) plt.set_cmap('gray') plt.subplot(4, 4, 9) plt.imshow( predicted_batch[0][2][0].detach().cpu().numpy()) plt.set_cmap('gray') plt.subplot(4, 4, 10) plt.imshow( predicted_batch[0][4][0].detach().cpu().numpy()) plt.set_cmap('gray') plt.subplot(4, 4, 11) plt.imshow( predicted_batch[0][6][0].detach().cpu().numpy()) plt.set_cmap('gray') plt.subplot(4, 4, 12) plt.imshow( predicted_batch[0][8][0].detach().cpu().numpy()) plt.set_cmap('gray') plt.subplot(4, 4, 13) img = plt.imshow( (input_batch[0][2][0] - predicted_batch[0][2][0]).detach().cpu().numpy()) plt.set_cmap('bwr') colorAxisNormalize(fig.colorbar(img)) plt.subplot(4, 4, 14) img = plt.imshow( (input_batch[0][4][0] - predicted_batch[0][4][0]).detach().cpu().numpy()) plt.set_cmap('bwr') colorAxisNormalize(fig.colorbar(img)) plt.subplot(4, 4, 15) img = plt.imshow( (input_batch[0][6][0] - predicted_batch[0][6][0]).detach().cpu().numpy()) plt.set_cmap('bwr') colorAxisNormalize(fig.colorbar(img)) plt.subplot(4, 4, 16) img = plt.imshow( (input_batch[0][8][0] - predicted_batch[0][8][0]).detach().cpu().numpy()) plt.set_cmap('bwr') colorAxisNormalize(fig.colorbar(img)) plt.tight_layout() plt.show() """ /end of non-generic part""" if not plotmode and self.global_iter % self.display_step == 0: pbar.write('iter:{}, loss:{:.3e}'.format( self.global_iter, running_loss_terminal_display / self.display_step)) running_loss_terminal_display = 0.0 if self.global_iter % self.save_step == 0: self.save_checkpoint(self.ckpt_name) pbar.write('Saved checkpoint(iter:{})'.format( self.global_iter)) self.gather.save_data_dict() if self.global_iter >= self.max_iter: out = True break pbar.write("[Training Finished]") pbar.close() def save_checkpoint(self, filename, silent=True): """ saves model and optimizer state as checkpoint modified from https://github.com/1Konny/Beta-VAE/blob/master/solver.py """ model_states = { 'net': self.net.state_dict(), } optim_states = { 'optim': self.optim.state_dict(), } states = { 'iter': self.global_iter, 'model_states': model_states, 'optim_states': optim_states } file_path = os.path.join(self.ckpt_dir, filename) with open(file_path, mode='wb+') as f: torch.save(states, f) if not silent: print("=> saved checkpoint '{}' (iter {})".format( file_path, self.global_iter)) def load_checkpoint(self, filename): """ loads model and optimizer state from checkpoint modified from https://github.com/1Konny/Beta-VAE/blob/master/solver.py """ file_path = os.path.join(self.ckpt_dir, filename) if os.path.isfile(file_path): checkpoint = torch.load(file_path) self.global_iter = checkpoint['iter'] self.net.load_state_dict(checkpoint['model_states']['net']) self.optim.load_state_dict(checkpoint['optim_states']['optim']) print("=> loaded checkpoint '{} (iter {})'".format( file_path, self.global_iter)) else: print("=> no checkpoint found at '{}'".format(file_path))
class MAXQLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.mac = mac self.logger = logger self.mac_params = list(mac.parameters()) self.params = list(self.mac.parameters()) self.last_target_update_episode = 0 self.mixer = None assert args.mixer is not None if args.mixer is not None: if args.mixer == "vdn": self.mixer = VDNMixer() elif args.mixer == "qmix": self.mixer = QMixer(args) else: raise ValueError("Mixer {} not recognised.".format(args.mixer)) self.mixer_params = list(self.mixer.parameters()) self.params += list(self.mixer.parameters()) self.target_mixer = copy.deepcopy(self.mixer) # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC self.target_mac = copy.deepcopy(mac) # Central Q # TODO: Clean this mess up! self.central_mac = None if self.args.central_mixer in ["ff", "atten"]: if self.args.central_loss == 0: self.central_mixer = self.mixer self.central_mac = self.mac self.target_central_mac = self.target_mac else: if self.args.central_mixer == "ff": self.central_mixer = QMixerCentralFF( args ) # Feedforward network that takes state and agent utils as input elif self.args.central_mixer == "atten": self.central_mixer = QMixerCentralAtten(args) else: raise Exception("Error with central_mixer") assert args.central_mac == "basic_central_mac" self.central_mac = mac_REGISTRY[args.central_mac]( scheme, args ) # Groups aren't used in the CentralBasicController. Little hacky self.target_central_mac = copy.deepcopy(self.central_mac) self.params += list(self.central_mac.parameters()) else: raise Exception("Error with qCentral") self.params += list(self.central_mixer.parameters()) self.target_central_mixer = copy.deepcopy(self.central_mixer) self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) self.log_stats_t = -self.args.learner_log_interval - 1 self.grad_norm = 1 self.mixer_norm = 1 self.mixer_norms = deque([1], maxlen=100) def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Pick the Q-Values for the actions taken by each agent chosen_action_qvals_agents = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze( 3) # Remove the last dim chosen_action_qvals = chosen_action_qvals_agents # Calculate the Q-Values necessary for the target target_mac_out = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = self.target_mac.forward(batch, t=t) target_mac_out.append(target_agent_outs) # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack(target_mac_out[:], dim=1) # Concat across time # Mask out unavailable actions target_mac_out[avail_actions[:, :] == 0] = -9999999 # From OG deepmarl # Max over target Q-Values if self.args.double_q: # Get actions that maximise live Q (for double q-learning) mac_out_detach = mac_out.clone().detach() mac_out_detach[avail_actions == 0] = -9999999 cur_max_action_targets, cur_max_actions = mac_out_detach[:, :].max( dim=3, keepdim=True) target_max_agent_qvals = th.gather( target_mac_out[:, :], 3, cur_max_actions[:, :]).squeeze(3) else: raise Exception("Use double q") # Central MAC stuff central_mac_out = [] self.central_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.central_mac.forward(batch, t=t) central_mac_out.append(agent_outs) central_mac_out = th.stack(central_mac_out, dim=1) # Concat over time central_chosen_action_qvals_agents = th.gather( central_mac_out[:, :-1], dim=3, index=actions.unsqueeze(4).repeat( 1, 1, 1, 1, self.args.central_action_embed)).squeeze( 3) # Remove the last dim central_target_mac_out = [] self.target_central_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = self.target_central_mac.forward(batch, t=t) central_target_mac_out.append(target_agent_outs) central_target_mac_out = th.stack(central_target_mac_out[:], dim=1) # Concat across time # Mask out unavailable actions central_target_mac_out[avail_actions[:, :] == 0] = -9999999 # From OG deepmarl # Use the Qmix max actions central_target_max_agent_qvals = th.gather( central_target_mac_out[:, :], 3, cur_max_actions[:, :].unsqueeze(4).repeat( 1, 1, 1, 1, self.args.central_action_embed)).squeeze(3) # --- # Mix chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) target_max_qvals = self.target_central_mixer( central_target_max_agent_qvals[:, 1:], batch["state"][:, 1:]) # Calculate 1-step Q-Learning targets targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals # Td-error td_error = (chosen_action_qvals - (targets.detach())) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Training central Q central_chosen_action_qvals = self.central_mixer( central_chosen_action_qvals_agents, batch["state"][:, :-1]) central_td_error = (central_chosen_action_qvals - targets.detach()) central_mask = mask.expand_as(central_td_error) central_masked_td_error = central_td_error * central_mask central_loss = (central_masked_td_error**2).sum() / mask.sum() # QMIX loss with weighting ws = th.ones_like(td_error) * self.args.w if self.args.hysteretic_qmix: # OW-QMIX ws = th.where(td_error < 0, th.ones_like(td_error) * 1, ws) # Target is greater than current max w_to_use = ws.mean().item() # For logging else: # CW-QMIX is_max_action = (actions == cur_max_actions[:, :-1]).min(dim=2)[0] max_action_qtot = self.target_central_mixer( central_target_max_agent_qvals[:, :-1], batch["state"][:, :-1]) qtot_larger = targets > max_action_qtot ws = th.where(is_max_action | qtot_larger, th.ones_like(td_error) * 1, ws) # Target is greater than current max w_to_use = ws.mean().item() # Average of ws for logging qmix_loss = (ws.detach() * (masked_td_error**2)).sum() / mask.sum() # The weightings for the different losses aren't used (they are always set to 1) loss = self.args.qmix_loss * qmix_loss + self.args.central_loss * central_loss # Optimise self.optimiser.zero_grad() loss.backward() # Logging agent_norm = 0 for p in self.mac_params: param_norm = p.grad.data.norm(2) agent_norm += param_norm.item()**2 agent_norm = agent_norm**(1. / 2) mixer_norm = 0 for p in self.mixer_params: param_norm = p.grad.data.norm(2) mixer_norm += param_norm.item()**2 mixer_norm = mixer_norm**(1. / 2) self.mixer_norm = mixer_norm self.mixer_norms.append(mixer_norm) grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.grad_norm = grad_norm self.optimiser.step() if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("qmix_loss", qmix_loss.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) self.logger.log_stat("mixer_norm", mixer_norm, t_env) self.logger.log_stat("agent_norm", agent_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat( "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("central_loss", central_loss.item(), t_env) self.logger.log_stat("w_to_use", w_to_use, t_env) self.log_stats_t = t_env def _update_targets(self): self.target_mac.load_state(self.mac) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) if self.central_mac is not None: self.target_central_mac.load_state(self.central_mac) self.target_central_mixer.load_state_dict( self.central_mixer.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() if self.mixer is not None: self.mixer.cuda() self.target_mixer.cuda() if self.central_mac is not None: self.central_mac.cuda() self.target_central_mac.cuda() self.central_mixer.cuda() self.target_central_mixer.cuda() # TODO: Model saving/loading is out of date! def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) # Not quite right but I don't want to save target networks self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict( th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict( th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
class QLearner_3s_vs_4z: def __init__(self, mac, scheme, logger, args): self.args = args self.mac = mac self.logger = logger self.params = list(mac.parameters()) self.last_target_update_episode = 0 self.mixer = None if args.mixer is not None: if args.mixer == "vdn": self.mixer = VDNMixer() elif args.mixer == "qmix": self.mixer = QMixer(args) else: raise ValueError("Mixer {} not recognised.".format(args.mixer)) self.params += list(self.mixer.parameters()) self.target_mixer = copy.deepcopy(self.mixer) self.params += list(self.mac.msg_rnn.parameters()) self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) self.target_mac = copy.deepcopy(mac) self.log_stats_t = -self.args.learner_log_interval - 1 self.loss_weight = [0.5, 1, 1.5] # its the beta in the Algorithm 1 def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values mac_out = [] previous_msg_list = [] smooth_loss_list = [] regulariation_smooth = 1.5 regulariation_robust = 0.3 self.mac.init_hidden(batch.batch_size) smooth_loss = th.zeros((32 * 3)).cuda() for t in range(batch.max_seq_length): agent_local_outputs, input_hidden_states, vi = self.mac.forward( batch, t=t) input_hidden_states = input_hidden_states.view(-1, 64) self.mac.hidden_states_msg, dummy = self.mac.msg_rnn( self.mac.hidden_states_msg, input_hidden_states) ss = min(len(previous_msg_list), 3) #record the l2 difference in the window for i in range(ss): smooth_loss += self.loss_weight[i] * (( (dummy - previous_msg_list[i])**2).sum(dim=1)) / ( (ss * 32 * 3 * 10 * (dummy**2)).sum(dim=1)) previous_msg_list.append(dummy) if (len(previous_msg_list) > 3): previous_msg_list.pop(0) smooth_loss_reshape = smooth_loss.reshape(32, 3, 1).sum(1) #(32,1) smooth_loss_list.append(smooth_loss_reshape) # generate the message dummy_final = dummy.reshape(32, 3, 10) dummy0 = dummy_final[:, 0, :] dummy1 = dummy_final[:, 1, :] dummy2 = dummy_final[:, 2, :] agent0 = (dummy1 + dummy2) / 2.0 agent1 = (dummy0 + dummy2) / 2.0 agent2 = (dummy0 + dummy1) / 2.0 agent_global_outputs = th.cat((agent0.view( (32, 1, 10)), agent1.view((32, 1, 10)), agent2.view( (32, 1, 10))), 1) agent_outs = agent_local_outputs + agent_global_outputs mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time ############compute the robustness loss################## robust_loss = th.topk(mac_out, 2)[0][:, :, :, 0] - th.topk( mac_out, 2)[0][:, :, :, 1] robust_loss = th.exp(-25.0 * robust_loss).sum( dim=2)[:, :-1].unsqueeze(2) / (32 * 6) #(32,38) # Pick the Q-Values for the actions taken by each agent chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze( 3) # Remove the last dim # Calculate the Q-Values necessary for the target target_mac_out = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_local_outputs, target_input_hidden_states, tvi = self.target_mac.forward( batch, t=t) target_input_hidden_states = target_input_hidden_states.view( -1, 64) self.target_mac.hidden_states_msg, target_dummy = self.target_mac.msg_rnn( self.target_mac.hidden_states_msg, target_input_hidden_states) target_dummy_final = target_dummy.reshape(32, 3, 10) dummy0 = target_dummy_final[:, 0, :] dummy1 = target_dummy_final[:, 1, :] dummy2 = target_dummy_final[:, 2, :] target_agent0 = (dummy1 + dummy2) / 2.0 target_agent1 = (dummy0 + dummy2) / 2.0 target_agent2 = (dummy0 + dummy1) / 2.0 target_agent_global_outputs = th.cat((target_agent0.view( (32, 1, 10)), target_agent1.view( (32, 1, 10)), target_agent2.view((32, 1, 10))), 1) target_agent_outs = target_agent_local_outputs + target_agent_global_outputs target_mac_out.append(target_agent_outs) # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time # Mask out unavailable actions target_mac_out[avail_actions[:, 1:] == 0] = -9999999 # Max over target Q-Values if self.args.double_q: # Get actions that maximise live Q (for double q-learning) mac_out[avail_actions == 0] = -9999999 cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1] target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) else: target_max_qvals = target_mac_out.max(dim=3)[0] # Mix if self.mixer is not None: chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:]) # Calculate 1-step Q-Learning targets targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals # Td-error td_error = (chosen_action_qvals - targets.detach()) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask ######compute the smooth_loss and robust_loss######### smooth_loss = th.stack(smooth_loss_list[0:-1], dim=1) smooth_loss = (smooth_loss * mask).sum() / mask.sum() robust_loss = (robust_loss * mask).sum() / mask.sum() # Normal L2 loss, take mean over actual data loss = (masked_td_error**2).sum() / mask.sum( ) + regulariation_smooth * smooth_loss + regulariation_robust * robust_loss # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.optimiser.step() if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat( "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.log_stats_t = t_env def _update_targets(self): self.target_mac.load_state(self.mac) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() if self.mixer is not None: self.mixer.cuda() self.target_mixer.cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) # Not quite right but I don't want to save target networks self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict( th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict( th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
class NQLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.mac = mac self.logger = logger self.last_target_update_episode = 0 self.device = th.device('cuda' if args.use_cuda else 'cpu') self.params = list(mac.parameters()) if args.mixer == "qatten": self.mixer = QattenMixer(args) elif args.mixer == "vdn": self.mixer = VDNMixer(args) elif args.mixer == "qmix": self.mixer = Mixer(args) else: raise "mixer error" self.target_mixer = copy.deepcopy(self.mixer) self.params += list(self.mixer.parameters()) print('Mixer Size: ') print(get_parameters_num(self.mixer.parameters())) if self.args.optimizer == 'adam': self.optimiser = Adam(params=self.params, lr=args.lr) else: self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC self.target_mac = copy.deepcopy(mac) self.log_stats_t = -self.args.learner_log_interval - 1 self.train_t = 0 # th.autograd.set_detect_anomaly(True) def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Pick the Q-Values for the actions taken by each agent chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze( 3) # Remove the last dim chosen_action_qvals_ = chosen_action_qvals # Calculate the Q-Values necessary for the target with th.no_grad(): target_mac_out = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = self.target_mac.forward(batch, t=t) target_mac_out.append(target_agent_outs) # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack(target_mac_out, dim=1) # Concat across time # Max over target Q-Values/ Double q learning mac_out_detach = mac_out.clone().detach() mac_out_detach[avail_actions == 0] = -9999999 cur_max_actions = mac_out_detach.max(dim=3, keepdim=True)[1] target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) # Calculate n-step Q-Learning targets target_max_qvals = self.target_mixer(target_max_qvals, batch["state"]) if getattr(self.args, 'q_lambda', False): qvals = th.gather(target_mac_out, 3, batch["actions"]).squeeze(3) qvals = self.target_mixer(qvals, batch["state"]) targets = build_q_lambda_targets(rewards, terminated, mask, target_max_qvals, qvals, self.args.gamma, self.args.td_lambda) else: targets = build_td_lambda_targets(rewards, terminated, mask, target_max_qvals, self.args.n_agents, self.args.gamma, self.args.td_lambda) # Mixer chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) td_error = (chosen_action_qvals - targets.detach()) td_error = 0.5 * td_error.pow(2) mask = mask.expand_as(td_error) masked_td_error = td_error * mask loss = L_td = masked_td_error.sum() / mask.sum() # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.optimiser.step() if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss_td", L_td.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat( "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.log_stats_t = t_env # print estimated matrix if self.args.env == "one_step_matrix_game": print_matrix_status(batch, self.mixer, mac_out) def _update_targets(self): self.target_mac.load_state(self.mac) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() if self.mixer is not None: self.mixer.cuda() self.target_mixer.cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) # Not quite right but I don't want to save target networks self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict( th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict( th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
class QLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.mac = mac self.logger = logger self.params = list(mac.parameters()) self.last_target_update_episode = 0 self.mixer = None if args.mixer is not None: if args.mixer == "vdn": self.mixer = VDNMixer() elif args.mixer == "qmix": self.mixer = NoiseQMixer(args) else: raise ValueError("Mixer {} not recognised.".format(args.mixer)) self.params += list(self.mixer.parameters()) self.target_mixer = copy.deepcopy(self.mixer) discrim_input = np.prod( self.args.state_shape) + self.args.n_agents * self.args.n_actions if self.args.rnn_discrim: self.rnn_agg = RNNAggregator(discrim_input, args) self.discrim = Discrim(args.rnn_agg_size, self.args.noise_dim, args) self.params += list(self.discrim.parameters()) self.params += list(self.rnn_agg.parameters()) else: self.discrim = Discrim(discrim_input, self.args.noise_dim, args) self.params += list(self.discrim.parameters()) self.discrim_loss = th.nn.CrossEntropyLoss(reduction="none") self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) self.target_mac = copy.deepcopy(mac) self.log_stats_t = -self.args.learner_log_interval - 1 def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] noise = batch["noise"][:, 0].unsqueeze(1).repeat(1, rewards.shape[1], 1) # Calculate estimated Q-Values mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Pick the Q-Values for the actions taken by each agent chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze( 3) # Remove the last dim # Calculate the Q-Values necessary for the target target_mac_out = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = self.target_mac.forward(batch, t=t) target_mac_out.append(target_agent_outs) # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time # Mask out unavailable actions #target_mac_out[avail_actions[:, 1:] == 0] = -9999999 # From OG deepmarl # Max over target Q-Values if self.args.double_q: # Get actions that maximise live Q (for double q-learning) #mac_out[avail_actions == 0] = -9999999 cur_max_actions = mac_out[:, 1:].max(dim=3, keepdim=True)[1] target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) else: target_max_qvals = target_mac_out.max(dim=3)[0] # Mix if self.mixer is not None: chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1], noise) target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:], noise) # Discriminator #mac_out[avail_actions == 0] = -9999999 q_softmax_actions = th.nn.functional.softmax(mac_out[:, :-1], dim=3) if self.args.hard_qs: maxs = th.max(mac_out[:, :-1], dim=3, keepdim=True)[1] zeros = th.zeros_like(q_softmax_actions) zeros.scatter_(dim=3, index=maxs, value=1) q_softmax_actions = zeros q_softmax_agents = q_softmax_actions.reshape( q_softmax_actions.shape[0], q_softmax_actions.shape[1], -1) states = batch["state"][:, :-1] state_and_softactions = th.cat([q_softmax_agents, states], dim=2) if self.args.rnn_discrim: h_to_use = th.zeros(size=(batch.batch_size, self.args.rnn_agg_size)).to( states.device) hs = th.ones_like(h_to_use) for t in range(batch.max_seq_length - 1): hs = self.rnn_agg(state_and_softactions[:, t], hs) for b in range(batch.batch_size): if t == batch.max_seq_length - 2 or (mask[b, t] == 1 and mask[b, t + 1] == 0): # This is the last timestep of the sequence h_to_use[b] = hs[b] s_and_softa_reshaped = h_to_use else: s_and_softa_reshaped = state_and_softactions.reshape( -1, state_and_softactions.shape[-1]) if self.args.mi_intrinsic: s_and_softa_reshaped = s_and_softa_reshaped.detach() discrim_prediction = self.discrim(s_and_softa_reshaped) # Cross-Entropy target_repeats = 1 if not self.args.rnn_discrim: target_repeats = q_softmax_actions.shape[1] discrim_target = batch["noise"][:, 0].long().detach().max( dim=1)[1].unsqueeze(1).repeat(1, target_repeats).reshape(-1) discrim_loss = self.discrim_loss(discrim_prediction, discrim_target) if self.args.rnn_discrim: averaged_discrim_loss = discrim_loss.mean() else: masked_discrim_loss = discrim_loss * mask.reshape(-1) averaged_discrim_loss = masked_discrim_loss.sum() / mask.sum() self.logger.log_stat("discrim_loss", averaged_discrim_loss.item(), t_env) # Calculate 1-step Q-Learning targets targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals if self.args.mi_intrinsic: assert self.args.rnn_discrim is False targets = targets + self.args.mi_scaler * discrim_loss.view_as( rewards) # Td-error td_error = (chosen_action_qvals - targets.detach()) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data loss = (masked_td_error**2).sum() / mask.sum() loss = loss + self.args.mi_loss * averaged_discrim_loss # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.optimiser.step() if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat( "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.log_stats_t = t_env def _update_targets(self): self.target_mac.load_state(self.mac) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() self.discrim.cuda() if self.args.rnn_discrim: self.rnn_agg.cuda() if self.mixer is not None: self.mixer.cuda() self.target_mixer.cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict( th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict( th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
class QLearner: def __init__(self, mac, args): self.args = args self.method = args.method if "aiqmix" in self.method: self.imaginary_lambda = args.imaginary_lambda self.mac = mac self.mixer = Mixer(args) # target networks self.target_mac = copy.deepcopy(mac) self.target_mixer = copy.deepcopy(self.mixer) self.disable_gradient(self.target_mac) self.disable_gradient(self.target_mixer) self.modules = [ self.mac, self.mixer, self.target_mac, self.target_mixer ] self.params = list(self.mac.parameters()) + list( self.mixer.parameters()) self.optimizer = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) self.n_params = sum(p.numel() for p in self.mac.parameters() if p.requires_grad) + \ sum(p.numel() for p in self.mixer.parameters() if p.requires_grad) if args.has_coach: self.coach = Coach(args) self.target_coach = copy.deepcopy(self.coach) self.disable_gradient(self.target_coach) self.modules.append(self.coach) self.modules.append(self.target_coach) self.n_params += sum(p.numel() for p in self.coach.parameters() if p.requires_grad) coach_params = list(self.coach.parameters()) if "vi1" in self.method: self.vi1 = VI1(args) self.modules.append(self.vi1) coach_params += list(self.vi1.parameters()) if "vi2" in self.method: self.vi2 = VI2(args) self.modules.append(self.vi2) coach_params += list(self.vi2.parameters()) self.coach_params = coach_params self.coach_optimizer = RMSprop(coach_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) print(f"[info] Total number of params: {self.n_params}") self.buffer = ReplayBuffer(args.buffer_size) self.t = 0 def disable_gradient(self, module): module.eval() for p in module.parameters(): p.requires_grad = False def tensorize(self, args): o, e, c, m, ms, a, r = args device = self.args.device o = torch.Tensor(o).to(device) # [batch, t, n_agents, observation_dim] e = torch.Tensor(e).to(device) # [batch, t, n_others, entity_dim] c = torch.Tensor(c).to(device) # [batch, t, n_agents, attribute_dim] m = torch.Tensor(m).to(device) # [batch, t, n_agents, n_all] ms = torch.Tensor(ms).to( device) # [batch, t, n_agents, n_all] full observation a = torch.LongTensor(a).to(device) # [batch, t, n_agents] r = torch.Tensor(r).to(device) # [batch, t,] mask = ms.sum(-1, keepdims=True).gt(0).float() o = mask * o a = mask.long().squeeze(-1) * a c = mask * (c - 0.5) return o, e, c, m, ms, a, r def update(self, logger, step): if len(self.buffer) < self.args.batch_size: return self.t += 1 o, e, c, m, ms, a, r = self.tensorize( self.buffer.sample(self.args.batch_size)) T = o.shape[1] - 1 # since we have T+1 steps 0, 1, ..., T if self.args.has_coach: # get the z_team_t0 training_team_strategy = self.mac.z_team.clone( ) # save previous team strategy z_t0, mu_t0, logvar_t0 = self.coach(o[:, 0], e[:, 0], c[:, 0], ms[:, 0]) z_t0_target, _, _ = self.target_coach(o[:, 0], e[:, 0], c[:, 0], ms[:, 0]) z_T_target, _, _ = self.target_coach(o[:, T], e[:, T], c[:, T], ms[:, T]) self.mac.set_team_strategy(z_t0) self.target_mac.set_team_strategy(z_t0_target) rnn_hidden = self.mac.init_hidden(o.shape[0], o.shape[2]) # [batch, n_agents, dh] Q = [] H_mixer = [] for t in range(T): prev_a = torch.zeros_like(a[:, 0]) if t == 0 else a[:, t - 1] qa, h, h_full, rnn_hidden = self.mac(o[:, t], e[:, t], c[:, t], m[:, t], ms[:, t], rnn_hidden, prev_a, a[:, t]) if self.args.has_coach: coach_h = self.coach.encode(o[:, t], e[:, t], c[:, t], ms[:, t]) q = self.mixer.coach_forward(coach_h, qa, ms[:, t]) else: q = self.mixer(h_full, qa, ms[:, t]) H_mixer.append(h_full) Q.append(q.unsqueeze(-1)) Q = torch.cat(Q, -1) # [batch, T] with torch.no_grad(): NQ = [] NQ_ = [] rnn_hidden = self.mac.init_hidden( o.shape[0], o.shape[2]) # [batch, n_agents, dh] for t in range(T + 1): if t == T and self.args.has_coach: # update strategy for last step self.target_mac.set_team_strategy(z_T_target) prev_a = torch.zeros_like(a[:, 0]) if t == 0 else a[:, t - 1] qa, h, h_full, rnn_hidden = self.target_mac( o[:, t], e[:, t], c[:, t], m[:, t], ms[:, t], rnn_hidden, prev_a) qa = qa.max(-1)[0] if self.args.has_coach: coach_h = self.target_coach.encode(o[:, t], e[:, t], c[:, t], ms[:, t]) nq = self.target_mixer.coach_forward(coach_h, qa, ms[:, t]) else: nq = self.target_mixer(h_full, qa, ms[:, t]) NQ.append(nq.unsqueeze(-1)) NQ = torch.cat(NQ, -1)[:, 1:] # [batch, T] #if self.args.has_coach: # NQ_ = torch.cat(NQ_, -1)[:,1:] # [batch, T] ###################################################################### # 1a. Bellman error ###################################################################### td_target = r[:, :-1] + self.args.gamma * NQ td_error = F.mse_loss(Q, td_target) #if self.args.has_coach: # td_error = td_error * 0.5 + \ # 0.5 * F.mse_loss(Q_, r[:,:-1] + self.args.gamma * NQ_) ###################################################################### # 1b. Imaginary Bellman error ###################################################################### if "aiqmix" in self.method: rnn_hidden = self.mac.init_hidden(o.shape[0] * 2, o.shape[2]) im_Q = [] for t in range(T): prev_a = torch.zeros_like(a[:, 0]) if t == 0 else a[:, t - 1] im_qa, im_h, im_h_full, rnn_hidden = self.mac.im_forward( o[:, t], e[:, t], c[:, t], m[:, t], ms[:, t], rnn_hidden, prev_a, a[:, t]) h_mixer = im_h_full im_qa = self.mixer.im_forward(h_mixer, H_mixer[t], im_qa, ms[:, t]) im_Q.append(im_qa.unsqueeze(-1)) im_Q = torch.cat(im_Q, -1) im_td_error = F.mse_loss(im_Q, td_target) td_error = (1-self.imaginary_lambda) * td_error + \ self.imaginary_lambda * im_td_error ###################################################################### # 2. ELBO ###################################################################### elbo = 0. if self.args.has_coach: if "vi1" in self.method: vi1_loss = self.vi1(o[:, 0], c[:, 0], ms[:, 0], z_t0) elbo += vi1_loss * self.args.lambda1 if "vi2" in self.method: vi2_loss = self.vi2(o, e, c, m, ms[:, 0], a, z_t0) p_ = D.normal.Normal(mu_t0, (0.5 * logvar_t0).exp()) entropy = p_.entropy().clamp_(0, 10).mean() elbo += vi2_loss * self.args.lambda2 - entropy * self.args.lambda2 / 10 if "vi3" in self.method: p_ = D.normal.Normal(mu_t0, (0.5 * logvar_t0).exp()) q_ = D.normal.Normal(torch.zeros_like(mu_t0), torch.ones_like(logvar_t0)) vi3_loss = D.kl_divergence(p_, q_).clamp_(0, 10).mean() elbo += vi3_loss * self.args.lambda3 #print(f"td {td_error.item():.4f} l2 {vi2_loss.item():.4f}") #print(f"td {td_error.item():.4f} ent {entropy.item():.4f} l2 {vi2_loss.item():.4f}") self.optimizer.zero_grad() if self.args.has_coach: self.coach_optimizer.zero_grad() (td_error + elbo).backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) if self.args.has_coach: coach_grad_norm = torch.nn.utils.clip_grad_norm_( self.coach_params, self.args.grad_norm_clip) self.optimizer.step() if self.args.has_coach: self.coach_optimizer.step() # set back team strategy for rollout self.mac.set_team_strategy(training_team_strategy) # update target once in a while if self.t % self.args.update_target_every == 0: self._update_targets() if "aiqmix" in self.method: logger.add_scalar("im_q_loss", im_td_error.cpu().item(), step) if "vi1" in self.method: logger.add_scalar("vi1", vi1_loss.item(), step) if "vi2" in self.method: logger.add_scalar("vi2", vi2_loss.item(), step) logger.add_scalar("q_loss", td_error.cpu().item(), step) logger.add_scalar("grad_norm", grad_norm.item(), step) def save_models(self, path): torch.save(self.mac.state_dict(), "{}/mac.th".format(path)) torch.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) torch.save(self.optimizer.state_dict(), "{}/opt.th".format(path)) if self.args.has_coach: torch.save(self.coach.state_dict(), "{}/coach.th".format(path)) torch.save(self.coach_optimizer.state_dict(), "{}/coach_opt.th".format(path)) if "vi1" in self.method: torch.save(self.vi1.state_dict(), "{}/vi1.th".format(path)) if "vi2" in self.method: torch.save(self.vi2.state_dict(), "{}/vi2.th".format(path)) def load_models(self, path): self.mac.load_state_dict(torch.load("{}/mac.th".format(path))) self.mixer.load_state_dict(torch.load("{}/mixer.th".format(path))) self.optimizer.load_state_dict(torch.load("{}/opt.th".format(path))) if self.args.has_coach: self.coach.load_state_dict(torch.load("{}/coach.th".format(path))) self.coach_optimizer.load_state_dict( torch.load("{}/coach_opt.th".format(path))) if "vi1" in self.method: self.vi1.load_state_dict(torch.load("{}/vi1.th".format(path))) if "vi2" in self.method: self.vi2.load_state_dict(torch.load("{}/vi2.th".format(path))) self.target_mac = copy.deepcopy(self.mac) self.target_mixer = copy.deepcopy(self.mixer) self.disable_gradient(self.target_mac) self.disable_gradient(self.target_mixer) if self.args.has_coach: self.target_coach = copy.deepcopy(self.coach) self.disable_gradient(self.target_coach) def _update_targets(self): self.target_mac.load_state_dict(self.mac.state_dict()) self.target_mixer.load_state_dict(self.mixer.state_dict()) if self.args.has_coach: self.target_coach.load_state_dict(self.coach.state_dict()) return def cuda(self): for m in self.modules: m.cuda() def cpu(self): for m in self.modules: m.cpu()
class SLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.mac = mac self.logger = logger self.n_actions_levin = args.n_actions self.params = list(mac.parameters()) self.last_target_update_episode = 0 self.mixer = None if args.mixer is not None: if args.mixer == "vdn": self.mixer = VDNMixer() elif args.mixer == "qmix": self.mixer = QMixer(args) else: raise ValueError("Mixer {} not recognised.".format(args.mixer)) self.params += list(self.mixer.parameters()) if not args.SubAVG_Mixer_flag: self.target_mixer = copy.deepcopy(self.mixer) elif args.mixer == "qmix": self.target_mixer_list = [] for i in range(self.args.SubAVG_Mixer_K): self.target_mixer_list.append(copy.deepcopy(self.mixer)) self.levin_iter_target_mixer_update = 0 self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC if not self.args.SubAVG_Agent_flag: self.target_mac = copy.deepcopy(mac) else: self.target_mac_list = [] for i in range(self.args.SubAVG_Agent_K): self.target_mac_list.append(copy.deepcopy(mac)) self.levin_iter_target_update = 0 self.log_stats_t = -self.args.learner_log_interval - 1 # ====== levin ===== self.number = 0 def train(self, batch: EpisodeBatch, t_env: int, episode_num: int, epsilon_levin=None): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Pick the Q-Values for the actions taken by each agent chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Calculate the Q-Values necessary for the target target_mac_out = [] if not self.args.SubAVG_Agent_flag: self.target_mac.init_hidden(batch.batch_size) else: for i in range(self.args.SubAVG_Agent_K): self.target_mac_list[i].init_hidden(batch.batch_size) for t in range(batch.max_seq_length): if not self.args.SubAVG_Agent_flag: target_agent_outs = self.target_mac.forward(batch, t=t) # exp:使用 average DQN的target_mac else: target_agent_outs = 0 self.target_agent_out_list = [] for i in range(self.args.SubAVG_Agent_K): target_agent_out = self.target_mac_list[i].forward(batch, t=t) target_agent_outs = target_agent_outs + target_agent_out if self.args.SubAVG_Agent_flag_select: self.target_agent_out_list.append(target_agent_out) target_agent_outs = target_agent_outs / self.args.SubAVG_Agent_K if self.args.SubAVG_Agent_flag_select: if self.args.SubAVG_Agent_name_select_replacement == 'mean': target_out_select_sum = 0 for i in range(self.args.SubAVG_Agent_K): if self.args.SubAVG_Agent_flag_select > 0: target_out_select = th.where( self.target_agent_out_list[i] < target_agent_outs, target_agent_outs, self.target_agent_out_list[i]) else: target_out_select = th.where( self.target_agent_out_list[i] > target_agent_outs, target_agent_outs, self.target_agent_out_list[i]) target_out_select_sum = target_out_select_sum + target_out_select target_agent_outs = target_out_select_sum / self.args.SubAVG_Agent_K elif self.args.SubAVG_Agent_name_select_replacement == 'zero': target_out_select_sum = 0 target_select_bool_sum = 0 for i in range(self.args.SubAVG_Agent_K): if self.args.SubAVG_Agent_flag_select > 0: target_select_bool = ( self.target_agent_out_list[i] > target_agent_outs).float() target_out_select = th.where( self.target_agent_out_list[i] > target_agent_outs, self.target_agent_out_list[i], th.full_like(target_agent_outs, 0)) else: target_select_bool = ( self.target_agent_out_list[i] < target_agent_outs).float() target_out_select = th.where( self.target_agent_out_list[i] < target_agent_outs, self.target_agent_out_list[i], th.full_like(target_agent_outs, 0)) target_select_bool_sum = target_select_bool_sum + target_select_bool target_out_select_sum = target_out_select_sum + target_out_select if self.levin_iter_target_update < 2: pass # print("using average directly") else: target_agent_outs = target_out_select_sum / target_select_bool_sum target_mac_out.append(target_agent_outs) # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack(target_mac_out, dim=1) # Concat across time # Mask out unavailable actions target_chosen_action_qvals = th.gather(target_mac_out, 3, batch['actions']).squeeze(-1) # Mix if self.mixer is None: target_qvals = target_chosen_action_qvals else: chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) if not self.args.SubAVG_Mixer_flag: target_qvals = self.target_mixer(target_chosen_action_qvals, batch['state']) elif self.args.mixer == "qmix": target_max_qvals_sum = 0 self.target_mixer_out_list = [] for i in range(self.args.SubAVG_Mixer_K): targe_mixer_out = self.target_mixer_list[i]( target_chosen_action_qvals, batch['state']) target_max_qvals_sum = target_max_qvals_sum + targe_mixer_out if self.args.SubAVG_Mixer_flag_select: self.target_mixer_out_list.append(targe_mixer_out) target_max_qvals = target_max_qvals_sum / self.args.SubAVG_Mixer_K # levin: mixer select if self.args.SubAVG_Mixer_flag_select: if self.args.SubAVG_Mixer_name_select_replacement == 'mean': target_mixer_select_sum = 0 for i in range(self.args.SubAVG_Mixer_K): if self.args.SubAVG_Mixer_flag_select > 0: target_mixer_select = th.where( self.target_mixer_out_list[i] < target_max_qvals, target_max_qvals, self.target_mixer_out_list[i]) else: target_mixer_select = th.where( self.target_mixer_out_list[i] > target_max_qvals, target_max_qvals, self.target_mixer_out_list[i]) target_mixer_select_sum = target_mixer_select_sum + target_mixer_select target_max_qvals = target_mixer_select_sum / self.args.SubAVG_Mixer_K elif self.args.SubAVG_Mixer_name_select_replacement == 'zero': target_mixer_select_sum = 0 target_mixer_select_bool_sum = 0 for i in range(self.args.SubAVG_Mixer_K): if self.args.SubAVG_Mixer_flag_select > 0: target_mixer_select_bool = ( self.target_mixer_out_list[i] > target_max_qvals).float() target_mixer_select = th.where( self.target_mixer_out_list[i] > target_max_qvals, self.target_mixer_out_list[i], th.full_like(target_max_qvals, 0)) else: target_mixer_select_bool = ( self.target_mixer_out_list[i] < target_max_qvals).float() target_mixer_select = th.where( self.target_mixer_out_list[i] < target_max_qvals, self.target_mixer_out_list[i], th.full_like(target_max_qvals, 0)) target_mixer_select_bool_sum = target_mixer_select_bool_sum + target_mixer_select_bool target_mixer_select_sum = target_mixer_select_sum + target_mixer_select if self.levin_iter_target_mixer_update < 2: pass # print("using average-mix directly") else: target_max_qvals = target_mixer_select_sum / target_mixer_select_bool_sum target_qvals = target_max_qvals if self.args.td_lambda <= 1 and self.args.td_lambda > 0: targets = build_td_lambda_targets(rewards, terminated, mask, target_qvals, self.args.n_agents, self.args.gamma, self.args.td_lambda) else: if self.args.td_lambda == 0: n = 1 # 1-step TD else: n = self.args.td_lambda targets = th.zeros_like(batch['reward']) targets += batch['reward'] for i in range(1, n): targets[:, :-i] += (self.args.gamma**i) * ( 1 - terminated[:, i - 1:]) * batch['reward'][:, i:] targets[:, :-n] += (self.args.gamma**n) * ( 1 - terminated[:, n - 1:]) * target_qvals[:, n:] targets = targets[:, :-1] # Td-error td_error = (chosen_action_qvals - targets.detach()) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data loss = (masked_td_error**2).sum() / mask.sum() * 2 # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.optimiser.step() if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) # self.logger.log_stat("loss_levin", loss_levin.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat( "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.log_stats_t = t_env def _update_targets(self): if not self.args.SubAVG_Agent_flag: self.target_mac.load_state(self.mac) else: self.number = self.levin_iter_target_update % self.args.SubAVG_Agent_K self.target_mac_list[self.number].load_state(self.mac) self.levin_iter_target_update = self.levin_iter_target_update + 1 if self.mixer is not None: if not self.args.SubAVG_Mixer_flag: self.target_mixer.load_state_dict(self.mixer.state_dict()) elif self.args.mixer == "qmix": mixer_number = self.levin_iter_target_mixer_update % self.args.SubAVG_Mixer_K self.target_mixer_list[mixer_number].load_state_dict( self.mixer.state_dict()) self.levin_iter_target_mixer_update = self.levin_iter_target_mixer_update + 1 self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() if not self.args.SubAVG_Agent_flag: self.target_mac.cuda() else: for i in range(self.args.SubAVG_Agent_K): self.target_mac_list[i].cuda() if self.mixer is not None: self.mixer.cuda() if not self.args.SubAVG_Mixer_flag: self.target_mixer.cuda() elif self.args.mixer == "qmix": for i in range(self.args.SubAVG_Mixer_K): self.target_mixer_list[i].cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) # Not quite right but I don't want to save target networks if not self.args.SubAVG_Agent_flag: self.target_mac.load_models(path) else: for i in range(self.args.SubAVG_Agent_K): self.target_mac_list[i].load_models(path) if self.mixer is not None: self.mixer.load_state_dict( th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict( th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
class Solver(object): def __init__(self, args): self.use_cuda = args.cuda and torch.cuda.is_available() self.device = torch.device("cuda:0" if self.use_cuda else "cpu") """ /begin of non-generic part that needs to be modified for each new model """ # model name is used for checkpointing (and here for setting self.net) self.model = args.model self.proportion_train_partition = args.proportion_train_partition self.dimensionwise_partition = args.dimensionwise_partition dataloaderparams = { 'batch_size': args.batch_size, 'shuffle': args.shuffle, 'num_workers': args.num_workers } if args.dataset.lower() == 'dsprites_circle': partition = partition_init( self.proportion_train_partition, dimensionwise_partition=self.dimensionwise_partition, shapetype='circle') self.train_loader = torch.utils.data.DataLoader( dSpriteBackgroundDataset(partition['train'], transform=transforms.Resize((32, 32)), shapetype='circle'), **dataloaderparams) elif args.dataset.lower() == 'dsprites': partition = partition_init( self.proportion_train_partition, dimensionwise_partition=self.dimensionwise_partition, shapetype='dsprites') self.train_loader = torch.utils.data.DataLoader( dSpriteBackgroundDataset(partition['train'], transform=transforms.Resize((32, 32)), shapetype='dsprite'), **dataloaderparams) if args.model.lower() == "encoderbvae_like": net = encoderBVAE_like self.modeltype = 'encoder' elif args.model.lower() == "decoderbvae_like": net = decoderBVAE_like self.modeltype = 'decoder' elif args.model.lower() == "decoderbvae_like_welu": net = decoderBVAE_like_wElu self.modeltype = 'decoder' elif args.model.lower() == "decoderbvae_like_welu_sigmoidoutput": net = decoderBVAE_like_wElu_SigmoidOutput self.modeltype = 'decoder' else: raise Exception('model "%s" unknown' % args.model) self.net = net(n_latent=args.n_latent, img_channels=args.img_channels).to(self.device) self.loss = MSELoss() self.lr = args.lr self.optim = RMSprop(self.net.parameters(), lr=self.lr) """ /end of non-generic part that needs to be modified for each new model """ self.max_iter = args.max_iter self.global_iter = 0 filepartPartitioningScheme = '' if self.dimensionwise_partition: filepartPartitioningScheme = '_dimwise' # prepare checkpointing if not os.path.isdir(args.ckpt_dir): os.mkdir(args.ckpt_dir) self.ckpt_dir = args.ckpt_dir self.load_last_checkpoint = args.load_last_checkpoint self.ckpt_name = '{}_{}_trainPartitionProportion={}{}_last'.format( self.model.lower(), args.dataset.lower(), self.proportion_train_partition, filepartPartitioningScheme) self.save_step = args.save_step if self.load_last_checkpoint is not None: self.load_checkpoint(self.ckpt_name) self.display_step = args.display_step # will store training-related information self.trainstats_gather_step = args.trainstats_gather_step self.trainstats_dir = args.trainstats_dir if not os.path.isdir(self.trainstats_dir): os.mkdir(self.trainstats_dir) self.trainstats_fname = '{}_{}_trainPartitionProportion={}{}'.format( self.model.lower(), args.dataset.lower(), self.proportion_train_partition, filepartPartitioningScheme) self.gather = DataGather( filename=os.path.join(self.trainstats_dir, self.trainstats_fname)) # store the partition of the data into train and validation (for later evaluations) self.partition_save_filename = '{}_{}_trainPartitionProportion={}{}_partition'.format( self.model.lower(), args.dataset.lower(), self.proportion_train_partition, filepartPartitioningScheme) file_path = os.path.join(self.ckpt_dir, self.partition_save_filename) with open(file_path, mode='wb+') as f: torch.save(partition, f) print("=> saved partition '{}'".format(file_path)) def train(self): pbar = tqdm(total=self.max_iter) pbar.update(self.global_iter) out = False running_loss_trainstats = 0.0 # running loss for the trainstats (gathered and pickeled) running_loss_terminal_display = 0.0 # running loss for the trainstats (gathered and pickeled) while not out: for [ samples, latents ] in self.train_loader: # not sure how long the train_loader spits out data (possibly infinite?) self.global_iter += 1 pbar.update(1) self.net.zero_grad() """ /begin of non-generic part (might need to be adapted for different models / data)""" # get current batch and push to device img_batch, code_batch = samples.to(self.device), latents.to( self.device) # depending on encoder/decoder image/label are input/target (or vice versa) if self.modeltype == 'encoder': input_batch = img_batch output_batch = code_batch elif self.modeltype == 'decoder': input_batch = code_batch output_batch = img_batch predicted_batch = self.net(input_batch) actLoss = self.loss(predicted_batch, output_batch) actLoss.backward() self.optim.step() running_loss_trainstats += actLoss.item() running_loss_terminal_display += actLoss.item() # update gather object with training information if self.global_iter % self.trainstats_gather_step == 0: running_loss_trainstats = running_loss_trainstats / self.trainstats_gather_step self.gather.insert( iter=self.global_iter, recon_loss=running_loss_trainstats, target=output_batch[0].detach().cpu().numpy(), reconstructed=predicted_batch[0].detach().cpu().numpy( ), ) running_loss_trainstats = 0.0 """ /end of non-generic part""" if self.global_iter % self.display_step == 0: pbar.write('iter:{}, loss:{:.3e}'.format( self.global_iter, running_loss_terminal_display / self.display_step)) running_loss_terminal_display = 0.0 if self.global_iter % self.save_step == 0: self.save_checkpoint(self.ckpt_name) pbar.write('Saved checkpoint(iter:{})'.format( self.global_iter)) self.gather.save_data_dict() if self.global_iter >= self.max_iter: out = True break pbar.write("[Training Finished]") pbar.close() def save_checkpoint(self, filename, silent=True): """ saves model and optimizer state as checkpoint modified from https://github.com/1Konny/Beta-VAE/blob/master/solver.py """ model_states = { 'net': self.net.state_dict(), } optim_states = { 'optim': self.optim.state_dict(), } states = { 'iter': self.global_iter, 'model_states': model_states, 'optim_states': optim_states } file_path = os.path.join(self.ckpt_dir, filename) with open(file_path, mode='wb+') as f: torch.save(states, f) if not silent: print("=> saved checkpoint '{}' (iter {})".format( file_path, self.global_iter)) def load_checkpoint(self, filename): """ loads model and optimizer state from checkpoint modified from https://github.com/1Konny/Beta-VAE/blob/master/solver.py """ file_path = os.path.join(self.ckpt_dir, filename) if os.path.isfile(file_path): checkpoint = torch.load(file_path) self.global_iter = checkpoint['iter'] self.net.load_state_dict(checkpoint['model_states']['net']) self.optim.load_state_dict(checkpoint['optim_states']['optim']) print("=> loaded checkpoint '{} (iter {})'".format( file_path, self.global_iter)) else: print("=> no checkpoint found at '{}'".format(file_path))
print("-- model using stereonet --") if args.cuda: model = nn.DataParallel(model) model.cuda() # optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999)) optimizer = RMSprop(model.parameters(), lr=1e-3, weight_decay=0.0001) epoch_start = 0 total_train_loss_save = 0 checkpoint_path = root_path + "/checkpoints/checkpoint_sceneflow.tar" if args.loadmodel is not None and os.path.exists(checkpoint_path): state_dict = torch.load(checkpoint_path) model.load_state_dict(state_dict['state_dict']) optimizer.load_state_dict(state_dict['optimizer_state_dict']) epoch_start = state_dict['epoch'] total_train_loss_save = state_dict['total_train_loss'] print("-- checkpoint loaded --") scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=epoch_start) else: scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) print("-- no checkpoint --") print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) def train(imgL,imgR, disp_L): model.train() imgL = Variable(torch.FloatTensor(imgL)) imgR = Variable(torch.FloatTensor(imgR))
class COMALearner: def __init__(self, mac, scheme, logger, args): self.args = args self.n_agents = args.n_agents self.n_actions = args.n_actions self.mac = mac self.logger = logger self.Mode = str(self.args.running_mode) self.last_target_update_step = 0 self.critic_training_steps = 0 self.log_stats_t = -self.args.learner_log_interval - 1 self.critic = COMACritic(scheme, args) self.target_critic = copy.deepcopy(self.critic) self.agent_params = list(mac.parameters()) #print("self.agent_params=",self.agent_params) self.critic_params = list(self.critic.parameters()) self.params = self.agent_params + self.critic_params self.agent_optimiser = RMSprop(params=self.agent_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) self.critic_optimiser = RMSprop(params=self.critic_params, lr=args.critic_lr, alpha=args.optim_alpha, eps=args.optim_eps) def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities bs = batch.batch_size #print("episode batch=",EpisodeBatch) #print("batch=",batch,"--------------------------------------------------------------------------") #print("batch[intrinsic_reward]=",batch["intrinsic_reward"],"--------------------------------------------------------------------------") #print("batch[reward]=",batch["reward"],"--------------------------------------------------------------------------") #print("shape of batch[reward]=",batch["actions"].shape,"--------------------------------------------------------------------------") max_t = batch.max_seq_length rewards = batch["reward"][:, :-1] #print("rewards =",rewards.shape) #print("len rewards =",len(rewards)) actions = batch["actions"][:, :] #print("actions =",actions.shape) terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) #print("mask =",mask.shape) #print("len mask =",len(mask)) avail_actions = batch["avail_actions"][:, :-1] critic_mask = mask.clone() mask = mask.repeat(1, 1, self.n_agents).view(-1) #print("mask2 =",mask.shape) q_vals, critic_train_stats = self._train_critic( batch, rewards, terminated, actions, avail_actions, critic_mask, bs, max_t) #print("q_vals =",q_vals.shape) actions = actions[:, :-1] #print("actions2 =",actions.shape) mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length - 1): agent_outs = self.mac.forward(batch, t=t) #print("t=",t,"agent_outs=",agent_outs) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time #print("mac_out=",mac_out.shape) #print("mac_out shape =",mac_out.size()) # Mask out unavailable actions, renormalise (as in action selection) mac_out[avail_actions == 0] = 0 mac_out = mac_out / mac_out.sum(dim=-1, keepdim=True) mac_out[avail_actions == 0] = 0 #print("mac_out2=",mac_out.shape) #print("mac_out shape2 =",mac_out.size()) # Calculated baseline q_vals = q_vals.reshape(-1, self.n_actions) pi = mac_out.view(-1, self.n_actions) baseline = (pi * q_vals).sum(-1).detach() #print("baseline=",baseline.shape) # Calculate policy grad with mask q_taken = th.gather(q_vals, dim=1, index=actions.reshape(-1, 1)).squeeze(1) pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1) pi_taken[mask == 0] = 1.0 log_pi_taken = th.log(pi_taken) advantages = th.FloatTensor([0.0]) #torch.clamp(a, min=-0.5, max=0.5) advantages = (q_taken - baseline).detach() #print("advantages",advantages) ##################################################### individual Intrinsic Reward advantages = advantages.reshape(-1) if self.Mode == "2": int_adv = batch["intrinsic_reward"][:, :-1, :].reshape(-1) #print("int_adv",int_adv) clip_ratio = 2 for t in range(len(advantages)): #print("adv shape =",advantages[t]) #print("int_adv shape =",int_adv[t]) int_adv_clipped = th.clamp(int_adv[t], min=clip_ratio * -advantages[t], max=clip_ratio * advantages[t]) advantages[t] = advantages[t] + int_adv_clipped #print("advantages after",advantages) ##################################################### Combined Intrinsic Reward #print("batchzzzz = ",batch["intrinsic_reward"][:, :-1, 3]) elif self.Mode == "5": #print("batch all =", th.cat((batch["intrinsic_reward"][:, :-1, :],batch["intrinsic_reward"][:, :-1, :],batch["intrinsic_reward"][:, :-1, :]),0).reshape(-1).shape) #print("batch soze =", batch["intrinsic_reward"][:, :-1, :].shape) #print("advantages =", advantages.shape) #temp = [] int_adv = batch["intrinsic_reward"][:, :-1, :] for p in range(self.n_agents - 1): int_adv = th.cat( (int_adv, batch["intrinsic_reward"][:, :-1, :]), 0) int_adv = int_adv.view(-1) #int_adv = th.cat((batch["intrinsic_reward"][:, :-1, :],batch["intrinsic_reward"][:, :-1, :],batch["intrinsic_reward"][:, :-1, :]),1).reshape(-1) clip_ratio = 2 for t in range(len(advantages)): #print("adv shape =",len(advantages)) #print("int_adv shape =",len(int_adv)) int_adv_clipped = th.clamp(int_adv[t], min=clip_ratio * -advantages[t], max=clip_ratio * advantages[t]) advantages[t] = advantages[t] + int_adv_clipped else: pass #print("advantages after",advantages) ################################################################################### #print("int_adv",int_adv.shape) #print("batch[intrinsic_reward]",batch["intrinsic_reward"].shape) #print("batch[reward]",batch["reward"].shape) print("log_pi_taken", log_pi_taken.shape) print("advantages", advantages.shape) coma_loss = -((advantages * log_pi_taken) * mask).sum() / mask.sum() #print("self.agent_optimiser=",self.agent_optimiser) # Optimise agents #print(self.critic.parameters()) #print(self.agent_optimiser.parameters()) self.agent_optimiser.zero_grad() coma_loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip) self.agent_optimiser.step() if (self.critic_training_steps - self.last_target_update_step ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_step = self.critic_training_steps if t_env - self.log_stats_t >= self.args.learner_log_interval: ts_logged = len(critic_train_stats["critic_loss"]) for key in [ "critic_loss", "critic_grad_norm", "td_error_abs", "q_taken_mean", "target_mean" ]: self.logger.log_stat(key, sum(critic_train_stats[key]) / ts_logged, t_env) self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env) self.logger.log_stat("coma_loss", coma_loss.item(), t_env) self.logger.log_stat("agent_grad_norm", grad_norm, t_env) self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env) self.log_stats_t = t_env def _train_critic(self, batch, rewards, terminated, actions, avail_actions, mask, bs, max_t): # Optimise critic #print("batch obs =",batch["obs"][0][0]) target_q_vals = self.target_critic(batch)[:, :] #print("target_q_vals=",target_q_vals) #print("shape target_q_vals=",target_q_vals.shape) #print("batch obs =",batch["obs"]) #print("size batch obs =",batch["obs"].size()) #print("rewards", rewards) #print("size of rewards", rewards.shape) targets_taken = th.gather(target_q_vals, dim=3, index=actions).squeeze(3) # Calculate td-lambda targets targets = build_td_lambda_targets(rewards, terminated, mask, targets_taken, self.n_agents, self.args.gamma, self.args.td_lambda) #print("targets=",targets) q_vals = th.zeros_like(target_q_vals)[:, :-1] running_log = { "critic_loss": [], "critic_grad_norm": [], "td_error_abs": [], "target_mean": [], "q_taken_mean": [], } for t in reversed(range(rewards.size(1))): #print("mask_t before=",mask[:, t]) mask_t = mask[:, t].expand(-1, self.n_agents) #print("mask_t after=",mask_t) if mask_t.sum() == 0: continue q_t = self.critic(batch, t) # may be implement in here #print("batch check what inside =",batch) #print("q_t=",q_t) q_vals[:, t] = q_t.view(bs, self.n_agents, self.n_actions) #print("q_vals=",q_vals) #print("q_vals shpae=",q_vals.shape) q_taken = th.gather(q_t, dim=3, index=actions[:, t:t + 1]).squeeze(3).squeeze(1) #print("q_taken=",q_taken) targets_t = targets[:, t] #print("targets_t=",targets_t) td_error = (q_taken - targets_t.detach()) # 0-out the targets that came from padded data masked_td_error = td_error * mask_t # Normal L2 loss, take mean over actual data loss = (masked_td_error**2).sum() / mask_t.sum() self.critic_optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.critic_params, self.args.grad_norm_clip) self.critic_optimiser.step() self.critic_training_steps += 1 running_log["critic_loss"].append(loss.item()) running_log["critic_grad_norm"].append(grad_norm) mask_elems = mask_t.sum().item() running_log["td_error_abs"].append( (masked_td_error.abs().sum().item() / mask_elems)) running_log["q_taken_mean"].append( (q_taken * mask_t).sum().item() / mask_elems) running_log["target_mean"].append( (targets_t * mask_t).sum().item() / mask_elems) return q_vals, running_log def _update_targets(self): self.target_critic.load_state_dict(self.critic.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.critic.cuda() self.target_critic.cuda() def save_models(self, path): self.mac.save_models(path) th.save(self.critic.state_dict(), "{}/critic.th".format(path)) th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path)) th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) self.critic.load_state_dict( th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage)) # Not quite right but I don't want to save target networks self.target_critic.load_state_dict(self.critic.state_dict()) self.agent_optimiser.load_state_dict( th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage)) self.critic_optimiser.load_state_dict( th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage))
class QLearner: def __init__(self, mac, scheme, logger, args): self.args = args self.mac = mac self.logger = logger self.params = list(mac.parameters()) self.last_target_update_episode = 0 self.mixer = None if args.mixer is not None: if args.mixer == "vdn": self.mixer = VDNMixer() elif args.mixer == "qmix": self.mixer = QMixer(args) else: raise ValueError("Mixer {} not recognised.".format(args.mixer)) self.params += list(self.mixer.parameters()) self.target_mixer = copy.deepcopy(self.mixer) self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC self.target_mac = copy.deepcopy(mac) self.log_stats_t = -self.args.learner_log_interval - 1 def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values mac_out = [] self.mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): agent_outs = self.mac.forward(batch, t=t) mac_out.append(agent_outs) mac_out = th.stack(mac_out, dim=1) # Concat over time # Pick the Q-Values for the actions taken by each agent chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim # Calculate the Q-Values necessary for the target target_mac_out = [] self.target_mac.init_hidden(batch.batch_size) for t in range(batch.max_seq_length): target_agent_outs = self.target_mac.forward(batch, t=t) target_mac_out.append(target_agent_outs) # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time # Mask out unavailable actions target_mac_out[avail_actions[:, 1:] == 0] = -9999999 # Max over target Q-Values if self.args.double_q: # Get actions that maximise live Q (for double q-learning) mac_out_detach = mac_out.clone().detach() mac_out_detach[avail_actions == 0] = -9999999 cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1] target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) else: target_max_qvals = target_mac_out.max(dim=3)[0] # Mix if self.mixer is not None: chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:]) # Calculate 1-step Q-Learning targets targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals # Td-error td_error = (chosen_action_qvals - targets.detach()) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data loss = (masked_td_error ** 2).sum() / mask.sum() # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) self.optimiser.step() if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) self.log_stats_t = t_env def _update_targets(self): self.target_mac.load_state(self.mac) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() if self.mixer is not None: self.mixer.cuda() self.target_mixer.cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) # Not quite right but I don't want to save target networks self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))
class LatentQLearner(QLearner): def __init__(self, mac, scheme, logger, args): super(LatentQLearner, self).__init__(mac, scheme, logger, args) self.args = args self.mac = mac self.logger = logger self.params = list(mac.parameters()) self.last_target_update_episode = 0 self.mixer = None if args.mixer is not None: if args.mixer == "vdn": self.mixer = VDNMixer() elif args.mixer == "qmix": self.mixer = QMixer(args) else: raise ValueError("Mixer {} not recognised.".format(args.mixer)) self.params += list(self.mixer.parameters()) self.target_mixer = copy.deepcopy(self.mixer) self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC self.target_mac = copy.deepcopy(mac) self.log_stats_t = -self.args.learner_log_interval - 1 self.role_save = 0 self.role_save_interval = 10 def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): # Get the relevant quantities rewards = batch["reward"][:, :-1] actions = batch["actions"][:, :-1] terminated = batch["terminated"][:, :-1].float() mask = batch["filled"][:, :-1].float() mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) avail_actions = batch["avail_actions"] # Calculate estimated Q-Values mac_out = [] self.mac.init_hidden(batch.batch_size) indicator, latent, latent_vae = self.mac.init_latent(batch.batch_size) reg_loss = 0 dis_loss = 0 ce_loss = 0 for t in range(batch.max_seq_length): agent_outs, loss_, dis_loss_, ce_loss_ = self.mac.forward( batch, t=t, t_glob=t_env, train_mode=True) # (bs,n,n_actions),(bs,n,latent_dim) reg_loss += loss_ dis_loss += dis_loss_ ce_loss += ce_loss_ # loss_cs=self.args.gamma*loss_cs + _loss mac_out.append(agent_outs) # [t,(bs,n,n_actions)] # mac_out_latent.append((agent_outs_latent)) #[t,(bs,n,latent_dim)] reg_loss /= batch.max_seq_length dis_loss /= batch.max_seq_length ce_loss /= batch.max_seq_length mac_out = th.stack(mac_out, dim=1) # Concat over time # (bs,t,n,n_actions), Q values of n_actions # mac_out_latent=th.stack(mac_out_latent,dim=1) # (bs,t,n,latent_dim) # mac_out_latent=mac_out_latent.reshape(-1,self.args.latent_dim) # Pick the Q-Values for the actions taken by each agent chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze( 3) # Remove the last dim # (bs,t,n) Q value of an action # Calculate the Q-Values necessary for the target target_mac_out = [] self.target_mac.init_hidden(batch.batch_size) # (bs,n,hidden_size) self.target_mac.init_latent(batch.batch_size) # (bs,n,latent_size) for t in range(batch.max_seq_length): target_agent_outs, loss_cs_target, _, _ = self.target_mac.forward( batch, t=t) # (bs,n,n_actions), (bs,n,latent_dim) target_mac_out.append(target_agent_outs) # [t,(bs,n,n_actions)] # We don't need the first timesteps Q-Value estimate for calculating targets target_mac_out = th.stack( target_mac_out[1:], dim=1) # Concat across time, dim=1 is time index # (bs,t,n,n_actions) # Mask out unavailable actions target_mac_out[avail_actions[:, 1:] == 0] = -9999999 # Q values # Max over target Q-Values if self.args.double_q: # True for QMix # Get actions that maximise live Q (for double q-learning) mac_out_detach = mac_out.clone().detach( ) # return a new Tensor, detached from the current graph mac_out_detach[avail_actions == 0] = -9999999 # (bs,t,n,n_actions), discard t=0 cur_max_actions = mac_out_detach[:, 1:].max( dim=3, keepdim=True)[1] # indices instead of values # (bs,t,n,1) target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) # (bs,t,n,n_actions) ==> (bs,t,n,1) ==> (bs,t,n) max target-Q else: target_max_qvals = target_mac_out.max(dim=3)[0] # Mix if self.mixer is not None: chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:]) # (bs,t,1) # Calculate 1-step Q-Learning targets targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals # Td-error td_error = (chosen_action_qvals - targets.detach() ) # no gradient through target net # (bs,t,1) mask = mask.expand_as(td_error) # 0-out the targets that came from padded data masked_td_error = td_error * mask # Normal L2 loss, take mean over actual data loss = (masked_td_error**2).sum() / mask.sum() # entropy loss # mac_out_latent_norm=th.sqrt(th.sum(mac_out_latent*mac_out_latent,dim=1)) # mac_out_latent=mac_out_latent/mac_out_latent_norm[:,None] # loss+=(th.norm(mac_out_latent)/mac_out_latent.size(0))*self.args.entropy_loss_weight loss += reg_loss # Optimise self.optimiser.zero_grad() loss.backward() grad_norm = th.nn.utils.clip_grad_norm_( self.params, self.args.grad_norm_clip) # max_norm self.optimiser.step() if (episode_num - self.last_target_update_episode ) / self.args.target_update_interval >= 1.0: self._update_targets() self.last_target_update_episode = episode_num if t_env - self.log_stats_t >= self.args.learner_log_interval: # if self.role_save % self.role_save_interval == 0: # self.role_save = 0 # if self.args.latent_dim in [2, 3]: # fig = plt.figure() # ax = fig.add_subplot(111, projection='3d') # print(self.mac.agent.latent[:, :self.args.latent_dim], # self.mac.agent.latent[:, -self.args.latent_dim:]) # self.role_save += 1 self.logger.log_stat("loss", loss.item(), t_env) self.logger.log_stat("loss_reg", reg_loss.item(), t_env) self.logger.log_stat("loss_dis", dis_loss.item(), t_env) self.logger.log_stat("loss_ce", ce_loss.item(), t_env) #indicator=[var_mean,mi.max(),mi.min(),mi.mean(),mi.std(),di.max(),di.min(),di.mean(),di.std()] self.logger.log_stat("var_mean", indicator[0].item(), t_env) self.logger.log_stat("mi_max", indicator[1].item(), t_env) self.logger.log_stat("mi_min", indicator[2].item(), t_env) self.logger.log_stat("mi_mean", indicator[3].item(), t_env) self.logger.log_stat("mi_std", indicator[4].item(), t_env) self.logger.log_stat("di_max", indicator[5].item(), t_env) self.logger.log_stat("di_min", indicator[6].item(), t_env) self.logger.log_stat("di_mean", indicator[7].item(), t_env) self.logger.log_stat("di_std", indicator[8].item(), t_env) self.logger.log_stat("grad_norm", grad_norm, t_env) mask_elems = mask.sum().item() self.logger.log_stat( "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env) self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) if self.args.use_tensorboard: # log_vec(self,mat,metadata,label_img,global_step,tag) self.logger.log_vec(latent, list(range(self.args.n_agents)), t_env, "latent") self.logger.log_vec(latent_vae, list(range(self.args.n_agents)), t_env, "latent-VAE") self.log_stats_t = t_env def _update_targets(self): self.target_mac.load_state(self.mac) if self.mixer is not None: self.target_mixer.load_state_dict(self.mixer.state_dict()) self.logger.console_logger.info("Updated target network") def cuda(self): self.mac.cuda() self.target_mac.cuda() if self.mixer is not None: self.mixer.cuda() self.target_mixer.cuda() def save_models(self, path): self.mac.save_models(path) if self.mixer is not None: th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) def load_models(self, path): self.mac.load_models(path) # Not quite right but I don't want to save target networks self.target_mac.load_models(path) if self.mixer is not None: self.mixer.load_state_dict( th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) self.optimiser.load_state_dict( th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage))