Beispiel #1
0
def acktr(env,
          policy,
          val_fn=None,
          total_steps=TOTAL_STEPS_DEFAULT,
          steps=125,
          n_envs=16,
          gamma=0.99,
          gaelam=0.96,
          val_iters=20,
          pikfac=None,
          vfkfac=None,
          warm_start=None,
          linesearch=True,
          **saver_kwargs):
    # handling default values
    pikfac = pikfac or {}
    vfkfac = vfkfac or {}
    val_fn = val_fn or ValueFunction.from_policy(policy)

    # save config and setup state saving
    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    # initialize models and optimizer
    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    pol_optim = KFACOptimizer(policy, **{**DEFAULT_PIKFAC, **pikfac})
    val_optim = KFACOptimizer(val_fn, **{**DEFAULT_VFKFAC, **vfkfac})
    loss_fn = torch.nn.MSELoss()

    # load state if provided
    if warm_start is not None:
        if ":" in warm_start:
            warm_start, index = warm_start.split(":")
        else:
            index = None
        _, state = SnapshotSaver(warm_start,
                                 latest_only=False).get_state(int(index))
        policy.load_state_dict(state["policy"])
        val_fn.load_state_dict(state["val_fn"])
        if "pol_optim" in state:
            pol_optim.load_state_dict(state["pol_optim"])
        if "val_optim" in state:
            val_optim.load_state_dict(state["val_optim"])

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        logger.info("Computing natural gradient using KFAC")
        with pol_optim.record_stats():
            policy.zero_grad()
            all_dists = policy(all_obs)
            all_logp = all_dists.log_prob(all_acts)
            all_logp.mean().backward(retain_graph=True)

        policy.zero_grad()
        old_dists, old_logp = all_dists.detach(), all_logp.detach()
        surr_loss = -((all_logp - old_logp).exp() * all_advs).mean()
        surr_loss.backward()
        pol_grad = [p.grad.clone() for p in policy.parameters()]
        pol_optim.step()

        if linesearch:
            logger.info("Performing line search")
            kl_clip = pol_optim.state["kl_clip"]
            expected_improvement = sum(
                (g * p.grad.data).sum()
                for g, p in zip(pol_grad, policy.parameters())).item()

            def f_barrier(scale):
                for p in policy.parameters():
                    p.data.add_(scale, p.grad.data)
                new_dists = policy(all_obs)
                for p in policy.parameters():
                    p.data.sub_(scale, p.grad.data)
                new_logp = new_dists.log_prob(all_acts)
                surr_loss = -((new_logp - old_logp).exp() * all_advs).mean()
                avg_kl = kl(old_dists, new_dists).mean().item()
                return surr_loss.item() if avg_kl < kl_clip else float("inf")

            scale, expected_improvement, improvement = line_search(
                f_barrier,
                x0=1,
                dx=1,
                expected_improvement=expected_improvement,
                y0=surr_loss.item(),
            )
            logger.logkv("ExpectedImprovement", expected_improvement)
            logger.logkv("ActualImprovement", improvement)
            logger.logkv("ImprovementRatio",
                         improvement / expected_improvement)
            for p in policy.parameters():
                p.data.add_(scale, p.grad.data)

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            with val_optim.record_stats():
                val_fn.zero_grad()
                values = val_fn(all_obs)
                noise = values.detach() + 0.5 * torch.randn_like(values)
                loss_fn(values, noise).backward(retain_graph=True)

            val_fn.zero_grad()
            val_loss = loss_fn(values, all_rets)
            val_loss.backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logu.log_average_kl_divergence(old_dists, policy, all_obs)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            index=samples // stp,
            state=dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                pol_optim=pol_optim.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )

    vec_env.close()
Beispiel #2
0
def trpo(env,
         policy,
         val_fn=None,
         total_steps=TOTAL_STEPS_DEFAULT,
         steps=125,
         n_envs=16,
         gamma=0.99,
         gaelam=0.97,
         kl_frac=1.0,
         delta=0.01,
         val_iters=80,
         val_lr=1e-3,
         linesearch=True,
         **saver_kwargs):
    val_fn = val_fn or ValueFunction.from_policy(policy)
    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr)
    loss_fn = torch.nn.MSELoss()

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        # subsample for fisher vector product computation
        if kl_frac < 1.0:
            n_samples = int(kl_frac * len(all_obs))
            indexes = torch.randperm(len(all_obs))[:n_samples]
            subsamp_obs = all_obs.index_select(0, indexes)
        else:
            subsamp_obs = all_obs

        logger.info("Computing policy gradient")
        all_dists = policy(all_obs)
        all_logp = all_dists.log_prob(all_acts)
        old_dists = all_dists.detach()
        old_logp = old_dists.log_prob(all_acts)
        surr_loss = -((all_logp - old_logp).exp() * all_advs).mean()
        pol_grad = flat_grad(surr_loss, policy.parameters())

        logger.info("Computing truncated natural gradient")

        def fvp(v):
            return fisher_vec_prod(v, subsamp_obs, policy)

        descent_direction = conjugate_gradient(fvp, pol_grad)
        scale = torch.sqrt(2 * delta /
                           (pol_grad.dot(descent_direction) + 1e-8))
        descent_step = descent_direction * scale

        if linesearch:
            logger.info("Performing line search")
            expected_improvement = pol_grad.dot(descent_step).item()

            def f_barrier(params,
                          all_obs=all_obs,
                          all_acts=all_acts,
                          all_advs=all_advs):
                vector_to_parameters(params, policy.parameters())
                new_dists = policy(all_obs)
                new_logp = new_dists.log_prob(all_acts)
                surr_loss = -((new_logp - old_logp).exp() * all_advs).mean()
                avg_kl = kl(old_dists, new_dists).mean().item()
                return surr_loss.item() if avg_kl < delta else float("inf")

            new_params, expected_improvement, improvement = line_search(
                f_barrier,
                parameters_to_vector(policy.parameters()),
                descent_step,
                expected_improvement,
                y0=surr_loss.item(),
            )
            logger.logkv("ExpectedImprovement", expected_improvement)
            logger.logkv("ActualImprovement", improvement)
            logger.logkv("ImprovementRatio",
                         improvement / expected_improvement)
        else:
            new_params = parameters_to_vector(
                policy.parameters()) - descent_step
        vector_to_parameters(new_params, policy.parameters())

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            val_optim.zero_grad()
            loss_fn(val_fn(all_obs), all_rets).backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logu.log_average_kl_divergence(old_dists, policy, all_obs)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            index=samples // stp,
            state=dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )
        del all_obs, all_acts, all_advs, all_vals, all_rets, trajs

    vec_env.close()
Beispiel #3
0
def ppo(env,
        policy,
        val_fn=None,
        total_steps=TOTAL_STEPS_DEFAULT,
        steps=125,
        n_envs=16,
        gamma=0.99,
        gaelam=0.96,
        clip_ratio=0.2,
        pol_iters=80,
        val_iters=80,
        pol_lr=3e-4,
        val_lr=1e-3,
        target_kl=0.01,
        mb_size=100,
        **saver_kwargs):
    val_fn = val_fn or ValueFunction.from_policy(policy)

    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    pol_optim = torch.optim.Adam(policy.parameters(), lr=pol_lr)
    val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr)
    loss_fn = torch.nn.MSELoss()

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        logger.info("Minimizing surrogate loss")
        with torch.no_grad():
            old_dists = policy(all_obs)
        old_logp = old_dists.log_prob(all_acts)
        min_advs = torch.where(all_advs > 0, (1 + clip_ratio) * all_advs,
                               (1 - clip_ratio) * all_advs)
        dataset = TensorDataset(all_obs, all_acts, all_advs, min_advs,
                                old_logp)
        dataloader = DataLoader(dataset, batch_size=mb_size, shuffle=True)
        for itr in range(pol_iters):
            for obs, acts, advs, min_adv, logp in dataloader:
                ratios = (policy(obs).log_prob(acts) - logp).exp()
                pol_optim.zero_grad()
                (-torch.min(ratios * advs, min_adv)).mean().backward()
                pol_optim.step()

            with torch.no_grad():
                mean_kl = kl(old_dists, policy(all_obs)).mean().item()
            if mean_kl > 1.5 * target_kl:
                logger.info(
                    "Stopped at step {} due to reaching max kl".format(itr +
                                                                       1))
                break
        logger.logkv("StopIter", itr + 1)

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            val_optim.zero_grad()
            loss_fn(val_fn(all_obs), all_rets).backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logger.logkv("MeanKL", mean_kl)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            index=samples // stp,
            state=dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                pol_optim=pol_optim.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )

    vec_env.close()
Beispiel #4
0
def natural(env,
            policy,
            val_fn=None,
            total_steps=TOTAL_STEPS_DEFAULT,
            steps=125,
            n_envs=16,
            gamma=0.99,
            gaelam=0.97,
            kl_frac=1.0,
            delta=0.01,
            val_iters=80,
            val_lr=1e-3,
            **saver_kwargs):
    """
    Natural Policy Gradient

    env: instance of proj.common.env_makers.VecEnvMaker
    policy: instance of proj.common.models.Policy
    val_fn (optional): instance of proj.common.models.ValueFunction
    total_steps: total number of environment steps to take
    steps: number of steps to take in each environment per iteration
    n_envs: number of environment copies to run in parallel
    gamma: GAE discount parameter
    gaelam: GAE lambda exponential average parameter
    kl_frac:
    delta:
    val_iters: number of optimization steps to update the critic per iteration
    val_lr: learning rate for critic optimizer
    saver_kwargs: keyword arguments for proj.utils.saver.SnapshotSaver
    """
    val_fn = val_fn or ValueFunction.from_policy(policy)
    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr)
    loss_fn = torch.nn.MSELoss()

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        # subsample for fisher vector product computation
        if kl_frac < 1.0:
            n_samples = int(kl_frac * len(all_obs))
            indexes = torch.randperm(len(all_obs))[:n_samples]
            subsamp_obs = all_obs.index_select(0, indexes)
        else:
            subsamp_obs = all_obs

        logger.info("Computing policy gradient")
        all_dists = policy(all_obs)
        old_dists = all_dists.detach()
        pol_loss = torch.mean(all_dists.log_prob(all_acts) * all_advs).neg()
        pol_grad = flat_grad(pol_loss, policy.parameters())

        logger.info("Computing truncated natural gradient")

        def fvp(v):
            return fisher_vec_prod(v, subsamp_obs, policy)

        descent_direction = conjugate_gradient(fvp, pol_grad)
        scale = torch.sqrt(2 * delta /
                           (pol_grad.dot(descent_direction) + 1e-8))
        descent_step = descent_direction * scale
        new_params = parameters_to_vector(policy.parameters()) - descent_step
        vector_to_parameters(new_params, policy.parameters())

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            val_optim.zero_grad()
            loss_fn(val_fn(all_obs), all_rets).backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logu.log_average_kl_divergence(old_dists, policy, all_obs)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            samples // stp,
            dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )
        del all_obs, all_acts, all_advs, all_vals, all_rets, trajs

    vec_env.close()
Beispiel #5
0
def vanilla(
    env,
    policy,
    val_fn=None,
    total_steps=TOTAL_STEPS_DEFAULT,
    steps=125,
    n_envs=16,
    gamma=0.99,
    gaelam=0.97,
    optimizer=None,
    val_iters=80,
    val_lr=1e-3,
    **saver_kwargs
):
    """
    Vanilla Policy Gradient

    env: instance of proj.common.env_makers.VecEnvMaker
    policy: instance of proj.common.models.Policy
    val_fn (optional): instance of proj.common.models.ValueFunction
    total_steps: total number of environment steps to take
    steps: number of steps to take in each environment per iteration
    n_envs: number of environment copies to run in parallel
    gamma: GAE discount parameter
    gaelam: GAE lambda exponential average parameter
    optimizer (optional): dictionary containing optimizer kwargs and/or class
    val_iters: number of optimization steps to update the critic per iteration
    val_lr: learning rate for critic optimizer
    saver_kwargs: keyword arguments for proj.utils.saver.SnapshotSaver
    """
    optimizer = optimizer or {}
    optimizer = {"class": torch.optim.Adam, **optimizer}
    val_fn = val_fn or ValueFunction.from_policy(policy)

    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    pol_optim = optimizer.pop("class")(policy.parameters(), **optimizer)
    val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr)
    loss_fn = torch.nn.MSELoss()

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        logger.info("Applying policy gradient")
        all_dists = policy(all_obs)
        old_dists = all_dists.detach()
        objective = torch.mean(all_dists.log_prob(all_acts) * all_advs)
        pol_optim.zero_grad()
        objective.neg().backward()
        pol_optim.step()

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            val_optim.zero_grad()
            loss_fn(val_fn(all_obs), all_rets).backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("Objective", objective.item())
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logu.log_average_kl_divergence(old_dists, policy, all_obs)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            samples // stp,
            dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                pol_optim=pol_optim.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )
        del all_obs, all_acts, all_advs, all_vals, all_rets, trajs

    vec_env.close()