def ppo_penalized( env_fn, actor_critic=ActorCritic, ac_kwargs=dict(), seed=0, episodes_per_epoch=40, epochs=500, gamma=0.99, lam=0.98, pi_lr=3e-4, vf_lr=1e-3, train_v_iters=80, train_pi_iters=1, ## NOTE: Incredibly Important That This Be Low For Penalized Learning max_ep_len=1000, logger_kwargs=dict(), clip_ratio=0.2, # tuned???? # Cost constraints / penalties: cost_lim=25, penalty_init=1., penalty_lr=5e-3, config_name='standard', save_freq=10): # W&B Logging wandb.login() composite_name = 'new_ppo_penalized_' + config_name wandb.init(project="LearningCurves", group="PPO Expert", name=composite_name) # Special function to avoid certain slowdowns from PyTorch + MPI combo. setup_pytorch_for_mpi() # Set up logger and save configuration logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) seed += 10000 * proc_id() torch.manual_seed(seed) np.random.seed(seed) # Instantiate environment env = env_fn() obs_dim = env.observation_space.shape act_dim = env.action_space.shape ac_kwargs['action_space'] = env.action_space # Models # Create actor-critic module and monitor it ac = actor_critic(input_dim=obs_dim[0], **ac_kwargs) # Set up model saving logger.setup_pytorch_saver(ac) # Sync params across processes sync_params(ac) # Buffers local_episodes_per_epoch = int(episodes_per_epoch / num_procs()) buf = BufferActor(obs_dim[0], act_dim[0], local_episodes_per_epoch, max_ep_len) # Count variables var_counts = tuple(count_vars(module) for module in [ac.pi, ac.v]) logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts) # Optimizers pi_optimizer = torch.optim.Adam(ac.pi.parameters(), lr=pi_lr) vf_optimizer = torch.optim.Adam(ac.v.parameters(), lr=vf_lr) # pi_optimizer = AdaBelief(ac.pi.parameters(), betas=(0.9, 0.999), eps=1e-8) # vf_optimizer = AdaBelief(ac.v.parameters(), betas=(0.9, 0.999), eps=1e-8) # # Parameters Sync # sync_all_params(ac.parameters()) # Set up function for computing PPO policy loss def compute_loss_pi(obs, act, adv, logp_old): # Policy loss # policy gradient term + entropy term # Policy loss with clipping (without clipping, loss_pi = -(logp*adv).mean()). # TODO: Think about removing clipping _, logp, _ = ac.pi(obs, act) ratio = torch.exp(logp - logp_old) clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv loss_pi = -(torch.min(ratio * adv, clip_adv)).mean() return loss_pi def penalty_update(cur_penalty): cur_cost = logger.get_stats('EpCost')[0] cur_rew = logger.get_stats('EpRet')[0] # Penalty update cur_penalty = max(0, cur_penalty + penalty_lr * (cur_cost - cost_lim)) return cur_penalty def update(e): obs, act, adv, ret, logp_old = [ torch.Tensor(x) for x in buf.retrieve_all() ] # Policy _, logp, _ = ac.pi(obs, act) entropy = (-logp).mean() # Train policy with multiple steps of gradient descent for _ in range(train_pi_iters): pi_optimizer.zero_grad() loss_pi = compute_loss_pi(obs, act, adv, ret) loss_pi.backward() # average_gradients(train_pi.param_groups) # mpi_avg_grads(pi_optimizer.param_groups) mpi_avg_grads(ac.pi) pi_optimizer.step() # Value function training v = ac.v(obs) v_l_old = F.mse_loss(v, ret) # old loss for _ in range(train_v_iters): v = ac.v(obs) v_loss = F.mse_loss( v, ret) # how well did our value function predict loss? # Value function train vf_optimizer.zero_grad() v_loss.backward() # average_gradients(vf_optimizer.param_groups) mpi_avg_grads(ac.v) # average gradients across MPI processes vf_optimizer.step() # Log the changes _, logp, _, v = ac(obs, act) # entropy_new = (-logp).mean() pi_loss_new = -(logp * adv).mean() v_loss_new = F.mse_loss(v, ret) kl = (logp_old - logp).mean() logger.store(LossPi=loss_pi, LossV=v_l_old, DeltaLossPi=(pi_loss_new - loss_pi), DeltaLossV=(v_loss_new - v_l_old), Entropy=entropy, KL=kl) # Prepare for interaction with the environment start_time = time.time() o, r, d, ep_ret, ep_cost, ep_len = env.reset(), 0, False, 0, 0, 0 total_t = 0 # Initialize penalty cur_penalty = np.log(max(np.exp(penalty_init) - 1, 1e-8)) for epoch in range(epochs): ac.eval() # eval mode # Policy rollout for _ in range(local_episodes_per_epoch): for _ in range(max_ep_len): # obs = a, _, lopg_t, v_t = ac(torch.Tensor(o.reshape(1, -1))) logger.store(VVals=v_t) o, r, d, info = env.step(a.detach().numpy()[0]) c = info.get('cost', 0) # Include penalty on cost r_total = r - cur_penalty * c r_total /= (1 + cur_penalty) # store buf.store(o, a.detach().numpy(), r_total, v_t.item(), lopg_t.detach().numpy()) ep_ret += r ep_cost += c ep_len += 1 total_t += 1 terminal = d or (ep_len == max_ep_len) if terminal: # buf.end_episode() buf.finish_path() logger.store(EpRet=ep_ret, EpCost=ep_cost, EpLen=ep_len) print("end of episode return: ", ep_ret) episode_metrics = { 'average ep ret': ep_ret, 'average ep cost': ep_cost } wandb.log(episode_metrics) o, r, d, ep_ret, ep_cost, ep_len = env.reset( ), 0, False, 0, 0, 0 if (epoch % save_freq == 0) or (epoch == epochs - 1): # logger._torch_save(ac, fname="expert_torch_save.pt") # logger._torch_save(ac, fname="model.pt") logger.save_state({'env': env}, None, None) # Update ac.train() # update penalty cur_penalty = penalty_update(cur_penalty) # update networks update(epoch) # Log logger.log_tabular('Epoch', epoch) # logger.log_tabular('EpRet', with_min_and_max=True) logger.log_tabular('EpRet', average_only=True) logger.log_tabular('EpCost', average_only=True) logger.log_tabular('EpLen', average_only=True) logger.log_tabular('VVals', average_only=True) # logger.log_tabular('VVals', with_min_and_max=True) logger.log_tabular('TotalEnvInteracts', total_t) logger.log_tabular('LossPi', average_only=True) logger.log_tabular('DeltaLossPi', average_only=True) logger.log_tabular('LossV', average_only=True) logger.log_tabular('DeltaLossV', average_only=True) logger.log_tabular('Entropy', average_only=True) # logger.log_tabular('KL', average_only=True) logger.log_tabular('Time', time.time() - start_time) logger.dump_tabular() wandb.finish()
def valor(env_fn, actor_critic=ActorCritic, ac_kwargs=dict(), disc=Discriminator, dc_kwargs=dict(), seed=0, episodes_per_epoch=40, epochs=50, gamma=0.99, pi_lr=3e-4, vf_lr=1e-3, dc_lr=5e-4, train_v_iters=80, train_dc_iters=10, train_dc_interv=10, lam=0.97, max_ep_len=1000, logger_kwargs=dict(), con_dim=5, save_freq=10, k=1): logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) seed += 10000 * proc_id() torch.manual_seed(seed) np.random.seed(seed) env = env_fn() obs_dim = env.observation_space.shape act_dim = env.action_space.shape ac_kwargs['action_space'] = env.action_space # Model ac = actor_critic(input_dim=obs_dim[0] + con_dim, **ac_kwargs) disc = disc(input_dim=obs_dim[0], context_dim=con_dim, **dc_kwargs) # Set up model saving logger.setup_pytorch_saver([ac, disc]) # Buffer local_episodes_per_epoch = int(episodes_per_epoch / num_procs()) buffer = VALORBuffer(con_dim, obs_dim[0], act_dim[0], local_episodes_per_epoch, max_ep_len, train_dc_interv) # Count variables var_counts = tuple( count_vars(module) for module in [ac.policy, ac.value_f, disc.policy]) logger.log('\nNumber of parameters: \t pi: %d, \t v: %d, \t d: %d\n' % var_counts) # Optimizers train_pi = torch.optim.Adam(ac.policy.parameters(), lr=pi_lr) train_v = torch.optim.Adam(ac.value_f.parameters(), lr=vf_lr) train_dc = torch.optim.Adam(disc.policy.parameters(), lr=dc_lr) # Parameters Sync sync_all_params(ac.parameters()) sync_all_params(disc.parameters()) def update(e): obs, act, adv, pos, ret, logp_old = [ torch.Tensor(x) for x in buffer.retrieve_all() ] # Policy _, logp, _ = ac.policy(obs, act) entropy = (-logp).mean() # Policy loss pi_loss = -(logp * (k * adv + pos)).mean() # Train policy train_pi.zero_grad() pi_loss.backward() average_gradients(train_pi.param_groups) train_pi.step() # Value function v = ac.value_f(obs) v_l_old = F.mse_loss(v, ret) for _ in range(train_v_iters): v = ac.value_f(obs) v_loss = F.mse_loss(v, ret) # Value function train train_v.zero_grad() v_loss.backward() average_gradients(train_v.param_groups) train_v.step() # Discriminator if (e + 1) % train_dc_interv == 0: print('Discriminator Update!') con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()] _, logp_dc, _ = disc(s_diff, con) d_l_old = -logp_dc.mean() # Discriminator train for _ in range(train_dc_iters): _, logp_dc, _ = disc(s_diff, con) d_loss = -logp_dc.mean() train_dc.zero_grad() d_loss.backward() average_gradients(train_dc.param_groups) train_dc.step() _, logp_dc, _ = disc(s_diff, con) dc_l_new = -logp_dc.mean() else: d_l_old = 0 dc_l_new = 0 # Log the changes _, logp, _, v = ac(obs, act) pi_l_new = -(logp * (k * adv + pos)).mean() v_l_new = F.mse_loss(v, ret) kl = (logp_old - logp).mean() logger.store(LossPi=pi_loss, LossV=v_l_old, KL=kl, Entropy=entropy, DeltaLossPi=(pi_l_new - pi_loss), DeltaLossV=(v_l_new - v_l_old), LossDC=d_l_old, DeltaLossDC=(dc_l_new - d_l_old)) # logger.store(Adv=adv.reshape(-1).numpy().tolist(), Pos=pos.reshape(-1).numpy().tolist()) start_time = time.time() o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 context_dist = Categorical(logits=torch.Tensor(np.ones(con_dim))) total_t = 0 for epoch in range(epochs): ac.eval() disc.eval() for _ in range(local_episodes_per_epoch): c = context_dist.sample() c_onehot = F.one_hot(c, con_dim).squeeze().float() for _ in range(max_ep_len): concat_obs = torch.cat( [torch.Tensor(o.reshape(1, -1)), c_onehot.reshape(1, -1)], 1) a, _, logp_t, v_t = ac(concat_obs) buffer.store(c, concat_obs.squeeze().detach().numpy(), a.detach().numpy(), r, v_t.item(), logp_t.detach().numpy()) logger.store(VVals=v_t) o, r, d, _ = env.step(a.detach().numpy()[0]) ep_ret += r ep_len += 1 total_t += 1 terminal = d or (ep_len == max_ep_len) if terminal: dc_diff = torch.Tensor(buffer.calc_diff()).unsqueeze(0) con = torch.Tensor([float(c)]).unsqueeze(0) _, _, log_p = disc(dc_diff, con) buffer.end_episode(log_p.detach().numpy()) logger.store(EpRet=ep_ret, EpLen=ep_len) o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 if (epoch % save_freq == 0) or (epoch == epochs - 1): logger.save_state({'env': env}, [ac, disc], None) # Update ac.train() disc.train() update(epoch) # Log logger.log_tabular('Epoch', epoch) logger.log_tabular('EpRet', with_min_and_max=True) logger.log_tabular('EpLen', average_only=True) logger.log_tabular('VVals', with_min_and_max=True) logger.log_tabular('TotalEnvInteracts', total_t) logger.log_tabular('LossPi', average_only=True) logger.log_tabular('DeltaLossPi', average_only=True) logger.log_tabular('LossV', average_only=True) logger.log_tabular('DeltaLossV', average_only=True) logger.log_tabular('LossDC', average_only=True) logger.log_tabular('DeltaLossDC', average_only=True) logger.log_tabular('Entropy', average_only=True) logger.log_tabular('KL', average_only=True) logger.log_tabular('Time', time.time() - start_time) logger.dump_tabular()
def policyg(env_fn, actor_critic=ActorCritic, ac_kwargs=dict(), seed=0, episodes_per_epoch=40, epochs=500, gamma=0.99, lam=0.97, pi_lr=3e-4, vf_lr=1e-3, train_v_iters=80, max_ep_len=1000, logger_kwargs=dict(), save_freq=10): logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) seed += 10000 * proc_id() torch.manual_seed(seed) np.random.seed(seed) env = env_fn() obs_dim = env.observation_space.shape act_dim = env.action_space.shape ac_kwargs['action_space'] = env.action_space # Models ac = actor_critic(input_dim=obs_dim[0], **ac_kwargs) # Set up model saving logger.setup_pytorch_saver(ac) # Buffers local_episodes_per_epoch = int(episodes_per_epoch / num_procs()) buf = BufferActor(obs_dim[0], act_dim[0], local_episodes_per_epoch, max_ep_len) # Count variables var_counts = tuple( count_vars(module) for module in [ac.policy, ac.value_f]) print("POLICY GRADIENT") logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts) # Optimizers train_pi = torch.optim.Adam(ac.policy.parameters(), lr=pi_lr) train_v = torch.optim.Adam(ac.value_f.parameters(), lr=vf_lr) # Parameters Sync sync_all_params(ac.parameters()) def update(e): obs, act, adv, ret, lgp_old = [ torch.Tensor(x) for x in buf.retrieve_all() ] # Policy _, lgp, _ = ac.policy(obs, act) entropy = (-lgp).mean() # Policy loss # policy gradient term + entropy term pi_loss = -(lgp * adv).mean() # Train policy train_pi.zero_grad() pi_loss.backward() average_gradients(train_pi.param_groups) train_pi.step() # Value function v = ac.value_f(obs) v_l_old = F.mse_loss(v, ret) for _ in range(train_v_iters): v = ac.value_f(obs) v_loss = F.mse_loss(v, ret) # Value function train train_v.zero_grad() v_loss.backward() average_gradients(train_v.param_groups) train_v.step() # Log the changes _, lgp, _, v = ac(obs, act) entropy_new = (-lgp).mean() pi_loss_new = -(lgp * adv).mean() v_loss_new = F.mse_loss(v, ret) kl = (lgp_old - lgp).mean() logger.store(LossPi=pi_loss, LossV=v_l_old, DeltaLossPi=(pi_loss_new - pi_loss), DeltaLossV=(v_loss_new - v_l_old), Entropy=entropy, KL=kl) start_time = time.time() o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 total_t = 0 for epoch in range(epochs): ac.eval() # Policy rollout for _ in range(local_episodes_per_epoch): for _ in range(max_ep_len): obs = torch.Tensor(o.reshape(1, -1)) a, _, lopg_t, v_t = ac(obs) buf.store(o, a.detach().numpy(), r, v_t.item(), lopg_t.detach().numpy()) logger.store(VVals=v_t) o, r, d, _ = env.step(a.detach().numpy()[0]) ep_ret += r ep_len += 1 total_t += 1 terminal = d or (ep_len == max_ep_len) if terminal: buf.end_episode() logger.store(EpRet=ep_ret, EpLen=ep_len) o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 if (epoch % save_freq == 0) or (epoch == epochs - 1): # logger._torch_save(ac, fname="expert_torch_save.pt") # logger._torch_save(ac, fname="model.pt") logger.save_state({'env': env}, None, None) # Update ac.train() update(epoch) # Log logger.log_tabular('Epoch', epoch) logger.log_tabular('EpRet', with_min_and_max=True) logger.log_tabular('EpLen', average_only=True) logger.log_tabular('VVals', with_min_and_max=True) logger.log_tabular('TotalEnvInteracts', total_t) logger.log_tabular('LossPi', average_only=True) logger.log_tabular('DeltaLossPi', average_only=True) logger.log_tabular('LossV', average_only=True) logger.log_tabular('DeltaLossV', average_only=True) logger.log_tabular('Entropy', average_only=True) logger.log_tabular('KL', average_only=True) logger.log_tabular('Time', time.time() - start_time) logger.dump_tabular()
def gail(env_fn, actor_critic=ActorCritic, ac_kwargs=dict(), disc=Discriminator, dc_kwargs=dict(), seed=0, episodes_per_epoch=40, epochs=500, gamma=0.99, lam=0.97, pi_lr=3e-3, vf_lr=3e-3, dc_lr=5e-4, train_v_iters=80, train_dc_iters=80, max_ep_len=1000, logger_kwargs=dict(), save_freq=10): l_lam = 0 # balance two loss term print("starting now") logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) seed += 10000 * proc_id() torch.manual_seed(seed) np.random.seed(seed) env = env_fn() obs_dim = env.observation_space.shape act_dim = env.action_space.shape ac_kwargs['action_space'] = env.action_space # Models ac = actor_critic(input_dim=obs_dim[0], **ac_kwargs) disc = disc(input_dim=obs_dim[0], **dc_kwargs) # Set up model saving logger.setup_pytorch_saver([ac, disc]) # TODO: Load expert policy here expert = actor_critic(input_dim=obs_dim[0], **ac_kwargs) # expert_name = "expert_torch_save.pt" expert_name = "model.pt" # expert = torch.load(osp.join(logger_kwargs['output_dir'],'pyt_save' , expert_name)) expert = torch.load( '/home/tyna/Documents/openai/research-project/data/anonymous-expert/anonymous-expert_s0/pyt_save/model.pt' ) print('RUNNING GAIL') # Buffers local_episodes_per_epoch = int(episodes_per_epoch / num_procs()) buff_s = BufferS(obs_dim[0], act_dim[0], local_episodes_per_epoch, max_ep_len) buff_t = BufferT(obs_dim[0], act_dim[0], local_episodes_per_epoch, max_ep_len) # Count variables var_counts = tuple( count_vars(module) for module in [ac.policy, ac.value_f, disc.policy]) print("GAIL") logger.log('\nNumber of parameters: \t pi: %d, \t v: %d, \t d: %d\n' % var_counts) # Optimizers train_pi = torch.optim.Adam(ac.policy.parameters(), lr=pi_lr) train_v = torch.optim.Adam(ac.value_f.parameters(), lr=vf_lr) train_dc = torch.optim.Adam(disc.policy.parameters(), lr=dc_lr) # Parameters Sync sync_all_params(ac.parameters()) sync_all_params(disc.parameters()) def update(e): obs_s, act, adv, ret, lgp_old = [ torch.Tensor(x) for x in buff_s.retrieve_all() ] obs_t, _ = [torch.Tensor(x) for x in buff_t.retrieve_all()] # Policy _, lgp, _ = ac.policy(obs_s, act) entropy = (-lgp).mean() # Policy loss # policy gradient term + entropy term pi_loss = -(lgp * adv).mean() - l_lam * entropy # Train policy if e > 10: train_pi.zero_grad() pi_loss.backward() average_gradients(train_pi.param_groups) train_pi.step() # Value function v = ac.value_f(obs_s) v_l_old = F.mse_loss(v, ret) for _ in range(train_v_iters): v = ac.value_f(obs_s) v_loss = F.mse_loss(v, ret) # Value function train train_v.zero_grad() v_loss.backward() average_gradients(train_v.param_groups) train_v.step() # Discriminator gt1 = torch.ones(obs_s.size()[0], dtype=torch.int) gt2 = torch.zeros(obs_t.size()[0], dtype=torch.int) _, lgp_s, _ = disc(obs_s, gt=gt1) _, lgp_t, _ = disc(obs_t, gt=gt2) dc_loss_old = -lgp_s.mean() - lgp_t.mean() for _ in range(train_dc_iters): _, lgp_s, _ = disc(obs_s, gt=gt1) _, lgp_t, _ = disc(obs_t, gt=gt2) dc_loss = -lgp_s.mean() - lgp_t.mean() # Discriminator train train_dc.zero_grad() dc_loss.backward() average_gradients(train_dc.param_groups) train_dc.step() _, lgp_s, _ = disc(obs_s, gt=gt1) _, lgp_t, _ = disc(obs_t, gt=gt2) dc_loss_new = -lgp_s.mean() - lgp_t.mean() # Log the changes _, lgp, _, v = ac(obs, act) entropy_new = (-lgp).mean() pi_loss_new = -(lgp * adv).mean() - l_lam * entropy v_loss_new = F.mse_loss(v, ret) kl = (lgp_old - lgp).mean() logger.store(LossPi=pi_loss, LossV=v_l_old, LossDC=dc_loss_old, DeltaLossPi=(pi_loss_new - pi_loss), DeltaLossV=(v_loss_new - v_l_old), DeltaLossDC=(dc_loss_new - dc_loss_old), DeltaEnt=(entropy_new - entropy), Entropy=entropy, KL=kl) start_time = time.time() o, r, sdr, d, ep_ret, ep_sdr, ep_len = env.reset(), 0, 0, False, 0, 0, 0 total_t = 0 ep_len_t = 0 for epoch in range(epochs): ac.eval() disc.eval() # We recognize the probability term of index [0] correspond to the teacher's policy # Student's policy rollout for _ in range(local_episodes_per_epoch): for _ in range(max_ep_len): obs = torch.Tensor(o.reshape(1, -1)) a, _, lopg_t, v_t = ac(obs) buff_s.store(o, a.detach().numpy(), r, sdr, v_t.item(), lopg_t.detach().numpy()) logger.store(VVals=v_t) o, r, d, _ = env.step(a.detach().numpy()[0]) _, sdr, _ = disc(torch.Tensor(o.reshape(1, -1)), gt=torch.Tensor([0])) if sdr < -4: # Truncate rewards sdr = -4 ep_ret += r ep_sdr += sdr ep_len += 1 total_t += 1 terminal = d or (ep_len == max_ep_len) if terminal: buff_s.end_episode() logger.store(EpRetS=ep_ret, EpLenS=ep_len, EpSdrS=ep_sdr) print("Student Episode Return: \t", ep_ret) o, r, sdr, d, ep_ret, ep_sdr, ep_len = env.reset( ), 0, 0, False, 0, 0, 0 # Teacher's policy rollout for _ in range(local_episodes_per_epoch): for _ in range(max_ep_len): obs = torch.Tensor(o.reshape(1, -1)) a, _, _, _ = expert(obs) buff_t.store(o, a.detach().numpy(), r) o, r, d, _ = env.step(a.detach().numpy()[0]) ep_ret += r ep_len += 1 total_t += 1 terminal = d or (ep_len == max_ep_len) if terminal: buff_t.end_episode() logger.store(EpRetT=ep_ret, EpLenT=ep_len) print("Teacher Episode Return: \t", ep_ret) o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0 if (epoch % save_freq == 0) or (epoch == epochs - 1): logger.save_state({'env': env}, [ac, disc], None) # Update ac.train() disc.train() update(epoch) # Log logger.log_tabular('Epoch', epoch) logger.log_tabular('EpRetS', with_min_and_max=True) logger.log_tabular('EpSdrS', with_min_and_max=True) logger.log_tabular('EpLenS', average_only=True) logger.log_tabular('EpRetT', with_min_and_max=True) logger.log_tabular('EpLenT', average_only=True) logger.log_tabular('VVals', with_min_and_max=True) logger.log_tabular('TotalEnvInteracts', total_t) logger.log_tabular('LossPi', average_only=True) logger.log_tabular('DeltaLossPi', average_only=True) logger.log_tabular('LossV', average_only=True) logger.log_tabular('DeltaLossV', average_only=True) logger.log_tabular('LossDC', average_only=True) logger.log_tabular('DeltaLossDC', average_only=True) logger.log_tabular('Entropy', average_only=True) logger.log_tabular('DeltaEnt', average_only=True) logger.log_tabular('KL', average_only=True) logger.log_tabular('Time', time.time() - start_time) logger.dump_tabular()
def gail_penalized( env_fn, actor_critic=ActorCritic, ac_kwargs=dict(), disc=Discriminator, dc_kwargs=dict(), seed=0, episodes_per_epoch=40, epochs=500, gamma=0.99, lam=0.97, # Cost constraints / penalties: cost_lim=25, penalty_init=1., penalty_lr=5e-3, clip_ratio=0.2, pi_lr=3e-3, vf_lr=3e-3, dc_lr=5e-4, train_v_iters=80, train_pi_iters=80, train_dc_iters=80, max_ep_len=1000, logger_kwargs=dict(), config_name='standard', save_freq=10): # W&B Logging wandb.login() composite_name = 'new_gail_penalized_' + config_name wandb.init(project="LearningCurves", group="GAIL Clone", name=composite_name) # Special function to avoid certain slowdowns from PyTorch + MPI combo. setup_pytorch_for_mpi() l_lam = 0 # balance two loss terms # Set up logger and save configuration logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) seed += 10000 * proc_id() torch.manual_seed(seed) np.random.seed(seed) # Instantiate environment env = env_fn() obs_dim = env.observation_space.shape act_dim = env.action_space.shape ac_kwargs['action_space'] = env.action_space # Models # Create actor-critic and discriminator modules ac = actor_critic(input_dim=obs_dim[0], **ac_kwargs) discrim = disc(input_dim=obs_dim[0], **dc_kwargs) # Set up model saving logger.setup_pytorch_saver([ac, discrim]) # Sync params across processes sync_params(ac) sync_params(discrim) # Load expert policy here expert = actor_critic(input_dim=obs_dim[0], **ac_kwargs) # expert_name = "expert_torch_save.pt" expert_name = "model.pt" # expert = torch.load(osp.join(logger_kwargs['output_dir'],'pyt_save' , expert_name)) # expert = torch.load('/home/tyna/Documents/openai/research-project/data/anonymous-expert/anonymous-expert_s0/pyt_save/model.pt') expert = torch.load( '/home/tyna/Documents/openai/research-project/data/test-pen-ppo/test-pen-ppo_s0/pyt_save/model.pt' ) print('RUNNING GAIL') # Buffers local_episodes_per_epoch = int(episodes_per_epoch / num_procs()) buff_s = BufferStudent(obs_dim[0], act_dim[0], local_episodes_per_epoch, max_ep_len) buff_t = BufferTeacher(obs_dim[0], act_dim[0], local_episodes_per_epoch, max_ep_len) # Count variables var_counts = tuple( count_vars(module) for module in [ac.pi, ac.v, discrim.pi]) logger.log('\nNumber of parameters: \t pi: %d, \t v: %d, \t d: %d\n' % var_counts) # Optimizers pi_optimizer = torch.optim.Adam(ac.pi.parameters(), lr=pi_lr) vf_optimizer = torch.optim.Adam(ac.v.parameters(), lr=vf_lr) discrim_optimizer = torch.optim.Adam(discrim.pi.parameters(), lr=dc_lr) # # Parameters Sync # sync_all_params(ac.parameters()) # sync_all_params(disc.parameters()) # Set up function for computing PPO policy loss def compute_loss_pi(obs, act, adv, logp_old): # Policy loss # policy gradient term + entropy term # Policy loss with clipping (without clipping, loss_pi = -(logp*adv).mean()). # TODO: Think about removing clipping _, logp, _ = ac.pi(obs, act) ratio = torch.exp(logp - logp_old) clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv loss_pi = -(torch.min(ratio * adv, clip_adv)).mean() return loss_pi def penalty_update(cur_penalty): cur_cost = logger.get_stats('EpCostS')[0] cur_rew = logger.get_stats('EpRetS')[0] # Penalty update cur_penalty = max(0, cur_penalty + penalty_lr * (cur_cost - cost_lim)) return cur_penalty def update(e): obs_s, act, adv, ret, log_pi_old = [ torch.Tensor(x) for x in buff_s.retrieve_all() ] obs_t, _ = [torch.Tensor(x) for x in buff_t.retrieve_all()] # Policy _, logp, _ = ac.pi(obs_s, act) entropy = (-logp).mean() # Policy loss # policy gradient term + entropy term # loss_pi = -(logp * adv).mean() - l_lam * entropy # Train policy if e > 10: # Train policy with multiple steps of gradient descent for _ in range(train_pi_iters): pi_optimizer.zero_grad() loss_pi = compute_loss_pi(obs, act, adv, ret) loss_pi.backward() mpi_avg_grads(ac.pi) pi_optimizer.step() # Value function v = ac.v(obs_s) v_l_old = F.mse_loss(v, ret) for _ in range(train_v_iters): v = ac.v(obs_s) v_loss = F.mse_loss(v, ret) # Value function train vf_optimizer.zero_grad() v_loss.backward() mpi_avg_grads(ac.v) # average gradients across MPI processes vf_optimizer.step() # Discriminator gt1 = torch.ones(obs_s.size()[0], dtype=torch.int) gt2 = torch.zeros(obs_t.size()[0], dtype=torch.int) _, logp_student, _ = discrim(obs_s, gt=gt1) _, logp_teacher, _ = discrim(obs_t, gt=gt2) discrim_l_old = -logp_student.mean() - logp_teacher.mean() for _ in range(train_dc_iters): _, logp_student, _ = discrim(obs_s, gt=gt1) _, logp_teacher, _ = discrim(obs_t, gt=gt2) dc_loss = -logp_student.mean() - logp_teacher.mean() # Discriminator train discrim_optimizer.zero_grad() dc_loss.backward() # average_gradients(discrim_optimizer.param_groups) mpi_avg_grads(discrim.pi) discrim_optimizer.step() _, logp_student, _ = discrim(obs_s, gt=gt1) _, logp_teacher, _ = discrim(obs_t, gt=gt2) dc_loss_new = -logp_student.mean() - logp_teacher.mean() # Log the changes _, logp, _, v = ac(obs, act) entropy_new = (-logp).mean() pi_loss_new = -(logp * adv).mean() - l_lam * entropy v_loss_new = F.mse_loss(v, ret) kl = (log_pi_old - logp).mean() logger.store( # LossPi=loss_pi, LossV=v_l_old, LossDC=discrim_l_old, # DeltaLossPi=(pi_loss_new - loss_pi), DeltaLossV=(v_loss_new - v_l_old), DeltaLossDC=(dc_loss_new - discrim_l_old), DeltaEnt=(entropy_new - entropy), Entropy=entropy, KL=kl) start_time = time.time() o, r, sdr, d, ep_ret, ep_cost, ep_sdr, ep_len = env.reset( ), 0, 0, False, 0, 0, 0, 0 total_t = 0 ep_len_t = 0 # Initialize penalty cur_penalty = np.log(max(np.exp(penalty_init) - 1, 1e-8)) for epoch in range(epochs): ac.eval() discrim.eval() # We recognize the probability term of index [0] correspond to the teacher's policy # Student's policy rollout for _ in range(local_episodes_per_epoch): for _ in range(max_ep_len): obs = torch.Tensor(o.reshape(1, -1)) a, _, lopg_t, v_t = ac(obs) buff_s.store(o, a.detach().numpy(), r, sdr, v_t.item(), lopg_t.detach().numpy()) logger.store(VVals=v_t) o, r, d, info = env.step(a.detach().numpy()[0]) # print("INFO: ", info) c = info.get("cost") _, sdr, _ = discrim(torch.Tensor(o.reshape(1, -1)), gt=torch.Tensor([0])) if sdr < -4: # Truncate rewards sdr = -4 ep_ret += r ep_cost += c ep_sdr += sdr ep_len += 1 total_t += 1 terminal = d or (ep_len == max_ep_len) if terminal: buff_s.end_episode() logger.store(EpRetS=ep_ret, EpCostS=ep_cost, EpLenS=ep_len, EpSdrS=ep_sdr) print("Student Episode Return: \t", ep_ret) o, r, sdr, d, ep_ret, ep_cost, ep_sdr, ep_len = env.reset( ), 0, 0, False, 0, 0, 0, 0 # Teacher's policy rollout for _ in range(local_episodes_per_epoch): for _ in range(max_ep_len): # obs = a, _, _, _ = expert(torch.Tensor(o.reshape(1, -1))) buff_t.store(o, a.detach().numpy(), r) o, r, d, info = env.step(a.detach().numpy()[0]) c = info.get("cost") ep_ret += r ep_cost += c ep_len += 1 total_t += 1 terminal = d or (ep_len == max_ep_len) if terminal: buff_t.end_episode() logger.store(EpRetT=ep_ret, EpCostT=ep_cost, EpLenT=ep_len) print("Teacher Episode Return: \t", ep_ret) o, r, d, ep_ret, ep_cost, ep_len = env.reset( ), 0, False, 0, 0, 0 if (epoch % save_freq == 0) or (epoch == epochs - 1): logger.save_state({'env': env}, [ac, discrim], None) # Update ac.train() discrim.train() # update penalty cur_penalty = penalty_update(cur_penalty) # update networks update(epoch) # Log logger.log_tabular('Epoch', epoch) logger.log_tabular('EpRetS', average_only=True) logger.log_tabular('EpSdrS', average_only=True) logger.log_tabular('EpLenS', average_only=True) logger.log_tabular('EpRetT', average_only=True) logger.log_tabular('EpLenT', average_only=True) logger.log_tabular('VVals', with_min_and_max=True) logger.log_tabular('TotalEnvInteracts', total_t) # logger.log_tabular('LossPi', average_only=True) # logger.log_tabular('DeltaLossPi', average_only=True) logger.log_tabular('LossV', average_only=True) logger.log_tabular('DeltaLossV', average_only=True) logger.log_tabular('LossDC', average_only=True) logger.log_tabular('DeltaLossDC', average_only=True) # logger.log_tabular('Entropy', average_only=True) # logger.log_tabular('DeltaEnt', average_only=True) # logger.log_tabular('KL', average_only=True) logger.log_tabular('Time', time.time() - start_time) logger.dump_tabular()
def valor_penalized( env_fn, actor_critic=ActorCritic, ac_kwargs=dict(), disc=ValorDiscriminator, dc_kwargs=dict(), seed=0, episodes_per_epoch=40, epochs=50, gamma=0.99, pi_lr=3e-4, vf_lr=1e-3, dc_lr=5e-4, train_pi_iters=1, train_v_iters=80, train_dc_iters=10, train_dc_interv=10, lam=0.97, # Cost constraints / penalties: cost_lim=25, penalty_init=1., penalty_lr=5e-3, clip_ratio=0.2, max_ep_len=1000, logger_kwargs=dict(), con_dim=5, config_name='standard', save_freq=10, k=1): # W&B Logging wandb.login() composite_name = 'new_valor_penalized_' + config_name wandb.init(project="LearningCurves", group="VALOR Expert", name=composite_name) # Special function to avoid certain slowdowns from PyTorch + MPI combo. setup_pytorch_for_mpi() # Set up logger and save configuration logger = EpochLogger(**logger_kwargs) logger.save_config(locals()) seed += 10000 * proc_id() torch.manual_seed(seed) np.random.seed(seed) # Instantiate environment env = env_fn() obs_dim = env.observation_space.shape act_dim = env.action_space.shape ac_kwargs['action_space'] = env.action_space # Model # Create actor-critic modules and discriminator and monitor them ac = actor_critic(input_dim=obs_dim[0] + con_dim, **ac_kwargs) discrim = disc(input_dim=obs_dim[0], context_dim=con_dim, **dc_kwargs) # Set up model saving logger.setup_pytorch_saver([ac, discrim]) # Sync params across processes sync_params(ac) sync_params(discrim) # Buffer local_episodes_per_epoch = int(episodes_per_epoch / num_procs()) buffer = VALORBuffer(con_dim, obs_dim[0], act_dim[0], local_episodes_per_epoch, max_ep_len, train_dc_interv) # Count variables var_counts = tuple( count_vars(module) for module in [ac.pi, ac.v, discrim.pi]) logger.log('\nNumber of parameters: \t pi: %d, \t v: %d, \t d: %d\n' % var_counts) # Optimizers pi_optimizer = torch.optim.Adam(ac.pi.parameters(), lr=pi_lr) vf_optimizer = torch.optim.Adam(ac.v.parameters(), lr=vf_lr) discrim_optimizer = torch.optim.Adam(discrim.pi.parameters(), lr=dc_lr) def compute_loss_pi(obs, act, adv, logp_old): # Policy loss # policy gradient term + entropy term # Policy loss with clipping (without clipping, loss_pi = -(logp*adv).mean()). # TODO: Think about removing clipping _, logp, _ = ac.pi(obs, act) ratio = torch.exp(logp - logp_old) clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv loss_pi = -(torch.min(ratio * adv, clip_adv)).mean() return loss_pi # # Parameters Sync # sync_all_params(ac.parameters()) # sync_all_params(disc.parameters()) def penalty_update(cur_penalty): cur_cost = logger.get_stats('EpCost')[0] cur_rew = logger.get_stats('EpRet')[0] # Penalty update cur_penalty = max(0, cur_penalty + penalty_lr * (cur_cost - cost_lim)) return cur_penalty def update(e): obs, act, adv, pos, ret, logp_old = [ torch.Tensor(x) for x in buffer.retrieve_all() ] # Policy _, logp, _ = ac.pi(obs, act) entropy = (-logp).mean() # Train policy with multiple steps of gradient descent for _ in range(train_pi_iters): pi_optimizer.zero_grad() loss_pi = compute_loss_pi(obs, act, adv, ret) loss_pi.backward() mpi_avg_grads(ac.pi) pi_optimizer.step() # Value function v = ac.v(obs) v_l_old = F.mse_loss(v, ret) for _ in range(train_v_iters): v = ac.v(obs) v_loss = F.mse_loss(v, ret) # Value function train vf_optimizer.zero_grad() v_loss.backward() mpi_avg_grads(ac.v) vf_optimizer.step() # Discriminator if (e + 1) % train_dc_interv == 0: print('Discriminator Update!') # Remove BiLSTM, take FFNN, take state_diff and predict what the context was # Predict what was the context based on the tuple (or just context from just the current state) con, s_diff = [torch.Tensor(x) for x in buffer.retrieve_dc_buff()] print("s diff: ", s_diff) _, logp_dc, _ = discrim(s_diff, con) d_l_old = -logp_dc.mean() # Discriminator train for _ in range(train_dc_iters): _, logp_dc, _ = discrim(s_diff, con) d_loss = -logp_dc.mean() discrim_optimizer.zero_grad() d_loss.backward() mpi_avg_grads(discrim.pi) discrim_optimizer.step() _, logp_dc, _ = discrim(s_diff, con) dc_l_new = -logp_dc.mean() else: d_l_old = 0 dc_l_new = 0 # Log the changes _, logp, _, v = ac(obs, act) pi_l_new = -(logp * (k * adv + pos)).mean() v_l_new = F.mse_loss(v, ret) kl = (logp_old - logp).mean() logger.store(LossPi=loss_pi, LossV=v_l_old, KL=kl, Entropy=entropy, DeltaLossPi=(pi_l_new - loss_pi), DeltaLossV=(v_l_new - v_l_old), LossDC=d_l_old, DeltaLossDC=(dc_l_new - d_l_old)) # logger.store(Adv=adv.reshape(-1).numpy().tolist(), Pos=pos.reshape(-1).numpy().tolist()) start_time = time.time() o, r, d, ep_ret, ep_cost, ep_len = env.reset(), 0, False, 0, 0, 0 context_dist = Categorical(logits=torch.Tensor(np.ones(con_dim))) print("context distribution:", context_dist) total_t = 0 # Initialize penalty cur_penalty = np.log(max(np.exp(penalty_init) - 1, 1e-8)) for epoch in range(epochs): ac.eval() discrim.eval() for _ in range(local_episodes_per_epoch): c = context_dist.sample() print("context sample: ", c) c_onehot = F.one_hot(c, con_dim).squeeze().float() # print("one hot sample: ", c_onehot) for _ in range(max_ep_len): concat_obs = torch.cat( [torch.Tensor(o.reshape(1, -1)), c_onehot.reshape(1, -1)], 1) a, _, logp_t, v_t = ac(concat_obs) o, r, d, info = env.step(a.detach().numpy()[0]) # print("info", info) # time.sleep(0.002) # cost = info.get("cost") # ep_cost += cost ep_ret += r ep_len += 1 total_t += 1 # r_total = r - cur_penalty * cost # r_total /= (1 + cur_penalty) # buffer.store(c, concat_obs.squeeze().detach().numpy(), a.detach().numpy(), r_total, v_t.item(), # logp_t.detach().numpy()) buffer.store(c, concat_obs.squeeze().detach().numpy(), a.detach().numpy(), r, v_t.item(), logp_t.detach().numpy()) logger.store(VVals=v_t) terminal = d or (ep_len == max_ep_len) if terminal: dc_diff = torch.Tensor(buffer.calc_diff()).unsqueeze(0) con = torch.Tensor([float(c)]).unsqueeze(0) _, _, log_p = discrim(dc_diff, con) buffer.finish_path(log_p.detach().numpy()) logger.store(EpRet=ep_ret, EpCost=ep_cost, EpLen=ep_len) o, r, d, ep_ret, ep_cost, ep_len = env.reset( ), 0, False, 0, 0, 0 if (epoch % save_freq == 0) or (epoch == epochs - 1): logger.save_state({'env': env}, [ac, discrim], None) # Update ac.train() discrim.train() # update penalty cur_penalty = penalty_update(cur_penalty) # update models update(epoch) # Log logger.log_tabular('Epoch', epoch) logger.log_tabular('EpRet', with_min_and_max=True) logger.log_tabular('EpCost', with_min_and_max=True) logger.log_tabular('EpLen', average_only=True) logger.log_tabular('VVals', with_min_and_max=True) logger.log_tabular('TotalEnvInteracts', total_t) logger.log_tabular('LossPi', average_only=True) logger.log_tabular('DeltaLossPi', average_only=True) logger.log_tabular('LossV', average_only=True) logger.log_tabular('DeltaLossV', average_only=True) logger.log_tabular('LossDC', average_only=True) logger.log_tabular('DeltaLossDC', average_only=True) logger.log_tabular('Entropy', average_only=True) logger.log_tabular('KL', average_only=True) logger.log_tabular('Time', time.time() - start_time) logger.dump_tabular()