Exemplo n.º 1
0
def ppo_update(episodes, policy, optimizer, baseline, prms):

    # Get values to device
    states, actions, rewards, dones, next_states = get_episode_values(episodes)

    # Update value function & Compute advantages
    returns = ch.td.discount(prms['gamma'], rewards, dones)
    advantages = compute_advantages(baseline, prms['tau'], prms['gamma'], rewards, dones, states, next_states)
    advantages = ch.normalize(advantages, epsilon=1e-8).detach()
    # Calculate loss between states and action in the network
    with torch.no_grad():
        old_log_probs = policy.log_prob(states, actions)

    # Initialize inner loop PPO optimizer
    av_loss = 0.0

    for ppo_epoch in range(prms['ppo_epochs']):
        new_log_probs = policy.log_prob(states, actions)

        # Compute the policy loss
        policy_loss = ppo.policy_loss(new_log_probs, old_log_probs, advantages, clip=prms['ppo_clip_ratio'])

        # Adapt model based on the loss
        optimizer.zero_grad()
        policy_loss.backward()
        optimizer.step()

        baseline.fit(states, returns)
        av_loss += policy_loss.item()

    return av_loss / prms['ppo_epochs']
Exemplo n.º 2
0
def update(replay, optimizer, policy, env, lr_schedule):
    _, next_state_value = policy(replay[-1].next_state())
    advantages = pg.generalized_advantage(GAMMA, TAU, replay.reward(),
                                          replay.done(), replay.value(),
                                          next_state_value)

    advantages = ch.utils.normalize(advantages, epsilon=1e-5).view(-1, 1)
    rewards = [a + v for a, v in zip(advantages, replay.value())]

    for i, sars in enumerate(replay):
        sars.reward = rewards[i].detach()
        sars.advantage = advantages[i].detach()

    # Logging
    policy_losses = []
    entropies = []
    value_losses = []
    mean = lambda a: sum(a) / len(a)

    # Perform some optimization steps
    for step in range(PPO_EPOCHS * PPO_NUM_BATCHES):
        batch = replay.sample(PPO_BSZ)
        masses, values = policy(batch.state())

        # Compute losses
        new_log_probs = masses.log_prob(batch.action()).sum(-1, keepdim=True)
        entropy = masses.entropy().sum(-1).mean()
        policy_loss = ppo.policy_loss(new_log_probs,
                                      batch.log_prob(),
                                      batch.advantage(),
                                      clip=PPO_CLIP)
        value_loss = ppo.state_value_loss(values,
                                          batch.value().detach(),
                                          batch.reward(),
                                          clip=PPO_CLIP)
        loss = policy_loss - ENT_WEIGHT * entropy + V_WEIGHT * value_loss

        # Take optimization step
        optimizer.zero_grad()
        loss.backward()
        th.nn.utils.clip_grad_norm_(policy.parameters(), GRAD_NORM)
        optimizer.step()

        policy_losses.append(policy_loss)
        entropies.append(entropy)
        value_losses.append(value_loss)

    # Log metrics
    if dist.get_rank() == 0:
        env.log('policy loss', mean(policy_losses).item())
        env.log('policy entropy', mean(entropies).item())
        env.log('value loss', mean(value_losses).item())
        ppt.plot(mean(env.all_rewards[-10000:]), 'PPO results')

    # Update the parameters on schedule
    if LINEAR_SCHEDULE:
        lr_schedule.step()
Exemplo n.º 3
0
 def test_ppo_policy_loss(self):
     for _ in range(10):
         for shape in [(1, BSZ), (BSZ, 1), (BSZ, )]:
             for clip in [0.0, 0.1, 0.2, 1.0]:
                 new_log_probs = th.randn(BSZ)
                 old_log_probs = th.randn(BSZ)
                 advantages = th.randn(BSZ)
                 ref = ref_policy_loss(new_log_probs,
                                       old_log_probs,
                                       advantages,
                                       clip=clip)
                 loss = ppo.policy_loss(new_log_probs.view(*shape),
                                        old_log_probs.view(*shape),
                                        advantages.view(*shape),
                                        clip=clip)
                 self.assertAlmostEqual(loss.item(), ref.item())
Exemplo n.º 4
0
def single_ppo_update(episodes, learner, baseline, params, anil=False):
    # Get values to device
    states, actions, rewards, dones, next_states = get_episode_values(episodes)

    # Update value function & Compute advantages
    advantages = compute_advantages(baseline, params['tau'], params['gamma'],
                                    rewards, dones, states, next_states)
    advantages = ch.normalize(advantages, epsilon=1e-8).detach()
    # Calculate loss between states and action in the network
    with torch.no_grad():
        old_log_probs = learner.log_prob(states, actions)

    # Initialize inner loop PPO optimizer
    new_log_probs = learner.log_prob(states, actions)
    # Compute the policy loss
    loss = ppo.policy_loss(new_log_probs,
                           old_log_probs,
                           advantages,
                           clip=params['ppo_clip_ratio'])
    # Adapt model based on the loss
    learner.adapt(loss, allow_unused=anil)
    return loss
Exemplo n.º 5
0
def fast_adapt_ppo(task, learner, baseline, params, anil=False, render=False):
    # During inner loop adaptation we do not store gradients for the network body
    if anil:
        learner.module.turn_off_body_grads()

    for step in range(params['adapt_steps']):
        # Collect adaptation / support episodes
        support_episodes = task.run(learner,
                                    episodes=params['adapt_batch_size'],
                                    render=render)

        # Get values to device
        states, actions, rewards, dones, next_states = get_episode_values(
            support_episodes)

        # Update value function & Compute advantages
        advantages = compute_advantages(baseline, params['tau'],
                                        params['gamma'], rewards, dones,
                                        states, next_states)
        advantages = ch.normalize(advantages, epsilon=1e-8).detach()
        # Calculate loss between states and action in the network
        with torch.no_grad():
            old_log_probs = learner.log_prob(states, actions)

        # Initialize inner loop PPO optimizer
        av_loss = 0.0
        for ppo_epoch in range(params['ppo_epochs']):
            new_log_probs = learner.log_prob(states, actions)
            # Compute the policy loss
            loss = ppo.policy_loss(new_log_probs,
                                   old_log_probs,
                                   advantages,
                                   clip=params['ppo_clip_ratio'])
            # Adapt model based on the loss
            learner.adapt(loss, allow_unused=anil)
            av_loss += loss

    # We need to include the body network parameters for the query set
    if anil:
        learner.module.turn_on_body_grads()

    # Collect evaluation / query episodes
    query_episodes = task.run(learner, episodes=params['adapt_batch_size'])
    # Get values to device
    states, actions, rewards, dones, next_states = get_episode_values(
        query_episodes)
    # Update value function & Compute advantages
    advantages = compute_advantages(baseline, params['tau'], params['gamma'],
                                    rewards, dones, states, next_states)
    advantages = ch.normalize(advantages, epsilon=1e-8).detach()
    # Calculate loss between states and action in the network
    with torch.no_grad():
        old_log_probs = learner.log_prob(states, actions)

    new_log_probs = learner.log_prob(states, actions)
    # Compute the policy loss
    valid_loss = ppo.policy_loss(new_log_probs,
                                 old_log_probs,
                                 advantages,
                                 clip=params['ppo_clip_ratio'])

    # Calculate the average reward of the evaluation episodes
    query_rew = query_episodes.reward().sum().item(
    ) / params['adapt_batch_size']
    query_success_rate = get_ep_successes(
        query_episodes, params['max_path_length']) / params['adapt_batch_size']

    return valid_loss, query_rew, query_success_rate
Exemplo n.º 6
0
def main(
    env_name='AntDirection-v1',
    adapt_lr=0.1,
    meta_lr=3e-4,
    adapt_steps=3,
    num_iterations=1000,
    meta_bsz=40,
    adapt_bsz=20,
    ppo_clip=0.3,
    ppo_steps=5,
    tau=1.00,
    gamma=0.99,
    eta=0.0005,
    adaptive_penalty=False,
    kl_target=0.01,
    num_workers=4,
    seed=421,
):
    random.seed(seed)
    np.random.seed(seed)
    th.manual_seed(seed)

    def make_env():
        env = gym.make(env_name)
        env = ch.envs.ActionSpaceScaler(env)
        return env

    env = l2l.gym.AsyncVectorEnv([make_env for _ in range(num_workers)])
    env.seed(seed)
    env = ch.envs.ActionSpaceScaler(env)
    env = ch.envs.Torch(env)
    policy = DiagNormalPolicy(input_size=env.state_size,
                              output_size=env.action_size,
                              hiddens=[64, 64],
                              activation='tanh')
    meta_learner = l2l.algorithms.MAML(policy, lr=meta_lr)
    baseline = LinearValue(env.state_size, env.action_size)
    opt = optim.Adam(meta_learner.parameters(), lr=meta_lr)

    for iteration in range(num_iterations):
        iteration_reward = 0.0
        iteration_replays = []
        iteration_policies = []

        # Sample Trajectories
        for task_config in tqdm(env.sample_tasks(meta_bsz),
                                leave=False,
                                desc='Data'):
            clone = deepcopy(meta_learner)
            env.set_task(task_config)
            env.reset()
            task = ch.envs.Runner(env)
            task_replay = []
            task_policies = []

            # Fast Adapt
            for step in range(adapt_steps):
                for p in clone.parameters():
                    p.detach_().requires_grad_()
                task_policies.append(deepcopy(clone))
                train_episodes = task.run(clone, episodes=adapt_bsz)
                clone = fast_adapt_a2c(clone,
                                       train_episodes,
                                       adapt_lr,
                                       baseline,
                                       gamma,
                                       tau,
                                       first_order=True)
                task_replay.append(train_episodes)

            # Compute Validation Loss
            for p in clone.parameters():
                p.detach_().requires_grad_()
            task_policies.append(deepcopy(clone))
            valid_episodes = task.run(clone, episodes=adapt_bsz)
            task_replay.append(valid_episodes)
            iteration_reward += valid_episodes.reward().sum().item(
            ) / adapt_bsz
            iteration_replays.append(task_replay)
            iteration_policies.append(task_policies)

        # Print statistics
        print('\nIteration', iteration)
        adaptation_reward = iteration_reward / meta_bsz
        print('adaptation_reward', adaptation_reward)

        # ProMP meta-optimization
        for ppo_step in tqdm(range(ppo_steps), leave=False, desc='Optim'):
            promp_loss = 0.0
            kl_total = 0.0
            for task_replays, old_policies in zip(iteration_replays,
                                                  iteration_policies):
                new_policy = meta_learner.clone()
                states = task_replays[0].state()
                actions = task_replays[0].action()
                rewards = task_replays[0].reward()
                dones = task_replays[0].done()
                next_states = task_replays[0].next_state()
                old_policy = old_policies[0]
                (old_density, new_density, old_log_probs,
                 new_log_probs) = precompute_quantities(
                     states, actions, old_policy, new_policy)
                advantages = compute_advantages(baseline, tau, gamma, rewards,
                                                dones, states, next_states)
                advantages = ch.normalize(advantages).detach()
                for step in range(adapt_steps):
                    # Compute KL penalty
                    kl_pen = kl_divergence(old_density, new_density).mean()
                    kl_total += kl_pen.item()

                    # Update the clone
                    surr_loss = trpo.policy_loss(new_log_probs, old_log_probs,
                                                 advantages)
                    new_policy.adapt(surr_loss)

                    # Move to next adaptation step
                    states = task_replays[step + 1].state()
                    actions = task_replays[step + 1].action()
                    rewards = task_replays[step + 1].reward()
                    dones = task_replays[step + 1].done()
                    next_states = task_replays[step + 1].next_state()
                    old_policy = old_policies[step + 1]
                    (old_density, new_density, old_log_probs,
                     new_log_probs) = precompute_quantities(
                         states, actions, old_policy, new_policy)

                    # Compute clip loss
                    advantages = compute_advantages(baseline, tau, gamma,
                                                    rewards, dones, states,
                                                    next_states)
                    advantages = ch.normalize(advantages).detach()
                    clip_loss = ppo.policy_loss(new_log_probs,
                                                old_log_probs,
                                                advantages,
                                                clip=ppo_clip)

                    # Combine into ProMP loss
                    promp_loss += clip_loss + eta * kl_pen

            kl_total /= meta_bsz * adapt_steps
            promp_loss /= meta_bsz * adapt_steps
            opt.zero_grad()
            promp_loss.backward(retain_graph=True)
            opt.step()

            # Adapt KL penalty based on desired target
            if adaptive_penalty:
                if kl_total < kl_target / 1.5:
                    eta /= 2.0
                elif kl_total > kl_target * 1.5:
                    eta *= 2.0
Exemplo n.º 7
0
def update(replay, optimizer, policy, env, lr_schedule):
    _, next_state_value = policy(replay[-1].next_state)
    # NOTE: Kostrikov uses GAE here.
    advantages = ch.generalized_advantage(GAMMA, TAU, replay.reward(),
                                          replay.done(), replay.value(),
                                          next_state_value)

    advantages = advantages.view(-1, 1)
    #    advantages = ch.utils.normalize(advantages, epsilon=1e-5).view(-1, 1)
    #    rewards = [a + v for a, v in zip(advantages, replay.value())]
    rewards = advantages + replay.value()

    #    rewards = ch.discount(GAMMA,
    #                          replay.reward(),
    #                          replay.done(),
    #                          bootstrap=next_state_value)
    #    rewards = rewards.detach()
    #    advantages = rewards.detach() - replay.value().detach()
    #    advantages = ch.utils.normalize(advantages, epsilon=1e-5).view(-1, 1)

    for i, sars in enumerate(replay):
        sars.reward = rewards[i].detach()
        sars.advantage = advantages[i].detach()

    # Logging
    policy_losses = []
    entropies = []
    value_losses = []
    mean = lambda a: sum(a) / len(a)

    # Perform some optimization steps
    for step in range(PPO_EPOCHS * PPO_NUM_BATCHES):
        batch = replay.sample(PPO_BSZ)
        masses, values = policy(batch.state())

        # Compute losses
        advs = ch.normalize(batch.advantage(), epsilon=1e-8)
        new_log_probs = masses.log_prob(batch.action()).sum(-1, keepdim=True)
        entropy = masses.entropy().sum(-1).mean()
        policy_loss = ppo.policy_loss(
            new_log_probs,
            batch.log_prob(),
            #                                      batch.advantage(),
            advs,
            clip=PPO_CLIP)
        value_loss = ppo.state_value_loss(values,
                                          batch.value().detach(),
                                          batch.reward(),
                                          clip=PPO_CLIP)
        loss = policy_loss - ENT_WEIGHT * entropy + V_WEIGHT * value_loss

        # Take optimization step
        optimizer.zero_grad()
        loss.backward()
        th.nn.utils.clip_grad_norm_(policy.parameters(), GRAD_NORM)
        optimizer.step()

        policy_losses.append(policy_loss.item())
        entropies.append(entropy.item())
        value_losses.append(value_loss.item())

    # Log metrics
    env.log('policy loss', mean(policy_losses))
    env.log('policy entropy', mean(entropies))
    env.log('value loss', mean(value_losses))

    # Update the parameters on schedule
    if LINEAR_SCHEDULE:
        lr_schedule.step()
Exemplo n.º 8
0
def main(
    experiment='dev',
    env_name='2DNavigation-v0',
    adapt_lr=0.1,
    meta_lr=0.01,
    adapt_steps=1,
    num_iterations=20,
    meta_bsz=10,
    adapt_bsz=10,
    tau=1.00,
    gamma=0.99,
    num_workers=1,
    seed=42,
):
    random.seed(seed)
    np.random.seed(seed)
    th.manual_seed(seed)

    def make_env():
        return gym.make(env_name)

    env = l2l.gym.AsyncVectorEnv([make_env for _ in range(num_workers)])
    env.seed(seed)
    env = ch.envs.Torch(env)
    policy = DiagNormalPolicy(env.state_size, env.action_size)
    meta_learner = l2l.MAML(policy, lr=meta_lr)
    baseline = LinearValue(env.state_size, env.action_size)
    opt = optim.Adam(meta_learner.parameters(), lr=meta_lr)

    all_rewards = []
    for iteration in range(num_iterations):
        iteration_reward = 0.0
        iteration_replays = []
        iteration_policies = []
        policy.to('cpu')
        baseline.to('cpu')

        for task_config in tqdm(env.sample_tasks(meta_bsz),
                                leave=False,
                                desc='Data'):  # Samples a new config
            learner = meta_learner.clone()
            env.reset_task(task_config)
            env.reset()
            task = ch.envs.Runner(env)
            task_replay = []

            # Fast Adapt
            for step in range(adapt_steps):
                train_episodes = task.run(learner, episodes=adapt_bsz)
                learner = fast_adapt_a2c(learner,
                                         train_episodes,
                                         adapt_lr,
                                         baseline,
                                         gamma,
                                         tau,
                                         first_order=True)
                task_replay.append(train_episodes)

            # Compute Validation Loss
            valid_episodes = task.run(learner, episodes=adapt_bsz)
            task_replay.append(valid_episodes)
            iteration_reward += valid_episodes.reward().sum().item(
            ) / adapt_bsz
            iteration_replays.append(task_replay)
            iteration_policies.append(learner)

        # Print statistics
        print('\nIteration', iteration)
        adaptation_reward = iteration_reward / meta_bsz
        all_rewards.append(adaptation_reward)
        print('adaptation_reward', adaptation_reward)

        # PPO meta-optimization
        for ppo_step in tqdm(range(10), leave=False, desc='Optim'):
            ppo_loss = 0.0
            for task_replays, old_policy in zip(iteration_replays,
                                                iteration_policies):
                train_replays = task_replays[:-1]
                valid_replay = task_replays[-1]

                # Fast adapt new policy, starting from the current init
                new_policy = meta_learner.clone()
                for train_episodes in train_replays:
                    new_policy = fast_adapt_a2c(new_policy, train_episodes,
                                                adapt_lr, baseline, gamma, tau)

                # Compute PPO loss between old and new clones
                states = valid_replay.state()
                actions = valid_replay.action()
                rewards = valid_replay.reward()
                dones = valid_replay.done()
                next_states = valid_replay.next_state()
                old_log_probs = old_policy.log_prob(states, actions).detach()
                new_log_probs = new_policy.log_prob(states, actions)
                advantages = compute_advantages(baseline, tau, gamma, rewards,
                                                dones, states, next_states)
                advantages = ch.normalize(advantages).detach()
                ppo_loss += ppo.policy_loss(new_log_probs,
                                            old_log_probs,
                                            advantages,
                                            clip=0.1)

            ppo_loss /= meta_bsz
            opt.zero_grad()
            ppo_loss.backward()
            opt.step()