def forward(self, obs): # state feature obs = to_tensor(obs).unsqueeze(0).to(self.device) phi = self.phi_body(obs) # option mean = [] std = [] beta = [] for option in self.options: prediction = option(phi) mean.append(prediction['mean'].unsqueeze(1)) std.append(prediction['std'].unsqueeze(1)) beta.append(prediction['beta']) mean = torch.cat(mean, dim=1) std = torch.cat(std, dim=1) beta = torch.cat(beta, dim=1) # critic network phi_c = self.critic_body(phi) q_o = self.fc_q_o(phi_c) return {'mean': mean, 'std': std, 'q_o': q_o, 'beta': beta}
def __init__(self, env_fn, save_dir, tensorboard_logdir=None, optimizer_class=RMSprop, oc_kwargs=dict(), logger_kwargs=dict(), eps_start=1.0, eps_end=0.1, eps_decay=1e4, lr=1e-3, gamma=0.99, rollout_length=2048, beta_reg=0.01, entropy_weight=0.01, gradient_clip=5, target_network_update_freq=200, max_ep_len=2000, save_freq=200, seed=0, **kwargs): self.seed = seed torch.manual_seed(seed) np.random.seed(seed) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.lr = lr self.env_fn = env_fn self.env = env_fn() self.oc_kwargs = oc_kwargs self.network_fn = self.get_network_fn(self.oc_kwargs) self.network = self.network_fn().to(self.device) self.target_network = self.network_fn().to(self.device) self.optimizer_class = optimizer_class self.optimizer = optimizer_class(self.network.parameters(), self.lr) self.target_network.load_state_dict(self.network.state_dict()) self.eps_start = eps_start self.eps_end = eps_end self.eps_decay = eps_decay self.eps_schedule = LinearSchedule(eps_start, eps_end, eps_decay) self.gamma = gamma self.rollout_length = rollout_length self.num_options = oc_kwargs['num_options'] self.beta_reg = beta_reg self.entropy_weight = entropy_weight self.gradient_clip = gradient_clip self.target_network_update_freq = target_network_update_freq self.max_ep_len = max_ep_len self.save_freq = save_freq self.save_dir = save_dir self.logger = Logger(**logger_kwargs) self.tensorboard_logdir = tensorboard_logdir # self.tensorboard_logger = SummaryWriter(log_dir=tensorboard_logdir) self.is_initial_states = to_tensor(np.ones((1))).byte() self.prev_options = self.is_initial_states.clone().long().to( self.device) self.best_mean_reward = -np.inf
def forward(self, obs, unsqueeze=True): obs = to_tensor(obs).to(self.device) if unsqueeze: obs = obs.unsqueeze(0) # obs = to_tensor(obs).unsqueeze(0).to(self.device) phi = self.phi_body(obs) mean = [] std = [] beta = [] for option in self.options: prediction = option(phi) mean.append(prediction['mean'].unsqueeze(1)) std.append(prediction['std'].unsqueeze(1)) beta.append(prediction['beta']) mean = torch.cat(mean, dim=1) std = torch.cat(std, dim=1) beta = torch.cat(beta, dim=1) phi_a = self.actor_body(phi) phi_a = self.fc_pi_o(phi_a) pi_o = F.softmax(phi_a, dim=-1) log_pi_o = F.log_softmax(phi_a, dim=-1) phi_c = self.critic_body(phi) q_o = self.fc_q_o(phi_c) return {'mean': mean, 'std': std, 'q_o': q_o, 'inter_pi': pi_o, 'log_inter_pi': log_pi_o, 'beta': beta}
def forward(self, x): phi = self.body(to_tensor(x).to(self.device)) q = self.fc_q(phi) beta = torch.sigmoid(self.fc_beta(phi)) pi = self.fc_pi(phi) pi = pi.view(-1, self.num_options, self.action_dim) log_pi = F.log_softmax(pi, dim=-1) pi = F.softmax(pi, dim=-1) return {'q': q, 'beta': beta, 'log_pi': log_pi, 'pi': pi}
def compute_adv(self, storage, mdp): v = storage.__getattribute__('v_%s' % (mdp)) adv = storage.__getattribute__('adv_%s' % (mdp)) all_ret = storage.__getattribute__('ret_%s' % (mdp)) ret = v[-1].detach() advantages = to_tensor(np.zeros((1))).to(self.device) for i in reversed(range(self.rollout_length)): ret = storage.r[i] + self.gamma * storage.m[i] * ret if not self.use_gae: advantages = ret - v[i].detach() else: td_error = storage.r[i] + self.gamma * storage.m[i] * v[ i + 1] - v[i] advantages = advantages * self.gae_tau * self.gamma * storage.m[ i] + td_error adv[i] = advantages.detach() all_ret[i] = ret.detach()
def test(self, timesteps=None, render=False, record=False): ''' Test the agent in the environment Args: render (bool): If true, render the image out for user to see in real time record (bool): If true, save the recording into a .gif file at the end of episode timesteps (int): number of timesteps to run the environment for. Default None will run to completion Return: Ep_Ret (int): Total reward from the episode Ep_Len (int): Total length of the episode in terms of timesteps ''' self.env.training = False self.network.eval() if render: self.env.render('human') states, terminals, ep_ret, ep_len = self.env.reset(), False, 0, 0 is_initial_states = to_tensor(np.ones((1))).byte().to(self.device) prev_options = to_tensor(np.zeros((1))).long().to(self.device) img = [] if record: img.append(self.env.render('rgb_array')) if timesteps is not None: for i in range(timesteps): prediction = self.network(states) pi_hat = self.compute_pi_hat(prediction, prev_options, is_initial_states) dist = torch.distributions.Categorical(probs=pi_hat) options = dist.sample() # Gaussian policy mean = prediction['mean'][0, options] std = prediction['std'][0, options] dist = torch.distributions.Normal(mean, std) # select action actions = mean next_states, rewards, terminals, _ = self.env.step( to_np(actions[0])) is_initial_states = to_tensor(terminals).unsqueeze( -1).byte().to(self.device) prev_options = options states = next_states if record: img.append(self.env.render('rgb_array')) else: self.env.render() ep_ret += rewards ep_len += 1 else: while not (terminals or (ep_len == self.max_ep_len)): # select option prediction = self.network(states) pi_hat = self.compute_pi_hat(prediction, prev_options, is_initial_states) dist = torch.distributions.Categorical(probs=pi_hat) options = dist.sample() # Gaussian policy mean = prediction['mean'][0, options] std = prediction['std'][0, options] # dist = torch.distributions.Normal(mean, std) # select action actions = mean next_states, rewards, terminals, _ = self.env.step( to_np(actions[0])) is_initial_states = to_tensor(terminals).unsqueeze( -1).byte().to(self.device) prev_options = options states = next_states if record: img.append(self.env.render('rgb_array')) else: self.env.render() ep_ret += rewards ep_len += 1 if record: imageio.mimsave( f'{os.path.join(self.save_dir, "recording.gif")}', [np.array(img) for i, img in enumerate(img) if i % 2 == 0], fps=29) self.env.training = True return ep_ret, ep_len
def learn_one_trial(self, num_timesteps, trial_num=1): self.states, ep_ret, ep_len = self.env.reset(), 0, 0 storage = Storage(self.rollout_length, ['adv_bar', 'adv_hat', 'ret_bar', 'ret_hat']) states = self.states for timestep in tqdm(range(1, num_timesteps + 1)): prediction = self.network(states) pi_hat = self.compute_pi_hat(prediction, self.prev_options, self.is_initial_states) dist = torch.distributions.Categorical(probs=pi_hat) options = dist.sample() # Gaussian policy mean = prediction['mean'][0, options] std = prediction['std'][0, options] dist = torch.distributions.Normal(mean, std) # select action actions = dist.sample() pi_bar = self.compute_pi_bar(options.unsqueeze(-1), actions, prediction['mean'], prediction['std']) v_bar = prediction['q_o'].gather(1, options.unsqueeze(-1)) v_hat = (prediction['q_o'] * pi_hat).sum(-1).unsqueeze(-1) next_states, rewards, terminals, _ = self.env.step( to_np(actions[0])) ep_ret += rewards ep_len += 1 # end of episode handling if terminals or ep_len == self.max_ep_len: next_states = self.env.reset() self.record_online_return(ep_ret, timestep, ep_len) ep_ret, ep_len = 0, 0 # Retrieve training reward x, y = self.logger.load_results(["EpLen", "EpRet"]) if len(x) > 0: # Mean training reward over the last 50 episodes mean_reward = np.mean(y[-50:]) # New best model if mean_reward > self.best_mean_reward: print("Num timesteps: {}".format(timestep)) print( "Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}" .format(self.best_mean_reward, mean_reward)) self.best_mean_reward = mean_reward self.save_weights(fname=f"best_{trial_num}.pth") if self.env.spec.reward_threshold is not None and self.best_mean_reward >= self.env.spec.reward_threshold: print("Solved Environment, stopping iteration...") return storage.add(prediction) storage.add({ 'r': to_tensor(rewards).to(self.device).unsqueeze(-1), 'm': to_tensor(1 - terminals).to(self.device).unsqueeze(-1), 'a': actions, 'o': options.unsqueeze(-1), 'prev_o': self.prev_options.unsqueeze(-1), 's': to_tensor(states).unsqueeze(0), 'init': self.is_initial_states.unsqueeze(-1), 'pi_hat': pi_hat, 'log_pi_hat': pi_hat[0, options].add(1e-5).log().unsqueeze(-1), 'log_pi_bar': pi_bar.add(1e-5).log(), 'v_bar': v_bar, 'v_hat': v_hat }) self.is_initial_states = to_tensor(terminals).unsqueeze(-1).to( self.device).byte() self.prev_options = options states = next_states if timestep % self.rollout_length == 0: self.states = states prediction = self.network(states) pi_hat = self.compute_pi_hat(prediction, self.prev_options, self.is_initial_states) dist = torch.distributions.Categorical(pi_hat) options = dist.sample() v_bar = prediction['q_o'].gather(1, options.unsqueeze(-1)) v_hat = (prediction['q_o'] * pi_hat).sum(-1).unsqueeze(-1) storage.add(prediction) storage.add({ 'v_bar': v_bar, 'v_hat': v_hat, }) storage.placeholder() self.compute_adv(storage, 'bar') self.compute_adv(storage, 'hat') mdps = ['hat', 'bar'] np.random.shuffle(mdps) self.update(storage, mdps[0], timestep) self.update(storage, mdps[1], timestep) storage = Storage(self.rollout_length, ['adv_bar', 'adv_hat', 'ret_bar', 'ret_hat']) if self.save_freq > 0 and timestep % self.save_freq == 0: self.save_weights(fname=f"latest_{trial_num}.pth")
def __init__(self, env_fn, save_dir, tensorboard_logdir=None, optimizer_class=Adam, weight_decay=0, oc_kwargs=dict(), logger_kwargs=dict(), lr=1e-3, optimization_epochs=5, mini_batch_size=64, ppo_ratio_clip=0.2, gamma=0.99, rollout_length=2048, beta_weight=0, entropy_weight=0.01, gradient_clip=5, gae_tau=0.95, max_ep_len=2000, save_freq=200, seed=0, **kwargs): self.seed = seed torch.manual_seed(seed) np.random.seed(seed) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.lr = lr self.env_fn = env_fn self.env = env_fn() self.oc_kwargs = oc_kwargs self.network_fn = self.get_network_fn(self.oc_kwargs) self.network = self.network_fn().to(self.device) self.optimizer_class = optimizer_class self.weight_decay = weight_decay self.optimizer = optimizer_class(self.network.parameters(), self.lr, weight_decay=self.weight_decay) self.gamma = gamma self.rollout_length = rollout_length self.num_options = oc_kwargs['num_options'] self.beta_weight = beta_weight self.entropy_weight = entropy_weight self.gradient_clip = gradient_clip self.max_ep_len = max_ep_len self.save_freq = save_freq self.save_dir = save_dir self.logger = Logger(**logger_kwargs) self.tensorboard_logdir = tensorboard_logdir # self.tensorboard_logger = SummaryWriter(log_dir=tensorboard_logdir) self.is_initial_states = to_tensor(np.ones((1))).byte().to(self.device) self.prev_options = to_tensor(np.zeros((1))).long().to(self.device) self.best_mean_reward = -np.inf self.optimization_epochs = optimization_epochs self.mini_batch_size = mini_batch_size self.ppo_ratio_clip = ppo_ratio_clip self.gae_tau = gae_tau self.use_gae = self.gae_tau > 0
def update(self, storage, mdp, timestep, freeze_v=False): states, actions, options, log_probs_old, returns, advantages, prev_options, inits, pi_hat, mean, std = \ storage.cat( ['s', 'a', 'o', 'log_pi_%s' % (mdp), 'ret_%s' % (mdp), 'adv_%s' % (mdp), 'prev_o', 'init', 'pi_hat', 'mean', 'std']) actions = actions.detach() log_probs_old = log_probs_old.detach() pi_hat = pi_hat.detach() mean = mean.detach() std = std.detach() advantages = (advantages - advantages.mean()) / advantages.std() for _ in range(self.optimization_epochs): sampler = random_sample(np.arange(states.size(0)), self.mini_batch_size) for batch_indices in sampler: batch_indices = to_tensor(batch_indices).long() sampled_pi_hat = pi_hat[batch_indices] sampled_mean = mean[batch_indices] sampled_std = std[batch_indices] sampled_states = states[batch_indices] sampled_prev_o = prev_options[batch_indices] sampled_init = inits[batch_indices] sampled_options = options[batch_indices] sampled_actions = actions[batch_indices] sampled_log_probs_old = log_probs_old[batch_indices] sampled_returns = returns[batch_indices] sampled_advantages = advantages[batch_indices] prediction = self.network(sampled_states, unsqueeze=False) if mdp == 'hat': cur_pi_hat = self.compute_pi_hat(prediction, sampled_prev_o.view(-1), sampled_init.view(-1)) entropy = -(cur_pi_hat * cur_pi_hat.add(1e-5).log()).sum(-1).mean() log_pi_a = self.compute_log_pi_a(sampled_options, cur_pi_hat, sampled_actions, sampled_mean, sampled_std, mdp) beta_loss = prediction['beta'].mean() elif mdp == 'bar': log_pi_a = self.compute_log_pi_a(sampled_options, sampled_pi_hat, sampled_actions, prediction['mean'], prediction['std'], mdp) entropy = 0 beta_loss = 0 else: raise NotImplementedError if mdp == 'bar': v = prediction['q_o'].gather(1, sampled_options) elif mdp == 'hat': v = (prediction['q_o'] * sampled_pi_hat).sum(-1).unsqueeze(-1) else: raise NotImplementedError ratio = (log_pi_a - sampled_log_probs_old).exp() obj = ratio * sampled_advantages obj_clipped = ratio.clamp( 1.0 - self.ppo_ratio_clip, 1.0 + self.ppo_ratio_clip) * sampled_advantages policy_loss = -torch.min(obj, obj_clipped).mean() - self.entropy_weight * entropy + \ self.beta_weight * beta_loss # discarded = (obj > obj_clipped).float().mean() value_loss = 0.5 * (sampled_returns - v).pow(2).mean() self.tensorboard_logger.add_scalar(f"loss/{mdp}_value_loss", value_loss.item(), timestep) self.tensorboard_logger.add_scalar(f"loss/{mdp}_policy_loss", policy_loss.item(), timestep) self.tensorboard_logger.add_scalar( f"loss/{mdp}_beta_loss", beta_loss if isinstance( beta_loss, int) else beta_loss.item(), timestep) if freeze_v: value_loss = 0 self.optimizer.zero_grad() (policy_loss + value_loss).backward() nn.utils.clip_grad_norm_(self.network.parameters(), self.gradient_clip) self.optimizer.step()
def learn_one_trial(self, num_timesteps, trial_num=1): self.states, ep_ret, ep_len = self.env.reset(), 0, 0 storage = Storage(self.rollout_length, ['beta', 'o', 'beta_adv', 'prev_o', 'init', 'eps']) for timestep in tqdm(range(1, num_timesteps + 1)): prediction = self.network(self.states) epsilon = self.eps_schedule() # select option options = self.sample_option(prediction, epsilon, self.prev_options, self.is_initial_states) prediction['pi'] = prediction['pi'][0, options] prediction['log_pi'] = prediction['log_pi'][0, options] dist = torch.distributions.Categorical(probs=prediction['pi']) actions = dist.sample() entropy = dist.entropy() next_states, rewards, terminals, _ = self.env.step(to_np(actions)) ep_ret += rewards ep_len += 1 # end of episode handling if terminals or ep_len == self.max_ep_len: next_states = self.env.reset() self.record_online_return(ep_ret, timestep, ep_len) ep_ret, ep_len = 0, 0 # Retrieve training reward x, y = self.logger.load_results(["EpLen", "EpRet"]) if len(x) > 0: # Mean training reward over the last 50 episodes mean_reward = np.mean(y[-50:]) # New best model if mean_reward > self.best_mean_reward: print("Num timesteps: {}".format(timestep)) print( "Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}" .format(self.best_mean_reward, mean_reward)) self.best_mean_reward = mean_reward self.save_weights(fname=f"best_{trial_num}.pth") if self.env.spec.reward_threshold is not None and self.best_mean_reward >= self.env.spec.reward_threshold: print("Solved Environment, stopping iteration...") return storage.add(prediction) storage.add({ 'r': to_tensor(rewards).to(self.device).unsqueeze(-1), 'm': to_tensor(1 - terminals).to(self.device).unsqueeze(-1), 'o': options.unsqueeze(-1), 'prev_o': self.prev_options.unsqueeze(-1), 'ent': entropy, 'a': actions.unsqueeze(-1), 'init': self.is_initial_states.unsqueeze(-1).to(self.device).float(), 'eps': epsilon }) self.is_initial_states = to_tensor(terminals).unsqueeze(-1).byte() self.prev_options = options self.states = next_states if timestep % self.target_network_update_freq == 0: self.target_network.load_state_dict(self.network.state_dict()) if timestep % self.rollout_length == 0: self.update(storage, self.states, timestep) storage = Storage( self.rollout_length, ['beta', 'o', 'beta_adv', 'prev_o', 'init', 'eps']) if self.save_freq > 0 and timestep % self.save_freq == 0: self.save_weights(fname=f"latest_{trial_num}.pth")