Example #1
0
    def thunk_plus():
        import torch
        import random
        import numpy as np
        from baselines import logger
        from proj.utils.tqdm_util import tqdm_out

        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)

        torch.set_num_threads(4)

        with tqdm_out(), logger.scoped_configure(log_dir, format_strs):
            from proj.common.log_utils import save_config

            logger.set_level(logger.WARN)
            save_config({"exp_name": exp_name, "alg": thunk})
            thunk(**kwargs)
Example #2
0
def acktr(env,
          policy,
          val_fn=None,
          total_steps=TOTAL_STEPS_DEFAULT,
          steps=125,
          n_envs=16,
          gamma=0.99,
          gaelam=0.96,
          val_iters=20,
          pikfac=None,
          vfkfac=None,
          warm_start=None,
          linesearch=True,
          **saver_kwargs):
    # handling default values
    pikfac = pikfac or {}
    vfkfac = vfkfac or {}
    val_fn = val_fn or ValueFunction.from_policy(policy)

    # save config and setup state saving
    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    # initialize models and optimizer
    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    pol_optim = KFACOptimizer(policy, **{**DEFAULT_PIKFAC, **pikfac})
    val_optim = KFACOptimizer(val_fn, **{**DEFAULT_VFKFAC, **vfkfac})
    loss_fn = torch.nn.MSELoss()

    # load state if provided
    if warm_start is not None:
        if ":" in warm_start:
            warm_start, index = warm_start.split(":")
        else:
            index = None
        _, state = SnapshotSaver(warm_start,
                                 latest_only=False).get_state(int(index))
        policy.load_state_dict(state["policy"])
        val_fn.load_state_dict(state["val_fn"])
        if "pol_optim" in state:
            pol_optim.load_state_dict(state["pol_optim"])
        if "val_optim" in state:
            val_optim.load_state_dict(state["val_optim"])

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        logger.info("Computing natural gradient using KFAC")
        with pol_optim.record_stats():
            policy.zero_grad()
            all_dists = policy(all_obs)
            all_logp = all_dists.log_prob(all_acts)
            all_logp.mean().backward(retain_graph=True)

        policy.zero_grad()
        old_dists, old_logp = all_dists.detach(), all_logp.detach()
        surr_loss = -((all_logp - old_logp).exp() * all_advs).mean()
        surr_loss.backward()
        pol_grad = [p.grad.clone() for p in policy.parameters()]
        pol_optim.step()

        if linesearch:
            logger.info("Performing line search")
            kl_clip = pol_optim.state["kl_clip"]
            expected_improvement = sum(
                (g * p.grad.data).sum()
                for g, p in zip(pol_grad, policy.parameters())).item()

            def f_barrier(scale):
                for p in policy.parameters():
                    p.data.add_(scale, p.grad.data)
                new_dists = policy(all_obs)
                for p in policy.parameters():
                    p.data.sub_(scale, p.grad.data)
                new_logp = new_dists.log_prob(all_acts)
                surr_loss = -((new_logp - old_logp).exp() * all_advs).mean()
                avg_kl = kl(old_dists, new_dists).mean().item()
                return surr_loss.item() if avg_kl < kl_clip else float("inf")

            scale, expected_improvement, improvement = line_search(
                f_barrier,
                x0=1,
                dx=1,
                expected_improvement=expected_improvement,
                y0=surr_loss.item(),
            )
            logger.logkv("ExpectedImprovement", expected_improvement)
            logger.logkv("ActualImprovement", improvement)
            logger.logkv("ImprovementRatio",
                         improvement / expected_improvement)
            for p in policy.parameters():
                p.data.add_(scale, p.grad.data)

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            with val_optim.record_stats():
                val_fn.zero_grad()
                values = val_fn(all_obs)
                noise = values.detach() + 0.5 * torch.randn_like(values)
                loss_fn(values, noise).backward(retain_graph=True)

            val_fn.zero_grad()
            val_loss = loss_fn(values, all_rets)
            val_loss.backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logu.log_average_kl_divergence(old_dists, policy, all_obs)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            index=samples // stp,
            state=dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                pol_optim=pol_optim.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )

    vec_env.close()
Example #3
0
def sac(env,
        policy,
        q_func=None,
        val_fn=None,
        total_steps=TOTAL_STEPS_DEFAULT,
        gamma=0.99,
        replay_size=REPLAY_SIZE_DEFAULT,
        polyak=0.995,
        start_steps=10000,
        epoch=5000,
        mb_size=100,
        lr=1e-3,
        alpha=0.2,
        target_entropy=None,
        reward_scale=1.0,
        updates_per_step=1.0,
        max_ep_length=1000,
        **saver_kwargs):

    # Set and save experiment hyperparameters
    q_func = q_func or ContinuousQFunction.from_policy(policy)
    val_fn = val_fn or ValueFunction.from_policy(policy)
    save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    # Initialize environments, models and replay buffer
    vec_env = VecEnvMaker(env)()
    test_env = VecEnvMaker(env)(train=False)
    ob_space, ac_space = vec_env.observation_space, vec_env.action_space
    pi_class, pi_args = policy.pop("class"), policy
    qf_class, qf_args = q_func.pop("class"), q_func
    vf_class, vf_args = val_fn.pop("class"), val_fn
    policy = pi_class(vec_env, **pi_args)
    q1func = qf_class(vec_env, **qf_args)
    q2func = qf_class(vec_env, **qf_args)
    val_fn = vf_class(vec_env, **vf_args)
    replay = ReplayBuffer(replay_size, ob_space, ac_space)
    if target_entropy is not None:
        log_alpha = torch.nn.Parameter(torch.zeros([]))
        if target_entropy == "auto":
            target_entropy = -np.prod(ac_space.shape)

    # Initialize optimizers and target networks
    loss_fn = torch.nn.MSELoss()
    pi_optim = torch.optim.Adam(policy.parameters(), lr=lr)
    qf_optim = torch.optim.Adam(list(q1func.parameters()) +
                                list(q2func.parameters()),
                                lr=lr)
    vf_optim = torch.optim.Adam(val_fn.parameters(), lr=lr)
    vf_targ = vf_class(vec_env, **vf_args)
    vf_targ.load_state_dict(val_fn.state_dict())
    if target_entropy is not None:
        al_optim = torch.optim.Adam([log_alpha], lr=lr)

    # Save initial state
    state = dict(
        alg=dict(samples=0),
        policy=policy.state_dict(),
        q1func=q1func.state_dict(),
        q2func=q2func.state_dict(),
        val_fn=val_fn.state_dict(),
        pi_optim=pi_optim.state_dict(),
        qf_optim=qf_optim.state_dict(),
        vf_optim=vf_optim.state_dict(),
        vf_targ=vf_targ.state_dict(),
    )
    if target_entropy is not None:
        state["log_alpha"] = log_alpha
        state["al_optim"] = al_optim.state_dict()
    saver.save_state(index=0, state=state)

    # Setup and run policy tests
    ob, don = test_env.reset(), False

    @torch.no_grad()
    def test_policy():
        nonlocal ob, don
        policy.eval()
        for _ in range(10):
            while not don:
                act = policy.actions(torch.from_numpy(ob))
                ob, _, don, _ = test_env.step(act.numpy())
            don = False
        policy.train()
        log_reward_statistics(test_env, num_last_eps=10, prefix="Test")

    test_policy()
    logger.logkv("Epoch", 0)
    logger.logkv("TotalNSamples", 0)
    logger.dumpkvs()

    # Set action sampling strategies
    def rand_uniform_actions(_):
        return np.stack([ac_space.sample() for _ in range(vec_env.num_envs)])

    @torch.no_grad()
    def stoch_policy_actions(obs):
        return policy.actions(torch.from_numpy(obs)).numpy()

    # Algorithm main loop
    obs1, ep_length = vec_env.reset(), 0
    for samples in trange(1, total_steps + 1, desc="Training", unit="step"):
        if samples <= start_steps:
            actions = rand_uniform_actions
        else:
            actions = stoch_policy_actions

        acts = actions(obs1)
        obs2, rews, dones, _ = vec_env.step(acts)
        ep_length += 1
        dones[0] = False if ep_length == max_ep_length else dones[0]

        as_tensors = map(torch.from_numpy,
                         (obs1, acts, rews, obs2, dones.astype("f")))
        for ob1, act, rew, ob2, done in zip(*as_tensors):
            replay.store(ob1, act, rew, ob2, done)
        obs1 = obs2

        if (dones[0] or ep_length == max_ep_length) and replay.size >= mb_size:
            for _ in range(int(ep_length * updates_per_step)):
                ob_1, act_, rew_, ob_2, done_ = replay.sample(mb_size)
                dist = policy(ob_1)
                pi_a = dist.rsample()
                logp = dist.log_prob(pi_a)
                if target_entropy is not None:
                    al_optim.zero_grad()
                    alpha_loss = torch.mean(
                        log_alpha * (logp.detach() + target_entropy)).neg()
                    alpha_loss.backward()
                    al_optim.step()
                    logger.logkv_mean("AlphaLoss", alpha_loss.item())
                    alpha = log_alpha.exp().item()

                with torch.no_grad():
                    y_qf = reward_scale * rew_ + gamma * (
                        1 - done_) * vf_targ(ob_2)
                    y_vf = (torch.min(q1func(ob_1, pi_a), q2func(ob_1, pi_a)) -
                            alpha * logp)

                qf_optim.zero_grad()
                q1_val = q1func(ob_1, act_)
                q2_val = q2func(ob_1, act_)
                q1_loss = loss_fn(q1_val, y_qf).div(2)
                q2_loss = loss_fn(q2_val, y_qf).div(2)
                q1_loss.add(q2_loss).backward()
                qf_optim.step()

                vf_optim.zero_grad()
                vf_val = val_fn(ob_1)
                vf_loss = loss_fn(vf_val, y_vf).div(2)
                vf_loss.backward()
                vf_optim.step()

                pi_optim.zero_grad()
                qpi_val = q1func(ob_1, pi_a)
                # qpi_val = torch.min(q1func(ob_1, pi_a), q2func(ob_1, pi_a))
                pi_loss = qpi_val.sub(logp, alpha=alpha).mean().neg()
                pi_loss.backward()
                pi_optim.step()

                update_polyak(val_fn, vf_targ, polyak)

                logger.logkv_mean("Entropy", logp.mean().neg().item())
                logger.logkv_mean("Q1Val", q1_val.mean().item())
                logger.logkv_mean("Q2Val", q2_val.mean().item())
                logger.logkv_mean("VFVal", vf_val.mean().item())
                logger.logkv_mean("QPiVal", qpi_val.mean().item())
                logger.logkv_mean("Q1Loss", q1_loss.item())
                logger.logkv_mean("Q2Loss", q2_loss.item())
                logger.logkv_mean("VFLoss", vf_loss.item())
                logger.logkv_mean("PiLoss", pi_loss.item())
                logger.logkv_mean("Alpha", alpha)

            ep_length = 0

        if samples % epoch == 0:
            test_policy()
            logger.logkv("Epoch", samples // epoch)
            logger.logkv("TotalNSamples", samples)
            log_reward_statistics(vec_env)
            logger.dumpkvs()

            state = dict(
                alg=dict(samples=samples),
                policy=policy.state_dict(),
                q1func=q1func.state_dict(),
                q2func=q2func.state_dict(),
                val_fn=val_fn.state_dict(),
                pi_optim=pi_optim.state_dict(),
                qf_optim=qf_optim.state_dict(),
                vf_optim=vf_optim.state_dict(),
                vf_targ=vf_targ.state_dict(),
            )
            if target_entropy is not None:
                state["log_alpha"] = log_alpha
                state["al_optim"] = al_optim.state_dict()
            saver.save_state(index=samples // epoch, state=state)
Example #4
0
def a2c(env,
        policy,
        val_fn=None,
        total_steps=TOTAL_STEPS_DEFAULT,
        steps=20,
        n_envs=16,
        gamma=0.99,
        optimizer=None,
        max_grad_norm=0.5,
        ent_coeff=0.01,
        vf_loss_coeff=0.5,
        log_interval=100,
        **saver_kwargs):
    assert val_fn is None or not issubclass(
        policy["class"], WeightSharingAC
    ), "Choose between a weight sharing model or separate policy and val_fn"

    optimizer = optimizer or {}
    optimizer = {
        "class": torch.optim.RMSprop,
        "lr": 1e-3,
        "eps": 1e-5,
        "alpha": 0.99,
        **optimizer,
    }
    if val_fn is None and not issubclass(policy["class"], WeightSharingAC):
        val_fn = ValueFunction.from_policy(policy)

    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    param_list = torch.nn.ParameterList(policy.parameters())
    if val_fn is not None:
        val_fn = val_fn.pop("class")(vec_env, **val_fn)
        param_list.extend(val_fn.parameters())
    optimizer = optimizer.pop("class")(param_list.parameters(), **optimizer)
    loss_fn = torch.nn.MSELoss()

    # Algorith main loop
    if val_fn is None:
        compute_dists_vals = policy
    else:

        def compute_dists_vals(obs):
            return policy(obs), val_fn(obs)

    generator = samples_generator(vec_env, policy, steps, compute_dists_vals)
    logger.info("Starting epoch {}".format(1))
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        all_acts, all_rews, all_dones, all_dists, all_vals, next_vals = next(
            generator)

        # Compute returns and advantages
        all_rets = all_rews.clone()
        all_rets[-1] += gamma * (1 - all_dones[-1]) * next_vals
        for i in reversed(range(steps - 1)):
            all_rets[i] += gamma * (1 - all_dones[i]) * all_rets[i + 1]
        all_advs = all_rets - all_vals.detach()

        # Compute loss
        log_li = all_dists.log_prob(all_acts.reshape(stp, -1).squeeze())
        pi_loss = -torch.mean(log_li * all_advs.flatten())
        vf_loss = loss_fn(all_vals.flatten(), all_rets.flatten())
        entropy = all_dists.entropy().mean()
        total_loss = pi_loss - ent_coeff * entropy + vf_loss_coeff * vf_loss

        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(param_list.parameters(), max_grad_norm)
        optimizer.step()

        updates = samples // stp
        if updates == 1 or updates % log_interval == 0:
            logger.logkv("Epoch", updates // log_interval + 1)
            logger.logkv("TotalNSamples", samples)
            logu.log_reward_statistics(vec_env)
            logu.log_val_fn_statistics(all_vals.flatten(), all_rets.flatten())
            logu.log_action_distribution_statistics(all_dists)
            logger.dumpkvs()
            logger.info("Starting epoch {}".format(updates // log_interval +
                                                   2))

        saver.save_state(
            index=updates,
            state=dict(
                alg=dict(last_updt=updates),
                policy=policy.state_dict(),
                val_fn=None if val_fn is None else val_fn.state_dict(),
                optimizer=optimizer.state_dict(),
            ),
        )

    vec_env.close()
Example #5
0
def td3(env,
        policy,
        q_func=None,
        total_steps=TOTAL_STEPS_DEFAULT,
        gamma=0.99,
        replay_size=REPLAY_SIZE_DEFAULT,
        polyak=0.995,
        start_steps=10000,
        epoch=5000,
        pi_lr=1e-3,
        qf_lr=1e-3,
        mb_size=100,
        act_noise=0.1,
        max_ep_length=1000,
        target_noise=0.2,
        noise_clip=0.5,
        policy_delay=2,
        updates_per_step=1.0,
        **saver_kwargs):
    # Set and save experiment hyperparameters
    q_func = q_func or ContinuousQFunction.from_policy(policy)
    save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    # Initialize environments, models and replay buffer
    vec_env = VecEnvMaker(env)()
    test_env = VecEnvMaker(env)(train=False)
    ob_space, ac_space = vec_env.observation_space, vec_env.action_space
    pi_class, pi_args = policy.pop("class"), policy
    qf_class, qf_args = q_func.pop("class"), q_func
    policy = pi_class(vec_env, **pi_args)
    q1func = qf_class(vec_env, **qf_args)
    q2func = qf_class(vec_env, **qf_args)
    replay = ReplayBuffer(replay_size, ob_space, ac_space)

    # Initialize optimizers and target networks
    loss_fn = torch.nn.MSELoss()
    pi_optim = torch.optim.Adam(policy.parameters(), lr=pi_lr)
    qf_optim = torch.optim.Adam(chain(q1func.parameters(),
                                      q2func.parameters()),
                                lr=qf_lr)
    pi_targ = pi_class(vec_env, **pi_args)
    q1_targ = qf_class(vec_env, **qf_args)
    q2_targ = qf_class(vec_env, **qf_args)
    pi_targ.load_state_dict(policy.state_dict())
    q1_targ.load_state_dict(q1func.state_dict())
    q2_targ.load_state_dict(q2func.state_dict())

    # Save initial state
    saver.save_state(
        index=0,
        state=dict(
            alg=dict(samples=0),
            policy=policy.state_dict(),
            q1func=q1func.state_dict(),
            q2func=q2func.state_dict(),
            pi_optim=pi_optim.state_dict(),
            qf_optim=qf_optim.state_dict(),
            pi_targ=pi_targ.state_dict(),
            q1_targ=q1_targ.state_dict(),
            q2_targ=q2_targ.state_dict(),
        ),
    )

    # Setup and run policy tests
    ob, don = test_env.reset(), False

    @torch.no_grad()
    def test_policy():
        nonlocal ob, don
        for _ in range(10):
            while not don:
                act = policy.actions(torch.from_numpy(ob))
                ob, _, don, _ = test_env.step(act.numpy())
            don = False
        log_reward_statistics(test_env, num_last_eps=10, prefix="Test")

    test_policy()
    logger.logkv("Epoch", 0)
    logger.logkv("TotalNSamples", 0)
    logger.dumpkvs()

    # Set action sampling strategies
    def rand_uniform_actions(_):
        return np.stack([ac_space.sample() for _ in range(vec_env.num_envs)])

    act_low, act_high = map(torch.Tensor, (ac_space.low, ac_space.high))

    @torch.no_grad()
    def noisy_policy_actions(obs):
        acts = policy.actions(torch.from_numpy(obs))
        acts += act_noise * torch.randn_like(acts)
        return np.clip(acts.numpy(), ac_space.low, ac_space.high)

    # Algorithm main loop
    obs1, ep_length, critic_updates = vec_env.reset(), 0, 0
    for samples in trange(1, total_steps + 1, desc="Training", unit="step"):
        if samples <= start_steps:
            actions = rand_uniform_actions
        else:
            actions = noisy_policy_actions

        acts = actions(obs1)
        obs2, rews, dones, _ = vec_env.step(acts)
        ep_length += 1
        dones[0] = False if ep_length == max_ep_length else dones[0]

        as_tensors = map(torch.from_numpy,
                         (obs1, acts, rews, obs2, dones.astype("f")))
        for ob1, act, rew, ob2, done in zip(*as_tensors):
            replay.store(ob1, act, rew, ob2, done)
        obs1 = obs2

        if (dones[0] or ep_length == max_ep_length) and replay.size >= mb_size:
            for _ in range(int(ep_length * updates_per_step)):
                ob_1, act_, rew_, ob_2, done_ = replay.sample(mb_size)
                with torch.no_grad():
                    atarg = pi_targ(ob_2)
                    atarg += torch.clamp(
                        target_noise * torch.randn_like(atarg), -noise_clip,
                        noise_clip)
                    atarg = torch.max(torch.min(atarg, act_high), act_low)
                    targs = rew_ + gamma * (1 - done_) * torch.min(
                        q1_targ(ob_2, atarg), q2_targ(ob_2, atarg))

                qf_optim.zero_grad()
                q1_val = q1func(ob_1, act_)
                q2_val = q2func(ob_1, act_)
                q1_loss = loss_fn(q1_val, targs).div(2)
                q2_loss = loss_fn(q2_val, targs).div(2)
                q1_loss.add(q2_loss).backward()
                qf_optim.step()

                critic_updates += 1
                if critic_updates % policy_delay == 0:
                    pi_optim.zero_grad()
                    qpi_val = q1func(ob_1, policy(ob_1))
                    pi_loss = qpi_val.mean().neg()
                    pi_loss.backward()
                    pi_optim.step()

                    update_polyak(policy, pi_targ, polyak)
                    update_polyak(q1func, q1_targ, polyak)
                    update_polyak(q2func, q2_targ, polyak)

                    logger.logkv_mean("QPiVal", qpi_val.mean().item())
                    logger.logkv_mean("PiLoss", pi_loss.item())

                logger.logkv_mean("Q1Val", q1_val.mean().item())
                logger.logkv_mean("Q2Val", q2_val.mean().item())
                logger.logkv_mean("Q1Loss", q1_loss.item())
                logger.logkv_mean("Q2Loss", q2_loss.item())

            ep_length = 0

        if samples % epoch == 0:
            test_policy()
            logger.logkv("Epoch", samples // epoch)
            logger.logkv("TotalNSamples", samples)
            log_reward_statistics(vec_env)
            logger.dumpkvs()

            saver.save_state(
                index=samples // epoch,
                state=dict(
                    alg=dict(samples=samples),
                    policy=policy.state_dict(),
                    q1func=q1func.state_dict(),
                    q2func=q2func.state_dict(),
                    pi_optim=pi_optim.state_dict(),
                    qf_optim=qf_optim.state_dict(),
                    pi_targ=pi_targ.state_dict(),
                    q1_targ=q1_targ.state_dict(),
                    q2_targ=q2_targ.state_dict(),
                ),
            )
Example #6
0
def trpo(env,
         policy,
         val_fn=None,
         total_steps=TOTAL_STEPS_DEFAULT,
         steps=125,
         n_envs=16,
         gamma=0.99,
         gaelam=0.97,
         kl_frac=1.0,
         delta=0.01,
         val_iters=80,
         val_lr=1e-3,
         linesearch=True,
         **saver_kwargs):
    val_fn = val_fn or ValueFunction.from_policy(policy)
    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr)
    loss_fn = torch.nn.MSELoss()

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        # subsample for fisher vector product computation
        if kl_frac < 1.0:
            n_samples = int(kl_frac * len(all_obs))
            indexes = torch.randperm(len(all_obs))[:n_samples]
            subsamp_obs = all_obs.index_select(0, indexes)
        else:
            subsamp_obs = all_obs

        logger.info("Computing policy gradient")
        all_dists = policy(all_obs)
        all_logp = all_dists.log_prob(all_acts)
        old_dists = all_dists.detach()
        old_logp = old_dists.log_prob(all_acts)
        surr_loss = -((all_logp - old_logp).exp() * all_advs).mean()
        pol_grad = flat_grad(surr_loss, policy.parameters())

        logger.info("Computing truncated natural gradient")

        def fvp(v):
            return fisher_vec_prod(v, subsamp_obs, policy)

        descent_direction = conjugate_gradient(fvp, pol_grad)
        scale = torch.sqrt(2 * delta /
                           (pol_grad.dot(descent_direction) + 1e-8))
        descent_step = descent_direction * scale

        if linesearch:
            logger.info("Performing line search")
            expected_improvement = pol_grad.dot(descent_step).item()

            def f_barrier(params,
                          all_obs=all_obs,
                          all_acts=all_acts,
                          all_advs=all_advs):
                vector_to_parameters(params, policy.parameters())
                new_dists = policy(all_obs)
                new_logp = new_dists.log_prob(all_acts)
                surr_loss = -((new_logp - old_logp).exp() * all_advs).mean()
                avg_kl = kl(old_dists, new_dists).mean().item()
                return surr_loss.item() if avg_kl < delta else float("inf")

            new_params, expected_improvement, improvement = line_search(
                f_barrier,
                parameters_to_vector(policy.parameters()),
                descent_step,
                expected_improvement,
                y0=surr_loss.item(),
            )
            logger.logkv("ExpectedImprovement", expected_improvement)
            logger.logkv("ActualImprovement", improvement)
            logger.logkv("ImprovementRatio",
                         improvement / expected_improvement)
        else:
            new_params = parameters_to_vector(
                policy.parameters()) - descent_step
        vector_to_parameters(new_params, policy.parameters())

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            val_optim.zero_grad()
            loss_fn(val_fn(all_obs), all_rets).backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logu.log_average_kl_divergence(old_dists, policy, all_obs)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            index=samples // stp,
            state=dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )
        del all_obs, all_acts, all_advs, all_vals, all_rets, trajs

    vec_env.close()
Example #7
0
def natural(env,
            policy,
            val_fn=None,
            total_steps=TOTAL_STEPS_DEFAULT,
            steps=125,
            n_envs=16,
            gamma=0.99,
            gaelam=0.97,
            kl_frac=1.0,
            delta=0.01,
            val_iters=80,
            val_lr=1e-3,
            **saver_kwargs):
    """
    Natural Policy Gradient

    env: instance of proj.common.env_makers.VecEnvMaker
    policy: instance of proj.common.models.Policy
    val_fn (optional): instance of proj.common.models.ValueFunction
    total_steps: total number of environment steps to take
    steps: number of steps to take in each environment per iteration
    n_envs: number of environment copies to run in parallel
    gamma: GAE discount parameter
    gaelam: GAE lambda exponential average parameter
    kl_frac:
    delta:
    val_iters: number of optimization steps to update the critic per iteration
    val_lr: learning rate for critic optimizer
    saver_kwargs: keyword arguments for proj.utils.saver.SnapshotSaver
    """
    val_fn = val_fn or ValueFunction.from_policy(policy)
    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr)
    loss_fn = torch.nn.MSELoss()

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        # subsample for fisher vector product computation
        if kl_frac < 1.0:
            n_samples = int(kl_frac * len(all_obs))
            indexes = torch.randperm(len(all_obs))[:n_samples]
            subsamp_obs = all_obs.index_select(0, indexes)
        else:
            subsamp_obs = all_obs

        logger.info("Computing policy gradient")
        all_dists = policy(all_obs)
        old_dists = all_dists.detach()
        pol_loss = torch.mean(all_dists.log_prob(all_acts) * all_advs).neg()
        pol_grad = flat_grad(pol_loss, policy.parameters())

        logger.info("Computing truncated natural gradient")

        def fvp(v):
            return fisher_vec_prod(v, subsamp_obs, policy)

        descent_direction = conjugate_gradient(fvp, pol_grad)
        scale = torch.sqrt(2 * delta /
                           (pol_grad.dot(descent_direction) + 1e-8))
        descent_step = descent_direction * scale
        new_params = parameters_to_vector(policy.parameters()) - descent_step
        vector_to_parameters(new_params, policy.parameters())

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            val_optim.zero_grad()
            loss_fn(val_fn(all_obs), all_rets).backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logu.log_average_kl_divergence(old_dists, policy, all_obs)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            samples // stp,
            dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )
        del all_obs, all_acts, all_advs, all_vals, all_rets, trajs

    vec_env.close()
Example #8
0
def a2c_kfac(env,
             policy,
             val_fn=None,
             total_steps=TOTAL_STEPS_DEFAULT,
             steps=20,
             n_envs=16,
             kfac=None,
             ent_coeff=0.01,
             vf_loss_coeff=0.5,
             gamma=0.99,
             log_interval=100,
             warm_start=None,
             **saver_kwargs):
    assert val_fn is None or not issubclass(
        policy["class"], WeightSharingAC
    ), "Choose between a weight sharing model or separate policy and val_fn"

    # handle default values
    kfac = kfac or {}
    kfac = {
        "eps": 1e-3,
        "pi": True,
        "alpha": 0.95,
        "kl_clip": 1e-3,
        "eta": 1.0,
        **kfac
    }
    if val_fn is None and not issubclass(policy["class"], WeightSharingAC):
        val_fn = ValueFunction.from_policy(policy)

    # save config and setup state saving
    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    # initialize models and optimizer
    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    module_list = torch.nn.ModuleList(policy.modules())
    if val_fn is not None:
        val_fn = val_fn.pop("class")(vec_env, **val_fn)
        module_list.extend(val_fn.modules())
    optimizer = KFACOptimizer(module_list, **kfac)
    # scheduler = LinearLR(optimizer, total_steps // (steps*n_envs))
    loss_fn = torch.nn.MSELoss()

    # load state if provided
    updates = 0
    if warm_start is not None:
        if ":" in warm_start:
            warm_start, index = warm_start.split(":")
        else:
            index = None
        config, state = SnapshotSaver(warm_start,
                                      latest_only=False).get_state(int(index))
        policy.load_state_dict(state["policy"])
        if "optimizer" in state:
            optimizer.load_state_dict(state["optimizer"])
        updates = state["alg"]["last_updt"]

    # Algorith main loop
    if val_fn is None:
        compute_dists_vals = policy
    else:

        def compute_dists_vals(obs):
            return policy(obs), val_fn(obs)

    ob_space, ac_space = vec_env.observation_space, vec_env.action_space
    obs = torch.from_numpy(vec_env.reset())
    with torch.no_grad():
        acts = policy.actions(obs)
    logger.info("Starting epoch {}".format(1))
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    total_updates = total_steps // stp
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        all_obs = torch.empty((steps, n_envs) + ob_space.shape,
                              dtype=_NP_TO_PT[ob_space.dtype.type])
        all_acts = torch.empty((steps, n_envs) + ac_space.shape,
                               dtype=_NP_TO_PT[ac_space.dtype.type])
        all_rews = torch.empty((steps, n_envs))
        all_dones = torch.empty((steps, n_envs))

        with torch.no_grad():
            for i in range(steps):
                next_obs, rews, dones, _ = vec_env.step(acts.numpy())
                all_obs[i] = obs
                all_acts[i] = acts
                all_rews[i] = torch.from_numpy(rews)
                all_dones[i] = torch.from_numpy(dones.astype("f"))
                obs = torch.from_numpy(next_obs)

                acts = policy.actions(obs)

        all_obs = all_obs.reshape(stp, -1).squeeze()
        all_acts = all_acts.reshape(stp, -1).squeeze()

        # Sample Fisher curvature matrix
        with optimizer.record_stats():
            optimizer.zero_grad()
            all_dists, all_vals = compute_dists_vals(all_obs)
            logp = all_dists.log_prob(all_acts)
            noise = all_vals.detach() + 0.5 * torch.randn_like(all_vals)
            (logp.mean() +
             loss_fn(all_vals, noise)).backward(retain_graph=True)

        # Compute returns and advantages
        with torch.no_grad():
            _, next_vals = compute_dists_vals(obs)
        all_rets = all_rews.clone()
        all_rets[-1] += gamma * (1 - all_dones[-1]) * next_vals
        for i in reversed(range(steps - 1)):
            all_rets[i] += gamma * (1 - all_dones[i]) * all_rets[i + 1]
        all_rets = all_rets.flatten()
        all_advs = all_rets - all_vals.detach()

        # Compute loss
        updates += 1
        # ent_coeff = ent_coeff*0.99 if updates % 10 == 0 else ent_coeff
        pi_loss = -torch.mean(logp * all_advs)
        vf_loss = loss_fn(all_vals, all_rets)
        entropy = all_dists.entropy().mean()
        total_loss = pi_loss - ent_coeff * entropy + vf_loss_coeff * vf_loss

        # scheduler.step()
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if updates == 1 or updates % log_interval == 0:
            logger.logkv("Epoch", updates // log_interval + 1)
            logger.logkv("TotalNSamples", samples)
            logu.log_reward_statistics(vec_env)
            logu.log_val_fn_statistics(all_vals, all_rets)
            logu.log_action_distribution_statistics(all_dists)
            logger.dumpkvs()
            logger.info("Starting epoch {}".format(updates // log_interval +
                                                   2))

        saver.save_state(
            index=updates,
            state=dict(
                alg=dict(last_updt=updates),
                policy=policy.state_dict(),
                val_fn=None if val_fn is None else val_fn.state_dict(),
                optimizer=optimizer.state_dict(),
            ),
        )

    vec_env.close()
Example #9
0
def ppo(env,
        policy,
        val_fn=None,
        total_steps=TOTAL_STEPS_DEFAULT,
        steps=125,
        n_envs=16,
        gamma=0.99,
        gaelam=0.96,
        clip_ratio=0.2,
        pol_iters=80,
        val_iters=80,
        pol_lr=3e-4,
        val_lr=1e-3,
        target_kl=0.01,
        mb_size=100,
        **saver_kwargs):
    val_fn = val_fn or ValueFunction.from_policy(policy)

    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    pol_optim = torch.optim.Adam(policy.parameters(), lr=pol_lr)
    val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr)
    loss_fn = torch.nn.MSELoss()

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        logger.info("Minimizing surrogate loss")
        with torch.no_grad():
            old_dists = policy(all_obs)
        old_logp = old_dists.log_prob(all_acts)
        min_advs = torch.where(all_advs > 0, (1 + clip_ratio) * all_advs,
                               (1 - clip_ratio) * all_advs)
        dataset = TensorDataset(all_obs, all_acts, all_advs, min_advs,
                                old_logp)
        dataloader = DataLoader(dataset, batch_size=mb_size, shuffle=True)
        for itr in range(pol_iters):
            for obs, acts, advs, min_adv, logp in dataloader:
                ratios = (policy(obs).log_prob(acts) - logp).exp()
                pol_optim.zero_grad()
                (-torch.min(ratios * advs, min_adv)).mean().backward()
                pol_optim.step()

            with torch.no_grad():
                mean_kl = kl(old_dists, policy(all_obs)).mean().item()
            if mean_kl > 1.5 * target_kl:
                logger.info(
                    "Stopped at step {} due to reaching max kl".format(itr +
                                                                       1))
                break
        logger.logkv("StopIter", itr + 1)

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            val_optim.zero_grad()
            loss_fn(val_fn(all_obs), all_rets).backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logger.logkv("MeanKL", mean_kl)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            index=samples // stp,
            state=dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                pol_optim=pol_optim.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )

    vec_env.close()
Example #10
0
def vanilla(
    env,
    policy,
    val_fn=None,
    total_steps=TOTAL_STEPS_DEFAULT,
    steps=125,
    n_envs=16,
    gamma=0.99,
    gaelam=0.97,
    optimizer=None,
    val_iters=80,
    val_lr=1e-3,
    **saver_kwargs
):
    """
    Vanilla Policy Gradient

    env: instance of proj.common.env_makers.VecEnvMaker
    policy: instance of proj.common.models.Policy
    val_fn (optional): instance of proj.common.models.ValueFunction
    total_steps: total number of environment steps to take
    steps: number of steps to take in each environment per iteration
    n_envs: number of environment copies to run in parallel
    gamma: GAE discount parameter
    gaelam: GAE lambda exponential average parameter
    optimizer (optional): dictionary containing optimizer kwargs and/or class
    val_iters: number of optimization steps to update the critic per iteration
    val_lr: learning rate for critic optimizer
    saver_kwargs: keyword arguments for proj.utils.saver.SnapshotSaver
    """
    optimizer = optimizer or {}
    optimizer = {"class": torch.optim.Adam, **optimizer}
    val_fn = val_fn or ValueFunction.from_policy(policy)

    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    pol_optim = optimizer.pop("class")(policy.parameters(), **optimizer)
    val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr)
    loss_fn = torch.nn.MSELoss()

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        logger.info("Applying policy gradient")
        all_dists = policy(all_obs)
        old_dists = all_dists.detach()
        objective = torch.mean(all_dists.log_prob(all_acts) * all_advs)
        pol_optim.zero_grad()
        objective.neg().backward()
        pol_optim.step()

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            val_optim.zero_grad()
            loss_fn(val_fn(all_obs), all_rets).backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("Objective", objective.item())
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logu.log_average_kl_divergence(old_dists, policy, all_obs)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            samples // stp,
            dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                pol_optim=pol_optim.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )
        del all_obs, all_acts, all_advs, all_vals, all_rets, trajs

    vec_env.close()