示例#1
0
def run_worker(model_dir, device_idx, endpoint="ipc://worker.ipc"):
    """worker process entrence"""
    try:
        log.debug("run_worker %s" % device_idx)
        os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv(
            "CUDA_VISIBLE_DEVICES").split(",")[device_idx]
        log.debug('cuda_env %s' % os.environ["CUDA_VISIBLE_DEVICES"])
        import paddle.fluid as F
        from ernie_gen.propeller.service import interface_pb2
        import ernie_gen.propeller.service.utils as serv_utils
        context = zmq.Context()
        socket = context.socket(zmq.REP)
        socket.connect(endpoint)
        #socket.bind(endpoint)
        log.debug("Predictor building %s" % device_idx)
        predictor = Predictor(model_dir, 0)
        log.debug("Predictor %s" % device_idx)
    except Exception as e:
        log.exception(e)

    while True:
        #  Wait for next request from client
        try:
            message = socket.recv()
            log.debug("get message %s" % device_idx)
            slots = interface_pb2.Slots()
            slots.ParseFromString(message)
            pts = [serv_utils.slot_to_paddlearray(s) for s in slots.slots]
            ret = predictor(pts)
            slots = interface_pb2.Slots(
                slots=[serv_utils.paddlearray_to_slot(r) for r in ret])
            socket.send(slots.SerializeToString())
        except Exception as e:
            log.exception(e)
            socket.send(e.message)
示例#2
0
def nparray_list_deserialize(string):
    """doc"""
    slots = interface_pb2.Slots()
    slots.ParseFromString(string)
    return [slot_to_numpy(slot) for slot in slots.slots]
示例#3
0
def nparray_list_serialize(arr_list):
    """doc"""
    slot_list = [numpy_to_slot(arr) for arr in arr_list]
    slots = interface_pb2.Slots(slots=slot_list)
    return slots.SerializeToString()