def learn(self, total_steps): """ The actual training loop Returns: model: trained model avg_reward_hist: list with the average reward per episode at each epoch var_dict: dictionary with all locals, for logging/debugging purposes """ # init everything # ============================================================================== # seed all our RNGs env = gym.make(self.env_name, **self.env_config) cur_total_steps = 0 env.seed(self.seed) torch.manual_seed(self.seed) np.random.seed(self.seed) progress_bar = tqdm.tqdm(total=total_steps) lr_lookup = make_schedule(self.lr_schedule, total_steps) self.sgd_lr = lr_lookup(0) progress_bar.update(0) early_stop = False self.pol_opt = torch.optim.RMSprop(self.model.policy.parameters(), lr=lr_lookup(cur_total_steps)) self.val_opt = torch.optim.RMSprop(self.model.value_fn.parameters(), lr=lr_lookup(cur_total_steps)) # Train until we hit our total steps or reach our reward threshold # ============================================================================== while cur_total_steps < total_steps: batch_obs = torch.empty(0) batch_act = torch.empty(0) batch_adv = torch.empty(0) batch_discrew = torch.empty(0) cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(self.raw_rew_hist) > 2 and self.reward_stop: if self.raw_rew_hist[ -1] >= self.reward_stop and self.raw_rew_hist[ -2] >= self.reward_stop: early_stop = True break # construct batch data from rollouts # ============================================================================== while cur_batch_steps < self.epoch_batch_size: ep_obs, ep_act, ep_rew, ep_steps, ep_term = do_rollout( env, self.model, self.env_no_term_steps) cur_batch_steps += ep_steps cur_total_steps += ep_steps #print(sum(ep_rew).item()) self.raw_rew_hist.append(sum(ep_rew).item()) #print("Rew:", sum(ep_rew).item()) batch_obs = torch.cat((batch_obs, ep_obs.clone())) batch_act = torch.cat((batch_act, ep_act.clone())) if self.normalize_return: self.rew_std = update_std(ep_rew, self.rew_std, cur_total_steps) ep_rew = ep_rew / (self.rew_std + 1e-6) if ep_term: ep_rew = torch.cat((ep_rew, torch.zeros(1, 1))) else: ep_rew = torch.cat((ep_rew, self.model.value_fn( ep_obs[-1]).detach().reshape(1, 1).clone())) ep_discrew = discount_cumsum(ep_rew, self.gamma)[:-1] batch_discrew = torch.cat((batch_discrew, ep_discrew.clone())) with torch.no_grad(): ep_val = torch.cat((self.model.value_fn(ep_obs), ep_rew[-1].reshape(1, 1).clone())) deltas = ep_rew[:-1] + self.gamma * ep_val[1:] - ep_val[:-1] ep_adv = discount_cumsum(deltas, self.gamma * self.lam) # make sure our advantages are zero mean and unit variance batch_adv = torch.cat((batch_adv, ep_adv.clone())) # PostProcess epoch and update weights # ============================================================================== if self.normalize_adv: # adv_mean = update_mean(batch_adv, adv_mean, cur_total_steps) # adv_var = update_std(batch_adv, adv_var, cur_total_steps) batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-6) # Update the policy using the PPO loss for pol_epoch in range(self.sgd_epochs): pol_loss, approx_kl = self.policy_update( batch_act, batch_obs, batch_adv) if approx_kl > self.target_kl: print("KL Stop") break for val_epoch in range(self.sgd_epochs): val_loss = self.value_update(batch_obs, batch_discrew) # update observation mean and variance if self.normalize_obs: self.obs_mean = update_mean(batch_obs, self.obs_mean, cur_total_steps) self.obs_std = update_std(batch_obs, self.obs_std, cur_total_steps) self.model.policy.state_means = self.obs_mean self.model.value_fn.state_means = self.obs_mean self.model.policy.state_std = self.obs_std self.model.value_fn.state_std = self.obs_std sgd_lr = lr_lookup(cur_total_steps) self.old_model = copy.deepcopy(self.model) self.val_loss_hist.append(val_loss.detach()) self.pol_loss_hist.append(pol_loss.detach()) self.lrp_hist.append( self.pol_opt.state_dict()['param_groups'][0]['lr']) self.lrv_hist.append( self.val_opt.state_dict()['param_groups'][0]['lr']) self.kl_hist.append(approx_kl.detach()) self.entropy_hist.append(self.model.policy.logstds.detach()) progress_bar.update(cur_batch_steps) progress_bar.close() return self.model, self.raw_rew_hist, locals()
def ars(env_name, n_epochs, env_config, step_size, n_delta, n_top, exp_noise, n_workers, policy, seed): torch.autograd.set_grad_enabled(False) # Gradient free baby! pool = Pool(processes=n_workers) W = torch.nn.utils.parameters_to_vector(policy.parameters()) n_param = W.shape[0] if env_config is None: env_config = {} env = gym.make(env_name, **env_config) env.seed(seed) torch.manual_seed(seed) np.random.seed(seed) total_steps = 0 r_hist = [] exp_dist = torch.distributions.Normal(torch.zeros(n_delta, n_param), torch.ones(n_delta, n_param)) do_rollout_partial = partial(do_rollout_train, env_name, policy) for _ in range(n_epochs): deltas = exp_dist.sample() ### pm_W = torch.cat((W + (deltas * exp_noise), W - (deltas * exp_noise))) results = pool.map(do_rollout_partial, pm_W) states = torch.empty(0) p_returns = [] m_returns = [] l_returns = [] top_returns = [] for p_result, m_result in zip(results[:n_delta], results[n_delta:]): ps, pr, plr = p_result ms, mr, mlr = m_result states = torch.cat((states, ms, ps), dim=0) p_returns.append(pr) m_returns.append(mr) l_returns.append(plr) l_returns.append(mlr) top_returns.append(max(pr, mr)) top_idx = sorted(range(len(top_returns)), key=lambda k: top_returns[k], reverse=True)[:n_top] p_returns = torch.stack(p_returns)[top_idx] m_returns = torch.stack(m_returns)[top_idx] l_returns = torch.stack(l_returns)[top_idx] r_hist.append(l_returns.mean()) ### W = W + (step_size / (n_delta * torch.cat( (p_returns, m_returns)).std() + 1e-6)) * torch.sum( (p_returns - m_returns) * deltas[top_idx].T, dim=1) ep_steps = states.shape[0] policy.state_means = update_mean(states, policy.state_means, total_steps) policy.state_std = update_std(states, policy.state_std, total_steps) do_rollout_partial = partial(do_rollout_train, env_name, policy) total_steps += ep_steps torch.nn.utils.vector_to_parameters(W, policy.parameters()) return policy, r_hist
def ppo_dim( env_name, total_steps, model, transient_length = 50, act_var_schedule=[0.7], epoch_batch_size=2048, gamma=0.99, lam=0.99, eps=0.2, seed=0, pol_batch_size=1024, val_batch_size=1024, pol_lr=1e-4, val_lr=1e-4, pol_epochs=10, val_epochs=10, target_kl=.01, use_gpu=False, reward_stop=None, normalize_return=True, env_config={} ): """ Implements proximal policy optimization with clipping Args: env_name: name of the openAI gym environment to solve total_steps: number of timesteps to run the PPO for model: model from seagul.rl.models. Contains policy and value fn transient_length: act_var_schedule: schedule to set the variance of the policy. Will linearly interpolate values epoch_batch_size: number of environment steps to take per batch, total steps will be num_epochs*epoch_batch_size seed: seed for all the rngs gamma: discount applied to future rewards, usually close to 1 lam: lambda for the Advantage estimation, usually close to 1 eps: epsilon for the clipping, usually .1 or .2 pol_batch_size: batch size for policy updates val_batch_size: batch size for value function updates pol_lr: learning rate for policy pol_optimizer val_lr: learning rate of value function pol_optimizer pol_epochs: how many epochs to use for each policy update val_epochs: how many epochs to use for each value update target_kl: max KL before breaking use_gpu: want to use the GPU? set to true reward_stop: reward value to stop if we achieve normalize_return: should we normalize the return? env_config: dictionary containing kwargs to pass to your the environment Returns: model: trained model avg_reward_hist: list with the average reward per episode at each epoch var_dict: dictionary with all locals, for logging/debugging purposes Example: from seagul.rl.algos import ppo from seagul.nn import MLP from seagul.rl.models import PPOModel import torch input_size = 3 output_size = 1 layer_size = 64 num_layers = 2 policy = MLP(input_size, output_size, num_layers, layer_size) value_fn = MLP(input_size, 1, num_layers, layer_size) model = PPOModel(policy, value_fn) model, rews, var_dict = ppo("Pendulum-v0", 10000, model) """ # init everything # ============================================================================== torch.set_num_threads(1) env = gym.make(env_name, **env_config) if isinstance(env.action_space, gym.spaces.Box): act_size = env.action_space.shape[0] act_dtype = torch.double else: raise NotImplementedError("trying to use unsupported action space", env.action_space) actvar_lookup = make_variance_schedule(act_var_schedule, model, total_steps) model.action_var = actvar_lookup(0) obs_size = env.observation_space.shape[0] obs_mean = torch.zeros(obs_size) obs_var = torch.ones(obs_size) adv_mean = torch.zeros(1) adv_var = torch.ones(1) rew_mean = torch.zeros(1) rew_var = torch.ones(1) old_model = pickle.loads( pickle.dumps(model) ) # copy.deepcopy broke for me with older version of torch. Using pickle for this is weird but works fine pol_opt = torch.optim.Adam(model.policy.parameters(), lr=pol_lr) val_opt = torch.optim.Adam(model.value_fn.parameters(), lr=val_lr) # seed all our RNGs env.seed(seed) torch.manual_seed(seed) np.random.seed(seed) # set defaults, and decide if we are using a GPU or not use_cuda = torch.cuda.is_available() and use_gpu device = torch.device("cuda:0" if use_cuda else "cpu") # init logging stuff raw_rew_hist = [] val_loss_hist = [] pol_loss_hist = [] progress_bar = tqdm.tqdm(total=total_steps) cur_total_steps = 0 progress_bar.update(0) early_stop = False # Train until we hit our total steps or reach our reward threshold # ============================================================================== while cur_total_steps < total_steps: batch_obs = torch.empty(0) batch_act = torch.empty(0) batch_adv = torch.empty(0) batch_discrew = torch.empty(0) cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(raw_rew_hist) > 2 and reward_stop: if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[-2] >= reward_stop: early_stop = True break # construct batch data from rollouts # ============================================================================== while cur_batch_steps < epoch_batch_size: ep_obs, ep_act, ep_rew, ep_steps = do_rollout(env, model) ep_rew /= var_dim(ep_obs[transient_length:],order=1) raw_rew_hist.append(sum(ep_rew)) ep_rew = (ep_rew - ep_rew.mean()) / (ep_rew.std() + 1e-6) batch_obs = torch.cat((batch_obs, ep_obs[:-1])) batch_act = torch.cat((batch_act, ep_act[:-1])) ep_discrew = discount_cumsum( ep_rew, gamma ) # [:-1] because we appended the value function to the end as an extra reward batch_discrew = torch.cat((batch_discrew, ep_discrew[:-1])) if normalize_return: rew_mean = update_mean(batch_discrew, rew_mean, cur_total_steps) rew_var = update_std(batch_discrew, rew_var, cur_total_steps) batch_discrew = (batch_discrew - rew_mean) / (rew_var + 1e-6) # calculate this episodes advantages last_val = model.value_fn(ep_obs[-1]).reshape(-1, 1) ep_val = model.value_fn(ep_obs) ep_val[-1] = last_val deltas = ep_rew[:-1] + gamma * ep_val[1:] - ep_val[:-1] ep_adv = discount_cumsum(deltas.detach(), gamma * lam) batch_adv = torch.cat((batch_adv, ep_adv)) cur_batch_steps += ep_steps cur_total_steps += ep_steps # make sure our advantages are zero mean and unit variance adv_mean = update_mean(batch_adv, adv_mean, cur_total_steps) adv_var = update_std(batch_adv, adv_var, cur_total_steps) batch_adv = (batch_adv - adv_mean) / (adv_var + 1e-6) # policy update # ======================================================================== num_mbatch = int(batch_obs.shape[0] / pol_batch_size) # Update the policy using the PPO loss for pol_epoch in range(pol_epochs): for i in range(num_mbatch): cur_sample = i * pol_batch_size logp = model.get_logp(batch_obs[cur_sample:cur_sample + pol_batch_size], batch_act[cur_sample:cur_sample + pol_batch_size]).reshape(-1, act_size) old_logp = old_model.get_logp(batch_obs[cur_sample:cur_sample + pol_batch_size], batch_act[cur_sample:cur_sample + pol_batch_size]).reshape(-1, act_size) r = torch.exp(logp - old_logp) clip_r = torch.clamp(r, 1 - eps, 1 + eps) pol_loss = -torch.min(r * batch_adv[cur_sample:cur_sample + pol_batch_size], clip_r * batch_adv[cur_sample:cur_sample + pol_batch_size]).mean() approx_kl = (logp - old_logp).mean() if approx_kl > target_kl: break pol_opt.zero_grad() pol_loss.backward() pol_opt.step() # value_fn update # ======================================================================== num_mbatch = int(batch_obs.shape[0] / val_batch_size) # Update value function with the standard L2 Loss for val_epoch in range(val_epochs): for i in range(num_mbatch): cur_sample = i * pol_batch_size # predict and calculate loss for the batch val_preds = model.value_fn(batch_obs[cur_sample:cur_sample + pol_batch_size]) val_loss = ((val_preds - batch_discrew[cur_sample:cur_sample + pol_batch_size]) ** 2).mean() # do the normal pytorch update val_opt.zero_grad() val_loss.backward() val_opt.step() # update observation mean and variance obs_mean = update_mean(batch_obs, obs_mean, cur_total_steps) obs_var = update_std(batch_obs, obs_var, cur_total_steps) model.policy.state_means = obs_mean model.value_fn.state_means = obs_mean model.policy.state_std = obs_var model.value_fn.state_std = obs_var model.action_var = actvar_lookup(cur_total_steps) old_model = pickle.loads(pickle.dumps(model)) val_loss_hist.append(val_loss) pol_loss_hist.append(pol_loss) progress_bar.update(cur_batch_steps) progress_bar.close() return model, raw_rew_hist, locals()
def ars(env_name, policy, n_epochs, n_workers=8, step_size=.02, n_delta=32, n_top=16, exp_noise=0.03, zero_policy=True, postprocess=postprocess_default): torch.autograd.set_grad_enabled(False) """ Augmented Random Search https://arxiv.org/pdf/1803.07055 Args: Returns: Example: """ pool = Pool(processes=n_workers) env = gym.make(env_name) W = torch.nn.utils.parameters_to_vector(policy.parameters()) n_param = W.shape[0] if zero_policy: W = torch.zeros_like(W) r_hist = [] s_mean = torch.zeros(env.observation_space.shape[0]) s_stdv = torch.ones(env.observation_space.shape[0]) total_steps = 0 exp_dist = torch.distributions.Normal(torch.zeros(n_delta, n_param), torch.ones(n_delta, n_param)) do_rollout_partial = partial(do_rollout_train, env_name, policy, postprocess) for _ in range(n_epochs): deltas = exp_dist.sample() pm_W = torch.cat((W + (deltas * exp_noise), W - (deltas * exp_noise))) results = pool.map(do_rollout_partial, pm_W) states = torch.empty(0) p_returns = [] m_returns = [] l_returns = [] top_returns = [] for p_result, m_result in zip(results[:n_delta], results[n_delta:]): ps, pr, plr = p_result ms, mr, mlr = m_result states = torch.cat((states, ms, ps), dim=0) p_returns.append(pr) m_returns.append(mr) l_returns.append(plr) l_returns.append(mlr) top_returns.append(max(pr, mr)) top_idx = sorted(range(len(top_returns)), key=lambda k: top_returns[k], reverse=True)[:n_top] p_returns = torch.stack(p_returns)[top_idx] m_returns = torch.stack(m_returns)[top_idx] l_returns = torch.stack(l_returns)[top_idx] r_hist.append(l_returns.mean()) ep_steps = states.shape[0] s_mean = update_mean(states, s_mean, total_steps) s_stdv = update_std(states, s_stdv, total_steps) total_steps += ep_steps policy.state_means = s_mean policy.state_std = s_stdv do_rollout_partial = partial(do_rollout_train, env_name, policy, postprocess) W = W + (step_size / (n_delta * torch.cat( (p_returns, m_returns)).std() + 1e-6)) * torch.sum( (p_returns - m_returns) * deltas[top_idx].T, dim=1) pool.terminate() torch.nn.utils.vector_to_parameters(W, policy.parameters()) return policy, r_hist
def ars(env_name, policy, n_epochs, env_config={}, n_workers=8, step_size=.02, n_delta=32, n_top=16, exp_noise=0.03, zero_policy=True, learn_means=True, postprocess=postprocess_default): torch.autograd.set_grad_enabled(False) """ Augmented Random Search https://arxiv.org/pdf/1803.07055 Args: Returns: Example: """ proc_list = [] master_pipe_list = [] for i in range(n_workers): master_con, worker_con = Pipe() proc = Process(target=worker_fn, args=(worker_con, env_name, env_config, policy, postprocess)) proc.start() proc_list.append(proc) master_pipe_list.append(master_con) W = torch.nn.utils.parameters_to_vector(policy.parameters()) n_param = W.shape[0] if zero_policy: W = torch.zeros_like(W) env = gym.make(env_name, **env_config) s_mean = policy.state_means s_std = policy.state_std total_steps = 0 env.close() r_hist = [] lr_hist = [] exp_dist = torch.distributions.Normal(torch.zeros(n_delta, n_param), torch.ones(n_delta, n_param)) for epoch in range(n_epochs): deltas = exp_dist.sample() pm_W = torch.cat((W + (deltas * exp_noise), W - (deltas * exp_noise))) for i, Ws in enumerate(pm_W): master_pipe_list[i % n_workers].send((Ws, s_mean, s_std)) results = [] for i, _ in enumerate(pm_W): results.append(master_pipe_list[i % n_workers].recv()) states = torch.empty(0) p_returns = [] m_returns = [] l_returns = [] top_returns = [] for p_result, m_result in zip(results[:n_delta], results[n_delta:]): ps, pr, plr = p_result ms, mr, mlr = m_result states = torch.cat((states, ms, ps), dim=0) p_returns.append(pr) m_returns.append(mr) l_returns.append(plr) l_returns.append(mlr) top_returns.append(max(pr, mr)) top_idx = sorted(range(len(top_returns)), key=lambda k: top_returns[k], reverse=True)[:n_top] p_returns = torch.stack(p_returns)[top_idx] m_returns = torch.stack(m_returns)[top_idx] l_returns = torch.stack(l_returns)[top_idx] lr_hist.append(l_returns.mean()) r_hist.append((p_returns.mean() + m_returns.mean()) / 2) ep_steps = states.shape[0] s_mean = update_mean(states, s_mean, total_steps) s_std = update_std(states, s_std, total_steps) total_steps += ep_steps if epoch % 5 == 0: print( f"epoch: {epoch}, reward: {lr_hist[-1].item()}, processed reward: {r_hist[-1].item()} " ) W = W + (step_size / (n_delta * torch.cat( (p_returns, m_returns)).std() + 1e-6)) * torch.sum( (p_returns - m_returns) * deltas[top_idx].T, dim=1) for pipe in master_pipe_list: pipe.send("STOP") policy.state_means = s_mean policy.state_std = s_std torch.nn.utils.vector_to_parameters(W, policy.parameters()) return policy, r_hist, lr_hist
def learn(self, n_epochs): torch.autograd.set_grad_enabled(False) proc_list = [] master_pipe_list = [] learn_start_idx = copy.copy(self.total_epochs) for i in range(self.n_workers): master_con, worker_con = Pipe() proc = Process(target=worker_fn, args=(worker_con, self.env_name, self.env_config, self.policy, self.postprocessor, self.seed)) proc.start() proc_list.append(proc) master_pipe_list.append(master_con) W = torch.nn.utils.parameters_to_vector(self.policy.parameters()) n_param = W.shape[0] torch.manual_seed(self.seed) exp_dist = torch.distributions.Normal( torch.zeros(self.n_delta, n_param), torch.ones(self.n_delta, n_param)) for _ in range(n_epochs): deltas = exp_dist.sample() pm_W = torch.cat( (W + (deltas * self.exp_noise), W - (deltas * self.exp_noise))) for i, Ws in enumerate(pm_W): master_pipe_list[i % self.n_workers].send( (Ws, self.policy.state_means, self.policy.state_std)) results = [] for i, _ in enumerate(pm_W): results.append(master_pipe_list[i % self.n_workers].recv()) states = torch.empty(0) p_returns = [] m_returns = [] l_returns = [] top_returns = [] for p_result, m_result in zip(results[:self.n_delta], results[self.n_delta:]): ps, pr, plr = p_result ms, mr, mlr = m_result states = torch.cat((states, ms, ps), dim=0) p_returns.append(pr) m_returns.append(mr) l_returns.append(plr) l_returns.append(mlr) top_returns.append(max(pr, mr)) top_idx = sorted(range(len(top_returns)), key=lambda k: top_returns[k], reverse=True)[:self.n_top] p_returns = torch.stack(p_returns)[top_idx] m_returns = torch.stack(m_returns)[top_idx] l_returns = torch.stack(l_returns)[top_idx] self.lr_hist.append(l_returns.mean()) self.r_hist.append((p_returns.mean() + m_returns.mean()) / 2) ep_steps = states.shape[0] self.policy.state_means = update_mean(states, self.policy.state_means, self.total_steps) self.policy.state_std = update_std(states, self.policy.state_std, self.total_steps) self.total_steps += ep_steps self.total_epochs += 1 W = W + (self.step_size / (self.n_delta * torch.cat( (p_returns, m_returns)).std() + 1e-6)) * torch.sum( (p_returns - m_returns) * deltas[top_idx].T, dim=1) for pipe in master_pipe_list: pipe.send("STOP") for proc in proc_list: proc.join() torch.nn.utils.vector_to_parameters(W, self.policy.parameters()) return self.lr_hist[learn_start_idx:]
def ppo_visit( env_name, total_steps, model, vc=.01, replay_buf_size=int(5e4), act_std_schedule=(0.7,), epoch_batch_size=2048, gamma=0.99, lam=0.95, eps=0.2, seed=0, entropy_coef=0.0, sgd_batch_size=1024, lr_schedule=(3e-4,), sgd_epochs=10, target_kl=float('inf'), val_coef=.5, clip_val=True, env_no_term_steps=0, use_gpu=False, reward_stop=None, normalize_return=True, normalize_obs=True, normalize_adv=True, env_config={} ): """ Implements proximal policy optimization with clipping Args: env_name: name of the openAI gym environment to solve total_steps: number of timesteps to run the PPO for model: model from seagul.rl.models. Contains policy and value fn act_std_schedule: schedule to set the variance of the policy. Will linearly interpolate values epoch_batch_size: number of environment steps to take per batch, total steps will be num_epochs*epoch_batch_size seed: seed for all the rngs gamma: discount applied to future rewards, usually close to 1 lam: lambda for the Advantage estimation, usually close to 1 eps: epsilon for the clipping, usually .1 or .2 sgd_batch_size: batch size for policy updates sgd_batch_size: batch size for value function updates lr_schedule: learning rate for policy pol_optimizer sgd_epochs: how many epochs to use for each policy update val_epochs: how many epochs to use for each value update target_kl: max KL before breaking use_gpu: want to use the GPU? set to true reward_stop: reward value to stop if we achieve normalize_return: should we normalize the return? env_config: dictionary containing kwargs to pass to your the environment Returns: model: trained model avg_reward_hist: list with the average reward per episode at each epoch var_dict: dictionary with all locals, for logging/debugging purposes Example: from seagul.rl.algos import ppo from seagul.nn import MLP from seagul.rl.models import PPOModel import torch input_size = 3 output_size = 1 layer_size = 64 num_layers = 2 policy = MLP(input_size, output_size, num_layers, layer_size) value_fn = MLP(input_size, 1, num_layers, layer_size) model = PPOModel(policy, value_fn) model, rews, var_dict = ppo("Pendulum-v0", 10000, model) """ # init everything # ============================================================================== torch.set_num_threads(1) env = gym.make(env_name, **env_config) if isinstance(env.action_space, gym.spaces.Box): act_size = env.action_space.shape[0] act_dtype = torch.double else: raise NotImplementedError("trying to use unsupported action space", env.action_space) replay_buf = ReplayBuffer(env.observation_space.shape[0], act_size, replay_buf_size) actstd_lookup = make_schedule(act_std_schedule, total_steps) lr_lookup = make_schedule(lr_schedule, total_steps) model.action_var = actstd_lookup(0) sgd_lr = lr_lookup(0) obs_size = env.observation_space.shape[0] obs_mean = torch.zeros(obs_size) obs_std = torch.ones(obs_size) rew_mean = torch.zeros(1) rew_std = torch.ones(1) # copy.deepcopy broke for me with older version of torch. Using pickle for this is weird but works fine old_model = pickle.loads(pickle.dumps(model)) # seed all our RNGs env.seed(seed) torch.manual_seed(seed) np.random.seed(seed) # set defaults, and decide if we are using a GPU or not use_cuda = torch.cuda.is_available() and use_gpu device = torch.device("cuda:0" if use_cuda else "cpu") # init logging stuff raw_rew_hist = [] val_loss_hist = [] pol_loss_hist = [] progress_bar = tqdm.tqdm(total=total_steps) cur_total_steps = 0 progress_bar.update(0) early_stop = False # Train until we hit our total steps or reach our reward threshold # ============================================================================== while cur_total_steps < total_steps: pol_opt = torch.optim.Adam(model.policy.parameters(), lr=sgd_lr) val_opt = torch.optim.Adam(model.value_fn.parameters(), lr=sgd_lr) batch_obs = torch.empty(0) batch_act = torch.empty(0) batch_adv = torch.empty(0) batch_discrew = torch.empty(0) cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(raw_rew_hist) > 2 and reward_stop: if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[-2] >= reward_stop: early_stop = True break # construct batch data from rollouts # ============================================================================== while cur_batch_steps < epoch_batch_size: ep_obs, ep_act, ep_rew, ep_steps, ep_term = do_rollout(env, model, env_no_term_steps) raw_rew_hist.append(sum(ep_rew).item()) for i, obs in enumerate(ep_obs): ep_rew[i] -= (np.min(np.linalg.norm(obs - replay_buf.obs1_buf, axis=1)))*vc replay_buf.store(ep_obs, ep_obs, ep_act, ep_rew, ep_rew) batch_obs = torch.cat((batch_obs, ep_obs[:-1])) batch_act = torch.cat((batch_act, ep_act[:-1])) if not ep_term: ep_rew[-1] = model.value_fn(ep_obs[-1]).detach() ep_discrew = discount_cumsum(ep_rew, gamma) if normalize_return: rew_mean = update_mean(batch_discrew, rew_mean, cur_total_steps) rew_std = update_std(ep_discrew, rew_std, cur_total_steps) ep_discrew = ep_discrew / (rew_std + 1e-6) batch_discrew = torch.cat((batch_discrew, ep_discrew[:-1])) ep_val = model.value_fn(ep_obs) deltas = ep_rew[:-1] + gamma * ep_val[1:] - ep_val[:-1] ep_adv = discount_cumsum(deltas.detach(), gamma * lam) batch_adv = torch.cat((batch_adv, ep_adv)) cur_batch_steps += ep_steps cur_total_steps += ep_steps # make sure our advantages are zero mean and unit variance if normalize_adv: #adv_mean = update_mean(batch_adv, adv_mean, cur_total_steps) #adv_var = update_std(batch_adv, adv_var, cur_total_steps) batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-6) num_mbatch = int(batch_obs.shape[0] / sgd_batch_size) # Update the policy using the PPO loss for pol_epoch in range(sgd_epochs): for i in range(num_mbatch): # policy update # ======================================================================== cur_sample = i * sgd_batch_size # Transfer to GPU (if GPU is enabled, else this does nothing) local_obs = batch_obs[cur_sample:cur_sample + sgd_batch_size] local_act = batch_act[cur_sample:cur_sample + sgd_batch_size] local_adv = batch_adv[cur_sample:cur_sample + sgd_batch_size] local_val = batch_discrew[cur_sample:cur_sample + sgd_batch_size] # Compute the loss logp = model.get_logp(local_obs, local_act).reshape(-1, act_size) old_logp = old_model.get_logp(local_obs, local_act).reshape(-1, act_size) mean_entropy = -(logp*torch.exp(logp)).mean() r = torch.exp(logp - old_logp) clip_r = torch.clamp(r, 1 - eps, 1 + eps) pol_loss = -torch.min(r * local_adv, clip_r * local_adv).mean() - entropy_coef*mean_entropy approx_kl = ((logp - old_logp)**2).mean() if approx_kl > target_kl: break pol_opt.zero_grad() pol_loss.backward() pol_opt.step() # value_fn update # ======================================================================== val_preds = model.value_fn(local_obs) if clip_val: old_val_preds = old_model.value_fn(local_obs) val_preds_clipped = old_val_preds + torch.clamp(val_preds - old_val_preds, -eps, eps) val_loss1 = (val_preds_clipped - local_val)**2 val_loss2 = (val_preds - local_val)**2 val_loss = val_coef*torch.max(val_loss1, val_loss2).mean() else: val_loss = val_coef*((val_preds - local_val) ** 2).mean() val_opt.zero_grad() val_loss.backward() val_opt.step() # update observation mean and variance if normalize_obs: obs_mean = update_mean(batch_obs, obs_mean, cur_total_steps) obs_std = update_std(batch_obs, obs_std, cur_total_steps) model.policy.state_means = obs_mean model.value_fn.state_means = obs_mean model.policy.state_std = obs_std model.value_fn.state_std = obs_std model.action_std = actstd_lookup(cur_total_steps) sgd_lr = lr_lookup(cur_total_steps) old_model = pickle.loads(pickle.dumps(model)) val_loss_hist.append(val_loss) pol_loss_hist.append(pol_loss) progress_bar.update(cur_batch_steps) progress_bar.close() return model, raw_rew_hist, locals()