Ejemplo n.º 1
0
def run_worker(model_dir, device_idx, endpoint="ipc://worker.ipc"):
    try:
        log.debug("run_worker %s" % device_idx)
        os.environ["CUDA_VISIBLE_DEVICES"] = os.getenv(
            "CUDA_VISIBLE_DEVICES").split(",")[device_idx]
        log.debug('cuda_env %s' % os.environ["CUDA_VISIBLE_DEVICES"])
        import paddle.fluid as F
        from propeller.service import interface_pb2
        import propeller.service.utils as serv_utils
        context = zmq.Context()
        socket = context.socket(zmq.REP)
        socket.connect(endpoint)
        #socket.bind(endpoint)
        log.debug("Predictor building %s" % device_idx)
        predictor = Predictor(model_dir, 0)
        log.debug("Predictor %s" % device_idx)
    except Exception as e:
        log.exception(e)

    while True:
        #  Wait for next request from client
        try:
            message = socket.recv()
            log.debug("get message %s" % device_idx)
            slots = interface_pb2.Slots()
            slots.ParseFromString(message)
            pts = [serv_utils.slot_to_paddlearray(s) for s in slots.slots]
            ret = predictor(pts)
            slots = interface_pb2.Slots(
                slots=[serv_utils.paddlearray_to_slot(r) for r in ret])
            socket.send(slots.SerializeToString())
        except Exception as e:
            log.exception(e)
            socket.send(e.message)
Ejemplo n.º 2
0
 def Infer(self, request, context):
     try:
         slots = request.slots
         current_thread = threading.current_thread()
         log.debug('%d slots received dispatch to thread %s' %
                   (len(slots), current_thread))
         if current_thread not in predictor_context:
             did = list(pool._threads).index(current_thread)
             log.debug('spawning worker thread %d' % did)
             predictor = Predictor(did)
             predictor_context[current_thread] = predictor
         else:
             predictor = predictor_context[current_thread]
         slots = [serv_utils.slot_to_paddlearray(s) for s in slots]
         ret = predictor(slots)
         response = [serv_utils.paddlearray_to_slot(r) for r in ret]
     except Exception as e:
         log.exception(e)
         raise e
     return interface_pb2.Slots(slots=response)
Ejemplo n.º 3
0
def nparray_list_deserialize(string):
    """doc"""
    slots = interface_pb2.Slots()
    slots.ParseFromString(string)
    return [slot_to_numpy(slot) for slot in slots.slots]
Ejemplo n.º 4
0
def nparray_list_serialize(arr_list):
    """doc"""
    slot_list = [numpy_to_slot(arr) for arr in arr_list]
    slots = interface_pb2.Slots(slots=slot_list)
    return slots.SerializeToString()