Exemplo n.º 1
0
 def testTrainCartpoleOffPolicy(self):
     register_env(
         "test3", lambda _: PartOffPolicyServing(gym.make("CartPole-v0"),
                                                 off_pol_frac=0.2))
     dqn = DQNAgent(env="test3", config={"exploration_fraction": 0.001})
     for i in range(100):
         result = dqn.train()
         print("Iteration {}, reward {}, timesteps {}".format(
             i, result["episode_reward_mean"], result["timesteps_total"]))
         if result["episode_reward_mean"] >= 100:
             return
     raise Exception("failed to improve reward")
Exemplo n.º 2
0
 def testTrainCartpoleOffPolicy(self):
     register_env(
         "test3", lambda _: PartOffPolicyServing(
             gym.make("CartPole-v0"), off_pol_frac=0.2))
     dqn = DQNAgent(env="test3", config={"exploration_fraction": 0.001})
     for i in range(100):
         result = dqn.train()
         print("Iteration {}, reward {}, timesteps {}".format(
             i, result["episode_reward_mean"], result["timesteps_total"]))
         if result["episode_reward_mean"] >= 100:
             return
     raise Exception("failed to improve reward")
Exemplo n.º 3
0
 def testEvaluationOption(self):
     ray.init()
     agent = DQNAgent(env="CartPole-v0", config={"evaluation_interval": 2})
     r0 = agent.train()
     r1 = agent.train()
     r2 = agent.train()
     r3 = agent.train()
     r4 = agent.train()
     self.assertTrue("evaluation" in r0)
     self.assertTrue("episode_reward_mean" in r0["evaluation"])
     self.assertEqual(r0["evaluation"], r1["evaluation"])
     self.assertNotEqual(r1["evaluation"], r2["evaluation"])
     self.assertEqual(r2["evaluation"], r3["evaluation"])
     self.assertNotEqual(r3["evaluation"], r4["evaluation"])
Exemplo n.º 4
0
CHECKPOINT_FILE = "last_checkpoint.out"

ray.init("localhost:6379")
ModelCatalog.register_custom_model("parametric", ParametricActionsModel)
register_env("halite_env", env_creator)
dqn = DQNAgent(
    env="halite_env",
    config={
        "env_config": {},
        "num_workers": 60,
        "num_cpus_per_worker": 1,
        "num_envs_per_worker": 20,
        "num_gpus": 1,
        "hiddens": [],
        "schedule_max_timesteps": 7500000,
        "timesteps_per_iteration": 4000,
        "exploration_fraction": 0.8,
        "exploration_final_eps": 0.02,
        "lr": 1e-3,
        "model": {
            "custom_model": "parametric",
            "custom_options": {},  # extra options to pass to your model
        }
    })

# Attempt to restore from checkpoint if possible.
if os.path.exists(CHECKPOINT_FILE):
    checkpoint_path = open(CHECKPOINT_FILE).read()
    print("Restoring from checkpoint path", checkpoint_path)
    dqn.restore(checkpoint_path)
Exemplo n.º 5
0
                                                       SERVER_PORT))
        server = PolicyServer(self, SERVER_ADDRESS, SERVER_PORT)
        server.serve_forever()


if __name__ == "__main__":
    ray.init()
    register_env("srv", lambda _: CartpoleServing())

    # We use DQN since it supports off-policy actions, but you can choose and
    # configure any agent.
    dqn = DQNAgent(
        env="srv",
        config={
            # Use a single process to avoid needing to set up a load balancer
            "num_workers": 0,
            # Configure the agent to run short iterations for debugging
            "exploration_fraction": 0.01,
            "learning_starts": 100,
            "timesteps_per_iteration": 200,
        })

    # Attempt to restore from checkpoint if possible.
    if os.path.exists(CHECKPOINT_FILE):
        checkpoint_path = open(CHECKPOINT_FILE).read()
        print("Restoring from checkpoint path", checkpoint_path)
        dqn.restore(checkpoint_path)

    # Serving and training loop
    while True:
        print(pretty_print(dqn.train()))
        checkpoint_path = dqn.save()
Exemplo n.º 6
0
CHECKPOINT_FILE = "last_checkpoint.out"

ray.init(local_mode=True)
ModelCatalog.register_custom_model("parametric", ParametricActionsModel)
register_env("halite_env", env_creator)
dqn = DQNAgent(
    env="halite_env",
    config={
        "env_config": {},
        "num_workers": 1,
        "num_cpus_per_worker": 1,
        "num_envs_per_worker": 1,
        "num_gpus": 1,
        "hiddens": [],
        "schedule_max_timesteps": 100000000,
        "timesteps_per_iteration": 1000,
        "exploration_fraction": 0.8,
        "exploration_final_eps": 0.02,
        "lr": 1e-3,
        "model": {
            "custom_model": "parametric",
            "custom_options": {},  # extra options to pass to your model
        }
    })

# Attempt to restore from checkpoint if possible.
if os.path.exists(CHECKPOINT_FILE):
    checkpoint_path = open(CHECKPOINT_FILE).read()
    print("Restoring from checkpoint path", checkpoint_path)
    dqn.restore(checkpoint_path)
Exemplo n.º 7
0

if __name__ == '__main__':

    args = parser.parse_args()
    ray.init()
    register_env('srv', lambda config: CartPoleServer(config))

    ModelCatalog.register_custom_model("CM", CustomModel)
    dqn = DQNAgent(env='srv',
                   config={
                       'num_workers': 0,
                       'env_config': {
                           'observation_size': args.observation_size,
                           'action_size': args.action_size,
                           'checkpoint_file': args.checkpoint_file
                       },
                       'model': {
                           'custom_model': 'CM',
                           'custom_options': {},
                       },
                       'learning_starts': 150
                   })
    if os.path.exists(args.checkpoint_file):
        checkpoint_path = open(args.checkpoint_file).read()
        print("Restoring from checkpoint path", checkpoint_path)
        dqn.restore(checkpoint_path)

    while True:
        print(pretty_print(dqn.train()))
        checkpoint_path = dqn.save()
        print("Last checkpoint", checkpoint_path)
Exemplo n.º 8
0
        server.serve_forever()


if __name__ == "__main__":
    args = parser.parse_args()
    ray.init()
    register_env("srv", lambda config: SimpleServing(config))

    if args.run == "DQN":
        agent = DQNAgent(
            env="srv",
            config={
                # Use a single process to avoid needing a load balancer
                "num_workers": 0,
                # Configure the agent to run short iterations for debugging
                "exploration_fraction": 0.01,
                "learning_starts": 100,
                "timesteps_per_iteration": 200,
                "env_config": {
                    "observation_size": args.observation_size,
                    "action_size": args.action_size,
                },
            })
    elif args.run == "PG":
        agent = PGAgent(
            env="srv",
            config={
                "num_workers": 0,
                "env_config": {
                    "observation_size": args.observation_size,
                    "action_size": args.action_size,
                },