def learn(self, n_epochs, verbose=True): proc_list = [] master_q_list = [] worker_q_list = [] learn_start_idx = copy.copy(self.total_epochs) if self.step_schedule: step_lookup = make_schedule(self.step_schedule, n_epochs) if self.exp_schedule: exp_lookup = make_schedule(self.exp_schedule, n_epochs) for i in range(self.n_workers): master_q = Queue() worker_q = Queue() proc = Process(target=worker_fn, args=(worker_q, master_q, self.model_list[0], self.env_name, self.env_config, self.postprocessor, self.seed)) proc.start() proc_list.append(proc) master_q_list.append(master_q) worker_q_list.append(worker_q) n_param = self.W_flat_list[0].shape[0] rng = default_rng() for epoch in range(n_epochs): if self.step_schedule: self.step_size = step_lookup(epoch) if self.exp_schedule: self.exp_noise = exp_lookup(epoch) if len(self.raw_rew_hist) > 2 and self.reward_stop: if self.raw_rew_hist[ -1] >= self.reward_stop and self.raw_rew_hist[ -2] >= self.reward_stop: early_stop = True break seeds = rng.integers(2**32, size=self.n_delta) delta_list = [] top_returns_list = [] m_returns_list = [] p_returns_list = [] states_list = [] with torch.no_grad(): for model_i, W_flat in enumerate(self.W_flat_list): deltas = rng.standard_normal((self.n_delta, n_param)) delta_list.append(deltas) W_plus_delta = np.concatenate( (W_flat + (deltas * self.exp_noise), W_flat - (deltas * self.exp_noise))) seeds = np.concatenate([seeds, seeds]) start = time.time() for i, Ws in enumerate(W_plus_delta): master_q_list[i % self.n_workers].put( (Ws, self.state_mean_list[model_i], self.state_std_list[model_i], seeds[i])) results = [] for i, _ in enumerate(W_plus_delta): results.append(worker_q_list[i % self.n_workers].get()) end = time.time() t = (end - start) states = [] p_returns = [] m_returns = [] top_returns = [] for p_result, m_result in zip(results[:self.n_delta], results[self.n_delta:]): ps, pr, plr = p_result ms, mr, mlr = m_result p_returns.append(pr) m_returns.append(mr) top_returns.append(max([pr, mr])) top_states = [ps, ms][np.argmax([pr, mr]).item()] states.append(top_states) states_list.append(states) top_returns_list.append(top_returns) m_returns_list.append(m_returns) p_returns_list.append(p_returns) concat_states = np.concatenate(states) self.state_mean_list[model_i] = update_mean( concat_states, self.state_mean_list[model_i], self.total_steps_list[model_i]) self.state_std_list[model_i] = update_std( concat_states, self.state_std_list[model_i], self.total_steps_list[model_i]) ep_steps = concat_states.shape[0] self.total_steps_list[model_i] += ep_steps # Classifier Update ====================================================================================== T = np.array(top_returns_list) Y = np.argmax(T, axis=0) Ytrain = [] Xtrain = [] for i, y in enumerate(Y): # print(f"states_lists[0][{i}][0] = {states_list[0][i][0]}") # print(f"states_lists[1][{i}][0] = {states_list[1][i][0]}") for x in states_list[y][i]: Xtrain.append(x) Ytrain.append(y) Xtrain = np.array(Xtrain, dtype=np.float32) Ytrain = np.array(Ytrain) print(Xtrain.shape) print(Ytrain.shape) loss_hist = fit_model(self.classifier, Xtrain, Ytrain, 5, batch_size=64, loss_fn=torch.nn.CrossEntropyLoss()) # ARS Update ============================================================================================ with torch.no_grad(): train_return_list = [[] for _ in range(T.shape[0])] train_m_list = [[] for _ in range(T.shape[0])] train_p_list = [[] for _ in range(T.shape[0])] train_delta_list = [[] for _ in range(T.shape[0])] for i, y in enumerate(Y): train_return_list[y].append(top_returns_list[y][i]) train_p_list[y].append(m_returns_list[y][i]) train_m_list[y].append(p_returns_list[y][i]) train_delta_list[y].append(delta_list[y][i]) for i, _ in enumerate(self.W_flat_list): top_returns = train_return_list[i] p_returns = train_p_list[i] m_returns = train_m_list[i] deltas = np.array(train_delta_list[i]) if len(top_returns) == 0: continue top_idx = sorted(range(len(top_returns)), key=lambda k: top_returns[k], reverse=True)[:self.n_top] p_returns = np.stack(p_returns)[top_idx] m_returns = np.stack(m_returns)[top_idx] #print(f"{i} : {self.model_list[i].policy.state_dict()}") #print(f" {i} : {self.W_flat_list[i]}") self.W_flat_list[i] = self.W_flat_list[i] + ( self.step_size / (self.n_delta * np.concatenate( (p_returns, m_returns)).std() + 1e-6)) * np.sum( (p_returns - m_returns) * deltas[top_idx].T, axis=1) #print(f"{i} : {self.model_list[i].policy.state_dict()}") print(f" {i} : {self.W_flat_list[i]}") # if verbose and epoch % 10 == 0: # print(f"{epoch} : mean return: {l_returns.mean()}, top_return: {l_returns.max()}, fps:{states.shape[0]/t}") # self.raw_rew_hist.append(np.stack(top_returns)[top_idx].mean()) # self.r_hist.append((p_returns.mean() + m_returns.mean())/2) self.total_epochs += 1 for q in master_q_list: q.put("STOP") for proc in proc_list: proc.join() #print(f" model 0 state dict before: {self.model_list[0].policy.state_dict()}") for i, _ in enumerate(self.model_list): print(f" model {i} w_flat: {self.W_flat_list[i]}") torch.nn.utils.vector_to_parameters( torch.tensor(self.W_flat_list[i]), self.model_list[i].policy.parameters()) self.model = ARSSwitchingModel(self.model_list, self.classifier) # self.model.policy.state_means = torch.from_numpy(self.state_mean) # self.model.policy.state_std = torch.from_numpy(self.state_std) # # torch.set_grad_enabled(True) return self.model, self.raw_rew_hist[learn_start_idx:], locals()
def learn(self, train_steps): """ runs sac for train_steps Returns: model: trained model avg_reward_hist: list with the average reward per episode at each epoch var_dict: dictionary with all locals, for logging/debugging purposes """ torch.set_num_threads(1) # performance issue with data loader env = gym.make(self.env_name, **self.env_config) if isinstance(env.action_space, gym.spaces.Box): act_size = env.action_space.shape[0] act_dtype = env.action_space.sample().dtype else: raise NotImplementedError("trying to use unsupported action space", env.action_space) obs_size = env.observation_space.shape[0] random_model = RandModel(self.model.act_limit, act_size) self.replay_buf = ReplayBuffer(obs_size, act_size, self.replay_buf_size) self.target_value_fn = copy.deepcopy(self.model.value_fn) pol_opt = torch.optim.Adam(self.model.policy.parameters(), lr=self.sgd_lr) val_opt = torch.optim.Adam(self.model.value_fn.parameters(), lr=self.sgd_lr) q1_opt = torch.optim.Adam(self.model.q1_fn.parameters(), lr=self.sgd_lr) q2_opt = torch.optim.Adam(self.model.q2_fn.parameters(), lr=self.sgd_lr) if self.sgd_lr_sched: sgd_lookup = make_schedule(self.sgd_lr_sched, train_steps) else: sgd_lookup = None # seed all our RNGs env.seed(self.seed) torch.manual_seed(self.seed) np.random.seed(self.seed) # set defaults, and decide if we are using a GPU or not use_cuda = torch.cuda.is_available() and self.use_gpu device = torch.device("cuda:0" if use_cuda else "cpu") self.raw_rew_hist = [] self.val_loss_hist = [] self.pol_loss_hist = [] self.q1_loss_hist = [] self.q2_loss_hist = [] progress_bar = tqdm.tqdm(total=train_steps + self.normalize_steps) cur_total_steps = 0 progress_bar.update(0) early_stop = False norm_obs1 = torch.empty(0) while cur_total_steps < self.normalize_steps: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, self.env_max_steps) norm_obs1 = torch.cat((norm_obs1, ep_obs1)) ep_steps = ep_rews.shape[0] cur_total_steps += ep_steps progress_bar.update(ep_steps) if self.normalize_steps > 0: obs_mean = norm_obs1.mean(axis=0) obs_std = norm_obs1.std(axis=0) obs_std[torch.isinf(1/obs_std)] = 1 self.model.policy.state_means = obs_mean self.model.policy.state_std = obs_std self.model.value_fn.state_means = obs_mean self.model.value_fn.state_std = obs_std self.target_value_fn.state_means = obs_mean self.target_value_fn.state_std = obs_std self.model.q1_fn.state_means = torch.cat((obs_mean, torch.zeros(act_size))) self.model.q1_fn.state_std = torch.cat((obs_std, torch.ones(act_size))) self.model.q2_fn.state_means = self.model.q1_fn.state_means self.model.q2_fn.state_std = self.model.q1_fn.state_std while cur_total_steps < self.exploration_steps: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, self.env_max_steps) self.replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) ep_steps = ep_rews.shape[0] cur_total_steps += ep_steps progress_bar.update(ep_steps) while cur_total_steps < train_steps: cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(self.raw_rew_hist) > 2 and self.reward_stop: if self.raw_rew_hist[-1] >= self.reward_stop and self.raw_rew_hist[-2] >= self.reward_stop: early_stop = True break # collect data with the current policy # ======================================================================== while cur_batch_steps < self.min_steps_per_update: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, self.model, self.env_max_steps) self.replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) ep_steps = ep_rews.shape[0] cur_batch_steps += ep_steps cur_total_steps += ep_steps self.raw_rew_hist.append(torch.sum(ep_rews)) #print(self.raw_rew_hist[-1]) progress_bar.update(cur_batch_steps) for _ in range(min(int(ep_steps), self.iters_per_update)): torch.autograd.set_grad_enabled(False) # compute targets for Q and V # ======================================================================== replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = self.replay_buf.sample_batch(self.replay_batch_size) q_targ = replay_rews + self.gamma * (1 - replay_done) * self.target_value_fn(replay_obs2) noise = torch.randn(self.replay_batch_size, act_size) sample_acts, sample_logp = self.model.select_action(replay_obs1, noise) q_in = torch.cat((replay_obs1, sample_acts), dim=1) q_preds = torch.cat((self.model.q1_fn(q_in), self.model.q2_fn(q_in)), dim=1) q_min, q_min_idx = torch.min(q_preds, dim=1) q_min = q_min.reshape(-1,1) v_targ = q_min - self.alpha * sample_logp #v_targ = v_targ torch.autograd.set_grad_enabled(True) # q_fn update # ======================================================================== num_mbatch = int(self.replay_batch_size / self.sgd_batch_size) for i in range(num_mbatch): cur_sample = i*self.sgd_batch_size q_in = torch.cat((replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], replay_acts[cur_sample:cur_sample + self.sgd_batch_size]), dim=1) q1_preds = self.model.q1_fn(q_in) q2_preds = self.model.q2_fn(q_in) q1_loss = torch.pow(q1_preds - q_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean() q2_loss = torch.pow(q2_preds - q_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean() q_loss = q1_loss + q2_loss q1_opt.zero_grad() q2_opt.zero_grad() q_loss.backward() q1_opt.step() q2_opt.step() # val_fn update # ======================================================================== for i in range(num_mbatch): cur_sample = i*self.sgd_batch_size # predict and calculate loss for the batch val_preds = self.model.value_fn(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size]) val_loss = torch.pow(val_preds - v_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean() # do the normal pytorch update val_opt.zero_grad() val_loss.backward() val_opt.step() # policy_fn update # ======================================================================== for param in self.model.q1_fn.parameters(): param.requires_grad = False for i in range(num_mbatch): cur_sample = i*self.sgd_batch_size noise = torch.randn(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size].shape[0], act_size) local_acts, local_logp = self.model.select_action(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], noise) q_in = torch.cat((replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], local_acts), dim=1) pol_loss = (self.alpha * local_logp - self.model.q1_fn(q_in)).mean() pol_opt.zero_grad() pol_loss.backward() pol_opt.step() for param in self.model.q1_fn.parameters(): param.requires_grad = True # Update target value fn with polyak average # ======================================================================== self.val_loss_hist.append(val_loss.item()) self.pol_loss_hist.append(pol_loss.item()) self.q1_loss_hist.append(q1_loss.item()) self.q2_loss_hist.append(q2_loss.item()) val_sd = self.model.value_fn.state_dict() tar_sd = self.target_value_fn.state_dict() for layer in tar_sd: tar_sd[layer] = self.polyak * tar_sd[layer] + (1 - self.polyak) * val_sd[layer] self.target_value_fn.load_state_dict(tar_sd) #Update LRs if sgd_lookup: pol_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps) val_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps) q1_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps) q2_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps) return self.model, self.raw_rew_hist, locals()
def learn(self, n_epochs, verbose=True): torch.set_grad_enabled(False) proc_list = [] master_q_list = [] worker_q_list = [] learn_start_idx = copy.copy(self.total_epochs) if self.step_schedule: step_lookup = make_schedule(self.step_schedule, n_epochs) if self.exp_schedule: exp_lookup = make_schedule(self.exp_schedule, n_epochs) for i in range(self.n_workers): master_q = Queue() worker_q = Queue() proc = Process(target=worker_fn, args=(worker_q, master_q, self.algo, self.env_name, self.postprocessor, self.get_trainable, self.seed)) proc.start() proc_list.append(proc) master_q_list.append(master_q) worker_q_list.append(worker_q) n_param = self.W_flat.shape[0] rng = default_rng() for epoch in range(n_epochs): if self.step_schedule: self.step_size = step_lookup(epoch) if self.exp_schedule: self.exp_noise = exp_lookup(epoch) if len(self.raw_rew_hist) > 2 and self.reward_stop: if self.raw_rew_hist[ -1] >= self.reward_stop and self.raw_rew_hist[ -2] >= self.reward_stop: early_stop = True break deltas = rng.standard_normal((self.n_delta, n_param), dtype=np.float32) #import ipdb; ipdb.set_trace() pm_W = np.concatenate((self.W_flat + (deltas * self.exp_noise), self.W_flat - (deltas * self.exp_noise))) start = time.time() seeds = np.random.randint(1, 2**32 - 1, self.n_delta) for i, Ws in enumerate(pm_W): # if self.epoch_seed: # epoch_seed = i%self.n_delta # else: # epoch_seed = None epoch_seed = int(seeds[i % self.n_delta]) #epoch_seed = None master_q_list[i % self.n_workers].put((Ws, False, epoch_seed)) results = [] for i, _ in enumerate(pm_W): results.append(worker_q_list[i % self.n_workers].get()) end = time.time() t = (end - start) p_returns = [] m_returns = [] l_returns = [] top_returns = [] for p_result, m_result in zip(results[:self.n_delta], results[self.n_delta:]): pr, plr = p_result mr, mlr = m_result p_returns.append(pr) m_returns.append(mr) l_returns.append(plr) l_returns.append(mlr) top_returns.append(max(pr, mr)) top_idx = sorted(range(len(top_returns)), key=lambda k: top_returns[k], reverse=True)[:self.n_top] p_returns = np.stack(p_returns).astype(np.float32)[top_idx] m_returns = np.stack(m_returns).astype(np.float32)[top_idx] l_returns = np.stack(l_returns).astype(np.float32)[top_idx] self.raw_rew_hist.append(np.stack(top_returns)[top_idx].mean()) self.r_hist.append((p_returns.mean() + m_returns.mean()) / 2) if verbose and epoch % 10 == 0: from seagul.zoo3_utils import do_rollout_stable env, model = load_zoo_agent(self.env_name, self.algo) torch.nn.utils.vector_to_parameters( torch.tensor(self.W_flat, requires_grad=False), self.get_trainable(self.model)) o, a, r, info = do_rollout_stable(env, self.model) if type(o[0]) == collections.OrderedDict: o, _, _ = dict_to_array(o) # o_mdim = o[200:] o_mdim = o try: mdim, cdim, _, _ = mesh_dim(o_mdim) except: mdim = np.nan cdim = np.nan print( f"{epoch} : mean return: {self.raw_rew_hist[-1]}, top_return: {np.stack(top_returns)[top_idx][0]}, mdim: {mdim}, cdim: {cdim}, eps:{self.n_delta*2/t}" ) self.total_epochs += 1 self.W_flat = self.W_flat + ( self.step_size / (self.n_delta * np.concatenate( (p_returns, m_returns)).std() + 1e-6)) * np.sum( (p_returns - m_returns) * deltas[top_idx].T, axis=1) for q in master_q_list: q.put((None, True, None)) for proc in proc_list: proc.join() torch.nn.utils.vector_to_parameters( torch.tensor(self.W_flat, requires_grad=False), self.get_trainable(self.model)) torch.set_grad_enabled(True) return self.model, self.raw_rew_hist[learn_start_idx:], locals()
def td3( env_name, train_steps, model, env_max_steps=0, min_steps_per_update=1, iters_per_update=200, replay_batch_size=64, seed=0, act_std_schedule=(.1,), gamma=0.95, polyak=0.995, sgd_batch_size=64, sgd_lr=3e-4, exploration_steps=1000, replay_buf_size=int(100000), reward_stop=None, env_config=None ): # Initialize env, and other globals # ======================================================================== if env_config is None: env_config = {} env = gym.make(env_name, **env_config) if isinstance(env.action_space, gym.spaces.Box): act_size = env.action_space.shape[0] act_dtype = env.action_space.sample().dtype else: raise NotImplementedError("trying to use unsupported action space", env.action_space) obs_size = env.observation_space.shape[0] # seed all our RNGs env.seed(seed) torch.manual_seed(seed) np.random.seed(seed) random_model = RandModel(model.act_limit, act_size) replay_buf = ReplayBuffer(obs_size, act_size, replay_buf_size) target_q1_fn = dill.loads(dill.dumps(model.q1_fn)) target_q2_fn = dill.loads(dill.dumps(model.q2_fn)) target_policy = dill.loads(dill.dumps(model.policy)) for param in target_q1_fn.parameters(): param.requires_grad = False for param in target_q2_fn.parameters(): param.requires_grad = False for param in target_policy.parameters(): param.requires_grad = False act_std_lookup = make_schedule(act_std_schedule, train_steps) act_std = act_std_lookup(0) pol_opt = torch.optim.Adam(model.policy.parameters(), lr=sgd_lr) q1_opt = torch.optim.Adam(model.q1_fn.parameters(), lr=sgd_lr) q2_opt = torch.optim.Adam(model.q2_fn.parameters(), lr=sgd_lr) progress_bar = tqdm.tqdm(total=train_steps) cur_total_steps = 0 progress_bar.update(0) early_stop = False raw_rew_hist = [] pol_loss_hist = [] q1_loss_hist = [] q2_loss_hist = [] # Fill the replay buffer with actions taken from a random model # ======================================================================== while cur_total_steps < exploration_steps: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, env_max_steps, act_std) replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) ep_steps = ep_rews.shape[0] cur_total_steps += ep_steps progress_bar.update(ep_steps) # Keep training until we take train_step environment steps # ======================================================================== while cur_total_steps < train_steps: cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(raw_rew_hist) > 2 and reward_stop: print(raw_rew_hist[-1]) if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[-2] >= reward_stop: early_stop = True break # collect data with the current policy # ======================================================================== while cur_batch_steps < min_steps_per_update: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, model, env_max_steps, act_std) replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) ep_steps = ep_rews.shape[0] cur_batch_steps += ep_steps cur_total_steps += ep_steps raw_rew_hist.append(torch.sum(ep_rews)) progress_bar.update(cur_batch_steps) # Do the update # ======================================================================== for _ in range(min(int(ep_steps), iters_per_update)): # Compute target Q replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = replay_buf.sample_batch(replay_batch_size) with torch.no_grad(): acts_from_target = target_policy(replay_obs2) q_in = torch.cat((replay_obs2, acts_from_target), dim=1) q_targ = replay_rews + gamma*(1 - replay_done)*target_q1_fn(q_in) num_mbatch = int(replay_batch_size / sgd_batch_size) # q_fn update # ======================================================================== for i in range(num_mbatch): cur_sample = i * sgd_batch_size q_in_local = torch.cat((replay_obs1[cur_sample:cur_sample + sgd_batch_size], replay_acts[cur_sample:cur_sample + sgd_batch_size]), dim=1) local_qtarg = q_targ[cur_sample:cur_sample + sgd_batch_size] q1_loss = ((model.q1_fn(q_in_local) - local_qtarg)**2).mean() #q2_preds = model.q2_fn(q_in) #q2_loss = (q2_preds - q_targ[cur_sample:cur_sample + sgd_batch_size]**2).mean() q_loss = q1_loss# + q2_loss q1_opt.zero_grad() #q2_opt.zero_grad() q_loss.backward() q1_opt.step() #q2_opt.step() # policy_fn update # ======================================================================== for param in model.q1_fn.parameters(): param.requires_grad = False for i in range(num_mbatch): cur_sample = i * sgd_batch_size local_obs = replay_obs1[cur_sample:cur_sample + sgd_batch_size] local_acts = model.policy(local_obs) q_in = torch.cat((local_obs, local_acts), dim=1) pol_loss = -(model.q1_fn(q_in).mean()) pol_opt.zero_grad() pol_loss.backward() pol_opt.step() for param in model.q1_fn.parameters(): param.requires_grad = True # Update target value fn with polyak average # ======================================================================== pol_loss_hist.append(pol_loss.item()) q1_loss_hist.append(q1_loss.item()) #q2_loss_hist.append(q2_loss.item()) target_q1_fn = update_target_fn(model.q1_fn, target_q1_fn, polyak) target_q2_fn = update_target_fn(model.q2_fn, target_q2_fn, polyak) target_policy = update_target_fn(model.policy, target_policy, polyak) act_std = act_std_lookup(cur_total_steps) return model, raw_rew_hist, locals()
def learn(self, n_epochs, verbose=True): torch.set_grad_enabled(False) proc_list = [] master_q_list = [] worker_q_list = [] learn_start_idx = copy.copy(self.total_epochs) if self.step_schedule: step_lookup = make_schedule(self.step_schedule, n_epochs) if self.exp_schedule: exp_lookup = make_schedule(self.exp_schedule, n_epochs) for i in range(self.n_workers): master_q = Queue() worker_q = Queue() proc = Process(target=worker_fn, args=(worker_q, master_q, self.model, self.env_name, self.env_config, self.postprocessor, self.seed)) proc.start() proc_list.append(proc) master_q_list.append(master_q) worker_q_list.append(worker_q) n_param = self.W_flat.shape[0] rng = default_rng() for epoch in range(n_epochs): if self.step_schedule: self.step_size = step_lookup(epoch) if self.exp_schedule: self.exp_noise = exp_lookup(epoch) if len(self.raw_rew_hist) > 2 and self.reward_stop: if self.raw_rew_hist[-1] >= self.reward_stop and self.raw_rew_hist[-2] >= self.reward_stop: early_stop = True break deltas = rng.standard_normal((self.n_delta, n_param)) #import ipdb; ipdb.set_trace() pm_W = np.concatenate((self.W_flat+(deltas*self.exp_noise), self.W_flat-(deltas*self.exp_noise))) start = time.time() for i,Ws in enumerate(pm_W): master_q_list[i % self.n_workers].put((Ws ,self.state_mean,self.state_std)) results = [] for i, _ in enumerate(pm_W): results.append(worker_q_list[i % self.n_workers].get()) end = time.time() t = (end - start) states = np.array([]).reshape(0,self.obs_size) p_returns = [] m_returns = [] l_returns = [] top_returns = [] for p_result, m_result in zip(results[:self.n_delta], results[self.n_delta:]): ps, pr, plr = p_result ms, mr, mlr = m_result states = np.concatenate((states, ms, ps), axis=0) p_returns.append(pr) m_returns.append(mr) l_returns.append(plr); l_returns.append(mlr) top_returns.append(max(pr,mr)) top_idx = sorted(range(len(top_returns)), key=lambda k: top_returns[k], reverse=True)[:self.n_top] p_returns = np.stack(p_returns)[top_idx] m_returns = np.stack(m_returns)[top_idx] l_returns = np.stack(l_returns)[top_idx] if verbose and epoch % 10 == 0: print(f"{epoch} : mean return: {l_returns.mean()}, top_return: {l_returns.max()}, fps:{states.shape[0]/t}") self.raw_rew_hist.append(np.stack(top_returns)[top_idx].mean()) self.r_hist.append((p_returns.mean() + m_returns.mean())/2) ep_steps = states.shape[0] self.state_mean = update_mean(states, self.state_mean, self.total_steps) self.state_std = update_std(states, self.state_std, self.total_steps) self.total_steps += ep_steps self.total_epochs += 1 self.W_flat = self.W_flat + (self.step_size / (self.n_delta * np.concatenate((p_returns, m_returns)).std() + 1e-6)) * np.sum((p_returns - m_returns)*deltas[top_idx].T, axis=1) for q in master_q_list: q.put("STOP") for proc in proc_list: proc.join() torch.nn.utils.vector_to_parameters(torch.tensor(self.W_flat), self.model.policy.parameters()) self.model.policy.state_means = torch.from_numpy(self.state_mean) self.model.policy.state_std = torch.from_numpy(self.state_std) torch.set_grad_enabled(True) return self.model, self.raw_rew_hist[learn_start_idx:], locals()
def learn(self, total_steps): """ The actual training loop Returns: model: trained model avg_reward_hist: list with the average reward per episode at each epoch var_dict: dictionary with all locals, for logging/debugging purposes """ # init everything # ============================================================================== # seed all our RNGs env = gym.make(self.env_name, **self.env_config) cur_total_steps = 0 env.seed(self.seed) torch.manual_seed(self.seed) np.random.seed(self.seed) progress_bar = tqdm.tqdm(total=total_steps) lr_lookup = make_schedule(self.lr_schedule, total_steps) self.sgd_lr = lr_lookup(0) progress_bar.update(0) early_stop = False self.pol_opt = torch.optim.RMSprop(self.model.policy.parameters(), lr=lr_lookup(cur_total_steps)) self.val_opt = torch.optim.RMSprop(self.model.value_fn.parameters(), lr=lr_lookup(cur_total_steps)) # Train until we hit our total steps or reach our reward threshold # ============================================================================== while cur_total_steps < total_steps: batch_obs = torch.empty(0) batch_act = torch.empty(0) batch_adv = torch.empty(0) batch_discrew = torch.empty(0) cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(self.raw_rew_hist) > 2 and self.reward_stop: if self.raw_rew_hist[ -1] >= self.reward_stop and self.raw_rew_hist[ -2] >= self.reward_stop: early_stop = True break # construct batch data from rollouts # ============================================================================== while cur_batch_steps < self.epoch_batch_size: ep_obs, ep_act, ep_rew, ep_steps, ep_term = do_rollout( env, self.model, self.env_no_term_steps) cur_batch_steps += ep_steps cur_total_steps += ep_steps #print(sum(ep_rew).item()) self.raw_rew_hist.append(sum(ep_rew).item()) #print("Rew:", sum(ep_rew).item()) batch_obs = torch.cat((batch_obs, ep_obs.clone())) batch_act = torch.cat((batch_act, ep_act.clone())) if self.normalize_return: self.rew_std = update_std(ep_rew, self.rew_std, cur_total_steps) ep_rew = ep_rew / (self.rew_std + 1e-6) if ep_term: ep_rew = torch.cat((ep_rew, torch.zeros(1, 1))) else: ep_rew = torch.cat((ep_rew, self.model.value_fn( ep_obs[-1]).detach().reshape(1, 1).clone())) ep_discrew = discount_cumsum(ep_rew, self.gamma)[:-1] batch_discrew = torch.cat((batch_discrew, ep_discrew.clone())) with torch.no_grad(): ep_val = torch.cat((self.model.value_fn(ep_obs), ep_rew[-1].reshape(1, 1).clone())) deltas = ep_rew[:-1] + self.gamma * ep_val[1:] - ep_val[:-1] ep_adv = discount_cumsum(deltas, self.gamma * self.lam) # make sure our advantages are zero mean and unit variance batch_adv = torch.cat((batch_adv, ep_adv.clone())) # PostProcess epoch and update weights # ============================================================================== if self.normalize_adv: # adv_mean = update_mean(batch_adv, adv_mean, cur_total_steps) # adv_var = update_std(batch_adv, adv_var, cur_total_steps) batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-6) # Update the policy using the PPO loss for pol_epoch in range(self.sgd_epochs): pol_loss, approx_kl = self.policy_update( batch_act, batch_obs, batch_adv) if approx_kl > self.target_kl: print("KL Stop") break for val_epoch in range(self.sgd_epochs): val_loss = self.value_update(batch_obs, batch_discrew) # update observation mean and variance if self.normalize_obs: self.obs_mean = update_mean(batch_obs, self.obs_mean, cur_total_steps) self.obs_std = update_std(batch_obs, self.obs_std, cur_total_steps) self.model.policy.state_means = self.obs_mean self.model.value_fn.state_means = self.obs_mean self.model.policy.state_std = self.obs_std self.model.value_fn.state_std = self.obs_std sgd_lr = lr_lookup(cur_total_steps) self.old_model = copy.deepcopy(self.model) self.val_loss_hist.append(val_loss.detach()) self.pol_loss_hist.append(pol_loss.detach()) self.lrp_hist.append( self.pol_opt.state_dict()['param_groups'][0]['lr']) self.lrv_hist.append( self.val_opt.state_dict()['param_groups'][0]['lr']) self.kl_hist.append(approx_kl.detach()) self.entropy_hist.append(self.model.policy.logstds.detach()) progress_bar.update(cur_batch_steps) progress_bar.close() return self.model, self.raw_rew_hist, locals()
def ppo_dim( env_name, total_steps, model, transient_length=50, act_std_schedule=(0.7,), epoch_batch_size=2048, gamma=0.99, lam=0.95, eps=0.2, seed=0, entropy_coef=0.0, sgd_batch_size=1024, lr_schedule=(3e-4,), sgd_epochs=10, target_kl=float('inf'), val_coef=.5, clip_val=True, env_no_term_steps=0, use_gpu=False, reward_stop=None, normalize_return=True, normalize_obs=True, normalize_adv=True, env_config={} ): """ Implements proximal policy optimization with clipping Args: env_name: name of the openAI gym environment to solve total_steps: number of timesteps to run the PPO for model: model from seagul.rl.models. Contains policy and value fn act_std_schedule: schedule to set the variance of the policy. Will linearly interpolate values epoch_batch_size: number of environment steps to take per batch, total steps will be num_epochs*epoch_batch_size seed: seed for all the rngs gamma: discount applied to future rewards, usually close to 1 lam: lambda for the Advantage estimation, usually close to 1 eps: epsilon for the clipping, usually .1 or .2 sgd_batch_size: batch size for policy updates sgd_batch_size: batch size for value function updates lr_schedule: learning rate for policy pol_optimizer sgd_epochs: how many epochs to use for each policy update val_epochs: how many epochs to use for each value update target_kl: max KL before breaking use_gpu: want to use the GPU? set to true reward_stop: reward value to stop if we achieve normalize_return: should we normalize the return? env_config: dictionary containing kwargs to pass to your the environment Returns: model: trained model avg_reward_hist: list with the average reward per episode at each epoch var_dict: dictionary with all locals, for logging/debugging purposes Example: from seagul.rl.algos import ppo from seagul.nn import MLP from seagul.rl.models import PPOModel import torch input_size = 3 output_size = 1 layer_size = 64 num_layers = 2 policy = MLP(input_size, output_size, num_layers, layer_size) value_fn = MLP(input_size, 1, num_layers, layer_size) model = PPOModel(policy, value_fn) model, rews, var_dict = ppo("Pendulum-v0", 10000, model) """ # init everything # ============================================================================== torch.set_num_threads(1) env = gym.make(env_name, **env_config) if isinstance(env.action_space, gym.spaces.Box): act_size = env.action_space.shape[0] act_dtype = torch.double else: raise NotImplementedError("trying to use unsupported action space", env.action_space) actstd_lookup = make_schedule(act_std_schedule, total_steps) lr_lookup = make_schedule(lr_schedule, total_steps) model.action_var = actstd_lookup(0) sgd_lr = lr_lookup(0) obs_size = env.observation_space.shape[0] obs_mean = torch.zeros(obs_size) obs_std = torch.ones(obs_size) rew_mean = torch.zeros(1) rew_std = torch.ones(1) # copy.deepcopy broke for me with older version of torch. Using pickle for this is weird but works fine old_model = pickle.loads(pickle.dumps(model)) # seed all our RNGs env.seed(seed) torch.manual_seed(seed) np.random.seed(seed) # set defaults, and decide if we are using a GPU or not use_cuda = torch.cuda.is_available() and use_gpu device = torch.device("cuda:0" if use_cuda else "cpu") # init logging stuff raw_rew_hist = [] val_loss_hist = [] pol_loss_hist = [] progress_bar = tqdm.tqdm(total=total_steps) cur_total_steps = 0 progress_bar.update(0) early_stop = False # Train until we hit our total steps or reach our reward threshold # ============================================================================== while cur_total_steps < total_steps: pol_opt = torch.optim.Adam(model.policy.parameters(), lr=sgd_lr) val_opt = torch.optim.Adam(model.value_fn.parameters(), lr=sgd_lr) batch_obs = torch.empty(0) batch_act = torch.empty(0) batch_adv = torch.empty(0) batch_discrew = torch.empty(0) cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(raw_rew_hist) > 2 and reward_stop: if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[-2] >= reward_stop: early_stop = True break # construct batch data from rollouts # ============================================================================== while cur_batch_steps < epoch_batch_size: ep_obs, ep_act, ep_rew, ep_steps, ep_term = do_rollout(env, model, env_no_term_steps) ep_rew /= var_dim(ep_obs[transient_length:],order=1) raw_rew_hist.append(sum(ep_rew).item()) batch_obs = torch.cat((batch_obs, ep_obs[:-1])) batch_act = torch.cat((batch_act, ep_act[:-1])) if not ep_term: ep_rew[-1] = model.value_fn(ep_obs[-1]).detach() ep_discrew = discount_cumsum(ep_rew, gamma) if normalize_return: rew_mean = update_mean(batch_discrew, rew_mean, cur_total_steps) rew_std = update_std(ep_discrew, rew_std, cur_total_steps) ep_discrew = ep_discrew / (rew_std + 1e-6) batch_discrew = torch.cat((batch_discrew, ep_discrew[:-1])) ep_val = model.value_fn(ep_obs) deltas = ep_rew[:-1] + gamma * ep_val[1:] - ep_val[:-1] ep_adv = discount_cumsum(deltas.detach(), gamma * lam) batch_adv = torch.cat((batch_adv, ep_adv)) cur_batch_steps += ep_steps cur_total_steps += ep_steps # make sure our advantages are zero mean and unit variance if normalize_adv: #adv_mean = update_mean(batch_adv, adv_mean, cur_total_steps) #adv_var = update_std(batch_adv, adv_var, cur_total_steps) batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-6) num_mbatch = int(batch_obs.shape[0] / sgd_batch_size) # Update the policy using the PPO loss for pol_epoch in range(sgd_epochs): for i in range(num_mbatch): # policy update # ======================================================================== cur_sample = i * sgd_batch_size # Transfer to GPU (if GPU is enabled, else this does nothing) local_obs = batch_obs[cur_sample:cur_sample + sgd_batch_size] local_act = batch_act[cur_sample:cur_sample + sgd_batch_size] local_adv = batch_adv[cur_sample:cur_sample + sgd_batch_size] local_val = batch_discrew[cur_sample:cur_sample + sgd_batch_size] # Compute the loss logp = model.get_logp(local_obs, local_act).reshape(-1, act_size) old_logp = old_model.get_logp(local_obs, local_act).reshape(-1, act_size) mean_entropy = -(logp*torch.exp(logp)).mean() r = torch.exp(logp - old_logp) clip_r = torch.clamp(r, 1 - eps, 1 + eps) pol_loss = -torch.min(r * local_adv, clip_r * local_adv).mean() - entropy_coef*mean_entropy approx_kl = ((logp - old_logp)**2).mean() if approx_kl > target_kl: break pol_opt.zero_grad() pol_loss.backward() pol_opt.step() # value_fn update # ======================================================================== val_preds = model.value_fn(local_obs) if clip_val: old_val_preds = old_model.value_fn(local_obs) val_preds_clipped = old_val_preds + torch.clamp(val_preds - old_val_preds, -eps, eps) val_loss1 = (val_preds_clipped - local_val)**2 val_loss2 = (val_preds - local_val)**2 val_loss = val_coef*torch.max(val_loss1, val_loss2).mean() else: val_loss = val_coef*((val_preds - local_val) ** 2).mean() val_opt.zero_grad() val_loss.backward() val_opt.step() # update observation mean and variance if normalize_obs: obs_mean = update_mean(batch_obs, obs_mean, cur_total_steps) obs_std = update_std(batch_obs, obs_std, cur_total_steps) model.policy.state_means = obs_mean model.value_fn.state_means = obs_mean model.policy.state_std = obs_std model.value_fn.state_std = obs_std model.action_std = actstd_lookup(cur_total_steps) sgd_lr = lr_lookup(cur_total_steps) old_model = pickle.loads(pickle.dumps(model)) val_loss_hist.append(val_loss) pol_loss_hist.append(pol_loss) progress_bar.update(cur_batch_steps) progress_bar.close() return model, raw_rew_hist, locals()