コード例 #1
0
def trpo_update(replay, policy, baseline):
    gamma = 0.99
    tau = 0.95
    max_kl = 0.01
    ls_max_steps = 15
    backtrack_factor = 0.5
    old_policy = deepcopy(policy)
    for step in range(10):
        states = replay.state()
        actions = replay.action()
        rewards = replay.reward()
        dones = replay.done()
        next_states = replay.next_state()
        returns = ch.td.discount(gamma, rewards, dones)
        baseline.fit(states, returns)
        values = baseline(states)
        next_values = baseline(next_states)

        # Compute KL
        with th.no_grad():
            old_density = old_policy.density(states)
        new_density = policy.density(states)
        kl = kl_divergence(old_density, new_density).mean()

        # Compute surrogate loss
        old_log_probs = old_density.log_prob(actions).mean(dim=1, keepdim=True)
        new_log_probs = new_density.log_prob(actions).mean(dim=1, keepdim=True)
        bootstraps = values * (1.0 - dones) + next_values * dones
        advantages = ch.pg.generalized_advantage(gamma, tau, rewards, dones,
                                                 bootstraps, th.zeros(1))
        advantages = ch.normalize(advantages).detach()
        surr_loss = trpo.policy_loss(new_log_probs, old_log_probs, advantages)

        # Compute the update
        grad = autograd.grad(surr_loss, policy.parameters(), retain_graph=True)
        Fvp = trpo.hessian_vector_product(kl, policy.parameters())
        grad = parameters_to_vector(grad).detach()
        step = trpo.conjugate_gradient(Fvp, grad)
        lagrange_mult = 0.5 * th.dot(step, Fvp(step)) / max_kl
        step = step / lagrange_mult
        step_ = [th.zeros_like(p.data) for p in policy.parameters()]
        vector_to_parameters(step, step_)
        step = step_

        #  Line-search
        for ls_step in range(ls_max_steps):
            stepsize = backtrack_factor**ls_step
            clone = deepcopy(policy)
            for c, u in zip(clone.parameters(), step):
                c.data.add_(-stepsize, u.data)
            new_density = clone.density(states)
            new_kl = kl_divergence(old_density, new_density).mean()
            new_log_probs = new_density.log_prob(actions).mean(dim=1,
                                                               keepdim=True)
            new_loss = trpo.policy_loss(new_log_probs, old_log_probs,
                                        advantages)
            if new_loss < surr_loss and new_kl < max_kl:
                for p, c in zip(policy.parameters(), clone.parameters()):
                    p.data[:] = c.data[:]
                break
コード例 #2
0
    def test_trpo_policy_loss(self):

        for shape in [(1, BSZ), (BSZ, 1), (BSZ, )]:
            new_log_probs = th.randn(BSZ)
            old_log_probs = th.randn(BSZ)
            advantages = th.randn(BSZ)
            ref = trpo.policy_loss(new_log_probs, old_log_probs, advantages)
            loss = trpo.policy_loss(new_log_probs.view(*shape),
                                    old_log_probs.view(*shape),
                                    advantages.view(*shape))
            self.assertAlmostEqual(loss.item(), ref.item())
コード例 #3
0
def meta_surrogate_loss(iteration_replays, iteration_policies, policy,
                        baseline, tau, gamma, adapt_lr):
    mean_loss = 0.0
    mean_kl = 0.0
    for task_replays, old_policy in tqdm(zip(iteration_replays,
                                             iteration_policies),
                                         total=len(iteration_replays),
                                         desc='Surrogate Loss',
                                         leave=False):
        policy.reset_context()
        train_replays = task_replays[:-1]
        valid_episodes = task_replays[-1]
        new_policy = l2l.clone_module(policy)

        # Fast Adapt
        for train_episodes in train_replays:
            new_policy = fast_adapt_a2c(new_policy,
                                        train_episodes,
                                        adapt_lr,
                                        baseline,
                                        gamma,
                                        tau,
                                        first_order=False)

        # Useful values
        states = valid_episodes.state()
        actions = valid_episodes.action()
        next_states = valid_episodes.next_state()
        rewards = valid_episodes.reward()
        dones = valid_episodes.done()

        # Compute KL
        old_densities = old_policy.density(states)
        new_densities = new_policy.density(states)
        kl = kl_divergence(new_densities, old_densities).mean()
        mean_kl += kl

        # Compute Surrogate Loss
        advantages = compute_advantages(baseline, tau, gamma, rewards, dones,
                                        states, next_states)
        advantages = ch.normalize(advantages).detach()
        old_log_probs = old_densities.log_prob(actions).mean(
            dim=1, keepdim=True).detach()
        new_log_probs = new_densities.log_prob(actions).mean(dim=1,
                                                             keepdim=True)
        mean_loss += trpo.policy_loss(new_log_probs, old_log_probs, advantages)
    mean_kl /= len(iteration_replays)
    mean_loss /= len(iteration_replays)
    return mean_loss, mean_kl
コード例 #4
0
def meta_surrogate_loss(iter_replays, iter_policies, policy, baseline, params,
                        anil):
    mean_loss = 0.0
    mean_kl = 0.0
    for task_replays, old_policy in zip(iter_replays, iter_policies):
        train_replays = task_replays[:-1]
        valid_episodes = task_replays[-1]
        new_policy = clone_module(policy)

        # Fast Adapt to the training episodes
        for train_episodes in train_replays:
            new_policy = trpo_update(train_episodes,
                                     new_policy,
                                     baseline,
                                     params['inner_lr'],
                                     params['gamma'],
                                     params['tau'],
                                     anil=anil,
                                     first_order=False)

        # Calculate KL from the validation episodes
        states, actions, rewards, dones, next_states = get_episode_values(
            valid_episodes)

        # Compute KL
        old_densities = old_policy.density(states)
        new_densities = new_policy.density(states)
        kl = kl_divergence(new_densities, old_densities).mean()
        mean_kl += kl

        # Compute Surrogate Loss
        advantages = compute_advantages(baseline, params['tau'],
                                        params['gamma'], rewards, dones,
                                        states, next_states)
        advantages = ch.normalize(advantages).detach()
        old_log_probs = old_densities.log_prob(actions).mean(
            dim=1, keepdim=True).detach()
        new_log_probs = new_densities.log_prob(actions).mean(dim=1,
                                                             keepdim=True)
        mean_loss += trpo.policy_loss(new_log_probs, old_log_probs, advantages)

    mean_kl /= len(iter_replays)
    mean_loss /= len(iter_replays)
    return mean_loss, mean_kl
コード例 #5
0
ファイル: promp.py プロジェクト: zstbackcourt/learn2learn
def main(
    env_name='AntDirection-v1',
    adapt_lr=0.1,
    meta_lr=3e-4,
    adapt_steps=3,
    num_iterations=1000,
    meta_bsz=40,
    adapt_bsz=20,
    ppo_clip=0.3,
    ppo_steps=5,
    tau=1.00,
    gamma=0.99,
    eta=0.0005,
    adaptive_penalty=False,
    kl_target=0.01,
    num_workers=4,
    seed=421,
):
    random.seed(seed)
    np.random.seed(seed)
    th.manual_seed(seed)

    def make_env():
        env = gym.make(env_name)
        env = ch.envs.ActionSpaceScaler(env)
        return env

    env = l2l.gym.AsyncVectorEnv([make_env for _ in range(num_workers)])
    env.seed(seed)
    env = ch.envs.ActionSpaceScaler(env)
    env = ch.envs.Torch(env)
    policy = DiagNormalPolicy(input_size=env.state_size,
                              output_size=env.action_size,
                              hiddens=[64, 64],
                              activation='tanh')
    meta_learner = l2l.algorithms.MAML(policy, lr=meta_lr)
    baseline = LinearValue(env.state_size, env.action_size)
    opt = optim.Adam(meta_learner.parameters(), lr=meta_lr)

    for iteration in range(num_iterations):
        iteration_reward = 0.0
        iteration_replays = []
        iteration_policies = []

        # Sample Trajectories
        for task_config in tqdm(env.sample_tasks(meta_bsz),
                                leave=False,
                                desc='Data'):
            clone = deepcopy(meta_learner)
            env.set_task(task_config)
            env.reset()
            task = ch.envs.Runner(env)
            task_replay = []
            task_policies = []

            # Fast Adapt
            for step in range(adapt_steps):
                for p in clone.parameters():
                    p.detach_().requires_grad_()
                task_policies.append(deepcopy(clone))
                train_episodes = task.run(clone, episodes=adapt_bsz)
                clone = fast_adapt_a2c(clone,
                                       train_episodes,
                                       adapt_lr,
                                       baseline,
                                       gamma,
                                       tau,
                                       first_order=True)
                task_replay.append(train_episodes)

            # Compute Validation Loss
            for p in clone.parameters():
                p.detach_().requires_grad_()
            task_policies.append(deepcopy(clone))
            valid_episodes = task.run(clone, episodes=adapt_bsz)
            task_replay.append(valid_episodes)
            iteration_reward += valid_episodes.reward().sum().item(
            ) / adapt_bsz
            iteration_replays.append(task_replay)
            iteration_policies.append(task_policies)

        # Print statistics
        print('\nIteration', iteration)
        adaptation_reward = iteration_reward / meta_bsz
        print('adaptation_reward', adaptation_reward)

        # ProMP meta-optimization
        for ppo_step in tqdm(range(ppo_steps), leave=False, desc='Optim'):
            promp_loss = 0.0
            kl_total = 0.0
            for task_replays, old_policies in zip(iteration_replays,
                                                  iteration_policies):
                new_policy = meta_learner.clone()
                states = task_replays[0].state()
                actions = task_replays[0].action()
                rewards = task_replays[0].reward()
                dones = task_replays[0].done()
                next_states = task_replays[0].next_state()
                old_policy = old_policies[0]
                (old_density, new_density, old_log_probs,
                 new_log_probs) = precompute_quantities(
                     states, actions, old_policy, new_policy)
                advantages = compute_advantages(baseline, tau, gamma, rewards,
                                                dones, states, next_states)
                advantages = ch.normalize(advantages).detach()
                for step in range(adapt_steps):
                    # Compute KL penalty
                    kl_pen = kl_divergence(old_density, new_density).mean()
                    kl_total += kl_pen.item()

                    # Update the clone
                    surr_loss = trpo.policy_loss(new_log_probs, old_log_probs,
                                                 advantages)
                    new_policy.adapt(surr_loss)

                    # Move to next adaptation step
                    states = task_replays[step + 1].state()
                    actions = task_replays[step + 1].action()
                    rewards = task_replays[step + 1].reward()
                    dones = task_replays[step + 1].done()
                    next_states = task_replays[step + 1].next_state()
                    old_policy = old_policies[step + 1]
                    (old_density, new_density, old_log_probs,
                     new_log_probs) = precompute_quantities(
                         states, actions, old_policy, new_policy)

                    # Compute clip loss
                    advantages = compute_advantages(baseline, tau, gamma,
                                                    rewards, dones, states,
                                                    next_states)
                    advantages = ch.normalize(advantages).detach()
                    clip_loss = ppo.policy_loss(new_log_probs,
                                                old_log_probs,
                                                advantages,
                                                clip=ppo_clip)

                    # Combine into ProMP loss
                    promp_loss += clip_loss + eta * kl_pen

            kl_total /= meta_bsz * adapt_steps
            promp_loss /= meta_bsz * adapt_steps
            opt.zero_grad()
            promp_loss.backward(retain_graph=True)
            opt.step()

            # Adapt KL penalty based on desired target
            if adaptive_penalty:
                if kl_total < kl_target / 1.5:
                    eta /= 2.0
                elif kl_total > kl_target * 1.5:
                    eta *= 2.0