def main(): max_iteration = 3000 episodes_per_batch = 20 max_kl = 0.01 init_logvar = -1 policy_epochs = 5 value_epochs = 10 value_batch_size = 256 gamma = 0.995 lam = .97 exp_info = 'humanoid_ego_pure' # initialize environment env = HumanoidEnv() env.seed(0) obs_dim = env.observation_space.shape[0] ego_dim = env.ego_pure_shape() print('obs_dim: ', obs_dim) print('ego_dim: ', ego_dim) act_dim = env.action_space.shape[0] logger = Logger() killer = GracefulKiller() # init qpos and qvel init_qpos = np.load('./mocap_expert_qpos.npy') init_qvel = np.load('./mocap_expert_qvel.npy') exp_obs = np.load('./mocap_pure_ego.npy') print('exp_obs shape: ', exp_obs.shape) # policy function policy = Policy(obs_dim=obs_dim, act_dim=act_dim, max_kl=max_kl, init_logvar=init_logvar, epochs=policy_epochs, logger=logger) # value function value = Value(obs_dim=obs_dim, act_dim=act_dim, epochs=value_epochs, batch_size=value_batch_size, logger=logger) discriminator = Discriminator(obs_dim=ego_dim, act_dim=act_dim, ent_reg_weight=1e-3, epochs=2, input_type='states', loss_type='pure_gail', logger=logger) # agent agent = GeneratorAgentEgoPure(env=env, policy_function=policy, value_function=value, discriminator=discriminator, gamma=gamma, lam=lam, init_qpos=init_qpos, init_qvel=init_qvel, logger=logger) print('policy lr: %f' %policy.lr) print('value lr %f' %value.lr) print('disc lr: %f' %discriminator.lr) # train for num_episodes iteration = 0 while iteration < max_iteration: print('-------- iteration %d --------' %iteration) # collect trajectories obs, uns_obs, acts, tdlams, advs = agent.collect(timesteps=20000) # update policy function using ppo policy.update(obs, acts, advs) # update value function value.update(obs, tdlams) idx = np.random.randint(low=0, high=exp_obs.shape[0], size=uns_obs.shape[0]) expert = exp_obs[idx, :] gen_acc, exp_acc, total_acc = discriminator.update(exp_obs=expert, gen_obs=uns_obs) print('gen_acc: %f, exp_acc: %f, total_acc: %f' %(gen_acc, exp_acc, total_acc)) if iteration % 50 == 0: print('saving...') # save the experiment logs filename = './model_inter_ego_pure/stats_' + exp_info + '_' + str(iteration) logger.dump(filename) # save session filename = './model_inter_ego_pure/model_' + exp_info + '_' + str(iteration) policy.save_session(filename) if killer.kill_now: break # update episode number iteration += 1 # save the experiment logs filename = './model_ego_pure/stats_' + exp_info logger.dump(filename) # save session filename = './model_ego_pure/model_' + exp_info policy.save_session(filename) # close everything policy.close_session() value.close_session() env.close()