Esempio n. 1
0
    def model_build(self, is_training=False):
        r"""
        Build the object detection model to predict the image.

        Args:
            is_training (bool): default: False.

        Returns:
            model.Model, generated object detection model.
        """
        model_net = model_checker.get(self.config.get('model_net'))
        if not model_net:
            err_msg = 'Currently model_net only supports {}!'.format(
                str(list(model_checker.keys())))
            raise KeyError(err_msg)

        class_num = self.config.get('class_num')
        if class_num <= 0:
            err_msg = 'The class_num should be an integer greater than 0, got {}.'.format(
                class_num)
            raise ValueError(err_msg)

        net = model_net(class_num=class_num, is_training=is_training)
        serve_model = model.Model(net)
        return serve_model
Esempio n. 2
0
def web_predict(instance, servable_name, servable_model, dataset_name, strategy):
    """
    Predict the result based on the input data.

    A network will be constructed based on the input and servable data, then load the checkpoint and do the predict.

    Args:
        instance (dict): the dict of input image after transformation, with keys of `shape`, `dtype` and `data`(Image object).
        servable_name (str): servable name
        servable_model (str): name of the model
        strategy (str): output strategy, usually select between `TOP1_CLASS` and `TOP5_CLASS`, for cyclegan, select between `gray2color` and `color2gray`

    Returns:
        The dict object of predicted result after post process.

    Examples:
        >>> # In the server part, after servable_search
        >>> res = web_predict(instance, servable_name, servable['model'], strategy)
        >>> return jsonify(res)
    """

    # check if servable model name is valid
    model_name = servable_model['name']
    net_func = model_checker.get(model_name)
    if net_func is None:
        err_msg = "Currently model_name only supports " + str(list(model_checker.keys())) + "!"
        return {"status": 1, "err_msg": err_msg}

    # check if model_format is valid
    model_format = servable_model['format']
    if model_format not in ("ckpt"):
        err_msg = "Currently model_format only supports `ckpt`!"
        return {"status": 1, "err_msg": err_msg}

    # Check if dataset supports
    trans_func = transform_checker.get(dataset_name)
    if trans_func is None:
        print("Currently dataset_name only supports {}!".format(list(transform_checker.keys())))
        sys.exit(0)

    # process the original data
    ori_img = np.array(json.loads(instance['data']), dtype=instance['dtype'])
    if dataset_name in ['mnist']:
        image = trans_func(ori_img)
    else:
        cvt_image = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
        image = trans_func(cvt_image)

    input_data = ts.array(image.tolist(), dtype=image.dtype.name)

    res_msg = ''
    if model_name == "cycle_gan":
        g_model = servable_model['g_model']
        if strategy == 'gray2color':
            # build the network
            G_generator, _ = net_func(g_model=g_model)
            ckpt_name = 'G_A'

        elif strategy == 'color2gray':
            _, G_generator = net_func(g_model=g_model)
            ckpt_name = 'G_B'
        else:
            err_msg = "Currently cycle_gan strategy only supports `gray2color` and `color2gray`!"
            return {"status": 1, "err_msg": err_msg}
        ckpt_path = os.path.join(serving_path, servable_name, ckpt_name + "." + model_format)
        out_img = cyclegan_predict(G_generator, input_data, ckpt_path)
        res_msg = '原图使用{}风格迁移效果'.format(strategy)
        data = numpy2base64(out_img)
    else:
        # build the network
        class_num = servable_model['class_num']
        net = net_func(class_num=class_num, is_training=False)
        serve_model = model.Model(net)

        # load checkpoint
        ckpt_path = os.path.join(serving_path, servable_name, model_name + "." + model_format)
        if not os.path.isfile(ckpt_path):
            err_msg = "The model path " + ckpt_path + " not exist!"
            return {"status": 1, "err_msg": err_msg}
        serve_model.load_checkpoint(ckpt_path)

        # execute the network to perform model prediction
        output = serve_model.predict(ts.expand_dims(input_data, 0))

        if model_name == "ssd300":
            output_np = (ts.concatenate((output[0], output[1]), axis=-1).asnumpy())
            ih, iw, _ = instance['shape']
            bbox_data = trans_func.postprocess(output_np, (ih, iw), strategy)
            print(bbox_data)
            bbox_num = len(bbox_data)
            if not bbox_num:
                err_msg = "抱歉!未检测到任何种类,无法标注。"
                return {"status": 1, "err_msg": err_msg}
            out_img = draw_boxes_in_image(bbox_data, ori_img)
            max_det = max(bbox_data, key=lambda k: k['score'])
            max_score = max_det['score']
            category = bbox_data[bbox_data.index(max_det)]['category_id']
            res_msg = '图中共标注了:{}个框,其中物种{}的得分最高, 为{}。'.format(bbox_num, category, round(max_score, 3))
            data = numpy2base64(cv2.cvtColor(out_img, cv2.COLOR_BGR2RGB))
        else:
            output_np = output.asnumpy()
            res_msg = trans_func.postprocess(output_np, strategy)
            data = numpy2base64(ori_img)

    res = {
        "status": 0,
        "instance": {
            "res_msg": res_msg,
            "data": data
        }
    }
    return res
Esempio n. 3
0
def predict(instance, servable_name, servable_model, strategy):
    """
    Predict the result based on the input data.

    A network will be constructed based on the input and servable data, then load the checkpoint and do the predict.

    Args:
        instance (dict): the dict of input image after transformation, with keys of `shape`, `dtype` and `data`(Image object).
        servable_name (str): servable name
        servable_model (str): name of the model
        strategy (str): output strategy, usually select between `TOP1_CLASS` and `TOP5_CLASS`, for cyclegan, select between `gray2color` and `color2gray`

    Returns:
        The dict object of predicted result after post process.

    Examples:
        >>> # In the server part, after servable_search
        >>> res = predict(instance, servable_name, servable['model'], strategy)
        >>> return jsonify(res)
    """

    # check if servable model name is valid
    model_name = servable_model['name']
    net_func = model_checker.get(model_name)
    if net_func is None:
        err_msg = "Currently model_name only supports " + str(list(model_checker.keys())) + "!"
        return {"status": 1, "err_msg": err_msg}

    # check if model_format is valid
    model_format = servable_model['format']
    if model_format not in ("ckpt"):
        err_msg = "Currently model_format only supports `ckpt`!"
        return {"status": 1, "err_msg": err_msg}

    # parse the input data
    input_data = ts.array(json.loads(instance['data']), dtype=instance['dtype'])

    if model_name == "cycle_gan":
        g_model = servable_model['g_model']
        if strategy == 'gray2color':
            # build the network
            G_generator, _ = net_func(g_model=g_model)
            ckpt_name = 'G_A'

        elif strategy == 'color2gray':
            _, G_generator = net_func(g_model=g_model)
            ckpt_name = 'G_B'
        else:
            err_msg = "Currently cycle_gan strategy only supports `gray2color` and `color2gray`!"
            return {"status": 1, "err_msg": err_msg}
        ckpt_path = os.path.join(serving_path, servable_name, ckpt_name + "." + model_format)
        data = cyclegan_predict(G_generator, input_data, ckpt_path)
    else:
        # build the network
        class_num = servable_model['class_num']
        net = net_func(class_num=class_num, is_training=False)
        serve_model = model.Model(net)

        # load checkpoint
        ckpt_path = os.path.join(serving_path, servable_name, model_name + "." + model_format)
        if not os.path.isfile(ckpt_path):
            err_msg = "The model path " + ckpt_path + " not exist!"
            return {"status": 1, "err_msg": err_msg}
        serve_model.load_checkpoint(ckpt_path)

        # execute the network to perform model prediction
        output = serve_model.predict(ts.expand_dims(input_data, 0))

        data = (ts.concatenate((output[0], output[1]), axis=-1).asnumpy() if model_name == "ssd300"
                else output.asnumpy())
    return {
        "status": 0,
        "instance": {
            "shape": data.shape,
            "dtype": data.dtype.name,
            "data": json.dumps(data.tolist())
        }
    }
Esempio n. 4
0
def predict(instance, servable_name, servable_model, strategy):
    # check if servable model name is valid
    model_name = servable_model['name']
    net_func = model_checker.get(model_name)
    if net_func is None:
        err_msg = "Currently model_name only supports " + str(
            list(model_checker.keys())) + "!"
        return {"status": 1, "err_msg": err_msg}

    # check if model_format is valid
    model_format = servable_model['format']
    if model_format not in ("ckpt"):
        err_msg = "Currently model_format only supports `ckpt`!"
        return {"status": 1, "err_msg": err_msg}

    # parse the input data
    input_data = ts.array(json.loads(instance['data']),
                          dtype=instance['dtype'])

    if model_name == "cycle_gan":
        g_model = servable_model['g_model']
        if strategy == 'gray2color':
            # build the network
            G_generator, _ = net_func(g_model=g_model)
            ckpt_name = 'G_A'

        elif strategy == 'color2gray':
            _, G_generator = net_func(g_model=g_model)
            ckpt_name = 'G_B'
        else:
            err_msg = "Currently cycle_gan strategy only supports `gray2color` and `color2gray`!"
            return {"status": 1, "err_msg": err_msg}
        ckpt_path = os.path.join("/etc/tinyms/serving", servable_name,
                                 ckpt_name + "." + model_format)

        data = cyclegan_predict(G_generator, input_data, ckpt_path)
    else:
        # build the network
        class_num = servable_model['class_num']
        net = net_func(class_num=class_num)
        serve_model = model.Model(net)

        # load checkpoint
        ckpt_path = os.path.join("/etc/tinyms/serving", servable_name,
                                 model_name + "." + model_format)
        if not os.path.isfile(ckpt_path):
            err_msg = "The model path " + ckpt_path + " not exist!"
            return {"status": 1, "err_msg": err_msg}
        serve_model.load_checkpoint(ckpt_path)

        # execute the network to perform model prediction
        output = serve_model.predict(ts.expand_dims(input_data, 0))

        data = (ts.concatenate((output[0], output[1]), axis=-1).asnumpy()
                if model_name == "ssd300" else output.asnumpy())
    return {
        "status": 0,
        "instance": {
            "shape": data.shape,
            "dtype": data.dtype.name,
            "data": json.dumps(data.tolist())
        }
    }