def linesearch(model, f, x, fullstep, expected_improve_rate, max_backtracks=10, accept_ratio=.5): fval = f(model).data #print("\tfval before", fval.item()) for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)): xnew = x + stepfrac * fullstep utils.set_flat_params_to(model, xnew) newfval = f(model).data actual_improve = fval - newfval expected_improve = expected_improve_rate * stepfrac ratio = actual_improve / expected_improve #print("\ta : %6.4e /e : %6.4e /r : %6.4e "%(actual_improve.item(), expected_improve.item(), ratio.item())) if ratio.item() > accept_ratio and actual_improve.item() > 0: #print("\tfval after", newfval.item()) #print("\tlog(std): %f"%xnew[0]) print('update model') return True, xnew return False, x #the reference below is for computing minibatch loss #reference here https://discuss.pytorch.org/t/why-do-we-need-to-set-the-gradients-manually-to-zero-in-pytorch/4903/20
def _update_policy_network(self, state_batch_tensor, advantages, action_batch_tensor): action_mean_old, action_std_old = self.policy_network(state_batch_tensor) old_normal_dist = Normal(action_mean_old, action_std_old) # get the corresponding probability from the beta distribution... old_action_prob = old_normal_dist.log_prob(action_batch_tensor).sum(dim=1, keepdim=True) old_action_prob = old_action_prob.detach() action_mean_old = action_mean_old.detach() action_std_old = action_std_old.detach() # here will calculate the surrogate object surrogate_loss = self._get_surrogate_loss(state_batch_tensor, advantages, action_batch_tensor, old_action_prob) # compute the surrogate gradient -> g, Ax = g, where A is the Fisher Information Matrix... surrogate_grad = torch.autograd.grad(surrogate_loss, self.policy_network.parameters()) flat_surrogate_grad = torch.cat([grad.view(-1) for grad in surrogate_grad]).data # use the conjugated gradient to calculate the scaled direction(natrual gradient) natural_grad = conjugated_gradient(self._fisher_vector_product, -flat_surrogate_grad, 10, \ state_batch_tensor, action_mean_old, action_std_old) # calculate the scale ratio... non_scale_kl = 0.5 * (natural_grad * self._fisher_vector_product(natural_grad, state_batch_tensor, \ action_mean_old, action_std_old).sum(0, keepdim=True)) scale_ratio = torch.sqrt(non_scale_kl / self.args.max_kl) final_natural_grad = natural_grad / scale_ratio[0] # calculate the expected improvement rate... expected_improvement = (-flat_surrogate_grad * natural_grad).sum(0, keepdim=True) / scale_ratio[0] # get the flat param ... prev_params = torch.cat([param.data.view(-1) for param in self.policy_network.parameters()]) # start to do the line search.. success, new_params = line_search(self.policy_network, self._get_surrogate_loss, prev_params, \ final_natural_grad, expected_improvement, state_batch_tensor, advantages, action_batch_tensor, old_action_prob) # set the params to the models... set_flat_params_to(self.policy_network, new_params) return surrogate_loss.item()
def _actor_update(self,S_t,A_t_hat,U_t): lg = None with torch.enable_grad(): # Get loss gradient lf_actor = - torch.mean(U_t.view(-1) * self._policy.log_likelihood(A_t_hat,S_t.detach())) # import pdb; pdb.set_trace() grads = torch.autograd.grad(lf_actor, self._policy.parameters(ordered=True)) lg = torch.cat([grad.view(-1) for grad in grads]).data lg = lg if self._args['backend'] == 'pytorch' else lg.numpy() # Get natural gradient direction # Disable grad to make sure no graph is made with torch.no_grad(): fi = self._policy.fisher_information(S_t,backend=self._args['backend']) FVP = lambda v : self._policy.fisher_vector_product(*fi,v,backend=self._args['backend']) stepdir = opt.conjugate_gradients(FVP, lg, 10, self._args['damping'],grad=False,backend=self._args['backend']) if np.isnan(stepdir).any() and self._args['debug']: import pdb; pdb.set_trace() stepdir = stepdir if isinstance(stepdir,torch.Tensor) else torch.from_numpy(stepdir) prev_params = ut.get_flat_params_from(self._actor,ordered=True) # # weight decay # l2_pen = 0.01 # stepdir = stepdir + l2_pen * prev_params new_params = prev_params - self._args['lr_actor'] * stepdir ut.set_flat_params_to(self._actor, new_params,ordered=True)
def backtracking_ls_ratio(model, f, x, fullstep, expected_improve_rate, max_backtracks=10, accept_ratio=.1): fval = f(True).data logger.debug('fval before: %0.5f' % fval.item()) for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)): xnew = x + stepfrac * fullstep ut.set_flat_params_to(model, xnew) newfval = f(True).data actual_improve = fval - newfval expected_improve = expected_improve_rate * stepfrac ratio = actual_improve / expected_improve logger.debug('Backtrack iter %d: a/e/r: %0.5f, %0.5f, %0.5f' % (_n_backtracks, actual_improve.item(), expected_improve.item(), ratio.item())) if ratio.item() > accept_ratio and actual_improve.item() > 0: logger.debug('Backtrack iter %d: Done -- fval after: %0.5f' % (_n_backtracks, newfval.item())) return True, xnew return False, x
def backtracking_ls(model, f, x, fullstep, expected_improve, get_constraint=lambda x: -1, constraint_max=0, max_backtracks=10): fval = f(True).data logger.debug('fval before: %0.5f' % fval.item()) for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)): xnew = x + stepfrac * fullstep ut.set_flat_params_to(model, xnew) newfval = f(True).data actual_improve = fval - newfval cv = get_constraint() > constraint_max ic = actual_improve.item() < 0 logger.debug('Backtrack iter %d: a/e/cv: %0.5f, %0.5f, %0.5f' % (_n_backtracks, actual_improve.item(), expected_improve.item(), cv)) if cv: # import pdb; pdb.set_trace() logger.debug('Backtrack iter %d: Constraint Violated %0.5f' % (_n_backtracks, get_constraint())) elif actual_improve.item() < 0: logger.debug('Backtrack iter %d: No Improvement' % _n_backtracks) else: logger.debug('Backtrack iter %d: Done -- fval after: %0.5f' % (_n_backtracks, newfval.item())) return True, xnew return False, x
def _actor_update(self,S_t,A_t_hat,U_t): lg = None with torch.enable_grad(): # Get loss gradient lf_actor = - torch.mean(U_t.view(-1) * self._policy.log_likelihood(A_t_hat,S_t.detach())) grads = torch.autograd.grad(lf_actor, self._policy.parameters(ordered=True)) lg_t = torch.cat([grad.view(-1) for grad in grads]).data lg = lg_t if self._args['backend'] == 'pytorch' else lg_t.numpy() # Get natural gradient direction # Disable grad to make sure no graph is made with torch.no_grad(): fi = self._policy.fisher_information(S_t,backend=self._args['backend']) FVP = lambda v : self._policy.fisher_vector_product(*fi,v,backend=self._args['backend']) stepdir = opt.conjugate_gradients(FVP, lg, 10, self._args['damping'],grad=False,backend=self._args['backend']) if np.isnan(stepdir).any() and self._args['debug']: import pdb; pdb.set_trace() stepdir = stepdir if isinstance(stepdir,torch.Tensor) else torch.from_numpy(stepdir) prev_params = ut.get_flat_params_from(self._actor,ordered=True) # Compute Max Step-size natural_norm = torch.sqrt((stepdir * (FVP(stepdir)+self._args['damping']*stepdir)).sum(0)).item() max_step_size = self._args['lr_actor'] / natural_norm # # weight decay # l2_pen = 0.01 # stepdir = stepdir + l2_pen * prev_params logger.debug('Max Step Size %5.3g, Loss Grad Norm: %5.3g, Natural Norm: %5.3g' % (max_step_size,lg_t.norm().item(),natural_norm)) new_params = prev_params - max_step_size*stepdir ut.set_flat_params_to(self._actor, new_params,ordered=True)
def trpo_step(model, get_loss, get_kl, max_kl, damping): loss = get_loss() grads = torch.autograd.grad(loss, model.parameters()) loss_grad = torch.cat([grad.view(-1) for grad in grads]).data def Fvp(v): kl = get_kl() kl = kl.mean() grads = torch.autograd.grad(kl, model.parameters(), create_graph=True) flat_grad_kl = torch.cat([grad.view(-1) for grad in grads]) kl_v = (flat_grad_kl * Variable(v)).sum() grads = torch.autograd.grad(kl_v, model.parameters()) flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data return flat_grad_grad_kl + v * damping stepdir = conjugate_gradients(Fvp, -loss_grad, 10) shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True) lm = torch.sqrt(shs / max_kl) fullstep = stepdir / lm[0] neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True) print(("lagrange multiplier:", lm[0], "grad_norm:", loss_grad.norm())) prev_params = get_flat_params_from(model) success, new_params = linesearch(model, get_loss, prev_params, fullstep, neggdotstepdir / lm[0]) set_flat_params_to(model, new_params) return loss
def _update_network(self, mb_obs, mb_actions, mb_returns, mb_advs): mb_obs_tensor = torch.tensor(mb_obs, dtype=torch.float32) mb_actions_tensor = torch.tensor(mb_actions, dtype=torch.float32) mb_returns_tensor = torch.tensor(mb_returns, dtype=torch.float32).unsqueeze(1) mb_advs_tensor = torch.tensor(mb_advs, dtype=torch.float32).unsqueeze(1) # try to get the old policy and current policy values, _ = self.net(mb_obs_tensor) with torch.no_grad(): _, pi_old = self.old_net(mb_obs_tensor) # get the surr loss surr_loss = self._get_surrogate_loss(mb_obs_tensor, mb_advs_tensor, mb_actions_tensor, pi_old) # comupte the surrogate gardient -> g, Ax = g, where A is the fisher information matrix surr_grad = torch.autograd.grad(surr_loss, self.net.actor.parameters()) flat_surr_grad = torch.cat([grad.view(-1) for grad in surr_grad]).data # use the conjugated gradient to calculate the scaled direction vector (natural gradient) nature_grad = conjugated_gradient(self._fisher_vector_product, -flat_surr_grad, 10, mb_obs_tensor, pi_old) # calculate the scaleing ratio non_scale_kl = 0.5 * (nature_grad * self._fisher_vector_product( nature_grad, mb_obs_tensor, pi_old)).sum(0, keepdim=True) scale_ratio = torch.sqrt(non_scale_kl / self.args.max_kl) final_nature_grad = nature_grad / scale_ratio[0] # calculate the expected improvement rate... expected_improve = (-flat_surr_grad * nature_grad).sum( 0, keepdim=True) / scale_ratio[0] # get the flat param ... prev_params = torch.cat( [param.data.view(-1) for param in self.net.actor.parameters()]) # start to do the line search success, new_params = line_search(self.net.actor, self._get_surrogate_loss, prev_params, final_nature_grad, \ expected_improve, mb_obs_tensor, mb_advs_tensor, mb_actions_tensor, pi_old) set_flat_params_to(self.net.actor, new_params) # then trying to update the critic network inds = np.arange(mb_obs.shape[0]) for _ in range(self.args.vf_itrs): np.random.shuffle(inds) for start in range(0, mb_obs.shape[0], self.args.batch_size): end = start + self.args.batch_size mbinds = inds[start:end] mini_obs = mb_obs[mbinds] mini_returns = mb_returns[mbinds] # put things in the tensor mini_obs = torch.tensor(mini_obs, dtype=torch.float32) mini_returns = torch.tensor(mini_returns, dtype=torch.float32).unsqueeze(1) values, _ = self.net(mini_obs) v_loss = (mini_returns - values).pow(2).mean() self.optimizer.zero_grad() v_loss.backward() self.optimizer.step() return surr_loss.item(), v_loss.item()
def l_bfgs(fn, model, *args, maxiter=25): """ Run L-BFGS algorithm. args is anything required for the loss function """ _loss_eval_grad = lambda pv: function_eval_grad( fn, model, *args, params=pv) params_0 = ut.get_flat_params_from(model).double().numpy() params_T, _, opt_info = scipy.optimize.fmin_l_bfgs_b(_loss_eval_grad, params_0, maxiter=maxiter) ut.set_flat_params_to(model, torch.from_numpy(params_T))
def function_eval_grad(loss_function, model, *args, params=None): # check if need to set model params if params is not None: ut.set_flat_params_to(model, torch.from_numpy(params)) for param in model.parameters(): if param.grad is not None: param.grad.data.fill_(0) # Compute loss and extract gradient, loss value loss = loss_function(model, *args) loss.backward() loss_val = loss.detach().double().numpy() loss_grad = ut.get_flat_grad_from(model).detach().double().numpy() return loss_val, loss_grad
def get_value_loss(flat_params): set_flat_params_to(value_net, torch.Tensor(flat_params)) for param in value_net.parameters(): if param.grad is not None: param.grad.data.fill_(0) values_ = value_net(Variable(states)) value_loss = (values_ - targets).pow(2).mean() # weight decay for param in value_net.parameters(): value_loss += param.pow(2).sum() * args.l2_reg value_loss.backward() return (value_loss.data.double().numpy(), get_flat_grad_from(value_net).data.double().numpy())
def linesearch(model, f, x, fullstep, expected_improve_rate, max_backtracks=10, accept_ratio=.1): fval = f(True).data print("fval before", fval.item()) for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)): xnew = x + stepfrac * fullstep set_flat_params_to(model, xnew) newfval = f(True).data actual_improve = fval - newfval expected_improve = expected_improve_rate * stepfrac ratio = actual_improve / expected_improve print("a/e/r", actual_improve.item(), expected_improve.item(), ratio.item()) if ratio.item() > accept_ratio and actual_improve.item() > 0: print("fval after", newfval.item()) return True, xnew return False, x
def _trpo_step(self, get_loss, get_kl, max_kl, damping): model = self._actor loss = get_loss() grads = torch.autograd.grad(loss, model.parameters()) loss_grad = torch.cat([grad.view(-1) for grad in grads]).data Fvp = lambda v: fisher_vector_product(get_kl, v, model) stepdir = opt.conjugate_gradients(Fvp, -loss_grad, 10, damping) shs = 0.5 * (stepdir * (Fvp(stepdir) + damping * stepdir)).sum(0) lm = torch.sqrt(shs / max_kl).item() fullstep = stepdir / lm neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True) # import pdb; pdb.set_trace() logger.debug('lagrange multiplier %s, grad norm: %s' % (str(lm), str(loss_grad.norm()))) prev_params = ut.get_flat_params_from(model) success, new_params = opt.backtracking_ls_ratio( model, get_loss, prev_params, fullstep, neggdotstepdir / lm) ut.set_flat_params_to(model, new_params) return loss
def linesearch(model, f, x, fullstep, expected_improve_rate, max_backtracks=10, accept_ratio=.5): fval = f(model).data #print("\tfval before", fval.item()) for (_n_backtracks, stepfrac) in enumerate(.5**np.arange(max_backtracks)): xnew = x + stepfrac * fullstep utils.set_flat_params_to(model, xnew) newfval = f(model).data actual_improve = fval - newfval expected_improve = expected_improve_rate * stepfrac ratio = actual_improve / expected_improve #print("\ta : %6.4e /e : %6.4e /r : %6.4e "%(actual_improve.item(), expected_improve.item(), ratio.item())) if ratio.item() > accept_ratio and actual_improve.item() > 0: #print("\tfval after", newfval.item()) #print("\tlog(std): %f"%xnew[0]) print('update model') return True, xnew return False, x
def _trpo_step(self, get_loss, get_kl, max_kl, damping): model = self._actor loss = get_loss() grads = torch.autograd.grad(loss, model.parameters()) loss_grad = torch.cat([grad.view(-1) for grad in grads]).data Fvp = lambda v: fisher_vector_product(get_kl, v, model) stepdir = opt.conjugate_gradients(Fvp, -loss_grad, 10, damping) shs = 0.5 * (stepdir * (Fvp(stepdir) + damping * stepdir)).sum(0) lm = torch.sqrt(shs / max_kl).item() fullstep = stepdir / lm expected_improve = (-loss_grad * stepdir).sum(0, keepdim=True) logger.debug('lagrange multiplier %5.3g, grad norm: %5.3g' % (lm, loss_grad.norm().item())) prev_params = ut.get_flat_params_from(model) kl_constraint_eval = lambda: get_kl().mean().item() success, new_params = opt.backtracking_ls(model, get_loss, prev_params, fullstep, expected_improve, kl_constraint_eval, 1.5 * max_kl) ut.set_flat_params_to(model, new_params) return loss
def trpo_learn(replay_buffer, replay_buffer_reward, env, model, cov_matrix): np.random.seed(0) current_best_reward = float('-inf') global_iteration_counter = 0 optimization_history_list = [] while True: new_sample_reward = [] replay_buffer = [] replay_buffer_reward = [] #The first step is to add more simulation result to the replay buffer for episode_counter in range(0, C.max_new_episode): observation_list,action_list,log_prob_action_list,reward_list = \ roll_out_once.roll_out_once(env,model,cov_matrix) #drop old simulation experience if len(replay_buffer) > C.replay_buffer_size: drop_index = np.argmin(replay_buffer_reward) replay_buffer.pop(drop_index) replay_buffer_reward.pop(drop_index) #add the new simulation result to the replay buffer total_reward = np.sum(reward_list) replay_buffer_reward.append(total_reward) temp_dict = {} temp_dict['observation_list'] = observation_list temp_dict['action_list'] = action_list temp_dict['log_prob_action_list'] = log_prob_action_list temp_dict['reward_list'] = reward_list replay_buffer.append(temp_dict) new_sample_reward.append(np.sum(reward_list)) global_iteration_counter += 1 print('this is global iteration ', global_iteration_counter) print('the current reward is', np.mean(new_sample_reward)) #record the optimization process optimization_history_list.append(np.mean(new_sample_reward)) optimization_history = {} optimization_history['objective_history'] = optimization_history_list cwd = os.getcwd() #cwd = os.path.join(cwd, 'data_folder') parameter_file = 'optimization_history.json' cwd = os.path.join(cwd, parameter_file) with open(cwd, 'w') as statusFile: statusFile.write(jsonpickle.encode(optimization_history)) if np.mean(new_sample_reward) > current_best_reward: current_best_reward = np.mean(new_sample_reward) #save the neural network model cwd = os.getcwd() parameter_file = 'pendulum_nn_trained_model.pt' cwd = os.path.join(cwd, parameter_file) torch.save(model.state_dict(), cwd) #we can update the model more than once because we are using off-line data for update_iteration in range(0, C.max_offline_training): #sample experience from the replay buffer for training # new_replay_buffer_reward = [] # for entry in replay_buffer_reward: # new_replay_buffer_reward.append(np.log(entry*-1)*-1) #because the reward is negative here # sample_probability = (np.exp(new_replay_buffer_reward))/np.sum(np.exp(new_replay_buffer_reward)) #apply softmax to the total_reward list sampled_off_line_data = [] for sample_counter in range(0, C.training_batch_size): sampled_index = random.randint(0, len(replay_buffer) - 1) #sampled_index = np.random.choice(np.arange(0, len(replay_buffer)), p=sample_probability.tolist()) sampled_off_line_data.append(replay_buffer[sampled_index]) #concatenate the sampled experience into one long experience total_sampled_observation = torch.empty(size=(0, )) total_sampled_action = torch.empty(size=(0, )) total_sampled_log_prob_action_state = torch.empty(size=(0, )) total_sampled_reward = np.zeros((0)) baseline_reward = 0 for sample_index in range(0, len(sampled_off_line_data)): off_line_data = sampled_off_line_data[sample_index] baseline_reward += np.sum(off_line_data['reward_list']) baseline_reward = baseline_reward / len(sampled_off_line_data) for sample_index in range(0, len(sampled_off_line_data)): off_line_data = sampled_off_line_data[sample_index] total_sampled_observation = torch.cat( (total_sampled_observation, off_line_data['observation_list']), dim=0) total_sampled_action = torch.cat( (total_sampled_action, off_line_data['action_list']), dim=0) total_sampled_log_prob_action_state = torch.cat( (total_sampled_log_prob_action_state, off_line_data['log_prob_action_list']), dim=0) total_sampled_reward = np.concatenate( (total_sampled_reward, np.asarray(off_line_data['reward_list']) - baseline_reward)) total_sampled_reward = torch.tensor(total_sampled_reward) #compute loss and update model with trpo #this get_loss function will also be used for line search later on get_loss = lambda x: getSurrogateloss( model, total_sampled_observation, total_sampled_action, cov_matrix, total_sampled_log_prob_action_state, total_sampled_reward) loss = get_loss(model) #print('the loss of the model is',loss) grads = torch.autograd.grad(loss, model.parameters()) loss_grad = torch.cat([grad.view(-1) for grad in grads]) #compute the direction for updating the model parameter Fvp = lambda v: FisherVectorProduct( v, model, total_sampled_observation, total_sampled_action, total_sampled_log_prob_action_state, C.damping, cov_matrix) stepdir = conjugate_gradients(Fvp, -loss_grad, 20) #print('the step direction is',stepdir) #now, I need to perform a line search for knowing how large a step I can take in the #direction computed previously shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True) if shs > 0: #this should be positive #print('shs is',shs) lm = torch.sqrt(shs / C.max_kl) #print('lm is',lm) fullstep = stepdir / lm[0] #print('fullstep is',fullstep) neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True) prev_params = utils.get_flat_params_from(model) success, new_params = linesearch(model, get_loss, prev_params, fullstep, neggdotstepdir / lm[0]) #print('new params is',new_params) model = utils.set_flat_params_to(model, new_params)
def update_params(batch): rewards = torch.Tensor(batch.reward) masks = torch.Tensor(batch.mask) actions = torch.Tensor(np.concatenate(batch.action, 0)) states = torch.Tensor(batch.state) values = value_net(Variable(states)) returns = torch.Tensor(actions.size(0), 1) deltas = torch.Tensor(actions.size(0), 1) advantages = torch.Tensor(actions.size(0), 1) prev_return = 0 prev_value = 0 prev_advantage = 0 for i in reversed(range(rewards.size(0))): returns[i] = rewards[i] + args.gamma * prev_return * masks[i] deltas[i] = rewards[ i] + args.gamma * prev_value * masks[i] - values.data[i] advantages[ i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i] prev_return = returns[i, 0] prev_value = values.data[i, 0] prev_advantage = advantages[i, 0] targets = Variable(returns) # Original code uses the same LBFGS to optimize the value loss def get_value_loss(flat_params): set_flat_params_to(value_net, torch.Tensor(flat_params)) for param in value_net.parameters(): if param.grad is not None: param.grad.data.fill_(0) values_ = value_net(Variable(states)) value_loss = (values_ - targets).pow(2).mean() # weight decay for param in value_net.parameters(): value_loss += param.pow(2).sum() * args.l2_reg value_loss.backward() return (value_loss.data.double().numpy(), get_flat_grad_from(value_net).data.double().numpy()) flat_params, _, opt_info = scipy.optimize.fmin_l_bfgs_b( get_value_loss, get_flat_params_from(value_net).double().numpy(), maxiter=25) set_flat_params_to(value_net, torch.Tensor(flat_params)) advantages = (advantages - advantages.mean()) / advantages.std() action_means, action_log_stds, action_stds = policy_net(Variable(states)) fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone() def get_loss(volatile=False): if volatile: with torch.no_grad(): action_means, action_log_stds, action_stds = policy_net( Variable(states)) else: action_means, action_log_stds, action_stds = policy_net( Variable(states)) log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds) action_loss = -Variable(advantages) * torch.exp( log_prob - Variable(fixed_log_prob)) return action_loss.mean() def get_kl(): mean1, log_std1, std1 = policy_net(Variable(states)) mean0 = Variable(mean1.data) log_std0 = Variable(log_std1.data) std0 = Variable(std1.data) kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / ( 2.0 * std1.pow(2)) - 0.5 return kl.sum(1, keepdim=True) if args.agent == 'trpo': trpo_step(policy_net, get_loss, get_kl, args.max_kl, args.damping)