Beispiel #1
0
        def line_search_criterion(search_dir, step_len):
            test_policy = current_policy + step_len * search_dir
            set_params(self.policy, test_policy)

            with torch.no_grad():
                # Test if conditions are satisfied
                test_dists = self.policy(observations)
                test_probs = test_dists.log_prob(actions)

                imp_sampling = torch.exp(test_probs -
                                         log_action_probs.detach())

                test_loss = -torch.mean(imp_sampling * reward_advs)
                test_cost = torch.sum(
                    imp_sampling *
                    constraint_advs) / self.simulator.n_trajectories
                test_kl = mean_kl_first_fixed(action_dists, test_dists)

                loss_improv_cond = (test_loss - reward_loss) / (
                    step_len *
                    exp_loss_improv) >= self.line_search_accept_ratio
                cost_cond = step_len * torch.matmul(
                    constraint_grad, search_dir) <= max(-c, 0.0)
                kl_cond = test_kl <= self.max_kl

            set_params(self.policy, current_policy)

            if is_feasible:
                return loss_improv_cond and cost_cond and kl_cond

            return cost_cond and kl_cond
Beispiel #2
0
    def update_policy(self, observations, actions, reward_advs,
                      constraint_advs, J_c):
        self.policy.train()

        observations = observations.to(self.device)
        actions = actions.to(self.device)
        reward_advs = reward_advs.to(self.device)
        constraint_advs = constraint_advs.to(self.device)

        action_dists = self.policy(observations)
        log_action_probs = action_dists.log_prob(actions)
        #         print(action_dists.shape)
        #         print(log_action_probs.shape)

        imp_sampling = torch.exp(log_action_probs - log_action_probs.detach())
        # Change to torch.matmul
        reward_loss = -torch.mean(imp_sampling * reward_advs)
        reward_grad = flat_grad(reward_loss,
                                self.policy.parameters(),
                                retain_graph=True)
        # Change to torch.matmul
        constraint_loss = torch.sum(
            imp_sampling * constraint_advs) / self.simulator.n_trajectories
        constraint_grad = flat_grad(constraint_loss,
                                    self.policy.parameters(),
                                    retain_graph=True)

        mean_kl = mean_kl_first_fixed(action_dists, action_dists)
        Fvp_fun = get_Hvp_fun(mean_kl, self.policy.parameters())

        F_inv_g = cg_solver(Fvp_fun, reward_grad, self.device)
        F_inv_b = cg_solver(Fvp_fun, constraint_grad, self.device)

        q = torch.matmul(reward_grad, F_inv_g)
        r = torch.matmul(reward_grad, F_inv_b)
        s = torch.matmul(constraint_grad, F_inv_b)
        c = (J_c - self.max_constraint_val).to(self.device)

        is_feasible = False if c > 0 and c**2 / s - 2 * self.max_kl > 0 else True

        if is_feasible:
            lam, nu = self.calc_dual_vars(q, r, s, c)
            search_dir = -lam**-1 * (F_inv_g + nu * F_inv_b)
        else:
            search_dir = -torch.sqrt(2 * self.max_kl / s) * F_inv_b

        # Should be positive
        exp_loss_improv = torch.matmul(reward_grad, search_dir)
        current_policy = get_flat_params(self.policy)

        def line_search_criterion(search_dir, step_len):
            test_policy = current_policy + step_len * search_dir
            set_params(self.policy, test_policy)

            with torch.no_grad():
                # Test if conditions are satisfied
                test_dists = self.policy(observations)
                test_probs = test_dists.log_prob(actions)

                imp_sampling = torch.exp(test_probs -
                                         log_action_probs.detach())

                test_loss = -torch.mean(imp_sampling * reward_advs)
                test_cost = torch.sum(
                    imp_sampling *
                    constraint_advs) / self.simulator.n_trajectories
                test_kl = mean_kl_first_fixed(action_dists, test_dists)

                loss_improv_cond = (test_loss - reward_loss) / (
                    step_len *
                    exp_loss_improv) >= self.line_search_accept_ratio
                cost_cond = step_len * torch.matmul(
                    constraint_grad, search_dir) <= max(-c, 0.0)
                kl_cond = test_kl <= self.max_kl

            set_params(self.policy, current_policy)

            if is_feasible:
                return loss_improv_cond and cost_cond and kl_cond

            return cost_cond and kl_cond

        step_len = line_search(search_dir, 1.0, line_search_criterion,
                               self.line_search_coef)
        print('Step Len.:', step_len)
        new_policy = current_policy + step_len * search_dir
        set_params(self.policy, new_policy)