def send_worker(address, send_queue, worker_alive):
        timing = AccumDict()
        log = Logger('./var/log/send_worker.log', opt.verbose)

        ctx = SerializingContext()
        sender = ctx.socket(zmq.PUSH)
        sender.connect(address)

        log(f"Sending to {address}")

        try:
            while worker_alive.value:
                tt = TicToc()

                try:
                    msg = send_queue.get(timeout=GET_TIMEOUT)
                except queue.Empty:
                    continue

                tt.tic()
                sender.send_data(*msg)
                timing.add('SEND', tt.toc())

                if opt.verbose:
                    Once(timing, log, per=1)
        except KeyboardInterrupt:
            log("send_worker: user interrupt")
        finally:
            worker_alive.value = 0

        sender.disconnect(address)
        sender.close()
        ctx.destroy()
        log("send_worker exit")
예제 #2
0
    def __init__(self,
                 *args,
                 worker_host='localhost',
                 worker_port=DEFAULT_PORT,
                 **kwargs):
        self.worker_host = worker_host
        self.worker_port = worker_port
        self.predictor_args = (args, kwargs)

        self.context = SerializingContext()
        self.socket = self.context.socket(zmq.PAIR)
        self.socket.connect(f"tcp://{worker_host}:{worker_port}")
        log(f"Connected to {worker_host}:{worker_port}")

        self.timing = AccumDict()

        self.init_worker()
예제 #3
0
    def __init__(self, *args, worker_host='localhost', worker_port=DEFAULT_PORT, **kwargs):
        self.worker_host = worker_host
        self.worker_port = worker_port
        self.predictor_args = (args, kwargs)
        self.timing = AccumDict()

        self.address = f"tcp://{worker_host}:{worker_port}"
        self.context = SerializingContext()
        self.socket = self.context.socket(zmq.PAIR)
        self.socket.connect(self.address)

        if not self.check_connection():
            self.socket.disconnect(self.address)
            # TODO: this hangs, as well as context.__del__
            # self.context.destroy()
            raise ConnectionError(f"Could not connect to {worker_host}:{worker_port}")

        log(f"Connected to {self.address}")

        self.init_worker()
예제 #4
0
    def recv_worker(port, recv_queue, worker_alive):
        timing = AccumDict()
        log = Logger('./var/log/recv_worker.log', verbose=opt.verbose)

        ctx = SerializingContext()
        socket = ctx.socket(zmq.PULL)
        socket.bind(f"tcp://*:{port}")
        socket.RCVTIMEO = RECV_TIMEOUT

        log(f'Receiving on port {port}', important=True)

        try:
            while worker_alive.value:
                tt = TicToc()

                try:
                    tt.tic()
                    msg = socket.recv_data()
                    timing.add('RECV', tt.toc())
                except zmq.error.Again:
                    log("recv timeout")
                    continue

                #log('recv', msg[0])

                method, data = msg
                if method['critical']:
                    recv_queue.put(msg)
                else:
                    try:
                        recv_queue.put(msg, block=False)
                    except queue.Full:
                        log('recv_queue full')

                Once(timing, log, per=1)
        except KeyboardInterrupt:
            log("recv_worker: user interrupt", important=True)

        worker_alive.value = 0
        log("recv_worker exit", important=True)
    def recv_worker(address, recv_queue, worker_alive):
        timing = AccumDict()
        log = Logger('./var/log/recv_worker.log')

        ctx = SerializingContext()
        receiver = ctx.socket(zmq.PULL)
        receiver.connect(address)
        receiver.RCVTIMEO = RECV_TIMEOUT

        log(f"Receiving from {address}")

        try:
            while worker_alive.value:
                tt = TicToc()

                try:
                    tt.tic()
                    msg = receiver.recv_data()
                    timing.add('RECV', tt.toc())
                except zmq.error.Again:
                    continue

                try:
                    recv_queue.put(msg, timeout=PUT_TIMEOUT)
                except queue.Full:
                    log('recv_queue full')
                    continue

                if opt.verbose:
                    Once(timing, log, per=1)
        except KeyboardInterrupt:
            log("recv_worker: user interrupt")
        finally:
            worker_alive.value = 0

        receiver.disconnect(address)
        receiver.close()
        ctx.destroy()
        log("recv_worker exit")
예제 #6
0
    def __init__(self, *args, in_addr=None, out_addr=None, **kwargs):
        self.in_addr = in_addr
        self.out_addr = out_addr
        self.predictor_args = (args, kwargs)
        self.timing = AccumDict()
        self.log = Tee('/tmp/predictor_remote.log')

        self.send_queue = mp.Queue(QUEUE_SIZE)
        self.recv_queue = mp.Queue(QUEUE_SIZE)

        self.worker_alive = mp.Value('i', 0)

        self.send_process = mp.Process(target=self.send_worker,
                                       args=(self.in_addr, self.send_queue,
                                             self.worker_alive),
                                       name="send_process")
        self.recv_process = mp.Process(target=self.recv_worker,
                                       args=(self.out_addr, self.recv_queue,
                                             self.worker_alive),
                                       name="recv_process")

        self._i_msg = -1
예제 #7
0
    def send_worker(port, send_queue, worker_alive):
        timing = AccumDict()
        log = Logger('./var/log/send_worker.log', verbose=opt.verbose)

        ctx = SerializingContext()
        socket = ctx.socket(zmq.PUSH)
        socket.bind(f"tcp://*:{port}")

        log(f'Sending on port {port}', important=True)

        try:
            while worker_alive.value:
                tt = TicToc()

                try:
                    method, data = send_queue.get(timeout=GET_TIMEOUT)
                except queue.Empty:
                    log("send queue empty")
                    continue

                # get the latest non-critical request from the queue
                # don't skip critical request
                while not send_queue.empty() and not method['critical']:
                    log(f"skip {method}")
                    method, data = send_queue.get()

                log("sending", method)

                tt.tic()
                socket.send_data(method, data)
                timing.add('SEND', tt.toc())

                Once(timing, log, per=1)
        except KeyboardInterrupt:
            log("predictor_worker: user interrupt", important=True)

        worker_alive.value = 0
        log("send_worker exit", important=True)
class PredictorRemote:
    def __init__(self, *args, in_addr=None, out_addr=None, **kwargs):
        self.in_addr = in_addr
        self.out_addr = out_addr
        self.predictor_args = (args, kwargs)
        self.timing = AccumDict()
        self.log = Logger('./var/log/predictor_remote.log',
                          verbose=opt.verbose)

        self.send_queue = mp.Queue(QUEUE_SIZE)
        self.recv_queue = mp.Queue(QUEUE_SIZE)

        self.worker_alive = mp.Value('i', 0)

        self.send_process = mp.Process(target=self.send_worker,
                                       args=(self.in_addr, self.send_queue,
                                             self.worker_alive),
                                       name="send_process")
        self.recv_process = mp.Process(target=self.recv_worker,
                                       args=(self.out_addr, self.recv_queue,
                                             self.worker_alive),
                                       name="recv_process")

        self._i_msg = -1

    def start(self):
        self.worker_alive.value = 1
        self.send_process.start()
        self.recv_process.start()

        self.init_remote_worker()

    def stop(self):
        self.worker_alive.value = 0
        self.log("join worker processes...")
        self.send_process.join(timeout=5)
        self.recv_process.join(timeout=5)
        self.send_process.terminate()
        self.recv_process.terminate()

    def init_remote_worker(self):
        return self._send_recv_async('__init__',
                                     self.predictor_args,
                                     critical=True)

    def __getattr__(self, attr):
        is_critical = attr != 'predict'
        return lambda *args, **kwargs: self._send_recv_async(
            attr, (args, kwargs), critical=is_critical)

    def _send_recv_async(self, method, args, critical):
        self._i_msg += 1

        args, kwargs = args

        tt = TicToc()
        tt.tic()
        if method == 'predict':
            image = args[0]
            assert isinstance(image, np.ndarray), 'Expected image'
            ret_code, data = cv2.imencode(
                ".jpg", image,
                [int(cv2.IMWRITE_JPEG_QUALITY), opt.jpg_quality])
        else:
            data = msgpack.packb((args, kwargs))
        self.timing.add('PACK', tt.toc())

        meta = {'name': method, 'critical': critical, 'id': self._i_msg}

        self.log("send", meta)

        if critical:
            self.send_queue.put((meta, data))

            while True:
                meta_recv, data_recv = self.recv_queue.get()
                if meta_recv == meta:
                    break
        else:
            try:
                # TODO: find good timeout
                self.send_queue.put((meta, data), timeout=PUT_TIMEOUT)
            except queue.Full:
                self.log('send_queue is full')

            try:
                meta_recv, data_recv = self.recv_queue.get(timeout=GET_TIMEOUT)
            except queue.Empty:
                self.log('recv_queue is empty')
                return None

        self.log("recv", meta_recv)

        tt.tic()
        if meta_recv['name'] == 'predict':
            result = cv2.imdecode(np.frombuffer(data_recv, dtype='uint8'), -1)
        else:
            result = msgpack.unpackb(data_recv)
        self.timing.add('UNPACK', tt.toc())

        if opt.verbose:
            Once(self.timing, per=1)

        return result

    @staticmethod
    def send_worker(address, send_queue, worker_alive):
        timing = AccumDict()
        log = Logger('./var/log/send_worker.log', opt.verbose)

        ctx = SerializingContext()
        sender = ctx.socket(zmq.PUSH)
        sender.connect(address)

        log(f"Sending to {address}")

        try:
            while worker_alive.value:
                tt = TicToc()

                try:
                    msg = send_queue.get(timeout=GET_TIMEOUT)
                except queue.Empty:
                    continue

                tt.tic()
                sender.send_data(*msg)
                timing.add('SEND', tt.toc())

                if opt.verbose:
                    Once(timing, log, per=1)
        except KeyboardInterrupt:
            log("send_worker: user interrupt")
        finally:
            worker_alive.value = 0

        sender.disconnect(address)
        sender.close()
        ctx.destroy()
        log("send_worker exit")

    @staticmethod
    def recv_worker(address, recv_queue, worker_alive):
        timing = AccumDict()
        log = Logger('./var/log/recv_worker.log')

        ctx = SerializingContext()
        receiver = ctx.socket(zmq.PULL)
        receiver.connect(address)
        receiver.RCVTIMEO = RECV_TIMEOUT

        log(f"Receiving from {address}")

        try:
            while worker_alive.value:
                tt = TicToc()

                try:
                    tt.tic()
                    msg = receiver.recv_data()
                    timing.add('RECV', tt.toc())
                except zmq.error.Again:
                    continue

                try:
                    recv_queue.put(msg, timeout=PUT_TIMEOUT)
                except queue.Full:
                    log('recv_queue full')
                    continue

                if opt.verbose:
                    Once(timing, log, per=1)
        except KeyboardInterrupt:
            log("recv_worker: user interrupt")
        finally:
            worker_alive.value = 0

        receiver.disconnect(address)
        receiver.close()
        ctx.destroy()
        log("recv_worker exit")
예제 #9
0
def message_handler(port):
    log("Creating socket")
    context = SerializingContext()
    socket = context.socket(zmq.PAIR)
    socket.bind("tcp://*:%s" % port)
    log("Listening for messages on port:", port)

    predictor = None
    predictor_args = ()
    timing = AccumDict()

    try:
        while True:
            tt = TicToc()

            tt.tic()
            attr, data = socket.recv_data()
            timing.add('RECV', tt.toc())

            try:
                tt.tic()
                if attr == 'predict':
                    image = cv2.imdecode(np.frombuffer(data, dtype='uint8'),
                                         -1)
                else:
                    args = msgpack.unpackb(data)
                timing.add('UNPACK', tt.toc())
            except ValueError:
                log("Invalid Message")
                continue

            tt.tic()
            if attr == "__init__":
                if args == predictor_args:
                    log("Same config as before... reusing previous predictor")
                else:
                    del predictor
                    predictor_args = args
                    predictor = PredictorLocal(*predictor_args[0],
                                               **predictor_args[1])
                    log("Initialized predictor with:", predictor_args)
                result = True
                tt.tic()  # don't account for init
            elif attr == 'predict':
                result = getattr(predictor, attr)(image)
            else:
                result = getattr(predictor, attr)(*args[0], **args[1])
            timing.add('CALL', tt.toc())

            tt.tic()
            if attr == 'predict':
                assert isinstance(result, np.ndarray), 'Expected image'
                ret_code, data_send = cv2.imencode(
                    ".jpg", result,
                    [int(cv2.IMWRITE_JPEG_QUALITY), opt.jpg_quality])
            else:
                data_send = msgpack.packb(result)
            timing.add('PACK', tt.toc())

            tt.tic()
            socket.send_data(attr, data_send)
            timing.add('SEND', tt.toc())

            Once(timing, per=1)
    except KeyboardInterrupt:
        pass
예제 #10
0
class PredictorRemote:
    def __init__(self, *args, worker_host='localhost', worker_port=DEFAULT_PORT, **kwargs):
        self.worker_host = worker_host
        self.worker_port = worker_port
        self.predictor_args = (args, kwargs)
        self.timing = AccumDict()

        self.address = f"tcp://{worker_host}:{worker_port}"
        self.context = SerializingContext()
        self.socket = self.context.socket(zmq.PAIR)
        self.socket.connect(self.address)

        if not self.check_connection():
            self.socket.disconnect(self.address)
            # TODO: this hangs, as well as context.__del__
            # self.context.destroy()
            raise ConnectionError(f"Could not connect to {worker_host}:{worker_port}")

        log(f"Connected to {self.address}")

        self.init_worker()

    def check_connection(self, timeout=1000):
        msg = (
            'hello',
            [], {}
        )

        try:
            old_rcvtimeo = self.socket.RCVTIMEO
            self.socket.RCVTIMEO = timeout
            response = self._send_recv_msg(msg)
            self.socket.RCVTIMEO = old_rcvtimeo
        except zmq.error.Again:
            return False

        return response == 'OK'

    def init_worker(self):
        msg = (
            '__init__',
            *self.predictor_args,
        )
        return self._send_recv_msg(msg)

    def __getattr__(self, attr):
        return lambda *args, **kwargs: self._send_recv_msg((attr, args, kwargs))

    def _send_recv_msg(self, msg):
        attr, args, kwargs = msg

        tt = TicToc()
        tt.tic()
        if attr == 'predict':
            image = args[0]
            assert isinstance(image, np.ndarray), 'Expected image'
            ret_code, data = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), opt.jpg_quality])
        else:
            data = msgpack.packb((args, kwargs))
        self.timing.add('PACK', tt.toc())

        tt.tic()
        self.socket.send_data(attr, data)
        self.timing.add('SEND', tt.toc())

        tt.tic()
        attr_recv, data_recv = self.socket.recv_data()
        self.timing.add('RECV', tt.toc())

        tt.tic()
        if attr_recv == 'predict':
            result = cv2.imdecode(np.frombuffer(data_recv, dtype='uint8'), -1)
        else:
            result = msgpack.unpackb(data_recv)
        self.timing.add('UNPACK', tt.toc())

        Once(self.timing, per=1)

        return result
예제 #11
0
class PredictorRemote:
    def __init__(self,
                 *args,
                 worker_host='localhost',
                 worker_port=DEFAULT_PORT,
                 **kwargs):
        self.worker_host = worker_host
        self.worker_port = worker_port
        self.predictor_args = (args, kwargs)

        self.context = SerializingContext()
        self.socket = self.context.socket(zmq.PAIR)
        self.socket.connect(f"tcp://{worker_host}:{worker_port}")
        log(f"Connected to {worker_host}:{worker_port}")

        self.timing = AccumDict()

        self.init_worker()

    def init_worker(self):
        msg = (
            '__init__',
            *self.predictor_args,
        )
        return self._send_recv_msg(msg)

    def __getattr__(self, attr):
        return lambda *args, **kwargs: self._send_recv_msg(
            (attr, args, kwargs))

    def _send_recv_msg(self, msg):
        attr, args, kwargs = msg

        tt = TicToc()
        tt.tic()
        if attr == 'predict':
            image = args[0]
            assert isinstance(image, np.ndarray), 'Expected image'
            ret_code, data = cv2.imencode(
                ".jpg", image,
                [int(cv2.IMWRITE_JPEG_QUALITY), opt.jpg_quality])
        else:
            data = msgpack.packb((args, kwargs))
        self.timing.add('PACK', tt.toc())

        tt.tic()
        self.socket.send_data(attr, data)
        self.timing.add('SEND', tt.toc())

        tt.tic()
        attr_recv, data_recv = self.socket.recv_data()
        self.timing.add('RECV', tt.toc())

        tt.tic()
        if attr_recv == 'predict':
            result = cv2.imdecode(np.frombuffer(data_recv, dtype='uint8'), -1)
        else:
            result = msgpack.unpackb(data_recv)
        self.timing.add('UNPACK', tt.toc())

        Once(self.timing, per=1)

        return result
예제 #12
0
    def predictor_worker(recv_queue, send_queue, worker_alive):
        predictor = None
        predictor_args = ()
        timing = AccumDict()
        log = Logger('./var/log/predictor_worker.log', verbose=opt.verbose)
        
        try:
            while worker_alive.value:
                tt = TicToc()

                try:
                    method, data = recv_queue.get(timeout=GET_TIMEOUT)
                except queue.Empty:
                    continue

                # get the latest non-critical request from the queue
                # don't skip critical request
                while not recv_queue.empty() and not method['critical']:
                    log(f"skip {method}")
                    method, data = recv_queue.get()

                log("working on", method)

                try:
                    tt.tic()
                    if method['name'] == 'predict':
                        image = cv2.imdecode(np.frombuffer(data, dtype='uint8'), -1)
                    else:
                        args = msgpack.unpackb(data)
                    timing.add('UNPACK', tt.toc())
                except ValueError:
                    log("Invalid Message", important=True)
                    continue

                tt.tic()
                if method['name'] == "hello":
                    result = "OK"
                elif method['name'] == "__init__":
                    if args == predictor_args:
                        log("Same config as before... reusing previous predictor")
                    else:
                        del predictor
                        predictor_args = args
                        predictor = PredictorLocal(*predictor_args[0], **predictor_args[1])
                        log("Initialized predictor with:", predictor_args, important=True)
                    result = True
                    tt.tic() # don't account for init
                elif method['name'] == 'predict':
                    assert predictor is not None, "Predictor was not initialized"
                    result = getattr(predictor, method['name'])(image)
                else:
                    assert predictor is not None, "Predictor was not initialized"
                    result = getattr(predictor, method['name'])(*args[0], **args[1])
                timing.add('CALL', tt.toc())

                tt.tic()
                if method['name'] == 'predict':
                    assert isinstance(result, np.ndarray), f'Expected np.ndarray, got {result.__class__}'
                    ret_code, data_send = cv2.imencode(".jpg", result, [int(cv2.IMWRITE_JPEG_QUALITY), opt.jpg_quality])
                else:
                    data_send = msgpack.packb(result)
                timing.add('PACK', tt.toc())

                if method['critical']:
                    send_queue.put((method, data_send))
                else:
                    try:
                        send_queue.put((method, data_send), block=False)
                    except queue.Full:
                        log("send_queue full")
                        pass

                Once(timing, log, per=1)
        except KeyboardInterrupt:
            log("predictor_worker: user interrupt", important=True)
        except Exception as e:
            log("predictor_worker error", important=True)
            traceback.print_exc()
    
        worker_alive.value = 0
        log("predictor_worker exit", important=True)
예제 #13
0
    def predictor_worker(recv_queue, send_queue, worker_alive):
        predictor = None
        predictor_args = ()
        timing = AccumDict()
        log = Logger('./var/log/predictor_worker.log', verbose=opt.verbose)

        try:
            while worker_alive.value:
                tt = TicToc()

                try:
                    method, data = recv_queue.get(timeout=GET_TIMEOUT)
                except queue.Empty:
                    continue

                # get the latest non-critical request from the queue
                # don't skip critical request
                while not recv_queue.empty() and not method['critical']:
                    log(f"skip {method}")
                    method, data = recv_queue.get()

                log("working on", method)

                try:
                    tt.tic()
                    if method['name'] == 'predict':
                        image = cv2.imdecode(
                            np.frombuffer(data, dtype='uint8'), -1)
                    else:
                        args = msgpack.unpackb(data)
                    timing.add('UNPACK', tt.toc())
                except ValueError:
                    log("Invalid Message", important=True)
                    continue

                tt.tic()
                if method['name'] == "hello":
                    result = "OK"
                elif method['name'] == "__init__":
                    if args == predictor_args:
                        log("Same config as before... reusing previous predictor"
                            )
                    else:
                        del predictor
                        cfg = get_cfg()
                        cfg.merge_from_file(
                            model_zoo.get_config_file(
                                "COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"
                            ))
                        cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(
                            "COCO-PanopticSegmentation/panoptic_fpn_R_101_3x.yaml"
                        )
                        cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 0.6
                        cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.6
                        cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.6
                        predictor = DefaultPredictor(cfg)
                    result = True
                    tt.tic()  # don't account for init
                elif method['name'] == 'predict':
                    assert predictor is not None, "Predictor was not initialized"
                    panoptic_seg, segments_info = predictor(
                        image)["panoptic_seg"]
                    v = Visualizer(image[:, :, ::-1],
                                   MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
                                   scale=1.2)
                    out = v.draw_panoptic_seg_predictions(
                        panoptic_seg.to("cpu"), segments_info)
                    result = out.get_image()[:, :, ::-1]
                else:
                    assert predictor is not None, "Predictor was not initialized"
                    panoptic_seg, segments_info = predictor(
                        image)["panoptic_seg"]
                    v = Visualizer(image[:, :, ::-1],
                                   MetadataCatalog.get(cfg.DATASETS.TRAIN[0]),
                                   scale=1.2)
                    out = v.draw_panoptic_seg_predictions(
                        panoptic_seg.to("cpu"), segments_info)
                    result = out.get_image()[:, :, ::-1]
                timing.add('CALL', tt.toc())

                tt.tic()
                if method['name'] == 'predict':
                    assert isinstance(
                        result, np.ndarray
                    ), f'Expected np.ndarray, got {result.__class__}'
                    ret_code, data_send = cv2.imencode(
                        ".jpg", result,
                        [int(cv2.IMWRITE_JPEG_QUALITY), opt.jpg_quality])
                else:
                    data_send = msgpack.packb(result)
                timing.add('PACK', tt.toc())

                if method['critical']:
                    send_queue.put((method, data_send))
                else:
                    try:
                        send_queue.put((method, data_send), block=False)
                    except queue.Full:
                        log("send_queue full")
                        pass

                Once(timing, log, per=1)
        except KeyboardInterrupt:
            log("predictor_worker: user interrupt", important=True)
        except Exception as e:
            log("predictor_worker error", important=True)
            traceback.print_exc()

        worker_alive.value = 0
        log("predictor_worker exit", important=True)