from defect_detection.config import Config from defect_detection.env import set_runtime_environment from defect_detection.model.callbacks import get_callbacks from defect_detection.model.generator import get_generators from defect_detection.model.loss import getLoss from defect_detection.model.optimizer import get_optimizer from defect_detection.model.resnet import resnet_retinanet set_runtime_environment() # 获取配置 config = Config('configRetinaNet.json') # 如果使用resnet if config.type.startswith('resnet'): model, bodyLayers = resnet_retinanet(num_classes=len(config.classes), backbone=config.type, weights='imagenet', nms=True, config=config) model.summary() else: model = None bodyLayers = None print("不存在相关网络({})".format(config.type)) exit(1) print("backend: ", config.type) # 载入预训练权重 model.load_weights('./h5/result.h5', by_name=True, skip_mismatch=True) # wpath = '../taurus_cv/pretrained_models/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5' # model.load_weights(wpath, by_name=True, skip_mismatch=True) # if config.do_freeze_layers:
from defect_detection.config import Config from taurus_cv.models.faster_rcnn.utils import np_utils, eval_utils from taurus_cv.utils.spe import spe time_start = time.time() config = Config('configRetinaNet.json') wname = 'BASE' wpath = config.trained_weights_path classes = config.classes if config.type.startswith('resnet'): model, _ = resnet_retinanet(len(classes), backbone=config.type, weights='imagenet', nms=True, config=config) else: model = None print("模型 ({})".format(config.type)) exit(1) print("backend: ", config.type) if os.path.isfile(wpath): model.load_weights(wpath, by_name=True, skip_mismatch=True) print("权重" + wname) else: print("None")
def inference(): set_runtime_environment() start_time = time.time() config = Config('configRetinaNet.json') wpath = config.trained_weights_path result_path = config.test_result_path txt_path = result_path.replace('results', 'txt') classes = config.classes if not os.path.exists(result_path): os.makedirs(result_path) if not os.path.exists(txt_path): os.makedirs(txt_path) model, _ = resnet_retinanet(len(classes), backbone=config.type, weights='imagenet', nms=True, config=config) model.load_weights(wpath, by_name=True, skip_mismatch=True) files = sorted(os.listdir(config.test_images_path)) for nimage, imgf in enumerate(files): # if imgf not in ['ship0201606110201801.jpg', 'ship0201606110201902.jpg', 'ship02016061102012014.jpg', 'ship0201711030202902.jpg']: # continue # if nimage >= int(len(files) * 0.1): # break imgfp = os.path.join(config.test_images_path, imgf) if os.path.isfile(imgfp): try: img = read_image_bgr(imgfp) except: continue img = preprocess_image(img.copy()) img, scale = resize_image(img, min_side=config.img_min_size, max_side=config.img_max_size) _, _, detections = model.predict_on_batch( np.expand_dims(img, axis=0)) # bbox要取到边界内 detections[:, :, 0] = np.maximum(0, detections[:, :, 0]) detections[:, :, 1] = np.maximum(0, detections[:, :, 1]) detections[:, :, 2] = np.minimum(img.shape[1], detections[:, :, 2]) detections[:, :, 3] = np.minimum(img.shape[0], detections[:, :, 3]) detections[0, :, :4] /= scale scores = detections[0, :, 4:] # 推测置信度 indices = np.where(detections[0, :, 4:] >= 0.05) scores = scores[indices] scores_sort = np.argsort(-scores)[:100] image_boxes = detections[0, indices[0][scores_sort], :4] image_scores = np.expand_dims( detections[0, indices[0][scores_sort], 4 + indices[1][scores_sort]], axis=1) image_detections = np.append(image_boxes, image_scores, axis=1) image_predicted_labels = indices[1][scores_sort] txtfile = imgf.replace('.jpg', '.txt') realpath = os.path.join(txt_path, txtfile) f = open(realpath, 'w', encoding='utf-8') if len(image_boxes) > 0: for i, box in enumerate(image_boxes): xmin = int(box[0]) ymin = int(box[1]) xmax = int(box[2]) ymax = int(box[3]) # print(xmin, ymin, xmax, ymax) f.write('{} {} {} {} {} {}\n'.format( classes[image_predicted_labels[i]], xmin, ymin, xmax, ymax, image_scores[i][0])) f.close() print("生成txt '" + txtfile + "'" + ' time:{}, 目标框:{}'.format(time.time() - start_time, len(image_boxes))) start_time = time.time()