class SACAgent: def __init__(self, env_name, model, env_max_steps=0, min_steps_per_update=1, iters_per_update=100, replay_batch_size=64, seed=0, gamma=0.95, polyak=0.995, alpha=0.2, sgd_batch_size=64, sgd_lr=1e-3, exploration_steps=100, replay_buf_size=int(100000), normalize_steps=1000, use_gpu=False, reward_stop=None, env_config={}, sgd_lr_sched=None): """ Implements soft actor critic Args: env_name: name of the openAI gym environment to solve model: model from seagul.rl.models. Contains policy, value fn, q1_fn, q2_fn min_steps_per_update: minimun number of steps to take before running updates, will finish episodes before updating env_max_steps: number of steps the environment takes before finishing, if the environment emits a done signal before this we consider it a failure. iters_per_update: how many update steps to make every time we update replay_batch_size: how big a batch to pull from the replay buffer for each update seed: random seed for all rngs gamma: discount applied to future rewards, usually close to 1 polyak: term determining how fast the target network is copied from the value function alpha: weighting term for the entropy. 0 corresponds to no penalty for deterministic policy sgd_batch_size: minibatch size for policy updates sgd_lr: initial learning rate for policy optimizer val_lr: initial learning rate for value optimizer q_lr: initial learning rate for q fn optimizer exploration_steps: initial number of random actions to take, aids exploration replay_buf_size: how big of a replay buffer to use use_gpu: determines if we try to use a GPU or not reward_stop: reward value to bail at env_config: dictionary containing kwargs to pass to your the environment sgd_lr_sched: list of sgd_lrs to interpolate between as training goes on """ self.env_name = env_name self.model = model self.env_max_steps=env_max_steps self.min_steps_per_update = min_steps_per_update self.iters_per_update = iters_per_update self.replay_batch_size = replay_batch_size self.seed = seed self.gamma = gamma self.polyak = polyak self.alpha = alpha self.sgd_batch_size = sgd_batch_size self.sgd_lr = sgd_lr self.exploration_steps = exploration_steps self.replay_buf_size = replay_buf_size self.normalize_steps = normalize_steps self.use_gpu = use_gpu self.reward_stop = reward_stop self.env_config = env_config self.sgd_lr_sched = sgd_lr_sched def learn(self, train_steps): """ runs sac for train_steps Returns: model: trained model avg_reward_hist: list with the average reward per episode at each epoch var_dict: dictionary with all locals, for logging/debugging purposes """ torch.set_num_threads(1) # performance issue with data loader env = gym.make(self.env_name, **self.env_config) if isinstance(env.action_space, gym.spaces.Box): act_size = env.action_space.shape[0] act_dtype = env.action_space.sample().dtype else: raise NotImplementedError("trying to use unsupported action space", env.action_space) obs_size = env.observation_space.shape[0] random_model = RandModel(self.model.act_limit, act_size) self.replay_buf = ReplayBuffer(obs_size, act_size, self.replay_buf_size) self.target_value_fn = copy.deepcopy(self.model.value_fn) pol_opt = torch.optim.Adam(self.model.policy.parameters(), lr=self.sgd_lr) val_opt = torch.optim.Adam(self.model.value_fn.parameters(), lr=self.sgd_lr) q1_opt = torch.optim.Adam(self.model.q1_fn.parameters(), lr=self.sgd_lr) q2_opt = torch.optim.Adam(self.model.q2_fn.parameters(), lr=self.sgd_lr) if self.sgd_lr_sched: sgd_lookup = make_schedule(self.sgd_lr_sched, train_steps) else: sgd_lookup = None # seed all our RNGs env.seed(self.seed) torch.manual_seed(self.seed) np.random.seed(self.seed) # set defaults, and decide if we are using a GPU or not use_cuda = torch.cuda.is_available() and self.use_gpu device = torch.device("cuda:0" if use_cuda else "cpu") self.raw_rew_hist = [] self.val_loss_hist = [] self.pol_loss_hist = [] self.q1_loss_hist = [] self.q2_loss_hist = [] progress_bar = tqdm.tqdm(total=train_steps + self.normalize_steps) cur_total_steps = 0 progress_bar.update(0) early_stop = False norm_obs1 = torch.empty(0) while cur_total_steps < self.normalize_steps: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, self.env_max_steps) norm_obs1 = torch.cat((norm_obs1, ep_obs1)) ep_steps = ep_rews.shape[0] cur_total_steps += ep_steps progress_bar.update(ep_steps) if self.normalize_steps > 0: obs_mean = norm_obs1.mean(axis=0) obs_std = norm_obs1.std(axis=0) obs_std[torch.isinf(1/obs_std)] = 1 self.model.policy.state_means = obs_mean self.model.policy.state_std = obs_std self.model.value_fn.state_means = obs_mean self.model.value_fn.state_std = obs_std self.target_value_fn.state_means = obs_mean self.target_value_fn.state_std = obs_std self.model.q1_fn.state_means = torch.cat((obs_mean, torch.zeros(act_size))) self.model.q1_fn.state_std = torch.cat((obs_std, torch.ones(act_size))) self.model.q2_fn.state_means = self.model.q1_fn.state_means self.model.q2_fn.state_std = self.model.q1_fn.state_std while cur_total_steps < self.exploration_steps: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, self.env_max_steps) self.replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) ep_steps = ep_rews.shape[0] cur_total_steps += ep_steps progress_bar.update(ep_steps) while cur_total_steps < train_steps: cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(self.raw_rew_hist) > 2 and self.reward_stop: if self.raw_rew_hist[-1] >= self.reward_stop and self.raw_rew_hist[-2] >= self.reward_stop: early_stop = True break # collect data with the current policy # ======================================================================== while cur_batch_steps < self.min_steps_per_update: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, self.model, self.env_max_steps) self.replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) ep_steps = ep_rews.shape[0] cur_batch_steps += ep_steps cur_total_steps += ep_steps self.raw_rew_hist.append(torch.sum(ep_rews)) #print(self.raw_rew_hist[-1]) progress_bar.update(cur_batch_steps) for _ in range(min(int(ep_steps), self.iters_per_update)): torch.autograd.set_grad_enabled(False) # compute targets for Q and V # ======================================================================== replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = self.replay_buf.sample_batch(self.replay_batch_size) q_targ = replay_rews + self.gamma * (1 - replay_done) * self.target_value_fn(replay_obs2) noise = torch.randn(self.replay_batch_size, act_size) sample_acts, sample_logp = self.model.select_action(replay_obs1, noise) q_in = torch.cat((replay_obs1, sample_acts), dim=1) q_preds = torch.cat((self.model.q1_fn(q_in), self.model.q2_fn(q_in)), dim=1) q_min, q_min_idx = torch.min(q_preds, dim=1) q_min = q_min.reshape(-1,1) v_targ = q_min - self.alpha * sample_logp #v_targ = v_targ torch.autograd.set_grad_enabled(True) # q_fn update # ======================================================================== num_mbatch = int(self.replay_batch_size / self.sgd_batch_size) for i in range(num_mbatch): cur_sample = i*self.sgd_batch_size q_in = torch.cat((replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], replay_acts[cur_sample:cur_sample + self.sgd_batch_size]), dim=1) q1_preds = self.model.q1_fn(q_in) q2_preds = self.model.q2_fn(q_in) q1_loss = torch.pow(q1_preds - q_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean() q2_loss = torch.pow(q2_preds - q_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean() q_loss = q1_loss + q2_loss q1_opt.zero_grad() q2_opt.zero_grad() q_loss.backward() q1_opt.step() q2_opt.step() # val_fn update # ======================================================================== for i in range(num_mbatch): cur_sample = i*self.sgd_batch_size # predict and calculate loss for the batch val_preds = self.model.value_fn(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size]) val_loss = torch.pow(val_preds - v_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean() # do the normal pytorch update val_opt.zero_grad() val_loss.backward() val_opt.step() # policy_fn update # ======================================================================== for param in self.model.q1_fn.parameters(): param.requires_grad = False for i in range(num_mbatch): cur_sample = i*self.sgd_batch_size noise = torch.randn(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size].shape[0], act_size) local_acts, local_logp = self.model.select_action(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], noise) q_in = torch.cat((replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], local_acts), dim=1) pol_loss = (self.alpha * local_logp - self.model.q1_fn(q_in)).mean() pol_opt.zero_grad() pol_loss.backward() pol_opt.step() for param in self.model.q1_fn.parameters(): param.requires_grad = True # Update target value fn with polyak average # ======================================================================== self.val_loss_hist.append(val_loss.item()) self.pol_loss_hist.append(pol_loss.item()) self.q1_loss_hist.append(q1_loss.item()) self.q2_loss_hist.append(q2_loss.item()) val_sd = self.model.value_fn.state_dict() tar_sd = self.target_value_fn.state_dict() for layer in tar_sd: tar_sd[layer] = self.polyak * tar_sd[layer] + (1 - self.polyak) * val_sd[layer] self.target_value_fn.load_state_dict(tar_sd) #Update LRs if sgd_lookup: pol_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps) val_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps) q1_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps) q2_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps) return self.model, self.raw_rew_hist, locals()
def learn(self, train_steps): """ runs sac for train_steps Returns: model: trained model avg_reward_hist: list with the average reward per episode at each epoch var_dict: dictionary with all locals, for logging/debugging purposes """ torch.set_num_threads(1) # performance issue with data loader env = gym.make(self.env_name, **self.env_config) if isinstance(env.action_space, gym.spaces.Box): act_size = env.action_space.shape[0] act_dtype = env.action_space.sample().dtype else: raise NotImplementedError("trying to use unsupported action space", env.action_space) obs_size = env.observation_space.shape[0] random_model = RandModel(self.model.act_limit, act_size) self.replay_buf = ReplayBuffer(obs_size, act_size, self.replay_buf_size) self.target_value_fn = copy.deepcopy(self.model.value_fn) pol_opt = torch.optim.Adam(self.model.policy.parameters(), lr=self.sgd_lr) val_opt = torch.optim.Adam(self.model.value_fn.parameters(), lr=self.sgd_lr) q1_opt = torch.optim.Adam(self.model.q1_fn.parameters(), lr=self.sgd_lr) q2_opt = torch.optim.Adam(self.model.q2_fn.parameters(), lr=self.sgd_lr) if self.sgd_lr_sched: sgd_lookup = make_schedule(self.sgd_lr_sched, train_steps) else: sgd_lookup = None # seed all our RNGs env.seed(self.seed) torch.manual_seed(self.seed) np.random.seed(self.seed) # set defaults, and decide if we are using a GPU or not use_cuda = torch.cuda.is_available() and self.use_gpu device = torch.device("cuda:0" if use_cuda else "cpu") self.raw_rew_hist = [] self.val_loss_hist = [] self.pol_loss_hist = [] self.q1_loss_hist = [] self.q2_loss_hist = [] progress_bar = tqdm.tqdm(total=train_steps + self.normalize_steps) cur_total_steps = 0 progress_bar.update(0) early_stop = False norm_obs1 = torch.empty(0) while cur_total_steps < self.normalize_steps: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, self.env_max_steps) norm_obs1 = torch.cat((norm_obs1, ep_obs1)) ep_steps = ep_rews.shape[0] cur_total_steps += ep_steps progress_bar.update(ep_steps) if self.normalize_steps > 0: obs_mean = norm_obs1.mean(axis=0) obs_std = norm_obs1.std(axis=0) obs_std[torch.isinf(1/obs_std)] = 1 self.model.policy.state_means = obs_mean self.model.policy.state_std = obs_std self.model.value_fn.state_means = obs_mean self.model.value_fn.state_std = obs_std self.target_value_fn.state_means = obs_mean self.target_value_fn.state_std = obs_std self.model.q1_fn.state_means = torch.cat((obs_mean, torch.zeros(act_size))) self.model.q1_fn.state_std = torch.cat((obs_std, torch.ones(act_size))) self.model.q2_fn.state_means = self.model.q1_fn.state_means self.model.q2_fn.state_std = self.model.q1_fn.state_std while cur_total_steps < self.exploration_steps: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, self.env_max_steps) self.replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) ep_steps = ep_rews.shape[0] cur_total_steps += ep_steps progress_bar.update(ep_steps) while cur_total_steps < train_steps: cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(self.raw_rew_hist) > 2 and self.reward_stop: if self.raw_rew_hist[-1] >= self.reward_stop and self.raw_rew_hist[-2] >= self.reward_stop: early_stop = True break # collect data with the current policy # ======================================================================== while cur_batch_steps < self.min_steps_per_update: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, self.model, self.env_max_steps) self.replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) ep_steps = ep_rews.shape[0] cur_batch_steps += ep_steps cur_total_steps += ep_steps self.raw_rew_hist.append(torch.sum(ep_rews)) #print(self.raw_rew_hist[-1]) progress_bar.update(cur_batch_steps) for _ in range(min(int(ep_steps), self.iters_per_update)): torch.autograd.set_grad_enabled(False) # compute targets for Q and V # ======================================================================== replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = self.replay_buf.sample_batch(self.replay_batch_size) q_targ = replay_rews + self.gamma * (1 - replay_done) * self.target_value_fn(replay_obs2) noise = torch.randn(self.replay_batch_size, act_size) sample_acts, sample_logp = self.model.select_action(replay_obs1, noise) q_in = torch.cat((replay_obs1, sample_acts), dim=1) q_preds = torch.cat((self.model.q1_fn(q_in), self.model.q2_fn(q_in)), dim=1) q_min, q_min_idx = torch.min(q_preds, dim=1) q_min = q_min.reshape(-1,1) v_targ = q_min - self.alpha * sample_logp #v_targ = v_targ torch.autograd.set_grad_enabled(True) # q_fn update # ======================================================================== num_mbatch = int(self.replay_batch_size / self.sgd_batch_size) for i in range(num_mbatch): cur_sample = i*self.sgd_batch_size q_in = torch.cat((replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], replay_acts[cur_sample:cur_sample + self.sgd_batch_size]), dim=1) q1_preds = self.model.q1_fn(q_in) q2_preds = self.model.q2_fn(q_in) q1_loss = torch.pow(q1_preds - q_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean() q2_loss = torch.pow(q2_preds - q_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean() q_loss = q1_loss + q2_loss q1_opt.zero_grad() q2_opt.zero_grad() q_loss.backward() q1_opt.step() q2_opt.step() # val_fn update # ======================================================================== for i in range(num_mbatch): cur_sample = i*self.sgd_batch_size # predict and calculate loss for the batch val_preds = self.model.value_fn(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size]) val_loss = torch.pow(val_preds - v_targ[cur_sample:cur_sample + self.sgd_batch_size], 2).mean() # do the normal pytorch update val_opt.zero_grad() val_loss.backward() val_opt.step() # policy_fn update # ======================================================================== for param in self.model.q1_fn.parameters(): param.requires_grad = False for i in range(num_mbatch): cur_sample = i*self.sgd_batch_size noise = torch.randn(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size].shape[0], act_size) local_acts, local_logp = self.model.select_action(replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], noise) q_in = torch.cat((replay_obs1[cur_sample:cur_sample + self.sgd_batch_size], local_acts), dim=1) pol_loss = (self.alpha * local_logp - self.model.q1_fn(q_in)).mean() pol_opt.zero_grad() pol_loss.backward() pol_opt.step() for param in self.model.q1_fn.parameters(): param.requires_grad = True # Update target value fn with polyak average # ======================================================================== self.val_loss_hist.append(val_loss.item()) self.pol_loss_hist.append(pol_loss.item()) self.q1_loss_hist.append(q1_loss.item()) self.q2_loss_hist.append(q2_loss.item()) val_sd = self.model.value_fn.state_dict() tar_sd = self.target_value_fn.state_dict() for layer in tar_sd: tar_sd[layer] = self.polyak * tar_sd[layer] + (1 - self.polyak) * val_sd[layer] self.target_value_fn.load_state_dict(tar_sd) #Update LRs if sgd_lookup: pol_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps) val_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps) q1_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps) q2_opt.param_groups[0]['lr'] = sgd_lookup(cur_total_steps) return self.model, self.raw_rew_hist, locals()
def td3( env_name, train_steps, model, env_max_steps=0, min_steps_per_update=1, iters_per_update=200, replay_batch_size=64, seed=0, act_std_schedule=(.1,), gamma=0.95, polyak=0.995, sgd_batch_size=64, sgd_lr=3e-4, exploration_steps=1000, replay_buf_size=int(100000), reward_stop=None, env_config=None ): # Initialize env, and other globals # ======================================================================== if env_config is None: env_config = {} env = gym.make(env_name, **env_config) if isinstance(env.action_space, gym.spaces.Box): act_size = env.action_space.shape[0] act_dtype = env.action_space.sample().dtype else: raise NotImplementedError("trying to use unsupported action space", env.action_space) obs_size = env.observation_space.shape[0] # seed all our RNGs env.seed(seed) torch.manual_seed(seed) np.random.seed(seed) random_model = RandModel(model.act_limit, act_size) replay_buf = ReplayBuffer(obs_size, act_size, replay_buf_size) target_q1_fn = dill.loads(dill.dumps(model.q1_fn)) target_q2_fn = dill.loads(dill.dumps(model.q2_fn)) target_policy = dill.loads(dill.dumps(model.policy)) for param in target_q1_fn.parameters(): param.requires_grad = False for param in target_q2_fn.parameters(): param.requires_grad = False for param in target_policy.parameters(): param.requires_grad = False act_std_lookup = make_schedule(act_std_schedule, train_steps) act_std = act_std_lookup(0) pol_opt = torch.optim.Adam(model.policy.parameters(), lr=sgd_lr) q1_opt = torch.optim.Adam(model.q1_fn.parameters(), lr=sgd_lr) q2_opt = torch.optim.Adam(model.q2_fn.parameters(), lr=sgd_lr) progress_bar = tqdm.tqdm(total=train_steps) cur_total_steps = 0 progress_bar.update(0) early_stop = False raw_rew_hist = [] pol_loss_hist = [] q1_loss_hist = [] q2_loss_hist = [] # Fill the replay buffer with actions taken from a random model # ======================================================================== while cur_total_steps < exploration_steps: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, env_max_steps, act_std) replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) ep_steps = ep_rews.shape[0] cur_total_steps += ep_steps progress_bar.update(ep_steps) # Keep training until we take train_step environment steps # ======================================================================== while cur_total_steps < train_steps: cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(raw_rew_hist) > 2 and reward_stop: print(raw_rew_hist[-1]) if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[-2] >= reward_stop: early_stop = True break # collect data with the current policy # ======================================================================== while cur_batch_steps < min_steps_per_update: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, model, env_max_steps, act_std) replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) ep_steps = ep_rews.shape[0] cur_batch_steps += ep_steps cur_total_steps += ep_steps raw_rew_hist.append(torch.sum(ep_rews)) progress_bar.update(cur_batch_steps) # Do the update # ======================================================================== for _ in range(min(int(ep_steps), iters_per_update)): # Compute target Q replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = replay_buf.sample_batch(replay_batch_size) with torch.no_grad(): acts_from_target = target_policy(replay_obs2) q_in = torch.cat((replay_obs2, acts_from_target), dim=1) q_targ = replay_rews + gamma*(1 - replay_done)*target_q1_fn(q_in) num_mbatch = int(replay_batch_size / sgd_batch_size) # q_fn update # ======================================================================== for i in range(num_mbatch): cur_sample = i * sgd_batch_size q_in_local = torch.cat((replay_obs1[cur_sample:cur_sample + sgd_batch_size], replay_acts[cur_sample:cur_sample + sgd_batch_size]), dim=1) local_qtarg = q_targ[cur_sample:cur_sample + sgd_batch_size] q1_loss = ((model.q1_fn(q_in_local) - local_qtarg)**2).mean() #q2_preds = model.q2_fn(q_in) #q2_loss = (q2_preds - q_targ[cur_sample:cur_sample + sgd_batch_size]**2).mean() q_loss = q1_loss# + q2_loss q1_opt.zero_grad() #q2_opt.zero_grad() q_loss.backward() q1_opt.step() #q2_opt.step() # policy_fn update # ======================================================================== for param in model.q1_fn.parameters(): param.requires_grad = False for i in range(num_mbatch): cur_sample = i * sgd_batch_size local_obs = replay_obs1[cur_sample:cur_sample + sgd_batch_size] local_acts = model.policy(local_obs) q_in = torch.cat((local_obs, local_acts), dim=1) pol_loss = -(model.q1_fn(q_in).mean()) pol_opt.zero_grad() pol_loss.backward() pol_opt.step() for param in model.q1_fn.parameters(): param.requires_grad = True # Update target value fn with polyak average # ======================================================================== pol_loss_hist.append(pol_loss.item()) q1_loss_hist.append(q1_loss.item()) #q2_loss_hist.append(q2_loss.item()) target_q1_fn = update_target_fn(model.q1_fn, target_q1_fn, polyak) target_q2_fn = update_target_fn(model.q2_fn, target_q2_fn, polyak) target_policy = update_target_fn(model.policy, target_policy, polyak) act_std = act_std_lookup(cur_total_steps) return model, raw_rew_hist, locals()
def sac_sym( env_name, total_steps, model, env_steps=0, min_steps_per_update=1, iters_per_update=100, replay_batch_size=64, seed=0, gamma=0.95, polyak=0.995, alpha=0.2, pol_batch_size=64, val_batch_size=64, q_batch_size=64, pol_lr=1e-3, val_lr=1e-3, q_lr=1e-3, exploration_steps=100, replay_buf_size=int(50000), use_gpu=False, reward_stop=None, ): """ Implements soft actor critic Args: env_name: name of the openAI gym environment to solve total_steps: number of timesteps to run the PPO for model: model from seagul.rl.models. Contains policy, value fn, q1_fn, q2_fn min_steps_per_update: minimun number of steps to take before running updates, will finish episodes before updating env_steps: number of steps the environment takes before finishing, if the environment emits a done signal before this we consider it a failure. iters_per_update: how many update steps to make every time we update replay_batch_size: how big a batch to pull from the replay buffer for each update seed: random seed for all rngs gamma: discount applied to future rewards, usually close to 1 polyak: term determining how fast the target network is copied from the value function alpha: weighting term for the entropy. 0 corresponds to no penalty for deterministic policy pol_batch_size: minibatch size for policy updates val_batch_size: minibatch size for value fn updates q_batch_size: minibatch size for q fn updates pol_lr: initial learning rate for policy optimizer val_lr: initial learning rate for value optimizer q_lr: initial learning rate for q fn optimizer exploration_steps: initial number of random actions to take, aids exploration replay_buf_size: how big of a replay buffer to use use_gpu: determines if we try to use a GPU or not reward_stop: reward value to bail at Returns: model: trained model avg_reward_hist: list with the average reward per episode at each epoch var_dict: dictionary with all locals, for logging/debugging purposes Example: from seagul.rl.algos.sac import sac import torch.nn as nn from seagul.nn import MLP from seagul.rl.models import SACModel input_size = 3 output_size = 1 layer_size = 64 num_layers = 2 activation = nn.ReLU policy = MLP(input_size, output_size*2, num_layers, layer_size, activation) value_fn = MLP(input_size, 1, num_layers, layer_size, activation) q1_fn = MLP(input_size + output_size, 1, num_layers, layer_size, activation) q2_fn = MLP(input_size + output_size, 1, num_layers, layer_size, activation) model = SACModel(policy, value_fn, q1_fn, q2_fn, 1) model, rews, var_dict = sac("Pendulum-v0", 10000, model) """ env = gym.make(env_name) if isinstance(env.action_space, gym.spaces.Box): act_size = env.action_space.shape[0] act_dtype = env.action_space.sample().dtype else: raise NotImplementedError("trying to use unsupported action space", env.action_space) obs_size = env.observation_space.shape[0] random_model = RandModel(model.act_limit, act_size) replay_buf = ReplayBuffer(obs_size, act_size, replay_buf_size) target_value_fn = dill.loads(dill.dumps(model.value_fn)) pol_opt = torch.optim.Adam(model.policy.parameters(), lr=pol_lr) val_opt = torch.optim.Adam(model.value_fn.parameters(), lr=val_lr) q1_opt = torch.optim.Adam(model.q1_fn.parameters(), lr=q_lr) q2_opt = torch.optim.Adam(model.q2_fn.parameters(), lr=q_lr) # seed all our RNGs env.seed(seed) torch.manual_seed(seed) np.random.seed(seed) # set defaults, and decide if we are using a GPU or not use_cuda = torch.cuda.is_available() and use_gpu device = torch.device("cuda:0" if use_cuda else "cpu") raw_rew_hist = [] val_loss_hist = [] pol_loss_hist = [] q1_loss_hist = [] q2_loss_hist = [] #progress_bar = tqdm.tqdm(total=total_steps) cur_total_steps = 0 #progress_bar.update(0) early_stop = False while cur_total_steps < exploration_steps: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout( env, random_model, env_steps) # can def be made more efficient if found to be a bottleneck for obs1, obs2, acts, rews, done in zip(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done): replay_buf.store(obs1, obs2, acts, rews, done) replay_buf.store(mirror_obs(obs1), mirror_obs(obs2), mirror_act(acts), rews, done) ep_steps = ep_rews.shape[0] * 2 cur_total_steps += ep_steps # progress_bar.update(cur_total_steps) while cur_total_steps < total_steps: cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(raw_rew_hist) > 2 and reward_stop: if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[ -2] >= reward_stop: early_stop = True break # collect data with the current policy # ======================================================================== while cur_batch_steps < min_steps_per_update: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout( env, model, env_steps) # can def be made more efficient if found to be a bottleneck for obs1, obs2, acts, rews, done in zip(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done): replay_buf.store(obs1, obs2, acts, rews, done) replay_buf.store(mirror_obs(obs1), mirror_obs(obs2), mirror_act(acts), rews, done) ep_steps = ep_rews.shape[0] * 2 cur_batch_steps += ep_steps cur_total_steps += ep_steps raw_rew_hist.append(torch.sum(ep_rews)) #progress_bar.update(cur_batch_steps) for _ in range(ep_steps): # compute targets for Q and V # ======================================================================== replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = replay_buf.sample_batch( replay_batch_size) q_targ = replay_rews + gamma * ( 1 - replay_done) * target_value_fn(replay_obs2) q_targ = q_targ.detach() noise = torch.randn(replay_batch_size, act_size) sample_acts, sample_logp = model.select_action(replay_obs1, noise) q_in = torch.cat((replay_obs1, sample_acts), dim=1) q_preds = torch.cat((model.q1_fn(q_in), model.q2_fn(q_in)), dim=1) q_min, q_min_idx = torch.min(q_preds, dim=1) q_min = q_min.reshape(-1, 1) v_targ = q_min - alpha * sample_logp v_targ = v_targ.detach() # q_fn update # ======================================================================== training_data = data.TensorDataset(replay_obs1, replay_acts, q_targ) training_generator = data.DataLoader(training_data, batch_size=q_batch_size, shuffle=False) for local_obs, local_acts, local_qtarg in training_generator: # Transfer to GPU (if GPU is enabled, else this does nothing) local_obs, local_acts, local_qtarg = ( local_obs.to(device), local_acts.to(device), local_qtarg.to(device), ) q_in = torch.cat((local_obs, local_acts), dim=1) q1_preds = model.q1_fn(q_in) q2_preds = model.q2_fn(q_in) q1_loss = torch.pow(q1_preds - local_qtarg, 2).mean() q2_loss = torch.pow(q2_preds - local_qtarg, 2).mean() q_loss = q1_loss + q2_loss q1_opt.zero_grad() q2_opt.zero_grad() q_loss.backward() q1_opt.step() q2_opt.step() # val_fn update # ======================================================================== training_data = data.TensorDataset(replay_obs1, v_targ) training_generator = data.DataLoader(training_data, batch_size=q_batch_size, shuffle=False) for local_obs, local_vtarg in training_generator: # Transfer to GPU (if GPU is enabled, else this does nothing) local_obs, local_vtarg = (local_obs.to(device), local_vtarg.to(device)) # predict and calculate loss for the batch val_preds = model.value_fn(local_obs) val_loss = torch.sum(torch.pow(val_preds - local_vtarg, 2)) / replay_batch_size # do the normal pytorch update val_opt.zero_grad() val_loss.backward() val_opt.step() # policy_fn update # ======================================================================== training_data = data.TensorDataset(replay_obs1) training_generator = data.DataLoader(training_data, batch_size=pol_batch_size, shuffle=False) for local_obs in training_generator: # Transfer to GPU (if GPU is enabled, else this does nothing) local_obs = local_obs[0].to(device) noise = torch.randn(pol_batch_size, act_size) local_acts, local_logp = model.select_action(local_obs, noise) q_in = torch.cat((local_obs, local_acts), dim=1) pol_loss = torch.sum(alpha * local_logp - model.q1_fn(q_in)) / replay_batch_size # do the normal pytorch update pol_opt.zero_grad() pol_loss.backward() pol_opt.step() # Update target value fn with polyak average # ======================================================================== val_loss_hist.append(val_loss.item()) pol_loss_hist.append(pol_loss.item()) q1_loss_hist.append(q1_loss.item()) q2_loss_hist.append(q2_loss.item()) # # model.policy.state_means = update_mean(replay_obs1, model.policy.state_means, cur_total_steps) # model.policy.state_var = update_var(replay_obs1, model.policy.state_var, cur_total_steps) # model.value_fn.state_means = model.policy.state_means # model.value_fn.state_var = model.policy.state_var # # model.q1_fn.state_means = update_mean(torch.cat((replay_obs1, replay_acts.detach()), dim=1), model.q1_fn.state_means, cur_total_steps) # model.q1_fn.state_var = update_var(torch.cat((replay_obs1, replay_acts.detach()), dim=1), model.q1_fn.state_var, cur_total_steps) # model.q2_fn.state_means = model.q1_fn.state_means # model.q2_fn.state_var = model.q1_fn.state_var val_sd = model.value_fn.state_dict() tar_sd = target_value_fn.state_dict() for layer in tar_sd: tar_sd[layer] = polyak * tar_sd[layer] + ( 1 - polyak) * val_sd[layer] target_value_fn.load_state_dict(tar_sd) return (model, raw_rew_hist, locals())
def sac_switched( env_name, total_steps, model, env_steps=0, min_steps_per_update=1, iters_per_update=100, replay_batch_size=64, seed=0, gamma=0.95, polyak=0.995, alpha=0.2, sgd_batch_size=64, sgd_lr=1e-3, exploration_steps=100, replay_buf_size=int(100000), use_gpu=False, reward_stop=None, goal_state=np.array([np.pi / 2, 0, 0, 0]), goal_lookback=10, goal_thresh=1, needle_lookup_prob=.5, gate_update_freq=500, gate_x=None, gate_y=None, gate_lr=1e-5, gate_w=1e-2, gate_epochs=1, env_config={}, ): """ Implements soft actor critic Args: env_name: name of the openAI gym environment to solve total_steps: number of timesteps to run the PPO for model: model from seagul.rl.models. Contains policy, value fn, q1_fn, q2_fn min_steps_per_update: minimun number of steps to take before running updates, will finish episodes before updating env_steps: number of steps the environment takes before finishing, if the environment emits a done signal before this we consider it a failure. iters_per_update: how many update steps to make every time we update replay_batch_size: how big a batch to pull from the replay buffer for each update seed: random seed for all rngs gamma: discount applied to future rewards, usually close to 1 polyak: term determining how fast the target network is copied from the value function alpha: weighting term for the entropy. 0 corresponds to no penalty for deterministic policy sgd_batch_size: minibatch size for policy updates sgd_lr: initial learning rate for policy optimizer val_lr: initial learning rate for value optimizer q_lr: initial learning rate for q fn optimizer exploration_steps: initial number of random actions to take, aids exploration replay_buf_size: how big of a replay buffer to use use_gpu: determines if we try to use a GPU or not reward_stop: reward value to bail at env_config: dictionary containing kwargs to pass to your the environment Returns: model: trained model avg_reward_hist: list with the average reward per episode at each epoch var_dict: dictionary with all locals, for logging/debugging purposes Example: from seagul.rl.algos.sac import sac import torch.nn as nn from seagul.nn import MLP from seagul.rl.models import SACModel input_size = 3 output_size = 1 layer_size = 64 num_layers = 2 activation = nn.ReLU policy = MLP(input_size, output_size*2, num_layers, layer_size, activation) value_fn = MLP(input_size, 1, num_layers, layer_size, activation) q1_fn = MLP(input_size + output_size, 1, num_layers, layer_size, activation) q2_fn = MLP(input_size + output_size, 1, num_layers, layer_size, activation) model = SACModel(policy, value_fn, q1_fn, q2_fn, 1) model, rews, var_dict = sac("Pendulum-v0", 10000, model) """ args = locals() torch.set_num_threads(1) env = gym.make(env_name, **env_config) if isinstance(env.action_space, gym.spaces.Box): act_size = env.action_space.shape[0] act_dtype = env.action_space.sample().dtype else: raise NotImplementedError("trying to use unsupported action space", env.action_space) obs_size = env.observation_space.shape[0] random_model = dill.loads(dill.dumps(model)) random_model.swingup_controller = lambda x: torch.rand( model.num_acts) * 2 * model.act_limit - model.act_limit replay_buf = ReplayBuffer(obs_size, act_size, replay_buf_size) needle_buf = ReplayBuffer(obs_size, act_size, replay_buf_size) target_value_fn = dill.loads(dill.dumps(model.value_fn)) pol_opt = torch.optim.Adam(model.policy.parameters(), lr=sgd_lr) val_opt = torch.optim.Adam(model.value_fn.parameters(), lr=sgd_lr) q1_opt = torch.optim.Adam(model.q1_fn.parameters(), lr=sgd_lr) q2_opt = torch.optim.Adam(model.q2_fn.parameters(), lr=sgd_lr) # seed all our RNGs env.seed(seed) torch.manual_seed(seed) np.random.seed(seed) # set defaults, and decide if we are using a GPU or not use_cuda = torch.cuda.is_available() and use_gpu device = torch.device("cuda:0" if use_cuda else "cpu") raw_rew_hist = [] val_loss_hist = [] pol_loss_hist = [] q1_loss_hist = [] q2_loss_hist = [] progress_bar = tqdm.tqdm(total=total_steps) cur_total_steps = 0 gate_update_counter = 0 progress_bar.update(0) early_stop = False needle_count = 0 not_needle_count = 0 while cur_total_steps < exploration_steps: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done, ep_path = do_rollout( env, random_model, env_steps) in_goal = torch.sum(torch.sqrt( (ep_obs2[-goal_lookback:] - goal_state)**2), axis=1) < goal_thresh if in_goal.all(): needle_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) ep_steps = ep_rews.shape[0] cur_total_steps += ep_steps progress_bar.update(cur_total_steps) while cur_total_steps < total_steps: cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(raw_rew_hist) > 2 and reward_stop: if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[ -2] >= reward_stop: early_stop = True break # Transfer back to CPU, which is faster for rollouts model = model.to('cpu') target_value_fn = target_value_fn.to('cpu') # collect data with the current policy # ======================================================================== while cur_batch_steps < min_steps_per_update: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done, ep_path = do_rollout( env, model, env_steps) in_goal = torch.sum(torch.sqrt( (ep_obs2[-goal_lookback:] - goal_state)**2), axis=1) < goal_thresh if in_goal.all(): needle_count += 1 needle_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) else: not_needle_count += 1 replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) if ep_path.sum() != 0: reverse_obs = np.flip(ep_obs1.numpy(), 0).copy() reverse_obs = torch.from_numpy(reverse_obs) reverse_path = np.flip(ep_path.numpy(), 0).copy() reverse_path = torch.from_numpy(reverse_path) if in_goal.all(): for path, obs in zip(reverse_path, reverse_obs): if not path: break else: gate_x = torch.cat((gate_x, obs.reshape(1, -1))) gate_y = torch.cat( (gate_y, torch.ones((1, 1), dtype=torch.float32))) else: for path, obs in zip(reverse_path, reverse_obs): if not path: pass else: gate_x = torch.cat((gate_x, obs.reshape(1, -1))) gate_y = torch.cat((gate_y, torch.zeros( (1, 1), dtype=torch.float32))) ep_steps = ep_rews.shape[0] cur_batch_steps += ep_steps cur_total_steps += ep_steps gate_update_counter += ep_steps raw_rew_hist.append(torch.sum(ep_rews)) progress_bar.update(ep_steps) print("needle/normal: ", str(needle_buf.size), str(replay_buf.size)) if gate_update_counter > gate_update_freq: # For training, transfer model to GPU model = model.to('cuda:0') target_value_fn = target_value_fn.to('cuda:0') class_weight = (gate_y.shape[0] / sum(gate_y) * gate_w).to('cuda:0') gate_loss = fit_model( model.gate_fn, gate_x, gate_y, gate_epochs, use_tqdm=False, use_cuda=True, batch_size=8192, loss_fn=torch.nn.BCEWithLogitsLoss(pos_weight=class_weight), learning_rate=gate_lr) print("gate updated: " + str(gate_y.shape[0]) + " " + str(sum(gate_y))) model = model.to('cpu') target_value_fn = target_value_fn.to('cpu') gate_update_counter = 0 for _ in range(min(int(ep_steps), iters_per_update)): # compute targets for Q and V # ======================================================================== p = np.random.random_sample(1) if p > needle_lookup_prob and needle_buf.size > 0: replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = needle_buf.sample_batch( replay_batch_size) else: replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = replay_buf.sample_batch( replay_batch_size) replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = \ [replay_obs1.to(device), replay_obs2.to(device), replay_acts.to(device), replay_rews.to(device), replay_done.to(device)] q_targ = replay_rews + gamma * ( 1 - replay_done) * target_value_fn(replay_obs2) q_targ = q_targ.detach() noise = torch.randn(replay_batch_size, act_size).to(device) sample_acts, sample_logp = model.select_action_parallel( replay_obs1, noise) q_in = torch.cat((replay_obs1, sample_acts), dim=1) q_preds = torch.cat((model.q1_fn(q_in), model.q2_fn(q_in)), dim=1) q_min, q_min_idx = torch.min(q_preds, dim=1) q_min = q_min.reshape(-1, 1) v_targ = q_min - alpha * sample_logp v_targ = v_targ.detach() # q_fn update # ======================================================================== training_data = data.TensorDataset(replay_obs1, replay_acts, q_targ) training_generator = data.DataLoader(training_data, batch_size=sgd_batch_size, shuffle=True, num_workers=0, pin_memory=False) for local_obs, local_acts, local_qtarg in training_generator: # Transfer to GPU (if GPU is enabled, else this does nothing) local_obs, local_acts, local_qtarg = ( local_obs.to(device), local_acts.to(device), local_qtarg.to(device), ) q_in = torch.cat((local_obs, local_acts), dim=1) q1_preds = model.q1_fn(q_in) q2_preds = model.q2_fn(q_in) q1_loss = torch.pow(q1_preds - local_qtarg, 2).mean() q2_loss = torch.pow(q2_preds - local_qtarg, 2).mean() q_loss = q1_loss + q2_loss q1_opt.zero_grad() q2_opt.zero_grad() q_loss.backward() q1_opt.step() q2_opt.step() # val_fn update # ======================================================================== training_data = data.TensorDataset(replay_obs1, v_targ) training_generator = data.DataLoader(training_data, batch_size=sgd_batch_size, shuffle=True, num_workers=0, pin_memory=False) for local_obs, local_vtarg in training_generator: # Transfer to GPU (if GPU is enabled, else this does nothing) local_obs, local_vtarg = (local_obs.to(device), local_vtarg.to(device)) # predict and calculate loss for the batch val_preds = model.value_fn(local_obs) val_loss = torch.sum(torch.pow(val_preds - local_vtarg, 2)) / replay_batch_size # do the normal pytorch update val_opt.zero_grad() val_loss.backward() val_opt.step() # policy_fn update # ======================================================================== training_data = data.TensorDataset(replay_obs1) training_generator = data.DataLoader(training_data, batch_size=sgd_batch_size, shuffle=True, num_workers=0, pin_memory=False) for local_obs in training_generator: # Transfer to GPU (if GPU is enabled, else this does nothing) local_obs = local_obs[0].to(device) noise = torch.randn(local_obs.shape[0], act_size).to(device) local_acts, local_logp = model.select_action_parallel( local_obs, noise) q_in = torch.cat((local_obs, local_acts), dim=1) pol_loss = torch.sum(alpha * local_logp - model.q1_fn(q_in)) / replay_batch_size # do the normal pytorch update pol_opt.zero_grad() pol_loss.backward() pol_opt.step() # Update target networks # ======================================================================== val_sd = model.value_fn.state_dict() tar_sd = target_value_fn.state_dict() for layer in tar_sd: tar_sd[layer] = polyak * tar_sd[layer] + ( 1 - polyak) * val_sd[layer] target_value_fn.load_state_dict(tar_sd) return model, raw_rew_hist, locals()
def sac( env_name, train_steps, model, env_max_steps=0, min_steps_per_update=1, iters_per_update=100, replay_batch_size=64, seed=0, gamma=0.95, polyak=0.995, alpha=0.2, sgd_batch_size=64, sgd_lr=1e-3, exploration_steps=100, replay_buf_size=int(100000), normalize_steps = 1000, use_gpu=False, reward_stop=None, env_config = {}, ): """ Implements soft actor critic Args: env_name: name of the openAI gym environment to solve train_steps: number of timesteps to run the PPO for model: model from seagul.rl.models. Contains policy, value fn, q1_fn, q2_fn min_steps_per_update: minimun number of steps to take before running updates, will finish episodes before updating env_max_steps: number of steps the environment takes before finishing, if the environment emits a done signal before this we consider it a failure. iters_per_update: how many update steps to make every time we update replay_batch_size: how big a batch to pull from the replay buffer for each update seed: random seed for all rngs gamma: discount applied to future rewards, usually close to 1 polyak: term determining how fast the target network is copied from the value function alpha: weighting term for the entropy. 0 corresponds to no penalty for deterministic policy sgd_batch_size: minibatch size for policy updates sgd_lr: initial learning rate for policy optimizer val_lr: initial learning rate for value optimizer q_lr: initial learning rate for q fn optimizer exploration_steps: initial number of random actions to take, aids exploration replay_buf_size: how big of a replay buffer to use use_gpu: determines if we try to use a GPU or not reward_stop: reward value to bail at env_config: dictionary containing kwargs to pass to your the environment Returns: model: trained model avg_reward_hist: list with the average reward per episode at each epoch var_dict: dictionary with all locals, for logging/debugging purposes Example: from seagul.rl.algos.sac import sac import torch.nn as nn from seagul.nn import MLP from seagul.rl.models import SACModel input_size = 3 output_size = 1 layer_size = 64 num_layers = 2 activation = nn.ReLU policy = MLP(input_size, output_size*2, num_layers, layer_size, activation) value_fn = MLP(input_size, 1, num_layers, layer_size, activation) q1_fn = MLP(input_size + output_size, 1, num_layers, layer_size, activation) q2_fn = MLP(input_size + output_size, 1, num_layers, layer_size, activation) model = SACModel(policy, value_fn, q1_fn, q2_fn, 1) model, rews, var_dict = sac("Pendulum-v0", 10000, model) """ torch.set_num_threads(1) env = gym.make(env_name, **env_config) if isinstance(env.action_space, gym.spaces.Box): act_size = env.action_space.shape[0] act_dtype = env.action_space.sample().dtype else: raise NotImplementedError("trying to use unsupported action space", env.action_space) obs_size = env.observation_space.shape[0] random_model = RandModel(model.act_limit, act_size) replay_buf = ReplayBuffer(obs_size, act_size, replay_buf_size) target_value_fn = dill.loads(dill.dumps(model.value_fn)) pol_opt = torch.optim.Adam(model.policy.parameters(), lr=sgd_lr) val_opt = torch.optim.Adam(model.value_fn.parameters(), lr=sgd_lr) q1_opt = torch.optim.Adam(model.q1_fn.parameters(), lr=sgd_lr) q2_opt = torch.optim.Adam(model.q2_fn.parameters(), lr=sgd_lr) # seed all our RNGs env.seed(seed) torch.manual_seed(seed) np.random.seed(seed) # set defaults, and decide if we are using a GPU or not use_cuda = torch.cuda.is_available() and use_gpu device = torch.device("cuda:0" if use_cuda else "cpu") raw_rew_hist = [] val_loss_hist = [] pol_loss_hist = [] q1_loss_hist = [] q2_loss_hist = [] progress_bar = tqdm.tqdm(total=train_steps + normalize_steps) cur_total_steps = 0 progress_bar.update(0) early_stop = False norm_obs1 = torch.empty(0) while cur_total_steps < normalize_steps: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, env_max_steps) norm_obs1 = torch.cat((norm_obs1, ep_obs1)) ep_steps = ep_rews.shape[0] cur_total_steps += ep_steps progress_bar.update(ep_steps) if normalize_steps > 0: obs_mean = norm_obs1.mean(axis=0) obs_std = norm_obs1.std(axis=0) obs_std[torch.isinf(1/obs_std)] = 1 model.policy.state_means = obs_mean model.policy.state_std = obs_std model.value_fn.state_means = obs_mean model.value_fn.state_std = obs_std target_value_fn.state_means = obs_mean target_value_fn.state_std = obs_std model.q1_fn.state_means = torch.cat((obs_mean, torch.zeros(act_size))) model.q1_fn.state_std = torch.cat((obs_std, torch.ones(act_size))) model.q2_fn.state_means = model.q1_fn.state_means model.q2_fn.state_std = model.q1_fn.state_std while cur_total_steps < exploration_steps: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, random_model, env_max_steps) replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) ep_steps = ep_rews.shape[0] cur_total_steps += ep_steps progress_bar.update(ep_steps) while cur_total_steps < train_steps: cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(raw_rew_hist) > 2 and reward_stop: if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[-2] >= reward_stop: early_stop = True break # collect data with the current policy # ======================================================================== while cur_batch_steps < min_steps_per_update: ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done = do_rollout(env, model, env_max_steps) replay_buf.store(ep_obs1, ep_obs2, ep_acts, ep_rews, ep_done) ep_steps = ep_rews.shape[0] cur_batch_steps += ep_steps cur_total_steps += ep_steps raw_rew_hist.append(torch.sum(ep_rews)) print(raw_rew_hist[-1]) progress_bar.update(cur_batch_steps) for _ in range(min(int(ep_steps), iters_per_update)): # compute targets for Q and V # ======================================================================== replay_obs1, replay_obs2, replay_acts, replay_rews, replay_done = replay_buf.sample_batch(replay_batch_size) q_targ = replay_rews + gamma * (1 - replay_done) * target_value_fn(replay_obs2) q_targ = q_targ.detach() noise = torch.randn(replay_batch_size, act_size) sample_acts, sample_logp = model.select_action(replay_obs1, noise) q_in = torch.cat((replay_obs1, sample_acts), dim=1) q_preds = torch.cat((model.q1_fn(q_in), model.q2_fn(q_in)), dim=1) q_min, q_min_idx = torch.min(q_preds, dim=1) q_min = q_min.reshape(-1, 1) v_targ = q_min - alpha * sample_logp v_targ = v_targ.detach() # q_fn update # ======================================================================== num_mbatch = int(replay_batch_size / sgd_batch_size) for i in range(num_mbatch): cur_sample = i*sgd_batch_size q_in = torch.cat((replay_obs1[cur_sample:cur_sample + sgd_batch_size], replay_acts[cur_sample:cur_sample + sgd_batch_size]), dim=1) q1_preds = model.q1_fn(q_in) q2_preds = model.q2_fn(q_in) q1_loss = torch.pow(q1_preds - q_targ[cur_sample:cur_sample + sgd_batch_size], 2).mean() q2_loss = torch.pow(q2_preds - q_targ[cur_sample:cur_sample + sgd_batch_size], 2).mean() q_loss = q1_loss + q2_loss q1_opt.zero_grad() q2_opt.zero_grad() q_loss.backward() q1_opt.step() q2_opt.step() # val_fn update # ======================================================================== for i in range(num_mbatch): cur_sample = i*sgd_batch_size # predict and calculate loss for the batch val_preds = model.value_fn(replay_obs1[cur_sample:cur_sample + sgd_batch_size]) val_loss = torch.sum(torch.pow(val_preds - v_targ[cur_sample:cur_sample + sgd_batch_size], 2)) / replay_batch_size # do the normal pytorch update val_opt.zero_grad() val_loss.backward() val_opt.step() # policy_fn update # ======================================================================== for param in model.q1_fn.parameters(): param.requires_grad = False for i in range(num_mbatch): cur_sample = i*sgd_batch_size noise = torch.randn(replay_obs1[cur_sample:cur_sample + sgd_batch_size].shape[0], act_size) local_acts, local_logp = model.select_action(replay_obs1[cur_sample:cur_sample + sgd_batch_size], noise) q_in = torch.cat((replay_obs1[cur_sample:cur_sample + sgd_batch_size], local_acts), dim=1) pol_loss = torch.sum(alpha * local_logp - model.q1_fn(q_in)) / replay_batch_size pol_opt.zero_grad() pol_loss.backward() pol_opt.step() for param in model.q1_fn.parameters(): param.requires_grad = True # Update target value fn with polyak average # ======================================================================== val_loss_hist.append(val_loss.item()) pol_loss_hist.append(pol_loss.item()) q1_loss_hist.append(q1_loss.item()) q2_loss_hist.append(q2_loss.item()) val_sd = model.value_fn.state_dict() tar_sd = target_value_fn.state_dict() for layer in tar_sd: tar_sd[layer] = polyak * tar_sd[layer] + (1 - polyak) * val_sd[layer] target_value_fn.load_state_dict(tar_sd) return model, raw_rew_hist, locals()
def ppo_visit( env_name, total_steps, model, vc=.01, replay_buf_size=int(5e4), act_std_schedule=(0.7,), epoch_batch_size=2048, gamma=0.99, lam=0.95, eps=0.2, seed=0, entropy_coef=0.0, sgd_batch_size=1024, lr_schedule=(3e-4,), sgd_epochs=10, target_kl=float('inf'), val_coef=.5, clip_val=True, env_no_term_steps=0, use_gpu=False, reward_stop=None, normalize_return=True, normalize_obs=True, normalize_adv=True, env_config={} ): """ Implements proximal policy optimization with clipping Args: env_name: name of the openAI gym environment to solve total_steps: number of timesteps to run the PPO for model: model from seagul.rl.models. Contains policy and value fn act_std_schedule: schedule to set the variance of the policy. Will linearly interpolate values epoch_batch_size: number of environment steps to take per batch, total steps will be num_epochs*epoch_batch_size seed: seed for all the rngs gamma: discount applied to future rewards, usually close to 1 lam: lambda for the Advantage estimation, usually close to 1 eps: epsilon for the clipping, usually .1 or .2 sgd_batch_size: batch size for policy updates sgd_batch_size: batch size for value function updates lr_schedule: learning rate for policy pol_optimizer sgd_epochs: how many epochs to use for each policy update val_epochs: how many epochs to use for each value update target_kl: max KL before breaking use_gpu: want to use the GPU? set to true reward_stop: reward value to stop if we achieve normalize_return: should we normalize the return? env_config: dictionary containing kwargs to pass to your the environment Returns: model: trained model avg_reward_hist: list with the average reward per episode at each epoch var_dict: dictionary with all locals, for logging/debugging purposes Example: from seagul.rl.algos import ppo from seagul.nn import MLP from seagul.rl.models import PPOModel import torch input_size = 3 output_size = 1 layer_size = 64 num_layers = 2 policy = MLP(input_size, output_size, num_layers, layer_size) value_fn = MLP(input_size, 1, num_layers, layer_size) model = PPOModel(policy, value_fn) model, rews, var_dict = ppo("Pendulum-v0", 10000, model) """ # init everything # ============================================================================== torch.set_num_threads(1) env = gym.make(env_name, **env_config) if isinstance(env.action_space, gym.spaces.Box): act_size = env.action_space.shape[0] act_dtype = torch.double else: raise NotImplementedError("trying to use unsupported action space", env.action_space) replay_buf = ReplayBuffer(env.observation_space.shape[0], act_size, replay_buf_size) actstd_lookup = make_schedule(act_std_schedule, total_steps) lr_lookup = make_schedule(lr_schedule, total_steps) model.action_var = actstd_lookup(0) sgd_lr = lr_lookup(0) obs_size = env.observation_space.shape[0] obs_mean = torch.zeros(obs_size) obs_std = torch.ones(obs_size) rew_mean = torch.zeros(1) rew_std = torch.ones(1) # copy.deepcopy broke for me with older version of torch. Using pickle for this is weird but works fine old_model = pickle.loads(pickle.dumps(model)) # seed all our RNGs env.seed(seed) torch.manual_seed(seed) np.random.seed(seed) # set defaults, and decide if we are using a GPU or not use_cuda = torch.cuda.is_available() and use_gpu device = torch.device("cuda:0" if use_cuda else "cpu") # init logging stuff raw_rew_hist = [] val_loss_hist = [] pol_loss_hist = [] progress_bar = tqdm.tqdm(total=total_steps) cur_total_steps = 0 progress_bar.update(0) early_stop = False # Train until we hit our total steps or reach our reward threshold # ============================================================================== while cur_total_steps < total_steps: pol_opt = torch.optim.Adam(model.policy.parameters(), lr=sgd_lr) val_opt = torch.optim.Adam(model.value_fn.parameters(), lr=sgd_lr) batch_obs = torch.empty(0) batch_act = torch.empty(0) batch_adv = torch.empty(0) batch_discrew = torch.empty(0) cur_batch_steps = 0 # Bail out if we have met out reward threshold if len(raw_rew_hist) > 2 and reward_stop: if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[-2] >= reward_stop: early_stop = True break # construct batch data from rollouts # ============================================================================== while cur_batch_steps < epoch_batch_size: ep_obs, ep_act, ep_rew, ep_steps, ep_term = do_rollout(env, model, env_no_term_steps) raw_rew_hist.append(sum(ep_rew).item()) for i, obs in enumerate(ep_obs): ep_rew[i] -= (np.min(np.linalg.norm(obs - replay_buf.obs1_buf, axis=1)))*vc replay_buf.store(ep_obs, ep_obs, ep_act, ep_rew, ep_rew) batch_obs = torch.cat((batch_obs, ep_obs[:-1])) batch_act = torch.cat((batch_act, ep_act[:-1])) if not ep_term: ep_rew[-1] = model.value_fn(ep_obs[-1]).detach() ep_discrew = discount_cumsum(ep_rew, gamma) if normalize_return: rew_mean = update_mean(batch_discrew, rew_mean, cur_total_steps) rew_std = update_std(ep_discrew, rew_std, cur_total_steps) ep_discrew = ep_discrew / (rew_std + 1e-6) batch_discrew = torch.cat((batch_discrew, ep_discrew[:-1])) ep_val = model.value_fn(ep_obs) deltas = ep_rew[:-1] + gamma * ep_val[1:] - ep_val[:-1] ep_adv = discount_cumsum(deltas.detach(), gamma * lam) batch_adv = torch.cat((batch_adv, ep_adv)) cur_batch_steps += ep_steps cur_total_steps += ep_steps # make sure our advantages are zero mean and unit variance if normalize_adv: #adv_mean = update_mean(batch_adv, adv_mean, cur_total_steps) #adv_var = update_std(batch_adv, adv_var, cur_total_steps) batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-6) num_mbatch = int(batch_obs.shape[0] / sgd_batch_size) # Update the policy using the PPO loss for pol_epoch in range(sgd_epochs): for i in range(num_mbatch): # policy update # ======================================================================== cur_sample = i * sgd_batch_size # Transfer to GPU (if GPU is enabled, else this does nothing) local_obs = batch_obs[cur_sample:cur_sample + sgd_batch_size] local_act = batch_act[cur_sample:cur_sample + sgd_batch_size] local_adv = batch_adv[cur_sample:cur_sample + sgd_batch_size] local_val = batch_discrew[cur_sample:cur_sample + sgd_batch_size] # Compute the loss logp = model.get_logp(local_obs, local_act).reshape(-1, act_size) old_logp = old_model.get_logp(local_obs, local_act).reshape(-1, act_size) mean_entropy = -(logp*torch.exp(logp)).mean() r = torch.exp(logp - old_logp) clip_r = torch.clamp(r, 1 - eps, 1 + eps) pol_loss = -torch.min(r * local_adv, clip_r * local_adv).mean() - entropy_coef*mean_entropy approx_kl = ((logp - old_logp)**2).mean() if approx_kl > target_kl: break pol_opt.zero_grad() pol_loss.backward() pol_opt.step() # value_fn update # ======================================================================== val_preds = model.value_fn(local_obs) if clip_val: old_val_preds = old_model.value_fn(local_obs) val_preds_clipped = old_val_preds + torch.clamp(val_preds - old_val_preds, -eps, eps) val_loss1 = (val_preds_clipped - local_val)**2 val_loss2 = (val_preds - local_val)**2 val_loss = val_coef*torch.max(val_loss1, val_loss2).mean() else: val_loss = val_coef*((val_preds - local_val) ** 2).mean() val_opt.zero_grad() val_loss.backward() val_opt.step() # update observation mean and variance if normalize_obs: obs_mean = update_mean(batch_obs, obs_mean, cur_total_steps) obs_std = update_std(batch_obs, obs_std, cur_total_steps) model.policy.state_means = obs_mean model.value_fn.state_means = obs_mean model.policy.state_std = obs_std model.value_fn.state_std = obs_std model.action_std = actstd_lookup(cur_total_steps) sgd_lr = lr_lookup(cur_total_steps) old_model = pickle.loads(pickle.dumps(model)) val_loss_hist.append(val_loss) pol_loss_hist.append(pol_loss) progress_bar.update(cur_batch_steps) progress_bar.close() return model, raw_rew_hist, locals()