def adapt_ng(self, episodes, first_order=False, max_kl=1e-3, cg_iters=20, cg_damping=1e-2, ls_max_steps=10, ls_backtrack_ratio=0.5): """Adapt the parameters of the policy network to a new task, from sampled trajectories `episodes`, with a one-step natural gradient update. """ # Fit the baseline to the training episodes self.baseline.fit(episodes) # Get the loss on the training episodes loss_lvc = self.inner_loss_lvc(episodes) # Get the new parameters after a one-step natural gradient update grads = torch.autograd.grad(loss_lvc, self.policy.parameters()) grads = parameters_to_vector(grads) # Compute the step direction with Conjugate Gradient hessian_vector_product = self.hessian_vector_product_ng( episodes, damping=cg_damping) stepdir = conjugate_gradient(hessian_vector_product, grads, cg_iters=cg_iters).detach() if self.verbose: print( torch.norm(hessian_vector_product(stepdir) - grads) / torch.norm(grads)) shs = 0.5 * (stepdir.dot(hessian_vector_product(stepdir))) lm = torch.sqrt(max_kl / shs) if self.verbose: print("learning rate {}".format(lm)) stepdir_named = vector_to_named_parameter_like( stepdir, self.policy.named_parameters()) step_size = lm.detach() params = OrderedDict() for (name, param) in self.policy.named_parameters(): params[name] = param - step_size * stepdir_named[name] if self.verbose: # compute the kl divergence with torch.autograd.no_grad(): pi = self.policy(episodes.observations, params=params) pi_old = self.policy(episodes.observations) kl = kl_divergence(pi_old, pi).mean() print(kl) return params, step_size, stepdir
def step(self, episodes, max_kl=1e-3, cg_iters=10, cg_damping=1e-2, ls_max_steps=10, ls_backtrack_ratio=0.5): """Meta-optimization step (ie. update of the initial parameters), based on Trust Region Policy Optimization (TRPO, [4]). """ old_loss, _, old_pis = self.surrogate_loss(episodes) if old_loss is None: # nothing needs to be done return grads = torch.autograd.grad(old_loss, self.policy.parameters()) grads = parameters_to_vector(grads) # Compute the step direction with Conjugate Gradient hessian_vector_product = self.hessian_vector_product( episodes, damping=cg_damping) stepdir = conjugate_gradient(hessian_vector_product, grads, cg_iters=cg_iters) # Compute the Lagrange multiplier shs = 0.5 * torch.dot(stepdir, hessian_vector_product(stepdir)) lagrange_multiplier = torch.sqrt(shs / max_kl) step = stepdir / lagrange_multiplier # Save the old parameters old_params = parameters_to_vector(self.policy.parameters()) # Line search step_size = 1.0 for _ in range(ls_max_steps): vector_to_parameters(old_params - step_size * step, self.policy.parameters()) loss, kl, _ = self.surrogate_loss(episodes, old_pis=old_pis) improve = loss - old_loss # if the new loss is smaller, and kl divergence is small enough (so the new policy is not too far away) if (improve.item() < 0.0) and (kl.item() < max_kl): break step_size *= ls_backtrack_ratio else: vector_to_parameters(old_params, self.policy.parameters())
def step(self, episodes, max_kl=1e-3, cg_iters=10, cg_damping=1e-2, ls_max_steps=10, ls_backtrack_ratio=0.5): """Meta-optimization step (ie. update of the initial parameters), based on Trust Region Policy Optimization (TRPO, [4]). """ old_loss, _, old_pis = self.surrogate_loss(episodes) grads = torch.autograd.grad(old_loss, self.policy.parameters()) grads = parameters_to_vector(grads) # Compute the step direction with Conjugate Gradient hessian_vector_product = self.hessian_vector_product( episodes, damping=cg_damping) stepdir = conjugate_gradient(hessian_vector_product, grads, cg_iters=cg_iters) # Compute the Lagrange multiplier shs = 0.5 * torch.dot(stepdir, hessian_vector_product(stepdir)) lagrange_multiplier = torch.sqrt(shs / max_kl) step = stepdir / lagrange_multiplier # Save the old parameters old_params = parameters_to_vector(self.policy.parameters()) # Line search step_size = 2.0 for _ in range(ls_max_steps): vector_to_parameters(old_params - step_size * step, self.policy.parameters()) loss, kl, _ = self.surrogate_loss(episodes, old_pis=old_pis) improve = loss - old_loss if (improve.item() < 0.0) and (kl.item() < max_kl): # if improve.item() < 0.0: print("New Actor surrogate_loss: ", loss) break step_size *= ls_backtrack_ratio else: print("same actor~~~~") vector_to_parameters(old_params, self.policy.parameters()) if self.policy.paramsFlag == OrderedDict( self.policy.named_parameters()): print("really same~~~~~~~~")
def step(self, episodes, max_kl=1e-3, cg_iters=10, cg_damping=1e-2, ls_max_steps=10, ls_backtrack_ratio=0.5): """Meta-optimization step (ie. update of the initial parameters), based on Trust Region Policy Optimization (TRPO, [4]). """ old_loss, _, old_pis = self.surrogate_loss(episodes) grads = torch.autograd.grad(old_loss, self.policy.parameters()) grads = parameters_to_vector(grads) # Compute the step direction with Conjugate Gradient hessian_vector_product = self.hessian_vector_product( episodes, damping=cg_damping) stepdir = conjugate_gradient(hessian_vector_product, grads, cg_iters=cg_iters) # Compute the Lagrange multiplier shs = 0.5 * torch.dot(stepdir, hessian_vector_product(stepdir)) lagrange_multiplier = torch.sqrt(shs / max_kl) step = stepdir / lagrange_multiplier # Save the old parameters old_params = parameters_to_vector(self.policy.parameters()) loss = None # Line search step_size = 1.0 for _ in range(ls_max_steps): vector_to_parameters(old_params - step_size * step, self.policy.parameters()) loss, kl, _ = self.surrogate_loss(episodes, old_pis=old_pis) improve = loss - old_loss if (improve.item() < 0.0) and (kl.item() < max_kl): break step_size *= ls_backtrack_ratio else: vector_to_parameters(old_params, self.policy.parameters()) wandb.log({'step_end_loss': loss}, commit=False)
def Conjugate_gradient_descent(self, episodes, old_loss, old_pis, max_kl=1e-3, cg_iters=10, cg_damping=1e-2, ls_max_steps=10, ls_backtrack_ratio=0.5): grads = torch.autograd.grad(old_loss, self.policy.parameters()) grads = parameters_to_vector(grads) # Compute the step direction with Conjugate Gradient hessian_vector_product = self.hessian_vector_product( episodes, damping=cg_damping) stepdir = conjugate_gradient(hessian_vector_product, grads, cg_iters=cg_iters) # Compute the Lagrange multiplier shs = 0.5 * torch.dot(stepdir, hessian_vector_product(stepdir)) lagrange_multiplier = torch.sqrt(shs / max_kl) step = stepdir / lagrange_multiplier # Save the old parameters old_params = parameters_to_vector(self.policy.parameters()) # Line search step_size = 1.0 for _ in range(ls_max_steps): vector_to_parameters(old_params - step_size * step, self.policy.parameters()) loss, kl, _ = self.surrogate_loss(episodes, old_pis=old_pis) improve = loss - old_loss if (improve.item() < 0.0) and (kl.item() < max_kl): break step_size *= ls_backtrack_ratio else: vector_to_parameters(old_params, self.policy.parameters())
def step(self, episodes, args): """Meta-optimization step (ie. update of the initial parameters), based on Trust Region Policy Optimization (TRPO, [4]). """ # Compute initial surrogate loss assuming old_pi and pi are the same old_loss, _, old_pis = self.surrogate_loss(episodes) grads = torch.autograd.grad(old_loss, self.policy.parameters()) grads = parameters_to_vector(grads) if args.first_order: raise ValueError("no first order") step = grads else: # Compute the step direction with Conjugate Gradient hessian_vector_product = self.hessian_vector_product(episodes, damping=args.cg_damping) stepdir = conjugate_gradient(hessian_vector_product, grads, cg_iters=args.cg_iters) # Compute the Lagrange multiplier shs = 0.5 * torch.dot(stepdir, hessian_vector_product(stepdir)) lagrange_multiplier = torch.sqrt(shs / args.max_kl) step = stepdir / lagrange_multiplier # Save the old parameters old_params = parameters_to_vector(self.policy.parameters()) # Line search step_size = 1.0 for _ in range(args.ls_max_steps): vector_to_parameters(old_params - step_size * step, self.policy.parameters()) loss, kl, _ = self.surrogate_loss(episodes, old_pis=old_pis) improve = loss - old_loss if (improve.item() < 0.0) and (kl.item() < args.max_kl): break step_size *= args.ls_backtrack_ratio else: vector_to_parameters(old_params, self.policy.parameters())
def step(self, train_futures, valid_futures, max_kl=1e-3, cg_iters=10, cg_damping=1e-2, ls_max_steps=10, ls_backtrack_ratio=0.5): num_tasks = len(train_futures[0]) logs = {} # Compute the surrogate loss old_losses, old_kls, old_pis = self._async_gather([ self.surrogate_loss(train, valid, old_pi=None) for (train, valid) in zip(zip(*train_futures), valid_futures) ]) logs['loss_before'] = to_numpy(old_losses) logs['kl_before'] = to_numpy(old_kls) old_loss = sum(old_losses) / num_tasks grads = torch.autograd.grad(old_loss, self.policy.parameters(), retain_graph=True) grads = parameters_to_vector(grads) # Compute the step direction with Conjugate Gradient old_kl = sum(old_kls) / num_tasks hessian_vector_product = self.hessian_vector_product( old_kl, damping=cg_damping) stepdir = conjugate_gradient(hessian_vector_product, grads, cg_iters=cg_iters) # Compute the Lagrange multiplier shs = 0.5 * torch.dot( stepdir, hessian_vector_product(stepdir, retain_graph=False)) lagrange_multiplier = torch.sqrt(shs / max_kl) step = stepdir / lagrange_multiplier # Save the old parameters old_params = parameters_to_vector(self.policy.parameters()) # Line search step_size = 1.0 for _ in range(ls_max_steps): vector_to_parameters(old_params - step_size * step, self.policy.parameters()) losses, kls, _ = self._async_gather([ self.surrogate_loss(train, valid, old_pi=old_pi) for (train, valid, old_pi) in zip(zip( *train_futures), valid_futures, old_pis) ]) improve = (sum(losses) / num_tasks) - old_loss kl = sum(kls) / num_tasks if (improve.item() < 0.0) and (kl.item() < max_kl): logs['loss_after'] = to_numpy(losses) logs['kl_after'] = to_numpy(kls) break step_size *= ls_backtrack_ratio else: vector_to_parameters(old_params, self.policy.parameters()) return logs
def compute_ng_gradient(self, episodes, max_kl=1e-3, cg_iters=20, cg_damping=1e-2, ls_max_steps=10, ls_backtrack_ratio=0.5): ng_grads = [] for train_episodes, valid_episodes in episodes: params, step_size, step = self.adapt(train_episodes) # compute $grad = \nabla_x J^{lvc}(x) at x = \theta - \eta\UM(\theta) pi = self.policy(valid_episodes.observations, params=params) pi_detach = detach_distribution(pi) values = self.baseline(valid_episodes) advantages = valid_episodes.gae(values, tau=self.tau) advantages = weighted_normalize(advantages, weights=valid_episodes.mask) log_ratio = pi.log_prob( valid_episodes.actions) - pi_detach.log_prob( valid_episodes.actions) if log_ratio.dim() > 2: log_ratio = torch.sum(log_ratio, dim=2) ratio = torch.exp(log_ratio) loss = -weighted_mean( ratio * advantages, dim=0, weights=valid_episodes.mask) ng_grad_0 = torch.autograd.grad( loss, self.policy.parameters()) # no create graph ng_grad_0 = parameters_to_vector(ng_grad_0) # compute the inverse of Fihser matrix at x=\theta times $grad with Conjugate Gradient hessian_vector_product = self.hessian_vector_product_ng( train_episodes, damping=cg_damping) F_inv_grad = conjugate_gradient(hessian_vector_product, ng_grad_0, cg_iters=cg_iters) # compute $ng_grad_1 = \nabla^2 J^{lvc}(x) at x = \theta times $F_inv_grad # create graph for higher differential # self.baseline.fit(train_episodes) loss = self.inner_loss(train_episodes) grad = torch.autograd.grad(loss, self.policy.parameters(), create_graph=True) grad = parameters_to_vector(grad) grad_F_inv_grad = torch.dot(grad, F_inv_grad.detach()) ng_grad_1 = torch.autograd.grad(grad_F_inv_grad, self.policy.parameters()) ng_grad_1 = parameters_to_vector(ng_grad_1) # compute $ng_grad_2 = the Jocobian of {F(x) U(\theta)} at x = \theta times $F_inv_grad hessian_vector_product = self.hessian_vector_product_ng( train_episodes, damping=cg_damping) F_U = hessian_vector_product(step) ng_grad_2 = torch.autograd.grad( torch.dot(F_U, F_inv_grad.detach()), self.policy.parameters()) ng_grad_2 = parameters_to_vector(ng_grad_2) ng_grad = ng_grad_0 - step_size * (ng_grad_1 + ng_grad_2) ng_grad = parameters_to_vector(ng_grad) ng_grads.append(ng_grad.view(len(ng_grad), 1)) return torch.mean(torch.stack(ng_grads, dim=1), dim=[1, 2])
def step(self, episodes, max_kl=1e-3, cg_iters=10, cg_damping=1e-2, ls_max_steps=10, ls_backtrack_ratio=0.5): """Meta-optimization step (ie. update of the initial parameters), based on Trust Region Policy Optimization (TRPO, [4]). """ # if self.tree: # # print(self.tasks) # # updated_episodes = [] # count = 0 # for train_episodes, valid_episodes in episodes: # _, embeddings = self.tree.forward(torch.from_numpy(np.array(self.tasks[count]))) # print("e", embeddings) # # print(train_episodes.observations) # print("prev", train_episodes.observations.shape) # teo_list = [] # for episode in train_episodes.observations: # # print("episode", episode) # te = torch.t(torch.stack([torch.cat([torch.from_numpy(np.array(teo)), embeddings[0]], 0) for teo in episode], 1)) # # print("stacked", te) # # print(te) # # print(te.shape) # teo_list.append(te) # train_episodes._observations = torch.stack(teo_list, 0) # print("augmented", train_episodes.observations.shape) # # teo_list = [] # for episode in valid_episodes.observations: # # print("episode", episode) # te = torch.t( # torch.stack([torch.cat([torch.from_numpy(np.array(teo)), embeddings[0]], 0) for teo in episode], # 1)) # # print("stacked", te) # # print(te) # # print(te.shape) # teo_list.append(te) # valid_episodes._observations = torch.stack(teo_list, 0) # count += 1 # updated_episodes.append((train_episodes, valid_episodes)) # episodes = updated_episodes old_loss, _, old_pis = self.surrogate_loss(episodes) grads = torch.autograd.grad(old_loss, self.policy.parameters()) grads = parameters_to_vector(grads) # Compute the step direction with Conjugate Gradient hessian_vector_product = self.hessian_vector_product( episodes, damping=cg_damping) stepdir = conjugate_gradient(hessian_vector_product, grads, cg_iters=cg_iters) # Compute the Lagrange multiplier shs = 0.5 * torch.dot(stepdir, hessian_vector_product(stepdir)) lagrange_multiplier = torch.sqrt(shs / max_kl) step = stepdir / lagrange_multiplier # Save the old parameters old_params = parameters_to_vector(self.policy.parameters()) # Line search step_size = 1.0 for _ in range(ls_max_steps): vector_to_parameters(old_params - step_size * step, self.policy.parameters()) loss, kl, _ = self.surrogate_loss(episodes, old_pis=old_pis) improve = loss - old_loss if (improve.item() < 0.0) and (kl.item() < max_kl): break step_size *= ls_backtrack_ratio else: vector_to_parameters(old_params, self.policy.parameters()) print("loss:", loss)