Esempio n. 1
0
def run(conn: Connection,
        config: dict,
        model: torch.nn.Module,
        log_queue: Optional[mp.Queue] = None):
    log.configure(log_queue)
    inference_proc = InferenceProcess(config, model)
    srv = MPServer(inference_proc, conn)
    srv.listen()
Esempio n. 2
0
def run(conn: Connection,
        config: dict,
        model: torch.nn.Module,
        log_queue: Optional[mp.Queue] = None):
    log.configure(log_queue)
    # print('CUDA_VISIBLE_DEVICES:', os.environ["CUDA_VISIBLE_DEVICES"])
    dryrun_proc = DryRunProcess(config, model)
    srv = MPServer(dryrun_proc, conn)
    srv.listen()
Esempio n. 3
0
def run(
    conn: Connection,
    config: dict,
    model: torch.nn.Module,
    optimizer_state: bytes = b"",
    log_queue: Optional[mp.Queue] = None,
):
    log.configure(log_queue)
    training_proc = TrainingProcess(config, model, optimizer_state)
    srv = MPServer(training_proc, conn)
    srv.listen()
Esempio n. 4
0
def run(
    conn: Connection,
    config: dict,
    model_file: bytes,
    model_state: bytes,
    optimizer_state: bytes,
    log_queue: Optional[mp.Queue] = None,
):
    log.configure(log_queue)
    handler = HandlerProcess(config, model_file, model_state, optimizer_state,
                             log_queue)
    srv = MPServer(handler, conn)
    srv.listen()
Esempio n. 5
0
def _run_model_session_process(
    conn: Connection, model_zip: bytes, devices: List[str], log_queue: Optional[_mp.Queue] = None
):
    try:
        # from: https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667
        import resource

        rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
        resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
    except ModuleNotFoundError:
        pass  # probably running on windows

    if log_queue:
        log.configure(log_queue)

    session_proc = ModelSessionProcess(model_zip, devices)
    srv = MPServer(session_proc, conn)
    srv.listen()
Esempio n. 6
0
def _srv(conn, log_queue):
    log.configure(log_queue)
    srv = MPServer(ApiImpl(), conn)
    srv.listen()
Esempio n. 7
0
def _run_srv(srv_cls, conn, log_queue):
    log.configure(log_queue)
    srv = MPServer(srv_cls(), conn)
    srv.listen()
Esempio n. 8
0
def _cancel_srv(conn, log_queue):
    log.configure(log_queue)
    srv = MPServer(CancelableSrv(), conn)
    srv.listen()