Ejemplo n.º 1
0
def main():
    # Process CLI arguments
    args = parse_args()
    if args['debug']:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    if args['topic_config']:
        with open(args['topic_config']) as f:
            topic_config = json.load(f)
    else:
        topic_config = {}
    topic_config['berrynet/data/mode'] = 'self.switch_mode'
    topic_config['berrynet/data/deploy'] = 'self.deploy'

    if args['topic'] is not None:
        for t, h in args['topic']:
            topic_config[t] = h

    w, h = args['warmup_size']
    # Setup pipeline service
    if args['disable_engine']:
        eng = PipelineDummyEngine()
    else:
        eng = PipelineEngine(args['pipeline_config'],
                             disable_warmup=args['disable_warmup'],
                             verbosity=args['verbosity'],
                             benchmark=args['benchmark'],
                             warmup_size=(h, w, 3))
    comm_config = {
        'subscribe': topic_config,
        'broker': {
            'address': args['broker_ip'],
            'port': 1883
        }
    }
    engine_service = PipelineService(
        'pipeline service',
        eng,
        comm_config,
        pid=args['pipeline_id'],
        pipeline_config_path=args['pipeline_config'],
        disable_engine=args['disable_engine'],
        disable_warmup=args['disable_warmup'],
        warmup_size=(h, w, 3))

    engine_service.run(args)
Ejemplo n.º 2
0
def main():
    args = parse_args()
    if args['debug']:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    eng = MockupEngine()
    comm_config = {
        'subscribe': {},
        'broker': {
            'address': 'localhost',
            'port': 1883
        }
    }
    engine_service = MockupService('mockup service', eng, comm_config)
    engine_service.run(args)
Ejemplo n.º 3
0
def main():
    args = parse_argsr()

    if args.debug:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    if args.engine == 'classifier':
        engine = OpenVINOClassifierEngine(
                     model = args.model,
                     device = args.device,
                     labels = args.labels,
                     top_k = args.number_top)
    elif args.engine == 'detector':
        engine = OpenVINODetectorEngine(
                     model = args.model,
                     device = args.device,
                     labels = args.labels)
    else:
        raise Exception('Illegal engine {}, it should be '
                        'classifier or detector'.format(args.engine))

    #set_openvino_environment()
    #if args.debug:
    #    logger.debug('OpenVINO environment vars')
    #    target_vars = ['INSTALLDIR',
    #                   'INTEL_CVSDK_DIR',
    #                   'LD_LIBRARY_PATH',
    #                   'InferenceEngine_DIR',
    #                   'IE_PLUGINS_PATH',
    #                   'PATH',
    #                   'PYTHONPATH']
    #    for i in target_vars:
    #        logger.debug('\t{var}: {val}'.format(
    #            var = i,
    #            val = os.environ.get(i)))

    bgr_array = cv2.imread(args.input)
    t = time()
    image_data = engine.process_input(bgr_array)
    output = engine.inference(image_data)
    model_outputs = engine.process_output(output)
    logger.debug('Result: {}'.format(model_outputs))
    logger.debug('Classification takes {} s'.format(time() - t))
Ejemplo n.º 4
0
def main():
    # Test TFLite engines
    args = parse_args()
    if args['debug']:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)
    if args['model_package'] != '':
        dlmm = DLModelManager()
        meta = dlmm.get_model_meta(args['model_package'])
        args['model'] = meta['model']
        args['label'] = meta['label']
    logger.debug('model filepath: ' + args['model'])
    logger.debug('label filepath: ' + args['label'])

    comm_config = {
        'subscribe': {},
        'broker': {
            'address': 'localhost',
            'port': 1883
        }
    }

    if args['service'] == 'classifier':
        engine = TFLiteClassifierEngine(
                     model = args['model'],
                     labels = args['label'],
                     top_k = args['top_k'],
                     num_threads = args['num_threads'])
        service_functor = TFLiteClassifierService
    elif args['service'] == 'detector':
        engine = TFLiteDetectorEngine(
                     model = args['model'],
                     labels = args['label'],
                     num_threads = args['num_threads'])
        service_functor = TFLiteDetectorService
    else:
        raise Exception('Illegal service {}, it should be '
                        'classifier or detector'.format(args['service']))

    engine_service = service_functor(args['service_name'],
                                     engine,
                                     comm_config,
                                     draw=args['draw'])
    engine_service.run(args)
Ejemplo n.º 5
0
def main():
    args = parse_args()
    if args['debug']:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    comm_config = {
        'subscribe': {},
        'broker': {
            'address': 'localhost',
            'port': 1883
        }
    }
    engine_service = PipelineRestarterService(
        'pipeline service restarter',
        comm_config)
    engine_service.run(args)
Ejemplo n.º 6
0
def main():
    args = parse_args()
    if args['debug']:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    comm_config = {
        'subscribe': {
            args['topic']: None
        },
        'publish': args['publish'],
        'broker': {
            'address': args['broker_ip'],
            'port': args['broker_port']
        }
    }
    config_client = DydaConfigUpdateClient(comm_config, args['debug'])
    config_client.run(args)
Ejemplo n.º 7
0
def main():
    args = parse_args()
    if args['debug']:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    # Topics and actions can come from two sources: CLI and config file.
    # Setup topic_config by parsing values from the two sources.
    if args['topic_config']:
        with open(args['topic_config']) as f:
            topic_config = json.load(f)
    else:
        topic_config = {}
    topic_config.update({t: args['topic_action'] for t in args['topic']})

    comm_config = {
        'subscribe': topic_config,
        'broker': {
            'address': args['broker_ip'],
            'port': 1883
        }
    }
    fbd_service = FBDashboardService(comm_config, args['data_dirpath'],
                                     args['no_decoration'], args['debug'],
                                     args['debug_save_frame'])
    fbd_service.run(args)

    glutInitWindowPosition(0, 0)
    glutInit(sys.argv)
    glutInitDisplayMode(GLUT_RGBA | GLUT_DOUBLE)
    glutCreateWindow("BerryNet Result Dashboard, q to quit")
    glutDisplayFunc(fbd_service.update_fb)
    glutKeyboardFunc(keyboard)
    init()
    glutIdleFunc(idle)
    if args['no_full_screen']:
        pass
    else:
        glutFullScreen()
    glutMainLoop()
Ejemplo n.º 8
0
def main():
    args = parse_args()
    if args['debug']:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    comm_config = {
        'subscribe': {
            args['topic']: None
        },
        'publish': args['publish'],
        'broker': {
            'address': args['broker_ip'],
            'port': args['broker_port']
        },
        'configfile': args['configfile'],
        'idlist': args['idlist']
    }
    config_service = DydaConfigUpdateService(comm_config, args['debug'])
    config_service.run(args)
Ejemplo n.º 9
0
def main():
    args = parse_args()
    if args['debug']:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)
    #logger.debug('model filepath: ' + args['model'])
    #logger.debug('label filepath: ' + args['label'])

    engine = DarknetEngine(
        config=b'/usr/share/dlmodels/tinyyolovoc-20170816/tiny-yolo-voc.cfg',
        model=b'/usr/share/dlmodels/tinyyolovoc-20170816/tiny-yolo-voc.weights',
        meta=b'/usr/share/dlmodels/tinyyolovoc-20170816/voc.data')
    comm_config = {
        'subscribe': {},
        'broker': {
            'address': 'localhost',
            'port': 1883
        }
    }
    engine_service = DarknetService(args['service_name'], engine, comm_config)
    engine_service.run(args)
Ejemplo n.º 10
0
def main():
    args = parse_args()
    if args['debug']:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    comm_config = {
        'subscribe': {},
        'broker': {
            'address': args['broker_ip'],
            'port': args['broker_port']
        }
    }
    comm = Communicator(comm_config, debug=True)

    duration = lambda t: (datetime.now() - t).microseconds / 1000

    if args['mode'] == 'stream':
        counter = 0
        # Check input stream source
        if args['stream_src'].isdigit():
            # source is a physically connected camera
            stream_source = '/dev/video{}'.format(int(args['stream_src']))
            capture = cv2.VideoCapture(int(args['stream_src']))
        else:
            # source is an IP camera
            stream_source = args['stream_src']
            capture = cv2.VideoCapture(args['stream_src'])
        cam_fps = capture.get(cv2.CAP_PROP_FPS)
        if cam_fps > 30 or cam_fps < 1:
            logger.warn(
                'Camera FPS is {} (>30 or <1). Set it to 30.'.format(cam_fps))
            cam_fps = 30
        out_fps = args['fps']
        interval = int(cam_fps / out_fps)

        # warmup
        #t_warmup_start = time.time()
        #t_warmup_now = time.time()
        #warmup_counter = 0
        #while t_warmup_now - t_warmup_start < 1:
        #    capture.read()
        #    warmup_counter += 1
        #    t_warmup_now = time.time()

        logger.debug('===== VideoCapture Information =====')
        logger.debug('Stream Source: {}'.format(stream_source))
        logger.debug('Camera FPS: {}'.format(cam_fps))
        logger.debug('Output FPS: {}'.format(out_fps))
        logger.debug('Interval: {}'.format(interval))
        #logger.debug('Warmup Counter: {}'.format(warmup_counter))
        logger.debug('====================================')

        while True:
            status, im = capture.read()
            if (status is False):
                logger.warn('ERROR: Failure happened when reading frame')

            # NOTE: Hard-coding rotation for AIKEA onboard camera.
            #       We will add parameter support in the future.
            im = tinycv.rotate_ccw_opencv(im)

            counter += 1
            if counter == interval:
                logger.debug('Drop frames: {}'.format(counter - 1))
                counter = 0

                # Open a window and display the ready-to-send frame.
                # This is useful for development and debugging.
                if args['display']:
                    cv2.imshow('Frame', im)
                    cv2.waitKey(1)

                t = datetime.now()
                #logger.debug('write frame to /tmp/output.jpg')
                #cv2.imwrite('/tmp/output.jpg', im)
                retval, jpg_bytes = cv2.imencode('.jpg', im)
                obj = {}
                obj['timestamp'] = datetime.now().isoformat()
                obj['bytes'] = payload.stringify_jpg(jpg_bytes)
                obj['meta'] = {
                    'roi': [{
                        'top': 50,
                        #'left': 341,
                        #'bottom': 500,
                        #'right': 682,
                        #'left': 640,
                        #'bottom': 980,
                        #'right': 1280,
                        'left': 10,
                        'bottom': 600,
                        'right': 600,
                        'overlap_threshold': 0.5
                    }]
                }
                logger.debug('timestamp: {}'.format(obj['timestamp']))
                logger.debug('bytes len: {}'.format(len(obj['bytes'])))
                logger.debug('meta: {}'.format(obj['meta']))
                mqtt_payload = payload.serialize_payload([obj])
                comm.send('berrynet/data/rgbimage', mqtt_payload)
                logger.debug('send: {} ms'.format(duration(t)))
            else:
                pass
    elif args['mode'] == 'file':
        # Prepare MQTT payload
        im = cv2.imread(args['filepath'])
        retval, jpg_bytes = cv2.imencode('.jpg', im)

        t = datetime.now()
        obj = {}
        obj['timestamp'] = datetime.now().isoformat()
        obj['bytes'] = payload.stringify_jpg(jpg_bytes)
        obj['meta'] = {
            'roi': [{
                'top': 50,
                'left': 10,
                'bottom': 600,
                'right': 600,
                'overlap_threshold': 0.5
            }]
        }
        mqtt_payload = payload.serialize_payload([obj])
        logger.debug('payload: {} ms'.format(duration(t)))
        logger.debug('payload size: {}'.format(len(mqtt_payload)))

        # Client publishes payload
        t = datetime.now()
        comm.send('berrynet/data/rgbimage', mqtt_payload)
        logger.debug('mqtt.publish: {} ms'.format(duration(t)))
        logger.debug('publish at {}'.format(datetime.now().isoformat()))
    else:
        logger.error('User assigned unknown mode {}'.format(args['mode']))
Ejemplo n.º 11
0
def main():
    args = parse_args()
    if args['debug']:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    comm_config = {
        'subscribe': {},
        'broker': {
            'address': args['broker_ip'],
            'port': args['broker_port']
        }
    }
    comm = Communicator(comm_config, debug=True)

    duration = lambda t: (datetime.now() - t).microseconds / 1000

    metadata = json.loads(args.get('meta', '{}'))

    if args['mode'] == 'stream':
        counter = 0
        fail_counter = 0

        # Check input stream source
        if args['stream_src'].isdigit():
            # source is a physically connected camera
            stream_source = int(args['stream_src'])
        else:
            # source is an IP camera
            stream_source = args['stream_src']
        capture = cv2.VideoCapture(stream_source)
        cam_fps = capture.get(cv2.CAP_PROP_FPS)
        if cam_fps > 30 or cam_fps < 1:
            logger.warn('Camera FPS is {} (>30 or <1). Set it to 30.'.format(cam_fps))
            cam_fps = 30
        out_fps = args['fps']
        interval = int(cam_fps / out_fps)

        # warmup
        #t_warmup_start = time.time()
        #t_warmup_now = time.time()
        #warmup_counter = 0
        #while t_warmup_now - t_warmup_start < 1:
        #    capture.read()
        #    warmup_counter += 1
        #    t_warmup_now = time.time()

        logger.debug('===== VideoCapture Information =====')
        if stream_source.isdigit():
            stream_source_uri = '/dev/video{}'.format(stream_source)
        else:
            stream_source_uri = stream_source
        logger.debug('Stream Source: {}'.format(stream_source_uri))
        logger.debug('Camera FPS: {}'.format(cam_fps))
        logger.debug('Output FPS: {}'.format(out_fps))
        logger.debug('Interval: {}'.format(interval))
        logger.debug('Send MQTT Topic: {}'.format(args['topic']))
        #logger.debug('Warmup Counter: {}'.format(warmup_counter))
        logger.debug('====================================')

        while True:
            status, im = capture.read()

            # To verify whether the input source is alive, you should use the
            # return value of capture.read(). It will not work by capturing
            # exception of a capture instance, or by checking the return value
            # of capture.isOpened().
            #
            # Two reasons:
            # 1. If a dead stream is alive again, capture will not notify
            #    that input source is dead.
            #
            # 2. If you check capture.isOpened(), it will keep retruning
            #    True if a stream is dead afterward. So you can not use
            #    the capture return value (capture status) to determine
            #    whether a stream is alive or not.
            if (status is True):
                counter += 1
                if counter == interval:
                    logger.debug('Drop frames: {}'.format(counter-1))
                    counter = 0

                    # Open a window and display the ready-to-send frame.
                    # This is useful for development and debugging.
                    if args['display']:
                        cv2.imshow('Frame', im)
                        cv2.waitKey(1)

                    t = datetime.now()
                    retval, jpg_bytes = cv2.imencode('.jpg', im)
                    mqtt_payload = payload.serialize_jpg(jpg_bytes, args['hash'], metadata)
                    comm.send(args['topic'], mqtt_payload)
                    logger.debug('send: {} ms'.format(duration(t)))
                else:
                    pass
            else:
                fail_counter += 1
                logger.critical('ERROR: Failure #{} happened when reading frame'.format(fail_counter))

                # Re-create capture.
                capture.release()
                logger.critical('Re-create a capture and reconnect to {} after 5s'.format(stream_source))
                time.sleep(5)
                capture = cv2.VideoCapture(stream_source)
    elif args['mode'] == 'file':
        # Prepare MQTT payload
        im = cv2.imread(args['filepath'])
        retval, jpg_bytes = cv2.imencode('.jpg', im)

        t = datetime.now()
        mqtt_payload = payload.serialize_jpg(jpg_bytes, args['hash'], metadata)
        logger.debug('payload: {} ms'.format(duration(t)))
        logger.debug('payload size: {}'.format(len(mqtt_payload)))

        # Client publishes payload
        t = datetime.now()
        comm.send(args['topic'], mqtt_payload)
        logger.debug('mqtt.publish: {} ms'.format(duration(t)))
        logger.debug('publish at {}'.format(datetime.now().isoformat()))
    else:
        logger.error('User assigned unknown mode {}'.format(args['mode']))
Ejemplo n.º 12
0
def main():
    args = parse_args()
    if args['debug']:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    comm_config = {
        'subscribe': {},
        'broker': {
            'address': args['broker_ip'],
            'port': args['broker_port']
        }
    }
    comm = Communicator(comm_config, debug=True)

    duration = lambda t: (datetime.now() - t).microseconds / 1000

    if args['mode'] == 'stream':
        counter = 0
        # Check input stream source
        if args['stream_src'].isdigit():
            # source is a physically connected camera
            stream_source = '/dev/video{}'.format(int(args['stream_src']))
            capture = cv2.VideoCapture(int(args['stream_src']))
        else:
            # source is an IP camera
            stream_source = args['stream_src']
            capture = cv2.VideoCapture(args['stream_src'])
        cam_fps = capture.get(cv2.CAP_PROP_FPS)
        if cam_fps > 30 or cam_fps < 1:
            logger.warn(
                'Camera FPS is {} (>30 or <1). Set it to 30.'.format(cam_fps))
            cam_fps = 30
        out_fps = args['fps']
        interval = int(cam_fps / out_fps)

        # warmup
        #t_warmup_start = time.time()
        #t_warmup_now = time.time()
        #warmup_counter = 0
        #while t_warmup_now - t_warmup_start < 1:
        #    capture.read()
        #    warmup_counter += 1
        #    t_warmup_now = time.time()

        logger.debug('===== VideoCapture Information =====')
        logger.debug('Stream Source: {}'.format(stream_source))
        logger.debug('Camera FPS: {}'.format(cam_fps))
        logger.debug('Output FPS: {}'.format(out_fps))
        logger.debug('Interval: {}'.format(interval))
        #logger.debug('Warmup Counter: {}'.format(warmup_counter))
        logger.debug('====================================')

        while True:
            status, im = capture.read()
            if (status is False):
                logger.warn('ERROR: Failure happened when reading frame')

            counter += 1
            if counter == interval:
                logger.debug('Drop frames: {}'.format(counter - 1))
                counter = 0

                # Open a window and display the ready-to-send frame.
                # This is useful for development and debugging.
                if args['display']:
                    cv2.imshow('Frame', im)
                    cv2.waitKey(1)

                t = datetime.now()
                retval, jpg_bytes = cv2.imencode('.jpg', im)
                mqtt_payload = payload.serialize_jpg(jpg_bytes)
                comm.send('berrynet/data/rgbimage', mqtt_payload)
                logger.debug('send: {} ms'.format(duration(t)))
            else:
                pass
    elif args['mode'] == 'file':
        # Prepare MQTT payload
        im = cv2.imread(args['filepath'])
        retval, jpg_bytes = cv2.imencode('.jpg', im)

        t = datetime.now()
        mqtt_payload = payload.serialize_jpg(jpg_bytes)
        logger.debug('payload: {} ms'.format(duration(t)))
        logger.debug('payload size: {}'.format(len(mqtt_payload)))

        # Client publishes payload
        t = datetime.now()
        comm.send('berrynet/data/rgbimage', mqtt_payload)
        logger.debug('mqtt.publish: {} ms'.format(duration(t)))
        logger.debug('publish at {}'.format(datetime.now().isoformat()))
    else:
        logger.error('User assigned unknown mode {}'.format(args['mode']))
Ejemplo n.º 13
0
def main():
    # Example command
    #     $ python3 tflite_engine.py -e detector \
    #           -m detect.tflite --labels labels.txt -i dog.jpg --debug
    args = parse_argsr()

    if args.debug:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    if args.engine == 'classifier':
        engine = TFLiteClassifierEngine(
                     model = args.model,
                     labels = args.labels,
                     top_k = args.top_k,
                     num_threads = args.num_threads)
    elif args.engine == 'detector':
        engine = TFLiteDetectorEngine(
                     model = args.model,
                     labels = args.labels,
                     num_threads = args.num_threads)
    else:
        raise Exception('Illegal engine {}, it should be '
                        'classifier or detector'.format(args.engine))

    for i in range(5):
        bgr_array = cv2.imread(args.input)
        t = time.time()
        image_data = engine.process_input(bgr_array)
        output = engine.inference(image_data)
        model_outputs = engine.process_output(output)
        # Reference result
        #     input:
        #         darknet/data/dog.jpg
        #     output:
        #         Inference takes 0.7533011436462402 s
        #         Inference takes 0.5741658210754395 s
        #         Inference takes 0.6120760440826416 s
        #         Inference takes 0.6191139221191406 s
        #         Inference takes 0.5809791088104248 s
        #         label: bicycle  conf: 0.9563907980918884  (139, 116) (567, 429)
        #         label: car  conf: 0.8757821917533875  (459, 80) (690, 172)
        #         label: dog  conf: 0.869189441204071  (131, 218) (314, 539)
        #         label: car  conf: 0.40003547072410583  (698, 122) (724, 152)
        logger.debug('Inference takes {} s'.format(time.time() - t))

    if args.engine == 'classifier':
        for r in model_outputs['annotations']:
            logger.debug('label: {0}  conf: {1}'.format(
                r['label'],
                r['confidence']
            ))
    elif args.engine == 'detector':
        for r in model_outputs['annotations']:
            logger.debug('label: {0}  conf: {1}  ({2}, {3}) ({4}, {5})'.format(
                r['label'],
                r['confidence'],
                r['left'],
                r['top'],
                r['right'],
                r['bottom']
            ))
    else:
        raise Exception('Can not get result '
                        'from illegal engine {}'.format(args.engine))
Ejemplo n.º 14
0
    def __init__(self, dyda_config_path='', debug=False):
        """ __init__ of DetectorOpenVINO

        Trainer Variables:
            input_data: a list of image array
            results: defined by lab_tools.output_pred_detection()

        Arguments:
            dyda_config_path -- Trainer config filepath
        """
        if debug:
            logger.setLevel(logging.DEBUG)
        else:
            logger.setLevel(logging.INFO)

        # Setup dyda config
        super(DetectorOpenVINO,
              self).__init__(dyda_config_path=dyda_config_path)
        self.set_param(self.class_name)

        self.check_param_keys()

        if "threshold" in self.param.keys():
            self.threshold = self.param["threshold"]
        else:
            self.threshold = 0.3

        # Setup DL model
        model_xml = self.param['model_description']
        model_bin = self.param['model_file']
        with open(self.param['label_file'], 'r') as f:
            self.labels_map = [x.strip() for x in f]

        # Setup OpenVINO
        #
        # Plugin initialization for specified device and
        # load extensions library if specified
        #
        # Note: MKLDNN CPU-targeted custom layer support is not included
        #       because we do not use it yet.
        self.plugin = IEPlugin(device=self.param['device'],
                               plugin_dirs=self.param['plugin_dirs'])
        if self.param['device'] == 'CPU':
            for ext in self.param['cpu_extensions']:
                logger.info('Add cpu extension: {}'.format(ext))
                self.plugin.add_cpu_extension(ext)
        logger.debug("Computation device: {}".format(self.param['device']))

        # Read IR
        logger.debug("Loading network files:\n\t{}\n\t{}".format(
            model_xml, model_bin))
        net = IENetwork(model=model_xml, weights=model_bin)

        if self.plugin.device == "CPU":
            supported_layers = self.plugin.get_supported_layers(net)
            not_supported_layers = [
                l for l in net.layers.keys() if l not in supported_layers
            ]
            if len(not_supported_layers) != 0:
                logger.error(
                    ('Following layers are not supported '
                     'by the plugin for specified device {}:\n {}').format(
                         self.plugin.device, ', '.join(not_supported_layers)))
                logger.error("Please try to specify cpu "
                             "extensions library path in demo's "
                             "command line parameters using -l "
                             "or --cpu_extension command line argument")
                sys.exit(1)

        assert len(net.inputs.keys()) == 1, (
            'Demo supports only single input topologies')
        assert len(
            net.outputs) == 1, ('Demo supports only single output topologies')

        # input_blob and and out_blob are the layer names in string format.
        logger.debug("Preparing input blobs")
        self.input_blob = next(iter(net.inputs))
        self.out_blob = next(iter(net.outputs))

        self.n, self.c, self.h, self.w = net.inputs[self.input_blob].shape

        # Loading model to the plugin
        self.exec_net = self.plugin.load(network=net, num_requests=2)

        del net

        # Initialize engine mode: sync or async
        #
        # FIXME: async mode does not work currently.
        #        process_input needs to provide two input tensors for async.
        self.is_async_mode = False
        self.cur_request_id = 0
        self.next_request_id = 1