def ppo_update(episodes, policy, optimizer, baseline, prms): # Get values to device states, actions, rewards, dones, next_states = get_episode_values(episodes) # Update value function & Compute advantages returns = ch.td.discount(prms['gamma'], rewards, dones) advantages = compute_advantages(baseline, prms['tau'], prms['gamma'], rewards, dones, states, next_states) advantages = ch.normalize(advantages, epsilon=1e-8).detach() # Calculate loss between states and action in the network with torch.no_grad(): old_log_probs = policy.log_prob(states, actions) # Initialize inner loop PPO optimizer av_loss = 0.0 for ppo_epoch in range(prms['ppo_epochs']): new_log_probs = policy.log_prob(states, actions) # Compute the policy loss policy_loss = ppo.policy_loss(new_log_probs, old_log_probs, advantages, clip=prms['ppo_clip_ratio']) # Adapt model based on the loss optimizer.zero_grad() policy_loss.backward() optimizer.step() baseline.fit(states, returns) av_loss += policy_loss.item() return av_loss / prms['ppo_epochs']
def update(replay, optimizer, policy, env, lr_schedule): _, next_state_value = policy(replay[-1].next_state()) advantages = pg.generalized_advantage(GAMMA, TAU, replay.reward(), replay.done(), replay.value(), next_state_value) advantages = ch.utils.normalize(advantages, epsilon=1e-5).view(-1, 1) rewards = [a + v for a, v in zip(advantages, replay.value())] for i, sars in enumerate(replay): sars.reward = rewards[i].detach() sars.advantage = advantages[i].detach() # Logging policy_losses = [] entropies = [] value_losses = [] mean = lambda a: sum(a) / len(a) # Perform some optimization steps for step in range(PPO_EPOCHS * PPO_NUM_BATCHES): batch = replay.sample(PPO_BSZ) masses, values = policy(batch.state()) # Compute losses new_log_probs = masses.log_prob(batch.action()).sum(-1, keepdim=True) entropy = masses.entropy().sum(-1).mean() policy_loss = ppo.policy_loss(new_log_probs, batch.log_prob(), batch.advantage(), clip=PPO_CLIP) value_loss = ppo.state_value_loss(values, batch.value().detach(), batch.reward(), clip=PPO_CLIP) loss = policy_loss - ENT_WEIGHT * entropy + V_WEIGHT * value_loss # Take optimization step optimizer.zero_grad() loss.backward() th.nn.utils.clip_grad_norm_(policy.parameters(), GRAD_NORM) optimizer.step() policy_losses.append(policy_loss) entropies.append(entropy) value_losses.append(value_loss) # Log metrics if dist.get_rank() == 0: env.log('policy loss', mean(policy_losses).item()) env.log('policy entropy', mean(entropies).item()) env.log('value loss', mean(value_losses).item()) ppt.plot(mean(env.all_rewards[-10000:]), 'PPO results') # Update the parameters on schedule if LINEAR_SCHEDULE: lr_schedule.step()
def test_ppo_policy_loss(self): for _ in range(10): for shape in [(1, BSZ), (BSZ, 1), (BSZ, )]: for clip in [0.0, 0.1, 0.2, 1.0]: new_log_probs = th.randn(BSZ) old_log_probs = th.randn(BSZ) advantages = th.randn(BSZ) ref = ref_policy_loss(new_log_probs, old_log_probs, advantages, clip=clip) loss = ppo.policy_loss(new_log_probs.view(*shape), old_log_probs.view(*shape), advantages.view(*shape), clip=clip) self.assertAlmostEqual(loss.item(), ref.item())
def single_ppo_update(episodes, learner, baseline, params, anil=False): # Get values to device states, actions, rewards, dones, next_states = get_episode_values(episodes) # Update value function & Compute advantages advantages = compute_advantages(baseline, params['tau'], params['gamma'], rewards, dones, states, next_states) advantages = ch.normalize(advantages, epsilon=1e-8).detach() # Calculate loss between states and action in the network with torch.no_grad(): old_log_probs = learner.log_prob(states, actions) # Initialize inner loop PPO optimizer new_log_probs = learner.log_prob(states, actions) # Compute the policy loss loss = ppo.policy_loss(new_log_probs, old_log_probs, advantages, clip=params['ppo_clip_ratio']) # Adapt model based on the loss learner.adapt(loss, allow_unused=anil) return loss
def fast_adapt_ppo(task, learner, baseline, params, anil=False, render=False): # During inner loop adaptation we do not store gradients for the network body if anil: learner.module.turn_off_body_grads() for step in range(params['adapt_steps']): # Collect adaptation / support episodes support_episodes = task.run(learner, episodes=params['adapt_batch_size'], render=render) # Get values to device states, actions, rewards, dones, next_states = get_episode_values( support_episodes) # Update value function & Compute advantages advantages = compute_advantages(baseline, params['tau'], params['gamma'], rewards, dones, states, next_states) advantages = ch.normalize(advantages, epsilon=1e-8).detach() # Calculate loss between states and action in the network with torch.no_grad(): old_log_probs = learner.log_prob(states, actions) # Initialize inner loop PPO optimizer av_loss = 0.0 for ppo_epoch in range(params['ppo_epochs']): new_log_probs = learner.log_prob(states, actions) # Compute the policy loss loss = ppo.policy_loss(new_log_probs, old_log_probs, advantages, clip=params['ppo_clip_ratio']) # Adapt model based on the loss learner.adapt(loss, allow_unused=anil) av_loss += loss # We need to include the body network parameters for the query set if anil: learner.module.turn_on_body_grads() # Collect evaluation / query episodes query_episodes = task.run(learner, episodes=params['adapt_batch_size']) # Get values to device states, actions, rewards, dones, next_states = get_episode_values( query_episodes) # Update value function & Compute advantages advantages = compute_advantages(baseline, params['tau'], params['gamma'], rewards, dones, states, next_states) advantages = ch.normalize(advantages, epsilon=1e-8).detach() # Calculate loss between states and action in the network with torch.no_grad(): old_log_probs = learner.log_prob(states, actions) new_log_probs = learner.log_prob(states, actions) # Compute the policy loss valid_loss = ppo.policy_loss(new_log_probs, old_log_probs, advantages, clip=params['ppo_clip_ratio']) # Calculate the average reward of the evaluation episodes query_rew = query_episodes.reward().sum().item( ) / params['adapt_batch_size'] query_success_rate = get_ep_successes( query_episodes, params['max_path_length']) / params['adapt_batch_size'] return valid_loss, query_rew, query_success_rate
def main( env_name='AntDirection-v1', adapt_lr=0.1, meta_lr=3e-4, adapt_steps=3, num_iterations=1000, meta_bsz=40, adapt_bsz=20, ppo_clip=0.3, ppo_steps=5, tau=1.00, gamma=0.99, eta=0.0005, adaptive_penalty=False, kl_target=0.01, num_workers=4, seed=421, ): random.seed(seed) np.random.seed(seed) th.manual_seed(seed) def make_env(): env = gym.make(env_name) env = ch.envs.ActionSpaceScaler(env) return env env = l2l.gym.AsyncVectorEnv([make_env for _ in range(num_workers)]) env.seed(seed) env = ch.envs.ActionSpaceScaler(env) env = ch.envs.Torch(env) policy = DiagNormalPolicy(input_size=env.state_size, output_size=env.action_size, hiddens=[64, 64], activation='tanh') meta_learner = l2l.algorithms.MAML(policy, lr=meta_lr) baseline = LinearValue(env.state_size, env.action_size) opt = optim.Adam(meta_learner.parameters(), lr=meta_lr) for iteration in range(num_iterations): iteration_reward = 0.0 iteration_replays = [] iteration_policies = [] # Sample Trajectories for task_config in tqdm(env.sample_tasks(meta_bsz), leave=False, desc='Data'): clone = deepcopy(meta_learner) env.set_task(task_config) env.reset() task = ch.envs.Runner(env) task_replay = [] task_policies = [] # Fast Adapt for step in range(adapt_steps): for p in clone.parameters(): p.detach_().requires_grad_() task_policies.append(deepcopy(clone)) train_episodes = task.run(clone, episodes=adapt_bsz) clone = fast_adapt_a2c(clone, train_episodes, adapt_lr, baseline, gamma, tau, first_order=True) task_replay.append(train_episodes) # Compute Validation Loss for p in clone.parameters(): p.detach_().requires_grad_() task_policies.append(deepcopy(clone)) valid_episodes = task.run(clone, episodes=adapt_bsz) task_replay.append(valid_episodes) iteration_reward += valid_episodes.reward().sum().item( ) / adapt_bsz iteration_replays.append(task_replay) iteration_policies.append(task_policies) # Print statistics print('\nIteration', iteration) adaptation_reward = iteration_reward / meta_bsz print('adaptation_reward', adaptation_reward) # ProMP meta-optimization for ppo_step in tqdm(range(ppo_steps), leave=False, desc='Optim'): promp_loss = 0.0 kl_total = 0.0 for task_replays, old_policies in zip(iteration_replays, iteration_policies): new_policy = meta_learner.clone() states = task_replays[0].state() actions = task_replays[0].action() rewards = task_replays[0].reward() dones = task_replays[0].done() next_states = task_replays[0].next_state() old_policy = old_policies[0] (old_density, new_density, old_log_probs, new_log_probs) = precompute_quantities( states, actions, old_policy, new_policy) advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states) advantages = ch.normalize(advantages).detach() for step in range(adapt_steps): # Compute KL penalty kl_pen = kl_divergence(old_density, new_density).mean() kl_total += kl_pen.item() # Update the clone surr_loss = trpo.policy_loss(new_log_probs, old_log_probs, advantages) new_policy.adapt(surr_loss) # Move to next adaptation step states = task_replays[step + 1].state() actions = task_replays[step + 1].action() rewards = task_replays[step + 1].reward() dones = task_replays[step + 1].done() next_states = task_replays[step + 1].next_state() old_policy = old_policies[step + 1] (old_density, new_density, old_log_probs, new_log_probs) = precompute_quantities( states, actions, old_policy, new_policy) # Compute clip loss advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states) advantages = ch.normalize(advantages).detach() clip_loss = ppo.policy_loss(new_log_probs, old_log_probs, advantages, clip=ppo_clip) # Combine into ProMP loss promp_loss += clip_loss + eta * kl_pen kl_total /= meta_bsz * adapt_steps promp_loss /= meta_bsz * adapt_steps opt.zero_grad() promp_loss.backward(retain_graph=True) opt.step() # Adapt KL penalty based on desired target if adaptive_penalty: if kl_total < kl_target / 1.5: eta /= 2.0 elif kl_total > kl_target * 1.5: eta *= 2.0
def update(replay, optimizer, policy, env, lr_schedule): _, next_state_value = policy(replay[-1].next_state) # NOTE: Kostrikov uses GAE here. advantages = ch.generalized_advantage(GAMMA, TAU, replay.reward(), replay.done(), replay.value(), next_state_value) advantages = advantages.view(-1, 1) # advantages = ch.utils.normalize(advantages, epsilon=1e-5).view(-1, 1) # rewards = [a + v for a, v in zip(advantages, replay.value())] rewards = advantages + replay.value() # rewards = ch.discount(GAMMA, # replay.reward(), # replay.done(), # bootstrap=next_state_value) # rewards = rewards.detach() # advantages = rewards.detach() - replay.value().detach() # advantages = ch.utils.normalize(advantages, epsilon=1e-5).view(-1, 1) for i, sars in enumerate(replay): sars.reward = rewards[i].detach() sars.advantage = advantages[i].detach() # Logging policy_losses = [] entropies = [] value_losses = [] mean = lambda a: sum(a) / len(a) # Perform some optimization steps for step in range(PPO_EPOCHS * PPO_NUM_BATCHES): batch = replay.sample(PPO_BSZ) masses, values = policy(batch.state()) # Compute losses advs = ch.normalize(batch.advantage(), epsilon=1e-8) new_log_probs = masses.log_prob(batch.action()).sum(-1, keepdim=True) entropy = masses.entropy().sum(-1).mean() policy_loss = ppo.policy_loss( new_log_probs, batch.log_prob(), # batch.advantage(), advs, clip=PPO_CLIP) value_loss = ppo.state_value_loss(values, batch.value().detach(), batch.reward(), clip=PPO_CLIP) loss = policy_loss - ENT_WEIGHT * entropy + V_WEIGHT * value_loss # Take optimization step optimizer.zero_grad() loss.backward() th.nn.utils.clip_grad_norm_(policy.parameters(), GRAD_NORM) optimizer.step() policy_losses.append(policy_loss.item()) entropies.append(entropy.item()) value_losses.append(value_loss.item()) # Log metrics env.log('policy loss', mean(policy_losses)) env.log('policy entropy', mean(entropies)) env.log('value loss', mean(value_losses)) # Update the parameters on schedule if LINEAR_SCHEDULE: lr_schedule.step()
def main( experiment='dev', env_name='2DNavigation-v0', adapt_lr=0.1, meta_lr=0.01, adapt_steps=1, num_iterations=20, meta_bsz=10, adapt_bsz=10, tau=1.00, gamma=0.99, num_workers=1, seed=42, ): random.seed(seed) np.random.seed(seed) th.manual_seed(seed) def make_env(): return gym.make(env_name) env = l2l.gym.AsyncVectorEnv([make_env for _ in range(num_workers)]) env.seed(seed) env = ch.envs.Torch(env) policy = DiagNormalPolicy(env.state_size, env.action_size) meta_learner = l2l.MAML(policy, lr=meta_lr) baseline = LinearValue(env.state_size, env.action_size) opt = optim.Adam(meta_learner.parameters(), lr=meta_lr) all_rewards = [] for iteration in range(num_iterations): iteration_reward = 0.0 iteration_replays = [] iteration_policies = [] policy.to('cpu') baseline.to('cpu') for task_config in tqdm(env.sample_tasks(meta_bsz), leave=False, desc='Data'): # Samples a new config learner = meta_learner.clone() env.reset_task(task_config) env.reset() task = ch.envs.Runner(env) task_replay = [] # Fast Adapt for step in range(adapt_steps): train_episodes = task.run(learner, episodes=adapt_bsz) learner = fast_adapt_a2c(learner, train_episodes, adapt_lr, baseline, gamma, tau, first_order=True) task_replay.append(train_episodes) # Compute Validation Loss valid_episodes = task.run(learner, episodes=adapt_bsz) task_replay.append(valid_episodes) iteration_reward += valid_episodes.reward().sum().item( ) / adapt_bsz iteration_replays.append(task_replay) iteration_policies.append(learner) # Print statistics print('\nIteration', iteration) adaptation_reward = iteration_reward / meta_bsz all_rewards.append(adaptation_reward) print('adaptation_reward', adaptation_reward) # PPO meta-optimization for ppo_step in tqdm(range(10), leave=False, desc='Optim'): ppo_loss = 0.0 for task_replays, old_policy in zip(iteration_replays, iteration_policies): train_replays = task_replays[:-1] valid_replay = task_replays[-1] # Fast adapt new policy, starting from the current init new_policy = meta_learner.clone() for train_episodes in train_replays: new_policy = fast_adapt_a2c(new_policy, train_episodes, adapt_lr, baseline, gamma, tau) # Compute PPO loss between old and new clones states = valid_replay.state() actions = valid_replay.action() rewards = valid_replay.reward() dones = valid_replay.done() next_states = valid_replay.next_state() old_log_probs = old_policy.log_prob(states, actions).detach() new_log_probs = new_policy.log_prob(states, actions) advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states) advantages = ch.normalize(advantages).detach() ppo_loss += ppo.policy_loss(new_log_probs, old_log_probs, advantages, clip=0.1) ppo_loss /= meta_bsz opt.zero_grad() ppo_loss.backward() opt.step()