def trpo_update(replay, policy, baseline): gamma = 0.99 tau = 0.95 max_kl = 0.01 ls_max_steps = 15 backtrack_factor = 0.5 old_policy = deepcopy(policy) for step in range(10): states = replay.state() actions = replay.action() rewards = replay.reward() dones = replay.done() next_states = replay.next_state() returns = ch.td.discount(gamma, rewards, dones) baseline.fit(states, returns) values = baseline(states) next_values = baseline(next_states) # Compute KL with th.no_grad(): old_density = old_policy.density(states) new_density = policy.density(states) kl = kl_divergence(old_density, new_density).mean() # Compute surrogate loss old_log_probs = old_density.log_prob(actions).mean(dim=1, keepdim=True) new_log_probs = new_density.log_prob(actions).mean(dim=1, keepdim=True) bootstraps = values * (1.0 - dones) + next_values * dones advantages = ch.pg.generalized_advantage(gamma, tau, rewards, dones, bootstraps, th.zeros(1)) advantages = ch.normalize(advantages).detach() surr_loss = trpo.policy_loss(new_log_probs, old_log_probs, advantages) # Compute the update grad = autograd.grad(surr_loss, policy.parameters(), retain_graph=True) Fvp = trpo.hessian_vector_product(kl, policy.parameters()) grad = parameters_to_vector(grad).detach() step = trpo.conjugate_gradient(Fvp, grad) lagrange_mult = 0.5 * th.dot(step, Fvp(step)) / max_kl step = step / lagrange_mult step_ = [th.zeros_like(p.data) for p in policy.parameters()] vector_to_parameters(step, step_) step = step_ # Line-search for ls_step in range(ls_max_steps): stepsize = backtrack_factor**ls_step clone = deepcopy(policy) for c, u in zip(clone.parameters(), step): c.data.add_(-stepsize, u.data) new_density = clone.density(states) new_kl = kl_divergence(old_density, new_density).mean() new_log_probs = new_density.log_prob(actions).mean(dim=1, keepdim=True) new_loss = trpo.policy_loss(new_log_probs, old_log_probs, advantages) if new_loss < surr_loss and new_kl < max_kl: for p, c in zip(policy.parameters(), clone.parameters()): p.data[:] = c.data[:] break
def test_trpo_policy_loss(self): for shape in [(1, BSZ), (BSZ, 1), (BSZ, )]: new_log_probs = th.randn(BSZ) old_log_probs = th.randn(BSZ) advantages = th.randn(BSZ) ref = trpo.policy_loss(new_log_probs, old_log_probs, advantages) loss = trpo.policy_loss(new_log_probs.view(*shape), old_log_probs.view(*shape), advantages.view(*shape)) self.assertAlmostEqual(loss.item(), ref.item())
def meta_surrogate_loss(iteration_replays, iteration_policies, policy, baseline, tau, gamma, adapt_lr): mean_loss = 0.0 mean_kl = 0.0 for task_replays, old_policy in tqdm(zip(iteration_replays, iteration_policies), total=len(iteration_replays), desc='Surrogate Loss', leave=False): policy.reset_context() train_replays = task_replays[:-1] valid_episodes = task_replays[-1] new_policy = l2l.clone_module(policy) # Fast Adapt for train_episodes in train_replays: new_policy = fast_adapt_a2c(new_policy, train_episodes, adapt_lr, baseline, gamma, tau, first_order=False) # Useful values states = valid_episodes.state() actions = valid_episodes.action() next_states = valid_episodes.next_state() rewards = valid_episodes.reward() dones = valid_episodes.done() # Compute KL old_densities = old_policy.density(states) new_densities = new_policy.density(states) kl = kl_divergence(new_densities, old_densities).mean() mean_kl += kl # Compute Surrogate Loss advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states) advantages = ch.normalize(advantages).detach() old_log_probs = old_densities.log_prob(actions).mean( dim=1, keepdim=True).detach() new_log_probs = new_densities.log_prob(actions).mean(dim=1, keepdim=True) mean_loss += trpo.policy_loss(new_log_probs, old_log_probs, advantages) mean_kl /= len(iteration_replays) mean_loss /= len(iteration_replays) return mean_loss, mean_kl
def meta_surrogate_loss(iter_replays, iter_policies, policy, baseline, params, anil): mean_loss = 0.0 mean_kl = 0.0 for task_replays, old_policy in zip(iter_replays, iter_policies): train_replays = task_replays[:-1] valid_episodes = task_replays[-1] new_policy = clone_module(policy) # Fast Adapt to the training episodes for train_episodes in train_replays: new_policy = trpo_update(train_episodes, new_policy, baseline, params['inner_lr'], params['gamma'], params['tau'], anil=anil, first_order=False) # Calculate KL from the validation episodes states, actions, rewards, dones, next_states = get_episode_values( valid_episodes) # Compute KL old_densities = old_policy.density(states) new_densities = new_policy.density(states) kl = kl_divergence(new_densities, old_densities).mean() mean_kl += kl # Compute Surrogate Loss advantages = compute_advantages(baseline, params['tau'], params['gamma'], rewards, dones, states, next_states) advantages = ch.normalize(advantages).detach() old_log_probs = old_densities.log_prob(actions).mean( dim=1, keepdim=True).detach() new_log_probs = new_densities.log_prob(actions).mean(dim=1, keepdim=True) mean_loss += trpo.policy_loss(new_log_probs, old_log_probs, advantages) mean_kl /= len(iter_replays) mean_loss /= len(iter_replays) return mean_loss, mean_kl
def main( env_name='AntDirection-v1', adapt_lr=0.1, meta_lr=3e-4, adapt_steps=3, num_iterations=1000, meta_bsz=40, adapt_bsz=20, ppo_clip=0.3, ppo_steps=5, tau=1.00, gamma=0.99, eta=0.0005, adaptive_penalty=False, kl_target=0.01, num_workers=4, seed=421, ): random.seed(seed) np.random.seed(seed) th.manual_seed(seed) def make_env(): env = gym.make(env_name) env = ch.envs.ActionSpaceScaler(env) return env env = l2l.gym.AsyncVectorEnv([make_env for _ in range(num_workers)]) env.seed(seed) env = ch.envs.ActionSpaceScaler(env) env = ch.envs.Torch(env) policy = DiagNormalPolicy(input_size=env.state_size, output_size=env.action_size, hiddens=[64, 64], activation='tanh') meta_learner = l2l.algorithms.MAML(policy, lr=meta_lr) baseline = LinearValue(env.state_size, env.action_size) opt = optim.Adam(meta_learner.parameters(), lr=meta_lr) for iteration in range(num_iterations): iteration_reward = 0.0 iteration_replays = [] iteration_policies = [] # Sample Trajectories for task_config in tqdm(env.sample_tasks(meta_bsz), leave=False, desc='Data'): clone = deepcopy(meta_learner) env.set_task(task_config) env.reset() task = ch.envs.Runner(env) task_replay = [] task_policies = [] # Fast Adapt for step in range(adapt_steps): for p in clone.parameters(): p.detach_().requires_grad_() task_policies.append(deepcopy(clone)) train_episodes = task.run(clone, episodes=adapt_bsz) clone = fast_adapt_a2c(clone, train_episodes, adapt_lr, baseline, gamma, tau, first_order=True) task_replay.append(train_episodes) # Compute Validation Loss for p in clone.parameters(): p.detach_().requires_grad_() task_policies.append(deepcopy(clone)) valid_episodes = task.run(clone, episodes=adapt_bsz) task_replay.append(valid_episodes) iteration_reward += valid_episodes.reward().sum().item( ) / adapt_bsz iteration_replays.append(task_replay) iteration_policies.append(task_policies) # Print statistics print('\nIteration', iteration) adaptation_reward = iteration_reward / meta_bsz print('adaptation_reward', adaptation_reward) # ProMP meta-optimization for ppo_step in tqdm(range(ppo_steps), leave=False, desc='Optim'): promp_loss = 0.0 kl_total = 0.0 for task_replays, old_policies in zip(iteration_replays, iteration_policies): new_policy = meta_learner.clone() states = task_replays[0].state() actions = task_replays[0].action() rewards = task_replays[0].reward() dones = task_replays[0].done() next_states = task_replays[0].next_state() old_policy = old_policies[0] (old_density, new_density, old_log_probs, new_log_probs) = precompute_quantities( states, actions, old_policy, new_policy) advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states) advantages = ch.normalize(advantages).detach() for step in range(adapt_steps): # Compute KL penalty kl_pen = kl_divergence(old_density, new_density).mean() kl_total += kl_pen.item() # Update the clone surr_loss = trpo.policy_loss(new_log_probs, old_log_probs, advantages) new_policy.adapt(surr_loss) # Move to next adaptation step states = task_replays[step + 1].state() actions = task_replays[step + 1].action() rewards = task_replays[step + 1].reward() dones = task_replays[step + 1].done() next_states = task_replays[step + 1].next_state() old_policy = old_policies[step + 1] (old_density, new_density, old_log_probs, new_log_probs) = precompute_quantities( states, actions, old_policy, new_policy) # Compute clip loss advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states) advantages = ch.normalize(advantages).detach() clip_loss = ppo.policy_loss(new_log_probs, old_log_probs, advantages, clip=ppo_clip) # Combine into ProMP loss promp_loss += clip_loss + eta * kl_pen kl_total /= meta_bsz * adapt_steps promp_loss /= meta_bsz * adapt_steps opt.zero_grad() promp_loss.backward(retain_graph=True) opt.step() # Adapt KL penalty based on desired target if adaptive_penalty: if kl_total < kl_target / 1.5: eta /= 2.0 elif kl_total > kl_target * 1.5: eta *= 2.0