Пример #1
0
def cyclegan_predict(G_generator, input_data, ckpt_path):
    """
    Single image predict.

    Args:
       G_generator (Generator): Generator, such as 'G_A'.
       input_data (Tensor): input image data.
       ckpt_path (str): the checkpoint path.

    Returns:
        Tensor, generated fake image data.

    Raises:
        ValueError: If `ckpt_path` does not exist.
        ValueError: If `fake_img` is not Tensor or Numpy.
    """
    G_generator.set_train(True)
    # load checkpoint
    if not os.path.isfile(ckpt_path):
        err_msg = "The model path " + ckpt_path + " not exist!"
        raise ValueError(err_msg)
    param_G = load_checkpoint(ckpt_path)
    load_param_into_net(G_generator, param_G)
    fake_img = G_generator(ts.expand_dims(input_data, 0))
    if isinstance(fake_img, Tensor):
        # Decode a [1, C, H, W] Tensor to image numpy array.
        mean = 0.5 * 255
        std = 0.5 * 255
        fake_img = (fake_img.asnumpy()[0] * std + mean).astype(
            np.uint8).transpose((1, 2, 0))
    elif not isinstance(fake_img, np.ndarray):
        raise ValueError(
            "img should be Tensor or numpy array, but get {}".format(
                type(fake_img)))
    return fake_img
Пример #2
0
    def convert2tensor(self, transform_input):
        r"""
        Convert the numpy data to the tensor format.

        Args:
            transform_input (numpy.ndarray): the preprocessing image.

        Returns:
            Tensor, the converted image.
        """
        if not isinstance(transform_input, np.ndarray):
            err_msg = 'The transform_input type should be numpy.ndarray, got {}.'.format(
                type(transform_input))
            raise TypeError(err_msg)
        input_tensor = ts.expand_dims(ts.array(list(transform_input)), 0)
        return input_tensor
Пример #3
0
def cyclegan_predict(G_generator, input_data, ckpt_path):
    G_generator.set_train(True)
    # load checkpoint
    if not os.path.isfile(ckpt_path):
        err_msg = "The model path " + ckpt_path + " not exist!"
        raise ValueError(err_msg)
    param_G = load_checkpoint(ckpt_path)
    load_param_into_net(G_generator, param_G)
    fake_img = G_generator(ts.expand_dims(input_data, 0))
    if isinstance(fake_img, Tensor):
        # Decode a [1, C, H, W] Tensor to image numpy array.
        mean = 0.5 * 255
        std = 0.5 * 255
        fake_img = (fake_img.asnumpy()[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
    elif not isinstance(fake_img, np.ndarray):
        raise ValueError("img should be Tensor or numpy array, but get {}".format(type(fake_img)))
    return fake_img
Пример #4
0
def web_predict(instance, servable_name, servable_model, dataset_name, strategy):
    """
    Predict the result based on the input data.

    A network will be constructed based on the input and servable data, then load the checkpoint and do the predict.

    Args:
        instance (dict): the dict of input image after transformation, with keys of `shape`, `dtype` and `data`(Image object).
        servable_name (str): servable name
        servable_model (str): name of the model
        strategy (str): output strategy, usually select between `TOP1_CLASS` and `TOP5_CLASS`, for cyclegan, select between `gray2color` and `color2gray`

    Returns:
        The dict object of predicted result after post process.

    Examples:
        >>> # In the server part, after servable_search
        >>> res = web_predict(instance, servable_name, servable['model'], strategy)
        >>> return jsonify(res)
    """

    # check if servable model name is valid
    model_name = servable_model['name']
    net_func = model_checker.get(model_name)
    if net_func is None:
        err_msg = "Currently model_name only supports " + str(list(model_checker.keys())) + "!"
        return {"status": 1, "err_msg": err_msg}

    # check if model_format is valid
    model_format = servable_model['format']
    if model_format not in ("ckpt"):
        err_msg = "Currently model_format only supports `ckpt`!"
        return {"status": 1, "err_msg": err_msg}

    # Check if dataset supports
    trans_func = transform_checker.get(dataset_name)
    if trans_func is None:
        print("Currently dataset_name only supports {}!".format(list(transform_checker.keys())))
        sys.exit(0)

    # process the original data
    ori_img = np.array(json.loads(instance['data']), dtype=instance['dtype'])
    if dataset_name in ['mnist']:
        image = trans_func(ori_img)
    else:
        cvt_image = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
        image = trans_func(cvt_image)

    input_data = ts.array(image.tolist(), dtype=image.dtype.name)

    res_msg = ''
    if model_name == "cycle_gan":
        g_model = servable_model['g_model']
        if strategy == 'gray2color':
            # build the network
            G_generator, _ = net_func(g_model=g_model)
            ckpt_name = 'G_A'

        elif strategy == 'color2gray':
            _, G_generator = net_func(g_model=g_model)
            ckpt_name = 'G_B'
        else:
            err_msg = "Currently cycle_gan strategy only supports `gray2color` and `color2gray`!"
            return {"status": 1, "err_msg": err_msg}
        ckpt_path = os.path.join(serving_path, servable_name, ckpt_name + "." + model_format)
        out_img = cyclegan_predict(G_generator, input_data, ckpt_path)
        res_msg = '原图使用{}风格迁移效果'.format(strategy)
        data = numpy2base64(out_img)
    else:
        # build the network
        class_num = servable_model['class_num']
        net = net_func(class_num=class_num, is_training=False)
        serve_model = model.Model(net)

        # load checkpoint
        ckpt_path = os.path.join(serving_path, servable_name, model_name + "." + model_format)
        if not os.path.isfile(ckpt_path):
            err_msg = "The model path " + ckpt_path + " not exist!"
            return {"status": 1, "err_msg": err_msg}
        serve_model.load_checkpoint(ckpt_path)

        # execute the network to perform model prediction
        output = serve_model.predict(ts.expand_dims(input_data, 0))

        if model_name == "ssd300":
            output_np = (ts.concatenate((output[0], output[1]), axis=-1).asnumpy())
            ih, iw, _ = instance['shape']
            bbox_data = trans_func.postprocess(output_np, (ih, iw), strategy)
            print(bbox_data)
            bbox_num = len(bbox_data)
            if not bbox_num:
                err_msg = "抱歉!未检测到任何种类,无法标注。"
                return {"status": 1, "err_msg": err_msg}
            out_img = draw_boxes_in_image(bbox_data, ori_img)
            max_det = max(bbox_data, key=lambda k: k['score'])
            max_score = max_det['score']
            category = bbox_data[bbox_data.index(max_det)]['category_id']
            res_msg = '图中共标注了:{}个框,其中物种{}的得分最高, 为{}。'.format(bbox_num, category, round(max_score, 3))
            data = numpy2base64(cv2.cvtColor(out_img, cv2.COLOR_BGR2RGB))
        else:
            output_np = output.asnumpy()
            res_msg = trans_func.postprocess(output_np, strategy)
            data = numpy2base64(ori_img)

    res = {
        "status": 0,
        "instance": {
            "res_msg": res_msg,
            "data": data
        }
    }
    return res
Пример #5
0
def predict(instance, servable_name, servable_model, strategy):
    """
    Predict the result based on the input data.

    A network will be constructed based on the input and servable data, then load the checkpoint and do the predict.

    Args:
        instance (dict): the dict of input image after transformation, with keys of `shape`, `dtype` and `data`(Image object).
        servable_name (str): servable name
        servable_model (str): name of the model
        strategy (str): output strategy, usually select between `TOP1_CLASS` and `TOP5_CLASS`, for cyclegan, select between `gray2color` and `color2gray`

    Returns:
        The dict object of predicted result after post process.

    Examples:
        >>> # In the server part, after servable_search
        >>> res = predict(instance, servable_name, servable['model'], strategy)
        >>> return jsonify(res)
    """

    # check if servable model name is valid
    model_name = servable_model['name']
    net_func = model_checker.get(model_name)
    if net_func is None:
        err_msg = "Currently model_name only supports " + str(list(model_checker.keys())) + "!"
        return {"status": 1, "err_msg": err_msg}

    # check if model_format is valid
    model_format = servable_model['format']
    if model_format not in ("ckpt"):
        err_msg = "Currently model_format only supports `ckpt`!"
        return {"status": 1, "err_msg": err_msg}

    # parse the input data
    input_data = ts.array(json.loads(instance['data']), dtype=instance['dtype'])

    if model_name == "cycle_gan":
        g_model = servable_model['g_model']
        if strategy == 'gray2color':
            # build the network
            G_generator, _ = net_func(g_model=g_model)
            ckpt_name = 'G_A'

        elif strategy == 'color2gray':
            _, G_generator = net_func(g_model=g_model)
            ckpt_name = 'G_B'
        else:
            err_msg = "Currently cycle_gan strategy only supports `gray2color` and `color2gray`!"
            return {"status": 1, "err_msg": err_msg}
        ckpt_path = os.path.join(serving_path, servable_name, ckpt_name + "." + model_format)
        data = cyclegan_predict(G_generator, input_data, ckpt_path)
    else:
        # build the network
        class_num = servable_model['class_num']
        net = net_func(class_num=class_num, is_training=False)
        serve_model = model.Model(net)

        # load checkpoint
        ckpt_path = os.path.join(serving_path, servable_name, model_name + "." + model_format)
        if not os.path.isfile(ckpt_path):
            err_msg = "The model path " + ckpt_path + " not exist!"
            return {"status": 1, "err_msg": err_msg}
        serve_model.load_checkpoint(ckpt_path)

        # execute the network to perform model prediction
        output = serve_model.predict(ts.expand_dims(input_data, 0))

        data = (ts.concatenate((output[0], output[1]), axis=-1).asnumpy() if model_name == "ssd300"
                else output.asnumpy())
    return {
        "status": 0,
        "instance": {
            "shape": data.shape,
            "dtype": data.dtype.name,
            "data": json.dumps(data.tolist())
        }
    }
Пример #6
0
def predict(instance, servable_name, servable_model, strategy):
    # check if servable model name is valid
    model_name = servable_model['name']
    net_func = model_checker.get(model_name)
    if net_func is None:
        err_msg = "Currently model_name only supports " + str(
            list(model_checker.keys())) + "!"
        return {"status": 1, "err_msg": err_msg}

    # check if model_format is valid
    model_format = servable_model['format']
    if model_format not in ("ckpt"):
        err_msg = "Currently model_format only supports `ckpt`!"
        return {"status": 1, "err_msg": err_msg}

    # parse the input data
    input_data = ts.array(json.loads(instance['data']),
                          dtype=instance['dtype'])

    if model_name == "cycle_gan":
        g_model = servable_model['g_model']
        if strategy == 'gray2color':
            # build the network
            G_generator, _ = net_func(g_model=g_model)
            ckpt_name = 'G_A'

        elif strategy == 'color2gray':
            _, G_generator = net_func(g_model=g_model)
            ckpt_name = 'G_B'
        else:
            err_msg = "Currently cycle_gan strategy only supports `gray2color` and `color2gray`!"
            return {"status": 1, "err_msg": err_msg}
        ckpt_path = os.path.join("/etc/tinyms/serving", servable_name,
                                 ckpt_name + "." + model_format)

        data = cyclegan_predict(G_generator, input_data, ckpt_path)
    else:
        # build the network
        class_num = servable_model['class_num']
        net = net_func(class_num=class_num)
        serve_model = model.Model(net)

        # load checkpoint
        ckpt_path = os.path.join("/etc/tinyms/serving", servable_name,
                                 model_name + "." + model_format)
        if not os.path.isfile(ckpt_path):
            err_msg = "The model path " + ckpt_path + " not exist!"
            return {"status": 1, "err_msg": err_msg}
        serve_model.load_checkpoint(ckpt_path)

        # execute the network to perform model prediction
        output = serve_model.predict(ts.expand_dims(input_data, 0))

        data = (ts.concatenate((output[0], output[1]), axis=-1).asnumpy()
                if model_name == "ssd300" else output.asnumpy())
    return {
        "status": 0,
        "instance": {
            "shape": data.shape,
            "dtype": data.dtype.name,
            "data": json.dumps(data.tolist())
        }
    }