def update(): obs, act, adv, ret, logp_old = [torch.Tensor(x) for x in buf.get()] # Policy gradient step _, logp, _ = actor_critic.policy(obs, act) ent = (-logp).mean() # a sample estimate for entropy # VPG policy objective pi_loss = -(logp * adv).mean() # Policy gradient step train_pi.zero_grad() pi_loss.backward() average_gradients(train_pi.param_groups) train_pi.step() # Value function learning v = actor_critic.value_function(obs) v_l_old = F.mse_loss(v, ret) for _ in range(train_v_iters): # Output from value function graph v = actor_critic.value_function(obs) # VPG value objective v_loss = F.mse_loss(v, ret) # Value function gradient step train_v.zero_grad() v_loss.backward() average_gradients(train_v.param_groups) train_v.step() # Log changes from update _, logp, _, v = actor_critic(obs, act) pi_l_new = -(logp * adv).mean() v_l_new = F.mse_loss(v, ret) kl = (logp_old - logp).mean() # a sample estimate for KL-divergence logger.store( LossPi=pi_loss, LossV=v_l_old, KL=kl, Entropy=ent, DeltaLossPi=(pi_l_new - pi_loss), DeltaLossV=(v_l_new - v_l_old), )
def update(): inputs = [torch.Tensor(x) for x in buf.get()] obs, act, adv, ret, logp_old = inputs[:-len(buf.sorted_info_keys)] policy_args = dict( zip(buf.sorted_info_keys, inputs[-len(buf.sorted_info_keys):])) # Main outputs from computation graph _, logp, _, _, d_kl, v = actor_critic(obs, act, **policy_args) # Prepare hessian func, gradient eval ratio = (logp - logp_old).exp() # pi(a|s) / pi_old(a|s) pi_l_old = -(ratio * adv).mean() v_l_old = F.mse_loss(v, ret) g = core.flat_grad(pi_l_old, actor_critic.policy.parameters(), retain_graph=True) g = torch.from_numpy(mpi_avg(g.numpy())) pi_l_old = mpi_avg(pi_l_old.item()) def Hx(x): hvp = core.hessian_vector_product(d_kl, actor_critic.policy, x) if damping_coeff > 0: hvp += damping_coeff * x return torch.from_numpy(mpi_avg(hvp.numpy())) # Core calculations for TRPO or NPG x = cg(Hx, g) alpha = torch.sqrt(2 * delta / (torch.dot(x, Hx(x)) + EPS)) old_params = parameters_to_vector(actor_critic.policy.parameters()) def set_and_eval(step): vector_to_parameters(old_params - alpha * x * step, actor_critic.policy.parameters()) _, logp, _, _, d_kl = actor_critic.policy(obs, act, **policy_args) ratio = (logp - logp_old).exp() pi_loss = -(ratio * adv).mean() return mpi_avg(d_kl.item()), mpi_avg(pi_loss.item()) if algo == "npg": kl, pi_l_new = set_and_eval(step=1.0) elif algo == "trpo": for j in range(backtrack_iters): kl, pi_l_new = set_and_eval(step=backtrack_coeff**j) if kl <= delta and pi_l_new <= pi_l_old: logger.log( "Accepting new params at step %d of line search." % j) logger.store(BacktrackIters=j) break if j == backtrack_iters - 1: logger.log("Line search failed! Keeping old params.") logger.store(BacktrackIters=j) kl, pi_l_new = set_and_eval(step=0.0) # Value function updates for _ in range(train_v_iters): v = actor_critic.value_function(obs) v_loss = F.mse_loss(v, ret) # Value function gradient step train_vf.zero_grad() v_loss.backward() average_gradients(train_vf.param_groups) train_vf.step() v = actor_critic.value_function(obs) v_l_new = F.mse_loss(v, ret) # Log changes from update logger.store( LossPi=pi_l_old, LossV=v_l_old, KL=kl, DeltaLossPi=(pi_l_new - pi_l_old), DeltaLossV=(v_l_new - v_l_old), )
def update(): temp_get = buf.get() obs, act, adv, ret, logp_old = [ torch.Tensor(x).to(device) for x in temp_get ] # Training policy _, logp, _ = actor_critic.policy(obs, act) ratio = (logp - logp_old).exp() min_adv = torch.where(adv > 0, (1 + clip_ratio) * adv, (1 - clip_ratio) * adv) pi_l_old = -(torch.min(ratio * adv, min_adv)).mean() ent = (-logp).mean() # a sample estimate for entropy for i in range(train_pi_iters): # Output from policy function graph _, logp, _ = actor_critic.policy(obs, act) # PPO policy objective ratio = (logp - logp_old).exp() min_adv = torch.where(adv > 0, (1 + clip_ratio) * adv, (1 - clip_ratio) * adv) pi_loss = -(torch.min(ratio * adv, min_adv)).mean() # Policy gradient step train_pi.zero_grad() pi_loss.backward() average_gradients(train_pi.param_groups) train_pi.step() _, logp, _ = actor_critic.policy(obs, act) kl = (logp_old - logp).mean() kl = mpi_avg(kl.item()) if kl > 1.5 * target_kl: logger.log( 'Early stopping at step %d due to reaching max kl.' % i) break logger.store(StopIter=i) # Training value function v = actor_critic.value_function(obs) v_l_old = F.mse_loss(v, ret) for _ in range(train_v_iters): # Output from value function graph v = actor_critic.value_function(obs) # PPO value function objective v_loss = F.mse_loss(v, ret) # Value function gradient step train_v.zero_grad() v_loss.backward() average_gradients(train_v.param_groups) train_v.step() # Log changes from update _, logp, _, v = actor_critic(obs, act) ratio = (logp - logp_old).exp() min_adv = torch.where(adv > 0, (1 + clip_ratio) * adv, (1 - clip_ratio) * adv) pi_l_new = -(torch.min(ratio * adv, min_adv)).mean() v_l_new = F.mse_loss(v, ret) kl = (logp_old - logp).mean() # a sample estimate for KL-divergence clipped = (ratio > (1 + clip_ratio)) | (ratio < (1 - clip_ratio)) cf = (clipped.float()).mean() logger.store(LossPi=pi_l_old, LossV=v_l_old, KL=kl, Entropy=ent, ClipFrac=cf, DeltaLossPi=(pi_l_new - pi_l_old), DeltaLossV=(v_l_new - v_l_old))