예제 #1
0
    def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
        master_addr = self.cluster_environment.master_address()
        master_port = self.cluster_environment.master_port()
        world_size = self.cluster_environment.world_size()
        assert_world_config(self.cluster_environment, master_addr, master_port,
                            world_size)

        if not is_world_initialized():
            pl_logger.info(
                "initializing world: GLOBAL_RANK: {}, MEMBER: {}/{}".format(
                    global_rank,
                    int(global_rank) + 1, world_size))
            _w = World(
                name=str(global_rank),
                rank=int(global_rank),
                world_size=int(world_size),
                dist_backend=self.torch_distributed_backend,
                dist_init_method=f"tcp://{master_addr}:{master_port}",
                rpc_init_method=f"tcp://{master_addr}:{int(master_port) + 1}",
            )
예제 #2
0
    def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
        master_addr = self.cluster_environment.master_address()
        master_port = self.cluster_environment.master_port()
        world_size = self.cluster_environment.world_size()
        assert_world_config(
            self.cluster_environment, master_addr, master_port, world_size
        )

        if not is_world_initialized():
            pl_logger.info(
                f"initializing world: GLOBAL_RANK: {global_rank}, "
                f"MEMBER: {int(global_rank) + 1}/{world_size}"
            )
            # TODO: currently nccl is having problems with supporting
            # different cnfigurations, use gloo as replacement.
            # See: https://github.com/pytorch/pytorch/issues/47885
            _w = World(
                name=str(global_rank),
                rank=int(global_rank),
                world_size=int(world_size),
                dist_backend="gloo",
                dist_init_method=f"tcp://{master_addr}:{master_port}",
                rpc_init_method=f"tcp://{master_addr}:{int(master_port) + 1}",
            )
        avg = t_plugin.reduce(avg, reduce_op=ReduceOp.SUM)
        return float(avg)


if __name__ == "__main__":
    os.environ["WORLD_SIZE"] = "3"
    print(os.environ["TEST_SAVE_PATH"])
    config = generate_env_config("CartPole-v0", {})
    config = generate_training_config(root_dir=os.environ["ROOT_DIR"], config=config)
    config = generate_algorithm_config("DQNApex", config)

    # use ddp gpu
    config["gpus"] = [0, 0, 0]
    config["num_processes"] = 3
    # this testing process corresponds to this node
    config["num_nodes"] = 1
    config["early_stopping_patience"] = 100
    # Use class instead of string name since algorithms is distributed.
    config["frame_config"]["models"] = [QNet, QNet]
    config["frame_config"]["model_kwargs"] = [
        {"state_dim": 4, "action_num": 2},
        {"state_dim": 4, "action_num": 2},
    ]

    # cb = [DDPInspectCallback(), LoggerDebugCallback()]
    cb = [DDPInspectCallback()]
    launch(config, pl_callbacks=cb)
    if is_world_initialized() and get_cur_rank() == 0:
        with open(os.environ["TEST_SAVE_PATH"], "wb") as f:
            pickle.dump(cb[0].avg_max_total_reward, f)