Exemple #1
0
config["target_network_update_freq"] = 500000
config["exploration_config"] = {"type": "PerWorkerEpsilonGreedy"}
config["worker_side_prioritization"] = True
# config["min_iter_time_s"] = 30
# config["training_intensity"] = None
# config["log_level"] = 'DEBUG'
config["env_config"] = env_config
trainer = DQNTrainer(config=config, env=SSA_Tasker_Env)
# Can optionally call trainer.restore(path) to load a checkpoint.

checkpoints = []
result = {'timesteps_total': 0}
i = 0
while result['timesteps_total'] < 1e7:
    # Perform one iteration of training the policy with PPO
    result = trainer.train()
    print(pretty_print(result))

    if result['training_iteration'] % 4 == 0:
        checkpoint = trainer.save()
        print("checkpoint saved at", checkpoint)
        checkpoints.append(copy(checkpoint))

# path = '/home/ash/ray_results/DQN_SSA_Tasker_Env_2020-08-13_13-26-22s_wgpxq5/checkpoint_1/'
# trainer.restore(path)

# trainer.import_model("my_weights.h5")
'''
ray.shutdown()
'''
    elif args.run == "PG":
        agent = PGTrainer(
            env="srv",
            config={
                "num_workers": 0,
                "env_config": {
                    # Use the connector server to generate experiences.
                    "input": (
                        lambda ioctx: PolicyServerInput(ioctx, SERVER_ADDRESS, SERVER_PORT)
                    ),
                    "observation_size": args.observation_size,
                    "action_size": args.action_size,
                },
            })
    else:
        raise ValueError("--run must be DQN or PG")

    # Attempt to restore from checkpoint if possible.
    if os.path.exists(args.checkpoint_file):
        checkpoint_file = open(args.checkpoint_file).read()
        print("Restoring from checkpoint path", checkpoint_file)
        agent.restore(checkpoint_file)

    # Serving and training loop
    while True:
        print(pretty_print(agent.train()))
        checkpoint_file = agent.save()
        print("Last checkpoint", checkpoint_file)
        with open(args.checkpoint_file, "w") as f:
            f.write(checkpoint_file)
Exemple #3
0
        # else:
        #     return "dqn_policy"
        return agent_id

    dqn_trainer = DQNTrainer(env="cityflow_multi",
                             config={
                                 "multiagent": {
                                     "policies":
                                     policies,
                                     "policy_mapping_fn":
                                     policy_mapping_fn,
                                     "policies_to_train":
                                     [id_ for id_ in intersection_id]
                                 },
                                 "gamma": 0.95,
                                 "n_step": 3,
                                 "num_workers": 1,
                                 "num_cpus_per_worker": 20,
                                 "env_config": config
                             })

    for i in range(args.epoch):
        print("== Iteration", i, "==")

        # improve the DQN policy
        print("-- DQN --")
        print(pretty_print(dqn_trainer.train()))

        if (i + 1) % 100 == 0:
            checkpoint = dqn_trainer.save()
            print("checkpoint saved at", checkpoint)