def thunk_plus(): import torch import random import numpy as np from baselines import logger from proj.utils.tqdm_util import tqdm_out np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) torch.set_num_threads(4) with tqdm_out(), logger.scoped_configure(log_dir, format_strs): from proj.common.log_utils import save_config logger.set_level(logger.WARN) save_config({"exp_name": exp_name, "alg": thunk}) thunk(**kwargs)
def acktr(env, policy, val_fn=None, total_steps=TOTAL_STEPS_DEFAULT, steps=125, n_envs=16, gamma=0.99, gaelam=0.96, val_iters=20, pikfac=None, vfkfac=None, warm_start=None, linesearch=True, **saver_kwargs): # handling default values pikfac = pikfac or {} vfkfac = vfkfac or {} val_fn = val_fn or ValueFunction.from_policy(policy) # save config and setup state saving logu.save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) # initialize models and optimizer vec_env = VecEnvMaker(env)(n_envs) policy = policy.pop("class")(vec_env, **policy) val_fn = val_fn.pop("class")(vec_env, **val_fn) pol_optim = KFACOptimizer(policy, **{**DEFAULT_PIKFAC, **pikfac}) val_optim = KFACOptimizer(val_fn, **{**DEFAULT_VFKFAC, **vfkfac}) loss_fn = torch.nn.MSELoss() # load state if provided if warm_start is not None: if ":" in warm_start: warm_start, index = warm_start.split(":") else: index = None _, state = SnapshotSaver(warm_start, latest_only=False).get_state(int(index)) policy.load_state_dict(state["policy"]) val_fn.load_state_dict(state["val_fn"]) if "pol_optim" in state: pol_optim.load_state_dict(state["pol_optim"]) if "val_optim" in state: val_optim.load_state_dict(state["val_optim"]) # Algorithm main loop collector = parallel_samples_collector(vec_env, policy, steps) beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs for samples in trange(beg, end, stp, desc="Training", unit="step"): logger.info("Starting iteration {}".format(samples // stp)) logger.logkv("Iteration", samples // stp) logger.info("Start collecting samples") trajs = next(collector) logger.info("Computing policy gradient variables") compute_pg_vars(trajs, val_fn, gamma, gaelam) flatten_trajs(trajs) all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values() all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs] logger.info("Computing natural gradient using KFAC") with pol_optim.record_stats(): policy.zero_grad() all_dists = policy(all_obs) all_logp = all_dists.log_prob(all_acts) all_logp.mean().backward(retain_graph=True) policy.zero_grad() old_dists, old_logp = all_dists.detach(), all_logp.detach() surr_loss = -((all_logp - old_logp).exp() * all_advs).mean() surr_loss.backward() pol_grad = [p.grad.clone() for p in policy.parameters()] pol_optim.step() if linesearch: logger.info("Performing line search") kl_clip = pol_optim.state["kl_clip"] expected_improvement = sum( (g * p.grad.data).sum() for g, p in zip(pol_grad, policy.parameters())).item() def f_barrier(scale): for p in policy.parameters(): p.data.add_(scale, p.grad.data) new_dists = policy(all_obs) for p in policy.parameters(): p.data.sub_(scale, p.grad.data) new_logp = new_dists.log_prob(all_acts) surr_loss = -((new_logp - old_logp).exp() * all_advs).mean() avg_kl = kl(old_dists, new_dists).mean().item() return surr_loss.item() if avg_kl < kl_clip else float("inf") scale, expected_improvement, improvement = line_search( f_barrier, x0=1, dx=1, expected_improvement=expected_improvement, y0=surr_loss.item(), ) logger.logkv("ExpectedImprovement", expected_improvement) logger.logkv("ActualImprovement", improvement) logger.logkv("ImprovementRatio", improvement / expected_improvement) for p in policy.parameters(): p.data.add_(scale, p.grad.data) logger.info("Updating val_fn") for _ in range(val_iters): with val_optim.record_stats(): val_fn.zero_grad() values = val_fn(all_obs) noise = values.detach() + 0.5 * torch.randn_like(values) loss_fn(values, noise).backward(retain_graph=True) val_fn.zero_grad() val_loss = loss_fn(values, all_rets) val_loss.backward() val_optim.step() logger.info("Logging information") logger.logkv("TotalNSamples", samples) logu.log_reward_statistics(vec_env) logu.log_val_fn_statistics(all_vals, all_rets) logu.log_action_distribution_statistics(old_dists) logu.log_average_kl_divergence(old_dists, policy, all_obs) logger.dumpkvs() logger.info("Saving snapshot") saver.save_state( index=samples // stp, state=dict( alg=dict(last_iter=samples // stp), policy=policy.state_dict(), val_fn=val_fn.state_dict(), pol_optim=pol_optim.state_dict(), val_optim=val_optim.state_dict(), ), ) vec_env.close()
def sac(env, policy, q_func=None, val_fn=None, total_steps=TOTAL_STEPS_DEFAULT, gamma=0.99, replay_size=REPLAY_SIZE_DEFAULT, polyak=0.995, start_steps=10000, epoch=5000, mb_size=100, lr=1e-3, alpha=0.2, target_entropy=None, reward_scale=1.0, updates_per_step=1.0, max_ep_length=1000, **saver_kwargs): # Set and save experiment hyperparameters q_func = q_func or ContinuousQFunction.from_policy(policy) val_fn = val_fn or ValueFunction.from_policy(policy) save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) # Initialize environments, models and replay buffer vec_env = VecEnvMaker(env)() test_env = VecEnvMaker(env)(train=False) ob_space, ac_space = vec_env.observation_space, vec_env.action_space pi_class, pi_args = policy.pop("class"), policy qf_class, qf_args = q_func.pop("class"), q_func vf_class, vf_args = val_fn.pop("class"), val_fn policy = pi_class(vec_env, **pi_args) q1func = qf_class(vec_env, **qf_args) q2func = qf_class(vec_env, **qf_args) val_fn = vf_class(vec_env, **vf_args) replay = ReplayBuffer(replay_size, ob_space, ac_space) if target_entropy is not None: log_alpha = torch.nn.Parameter(torch.zeros([])) if target_entropy == "auto": target_entropy = -np.prod(ac_space.shape) # Initialize optimizers and target networks loss_fn = torch.nn.MSELoss() pi_optim = torch.optim.Adam(policy.parameters(), lr=lr) qf_optim = torch.optim.Adam(list(q1func.parameters()) + list(q2func.parameters()), lr=lr) vf_optim = torch.optim.Adam(val_fn.parameters(), lr=lr) vf_targ = vf_class(vec_env, **vf_args) vf_targ.load_state_dict(val_fn.state_dict()) if target_entropy is not None: al_optim = torch.optim.Adam([log_alpha], lr=lr) # Save initial state state = dict( alg=dict(samples=0), policy=policy.state_dict(), q1func=q1func.state_dict(), q2func=q2func.state_dict(), val_fn=val_fn.state_dict(), pi_optim=pi_optim.state_dict(), qf_optim=qf_optim.state_dict(), vf_optim=vf_optim.state_dict(), vf_targ=vf_targ.state_dict(), ) if target_entropy is not None: state["log_alpha"] = log_alpha state["al_optim"] = al_optim.state_dict() saver.save_state(index=0, state=state) # Setup and run policy tests ob, don = test_env.reset(), False @torch.no_grad() def test_policy(): nonlocal ob, don policy.eval() for _ in range(10): while not don: act = policy.actions(torch.from_numpy(ob)) ob, _, don, _ = test_env.step(act.numpy()) don = False policy.train() log_reward_statistics(test_env, num_last_eps=10, prefix="Test") test_policy() logger.logkv("Epoch", 0) logger.logkv("TotalNSamples", 0) logger.dumpkvs() # Set action sampling strategies def rand_uniform_actions(_): return np.stack([ac_space.sample() for _ in range(vec_env.num_envs)]) @torch.no_grad() def stoch_policy_actions(obs): return policy.actions(torch.from_numpy(obs)).numpy() # Algorithm main loop obs1, ep_length = vec_env.reset(), 0 for samples in trange(1, total_steps + 1, desc="Training", unit="step"): if samples <= start_steps: actions = rand_uniform_actions else: actions = stoch_policy_actions acts = actions(obs1) obs2, rews, dones, _ = vec_env.step(acts) ep_length += 1 dones[0] = False if ep_length == max_ep_length else dones[0] as_tensors = map(torch.from_numpy, (obs1, acts, rews, obs2, dones.astype("f"))) for ob1, act, rew, ob2, done in zip(*as_tensors): replay.store(ob1, act, rew, ob2, done) obs1 = obs2 if (dones[0] or ep_length == max_ep_length) and replay.size >= mb_size: for _ in range(int(ep_length * updates_per_step)): ob_1, act_, rew_, ob_2, done_ = replay.sample(mb_size) dist = policy(ob_1) pi_a = dist.rsample() logp = dist.log_prob(pi_a) if target_entropy is not None: al_optim.zero_grad() alpha_loss = torch.mean( log_alpha * (logp.detach() + target_entropy)).neg() alpha_loss.backward() al_optim.step() logger.logkv_mean("AlphaLoss", alpha_loss.item()) alpha = log_alpha.exp().item() with torch.no_grad(): y_qf = reward_scale * rew_ + gamma * ( 1 - done_) * vf_targ(ob_2) y_vf = (torch.min(q1func(ob_1, pi_a), q2func(ob_1, pi_a)) - alpha * logp) qf_optim.zero_grad() q1_val = q1func(ob_1, act_) q2_val = q2func(ob_1, act_) q1_loss = loss_fn(q1_val, y_qf).div(2) q2_loss = loss_fn(q2_val, y_qf).div(2) q1_loss.add(q2_loss).backward() qf_optim.step() vf_optim.zero_grad() vf_val = val_fn(ob_1) vf_loss = loss_fn(vf_val, y_vf).div(2) vf_loss.backward() vf_optim.step() pi_optim.zero_grad() qpi_val = q1func(ob_1, pi_a) # qpi_val = torch.min(q1func(ob_1, pi_a), q2func(ob_1, pi_a)) pi_loss = qpi_val.sub(logp, alpha=alpha).mean().neg() pi_loss.backward() pi_optim.step() update_polyak(val_fn, vf_targ, polyak) logger.logkv_mean("Entropy", logp.mean().neg().item()) logger.logkv_mean("Q1Val", q1_val.mean().item()) logger.logkv_mean("Q2Val", q2_val.mean().item()) logger.logkv_mean("VFVal", vf_val.mean().item()) logger.logkv_mean("QPiVal", qpi_val.mean().item()) logger.logkv_mean("Q1Loss", q1_loss.item()) logger.logkv_mean("Q2Loss", q2_loss.item()) logger.logkv_mean("VFLoss", vf_loss.item()) logger.logkv_mean("PiLoss", pi_loss.item()) logger.logkv_mean("Alpha", alpha) ep_length = 0 if samples % epoch == 0: test_policy() logger.logkv("Epoch", samples // epoch) logger.logkv("TotalNSamples", samples) log_reward_statistics(vec_env) logger.dumpkvs() state = dict( alg=dict(samples=samples), policy=policy.state_dict(), q1func=q1func.state_dict(), q2func=q2func.state_dict(), val_fn=val_fn.state_dict(), pi_optim=pi_optim.state_dict(), qf_optim=qf_optim.state_dict(), vf_optim=vf_optim.state_dict(), vf_targ=vf_targ.state_dict(), ) if target_entropy is not None: state["log_alpha"] = log_alpha state["al_optim"] = al_optim.state_dict() saver.save_state(index=samples // epoch, state=state)
def a2c(env, policy, val_fn=None, total_steps=TOTAL_STEPS_DEFAULT, steps=20, n_envs=16, gamma=0.99, optimizer=None, max_grad_norm=0.5, ent_coeff=0.01, vf_loss_coeff=0.5, log_interval=100, **saver_kwargs): assert val_fn is None or not issubclass( policy["class"], WeightSharingAC ), "Choose between a weight sharing model or separate policy and val_fn" optimizer = optimizer or {} optimizer = { "class": torch.optim.RMSprop, "lr": 1e-3, "eps": 1e-5, "alpha": 0.99, **optimizer, } if val_fn is None and not issubclass(policy["class"], WeightSharingAC): val_fn = ValueFunction.from_policy(policy) logu.save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) vec_env = VecEnvMaker(env)(n_envs) policy = policy.pop("class")(vec_env, **policy) param_list = torch.nn.ParameterList(policy.parameters()) if val_fn is not None: val_fn = val_fn.pop("class")(vec_env, **val_fn) param_list.extend(val_fn.parameters()) optimizer = optimizer.pop("class")(param_list.parameters(), **optimizer) loss_fn = torch.nn.MSELoss() # Algorith main loop if val_fn is None: compute_dists_vals = policy else: def compute_dists_vals(obs): return policy(obs), val_fn(obs) generator = samples_generator(vec_env, policy, steps, compute_dists_vals) logger.info("Starting epoch {}".format(1)) beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs for samples in trange(beg, end, stp, desc="Training", unit="step"): all_acts, all_rews, all_dones, all_dists, all_vals, next_vals = next( generator) # Compute returns and advantages all_rets = all_rews.clone() all_rets[-1] += gamma * (1 - all_dones[-1]) * next_vals for i in reversed(range(steps - 1)): all_rets[i] += gamma * (1 - all_dones[i]) * all_rets[i + 1] all_advs = all_rets - all_vals.detach() # Compute loss log_li = all_dists.log_prob(all_acts.reshape(stp, -1).squeeze()) pi_loss = -torch.mean(log_li * all_advs.flatten()) vf_loss = loss_fn(all_vals.flatten(), all_rets.flatten()) entropy = all_dists.entropy().mean() total_loss = pi_loss - ent_coeff * entropy + vf_loss_coeff * vf_loss optimizer.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(param_list.parameters(), max_grad_norm) optimizer.step() updates = samples // stp if updates == 1 or updates % log_interval == 0: logger.logkv("Epoch", updates // log_interval + 1) logger.logkv("TotalNSamples", samples) logu.log_reward_statistics(vec_env) logu.log_val_fn_statistics(all_vals.flatten(), all_rets.flatten()) logu.log_action_distribution_statistics(all_dists) logger.dumpkvs() logger.info("Starting epoch {}".format(updates // log_interval + 2)) saver.save_state( index=updates, state=dict( alg=dict(last_updt=updates), policy=policy.state_dict(), val_fn=None if val_fn is None else val_fn.state_dict(), optimizer=optimizer.state_dict(), ), ) vec_env.close()
def td3(env, policy, q_func=None, total_steps=TOTAL_STEPS_DEFAULT, gamma=0.99, replay_size=REPLAY_SIZE_DEFAULT, polyak=0.995, start_steps=10000, epoch=5000, pi_lr=1e-3, qf_lr=1e-3, mb_size=100, act_noise=0.1, max_ep_length=1000, target_noise=0.2, noise_clip=0.5, policy_delay=2, updates_per_step=1.0, **saver_kwargs): # Set and save experiment hyperparameters q_func = q_func or ContinuousQFunction.from_policy(policy) save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) # Initialize environments, models and replay buffer vec_env = VecEnvMaker(env)() test_env = VecEnvMaker(env)(train=False) ob_space, ac_space = vec_env.observation_space, vec_env.action_space pi_class, pi_args = policy.pop("class"), policy qf_class, qf_args = q_func.pop("class"), q_func policy = pi_class(vec_env, **pi_args) q1func = qf_class(vec_env, **qf_args) q2func = qf_class(vec_env, **qf_args) replay = ReplayBuffer(replay_size, ob_space, ac_space) # Initialize optimizers and target networks loss_fn = torch.nn.MSELoss() pi_optim = torch.optim.Adam(policy.parameters(), lr=pi_lr) qf_optim = torch.optim.Adam(chain(q1func.parameters(), q2func.parameters()), lr=qf_lr) pi_targ = pi_class(vec_env, **pi_args) q1_targ = qf_class(vec_env, **qf_args) q2_targ = qf_class(vec_env, **qf_args) pi_targ.load_state_dict(policy.state_dict()) q1_targ.load_state_dict(q1func.state_dict()) q2_targ.load_state_dict(q2func.state_dict()) # Save initial state saver.save_state( index=0, state=dict( alg=dict(samples=0), policy=policy.state_dict(), q1func=q1func.state_dict(), q2func=q2func.state_dict(), pi_optim=pi_optim.state_dict(), qf_optim=qf_optim.state_dict(), pi_targ=pi_targ.state_dict(), q1_targ=q1_targ.state_dict(), q2_targ=q2_targ.state_dict(), ), ) # Setup and run policy tests ob, don = test_env.reset(), False @torch.no_grad() def test_policy(): nonlocal ob, don for _ in range(10): while not don: act = policy.actions(torch.from_numpy(ob)) ob, _, don, _ = test_env.step(act.numpy()) don = False log_reward_statistics(test_env, num_last_eps=10, prefix="Test") test_policy() logger.logkv("Epoch", 0) logger.logkv("TotalNSamples", 0) logger.dumpkvs() # Set action sampling strategies def rand_uniform_actions(_): return np.stack([ac_space.sample() for _ in range(vec_env.num_envs)]) act_low, act_high = map(torch.Tensor, (ac_space.low, ac_space.high)) @torch.no_grad() def noisy_policy_actions(obs): acts = policy.actions(torch.from_numpy(obs)) acts += act_noise * torch.randn_like(acts) return np.clip(acts.numpy(), ac_space.low, ac_space.high) # Algorithm main loop obs1, ep_length, critic_updates = vec_env.reset(), 0, 0 for samples in trange(1, total_steps + 1, desc="Training", unit="step"): if samples <= start_steps: actions = rand_uniform_actions else: actions = noisy_policy_actions acts = actions(obs1) obs2, rews, dones, _ = vec_env.step(acts) ep_length += 1 dones[0] = False if ep_length == max_ep_length else dones[0] as_tensors = map(torch.from_numpy, (obs1, acts, rews, obs2, dones.astype("f"))) for ob1, act, rew, ob2, done in zip(*as_tensors): replay.store(ob1, act, rew, ob2, done) obs1 = obs2 if (dones[0] or ep_length == max_ep_length) and replay.size >= mb_size: for _ in range(int(ep_length * updates_per_step)): ob_1, act_, rew_, ob_2, done_ = replay.sample(mb_size) with torch.no_grad(): atarg = pi_targ(ob_2) atarg += torch.clamp( target_noise * torch.randn_like(atarg), -noise_clip, noise_clip) atarg = torch.max(torch.min(atarg, act_high), act_low) targs = rew_ + gamma * (1 - done_) * torch.min( q1_targ(ob_2, atarg), q2_targ(ob_2, atarg)) qf_optim.zero_grad() q1_val = q1func(ob_1, act_) q2_val = q2func(ob_1, act_) q1_loss = loss_fn(q1_val, targs).div(2) q2_loss = loss_fn(q2_val, targs).div(2) q1_loss.add(q2_loss).backward() qf_optim.step() critic_updates += 1 if critic_updates % policy_delay == 0: pi_optim.zero_grad() qpi_val = q1func(ob_1, policy(ob_1)) pi_loss = qpi_val.mean().neg() pi_loss.backward() pi_optim.step() update_polyak(policy, pi_targ, polyak) update_polyak(q1func, q1_targ, polyak) update_polyak(q2func, q2_targ, polyak) logger.logkv_mean("QPiVal", qpi_val.mean().item()) logger.logkv_mean("PiLoss", pi_loss.item()) logger.logkv_mean("Q1Val", q1_val.mean().item()) logger.logkv_mean("Q2Val", q2_val.mean().item()) logger.logkv_mean("Q1Loss", q1_loss.item()) logger.logkv_mean("Q2Loss", q2_loss.item()) ep_length = 0 if samples % epoch == 0: test_policy() logger.logkv("Epoch", samples // epoch) logger.logkv("TotalNSamples", samples) log_reward_statistics(vec_env) logger.dumpkvs() saver.save_state( index=samples // epoch, state=dict( alg=dict(samples=samples), policy=policy.state_dict(), q1func=q1func.state_dict(), q2func=q2func.state_dict(), pi_optim=pi_optim.state_dict(), qf_optim=qf_optim.state_dict(), pi_targ=pi_targ.state_dict(), q1_targ=q1_targ.state_dict(), q2_targ=q2_targ.state_dict(), ), )
def trpo(env, policy, val_fn=None, total_steps=TOTAL_STEPS_DEFAULT, steps=125, n_envs=16, gamma=0.99, gaelam=0.97, kl_frac=1.0, delta=0.01, val_iters=80, val_lr=1e-3, linesearch=True, **saver_kwargs): val_fn = val_fn or ValueFunction.from_policy(policy) logu.save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) vec_env = VecEnvMaker(env)(n_envs) policy = policy.pop("class")(vec_env, **policy) val_fn = val_fn.pop("class")(vec_env, **val_fn) val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr) loss_fn = torch.nn.MSELoss() # Algorithm main loop collector = parallel_samples_collector(vec_env, policy, steps) beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs for samples in trange(beg, end, stp, desc="Training", unit="step"): logger.info("Starting iteration {}".format(samples // stp)) logger.logkv("Iteration", samples // stp) logger.info("Start collecting samples") trajs = next(collector) logger.info("Computing policy gradient variables") compute_pg_vars(trajs, val_fn, gamma, gaelam) flatten_trajs(trajs) all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values() all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs] # subsample for fisher vector product computation if kl_frac < 1.0: n_samples = int(kl_frac * len(all_obs)) indexes = torch.randperm(len(all_obs))[:n_samples] subsamp_obs = all_obs.index_select(0, indexes) else: subsamp_obs = all_obs logger.info("Computing policy gradient") all_dists = policy(all_obs) all_logp = all_dists.log_prob(all_acts) old_dists = all_dists.detach() old_logp = old_dists.log_prob(all_acts) surr_loss = -((all_logp - old_logp).exp() * all_advs).mean() pol_grad = flat_grad(surr_loss, policy.parameters()) logger.info("Computing truncated natural gradient") def fvp(v): return fisher_vec_prod(v, subsamp_obs, policy) descent_direction = conjugate_gradient(fvp, pol_grad) scale = torch.sqrt(2 * delta / (pol_grad.dot(descent_direction) + 1e-8)) descent_step = descent_direction * scale if linesearch: logger.info("Performing line search") expected_improvement = pol_grad.dot(descent_step).item() def f_barrier(params, all_obs=all_obs, all_acts=all_acts, all_advs=all_advs): vector_to_parameters(params, policy.parameters()) new_dists = policy(all_obs) new_logp = new_dists.log_prob(all_acts) surr_loss = -((new_logp - old_logp).exp() * all_advs).mean() avg_kl = kl(old_dists, new_dists).mean().item() return surr_loss.item() if avg_kl < delta else float("inf") new_params, expected_improvement, improvement = line_search( f_barrier, parameters_to_vector(policy.parameters()), descent_step, expected_improvement, y0=surr_loss.item(), ) logger.logkv("ExpectedImprovement", expected_improvement) logger.logkv("ActualImprovement", improvement) logger.logkv("ImprovementRatio", improvement / expected_improvement) else: new_params = parameters_to_vector( policy.parameters()) - descent_step vector_to_parameters(new_params, policy.parameters()) logger.info("Updating val_fn") for _ in range(val_iters): val_optim.zero_grad() loss_fn(val_fn(all_obs), all_rets).backward() val_optim.step() logger.info("Logging information") logger.logkv("TotalNSamples", samples) logu.log_reward_statistics(vec_env) logu.log_val_fn_statistics(all_vals, all_rets) logu.log_action_distribution_statistics(old_dists) logu.log_average_kl_divergence(old_dists, policy, all_obs) logger.dumpkvs() logger.info("Saving snapshot") saver.save_state( index=samples // stp, state=dict( alg=dict(last_iter=samples // stp), policy=policy.state_dict(), val_fn=val_fn.state_dict(), val_optim=val_optim.state_dict(), ), ) del all_obs, all_acts, all_advs, all_vals, all_rets, trajs vec_env.close()
def natural(env, policy, val_fn=None, total_steps=TOTAL_STEPS_DEFAULT, steps=125, n_envs=16, gamma=0.99, gaelam=0.97, kl_frac=1.0, delta=0.01, val_iters=80, val_lr=1e-3, **saver_kwargs): """ Natural Policy Gradient env: instance of proj.common.env_makers.VecEnvMaker policy: instance of proj.common.models.Policy val_fn (optional): instance of proj.common.models.ValueFunction total_steps: total number of environment steps to take steps: number of steps to take in each environment per iteration n_envs: number of environment copies to run in parallel gamma: GAE discount parameter gaelam: GAE lambda exponential average parameter kl_frac: delta: val_iters: number of optimization steps to update the critic per iteration val_lr: learning rate for critic optimizer saver_kwargs: keyword arguments for proj.utils.saver.SnapshotSaver """ val_fn = val_fn or ValueFunction.from_policy(policy) logu.save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) vec_env = VecEnvMaker(env)(n_envs) policy = policy.pop("class")(vec_env, **policy) val_fn = val_fn.pop("class")(vec_env, **val_fn) val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr) loss_fn = torch.nn.MSELoss() # Algorithm main loop collector = parallel_samples_collector(vec_env, policy, steps) beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs for samples in trange(beg, end, stp, desc="Training", unit="step"): logger.info("Starting iteration {}".format(samples // stp)) logger.logkv("Iteration", samples // stp) logger.info("Start collecting samples") trajs = next(collector) logger.info("Computing policy gradient variables") compute_pg_vars(trajs, val_fn, gamma, gaelam) flatten_trajs(trajs) all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values() all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs] # subsample for fisher vector product computation if kl_frac < 1.0: n_samples = int(kl_frac * len(all_obs)) indexes = torch.randperm(len(all_obs))[:n_samples] subsamp_obs = all_obs.index_select(0, indexes) else: subsamp_obs = all_obs logger.info("Computing policy gradient") all_dists = policy(all_obs) old_dists = all_dists.detach() pol_loss = torch.mean(all_dists.log_prob(all_acts) * all_advs).neg() pol_grad = flat_grad(pol_loss, policy.parameters()) logger.info("Computing truncated natural gradient") def fvp(v): return fisher_vec_prod(v, subsamp_obs, policy) descent_direction = conjugate_gradient(fvp, pol_grad) scale = torch.sqrt(2 * delta / (pol_grad.dot(descent_direction) + 1e-8)) descent_step = descent_direction * scale new_params = parameters_to_vector(policy.parameters()) - descent_step vector_to_parameters(new_params, policy.parameters()) logger.info("Updating val_fn") for _ in range(val_iters): val_optim.zero_grad() loss_fn(val_fn(all_obs), all_rets).backward() val_optim.step() logger.info("Logging information") logger.logkv("TotalNSamples", samples) logu.log_reward_statistics(vec_env) logu.log_val_fn_statistics(all_vals, all_rets) logu.log_action_distribution_statistics(old_dists) logu.log_average_kl_divergence(old_dists, policy, all_obs) logger.dumpkvs() logger.info("Saving snapshot") saver.save_state( samples // stp, dict( alg=dict(last_iter=samples // stp), policy=policy.state_dict(), val_fn=val_fn.state_dict(), val_optim=val_optim.state_dict(), ), ) del all_obs, all_acts, all_advs, all_vals, all_rets, trajs vec_env.close()
def a2c_kfac(env, policy, val_fn=None, total_steps=TOTAL_STEPS_DEFAULT, steps=20, n_envs=16, kfac=None, ent_coeff=0.01, vf_loss_coeff=0.5, gamma=0.99, log_interval=100, warm_start=None, **saver_kwargs): assert val_fn is None or not issubclass( policy["class"], WeightSharingAC ), "Choose between a weight sharing model or separate policy and val_fn" # handle default values kfac = kfac or {} kfac = { "eps": 1e-3, "pi": True, "alpha": 0.95, "kl_clip": 1e-3, "eta": 1.0, **kfac } if val_fn is None and not issubclass(policy["class"], WeightSharingAC): val_fn = ValueFunction.from_policy(policy) # save config and setup state saving logu.save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) # initialize models and optimizer vec_env = VecEnvMaker(env)(n_envs) policy = policy.pop("class")(vec_env, **policy) module_list = torch.nn.ModuleList(policy.modules()) if val_fn is not None: val_fn = val_fn.pop("class")(vec_env, **val_fn) module_list.extend(val_fn.modules()) optimizer = KFACOptimizer(module_list, **kfac) # scheduler = LinearLR(optimizer, total_steps // (steps*n_envs)) loss_fn = torch.nn.MSELoss() # load state if provided updates = 0 if warm_start is not None: if ":" in warm_start: warm_start, index = warm_start.split(":") else: index = None config, state = SnapshotSaver(warm_start, latest_only=False).get_state(int(index)) policy.load_state_dict(state["policy"]) if "optimizer" in state: optimizer.load_state_dict(state["optimizer"]) updates = state["alg"]["last_updt"] # Algorith main loop if val_fn is None: compute_dists_vals = policy else: def compute_dists_vals(obs): return policy(obs), val_fn(obs) ob_space, ac_space = vec_env.observation_space, vec_env.action_space obs = torch.from_numpy(vec_env.reset()) with torch.no_grad(): acts = policy.actions(obs) logger.info("Starting epoch {}".format(1)) beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs total_updates = total_steps // stp for samples in trange(beg, end, stp, desc="Training", unit="step"): all_obs = torch.empty((steps, n_envs) + ob_space.shape, dtype=_NP_TO_PT[ob_space.dtype.type]) all_acts = torch.empty((steps, n_envs) + ac_space.shape, dtype=_NP_TO_PT[ac_space.dtype.type]) all_rews = torch.empty((steps, n_envs)) all_dones = torch.empty((steps, n_envs)) with torch.no_grad(): for i in range(steps): next_obs, rews, dones, _ = vec_env.step(acts.numpy()) all_obs[i] = obs all_acts[i] = acts all_rews[i] = torch.from_numpy(rews) all_dones[i] = torch.from_numpy(dones.astype("f")) obs = torch.from_numpy(next_obs) acts = policy.actions(obs) all_obs = all_obs.reshape(stp, -1).squeeze() all_acts = all_acts.reshape(stp, -1).squeeze() # Sample Fisher curvature matrix with optimizer.record_stats(): optimizer.zero_grad() all_dists, all_vals = compute_dists_vals(all_obs) logp = all_dists.log_prob(all_acts) noise = all_vals.detach() + 0.5 * torch.randn_like(all_vals) (logp.mean() + loss_fn(all_vals, noise)).backward(retain_graph=True) # Compute returns and advantages with torch.no_grad(): _, next_vals = compute_dists_vals(obs) all_rets = all_rews.clone() all_rets[-1] += gamma * (1 - all_dones[-1]) * next_vals for i in reversed(range(steps - 1)): all_rets[i] += gamma * (1 - all_dones[i]) * all_rets[i + 1] all_rets = all_rets.flatten() all_advs = all_rets - all_vals.detach() # Compute loss updates += 1 # ent_coeff = ent_coeff*0.99 if updates % 10 == 0 else ent_coeff pi_loss = -torch.mean(logp * all_advs) vf_loss = loss_fn(all_vals, all_rets) entropy = all_dists.entropy().mean() total_loss = pi_loss - ent_coeff * entropy + vf_loss_coeff * vf_loss # scheduler.step() optimizer.zero_grad() total_loss.backward() optimizer.step() if updates == 1 or updates % log_interval == 0: logger.logkv("Epoch", updates // log_interval + 1) logger.logkv("TotalNSamples", samples) logu.log_reward_statistics(vec_env) logu.log_val_fn_statistics(all_vals, all_rets) logu.log_action_distribution_statistics(all_dists) logger.dumpkvs() logger.info("Starting epoch {}".format(updates // log_interval + 2)) saver.save_state( index=updates, state=dict( alg=dict(last_updt=updates), policy=policy.state_dict(), val_fn=None if val_fn is None else val_fn.state_dict(), optimizer=optimizer.state_dict(), ), ) vec_env.close()
def ppo(env, policy, val_fn=None, total_steps=TOTAL_STEPS_DEFAULT, steps=125, n_envs=16, gamma=0.99, gaelam=0.96, clip_ratio=0.2, pol_iters=80, val_iters=80, pol_lr=3e-4, val_lr=1e-3, target_kl=0.01, mb_size=100, **saver_kwargs): val_fn = val_fn or ValueFunction.from_policy(policy) logu.save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) vec_env = VecEnvMaker(env)(n_envs) policy = policy.pop("class")(vec_env, **policy) val_fn = val_fn.pop("class")(vec_env, **val_fn) pol_optim = torch.optim.Adam(policy.parameters(), lr=pol_lr) val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr) loss_fn = torch.nn.MSELoss() # Algorithm main loop collector = parallel_samples_collector(vec_env, policy, steps) beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs for samples in trange(beg, end, stp, desc="Training", unit="step"): logger.info("Starting iteration {}".format(samples // stp)) logger.logkv("Iteration", samples // stp) logger.info("Start collecting samples") trajs = next(collector) logger.info("Computing policy gradient variables") compute_pg_vars(trajs, val_fn, gamma, gaelam) flatten_trajs(trajs) all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values() all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs] logger.info("Minimizing surrogate loss") with torch.no_grad(): old_dists = policy(all_obs) old_logp = old_dists.log_prob(all_acts) min_advs = torch.where(all_advs > 0, (1 + clip_ratio) * all_advs, (1 - clip_ratio) * all_advs) dataset = TensorDataset(all_obs, all_acts, all_advs, min_advs, old_logp) dataloader = DataLoader(dataset, batch_size=mb_size, shuffle=True) for itr in range(pol_iters): for obs, acts, advs, min_adv, logp in dataloader: ratios = (policy(obs).log_prob(acts) - logp).exp() pol_optim.zero_grad() (-torch.min(ratios * advs, min_adv)).mean().backward() pol_optim.step() with torch.no_grad(): mean_kl = kl(old_dists, policy(all_obs)).mean().item() if mean_kl > 1.5 * target_kl: logger.info( "Stopped at step {} due to reaching max kl".format(itr + 1)) break logger.logkv("StopIter", itr + 1) logger.info("Updating val_fn") for _ in range(val_iters): val_optim.zero_grad() loss_fn(val_fn(all_obs), all_rets).backward() val_optim.step() logger.info("Logging information") logger.logkv("TotalNSamples", samples) logu.log_reward_statistics(vec_env) logu.log_val_fn_statistics(all_vals, all_rets) logu.log_action_distribution_statistics(old_dists) logger.logkv("MeanKL", mean_kl) logger.dumpkvs() logger.info("Saving snapshot") saver.save_state( index=samples // stp, state=dict( alg=dict(last_iter=samples // stp), policy=policy.state_dict(), val_fn=val_fn.state_dict(), pol_optim=pol_optim.state_dict(), val_optim=val_optim.state_dict(), ), ) vec_env.close()
def vanilla( env, policy, val_fn=None, total_steps=TOTAL_STEPS_DEFAULT, steps=125, n_envs=16, gamma=0.99, gaelam=0.97, optimizer=None, val_iters=80, val_lr=1e-3, **saver_kwargs ): """ Vanilla Policy Gradient env: instance of proj.common.env_makers.VecEnvMaker policy: instance of proj.common.models.Policy val_fn (optional): instance of proj.common.models.ValueFunction total_steps: total number of environment steps to take steps: number of steps to take in each environment per iteration n_envs: number of environment copies to run in parallel gamma: GAE discount parameter gaelam: GAE lambda exponential average parameter optimizer (optional): dictionary containing optimizer kwargs and/or class val_iters: number of optimization steps to update the critic per iteration val_lr: learning rate for critic optimizer saver_kwargs: keyword arguments for proj.utils.saver.SnapshotSaver """ optimizer = optimizer or {} optimizer = {"class": torch.optim.Adam, **optimizer} val_fn = val_fn or ValueFunction.from_policy(policy) logu.save_config(locals()) saver = SnapshotSaver(logger.get_dir(), locals(), **saver_kwargs) vec_env = VecEnvMaker(env)(n_envs) policy = policy.pop("class")(vec_env, **policy) val_fn = val_fn.pop("class")(vec_env, **val_fn) pol_optim = optimizer.pop("class")(policy.parameters(), **optimizer) val_optim = torch.optim.Adam(val_fn.parameters(), lr=val_lr) loss_fn = torch.nn.MSELoss() # Algorithm main loop collector = parallel_samples_collector(vec_env, policy, steps) beg, end, stp = steps * n_envs, total_steps + steps * n_envs, steps * n_envs for samples in trange(beg, end, stp, desc="Training", unit="step"): logger.info("Starting iteration {}".format(samples // stp)) logger.logkv("Iteration", samples // stp) logger.info("Start collecting samples") trajs = next(collector) logger.info("Computing policy gradient variables") compute_pg_vars(trajs, val_fn, gamma, gaelam) flatten_trajs(trajs) all_obs, all_acts, _, _, all_advs, all_vals, all_rets = trajs.values() all_obs, all_vals = all_obs[:-n_envs], all_vals[:-n_envs] logger.info("Applying policy gradient") all_dists = policy(all_obs) old_dists = all_dists.detach() objective = torch.mean(all_dists.log_prob(all_acts) * all_advs) pol_optim.zero_grad() objective.neg().backward() pol_optim.step() logger.info("Updating val_fn") for _ in range(val_iters): val_optim.zero_grad() loss_fn(val_fn(all_obs), all_rets).backward() val_optim.step() logger.info("Logging information") logger.logkv("Objective", objective.item()) logger.logkv("TotalNSamples", samples) logu.log_reward_statistics(vec_env) logu.log_val_fn_statistics(all_vals, all_rets) logu.log_action_distribution_statistics(old_dists) logu.log_average_kl_divergence(old_dists, policy, all_obs) logger.dumpkvs() logger.info("Saving snapshot") saver.save_state( samples // stp, dict( alg=dict(last_iter=samples // stp), policy=policy.state_dict(), val_fn=val_fn.state_dict(), pol_optim=pol_optim.state_dict(), val_optim=val_optim.state_dict(), ), ) del all_obs, all_acts, all_advs, all_vals, all_rets, trajs vec_env.close()