コード例 #1
0
    def adapt_ng(self,
                 episodes,
                 first_order=False,
                 max_kl=1e-3,
                 cg_iters=20,
                 cg_damping=1e-2,
                 ls_max_steps=10,
                 ls_backtrack_ratio=0.5):
        """Adapt the parameters of the policy network to a new task, from
        sampled trajectories `episodes`, with a one-step natural gradient update.
        """
        # Fit the baseline to the training episodes
        self.baseline.fit(episodes)
        # Get the loss on the training episodes
        loss_lvc = self.inner_loss_lvc(episodes)
        # Get the new parameters after a one-step natural gradient update
        grads = torch.autograd.grad(loss_lvc, self.policy.parameters())
        grads = parameters_to_vector(grads)

        # Compute the step direction with Conjugate Gradient
        hessian_vector_product = self.hessian_vector_product_ng(
            episodes, damping=cg_damping)
        stepdir = conjugate_gradient(hessian_vector_product,
                                     grads,
                                     cg_iters=cg_iters).detach()
        if self.verbose:
            print(
                torch.norm(hessian_vector_product(stepdir) - grads) /
                torch.norm(grads))

        shs = 0.5 * (stepdir.dot(hessian_vector_product(stepdir)))
        lm = torch.sqrt(max_kl / shs)

        if self.verbose:
            print("learning rate {}".format(lm))
        stepdir_named = vector_to_named_parameter_like(
            stepdir, self.policy.named_parameters())
        step_size = lm.detach()
        params = OrderedDict()
        for (name, param) in self.policy.named_parameters():
            params[name] = param - step_size * stepdir_named[name]

        if self.verbose:
            # compute the kl divergence
            with torch.autograd.no_grad():
                pi = self.policy(episodes.observations, params=params)
                pi_old = self.policy(episodes.observations)
                kl = kl_divergence(pi_old, pi).mean()
                print(kl)

        return params, step_size, stepdir
コード例 #2
0
ファイル: metalearner.py プロジェクト: shunzh/pytorch-maml-rl
    def step(self,
             episodes,
             max_kl=1e-3,
             cg_iters=10,
             cg_damping=1e-2,
             ls_max_steps=10,
             ls_backtrack_ratio=0.5):
        """Meta-optimization step (ie. update of the initial parameters), based 
        on Trust Region Policy Optimization (TRPO, [4]).
        """
        old_loss, _, old_pis = self.surrogate_loss(episodes)

        if old_loss is None:
            # nothing needs to be done
            return

        grads = torch.autograd.grad(old_loss, self.policy.parameters())
        grads = parameters_to_vector(grads)

        # Compute the step direction with Conjugate Gradient
        hessian_vector_product = self.hessian_vector_product(
            episodes, damping=cg_damping)
        stepdir = conjugate_gradient(hessian_vector_product,
                                     grads,
                                     cg_iters=cg_iters)

        # Compute the Lagrange multiplier
        shs = 0.5 * torch.dot(stepdir, hessian_vector_product(stepdir))
        lagrange_multiplier = torch.sqrt(shs / max_kl)

        step = stepdir / lagrange_multiplier

        # Save the old parameters
        old_params = parameters_to_vector(self.policy.parameters())

        # Line search
        step_size = 1.0
        for _ in range(ls_max_steps):
            vector_to_parameters(old_params - step_size * step,
                                 self.policy.parameters())
            loss, kl, _ = self.surrogate_loss(episodes, old_pis=old_pis)
            improve = loss - old_loss
            # if the new loss is smaller, and kl divergence is small enough (so the new policy is not too far away)
            if (improve.item() < 0.0) and (kl.item() < max_kl):
                break
            step_size *= ls_backtrack_ratio
        else:
            vector_to_parameters(old_params, self.policy.parameters())
コード例 #3
0
    def step(self,
             episodes,
             max_kl=1e-3,
             cg_iters=10,
             cg_damping=1e-2,
             ls_max_steps=10,
             ls_backtrack_ratio=0.5):
        """Meta-optimization step (ie. update of the initial parameters), based 
        on Trust Region Policy Optimization (TRPO, [4]).
        """
        old_loss, _, old_pis = self.surrogate_loss(episodes)
        grads = torch.autograd.grad(old_loss, self.policy.parameters())
        grads = parameters_to_vector(grads)

        # Compute the step direction with Conjugate Gradient
        hessian_vector_product = self.hessian_vector_product(
            episodes, damping=cg_damping)
        stepdir = conjugate_gradient(hessian_vector_product,
                                     grads,
                                     cg_iters=cg_iters)

        # Compute the Lagrange multiplier
        shs = 0.5 * torch.dot(stepdir, hessian_vector_product(stepdir))
        lagrange_multiplier = torch.sqrt(shs / max_kl)

        step = stepdir / lagrange_multiplier

        # Save the old parameters
        old_params = parameters_to_vector(self.policy.parameters())

        # Line search
        step_size = 2.0
        for _ in range(ls_max_steps):
            vector_to_parameters(old_params - step_size * step,
                                 self.policy.parameters())
            loss, kl, _ = self.surrogate_loss(episodes, old_pis=old_pis)
            improve = loss - old_loss
            if (improve.item() < 0.0) and (kl.item() < max_kl):
                # if improve.item() < 0.0:
                print("New Actor surrogate_loss: ", loss)
                break
            step_size *= ls_backtrack_ratio
        else:
            print("same actor~~~~")
            vector_to_parameters(old_params, self.policy.parameters())
            if self.policy.paramsFlag == OrderedDict(
                    self.policy.named_parameters()):
                print("really same~~~~~~~~")
コード例 #4
0
    def step(self,
             episodes,
             max_kl=1e-3,
             cg_iters=10,
             cg_damping=1e-2,
             ls_max_steps=10,
             ls_backtrack_ratio=0.5):
        """Meta-optimization step (ie. update of the initial parameters), based 
        on Trust Region Policy Optimization (TRPO, [4]).
        """
        old_loss, _, old_pis = self.surrogate_loss(episodes)
        grads = torch.autograd.grad(old_loss, self.policy.parameters())
        grads = parameters_to_vector(grads)

        # Compute the step direction with Conjugate Gradient
        hessian_vector_product = self.hessian_vector_product(
            episodes, damping=cg_damping)
        stepdir = conjugate_gradient(hessian_vector_product,
                                     grads,
                                     cg_iters=cg_iters)

        # Compute the Lagrange multiplier
        shs = 0.5 * torch.dot(stepdir, hessian_vector_product(stepdir))
        lagrange_multiplier = torch.sqrt(shs / max_kl)

        step = stepdir / lagrange_multiplier

        # Save the old parameters
        old_params = parameters_to_vector(self.policy.parameters())

        loss = None

        # Line search
        step_size = 1.0
        for _ in range(ls_max_steps):
            vector_to_parameters(old_params - step_size * step,
                                 self.policy.parameters())
            loss, kl, _ = self.surrogate_loss(episodes, old_pis=old_pis)
            improve = loss - old_loss
            if (improve.item() < 0.0) and (kl.item() < max_kl):
                break
            step_size *= ls_backtrack_ratio
        else:
            vector_to_parameters(old_params, self.policy.parameters())

        wandb.log({'step_end_loss': loss}, commit=False)
コード例 #5
0
ファイル: metalearner.py プロジェクト: NagisaZj/exp_maml
    def Conjugate_gradient_descent(self,
                                   episodes,
                                   old_loss,
                                   old_pis,
                                   max_kl=1e-3,
                                   cg_iters=10,
                                   cg_damping=1e-2,
                                   ls_max_steps=10,
                                   ls_backtrack_ratio=0.5):
        grads = torch.autograd.grad(old_loss, self.policy.parameters())
        grads = parameters_to_vector(grads)

        # Compute the step direction with Conjugate Gradient
        hessian_vector_product = self.hessian_vector_product(
            episodes, damping=cg_damping)
        stepdir = conjugate_gradient(hessian_vector_product,
                                     grads,
                                     cg_iters=cg_iters)

        # Compute the Lagrange multiplier
        shs = 0.5 * torch.dot(stepdir, hessian_vector_product(stepdir))
        lagrange_multiplier = torch.sqrt(shs / max_kl)

        step = stepdir / lagrange_multiplier

        # Save the old parameters
        old_params = parameters_to_vector(self.policy.parameters())

        # Line search
        step_size = 1.0
        for _ in range(ls_max_steps):
            vector_to_parameters(old_params - step_size * step,
                                 self.policy.parameters())
            loss, kl, _ = self.surrogate_loss(episodes, old_pis=old_pis)
            improve = loss - old_loss
            if (improve.item() < 0.0) and (kl.item() < max_kl):
                break
            step_size *= ls_backtrack_ratio
        else:
            vector_to_parameters(old_params, self.policy.parameters())
コード例 #6
0
    def step(self, episodes, args):
        """Meta-optimization step (ie. update of the initial parameters), based 
        on Trust Region Policy Optimization (TRPO, [4]).
        """
        # Compute initial surrogate loss assuming old_pi and pi are the same
        old_loss, _, old_pis = self.surrogate_loss(episodes)
        grads = torch.autograd.grad(old_loss, self.policy.parameters())
        grads = parameters_to_vector(grads)
        
        if args.first_order:
            raise ValueError("no first order")
            step = grads
        else:
            # Compute the step direction with Conjugate Gradient
            hessian_vector_product = self.hessian_vector_product(episodes, damping=args.cg_damping)
            stepdir = conjugate_gradient(hessian_vector_product, grads, cg_iters=args.cg_iters)

            # Compute the Lagrange multiplier
            shs = 0.5 * torch.dot(stepdir, hessian_vector_product(stepdir))
            lagrange_multiplier = torch.sqrt(shs / args.max_kl)

            step = stepdir / lagrange_multiplier

        # Save the old parameters
        old_params = parameters_to_vector(self.policy.parameters())

        # Line search
        step_size = 1.0
        for _ in range(args.ls_max_steps):
            vector_to_parameters(old_params - step_size * step, self.policy.parameters())
            loss, kl, _ = self.surrogate_loss(episodes, old_pis=old_pis)
            improve = loss - old_loss
            if (improve.item() < 0.0) and (kl.item() < args.max_kl):
                break
            step_size *= args.ls_backtrack_ratio
        else:
            vector_to_parameters(old_params, self.policy.parameters())
コード例 #7
0
ファイル: maml_trpo.py プロジェクト: imhgchoi/MAML-RL
    def step(self,
             train_futures,
             valid_futures,
             max_kl=1e-3,
             cg_iters=10,
             cg_damping=1e-2,
             ls_max_steps=10,
             ls_backtrack_ratio=0.5):
        num_tasks = len(train_futures[0])
        logs = {}
        # Compute the surrogate loss
        old_losses, old_kls, old_pis = self._async_gather([
            self.surrogate_loss(train, valid, old_pi=None)
            for (train, valid) in zip(zip(*train_futures), valid_futures)
        ])

        logs['loss_before'] = to_numpy(old_losses)
        logs['kl_before'] = to_numpy(old_kls)

        old_loss = sum(old_losses) / num_tasks
        grads = torch.autograd.grad(old_loss,
                                    self.policy.parameters(),
                                    retain_graph=True)
        grads = parameters_to_vector(grads)

        # Compute the step direction with Conjugate Gradient
        old_kl = sum(old_kls) / num_tasks
        hessian_vector_product = self.hessian_vector_product(
            old_kl, damping=cg_damping)
        stepdir = conjugate_gradient(hessian_vector_product,
                                     grads,
                                     cg_iters=cg_iters)

        # Compute the Lagrange multiplier
        shs = 0.5 * torch.dot(
            stepdir, hessian_vector_product(stepdir, retain_graph=False))
        lagrange_multiplier = torch.sqrt(shs / max_kl)

        step = stepdir / lagrange_multiplier

        # Save the old parameters
        old_params = parameters_to_vector(self.policy.parameters())

        # Line search
        step_size = 1.0
        for _ in range(ls_max_steps):
            vector_to_parameters(old_params - step_size * step,
                                 self.policy.parameters())

            losses, kls, _ = self._async_gather([
                self.surrogate_loss(train, valid, old_pi=old_pi)
                for (train, valid, old_pi) in zip(zip(
                    *train_futures), valid_futures, old_pis)
            ])

            improve = (sum(losses) / num_tasks) - old_loss
            kl = sum(kls) / num_tasks
            if (improve.item() < 0.0) and (kl.item() < max_kl):
                logs['loss_after'] = to_numpy(losses)
                logs['kl_after'] = to_numpy(kls)
                break
            step_size *= ls_backtrack_ratio
        else:
            vector_to_parameters(old_params, self.policy.parameters())

        return logs
コード例 #8
0
    def compute_ng_gradient(self,
                            episodes,
                            max_kl=1e-3,
                            cg_iters=20,
                            cg_damping=1e-2,
                            ls_max_steps=10,
                            ls_backtrack_ratio=0.5):
        ng_grads = []
        for train_episodes, valid_episodes in episodes:
            params, step_size, step = self.adapt(train_episodes)

            # compute $grad = \nabla_x J^{lvc}(x) at x = \theta - \eta\UM(\theta)
            pi = self.policy(valid_episodes.observations, params=params)
            pi_detach = detach_distribution(pi)

            values = self.baseline(valid_episodes)
            advantages = valid_episodes.gae(values, tau=self.tau)
            advantages = weighted_normalize(advantages,
                                            weights=valid_episodes.mask)

            log_ratio = pi.log_prob(
                valid_episodes.actions) - pi_detach.log_prob(
                    valid_episodes.actions)
            if log_ratio.dim() > 2:
                log_ratio = torch.sum(log_ratio, dim=2)
            ratio = torch.exp(log_ratio)

            loss = -weighted_mean(
                ratio * advantages, dim=0, weights=valid_episodes.mask)

            ng_grad_0 = torch.autograd.grad(
                loss, self.policy.parameters())  # no create graph
            ng_grad_0 = parameters_to_vector(ng_grad_0)
            # compute the inverse of Fihser matrix at x=\theta times $grad with Conjugate Gradient
            hessian_vector_product = self.hessian_vector_product_ng(
                train_episodes, damping=cg_damping)
            F_inv_grad = conjugate_gradient(hessian_vector_product,
                                            ng_grad_0,
                                            cg_iters=cg_iters)

            # compute $ng_grad_1 = \nabla^2 J^{lvc}(x) at x = \theta times $F_inv_grad
            # create graph for higher differential
            # self.baseline.fit(train_episodes)
            loss = self.inner_loss(train_episodes)
            grad = torch.autograd.grad(loss,
                                       self.policy.parameters(),
                                       create_graph=True)
            grad = parameters_to_vector(grad)
            grad_F_inv_grad = torch.dot(grad, F_inv_grad.detach())
            ng_grad_1 = torch.autograd.grad(grad_F_inv_grad,
                                            self.policy.parameters())
            ng_grad_1 = parameters_to_vector(ng_grad_1)
            # compute $ng_grad_2 = the Jocobian of {F(x) U(\theta)} at x = \theta times $F_inv_grad
            hessian_vector_product = self.hessian_vector_product_ng(
                train_episodes, damping=cg_damping)
            F_U = hessian_vector_product(step)
            ng_grad_2 = torch.autograd.grad(
                torch.dot(F_U, F_inv_grad.detach()), self.policy.parameters())
            ng_grad_2 = parameters_to_vector(ng_grad_2)
            ng_grad = ng_grad_0 - step_size * (ng_grad_1 + ng_grad_2)

            ng_grad = parameters_to_vector(ng_grad)
            ng_grads.append(ng_grad.view(len(ng_grad), 1))

        return torch.mean(torch.stack(ng_grads, dim=1), dim=[1, 2])
コード例 #9
0
    def step(self,
             episodes,
             max_kl=1e-3,
             cg_iters=10,
             cg_damping=1e-2,
             ls_max_steps=10,
             ls_backtrack_ratio=0.5):
        """Meta-optimization step (ie. update of the initial parameters), based 
        on Trust Region Policy Optimization (TRPO, [4]).
        """
        # if self.tree:
        #     # print(self.tasks)
        #
        #     updated_episodes = []
        #     count = 0
        #     for train_episodes, valid_episodes in episodes:
        #         _, embeddings = self.tree.forward(torch.from_numpy(np.array(self.tasks[count])))
        #         print("e", embeddings)
        #         # print(train_episodes.observations)
        #         print("prev", train_episodes.observations.shape)
        #         teo_list = []
        #         for episode in train_episodes.observations:
        #             # print("episode", episode)
        #             te = torch.t(torch.stack([torch.cat([torch.from_numpy(np.array(teo)), embeddings[0]], 0) for teo in episode], 1))
        #             # print("stacked", te)
        #             # print(te)
        #             # print(te.shape)
        #             teo_list.append(te)
        #         train_episodes._observations = torch.stack(teo_list, 0)
        #         print("augmented", train_episodes.observations.shape)
        #
        #         teo_list = []
        #         for episode in valid_episodes.observations:
        #             # print("episode", episode)
        #             te = torch.t(
        #                 torch.stack([torch.cat([torch.from_numpy(np.array(teo)), embeddings[0]], 0) for teo in episode],
        #                             1))
        #             # print("stacked", te)
        #             # print(te)
        #             # print(te.shape)
        #             teo_list.append(te)
        #         valid_episodes._observations = torch.stack(teo_list, 0)
        #         count += 1
        #         updated_episodes.append((train_episodes, valid_episodes))
        #     episodes = updated_episodes

        old_loss, _, old_pis = self.surrogate_loss(episodes)
        grads = torch.autograd.grad(old_loss, self.policy.parameters())
        grads = parameters_to_vector(grads)

        # Compute the step direction with Conjugate Gradient
        hessian_vector_product = self.hessian_vector_product(
            episodes, damping=cg_damping)
        stepdir = conjugate_gradient(hessian_vector_product,
                                     grads,
                                     cg_iters=cg_iters)

        # Compute the Lagrange multiplier
        shs = 0.5 * torch.dot(stepdir, hessian_vector_product(stepdir))
        lagrange_multiplier = torch.sqrt(shs / max_kl)

        step = stepdir / lagrange_multiplier

        # Save the old parameters
        old_params = parameters_to_vector(self.policy.parameters())

        # Line search
        step_size = 1.0
        for _ in range(ls_max_steps):
            vector_to_parameters(old_params - step_size * step,
                                 self.policy.parameters())
            loss, kl, _ = self.surrogate_loss(episodes, old_pis=old_pis)
            improve = loss - old_loss
            if (improve.item() < 0.0) and (kl.item() < max_kl):
                break
            step_size *= ls_backtrack_ratio
        else:
            vector_to_parameters(old_params, self.policy.parameters())

        print("loss:", loss)