def perturb_actor_parameters(self, param_noise): """Apply parameter noise to actor model, for exploration""" hard_update(self.actor_perturbed, self.actor) params = self.actor_perturbed.state_dict() for name in params: if 'ln' in name: pass param = params[name] param += torch.randn(param.shape) * param_noise.current_stddev
def update_critic_parameters(self, batch, agent_id, shuffle=None, eval=False): state_batch = Variable(torch.cat(batch.state)).to(self.device) action_batch = Variable(torch.cat(batch.action)).to(self.device) reward_batch = Variable(torch.cat(batch.reward)).to(self.device) mask_batch = Variable(torch.cat(batch.mask)).to(self.device) next_state_batch = torch.cat(batch.next_state).to(self.device) if shuffle == 'shuffle': rand_idx = np.random.permutation(self.n_agent) new_state_batch = state_batch.view(-1, self.n_agent, self.obs_dim) state_batch = new_state_batch[:, rand_idx, :].view( -1, self.obs_dim * self.n_agent) new_next_state_batch = next_state_batch.view( -1, self.n_agent, self.obs_dim) next_state_batch = new_next_state_batch[:, rand_idx, :].view( -1, self.obs_dim * self.n_agent) new_action_batch = action_batch.view(-1, self.n_agent, self.n_action) action_batch = new_action_batch[:, rand_idx, :].view( -1, self.n_action * self.n_agent) next_action_batch = self.select_action(next_state_batch.view( -1, self.obs_dim), action_noise=self.train_noise) next_action_batch = next_action_batch.view( -1, self.n_action * self.n_agent) next_state_action_values = self.critic_target(next_state_batch, next_action_batch) reward_batch = reward_batch[:, agent_id].unsqueeze(1) mask_batch = mask_batch[:, agent_id].unsqueeze(1) expected_state_action_batch = reward_batch + (self.gamma * mask_batch * next_state_action_values) self.critic_optim.zero_grad() state_action_batch = self.critic(state_batch, action_batch) perturb_out = 0 value_loss = ((state_action_batch - expected_state_action_batch)**2).mean() if eval: return value_loss.item(), perturb_out value_loss.backward() unclipped_norm = clip_grad_norm_(self.critic_params, 0.5) self.critic_optim.step() if self.target_update_mode == 'soft': soft_update(self.critic_target, self.critic, self.tau) elif self.target_update_mode == 'hard': hard_update(self.critic_target, self.critic) return value_loss.item(), perturb_out, unclipped_norm
def __init__(self, gamma, tau, hidden_size, obs_dim, n_action, n_agent, obs_dims, agent_id, actor_lr, critic_lr, fixed_lr, critic_type, train_noise, num_episodes, num_steps, critic_dec_cen, target_update_mode='soft', device='cpu', groups=None): self.group_dim_id = [obs_dims[g] for g in groups] self.group_cum_id = np.cumsum([0] + groups) self.n_group = len(groups) self.device = device self.obs_dim = obs_dim self.n_agent = n_agent self.n_action = n_action self.actors = [ Actor(hidden_size, self.group_dim_id[i], n_action).to(self.device) for i in range(len(groups)) ] self.actor_targets = [ Actor(hidden_size, self.group_dim_id[i], n_action).to(self.device) for i in range(len(groups)) ] self.actor_optims = [ Adam(self.actors[i].parameters(), lr=actor_lr, weight_decay=0) for i in range(len(groups)) ] self.critic = Critic(hidden_size, np.sum(obs_dims), n_action * n_agent, n_agent, critic_type, agent_id, groups).to(self.device) self.critic_target = Critic(hidden_size, np.sum(obs_dims), n_action * n_agent, n_agent, critic_type, agent_id, groups).to(self.device) critic_n_params = sum(p.numel() for p in self.critic.parameters()) print('# of critic params', critic_n_params) self.critic_optim = Adam(self.critic.parameters(), lr=critic_lr) self.fixed_lr = fixed_lr self.init_act_lr = actor_lr self.init_critic_lr = critic_lr self.num_episodes = num_episodes self.start_episode = 0 self.num_steps = num_steps self.gamma = gamma self.tau = tau self.train_noise = train_noise self.obs_dims_cumsum = np.cumsum(obs_dims) self.critic_dec_cen = critic_dec_cen self.agent_id = agent_id self.debug = False self.target_update_mode = target_update_mode self.actors_params = [ self.actors[i].parameters() for i in range(self.n_group) ] self.critic_params = self.critic.parameters() # Make sure target is with the same weight for i in range(self.n_group): hard_update(self.actor_targets[i], self.actors[i]) hard_update(self.critic_target, self.critic)
print('episode {}, p loss {}, p_lr {}'.format( i_episode, policy_loss, agent.actor_lr)) if total_numsteps % args.steps_per_critic_update == 0: value_losses = [] for _ in range(args.critic_updates_per_step): transitions = memory.sample(args.batch_size) batch = Transition(*zip(*transitions)) value_losses.append( agent.update_critic_parameters(batch, i, args.shuffle)) updates += 1 value_loss = np.mean(value_losses) print('episode {}, q loss {}, q_lr {}'.format( i_episode, value_loss, agent.critic_optim.param_groups[0]['lr'])) if args.target_update_mode == 'episodic': hard_update(agent.critic_target, agent.critic) if done_n[0] or terminal: print('train epidoe reward', episode_reward) episode_step = 0 break if not args.fixed_lr: agent.adjust_lr(i_episode) # writer.add_scalar('reward/train', episode_reward, i_episode) rewards.append(episode_reward) # if (i_episode + 1) % 1000 == 0 or ((i_episode + 1) >= args.num_episodes - 50 and (i_episode + 1) % 4 == 0): if (i_episode + 1) % args.eval_freq == 0: tr_log = { 'num_adversary': 0, 'best_good_eval_reward': best_good_eval_reward, 'best_adversary_eval_reward': best_adversary_eval_reward,