def acktr(env, policy, val_fn=None, total_steps=TOTAL_STEPS_DEFAULT, steps=125, n_envs=16, gamma=0.99, gaelam=0.96, val_iters=20, pikfac=None, vfkfac=None, warm_start=None, linesearch=True, **saver_kwargs): # handling default values pikfac = pikfac or {} vfkfac = vfkfac or {} val_fn = val_fn or ValueFunction.from_policy(policy) # save config and setup state saving logu.save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) # initialize models and optimizer vec_env = VecEnvMaker(env)(n_envs) policy = policy.pop("class")(vec_env, **policy) val_fn = val_fn.pop("class")(vec_env, **val_fn) pol_optim = KFACOptimizer(policy, **{**DEFAULT_PIKFAC, **pikfac}) val_optim = KFACOptimizer(val_fn, **{**DEFAULT_VFKFAC, **vfkfac}) loss_fn = torch.nn.MSELoss() # load state if provided if warm_start is not None: if ":" in warm_start: warm_start, index = warm_start.split(":") else: index = None _, state = SnapshotSaver(warm_start, latest_only=False).get_state(int(index)) policy.load_state_dict(state["policy"]) val_fn.load_state_dict(state["val_fn"]) if "pol_optim" in state: pol_optim.load_state_dict(state["pol_optim"]) if "val_optim" in state: val_optim.load_state_dict(state["val_optim"]) # Algorithm main loop collector = parallel_samples_collector(vec_env, policy, steps) beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs for samples in trange(beg, end, stp, desc="Training", unit="step"): logger.info("Starting iteration {}".format(samples // stp)) logger.logkv("Iteration", samples // stp) logger.info("Start collecting samples") trajs = next(collector) logger.info("Computing policy gradient variables") compute_pg_vars(trajs, val_fn, gamma, gaelam) flatten_trajs(trajs) all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values() all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs] logger.info("Computing natural gradient using KFAC") with pol_optim.record_stats(): policy.zero_grad() all_dists = policy(all_obs) all_logp = all_dists.log_prob(all_acts) all_logp.mean().backward(retain_graph=True) policy.zero_grad() old_dists, old_logp = all_dists.detach(), all_logp.detach() surr_loss = -((all_logp - old_logp).exp() * all_advs).mean() surr_loss.backward() pol_grad = [p.grad.clone() for p in policy.parameters()] pol_optim.step() if linesearch: logger.info("Performing line search") kl_clip = pol_optim.state["kl_clip"] expected_improvement = sum( (g * p.grad.data).sum() for g, p in zip(pol_grad, policy.parameters())).item() def f_barrier(scale): for p in policy.parameters(): p.data.add_(scale, p.grad.data) new_dists = policy(all_obs) for p in policy.parameters(): p.data.sub_(scale, p.grad.data) new_logp = new_dists.log_prob(all_acts) surr_loss = -((new_logp - old_logp).exp() * all_advs).mean() avg_kl = kl(old_dists, new_dists).mean().item() return surr_loss.item() if avg_kl < kl_clip else float("inf") scale, expected_improvement, improvement = line_search( f_barrier, x0=1, dx=1, expected_improvement=expected_improvement, y0=surr_loss.item(), ) logger.logkv("ExpectedImprovement", expected_improvement) logger.logkv("ActualImprovement", improvement) logger.logkv("ImprovementRatio", improvement / expected_improvement) for p in policy.parameters(): p.data.add_(scale, p.grad.data) logger.info("Updating val_fn") for _ in range(val_iters): with val_optim.record_stats(): val_fn.zero_grad() values = val_fn(all_obs) noise = values.detach() + 0.5 * torch.randn_like(values) loss_fn(values, noise).backward(retain_graph=True) val_fn.zero_grad() val_loss = loss_fn(values, all_rets) val_loss.backward() val_optim.step() logger.info("Logging information") logger.logkv("TotalNSamples", samples) logu.log_reward_statistics(vec_env) logu.log_val_fn_statistics(all_vals, all_rets) logu.log_action_distribution_statistics(old_dists) logu.log_average_kl_divergence(old_dists, policy, all_obs) logger.dumpkvs() logger.info("Saving snapshot") saver.save_state( index=samples // stp, state=dict( alg=dict(last_iter=samples // stp), policy=policy.state_dict(), val_fn=val_fn.state_dict(), pol_optim=pol_optim.state_dict(), val_optim=val_optim.state_dict(), ), ) vec_env.close()
def trpo(env, policy, val_fn=None, total_steps=TOTAL_STEPS_DEFAULT, steps=125, n_envs=16, gamma=0.99, gaelam=0.97, kl_frac=1.0, delta=0.01, val_iters=80, val_lr=1e-3, linesearch=True, **saver_kwargs): val_fn = val_fn or ValueFunction.from_policy(policy) logu.save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) vec_env = VecEnvMaker(env)(n_envs) policy = policy.pop("class")(vec_env, **policy) val_fn = val_fn.pop("class")(vec_env, **val_fn) val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr) loss_fn = torch.nn.MSELoss() # Algorithm main loop collector = parallel_samples_collector(vec_env, policy, steps) beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs for samples in trange(beg, end, stp, desc="Training", unit="step"): logger.info("Starting iteration {}".format(samples // stp)) logger.logkv("Iteration", samples // stp) logger.info("Start collecting samples") trajs = next(collector) logger.info("Computing policy gradient variables") compute_pg_vars(trajs, val_fn, gamma, gaelam) flatten_trajs(trajs) all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values() all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs] # subsample for fisher vector product computation if kl_frac < 1.0: n_samples = int(kl_frac * len(all_obs)) indexes = torch.randperm(len(all_obs))[:n_samples] subsamp_obs = all_obs.index_select(0, indexes) else: subsamp_obs = all_obs logger.info("Computing policy gradient") all_dists = policy(all_obs) all_logp = all_dists.log_prob(all_acts) old_dists = all_dists.detach() old_logp = old_dists.log_prob(all_acts) surr_loss = -((all_logp - old_logp).exp() * all_advs).mean() pol_grad = flat_grad(surr_loss, policy.parameters()) logger.info("Computing truncated natural gradient") def fvp(v): return fisher_vec_prod(v, subsamp_obs, policy) descent_direction = conjugate_gradient(fvp, pol_grad) scale = torch.sqrt(2 * delta / (pol_grad.dot(descent_direction) + 1e-8)) descent_step = descent_direction * scale if linesearch: logger.info("Performing line search") expected_improvement = pol_grad.dot(descent_step).item() def f_barrier(params, all_obs=all_obs, all_acts=all_acts, all_advs=all_advs): vector_to_parameters(params, policy.parameters()) new_dists = policy(all_obs) new_logp = new_dists.log_prob(all_acts) surr_loss = -((new_logp - old_logp).exp() * all_advs).mean() avg_kl = kl(old_dists, new_dists).mean().item() return surr_loss.item() if avg_kl < delta else float("inf") new_params, expected_improvement, improvement = line_search( f_barrier, parameters_to_vector(policy.parameters()), descent_step, expected_improvement, y0=surr_loss.item(), ) logger.logkv("ExpectedImprovement", expected_improvement) logger.logkv("ActualImprovement", improvement) logger.logkv("ImprovementRatio", improvement / expected_improvement) else: new_params = parameters_to_vector( policy.parameters()) - descent_step vector_to_parameters(new_params, policy.parameters()) logger.info("Updating val_fn") for _ in range(val_iters): val_optim.zero_grad() loss_fn(val_fn(all_obs), all_rets).backward() val_optim.step() logger.info("Logging information") logger.logkv("TotalNSamples", samples) logu.log_reward_statistics(vec_env) logu.log_val_fn_statistics(all_vals, all_rets) logu.log_action_distribution_statistics(old_dists) logu.log_average_kl_divergence(old_dists, policy, all_obs) logger.dumpkvs() logger.info("Saving snapshot") saver.save_state( index=samples // stp, state=dict( alg=dict(last_iter=samples // stp), policy=policy.state_dict(), val_fn=val_fn.state_dict(), val_optim=val_optim.state_dict(), ), ) del all_obs, all_acts, all_advs, all_vals, all_rets, trajs vec_env.close()
def ppo(env, policy, val_fn=None, total_steps=TOTAL_STEPS_DEFAULT, steps=125, n_envs=16, gamma=0.99, gaelam=0.96, clip_ratio=0.2, pol_iters=80, val_iters=80, pol_lr=3e-4, val_lr=1e-3, target_kl=0.01, mb_size=100, **saver_kwargs): val_fn = val_fn or ValueFunction.from_policy(policy) logu.save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) vec_env = VecEnvMaker(env)(n_envs) policy = policy.pop("class")(vec_env, **policy) val_fn = val_fn.pop("class")(vec_env, **val_fn) pol_optim = torch.optim.Adam(policy.parameters(), lr=pol_lr) val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr) loss_fn = torch.nn.MSELoss() # Algorithm main loop collector = parallel_samples_collector(vec_env, policy, steps) beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs for samples in trange(beg, end, stp, desc="Training", unit="step"): logger.info("Starting iteration {}".format(samples // stp)) logger.logkv("Iteration", samples // stp) logger.info("Start collecting samples") trajs = next(collector) logger.info("Computing policy gradient variables") compute_pg_vars(trajs, val_fn, gamma, gaelam) flatten_trajs(trajs) all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values() all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs] logger.info("Minimizing surrogate loss") with torch.no_grad(): old_dists = policy(all_obs) old_logp = old_dists.log_prob(all_acts) min_advs = torch.where(all_advs > 0, (1 + clip_ratio) * all_advs, (1 - clip_ratio) * all_advs) dataset = TensorDataset(all_obs, all_acts, all_advs, min_advs, old_logp) dataloader = DataLoader(dataset, batch_size=mb_size, shuffle=True) for itr in range(pol_iters): for obs, acts, advs, min_adv, logp in dataloader: ratios = (policy(obs).log_prob(acts) - logp).exp() pol_optim.zero_grad() (-torch.min(ratios * advs, min_adv)).mean().backward() pol_optim.step() with torch.no_grad(): mean_kl = kl(old_dists, policy(all_obs)).mean().item() if mean_kl > 1.5 * target_kl: logger.info( "Stopped at step {} due to reaching max kl".format(itr + 1)) break logger.logkv("StopIter", itr + 1) logger.info("Updating val_fn") for _ in range(val_iters): val_optim.zero_grad() loss_fn(val_fn(all_obs), all_rets).backward() val_optim.step() logger.info("Logging information") logger.logkv("TotalNSamples", samples) logu.log_reward_statistics(vec_env) logu.log_val_fn_statistics(all_vals, all_rets) logu.log_action_distribution_statistics(old_dists) logger.logkv("MeanKL", mean_kl) logger.dumpkvs() logger.info("Saving snapshot") saver.save_state( index=samples // stp, state=dict( alg=dict(last_iter=samples // stp), policy=policy.state_dict(), val_fn=val_fn.state_dict(), pol_optim=pol_optim.state_dict(), val_optim=val_optim.state_dict(), ), ) vec_env.close()
def natural(env, policy, val_fn=None, total_steps=TOTAL_STEPS_DEFAULT, steps=125, n_envs=16, gamma=0.99, gaelam=0.97, kl_frac=1.0, delta=0.01, val_iters=80, val_lr=1e-3, **saver_kwargs): """ Natural Policy Gradient env: instance of proj.common.env_makers.VecEnvMaker policy: instance of proj.common.models.Policy val_fn (optional): instance of proj.common.models.ValueFunction total_steps: total number of environment steps to take steps: number of steps to take in each environment per iteration n_envs: number of environment copies to run in parallel gamma: GAE discount parameter gaelam: GAE lambda exponential average parameter kl_frac: delta: val_iters: number of optimization steps to update the critic per iteration val_lr: learning rate for critic optimizer saver_kwargs: keyword arguments for proj.utils.saver.SnapshotSaver """ val_fn = val_fn or ValueFunction.from_policy(policy) logu.save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) vec_env = VecEnvMaker(env)(n_envs) policy = policy.pop("class")(vec_env, **policy) val_fn = val_fn.pop("class")(vec_env, **val_fn) val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr) loss_fn = torch.nn.MSELoss() # Algorithm main loop collector = parallel_samples_collector(vec_env, policy, steps) beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs for samples in trange(beg, end, stp, desc="Training", unit="step"): logger.info("Starting iteration {}".format(samples // stp)) logger.logkv("Iteration", samples // stp) logger.info("Start collecting samples") trajs = next(collector) logger.info("Computing policy gradient variables") compute_pg_vars(trajs, val_fn, gamma, gaelam) flatten_trajs(trajs) all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values() all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs] # subsample for fisher vector product computation if kl_frac < 1.0: n_samples = int(kl_frac * len(all_obs)) indexes = torch.randperm(len(all_obs))[:n_samples] subsamp_obs = all_obs.index_select(0, indexes) else: subsamp_obs = all_obs logger.info("Computing policy gradient") all_dists = policy(all_obs) old_dists = all_dists.detach() pol_loss = torch.mean(all_dists.log_prob(all_acts) * all_advs).neg() pol_grad = flat_grad(pol_loss, policy.parameters()) logger.info("Computing truncated natural gradient") def fvp(v): return fisher_vec_prod(v, subsamp_obs, policy) descent_direction = conjugate_gradient(fvp, pol_grad) scale = torch.sqrt(2 * delta / (pol_grad.dot(descent_direction) + 1e-8)) descent_step = descent_direction * scale new_params = parameters_to_vector(policy.parameters()) - descent_step vector_to_parameters(new_params, policy.parameters()) logger.info("Updating val_fn") for _ in range(val_iters): val_optim.zero_grad() loss_fn(val_fn(all_obs), all_rets).backward() val_optim.step() logger.info("Logging information") logger.logkv("TotalNSamples", samples) logu.log_reward_statistics(vec_env) logu.log_val_fn_statistics(all_vals, all_rets) logu.log_action_distribution_statistics(old_dists) logu.log_average_kl_divergence(old_dists, policy, all_obs) logger.dumpkvs() logger.info("Saving snapshot") saver.save_state( samples // stp, dict( alg=dict(last_iter=samples // stp), policy=policy.state_dict(), val_fn=val_fn.state_dict(), val_optim=val_optim.state_dict(), ), ) del all_obs, all_acts, all_advs, all_vals, all_rets, trajs vec_env.close()
def vanilla( env, policy, val_fn=None, total_steps=TOTAL_STEPS_DEFAULT, steps=125, n_envs=16, gamma=0.99, gaelam=0.97, optimizer=None, val_iters=80, val_lr=1e-3, **saver_kwargs ): """ Vanilla Policy Gradient env: instance of proj.common.env_makers.VecEnvMaker policy: instance of proj.common.models.Policy val_fn (optional): instance of proj.common.models.ValueFunction total_steps: total number of environment steps to take steps: number of steps to take in each environment per iteration n_envs: number of environment copies to run in parallel gamma: GAE discount parameter gaelam: GAE lambda exponential average parameter optimizer (optional): dictionary containing optimizer kwargs and/or class val_iters: number of optimization steps to update the critic per iteration val_lr: learning rate for critic optimizer saver_kwargs: keyword arguments for proj.utils.saver.SnapshotSaver """ optimizer = optimizer or {} optimizer = {"class": torch.optim.Adam, **optimizer} val_fn = val_fn or ValueFunction.from_policy(policy) logu.save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) vec_env = VecEnvMaker(env)(n_envs) policy = policy.pop("class")(vec_env, **policy) val_fn = val_fn.pop("class")(vec_env, **val_fn) pol_optim = optimizer.pop("class")(policy.parameters(), **optimizer) val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr) loss_fn = torch.nn.MSELoss() # Algorithm main loop collector = parallel_samples_collector(vec_env, policy, steps) beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs for samples in trange(beg, end, stp, desc="Training", unit="step"): logger.info("Starting iteration {}".format(samples // stp)) logger.logkv("Iteration", samples // stp) logger.info("Start collecting samples") trajs = next(collector) logger.info("Computing policy gradient variables") compute_pg_vars(trajs, val_fn, gamma, gaelam) flatten_trajs(trajs) all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values() all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs] logger.info("Applying policy gradient") all_dists = policy(all_obs) old_dists = all_dists.detach() objective = torch.mean(all_dists.log_prob(all_acts) * all_advs) pol_optim.zero_grad() objective.neg().backward() pol_optim.step() logger.info("Updating val_fn") for _ in range(val_iters): val_optim.zero_grad() loss_fn(val_fn(all_obs), all_rets).backward() val_optim.step() logger.info("Logging information") logger.logkv("Objective", objective.item()) logger.logkv("TotalNSamples", samples) logu.log_reward_statistics(vec_env) logu.log_val_fn_statistics(all_vals, all_rets) logu.log_action_distribution_statistics(old_dists) logu.log_average_kl_divergence(old_dists, policy, all_obs) logger.dumpkvs() logger.info("Saving snapshot") saver.save_state( samples // stp, dict( alg=dict(last_iter=samples // stp), policy=policy.state_dict(), val_fn=val_fn.state_dict(), pol_optim=pol_optim.state_dict(), val_optim=val_optim.state_dict(), ), ) del all_obs, all_acts, all_advs, all_vals, all_rets, trajs vec_env.close()