def model_build(self, is_training=False): r""" Build the object detection model to predict the image. Args: is_training (bool): default: False. Returns: model.Model, generated object detection model. """ model_net = model_checker.get(self.config.get('model_net')) if not model_net: err_msg = 'Currently model_net only supports {}!'.format( str(list(model_checker.keys()))) raise KeyError(err_msg) class_num = self.config.get('class_num') if class_num <= 0: err_msg = 'The class_num should be an integer greater than 0, got {}.'.format( class_num) raise ValueError(err_msg) net = model_net(class_num=class_num, is_training=is_training) serve_model = model.Model(net) return serve_model
def web_predict(instance, servable_name, servable_model, dataset_name, strategy): """ Predict the result based on the input data. A network will be constructed based on the input and servable data, then load the checkpoint and do the predict. Args: instance (dict): the dict of input image after transformation, with keys of `shape`, `dtype` and `data`(Image object). servable_name (str): servable name servable_model (str): name of the model strategy (str): output strategy, usually select between `TOP1_CLASS` and `TOP5_CLASS`, for cyclegan, select between `gray2color` and `color2gray` Returns: The dict object of predicted result after post process. Examples: >>> # In the server part, after servable_search >>> res = web_predict(instance, servable_name, servable['model'], strategy) >>> return jsonify(res) """ # check if servable model name is valid model_name = servable_model['name'] net_func = model_checker.get(model_name) if net_func is None: err_msg = "Currently model_name only supports " + str(list(model_checker.keys())) + "!" return {"status": 1, "err_msg": err_msg} # check if model_format is valid model_format = servable_model['format'] if model_format not in ("ckpt"): err_msg = "Currently model_format only supports `ckpt`!" return {"status": 1, "err_msg": err_msg} # Check if dataset supports trans_func = transform_checker.get(dataset_name) if trans_func is None: print("Currently dataset_name only supports {}!".format(list(transform_checker.keys()))) sys.exit(0) # process the original data ori_img = np.array(json.loads(instance['data']), dtype=instance['dtype']) if dataset_name in ['mnist']: image = trans_func(ori_img) else: cvt_image = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB) image = trans_func(cvt_image) input_data = ts.array(image.tolist(), dtype=image.dtype.name) res_msg = '' if model_name == "cycle_gan": g_model = servable_model['g_model'] if strategy == 'gray2color': # build the network G_generator, _ = net_func(g_model=g_model) ckpt_name = 'G_A' elif strategy == 'color2gray': _, G_generator = net_func(g_model=g_model) ckpt_name = 'G_B' else: err_msg = "Currently cycle_gan strategy only supports `gray2color` and `color2gray`!" return {"status": 1, "err_msg": err_msg} ckpt_path = os.path.join(serving_path, servable_name, ckpt_name + "." + model_format) out_img = cyclegan_predict(G_generator, input_data, ckpt_path) res_msg = '原图使用{}风格迁移效果'.format(strategy) data = numpy2base64(out_img) else: # build the network class_num = servable_model['class_num'] net = net_func(class_num=class_num, is_training=False) serve_model = model.Model(net) # load checkpoint ckpt_path = os.path.join(serving_path, servable_name, model_name + "." + model_format) if not os.path.isfile(ckpt_path): err_msg = "The model path " + ckpt_path + " not exist!" return {"status": 1, "err_msg": err_msg} serve_model.load_checkpoint(ckpt_path) # execute the network to perform model prediction output = serve_model.predict(ts.expand_dims(input_data, 0)) if model_name == "ssd300": output_np = (ts.concatenate((output[0], output[1]), axis=-1).asnumpy()) ih, iw, _ = instance['shape'] bbox_data = trans_func.postprocess(output_np, (ih, iw), strategy) print(bbox_data) bbox_num = len(bbox_data) if not bbox_num: err_msg = "抱歉!未检测到任何种类,无法标注。" return {"status": 1, "err_msg": err_msg} out_img = draw_boxes_in_image(bbox_data, ori_img) max_det = max(bbox_data, key=lambda k: k['score']) max_score = max_det['score'] category = bbox_data[bbox_data.index(max_det)]['category_id'] res_msg = '图中共标注了:{}个框,其中物种{}的得分最高, 为{}。'.format(bbox_num, category, round(max_score, 3)) data = numpy2base64(cv2.cvtColor(out_img, cv2.COLOR_BGR2RGB)) else: output_np = output.asnumpy() res_msg = trans_func.postprocess(output_np, strategy) data = numpy2base64(ori_img) res = { "status": 0, "instance": { "res_msg": res_msg, "data": data } } return res
def predict(instance, servable_name, servable_model, strategy): """ Predict the result based on the input data. A network will be constructed based on the input and servable data, then load the checkpoint and do the predict. Args: instance (dict): the dict of input image after transformation, with keys of `shape`, `dtype` and `data`(Image object). servable_name (str): servable name servable_model (str): name of the model strategy (str): output strategy, usually select between `TOP1_CLASS` and `TOP5_CLASS`, for cyclegan, select between `gray2color` and `color2gray` Returns: The dict object of predicted result after post process. Examples: >>> # In the server part, after servable_search >>> res = predict(instance, servable_name, servable['model'], strategy) >>> return jsonify(res) """ # check if servable model name is valid model_name = servable_model['name'] net_func = model_checker.get(model_name) if net_func is None: err_msg = "Currently model_name only supports " + str(list(model_checker.keys())) + "!" return {"status": 1, "err_msg": err_msg} # check if model_format is valid model_format = servable_model['format'] if model_format not in ("ckpt"): err_msg = "Currently model_format only supports `ckpt`!" return {"status": 1, "err_msg": err_msg} # parse the input data input_data = ts.array(json.loads(instance['data']), dtype=instance['dtype']) if model_name == "cycle_gan": g_model = servable_model['g_model'] if strategy == 'gray2color': # build the network G_generator, _ = net_func(g_model=g_model) ckpt_name = 'G_A' elif strategy == 'color2gray': _, G_generator = net_func(g_model=g_model) ckpt_name = 'G_B' else: err_msg = "Currently cycle_gan strategy only supports `gray2color` and `color2gray`!" return {"status": 1, "err_msg": err_msg} ckpt_path = os.path.join(serving_path, servable_name, ckpt_name + "." + model_format) data = cyclegan_predict(G_generator, input_data, ckpt_path) else: # build the network class_num = servable_model['class_num'] net = net_func(class_num=class_num, is_training=False) serve_model = model.Model(net) # load checkpoint ckpt_path = os.path.join(serving_path, servable_name, model_name + "." + model_format) if not os.path.isfile(ckpt_path): err_msg = "The model path " + ckpt_path + " not exist!" return {"status": 1, "err_msg": err_msg} serve_model.load_checkpoint(ckpt_path) # execute the network to perform model prediction output = serve_model.predict(ts.expand_dims(input_data, 0)) data = (ts.concatenate((output[0], output[1]), axis=-1).asnumpy() if model_name == "ssd300" else output.asnumpy()) return { "status": 0, "instance": { "shape": data.shape, "dtype": data.dtype.name, "data": json.dumps(data.tolist()) } }
def predict(instance, servable_name, servable_model, strategy): # check if servable model name is valid model_name = servable_model['name'] net_func = model_checker.get(model_name) if net_func is None: err_msg = "Currently model_name only supports " + str( list(model_checker.keys())) + "!" return {"status": 1, "err_msg": err_msg} # check if model_format is valid model_format = servable_model['format'] if model_format not in ("ckpt"): err_msg = "Currently model_format only supports `ckpt`!" return {"status": 1, "err_msg": err_msg} # parse the input data input_data = ts.array(json.loads(instance['data']), dtype=instance['dtype']) if model_name == "cycle_gan": g_model = servable_model['g_model'] if strategy == 'gray2color': # build the network G_generator, _ = net_func(g_model=g_model) ckpt_name = 'G_A' elif strategy == 'color2gray': _, G_generator = net_func(g_model=g_model) ckpt_name = 'G_B' else: err_msg = "Currently cycle_gan strategy only supports `gray2color` and `color2gray`!" return {"status": 1, "err_msg": err_msg} ckpt_path = os.path.join("/etc/tinyms/serving", servable_name, ckpt_name + "." + model_format) data = cyclegan_predict(G_generator, input_data, ckpt_path) else: # build the network class_num = servable_model['class_num'] net = net_func(class_num=class_num) serve_model = model.Model(net) # load checkpoint ckpt_path = os.path.join("/etc/tinyms/serving", servable_name, model_name + "." + model_format) if not os.path.isfile(ckpt_path): err_msg = "The model path " + ckpt_path + " not exist!" return {"status": 1, "err_msg": err_msg} serve_model.load_checkpoint(ckpt_path) # execute the network to perform model prediction output = serve_model.predict(ts.expand_dims(input_data, 0)) data = (ts.concatenate((output[0], output[1]), axis=-1).asnumpy() if model_name == "ssd300" else output.asnumpy()) return { "status": 0, "instance": { "shape": data.shape, "dtype": data.dtype.name, "data": json.dumps(data.tolist()) } }