def main():
	max_iteration = 3000
	episodes_per_batch = 20
	max_kl = 0.01
	init_logvar = -1
	policy_epochs = 5
	value_epochs = 10
	value_batch_size = 256
	gamma = 0.995
	lam = .97

	exp_info = 'humanoid_ego_pure'
	# initialize environment
	env = HumanoidEnv()
	env.seed(0)
	
	obs_dim = env.observation_space.shape[0]
	ego_dim = env.ego_pure_shape()

	print('obs_dim: ', obs_dim)
	print('ego_dim: ', ego_dim)
	act_dim = env.action_space.shape[0]

	logger = Logger()
	killer = GracefulKiller()

	# init qpos and qvel
	init_qpos = np.load('./mocap_expert_qpos.npy')
	init_qvel = np.load('./mocap_expert_qvel.npy')
	exp_obs = np.load('./mocap_pure_ego.npy')
	print('exp_obs shape: ', exp_obs.shape)

	# policy function
	policy = Policy(obs_dim=obs_dim, act_dim=act_dim, max_kl=max_kl,
					init_logvar=init_logvar, epochs=policy_epochs, 
					logger=logger)

	# value function
	value = Value(obs_dim=obs_dim, act_dim=act_dim, epochs=value_epochs, 
				  batch_size=value_batch_size, logger=logger)

	discriminator = Discriminator(obs_dim=ego_dim, act_dim=act_dim, ent_reg_weight=1e-3,
								  epochs=2, input_type='states', loss_type='pure_gail',
								  logger=logger)
	# agent
	agent = GeneratorAgentEgoPure(env=env, policy_function=policy, value_function=value, discriminator=discriminator,
				  		   	  gamma=gamma, lam=lam, init_qpos=init_qpos, init_qvel=init_qvel,
				  		   	  logger=logger)

	print('policy lr: %f' %policy.lr)
	print('value lr %f' %value.lr)
	print('disc lr: %f' %discriminator.lr)
	# train for num_episodes
	iteration = 0
	while iteration < max_iteration:
		print('-------- iteration %d --------' %iteration)
		# collect trajectories
		obs, uns_obs, acts, tdlams, advs = agent.collect(timesteps=20000)
		
		# update policy function using ppo
		policy.update(obs, acts, advs)

		# update value function
		value.update(obs, tdlams)

		idx = np.random.randint(low=0, high=exp_obs.shape[0], size=uns_obs.shape[0])
		expert = exp_obs[idx, :]
		gen_acc, exp_acc, total_acc = discriminator.update(exp_obs=expert, gen_obs=uns_obs)
		print('gen_acc: %f, exp_acc: %f, total_acc: %f' %(gen_acc, exp_acc, total_acc))
		
		if iteration % 50 == 0:
			print('saving...')
			# save the experiment logs
			filename = './model_inter_ego_pure/stats_' + exp_info + '_' + str(iteration)
			logger.dump(filename)

			# save session
			filename = './model_inter_ego_pure/model_' + exp_info + '_' + str(iteration)
			policy.save_session(filename)

		if killer.kill_now:
			break
		# update episode number
		iteration += 1
		
	# save the experiment logs
	filename = './model_ego_pure/stats_' + exp_info
	logger.dump(filename)

	# save session
	filename = './model_ego_pure/model_' + exp_info
	policy.save_session(filename)

	# close everything
	policy.close_session()
	value.close_session()
	env.close()