コード例 #1
0
def get_time_to_first_contact(env, policy, is_random=False, num_trajs=100):
    import itertools
    time_contact = []
    if is_random:
        from rllab.policies.uniform_control_policy import UniformControlPolicy
        policy = UniformControlPolicy(env.spec)
    print("Using {}".format(policy))
    for traj_i in range(num_trajs):
        obs = env.reset()
        print("Start traj {}".format(traj_i))
        for t in itertools.count():
            action, _ = policy.get_action(obs)
            obs, reward, done, env_info = env.step(action)
            if env_info['contact_reward'] > 0 or done:
                time_contact.append(t)
                break
    # plt.hist(time_contact)
    # plt.title("Time to first contact over {} trajectories".format(num_trajs))
    # plt.show()
    data_path = input("Where do you want to save it? \n")
    np.save(data_path, time_contact)
    print("Data saved")
    print(
        "Mean time to first contact: {}, median:{}, std:{} for {}, ({} trajectories)"
        .format(np.mean(time_contact), np.median(time_contact),
                np.std(time_contact), policy, num_trajs))
コード例 #2
0
def episode_reward(env, policy, is_random=False):
	import itertools
	mean_reward = []
	if is_random:
		from rllab.policies.uniform_control_policy import UniformControlPolicy
		policy = UniformControlPolicy(env.spec)
	print ("Using {}".format(policy))
	for traj_i in range(num_trajs):
		obs = env.reset()
		print ("Start traj {}".format(traj_i))
		rewards
		for t in itertools.count():
			action, _ = policy.get_action(obs)
			obs, reward, done, env_info = env.step(action)
			if done:
				break
	plt.his
	print ("Mean time to first contact: {} for {}, ({} trajectories)".format(np.mean(time_contact), policy, num_trajs))
コード例 #3
0
def test_state_hist(env):
	policy = UniformControlPolicy(env.spec)
	_states = []
	o = env.reset()
	try:
		while True:
			_states.append(o)
			a, _ = policy.get_action(o)
			next_o, r, d, env_info = env.step(a)
			if d:
				o = env.reset()
			else:
				o = next_o
	except KeyboardInterrupt:
		states = np.asarray(_states)
		save_path = '/Users/dianchen/state.npy'
		np.save(save_path, states)
		# pickle.dump(states, save_path)
		print ("State samples saved to {}".format(save_path))