Пример #1
0
    def test_dqn_apex_cpu_spawn_full_train(self, tmpdir):
        # by default, pytorch lightning will use ddp-spawn mode to replace ddp
        # if there are only cpus
        os.environ["WORLD_SIZE"] = "3"
        config = generate_env_config("CartPole-v0", {})
        config = generate_training_config(root_dir=tmpdir.make_numbered_dir(),
                                          config=config)
        config = generate_algorithm_config("DQNApex", config)
        # use ddp_cpu
        config["gpus"] = None
        config["num_processes"] = 3
        # this testing process corresponds to this node
        config["num_nodes"] = 1
        config["early_stopping_patience"] = 100
        # Use class instead of string name since algorithms is distributed.
        config["frame_config"]["models"] = [QNet, QNet]
        config["frame_config"]["model_kwargs"] = [
            {
                "state_dim": 4,
                "action_num": 2
            },
            {
                "state_dim": 4,
                "action_num": 2
            },
        ]

        # for spawn we use a special callback, because the we cannot access
        # max_total_reward from sub-processes
        queue = SimpleQueue(ctx=mp.get_context("spawn"))
        # cb = [SpawnInspectCallback(queue), LoggerDebugCallback()]
        cb = [SpawnInspectCallback(queue)]
        t = Thread(target=launch, args=(config, ), kwargs={"pl_callbacks": cb})
        t.start()

        default_logger.info("Start tracking")
        subproc_max_total_reward = [0, 0, 0]
        while True:
            try:
                result = queue.quick_get(timeout=60)
                default_logger.info(
                    f"Result from process [{result[0]}]: {result[1]}")
                subproc_max_total_reward[result[0]] = result[1]
            except TimeoutError:
                # no more results
                default_logger.info("No more results.")
                break
        t.join()
        assert (
            sum(subproc_max_total_reward) / 3 >= 150
        ), f"Max total reward {sum(subproc_max_total_reward) / 3} below threshold 150."
Пример #2
0
 def test_dqn_apex(_, tmpdir):
     config = generate_gym_env_config("CartPole-v0", {})
     config = generate_training_config(trials_dir=str(
         tmpdir.make_numbered_dir()),
                                       config=config)
     config = generate_algorithm_config("DQNApex", config)
     config["early_stopping_patience"] = 10
     config["frame_config"]["models"] = ["QNet", "QNet"]
     config["frame_config"]["model_kwargs"] = [
         {
             "state_dim": 4,
             "action_num": 2
         },
         {
             "state_dim": 4,
             "action_num": 2
         },
     ]
     cb = InspectCallback()
     launch_gym(config, pl_callbacks=[cb])
     assert cb.max_total_reward >= 150
Пример #3
0
 def test_dqn_full_train(self, tmpdir):
     config = generate_env_config("CartPole-v0", {})
     config = generate_training_config(root_dir=str(
         tmpdir.make_numbered_dir()),
                                       config=config)
     config = generate_algorithm_config("DQN", config)
     config["early_stopping_patience"] = 100
     config["frame_config"]["models"] = ["QNet", "QNet"]
     config["frame_config"]["model_kwargs"] = [
         {
             "state_dim": 4,
             "action_num": 2
         },
         {
             "state_dim": 4,
             "action_num": 2
         },
     ]
     cb = InspectCallback()
     launch(config, pl_callbacks=[cb])
     assert (
         cb.max_total_reward >= 150
     ), f"Max total reward {cb.max_total_reward} below threshold 150."
Пример #4
0
import torch.nn as nn


class SomeQNet(nn.Module):
    def __init__(self, state_dim, action_num):
        super().__init__()

        self.fc1 = nn.Linear(state_dim, 16)
        self.fc2 = nn.Linear(16, 16)
        self.fc3 = nn.Linear(16, action_num)

    def forward(self, state):
        a = t.relu(self.fc1(state))
        a = t.relu(self.fc2(a))
        return self.fc3(a)


if __name__ == "__main__":
    config = generate_algorithm_config("DQN")
    config = generate_env_config("openai_gym", config)
    config = generate_training_config(root_dir="trial",
                                      episode_per_epoch=10,
                                      max_episodes=10000,
                                      config=config)
    config["frame_config"]["models"] = ["SomeQNet", "SomeQNet"]
    config["frame_config"]["model_kwargs"] = [{
        "state_dim": 4,
        "action_num": 2
    }] * 2
    launch(config)
        )
        should_stop = t_plugin.reduce(should_stop, reduce_op=ReduceOp.SUM)
        should_stop = bool(should_stop == trainer.world_size)
        return should_stop

    def reduce_max_total_reward(self, trainer, t_plugin):
        avg = t.tensor(self.max_total_reward, device=trainer.lightning_module.device)
        avg = t_plugin.reduce(avg, reduce_op=ReduceOp.SUM)
        return float(avg)


if __name__ == "__main__":
    os.environ["WORLD_SIZE"] = "3"
    print(os.environ["TEST_SAVE_PATH"])
    config = generate_env_config("CartPole-v0", {})
    config = generate_training_config(root_dir=os.environ["ROOT_DIR"], config=config)
    config = generate_algorithm_config("DQNApex", config)

    # use ddp gpu
    config["gpus"] = [0, 0, 0]
    config["num_processes"] = 3
    # this testing process corresponds to this node
    config["num_nodes"] = 1
    config["early_stopping_patience"] = 100
    # Use class instead of string name since algorithms is distributed.
    config["frame_config"]["models"] = [QNet, QNet]
    config["frame_config"]["model_kwargs"] = [
        {"state_dim": 4, "action_num": 2},
        {"state_dim": 4, "action_num": 2},
    ]
Пример #6
0
    elif args.command == "generate":
        if args.algo not in get_available_algorithms():
            print(
                f"{args.algo} is not a valid algorithm, use list "
                "--algo to get a list of available algorithms."
            )
            exit(0)
        if args.env not in get_available_environments():
            print(
                f"{args.env} is not a valid environment, use list "
                "--env to get a list of available environments."
            )
            exit(0)
        config = {}
        config = generate_env_config(args.env, config=config)
        config = generate_algorithm_config(args.algo, config=config)
        config = generate_training_config(config=config)

        if args.print:
            pprint(config)

        with open(args.output, "w") as f:
            json.dump(config, f, indent=4, sort_keys=True)
        print(f"Config saved to {args.output}")

    elif args.command == "launch":
        with open(args.config, "r") as f:
            conf = json.load(f)
        launch(conf)