Ejemplo n.º 1
0
def a2c(env,
        policy,
        val_fn=None,
        total_steps=TOTAL_STEPS_DEFAULT,
        steps=20,
        n_envs=16,
        gamma=0.99,
        optimizer=None,
        max_grad_norm=0.5,
        ent_coeff=0.01,
        vf_loss_coeff=0.5,
        log_interval=100,
        **saver_kwargs):
    assert val_fn is None or not issubclass(
        policy["class"], WeightSharingAC
    ), "Choose between a weight sharing model or separate policy and val_fn"

    optimizer = optimizer or {}
    optimizer = {
        "class": torch.optim.RMSprop,
        "lr": 1e-3,
        "eps": 1e-5,
        "alpha": 0.99,
        **optimizer,
    }
    if val_fn is None and not issubclass(policy["class"], WeightSharingAC):
        val_fn = ValueFunction.from_policy(policy)

    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    param_list = torch.nn.ParameterList(policy.parameters())
    if val_fn is not None:
        val_fn = val_fn.pop("class")(vec_env, **val_fn)
        param_list.extend(val_fn.parameters())
    optimizer = optimizer.pop("class")(param_list.parameters(), **optimizer)
    loss_fn = torch.nn.MSELoss()

    # Algorith main loop
    if val_fn is None:
        compute_dists_vals = policy
    else:

        def compute_dists_vals(obs):
            return policy(obs), val_fn(obs)

    generator = samples_generator(vec_env, policy, steps, compute_dists_vals)
    logger.info("Starting epoch {}".format(1))
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        all_acts, all_rews, all_dones, all_dists, all_vals, next_vals = next(
            generator)

        # Compute returns and advantages
        all_rets = all_rews.clone()
        all_rets[-1] += gamma * (1 - all_dones[-1]) * next_vals
        for i in reversed(range(steps - 1)):
            all_rets[i] += gamma * (1 - all_dones[i]) * all_rets[i + 1]
        all_advs = all_rets - all_vals.detach()

        # Compute loss
        log_li = all_dists.log_prob(all_acts.reshape(stp, -1).squeeze())
        pi_loss = -torch.mean(log_li * all_advs.flatten())
        vf_loss = loss_fn(all_vals.flatten(), all_rets.flatten())
        entropy = all_dists.entropy().mean()
        total_loss = pi_loss - ent_coeff * entropy + vf_loss_coeff * vf_loss

        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(param_list.parameters(), max_grad_norm)
        optimizer.step()

        updates = samples // stp
        if updates == 1 or updates % log_interval == 0:
            logger.logkv("Epoch", updates // log_interval + 1)
            logger.logkv("TotalNSamples", samples)
            logu.log_reward_statistics(vec_env)
            logu.log_val_fn_statistics(all_vals.flatten(), all_rets.flatten())
            logu.log_action_distribution_statistics(all_dists)
            logger.dumpkvs()
            logger.info("Starting epoch {}".format(updates // log_interval +
                                                   2))

        saver.save_state(
            index=updates,
            state=dict(
                alg=dict(last_updt=updates),
                policy=policy.state_dict(),
                val_fn=None if val_fn is None else val_fn.state_dict(),
                optimizer=optimizer.state_dict(),
            ),
        )

    vec_env.close()
Ejemplo n.º 2
0
def acktr(env,
          policy,
          val_fn=None,
          total_steps=TOTAL_STEPS_DEFAULT,
          steps=125,
          n_envs=16,
          gamma=0.99,
          gaelam=0.96,
          val_iters=20,
          pikfac=None,
          vfkfac=None,
          warm_start=None,
          linesearch=True,
          **saver_kwargs):
    # handling default values
    pikfac = pikfac or {}
    vfkfac = vfkfac or {}
    val_fn = val_fn or ValueFunction.from_policy(policy)

    # save config and setup state saving
    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    # initialize models and optimizer
    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    pol_optim = KFACOptimizer(policy, **{**DEFAULT_PIKFAC, **pikfac})
    val_optim = KFACOptimizer(val_fn, **{**DEFAULT_VFKFAC, **vfkfac})
    loss_fn = torch.nn.MSELoss()

    # load state if provided
    if warm_start is not None:
        if ":" in warm_start:
            warm_start, index = warm_start.split(":")
        else:
            index = None
        _, state = SnapshotSaver(warm_start,
                                 latest_only=False).get_state(int(index))
        policy.load_state_dict(state["policy"])
        val_fn.load_state_dict(state["val_fn"])
        if "pol_optim" in state:
            pol_optim.load_state_dict(state["pol_optim"])
        if "val_optim" in state:
            val_optim.load_state_dict(state["val_optim"])

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        logger.info("Computing natural gradient using KFAC")
        with pol_optim.record_stats():
            policy.zero_grad()
            all_dists = policy(all_obs)
            all_logp = all_dists.log_prob(all_acts)
            all_logp.mean().backward(retain_graph=True)

        policy.zero_grad()
        old_dists, old_logp = all_dists.detach(), all_logp.detach()
        surr_loss = -((all_logp - old_logp).exp() * all_advs).mean()
        surr_loss.backward()
        pol_grad = [p.grad.clone() for p in policy.parameters()]
        pol_optim.step()

        if linesearch:
            logger.info("Performing line search")
            kl_clip = pol_optim.state["kl_clip"]
            expected_improvement = sum(
                (g * p.grad.data).sum()
                for g, p in zip(pol_grad, policy.parameters())).item()

            def f_barrier(scale):
                for p in policy.parameters():
                    p.data.add_(scale, p.grad.data)
                new_dists = policy(all_obs)
                for p in policy.parameters():
                    p.data.sub_(scale, p.grad.data)
                new_logp = new_dists.log_prob(all_acts)
                surr_loss = -((new_logp - old_logp).exp() * all_advs).mean()
                avg_kl = kl(old_dists, new_dists).mean().item()
                return surr_loss.item() if avg_kl < kl_clip else float("inf")

            scale, expected_improvement, improvement = line_search(
                f_barrier,
                x0=1,
                dx=1,
                expected_improvement=expected_improvement,
                y0=surr_loss.item(),
            )
            logger.logkv("ExpectedImprovement", expected_improvement)
            logger.logkv("ActualImprovement", improvement)
            logger.logkv("ImprovementRatio",
                         improvement / expected_improvement)
            for p in policy.parameters():
                p.data.add_(scale, p.grad.data)

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            with val_optim.record_stats():
                val_fn.zero_grad()
                values = val_fn(all_obs)
                noise = values.detach() + 0.5 * torch.randn_like(values)
                loss_fn(values, noise).backward(retain_graph=True)

            val_fn.zero_grad()
            val_loss = loss_fn(values, all_rets)
            val_loss.backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logu.log_average_kl_divergence(old_dists, policy, all_obs)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            index=samples // stp,
            state=dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                pol_optim=pol_optim.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )

    vec_env.close()
Ejemplo n.º 3
0
def trpo(env,
         policy,
         val_fn=None,
         total_steps=TOTAL_STEPS_DEFAULT,
         steps=125,
         n_envs=16,
         gamma=0.99,
         gaelam=0.97,
         kl_frac=1.0,
         delta=0.01,
         val_iters=80,
         val_lr=1e-3,
         linesearch=True,
         **saver_kwargs):
    val_fn = val_fn or ValueFunction.from_policy(policy)
    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr)
    loss_fn = torch.nn.MSELoss()

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        # subsample for fisher vector product computation
        if kl_frac < 1.0:
            n_samples = int(kl_frac * len(all_obs))
            indexes = torch.randperm(len(all_obs))[:n_samples]
            subsamp_obs = all_obs.index_select(0, indexes)
        else:
            subsamp_obs = all_obs

        logger.info("Computing policy gradient")
        all_dists = policy(all_obs)
        all_logp = all_dists.log_prob(all_acts)
        old_dists = all_dists.detach()
        old_logp = old_dists.log_prob(all_acts)
        surr_loss = -((all_logp - old_logp).exp() * all_advs).mean()
        pol_grad = flat_grad(surr_loss, policy.parameters())

        logger.info("Computing truncated natural gradient")

        def fvp(v):
            return fisher_vec_prod(v, subsamp_obs, policy)

        descent_direction = conjugate_gradient(fvp, pol_grad)
        scale = torch.sqrt(2 * delta /
                           (pol_grad.dot(descent_direction) + 1e-8))
        descent_step = descent_direction * scale

        if linesearch:
            logger.info("Performing line search")
            expected_improvement = pol_grad.dot(descent_step).item()

            def f_barrier(params,
                          all_obs=all_obs,
                          all_acts=all_acts,
                          all_advs=all_advs):
                vector_to_parameters(params, policy.parameters())
                new_dists = policy(all_obs)
                new_logp = new_dists.log_prob(all_acts)
                surr_loss = -((new_logp - old_logp).exp() * all_advs).mean()
                avg_kl = kl(old_dists, new_dists).mean().item()
                return surr_loss.item() if avg_kl < delta else float("inf")

            new_params, expected_improvement, improvement = line_search(
                f_barrier,
                parameters_to_vector(policy.parameters()),
                descent_step,
                expected_improvement,
                y0=surr_loss.item(),
            )
            logger.logkv("ExpectedImprovement", expected_improvement)
            logger.logkv("ActualImprovement", improvement)
            logger.logkv("ImprovementRatio",
                         improvement / expected_improvement)
        else:
            new_params = parameters_to_vector(
                policy.parameters()) - descent_step
        vector_to_parameters(new_params, policy.parameters())

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            val_optim.zero_grad()
            loss_fn(val_fn(all_obs), all_rets).backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logu.log_average_kl_divergence(old_dists, policy, all_obs)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            index=samples // stp,
            state=dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )
        del all_obs, all_acts, all_advs, all_vals, all_rets, trajs

    vec_env.close()
Ejemplo n.º 4
0
def natural(env,
            policy,
            val_fn=None,
            total_steps=TOTAL_STEPS_DEFAULT,
            steps=125,
            n_envs=16,
            gamma=0.99,
            gaelam=0.97,
            kl_frac=1.0,
            delta=0.01,
            val_iters=80,
            val_lr=1e-3,
            **saver_kwargs):
    """
    Natural Policy Gradient

    env: instance of proj.common.env_makers.VecEnvMaker
    policy: instance of proj.common.models.Policy
    val_fn (optional): instance of proj.common.models.ValueFunction
    total_steps: total number of environment steps to take
    steps: number of steps to take in each environment per iteration
    n_envs: number of environment copies to run in parallel
    gamma: GAE discount parameter
    gaelam: GAE lambda exponential average parameter
    kl_frac:
    delta:
    val_iters: number of optimization steps to update the critic per iteration
    val_lr: learning rate for critic optimizer
    saver_kwargs: keyword arguments for proj.utils.saver.SnapshotSaver
    """
    val_fn = val_fn or ValueFunction.from_policy(policy)
    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr)
    loss_fn = torch.nn.MSELoss()

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        # subsample for fisher vector product computation
        if kl_frac < 1.0:
            n_samples = int(kl_frac * len(all_obs))
            indexes = torch.randperm(len(all_obs))[:n_samples]
            subsamp_obs = all_obs.index_select(0, indexes)
        else:
            subsamp_obs = all_obs

        logger.info("Computing policy gradient")
        all_dists = policy(all_obs)
        old_dists = all_dists.detach()
        pol_loss = torch.mean(all_dists.log_prob(all_acts) * all_advs).neg()
        pol_grad = flat_grad(pol_loss, policy.parameters())

        logger.info("Computing truncated natural gradient")

        def fvp(v):
            return fisher_vec_prod(v, subsamp_obs, policy)

        descent_direction = conjugate_gradient(fvp, pol_grad)
        scale = torch.sqrt(2 * delta /
                           (pol_grad.dot(descent_direction) + 1e-8))
        descent_step = descent_direction * scale
        new_params = parameters_to_vector(policy.parameters()) - descent_step
        vector_to_parameters(new_params, policy.parameters())

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            val_optim.zero_grad()
            loss_fn(val_fn(all_obs), all_rets).backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logu.log_average_kl_divergence(old_dists, policy, all_obs)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            samples // stp,
            dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )
        del all_obs, all_acts, all_advs, all_vals, all_rets, trajs

    vec_env.close()
Ejemplo n.º 5
0
def a2c_kfac(env,
             policy,
             val_fn=None,
             total_steps=TOTAL_STEPS_DEFAULT,
             steps=20,
             n_envs=16,
             kfac=None,
             ent_coeff=0.01,
             vf_loss_coeff=0.5,
             gamma=0.99,
             log_interval=100,
             warm_start=None,
             **saver_kwargs):
    assert val_fn is None or not issubclass(
        policy["class"], WeightSharingAC
    ), "Choose between a weight sharing model or separate policy and val_fn"

    # handle default values
    kfac = kfac or {}
    kfac = {
        "eps": 1e-3,
        "pi": True,
        "alpha": 0.95,
        "kl_clip": 1e-3,
        "eta": 1.0,
        **kfac
    }
    if val_fn is None and not issubclass(policy["class"], WeightSharingAC):
        val_fn = ValueFunction.from_policy(policy)

    # save config and setup state saving
    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    # initialize models and optimizer
    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    module_list = torch.nn.ModuleList(policy.modules())
    if val_fn is not None:
        val_fn = val_fn.pop("class")(vec_env, **val_fn)
        module_list.extend(val_fn.modules())
    optimizer = KFACOptimizer(module_list, **kfac)
    # scheduler = LinearLR(optimizer, total_steps // (steps*n_envs))
    loss_fn = torch.nn.MSELoss()

    # load state if provided
    updates = 0
    if warm_start is not None:
        if ":" in warm_start:
            warm_start, index = warm_start.split(":")
        else:
            index = None
        config, state = SnapshotSaver(warm_start,
                                      latest_only=False).get_state(int(index))
        policy.load_state_dict(state["policy"])
        if "optimizer" in state:
            optimizer.load_state_dict(state["optimizer"])
        updates = state["alg"]["last_updt"]

    # Algorith main loop
    if val_fn is None:
        compute_dists_vals = policy
    else:

        def compute_dists_vals(obs):
            return policy(obs), val_fn(obs)

    ob_space, ac_space = vec_env.observation_space, vec_env.action_space
    obs = torch.from_numpy(vec_env.reset())
    with torch.no_grad():
        acts = policy.actions(obs)
    logger.info("Starting epoch {}".format(1))
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    total_updates = total_steps // stp
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        all_obs = torch.empty((steps, n_envs) + ob_space.shape,
                              dtype=_NP_TO_PT[ob_space.dtype.type])
        all_acts = torch.empty((steps, n_envs) + ac_space.shape,
                               dtype=_NP_TO_PT[ac_space.dtype.type])
        all_rews = torch.empty((steps, n_envs))
        all_dones = torch.empty((steps, n_envs))

        with torch.no_grad():
            for i in range(steps):
                next_obs, rews, dones, _ = vec_env.step(acts.numpy())
                all_obs[i] = obs
                all_acts[i] = acts
                all_rews[i] = torch.from_numpy(rews)
                all_dones[i] = torch.from_numpy(dones.astype("f"))
                obs = torch.from_numpy(next_obs)

                acts = policy.actions(obs)

        all_obs = all_obs.reshape(stp, -1).squeeze()
        all_acts = all_acts.reshape(stp, -1).squeeze()

        # Sample Fisher curvature matrix
        with optimizer.record_stats():
            optimizer.zero_grad()
            all_dists, all_vals = compute_dists_vals(all_obs)
            logp = all_dists.log_prob(all_acts)
            noise = all_vals.detach() + 0.5 * torch.randn_like(all_vals)
            (logp.mean() +
             loss_fn(all_vals, noise)).backward(retain_graph=True)

        # Compute returns and advantages
        with torch.no_grad():
            _, next_vals = compute_dists_vals(obs)
        all_rets = all_rews.clone()
        all_rets[-1] += gamma * (1 - all_dones[-1]) * next_vals
        for i in reversed(range(steps - 1)):
            all_rets[i] += gamma * (1 - all_dones[i]) * all_rets[i + 1]
        all_rets = all_rets.flatten()
        all_advs = all_rets - all_vals.detach()

        # Compute loss
        updates += 1
        # ent_coeff = ent_coeff*0.99 if updates % 10 == 0 else ent_coeff
        pi_loss = -torch.mean(logp * all_advs)
        vf_loss = loss_fn(all_vals, all_rets)
        entropy = all_dists.entropy().mean()
        total_loss = pi_loss - ent_coeff * entropy + vf_loss_coeff * vf_loss

        # scheduler.step()
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if updates == 1 or updates % log_interval == 0:
            logger.logkv("Epoch", updates // log_interval + 1)
            logger.logkv("TotalNSamples", samples)
            logu.log_reward_statistics(vec_env)
            logu.log_val_fn_statistics(all_vals, all_rets)
            logu.log_action_distribution_statistics(all_dists)
            logger.dumpkvs()
            logger.info("Starting epoch {}".format(updates // log_interval +
                                                   2))

        saver.save_state(
            index=updates,
            state=dict(
                alg=dict(last_updt=updates),
                policy=policy.state_dict(),
                val_fn=None if val_fn is None else val_fn.state_dict(),
                optimizer=optimizer.state_dict(),
            ),
        )

    vec_env.close()
Ejemplo n.º 6
0
def ppo(env,
        policy,
        val_fn=None,
        total_steps=TOTAL_STEPS_DEFAULT,
        steps=125,
        n_envs=16,
        gamma=0.99,
        gaelam=0.96,
        clip_ratio=0.2,
        pol_iters=80,
        val_iters=80,
        pol_lr=3e-4,
        val_lr=1e-3,
        target_kl=0.01,
        mb_size=100,
        **saver_kwargs):
    val_fn = val_fn or ValueFunction.from_policy(policy)

    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    pol_optim = torch.optim.Adam(policy.parameters(), lr=pol_lr)
    val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr)
    loss_fn = torch.nn.MSELoss()

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        logger.info("Minimizing surrogate loss")
        with torch.no_grad():
            old_dists = policy(all_obs)
        old_logp = old_dists.log_prob(all_acts)
        min_advs = torch.where(all_advs > 0, (1 + clip_ratio) * all_advs,
                               (1 - clip_ratio) * all_advs)
        dataset = TensorDataset(all_obs, all_acts, all_advs, min_advs,
                                old_logp)
        dataloader = DataLoader(dataset, batch_size=mb_size, shuffle=True)
        for itr in range(pol_iters):
            for obs, acts, advs, min_adv, logp in dataloader:
                ratios = (policy(obs).log_prob(acts) - logp).exp()
                pol_optim.zero_grad()
                (-torch.min(ratios * advs, min_adv)).mean().backward()
                pol_optim.step()

            with torch.no_grad():
                mean_kl = kl(old_dists, policy(all_obs)).mean().item()
            if mean_kl > 1.5 * target_kl:
                logger.info(
                    "Stopped at step {} due to reaching max kl".format(itr +
                                                                       1))
                break
        logger.logkv("StopIter", itr + 1)

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            val_optim.zero_grad()
            loss_fn(val_fn(all_obs), all_rets).backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logger.logkv("MeanKL", mean_kl)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            index=samples // stp,
            state=dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                pol_optim=pol_optim.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )

    vec_env.close()
Ejemplo n.º 7
0
def vanilla(
    env,
    policy,
    val_fn=None,
    total_steps=TOTAL_STEPS_DEFAULT,
    steps=125,
    n_envs=16,
    gamma=0.99,
    gaelam=0.97,
    optimizer=None,
    val_iters=80,
    val_lr=1e-3,
    **saver_kwargs
):
    """
    Vanilla Policy Gradient

    env: instance of proj.common.env_makers.VecEnvMaker
    policy: instance of proj.common.models.Policy
    val_fn (optional): instance of proj.common.models.ValueFunction
    total_steps: total number of environment steps to take
    steps: number of steps to take in each environment per iteration
    n_envs: number of environment copies to run in parallel
    gamma: GAE discount parameter
    gaelam: GAE lambda exponential average parameter
    optimizer (optional): dictionary containing optimizer kwargs and/or class
    val_iters: number of optimization steps to update the critic per iteration
    val_lr: learning rate for critic optimizer
    saver_kwargs: keyword arguments for proj.utils.saver.SnapshotSaver
    """
    optimizer = optimizer or {}
    optimizer = {"class": torch.optim.Adam, **optimizer}
    val_fn = val_fn or ValueFunction.from_policy(policy)

    logu.save_config(locals())
    saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs)

    vec_env = VecEnvMaker(env)(n_envs)
    policy = policy.pop("class")(vec_env, **policy)
    val_fn = val_fn.pop("class")(vec_env, **val_fn)
    pol_optim = optimizer.pop("class")(policy.parameters(), **optimizer)
    val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr)
    loss_fn = torch.nn.MSELoss()

    # Algorithm main loop
    collector = parallel_samples_collector(vec_env, policy, steps)
    beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs
    for samples in trange(beg, end, stp, desc="Training", unit="step"):
        logger.info("Starting iteration {}".format(samples // stp))
        logger.logkv("Iteration", samples // stp)

        logger.info("Start collecting samples")
        trajs = next(collector)

        logger.info("Computing policy gradient variables")
        compute_pg_vars(trajs, val_fn, gamma, gaelam)
        flatten_trajs(trajs)
        all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values()
        all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs]

        logger.info("Applying policy gradient")
        all_dists = policy(all_obs)
        old_dists = all_dists.detach()
        objective = torch.mean(all_dists.log_prob(all_acts) * all_advs)
        pol_optim.zero_grad()
        objective.neg().backward()
        pol_optim.step()

        logger.info("Updating val_fn")
        for _ in range(val_iters):
            val_optim.zero_grad()
            loss_fn(val_fn(all_obs), all_rets).backward()
            val_optim.step()

        logger.info("Logging information")
        logger.logkv("Objective", objective.item())
        logger.logkv("TotalNSamples", samples)
        logu.log_reward_statistics(vec_env)
        logu.log_val_fn_statistics(all_vals, all_rets)
        logu.log_action_distribution_statistics(old_dists)
        logu.log_average_kl_divergence(old_dists, policy, all_obs)
        logger.dumpkvs()

        logger.info("Saving snapshot")
        saver.save_state(
            samples // stp,
            dict(
                alg=dict(last_iter=samples // stp),
                policy=policy.state_dict(),
                val_fn=val_fn.state_dict(),
                pol_optim=pol_optim.state_dict(),
                val_optim=val_optim.state_dict(),
            ),
        )
        del all_obs, all_acts, all_advs, all_vals, all_rets, trajs

    vec_env.close()