def line_search_criterion(search_dir, step_len): test_policy = current_policy + step_len * search_dir set_params(self.policy, test_policy) with torch.no_grad(): # Test if conditions are satisfied test_dists = self.policy(observations) test_probs = test_dists.log_prob(actions) imp_sampling = torch.exp(test_probs - log_action_probs.detach()) test_loss = -torch.mean(imp_sampling * reward_advs) test_cost = torch.sum( imp_sampling * constraint_advs) / self.simulator.n_trajectories test_kl = mean_kl_first_fixed(action_dists, test_dists) loss_improv_cond = (test_loss - reward_loss) / ( step_len * exp_loss_improv) >= self.line_search_accept_ratio cost_cond = step_len * torch.matmul( constraint_grad, search_dir) <= max(-c, 0.0) kl_cond = test_kl <= self.max_kl set_params(self.policy, current_policy) if is_feasible: return loss_improv_cond and cost_cond and kl_cond return cost_cond and kl_cond
def update_policy(self, observations, actions, reward_advs, constraint_advs, J_c): self.policy.train() observations = observations.to(self.device) actions = actions.to(self.device) reward_advs = reward_advs.to(self.device) constraint_advs = constraint_advs.to(self.device) action_dists = self.policy(observations) log_action_probs = action_dists.log_prob(actions) # print(action_dists.shape) # print(log_action_probs.shape) imp_sampling = torch.exp(log_action_probs - log_action_probs.detach()) # Change to torch.matmul reward_loss = -torch.mean(imp_sampling * reward_advs) reward_grad = flat_grad(reward_loss, self.policy.parameters(), retain_graph=True) # Change to torch.matmul constraint_loss = torch.sum( imp_sampling * constraint_advs) / self.simulator.n_trajectories constraint_grad = flat_grad(constraint_loss, self.policy.parameters(), retain_graph=True) mean_kl = mean_kl_first_fixed(action_dists, action_dists) Fvp_fun = get_Hvp_fun(mean_kl, self.policy.parameters()) F_inv_g = cg_solver(Fvp_fun, reward_grad, self.device) F_inv_b = cg_solver(Fvp_fun, constraint_grad, self.device) q = torch.matmul(reward_grad, F_inv_g) r = torch.matmul(reward_grad, F_inv_b) s = torch.matmul(constraint_grad, F_inv_b) c = (J_c - self.max_constraint_val).to(self.device) is_feasible = False if c > 0 and c**2 / s - 2 * self.max_kl > 0 else True if is_feasible: lam, nu = self.calc_dual_vars(q, r, s, c) search_dir = -lam**-1 * (F_inv_g + nu * F_inv_b) else: search_dir = -torch.sqrt(2 * self.max_kl / s) * F_inv_b # Should be positive exp_loss_improv = torch.matmul(reward_grad, search_dir) current_policy = get_flat_params(self.policy) def line_search_criterion(search_dir, step_len): test_policy = current_policy + step_len * search_dir set_params(self.policy, test_policy) with torch.no_grad(): # Test if conditions are satisfied test_dists = self.policy(observations) test_probs = test_dists.log_prob(actions) imp_sampling = torch.exp(test_probs - log_action_probs.detach()) test_loss = -torch.mean(imp_sampling * reward_advs) test_cost = torch.sum( imp_sampling * constraint_advs) / self.simulator.n_trajectories test_kl = mean_kl_first_fixed(action_dists, test_dists) loss_improv_cond = (test_loss - reward_loss) / ( step_len * exp_loss_improv) >= self.line_search_accept_ratio cost_cond = step_len * torch.matmul( constraint_grad, search_dir) <= max(-c, 0.0) kl_cond = test_kl <= self.max_kl set_params(self.policy, current_policy) if is_feasible: return loss_improv_cond and cost_cond and kl_cond return cost_cond and kl_cond step_len = line_search(search_dir, 1.0, line_search_criterion, self.line_search_coef) print('Step Len.:', step_len) new_policy = current_policy + step_len * search_dir set_params(self.policy, new_policy)