def __init__(self, args): """ 1. Initialize Agent embeddings z, poliy theta & pen psi """ self.args = args self.agent1 = Agent(self.args) self.agent2 = Agent(self.args) self.policy = SteerablePolicy(self.args) self.pen = PolicyEvaluationNetwork(self.args) self.pen_optimizer = torch.optim.Adam(params=self.pen.parameters(), lr=args.lr_pen) self.ipd2 = IPD(max_steps=args.len_rollout, batch_size=args.nsamples_bin) if os.path.exists(args.logdir): rmtree(args.logdir) writer = SummaryWriter(args.logdir) self.writer = SummaryWriter(args.logdir)
class Hp(): def __init__(self): self.lr_out = 0.2 self.lr_in = 0.3 self.lr_v = 0.1 self.gamma = 0.96 self.n_update = 200 self.len_rollout = 150 self.batch_size = 128 self.use_baseline = True self.seed = 42 hp = Hp() ipd = IPD(hp.len_rollout, hp.batch_size) def magic_box(x): return torch.exp(x - x.detach()) class Memory(): def __init__(self): self.self_logprobs = [] self.other_logprobs = [] self.values = [] self.rewards = [] def add(self, lp, other_lp, v, r): self.self_logprobs.append(lp)
import argparse from envs import IPD from ipd_DiCE import Agent, play parser = argparse.ArgumentParser() parser.add_argument("--lr-out", default=0.2) parser.add_argument("--lr-in", default=0.3) parser.add_argument("--lr-v", default=0.3) parser.add_argument("--gamma", default=0.96) parser.add_argument("--n-update", default=200) parser.add_argument("--len-rollout", default=150) parser.add_argument("--batch-size", default=128) parser.add_argument("--use-baseline", action="store_true") parser.add_argument("--order", default=0) parser.add_argument("--seed", default=42) args = parser.parse_args() ipd = IPD(args.len_rollout, args.batch_size) scores = play( Agent(args.lr_out, args.lr_v, args.gamma, args.use_baseline, args.len_rollout), Agent(args.lr_out, args.lr_v, args.gamma, args.use_baseline, args.len_rollout), args.order, ipd, args.n_update, args.lr_in, args.len_rollout) return -max(scores)
class Polen(): def __init__(self, args): """ 1. Initialize Agent embeddings z, poliy theta & pen psi """ self.args = args self.agent1 = Agent(self.args) self.agent2 = Agent(self.args) self.policy = SteerablePolicy(self.args) self.pen = PolicyEvaluationNetwork(self.args) self.pen_optimizer = torch.optim.Adam(params=self.pen.parameters(), lr=args.lr_pen) self.ipd2 = IPD(max_steps=args.len_rollout, batch_size=args.nsamples_bin) if os.path.exists(args.logdir): rmtree(args.logdir) writer = SummaryWriter(args.logdir) self.writer = SummaryWriter(args.logdir) def rollout(self, nsteps): # just to evaluate progress: (s1, s2), _ = ipd.reset() score1 = 0 score2 = 0 for t in range(nsteps): a1, lp1 = self.policy.act(s1, self.agent1.z) a2, lp2 = self.policy.act(s2, self.agent2.z) (s1, s2), (r1, r2), _, _ = ipd.step((a1, a2)) # cumulate scores score1 += np.mean(r1) / float(self.args.len_rollout) score2 += np.mean(r2) / float(self.args.len_rollout) return (score1, score2) def rollout_binning(self, nsteps, z1, z2): # just to evaluate progress: (s1, s2), _ = self.ipd2.reset() score1 = torch.zeros(self.args.nsamples_bin, dtype=torch.float) score2 = torch.zeros(self.args.nsamples_bin, dtype=torch.float) for t in range(nsteps): a1, lp1 = self.policy.act(s1, z1) a2, lp2 = self.policy.act(s2, z2) (s1, s2), (r1, r2), _, _ = self.ipd2.step((a1, a2)) # cumulate scores score1 += r1 score2 += r2 score1 = score1 / nsteps score2 = score2 / nsteps hist1 = torch.histc(score1, bins=self.args.nbins, min=-3, max=0) hist2 = torch.histc(score2, bins=self.args.nbins, min=-3, max=0) return hist1, hist2 def train(self): print("start iterations with", self.args.lookaheads, "lookaheads:") joint_scores = [] for update in range(self.args.n_iter): # 1a. For fixed z1 & z2, Learn steerable policy theta & PEN by maximizing rollouts for t in range(self.args.n_policy): # Update steerable policy parameters. True possible for IPD policy_loss = self.policy_update_true() # self.policy_update_pg() self.writer.add_scalar('PolicyObjective V1 plus V2', -policy_loss, update * self.args.n_policy + t) # 1b. Train the PEN # TODO: Convert this to a parallel version so one call to PEN is required for t in range(self.args.n_pen): # randomly generate z1, z2. Maybe generation centered on z0, z1 would be better. z1 = torch.randn(self.args.embedding_size) z2 = torch.randn(self.args.embedding_size) # Experiment with smaller length of rollouts for estimation hist1, hist2 = self.rollout_binning(self.args.len_rollout, z1, z2) # Compute the KL Div w1, w2 = self.pen.forward(self.agent1.z.unsqueeze(0), self.agent2.z.unsqueeze(0)) w1 = F.softmax(w1.squeeze(), dim=0) w2 = F.softmax(w2.squeeze(), dim=0) # F.kl_div(Q.log(), P, None, None, 'sum') self.pen_optimizer.zero_grad() # pen_loss = (hist1* (hist1 / w1).log()).sum() + (hist2* (hist2 / w2).log()).sum() pen_loss = F.kl_div(hist1, w1) + F.kl_div(hist2, w2) pen_loss.backward() self.pen_optimizer.step() self.writer.add_scalar('PEN Loss: KL1 plus KL2', pen_loss, update * self.args.n_pen + t) # 2. Do on Lola Updates self.lola_update_exact() # evaluate: score = self.rollout(self.args.len_rollout) avg_score = 0.5 * (score[0] + score[1]) self.writer.add_scalar('Avg Score of Agent', avg_score, update) joint_scores.append(avg_score) # Logging if update % 10 == 0: print('After update', update, '------------') p0 = [p.item() for p in torch.sigmoid(self.policy.theta)] p1 = [ p.item() for p in torch.sigmoid(self.policy.theta + self.agent1.z) ] p2 = [ p.item() for p in torch.sigmoid(self.policy.theta + self.agent2.z) ] print( 'score (%.3f,%.3f)\n' % (score[0], score[1]), 'Default = {S: %.3f, DD: %.3f, DC: %.3f, CD: %.3f, CC: %.3f}\n' % (p0[0], p0[1], p0[2], p0[3], p0[4]), '(agent1) = {S: %.3f, DD: %.3f, DC: %.3f, CD: %.3f, CC: %.3f}\n' % (p1[0], p1[1], p1[2], p1[3], p1[4]), '(agent2) = {S: %.3f, DD: %.3f, DC: %.3f, CD: %.3f, CC: %.3f}' % (p2[0], p2[1], p2[2], p2[3], p2[4])) # print('theta: ', self.policy.theta, '\n', 'z1: ', self.agent1.z, '\n', 'z2: ', self.agent2.z) return joint_scores def policy_update_true(self): # Batching not needed here. self.policy.theta_optimizer.zero_grad() theta1 = self.agent1.z + self.policy.theta theta2 = self.agent2.z + self.policy.theta objective = (true_objective(theta1, theta2, ipd) + true_objective(theta2, theta1, ipd)) objective.backward() self.policy.theta_optimizer.step() return objective def policy_update_pg(self): """ TODO: Will need batching """ pass def lola_update_exact(self): """ Do Lola Updates """ # copy other's parameters: z1_ = self.agent1.z.clone().detach().requires_grad_(True) z2_ = self.agent2.z.clone().detach().requires_grad_(True) for k in range(self.args.lookaheads): # estimate other's gradients from in_lookahead: grad2 = self.agent1.in_lookahead(self.policy.theta, z2_) grad1 = self.agent2.in_lookahead(self.policy.theta, z1_) # update other's theta z2_ = z2_ - self.args.lr_in * grad2 z1_ = z1_ - self.args.lr_in * grad1 # update own parameters from out_lookahead: self.agent1.out_lookahead(self.policy.theta, z2_) self.agent2.out_lookahead(self.policy.theta, z1_)
""" Do Lola Updates """ # copy other's parameters: z1_ = self.agent1.z.clone().detach().requires_grad_(True) z2_ = self.agent2.z.clone().detach().requires_grad_(True) for k in range(self.args.lookaheads): # estimate other's gradients from in_lookahead: grad2 = self.agent1.in_lookahead(self.policy.theta, z2_) grad1 = self.agent2.in_lookahead(self.policy.theta, z1_) # update other's theta z2_ = z2_ - self.args.lr_in * grad2 z1_ = z1_ - self.args.lr_in * grad1 # update own parameters from out_lookahead: self.agent1.out_lookahead(self.policy.theta, z2_) self.agent2.out_lookahead(self.policy.theta, z1_) if __name__ == "__main__": args = get_args() global ipd ipd = IPD(max_steps=args.len_rollout, batch_size=args.batch_size) torch.manual_seed(args.seed) polen = Polen(args) scores = polen.train() polen.writer.close() if args.plot: plot(scores, args)
def __init__(self, args): """ 1. Initialize Agents with embeddings z, shared policy theta & pen """ self.args = args self.device = torch.device("cuda:0" if ( torch.cuda.is_available() and self.args.gpu) else "cpu") print('Using Device: ', self.device) # Default ipd object has batch_size=args.pen_batch_size. For any other # pass that batch_size in the reset, step functions self.ipd = IPD(max_steps=args.len_rollout, batch_size=args.pen_batch_size) self.agent1 = Agent(self.args, self.device, self.ipd) if self.args.fix_z2: self.agent2 = Agent(self.args, self.device, self.ipd, val=torch.tensor( [-20.0, 20.0, -20.0, 20.0, -20.0])) else: self.agent2 = Agent(self.args, self.device, self.ipd) if self.args.not_use_policy_net: self.policy = SteerablePolicy(self.args, self.device) self.policy_optimizer = torch.optim.Adam( params=(self.policy.theta, ), lr=args.lr_theta) else: self.policy = SteerablePolicyNet(self.args, self.device) self.policy.to(self.device) self.policy_optimizer = torch.optim.Adam( params=self.policy.parameters(), lr=args.lr_theta) z self.buffer = ReplayBuffer(size=self.args.buffer_size, dim=self.args.embedding_size) if self.args.use_kl: self.pen = PolicyEvaluationNetwork(self.args, self.device, no_avg=True) else: self.pen = PolicyEvaluationNetwork_2(self.args, self.device) self.pen.to(self.device) self.pen_optimizer = torch.optim.Adam(params=self.pen.parameters(), lr=args.lr_pen) # self.pen_scheduler = torch.optim.lr_scheduler.StepLR(self.pen_optimizer, step_size=10, gamma=0.9) self.pth_1 = '2l_tanh/' + ('true_vf/' if self.args.pen_true_vf else ('sampled/num_' + str(self.args.nsamples_return))) self.pth_2 = ('_KL_' if self.args.use_kl else '') + '_policy_h_' + str( self.args.policy_hidden) + '_h_' + str( self.args.pen_hidden) + '_train_' + str( self.args.pen_train_size) + '_bsz_' + str( self.args.pen_batch_size) + '_ep_' + str( self.args.pen_epochs) + '_lr_' + str( self.args.lr_pen) + '_seed_' + str( self.args.seed) + ('_gpu_' if self.args.gpu else '') + '.pth' self.args.logdir = args.logdir + self.pth_1 + 'look_' + str( args.lookaheads ) + '/n_policy_' + str(args.n_policy) + '_n_pen' + str( args.n_pen) + '_lr_theta_' + str(args.lr_theta) + '_lr_in_' + str( args.lr_in) + '_lr_out_' + str(args.lr_out) + '_emb_' + str( args.embedding_size) + self.pth_2 if os.path.exists(self.args.logdir): print('Removing existing') rmtree(self.args.logdir) self.writer = SummaryWriter(self.args.logdir)
class Polen(): def __init__(self, args): """ 1. Initialize Agents with embeddings z, shared policy theta & pen """ self.args = args self.device = torch.device("cuda:0" if ( torch.cuda.is_available() and self.args.gpu) else "cpu") print('Using Device: ', self.device) # Default ipd object has batch_size=args.pen_batch_size. For any other # pass that batch_size in the reset, step functions self.ipd = IPD(max_steps=args.len_rollout, batch_size=args.pen_batch_size) self.agent1 = Agent(self.args, self.device, self.ipd) if self.args.fix_z2: self.agent2 = Agent(self.args, self.device, self.ipd, val=torch.tensor( [-20.0, 20.0, -20.0, 20.0, -20.0])) else: self.agent2 = Agent(self.args, self.device, self.ipd) if self.args.not_use_policy_net: self.policy = SteerablePolicy(self.args, self.device) self.policy_optimizer = torch.optim.Adam( params=(self.policy.theta, ), lr=args.lr_theta) else: self.policy = SteerablePolicyNet(self.args, self.device) self.policy.to(self.device) self.policy_optimizer = torch.optim.Adam( params=self.policy.parameters(), lr=args.lr_theta) z self.buffer = ReplayBuffer(size=self.args.buffer_size, dim=self.args.embedding_size) if self.args.use_kl: self.pen = PolicyEvaluationNetwork(self.args, self.device, no_avg=True) else: self.pen = PolicyEvaluationNetwork_2(self.args, self.device) self.pen.to(self.device) self.pen_optimizer = torch.optim.Adam(params=self.pen.parameters(), lr=args.lr_pen) # self.pen_scheduler = torch.optim.lr_scheduler.StepLR(self.pen_optimizer, step_size=10, gamma=0.9) self.pth_1 = '2l_tanh/' + ('true_vf/' if self.args.pen_true_vf else ('sampled/num_' + str(self.args.nsamples_return))) self.pth_2 = ('_KL_' if self.args.use_kl else '') + '_policy_h_' + str( self.args.policy_hidden) + '_h_' + str( self.args.pen_hidden) + '_train_' + str( self.args.pen_train_size) + '_bsz_' + str( self.args.pen_batch_size) + '_ep_' + str( self.args.pen_epochs) + '_lr_' + str( self.args.lr_pen) + '_seed_' + str( self.args.seed) + ('_gpu_' if self.args.gpu else '') + '.pth' self.args.logdir = args.logdir + self.pth_1 + 'look_' + str( args.lookaheads ) + '/n_policy_' + str(args.n_policy) + '_n_pen' + str( args.n_pen) + '_lr_theta_' + str(args.lr_theta) + '_lr_in_' + str( args.lr_in) + '_lr_out_' + str(args.lr_out) + '_emb_' + str( args.embedding_size) + self.pth_2 if os.path.exists(self.args.logdir): print('Removing existing') rmtree(self.args.logdir) self.writer = SummaryWriter(self.args.logdir) def train(self): """ Main function to 1. Initialize PEN 2. Do LOLA updates while refining PEN. """ print(f"start iterations with {self.args.lookaheads} lookaheads") joint_scores = [] PATH = self.args.savedir + self.pth_1 + self.pth_2 # PATH = 'saved_models/2l_tanh/true_vf/_h_128_train_50000_bsz_128_ep_3_lr_0.002_seed_911.pth' if os.path.exists(PATH): print('Loading Saved Model from : ', PATH) self.pen.load_state_dict(torch.load(PATH)) elif self.args.grads == 'pen': print('Training PEN from Sratch------------') pen_train = PenDataset(self.args.pen_train_size, self.args.embedding_size) pen_test = PenDataset(self.args.pen_test_size, self.args.embedding_size) dloader_train = DataLoader(pen_train, batch_size=self.args.pen_batch_size) pen_steps = 0 for epoch in range(1, self.args.pen_epochs + 1): for t, sampled_batch in enumerate(dloader_train): z1s, z2s = sampled_batch pen_train_loss = self.pen_update( z1s, z2s, use_true_vf=self.args.pen_true_vf) self.writer.add_scalar( 'PEN_LOSS/Train', pen_train_loss, pen_steps * self.args.pen_batch_size) if t % 100 == 0: #Compute test stats pen_test_loss = self.pen_update( pen_test.z1s, pen_test.z2s, eval=True, use_true_vf=self.args.pen_true_vf) self.writer.add_scalar( 'PEN_LOSS/Test', pen_test_loss, pen_steps * self.args.pen_batch_size) print( 'Epoch: {}, After {} iter, Train Loss: {} , Test Loss: {}' .format(epoch, t, pen_train_loss, pen_test_loss)) pen_steps += 1 # self.writer.add_scalar('MSE V1 and R1', mse_1, t) # self.writer.add_scalar('MSE V2 and R2', mse_2, t) print('Saving Model to : ', PATH) torch.save(self.pen.state_dict(), PATH) # Start main algo iterations for update in range(self.args.n_iter): # 1a. For fixed z1 & z2, Learn steerable policy theta & PEN by maximizing rollouts if self.args.n_policy and update % self.args.n_policy == 0: # for t in range(self.args.n_policy): # Update steerable policy parameters. True possible for IPD if not self.args.use_pg: policy_loss = self.policy_update_true(step=update) else: policy_loss = self.policy_update_pg() self.writer.add_scalar('PolicyObjective V1 plus V2', -policy_loss, update) # self.writer.add_scalar('PolicyObjective V1 plus V2', -policy_loss, update*self.args.n_policy + t ) if self.args.grads == 'pen': self.buffer.add_surround(self.agent1.z.data, self.agent2.z.data, num_samples=self.args.pen_batch_size) # Refine the PEN if update % 1 == 0: for t in range(self.args.n_pen): z1s, z2s = self.buffer.sample_batch( batch_size=self.args.pen_batch_size) # # Generate z1s & z2s aroud z1, z2 sampled from the buffer # z1s = (torch.rand((self.args.pen_batch_size,self.args.embedding_size)).to(self.device)) + z1_sample # z2s = (torch.rand((self.args.pen_batch_size,self.args.embedding_size)).to(self.device)) + z2_sample pen_train_loss = self.pen_update( z1s.to(self.device), z2s.to(self.device), use_true_vf=self.args.pen_true_vf) self.writer.add_scalar('PEN_Refine/Train', pen_train_loss, update * self.args.n_pen + t) # self.pen_scheduler.step() # print('New PEN lr rate ', self.pen_optimizer.param_groups[0]['lr']) # 2. Do Lola Updates if self.args.grads == 'pen': in_grads, out_grads = self.lola_update_pen( fix_z2=self.args.fix_z2) self.writer.add_scalar('LOLA_Grads/MSE', in_grads[0], update) self.writer.add_scalar('LOLA_Grads/COS_Similarity_1', in_grads[1], update) self.writer.add_scalar('LOLA_Grads/COS_Similarity_2', in_grads[2], update) self.writer.add_scalar('LOLA_Grads/Norm_PEN', in_grads[3], update) self.writer.add_scalar('Outer_LOLA_Grads/MSE', out_grads[0], update) self.writer.add_scalar('Outer_LOLA_Grads/COS_Similarity_1', out_grads[1], update) self.writer.add_scalar('Outer_LOLA_Grads/COS_Similarity_2', out_grads[2], update) self.writer.add_scalar('Outer_LOLA_Grads/Norm_PEN', out_grads[3], update) elif self.args.grads == 'dice': self.lola_update_dice() elif self.args.grads == 'true': self.lola_update_exact() else: raise NotImplementedError # Evaluate: score = self.rollout(self.args.len_rollout) avg_score = 0.5 * (score[0] + score[1]) self.writer.add_scalar('Avg Score of Agent', avg_score, update) joint_scores.append(avg_score) # Logging if update % 10 == 0: # import ipdb; ipdb.set_trace() p0 = [ p.item() for p in torch.sigmoid( self.policy.fwd(torch.zeros_like(self.agent1.z))) ] p1 = [ p.item() for p in torch.sigmoid(self.policy.fwd(self.agent1.z)) ] p2 = [ p.item() for p in torch.sigmoid(self.policy.fwd(self.agent2.z)) ] self.writer.add_image('Defect Probs', plot_scatter(p0, p1, p2), update) print('After update', update, '------------') v_env_1, v_env_2 = score[0], score[1] v_true_1 = -true_objective(self.policy.fwd(self.agent1.z), self.policy.fwd(self.agent2.z), self.ipd).data v_true_2 = -true_objective(self.policy.fwd(self.agent2.z), self.policy.fwd(self.agent1.z), self.ipd).data v_pen_1 = -self.pen.predict(self.agent1.z, self.agent2.z).data v_pen_2 = -self.pen.predict(self.agent2.z, self.agent1.z).data print('Sampled Score in Env (%.3f,%.3f)' % (v_env_1, v_env_2)) print('True Value ({}, {})'.format(v_true_1, v_true_2)) print('PEN Value ({}, {})'.format(v_pen_1, v_pen_2)) print('Default = {S: %.3f, DD: %.3f, DC: %.3f, CD: %.3f, CC: %.3f}\n' % (p0[0], p0[1], p0[2], p0[3], p0[4]),\ '(agent1) = {S: %.3f, DD: %.3f, DC: %.3f, CD: %.3f, CC: %.3f}\n' % (p1[0], p1[1], p1[2], p1[3], p1[4]),\ '(agent2) = {S: %.3f, DD: %.3f, DC: %.3f, CD: %.3f, CC: %.3f}' % (p2[0], p2[1], p2[2], p2[3], p2[4])) self.writer.add_scalar('Value_Agent_1/Sampled in env', v_env_1, update) self.writer.add_scalar('Value_Agent_1/True Value', v_true_1, update) self.writer.add_scalar('Value_Agent_1/PEN Value', v_pen_1, update) self.writer.add_scalar('Value_Agent_2/Sampled in env', v_env_2, update) self.writer.add_scalar('Value_Agent_2/True Value', v_true_2, update) self.writer.add_scalar('Value_Agent_2/PEN Value', v_pen_2, update) return joint_scores, p0, p1, p2 def pen_update(self, z1s, z2s, eval=False, use_true_vf=True): """ Method to update PEN with MSE loss Args: z1s (bsz x dimension): z2s (bsz x dimension): eval (bool, optional): Run in eval mode. No gradient steps on PEN. Defaults to False. use_true_vf (bool, optional): Whether to use True Value Function as target or use returns computed from the environment. Defaults to True. Returns: pen_loss: The PEN MSE loss computed for the batch """ if use_true_vf: target_1, target_2 = [], [] # TODO: See if can parallelize this for i in range(z1s.shape[0]): target_1.append(-true_objective(self.policy.fwd( z1s[i]), self.policy.fwd(z2s[i]), self.ipd)) target_2.append(-true_objective(self.policy.fwd( z2s[i]), self.policy.fwd(z1s[i]), self.ipd)) target_1 = torch.tensor(target_1).to(self.device) target_2 = torch.tensor(target_2).to(self.device) else: target_1, target_2 = torch.zeros(z1s.shape[0]), torch.zeros( z2s.shape[0]) for _ in range(self.args.nsamples_return): t_1, t_2 = self.rollout_binning_batch(self.args.len_rollout, z1s, z2s) target_1 += t_1 target_2 += t_2 target_1 = target_1 / self.args.nsamples_return target_2 = target_2 / self.args.nsamples_return # Compute the PEN Values w1 = self.pen.forward(z1s, z2s) w2 = self.pen.forward(z2s, z1s) pen_loss = ((w1.squeeze(1) - target_1)**2 + (w2.squeeze(1) - target_2)**2).mean() if not eval: self.pen_optimizer.zero_grad() pen_loss.backward() self.pen_optimizer.step() return pen_loss def lola_update_pen(self, fix_z2=False): """ Do lola updates on agent parameters using Gradients from the PEN Args: fix_z2 (bool, optional): Fix Agent 2's parameters and do learning only for Agent 1. To make the environment stationary for Debugging. Defaults to False. Returns: in_grads: List of gradient statistics in inner loop of lola out_grads: List of gradient statistics in outer loop of lola """ # copy other's parameters: z1_ = self.agent1.z.clone().detach().requires_grad_(True) z2_ = self.agent2.z.clone().detach().requires_grad_(True) grad_diff = 0 grad_cos_1 = 0 grad_cos_2 = 0 grads = 0 out_grad_cos_1 = 0 out_grad_cos_2 = 0 # Update clones using lookaheads for k in range(self.args.lookaheads): # estimate other's gradients from in_lookahead: grad2 = self.agent1.in_lookahead_pen(self.pen, z2_) grad1 = self.agent2.in_lookahead_pen(self.pen, z1_) grad2_ex = self.agent1.in_lookahead_exact(self.policy, z2_) grad1_ex = self.agent2.in_lookahead_exact(self.policy, z1_) # update other's theta z2_ = z2_ - self.args.lr_in * grad2 z1_ = z1_ - self.args.lr_in * grad1 grads = grads + torch.norm(grad1) grad_diff += F.mse_loss(grad1, grad1_ex) + F.mse_loss( grad2, grad2_ex) grad_cos_1 += F.cosine_similarity(grad1.unsqueeze(0), grad1_ex.unsqueeze(0)) grad_cos_2 += F.cosine_similarity(grad2.unsqueeze(0), grad2_ex.unsqueeze(0)) # update own parameters from out_lookahead: # Need to clone because grad is just a reference grad_z1_pen = self.agent1.out_lookahead_pen(self.pen, z2_).clone() grad_z1_exact = self.agent1.out_lookahead_exact(self.policy, z2_, eval=True) # if we fix_z2, then equal to just doing eval on it grad_z2_pen = self.agent2.out_lookahead_pen(self.pen, z1_, eval=fix_z2).clone() grad_z2_exact = self.agent2.out_lookahead_exact(self.policy, z1_, eval=True) out_grad_cos_1 = F.cosine_similarity(grad_z1_pen.unsqueeze(0), grad_z1_exact.unsqueeze(0)) out_grad_cos_2 = F.cosine_similarity(grad_z2_pen.unsqueeze(0), grad_z2_exact.unsqueeze(0)) out_grad_diff = F.mse_loss(grad_z1_pen, grad_z1_exact) + F.mse_loss( grad_z2_pen, grad_z2_exact) out_grad_norm = torch.norm(grad_z1_pen) + torch.norm(grad_z2_pen) out_grads = [ out_grad_diff, out_grad_cos_1, out_grad_cos_2, out_grad_norm ] if self.args.lookaheads > 0: in_grads = [ grad_diff / self.args.lookaheads, grad_cos_1 / self.args.lookaheads, grad_cos_2 / self.args.lookaheads, grads / self.args.lookaheads ] else: in_grads = [grad_diff, grad_cos_1, grad_cos_2, grads] return in_grads, out_grads def pen_update_kl(self): """ Method to update PEN with KL Dive loss. DEPRECATED CURRENTLY Args: z1s (bsz x dimension): z2s (bsz x dimension): eval (bool, optional): Run in eval mode. No gradient steps on PEN. Defaults to False. use_true_vf (bool, optional): Whether to use True Value Function as target or use returns computed from the environment. Defaults to True. Returns: pen_loss: The PEN MSE loss computed for the batch """ # randomly generate z1, z2. Maybe generation centered on z0, z1 would be better. z1 = sigmoid_inv(torch.rand(self.args.embedding_size)) z2 = sigmoid_inv(torch.rand(self.args.embedding_size)) # Experiment with smaller length of rollouts for estimation hist1, hist2, avg_return_1, avg_return_2 = self.rollout_binning( self.args.len_rollout, z1, z2) # Compute the KL Div w1 = self.pen.forward(self.agent1.z.unsqueeze(0), self.agent2.z.unsqueeze(0)) w2 = self.pen.forward(self.agent2.z.unsqueeze(0), self.agent1.z.unsqueeze(0)) v1 = self.pen.predict(self.agent1.z, self.agent2.z) v2 = self.pen.predict(self.agent2.z, self.agent1.z) w1 = F.softmax(w1.squeeze(), dim=0) w2 = F.softmax(w2.squeeze(), dim=0) mse_1 = (-v1 - avg_return_1 / (1 - self.args.gamma))**2 mse_2 = (-v2 - avg_return_2 / (1 - self.args.gamma))**2 self.pen_optimizer.zero_grad() # Note: F.kl_div(Q.log(), P) computes KL(P || Q) pen_loss = F.kl_div(w1.squeeze().log(), hist1) + F.kl_div( w2.squeeze().log(), hist2) pen_loss.backward() self.pen_optimizer.step() return pen_loss, mse_1, mse_2 def policy_update_true(self, step): # Batching not needed here. self.policy_optimizer.zero_grad() # if self.args.not_use_policy_net: # theta_ = self.policy.theta.clone().detach() # objective = true_objective(self.agent1.z + self.policy.theta, self.agent2.z + theta_, self.ipd)\ # + true_objective(self.agent2.z + self.policy.theta, self.agent1.z + theta_, self.ipd) # else: # objective = true_objective(self.policy.fwd(self.agent1.z), self.policy.fwd(self.agent2.z), self.ipd)\ # + true_objective(self.policy.fwd(self.agent2.z), self.policy.fwd(self.agent1.z), self.ipd) objective = true_objective(self.policy.fwd(self.agent1.z), self.policy.fwd(self.agent2.z).detach(), self.ipd)\ + true_objective(self.policy.fwd(self.agent2.z), self.policy.fwd(self.agent1.z).detach(), self.ipd) # TODO: See what this retain_graph does. Some multiprocessing stuff. objective.backward(retain_graph=True) self.policy_optimizer.step() # grad_norm = self.policy.linear1.weight.grad.norm()+ self.policy.linear2.weight.grad.norm() # self.writer.add_scalar('Policy/Grad_Norms', grad_norm, step) return objective def policy_update_pg(self): """ Policy Gradient for theta using Dice Objective. Treat theta as shared parameters """ # First using Agent 1 (s1, s2), _ = self.ipd.reset(self.args.ipd_batch_size) memory1 = Memory(self.args) memory2 = Memory(self.args) for t in range(self.args.len_rollout): a1, lp1, v1 = self.policy.act(s1, self.agent1.z, self.agent1.values) a2, lp2, v2 = self.policy.act(s2, self.agent2.z, self.agent2.values) (s1, s2), (r1, r2), _, _ = self.ipd.step((a1, a2), self.args.ipd_batch_size) memory1.add(lp1, lp2, v1, torch.from_numpy(r1).float()) memory2.add(lp2, lp1, v2, torch.from_numpy(r2).float()) # update self z objective = memory1.dice_objective( dice_both=self.args.dice_both) + memory2.dice_objective( dice_both=self.args.dice_both) # self.policy.policy_update(objective) self.policy_optimizer.zero_grad() objective.backward(retain_graph=True) self.policy_optimizer.step() # # update self value: self.agent1.value_update(memory1.value_loss()) self.agent2.value_update(memory2.value_loss()) return objective def lola_update_exact(self): """ Do Lola Updates using true value function """ # copy other's parameters: z1_ = self.agent1.z.clone().detach().requires_grad_(True) z2_ = self.agent2.z.clone().detach().requires_grad_(True) for k in range(self.args.lookaheads): # estimate other's gradients from in_lookahead: grad2 = self.agent1.in_lookahead_exact(self.policy, z2_) grad1 = self.agent2.in_lookahead_exact(self.policy, z1_) # update other's theta z2_ = z2_ - self.args.lr_in * grad2 z1_ = z1_ - self.args.lr_in * grad1 # update own parameters from out_lookahead: self.agent1.out_lookahead_exact(self.policy, z2_) self.agent2.out_lookahead_exact(self.policy, z1_) def lola_update_dice(self): """ Do Lola-DiCE Updates """ # copy other's parameters: z1_ = self.agent1.z.clone().detach().requires_grad_(True) z2_ = self.agent2.z.clone().detach().requires_grad_(True) values1_ = self.agent1.values.clone().detach().requires_grad_(True) values2_ = self.agent2.values.clone().detach().requires_grad_(True) for k in range(self.args.lookaheads): # estimate other's gradients from in_lookahead: grad2 = self.agent1.in_lookahead(z2_, values2_, self.policy) grad1 = self.agent2.in_lookahead(z1_, values1_, self.policy) # update other's theta z2_ = z2_ - self.args.lr_in * grad2 z1_ = z1_ - self.args.lr_in * grad1 # update own parameters from out_lookahead: self.agent1.out_lookahead(z2_, values2_, self.policy) self.agent2.out_lookahead(z1_, values1_, self.policy) def rollout(self, nsteps): # just to evaluate progress: (s1, s2), _ = self.ipd.reset(self.args.ipd_batch_size) score1 = 0 score2 = 0 for t in range(nsteps): a1, lp1, v1 = self.policy.act(s1, self.agent1.z, self.agent1.values) a2, lp2, v2 = self.policy.act(s2, self.agent2.z, self.agent2.values) (s1, s2), (r1, r2), _, _ = self.ipd.step((a1, a2), self.args.ipd_batch_size) # cumulate scores score1 += np.mean(r1) / float(self.args.len_rollout) score2 += np.mean(r2) / float(self.args.len_rollout) return (score1, score2) def rollout_binning(self, nsteps, z1, z2): # just to evaluate progress: (s1, s2), _ = self.ipd.reset(self.args.nsamples_return) score1 = np.zeros(self.args.nsamples_return) score2 = np.zeros(self.args.nsamples_return) gamma_t = 1 for t in range(nsteps): a1, lp1 = self.policy.act(s1, z1) a2, lp2 = self.policy.act(s2, z2) (s1, s2), (r1, r2), _, _ = self.ipd.step((a1, a2), self.args.nsamples_return) # cumulate scores score1 += gamma_t * r1 score2 += gamma_t * r2 gamma_t = gamma_t * self.args.gamma # score1 = torch.tensor(score1*(1-self.args.gamma), dtype=torch.float32) # score2 = torch.tensor(score2*(1-self.args.gamma), dtype=torch.float32) # hist1 = torch.histc(score1, bins=self.args.nbins, min=-3.0, max=0.0) # hist2 = torch.histc(score2, bins=self.args.nbins, min=-3.0, max=0.0) score1 = torch.tensor(score1, dtype=torch.float32) score2 = torch.tensor(score2, dtype=torch.float32) hist1 = torch.histc(score1, bins=self.args.nbins, min=-75.0, max=0.0) hist2 = torch.histc(score2, bins=self.args.nbins, min=-75.0, max=0.0) return hist1, hist2, score1.mean(), score2.mean() def rollout_binning_batch(self, nsteps, z1s, z2s): """Compute the Value Targets by sampling in the environment using Steerable policy conditioned on strategy vectors. Args: nsteps (int): No of time-steps in the environment z1s (bsz x dim): Batch of z1s z2s (bsz x dim): Batch of z2s Returns: score1 (bsz): Value Targets for agent 1 score2 (bsz): Value Targets for agent 2 """ bsz = z1s.shape[0] (s1, s2), _ = self.ipd.reset(batch_size=bsz) score1 = np.zeros(bsz) score2 = np.zeros(bsz) gamma_t = 1 for t in range(nsteps): a1, lp1 = self.policy.act_parallel(s1, z1s) a2, lp2 = self.policy.act_parallel(s2, z2s) (s1, s2), (r1, r2), _, _ = self.ipd.step((a1, a2), batch_size=bsz) # Accrue discounted scores score1 += gamma_t * r1 score2 += gamma_t * r2 gamma_t = gamma_t * self.args.gamma score1 = torch.tensor(score1, dtype=torch.float32) score2 = torch.tensor(score2, dtype=torch.float32) return score1, score2