示例#1
0
文件: test_es.py 项目: rlan/ray
    def test_es_compilation(self):
        """Test whether an ESTrainer can be built on all frameworks."""
        ray.init(num_cpus=4)
        config = es.DEFAULT_CONFIG.copy()
        # Keep it simple.
        config["model"]["fcnet_hiddens"] = [10]
        config["model"]["fcnet_activation"] = None
        config["noise_size"] = 2500000
        config["num_workers"] = 1
        config["episodes_per_batch"] = 10
        config["train_batch_size"] = 100
        # Test eval workers ("normal" Trainer eval WorkerSet).
        config["evaluation_interval"] = 1
        config["evaluation_num_workers"] = 2

        num_iterations = 1

        for _ in framework_iterator(config):
            for env in ["CartPole-v0", "Pendulum-v0"]:
                plain_config = config.copy()
                trainer = es.ESTrainer(config=plain_config, env=env)
                for i in range(num_iterations):
                    results = trainer.train()
                    print(results)

                check_compute_single_action(trainer)
                trainer.stop()
        ray.shutdown()
示例#2
0
文件: test_es.py 项目: yuishihara/ray
    def test_es_compilation(self):
        """Test whether an ESTrainer can be built on all frameworks."""
        ray.init()
        config = es.DEFAULT_CONFIG.copy()
        # Keep it simple.
        config["model"]["fcnet_hiddens"] = [10]
        config["model"]["fcnet_activation"] = None

        num_iterations = 2

        for _ in framework_iterator(config, ("torch", "tf")):
            plain_config = config.copy()
            trainer = es.ESTrainer(config=plain_config, env="CartPole-v0")
            for i in range(num_iterations):
                results = trainer.train()
                print(results)
示例#3
0
文件: test_es.py 项目: zhongchun/ray
    def test_es_compilation(self):
        """Test whether an ESTrainer can be built on all frameworks."""
        ray.init()
        config = es.DEFAULT_CONFIG.copy()
        # Keep it simple.
        config["model"]["fcnet_hiddens"] = [10]
        config["model"]["fcnet_activation"] = None
        config["noise_size"] = 2500000

        num_iterations = 2

        for _ in framework_iterator(config):
            plain_config = config.copy()
            trainer = es.ESTrainer(config=plain_config, env="CartPole-v0")
            for i in range(num_iterations):
                results = trainer.train()
                print(results)

            check_compute_single_action(trainer)
            trainer.stop()
        ray.shutdown()
示例#4
0
def render(checkpoint, home_path):
    """
    Renders pybullet and mujoco environments.
    """
    alg = re.match('.+?(?=_)', os.path.basename(os.path.normpath(home_path))).group(0)
    current_env = re.search("(?<=_).*?(?=_)", os.path.basename(os.path.normpath(home_path))).group(0)
    checkpoint_path = home_path + "checkpoint_" + str(checkpoint) + "/checkpoint-" + str(checkpoint)
    config = json.load(open(home_path + "params.json"))
    config_bin = pickle.load(open(home_path + "params.pkl", "rb"))
    ray.shutdown()
    import pybullet_envs
    ray.init()
    ModelCatalog.register_custom_model("RBF", RBFModel)
    ModelCatalog.register_custom_model("MLP_2_64", MLP)
    ModelCatalog.register_custom_model("linear", Linear)

    if alg == "PPO":
        trainer = ppo.PPOTrainer(config_bin)
    if alg == "SAC":
        trainer = sac.SACTrainer(config)
    if alg == "DDPG":
        trainer = ddpg.DDPGTrainer(config)
    if alg == "PG":
        trainer = pg.PGTrainer(config)
    if alg == "A3C":
        trainer = a3c.A3CTrainer(config)
    if alg == "TD3":
        trainer = td3.TD3Trainer(config)
    if alg == "ES":
        trainer = es.ESTrainer(config)
    if alg == "ARS":
        trainer = ars.ARSTrainer(config)
#   "normalize_actions": true,
    trainer.restore(checkpoint_path)

    if "Bullet" in current_env:
        env = gym.make(current_env, render=True)
    else:
        env = gym.make(current_env)
    #env.unwrapped.reset_model = det_reset_model
    env._max_episode_steps = 10000
    obs = env.reset()

    action_hist = []
    m_act_hist = []
    state_hist  = []
    obs_hist = []
    reward_hist = []

    done = False
    step = 0

    for t in range(10000):
        # for some algorithms you can get the sample mean out, need to change the value on the index to match your env for now
        # mean_actions = out_dict['behaviour_logits'][:17]
        # actions = trainer.compute_action(obs.flatten())
        # sampled_actions, _ , out_dict = trainer.compute_action(obs.flatten(),full_fetch=True)
        sampled_actions = trainer.compute_action(obs.flatten())
        # sampled_actions, _ , out_dict = trainer.compute_action(obs.flatten(),full_fetch=True)
        
        actions = sampled_actions
        
        obs, reward, done, _ = env.step(np.asarray(actions))
        # env.camera_adjust()
        env.render(mode='human')
        time.sleep(0.01)
        # env.render()
        # env.render(mode='rgb_array', close = True)
        # p.computeViewMatrix(cameraEyePosition=[0,10,5], cameraTargetPosition=[0,0,0], cameraUpVector=[0,0,0])

        # if step % 1000 == 0:
        #     env.reset()
        # step += 1
        
        action_hist.append(np.copy(actions))
        obs_hist.append(np.copy(obs))
        reward_hist.append(np.copy(reward))
        if done:
            obs = env.reset()
    # print(sum(reward_hist))
    # print((obs_hist))
    #plt.plot(action_hist)
    #plt.figure()
    #plt.figure()
    #plt.plot(obs_hist)
    #plt.figure()

    # Reminder that the bahavior logits that come out are the mean and logstd (not log mean, despite the name logit)
    # trainer.compute_action(obs, full_fetch=True)
    trainer.compute_action(obs)