예제 #1
0
    def test_evaluation_wo_evaluation_worker_set(self):
        config = a3c.DEFAULT_CONFIG.copy()
        config.update({
            "env": "CartPole-v0",
            # Switch off evaluation (this should already be the default).
            "evaluation_interval": None,
        })
        for _ in framework_iterator(frameworks=("tf", "torch")):
            # Setup trainer w/o evaluation worker set and still call
            # evaluate() -> Expect error.
            trainer_wo_env_on_driver = a3c.A3CTrainer(config=config)
            self.assertRaisesRegexp(
                ValueError, "Cannot evaluate w/o an evaluation worker set",
                trainer_wo_env_on_driver.evaluate)
            trainer_wo_env_on_driver.stop()

            # Try again using `create_env_on_driver=True`.
            # This force-adds the env on the local-worker, so this Trainer
            # can `evaluate` even though, it doesn't have an evaluation-worker
            # set.
            config["create_env_on_driver"] = True
            trainer_w_env_on_driver = a3c.A3CTrainer(config=config)
            results = trainer_w_env_on_driver.evaluate()
            assert "evaluation" in results
            assert "episode_reward_mean" in results["evaluation"]
            trainer_w_env_on_driver.stop()
            config["create_env_on_driver"] = False
예제 #2
0
 def check_learned(self):
     """
     check the learned agent
     """
     ray.init(local_mode=True)
     if self.algorithm == 'PPO':
         agent = ppo.PPOTrainer(config=self.ray_config,
                                env=self.env.__class__)
     elif self.algorithm == 'A3C':
         agent = a3c.A3CTrainer(config=self.ray_config,
                                env=self.env.__class__)
     elif self.algorithm == 'PG':
         agent = pg.PGTrainer(config=self.ray_config,
                              env=self.env.__class__)
     agent.restore(self.checkpoint_path)
     # run until episode ends
     episode_reward = 0
     done = False
     obs = self.env.reset()
     while True:
         self.env.render()
         action = agent.compute_action(obs)
         obs, reward, done, info = self.env.step(action)
         # print(f"obs:\n{obs}")
         print(f"reward:\n{reward}")
         print(f"info:\n{info}")
         episode_reward += reward
예제 #3
0
    def test(self, algo, path, lr, fc_hid, fc_act):
        """Test trained agent for a single episode. Return the episode reward"""
        # instantiate env class
        unused_shared = []
        unused_own = []
        unsatisfied_shared = []
        unsatisfied_own = []

        episode_reward = 0

        #self.config["num_workers"] = 0
        self.config["lr"] = lr
        self.config['model']["fcnet_hiddens"] = fc_hid
        self.config['model']["fcnet_activation"] = fc_act

        if algo == "ppo":
            self.agent = ppo.PPOTrainer(config=self.config)
        if algo == "ddpg":
            self.agent = ddpg.DDPGTrainer(config=self.config)
        if algo == "a3c":
            self.agent = a3c.A3CTrainer(config=self.config)
        if algo == "impala":
            self.agent = impala.ImpalaTrainer(config=self.config)
        if algo == "appo":
            self.agent = ppo.APPOTrainer(config=self.config)
        if algo == "td3":
            self.agent = ddpg.TD3Trainer(config=self.config)

        self.agent.restore(path)

        env = caching_vM(config=self.config)

        obs = env.reset()
        done = False

        action = {}
        for agent_id, agent_obs in obs.items():
            policy_id = self.config['multiagent']['policy_mapping_fn'](
                agent_id)
            action[agent_id] = self.agent.compute_action(agent_obs,
                                                         policy_id=policy_id)
        obs, reward, done, info = env.step(action)
        done = done['__all__']

        for x in range(len(info)):
            res = ast.literal_eval(info[x])
            unused_shared.append(res[0])
            unused_own.append(res[1])
            unsatisfied_shared.append(res[2])
            unsatisfied_own.append(res[3])

        print("reward == ", reward)
        # sum up reward for all agents
        episode_reward += sum(reward.values())

        return episode_reward, unused_shared, unused_own, unsatisfied_shared, unsatisfied_own
예제 #4
0
    def test(self,algo, path, lr, fc_hid, fc_act):

        """Test trained agent for a single episode. Return the episode reward"""
        # instantiate env class
        unused_shared = []
        unused_own = []
        unsatisfied_shared = []
        unsatisfied_own = []

        episode_reward = 0
        self.config_test["num_workers"] = 0
        self.config_test["lr"] = lr
        self.config_test['model']["fcnet_hiddens"] = fc_hid
        self.config_test['model']["fcnet_activation"] = fc_act

        if algo == "ppo":
            self.agent = ppo.PPOTrainer(config=self.config_test)
        if algo == "ddpg":
            self.agent = ddpg.DDPGTrainer(config=self.config_test)
        if algo == "a3c":
            self.agent = a3c.A3CTrainer(config=self.config_test)
        if algo == "impala":
            self.agent = impala.ImpalaTrainer(config=self.config_test)
        if algo == "appo":
            self.agent = ppo.APPOTrainer(config=self.config_test)
        if algo == "td3":
            self.agent = ddpg.TD3Trainer(config=self.config_test)

        self.agent.restore(path)

        #env = self.agent.workers.local_worker().env
        #env = self.env_class(self.env_config)
        #env = ContentCaching(*self.config_train)
        #env = self.config_train["env"]#env_config)
        #env = self.env_class(3)
        #env = ContentCaching
        #env = self.env
        #self.env = ContentCaching
        #env = self.config_train["env"]
        
     
        obs = ContentCaching.reset()
        done = False

        while not done:
            action = self.agent.compute_action(obs)
            obs, reward, done, info = self.env.step(action)
            episode_reward += reward

            unused_shared.append(info["unused_shared"])
            unused_own.append(info["unused_own"])
            unsatisfied_shared.append(info["unsatisfied_shared"])
            unsatisfied_own.append(info["unsatisfied_own"])
        
        return episode_reward, unused_shared, unused_own, unsatisfied_shared, unsatisfied_own
예제 #5
0
파일: test_a3c.py 프로젝트: zcm-006/ray
    def test_a3c_compilation(self):
        """Test whether an A3CTrainer can be built with both frameworks."""
        config = a3c.DEFAULT_CONFIG.copy()
        config["num_workers"] = 2
        config["num_envs_per_worker"] = 2

        num_iterations = 1

        # Test against all frameworks.
        for fw in framework_iterator(config, ("tf", "torch")):
            config["sample_async"] = fw == "tf"
            for env in ["CartPole-v0", "Pendulum-v0", "PongDeterministic-v0"]:
                trainer = a3c.A3CTrainer(config=config, env=env)
                for i in range(num_iterations):
                    results = trainer.train()
                    print(results)
                check_compute_action(trainer)
예제 #6
0
파일: test_a3c.py 프로젝트: tuyulers5/jav44
    def test_a3c_compilation(self):
        """Test whether an A3CTrainer can be built with both frameworks."""
        config = a3c.DEFAULT_CONFIG.copy()
        config["num_workers"] = 2
        config["num_envs_per_worker"] = 2

        num_iterations = 1

        # Test against all frameworks.
        for _ in framework_iterator(config):
            for env in ["CartPole-v0", "Pendulum-v0", "PongDeterministic-v0"]:
                print("env={}".format(env))
                trainer = a3c.A3CTrainer(config=config, env=env)
                for i in range(num_iterations):
                    results = trainer.train()
                    print(results)
                check_compute_single_action(trainer)
                trainer.stop()
예제 #7
0
    def test_a3c_entropy_coeff_schedule(self):
        """Test A3CTrainer entropy coeff schedule support."""
        config = a3c.DEFAULT_CONFIG.copy()
        config["num_workers"] = 1
        config["num_envs_per_worker"] = 1
        config["train_batch_size"] = 20
        config["batch_mode"] = "truncate_episodes"
        config["rollout_fragment_length"] = 10
        config["timesteps_per_iteration"] = 20
        # 0 metrics reporting delay, this makes sure timestep,
        # which entropy coeff depends on, is updated after each worker rollout.
        config["min_time_s_per_reporting"] = 0
        # Initial lr, doesn't really matter because of the schedule below.
        config["entropy_coeff"] = 0.01
        schedule = [
            [0, 0.01],
            [120, 0.0001],
        ]
        config["entropy_coeff_schedule"] = schedule

        def _step_n_times(trainer, n: int):
            """Step trainer n times.

            Returns:
                learning rate at the end of the execution.
            """
            for _ in range(n):
                results = trainer.train()
            return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][
                LEARNER_STATS_KEY]["entropy_coeff"]

        # Test against all frameworks.
        for _ in framework_iterator(config):
            trainer = a3c.A3CTrainer(config=config, env="CartPole-v1")

            coeff = _step_n_times(trainer, 1)  # 20 timesteps
            # Should be close to the starting coeff of 0.01
            self.assertGreaterEqual(coeff, 0.005)

            coeff = _step_n_times(trainer, 10)  # 200 timesteps
            # Should have annealed to the final coeff of 0.0001.
            self.assertLessEqual(coeff, 0.00011)

            trainer.stop()
예제 #8
0
def get_rl_agent(agent_name, config, env_to_agent):
    if agent_name == A2C:
        import ray.rllib.agents.a3c as a2c
        agent = a2c.A2CTrainer(config=config, env=env_to_agent)
    elif agent_name == A3C:
        import ray.rllib.agents.a3c as a3c
        agent = a3c.A3CTrainer(config=config, env=env_to_agent)
    elif agent_name == BC:
        import ray.rllib.agents.marwil as bc
        agent = bc.BCTrainer(config=config, env=env_to_agent)
    elif agent_name == DQN:
        import ray.rllib.agents.dqn as dqn
        agent = dqn.DQNTrainer(config=config, env=env_to_agent)
    elif agent_name == APEX_DQN:
        import ray.rllib.agents.dqn as dqn
        agent = dqn.ApexTrainer(config=config, env=env_to_agent)
    elif agent_name == IMPALA:
        import ray.rllib.agents.impala as impala
        agent = impala.ImpalaTrainer(config=config, env=env_to_agent)
    elif agent_name == MARWIL:
        import ray.rllib.agents.marwil as marwil
        agent = marwil.MARWILTrainer(config=config, env=env_to_agent)
    elif agent_name == PG:
        import ray.rllib.agents.pg as pg
        agent = pg.PGTrainer(config=config, env=env_to_agent)
    elif agent_name == PPO:
        import ray.rllib.agents.ppo as ppo
        agent = ppo.PPOTrainer(config=config, env=env_to_agent)
    elif agent_name == APPO:
        import ray.rllib.agents.ppo as ppo
        agent = ppo.APPOTrainer(config=config, env=env_to_agent)
    elif agent_name == SAC:
        import ray.rllib.agents.sac as sac
        agent = sac.SACTrainer(config=config, env=env_to_agent)
    elif agent_name == LIN_UCB:
        import ray.rllib.contrib.bandits.agents.lin_ucb as lin_ucb
        agent = lin_ucb.LinUCBTrainer(config=config, env=env_to_agent)
    elif agent_name == LIN_TS:
        import ray.rllib.contrib.bandits.agents.lin_ts as lin_ts
        agent = lin_ts.LinTSTrainer(config=config, env=env_to_agent)
    else:
        raise Exception("Not valid agent name")
    return agent
예제 #9
0
def get_rllib_agent(agent_name, env_name, env, env_to_agent):
    config = get_config(env_name, env, 1) if is_rllib_agent(agent_name) else {}
    if agent_name == RLLIB_A2C:
        import ray.rllib.agents.a3c as a2c
        agent = a2c.A2CTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_A3C:
        import ray.rllib.agents.a3c as a3c
        agent = a3c.A3CTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_BC:
        import ray.rllib.agents.marwil as bc
        agent = bc.BCTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_DQN:
        import ray.rllib.agents.dqn as dqn
        agent = dqn.DQNTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_APEX_DQN:
        import ray.rllib.agents.dqn as dqn
        agent = dqn.ApexTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_IMPALA:
        import ray.rllib.agents.impala as impala
        agent = impala.ImpalaTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_MARWIL:
        import ray.rllib.agents.marwil as marwil
        agent = marwil.MARWILTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_PG:
        import ray.rllib.agents.pg as pg
        agent = pg.PGTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_PPO:
        import ray.rllib.agents.ppo as ppo
        agent = ppo.PPOTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_APPO:
        import ray.rllib.agents.ppo as ppo
        agent = ppo.APPOTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_SAC:
        import ray.rllib.agents.sac as sac
        agent = sac.SACTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_LIN_UCB:
        import ray.rllib.contrib.bandits.agents.lin_ucb as lin_ucb
        agent = lin_ucb.LinUCBTrainer(config=config, env=env_to_agent)
    elif agent_name == RLLIB_LIN_TS:
        import ray.rllib.contrib.bandits.agents.lin_ts as lin_ts
        agent = lin_ts.LinTSTrainer(config=config, env=env_to_agent)
    return agent
예제 #10
0
    def test_a3c_compilation(self):
        """Test whether an A3CTrainer can be built with both frameworks."""
        config = a3c.DEFAULT_CONFIG.copy()
        config["num_workers"] = 2
        config["num_envs_per_worker"] = 2

        num_iterations = 2

        # Test against all frameworks.
        for _ in framework_iterator(config, with_eager_tracing=True):
            for env in ["CartPole-v1", "Pendulum-v1", "PongDeterministic-v0"]:
                print("env={}".format(env))
                config["model"]["use_lstm"] = env == "CartPole-v1"
                trainer = a3c.A3CTrainer(config=config, env=env)
                for i in range(num_iterations):
                    results = trainer.train()
                    check_train_results(results)
                    print(results)
                check_compute_single_action(
                    trainer, include_state=config["model"]["use_lstm"])
                trainer.stop()
예제 #11
0
def main():
    ray.init()
    '''config = ppo.DEFAULT_CONFIG.copy()
    config["num_gpus"] = 0
    config["num_workers"] = 1
    trainer = ppo.PPOTrainer(config=config, env=SingleplayerGym)'''

    config = a3c.DEFAULT_CONFIG.copy()
    config["num_gpus"] = 2
    config["num_workers"] = 6
    trainer = a3c.A3CTrainer(config=config, env=SingleplayerGym)

    # Can optionally call trainer.restore(path) to load a checkpoint.

    for i in range(1000):
        # Perform one iteration of training the policy with PPO
        result = trainer.train()
        print(pretty_print(result))

        if i % 100 == 0:
            checkpoint = trainer.save()
            print("checkpoint saved at", checkpoint)
예제 #12
0
    'reward_scale': 0.01,
    'max_steps': 1000,
    'is_continuous': True,
    'catch_distance': 0.1
}

ray.init(include_dashboard=False)
ModelCatalog.register_custom_model("CartpoleModel", PredatorVictimModel)
PVEnv = gym.make("PredatorVictim-v0", params=params)
register_env("PredatorVictimEnv", lambda _: PVEnv)

trainer = a3c.A3CTrainer(env="PredatorVictimEnv",
                         config={
                             "multiagent": {
                                 "policies": {
                                     "policy_predator": gen_policy(PVEnv, 0),
                                     "policy_victim": gen_policy(PVEnv, 1)
                                 },
                                 "policy_mapping_fn": policy_mapping_fn,
                             },
                         })

if os.path.isfile(model_file):
    weights = pickle.load(open(model_file, "rb"))
    trainer.restore_from_object(weights)

keyboard.on_press_key("q", press_key_exit)
while True:
    if ready_to_exit:
        break
    rest = trainer.train()
    print(rest['policy_reward_mean'])
예제 #13
0
        agent_delta_config = delta_config['agent']
        for key, value in agent_delta_config.items():
            print('Agent config: ', key, ' --> ', value)
            agent_config[key] = value

        # Load parameters that control the training regime
        training_config = delta_config['training']
        evaluation_steps = training_config['evaluation_steps']
        checkpoint_path = training_config['checkpoint_path']

# Register the custom items
ModelCatalog.register_custom_model(model_name, Agent)

print('Agent config:\n', agent_config)
# agent_config['gamma'] = 0.0
agent = a3c.A3CTrainer(agent_config,
                       env=meta_env_type)  # Note use of custom Env creator fn

agent.restore(checkpoint_path)

# Use this line uncommented to see the whole config and all options
# print('\n\n\nPOLICY CONFIG',agent.get_policy().config,"\n\n\n")

# Evaluate the model


def find_json_value(key_path, json, delimiter='.'):
    paths = key_path.split(delimiter)
    data = json
    for i in range(0, len(paths)):
        data = data[paths[i]]
    return data
예제 #14
0
    "env": NewsWorld
}

sac_config = {
    "env": NewsWorld
}

parser = argparse.ArgumentParser()
parser.add_argument("--iterations", type=int, default=10)

if __name__ == "__main__":
    args = parser.parse_args()
    ray.init()

    register_env(
        "NewsLearn",
        #lambda _: HeartsEnv()
        lambda _: ExternalWorld(env=NewsWorld(dict()))
    )

    trainer = a3c.A3CTrainer(env="NewsLearn", config=dict())

    for i in range(args.iterations):
        result = trainer.train()
        print("Iteration {}, Episodes {}, Mean Reward {}, Mean Length {}".format(
            i, result['episodes_this_iter'], result['episode_reward_mean'], result['episode_len_mean']
        ))
        i += 1

    ray.shutdown()
예제 #15
0
import os

# os.environ["TUNE_RESULT_DIR"] = "/media/drake/BlackPassport/ray_results/"

import time
import ray
import ray.rllib.agents.a3c as a3c
from ray.tune.logger import pretty_print

ray.init()

config = a3c.DEFAULT_CONFIG.copy()
config["num_gpus"] = 0
config["num_workers"] = 5
config["num_envs_per_worker"] = 5
trainer = a3c.A3CTrainer(config=config, env="Blackjack-v0")

# Can optionally call trainer.restore(path) to load a checkpoint.
start = int(round(time.time()))
while True:
    # Perform one iteration of training the policy with PPO
    elapsed = int(round(time.time())) - start
    if elapsed > 30:
        break
    result = trainer.train()
    print(pretty_print(result))
예제 #16
0
    def forward(self, input_dict, state, seq_lens):
        model_out, self._value_out = self.base_model(input_dict["obs"])
        return model_out, state

    def value_function(self):
        return tf.reshape(self._value_out, [-1])


ray.init(include_dashboard=False)
ModelCatalog.register_custom_model("CartpoleModel", CartpoleModel)
CartpoleEnv = gym.make('CartPole-v0')
CartpoleEnv = ScaleReward(CartpoleEnv)
register_env("CP", lambda _:CartpoleEnv)

trainer = a3c.A3CTrainer(env="CP", config={"model": {"custom_model": "CartpoleModel"}})
if os.path.isfile('CartPole.pickle'):
    weights = pickle.load(open("CartPole.pickle", "rb"))
    trainer.restore_from_object(weights)


keyboard.on_press_key("q", press_key_exit)
while True:
    if ready_to_exit:
        break
    rest = trainer.train()
    print(rest["episode_reward_mean"])

weights = trainer.save_to_object()
pickle.dump(weights, open('CartPole.pickle', 'wb'))
print('Model saved')
예제 #17
0
        self.has_virus = self.get_virus()

        self.destinations = self.get_position()

        states = [self.encode_state(i) for i in range(self.agent_num)]
        self.s = states

        return np.array(agent_matrix).astype(int)


if __name__ == "__main__":
    population = [
        np.load("./data/seocho.npy"),
        np.load("./data/daechi.npy"),
        np.load("./data/dogok.npy"),
        np.load("./data/yangjae.npy"),
        np.load("./data/sunreung.npy"),
        np.load("./data/nambu.npy")
    ]
    ray.init()
    trainer = a3c.A3CTrainer(
        env=EpidemicMultiEnv,
        config={
            "env_config": {
                'agent_num': 200,
                'population': population
            },  # config to pass to env class
        })
    for _ in range(90000):
        print(trainer.train())
예제 #18
0
import ray
from model import EpiNN
import ray.rllib.agents.a3c as a3c
from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog

ray.init(num_gpus=2)
config = a3c.DEFAULT_CONFIG.copy()
VirtLocalCDC = EpiNN()
ModelCatalog.register_custom_model("CDC_model", VirtLocalCDC)

config["num_gpus"] = 2
config["num_workers"] = 10
config["eager"] = True
config["model"] = "CDC_model"


trainer = a3c.A3CTrainer(config=config,env=)

for i in range(1000):
   result = trainer.train()
   print(pretty_print(result))

   if i % 100 == 0:
       checkpoint = trainer.save()
       print("checkpoint saved at", checkpoint)



예제 #19
0
def render(checkpoint, home_path):
    """
    Renders pybullet and mujoco environments.
    """
    alg = re.match('.+?(?=_)', os.path.basename(os.path.normpath(home_path))).group(0)
    current_env = re.search("(?<=_).*?(?=_)", os.path.basename(os.path.normpath(home_path))).group(0)
    checkpoint_path = home_path + "checkpoint_" + str(checkpoint) + "/checkpoint-" + str(checkpoint)
    config = json.load(open(home_path + "params.json"))
    config_bin = pickle.load(open(home_path + "params.pkl", "rb"))
    ray.shutdown()
    import pybullet_envs
    ray.init()
    ModelCatalog.register_custom_model("RBF", RBFModel)
    ModelCatalog.register_custom_model("MLP_2_64", MLP)
    ModelCatalog.register_custom_model("linear", Linear)

    if alg == "PPO":
        trainer = ppo.PPOTrainer(config_bin)
    if alg == "SAC":
        trainer = sac.SACTrainer(config)
    if alg == "DDPG":
        trainer = ddpg.DDPGTrainer(config)
    if alg == "PG":
        trainer = pg.PGTrainer(config)
    if alg == "A3C":
        trainer = a3c.A3CTrainer(config)
    if alg == "TD3":
        trainer = td3.TD3Trainer(config)
    if alg == "ES":
        trainer = es.ESTrainer(config)
    if alg == "ARS":
        trainer = ars.ARSTrainer(config)
#   "normalize_actions": true,
    trainer.restore(checkpoint_path)

    if "Bullet" in current_env:
        env = gym.make(current_env, render=True)
    else:
        env = gym.make(current_env)
    #env.unwrapped.reset_model = det_reset_model
    env._max_episode_steps = 10000
    obs = env.reset()

    action_hist = []
    m_act_hist = []
    state_hist  = []
    obs_hist = []
    reward_hist = []

    done = False
    step = 0

    for t in range(10000):
        # for some algorithms you can get the sample mean out, need to change the value on the index to match your env for now
        # mean_actions = out_dict['behaviour_logits'][:17]
        # actions = trainer.compute_action(obs.flatten())
        # sampled_actions, _ , out_dict = trainer.compute_action(obs.flatten(),full_fetch=True)
        sampled_actions = trainer.compute_action(obs.flatten())
        # sampled_actions, _ , out_dict = trainer.compute_action(obs.flatten(),full_fetch=True)
        
        actions = sampled_actions
        
        obs, reward, done, _ = env.step(np.asarray(actions))
        # env.camera_adjust()
        env.render(mode='human')
        time.sleep(0.01)
        # env.render()
        # env.render(mode='rgb_array', close = True)
        # p.computeViewMatrix(cameraEyePosition=[0,10,5], cameraTargetPosition=[0,0,0], cameraUpVector=[0,0,0])

        # if step % 1000 == 0:
        #     env.reset()
        # step += 1
        
        action_hist.append(np.copy(actions))
        obs_hist.append(np.copy(obs))
        reward_hist.append(np.copy(reward))
        if done:
            obs = env.reset()
    # print(sum(reward_hist))
    # print((obs_hist))
    #plt.plot(action_hist)
    #plt.figure()
    #plt.figure()
    #plt.plot(obs_hist)
    #plt.figure()

    # Reminder that the bahavior logits that come out are the mean and logstd (not log mean, despite the name logit)
    # trainer.compute_action(obs, full_fetch=True)
    trainer.compute_action(obs)
예제 #20
0
        self.model.add(layers.Dense(2, name='l5', activation='relu'))
        return self.model.get_layer("l5").output, self.model.get_layer("l4").output



ray.init()
ModelCatalog.register_custom_model("CartpoleModel", CartpoleModel)
CartpoleEnv = gym.make('CartPole-v0')
CartpoleEnv=ScaleReward(CartpoleEnv)
register_env("CP", lambda _:CartpoleEnv)



trainer = a3c.A3CTrainer(env="CP", config={
    #"model": {"custom_model": "CartpoleModel"},
    #"observation_filter": "MeanStdFilter",
    #"vf_share_layers": True,
}, logger_creator=lambda _:ray.tune.logger.NoopLogger({},None))

if os.path.isfile('weights.pickle'):
   weights = pickle.load(open("weights.pickle", "rb"))
   trainer.restore_from_object(weights)


a_list = []
_thread.start_new_thread(input_thread, (a_list,))
while not a_list:
      rest=trainer.train()
      print(rest["episode_reward_mean"])

weights=trainer.save_to_object()