def predict_common_func1(final, image_paths, params, out_dir): pred_od_bboxs = final.data.cpu().numpy() for image_file in image_paths: raw = Image.open(image_paths[0]) raw, _, _, _ = scale_image(raw, 512) raw = np.array(raw) im2show = np.copy(raw) im2show = cv2.cvtColor(im2show, cv2.COLOR_RGB2BGR) bbox = final.data[0].cpu().numpy() bbox = [int(x) for x in bbox] param = [params[0][0], params[1][0], params[2][0]] # bbox = pts_trans_inv(bbox, param[0], param[1], param[2]) cv2.rectangle(im2show, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 255, 0), 4) # width = bbox[2] - bbox[0] # height = bbox[3] - bbox[1] # # cv2.rectangle(im2show, (bbox[4]-width//2, bbox[5]-height//2), (bbox[4]+width//2, bbox[5]+height//2), (0, 255, 255), 4) cv2.rectangle(im2show, (bbox[4] - 20, bbox[5] - 20), (bbox[4] + 20, bbox[5] + 20), (255, 255, 0), 4) cv2.circle(im2show, (bbox[4], bbox[5]), 4, (0, 255, 255)) cv2.imshow('test', im2show) os.makedirs(out_dir, exist_ok=True) out_file = os.path.join( out_dir, os.path.basename(image_file).split('.')[0] + '_det.jpg') cv2.imwrite(out_file, im2show)
def cls_predict(val_data_loader, model, criterion, display): model.eval() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() result = np.array([], dtype=int) ious = np.array([], dtype=float) images_list = [] end = time.time() logger = [] trans = ToPILImage() for num_iter, (images, _, image_paths, bboxs, bboxs_c, params) in enumerate(val_data_loader): for image_file in image_paths: images_list.append(image_file) data_time.update(time.time() - end) final, map = model(Variable(images)) # loss = criterion(final, bbox_od) batch_time.update(time.time() - end) end = time.time() # im2show = np.copy(np.array(trans(images[0]))) # raw = cv2.imread(image_paths[0]) pred_od_bboxs = final.data.cpu().numpy() gt_od_bboxs = bboxs.numpy() tmp, tmp_ious = get_detect_od_array(pred_od_bboxs, gt_od_bboxs) result = np.append(result, tmp) ious = np.append(ious, tmp_ious) raw = Image.open(image_paths[0]) raw,_,_,_ = scale_image(raw, 512) raw = np.array(raw) im2show = np.copy(raw) im2show = cv2.cvtColor(im2show, cv2.COLOR_RGB2BGR) bbox = final.data[0].cpu().numpy() bbox = [int(x) for x in bbox] param = [params[0][0], params[1][0], params[2][0]] # bbox = pts_trans_inv(bbox, param[0], param[1], param[2]) cv2.rectangle(im2show, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 255, 0), 4) # width = bbox[2] - bbox[0] # height = bbox[3] - bbox[1] # # cv2.rectangle(im2show, (bbox[4]-width//2, bbox[5]-height//2), (bbox[4]+width//2, bbox[5]+height//2), (0, 255, 255), 4) cv2.rectangle(im2show, (bbox[4] - 20, bbox[5] - 20), (bbox[4] + 20, bbox[5] + 20), (255, 255, 0), 4) cv2.circle(im2show, (bbox[4], bbox[5]), 4, (0, 255, 255)) cv2.imshow('test', im2show) cv2.waitKey(2000) print_info = '[optic_disc detection]:\tthreshold:{}\tdetection accuracy:{}'.format(0.5, result.sum() / len(result)) print(print_info) logger.append(print_info) assert len(ious) == len(images_list) error_image_list = [] error_thres = 0.5 for i in range(len(ious)): if ious[i] < error_thres: error_image_list.append(images_list[i]) print(error_image_list) logger.append(''.format(error_image_list)) return logger
def single_processing(image_file, data_root, p_data_root): image_file = image_file xml_file = os.path.join( data_root, os.path.basename(image_file).split('.')[0] + '.xml') if not os.path.exists(xml_file): return p_data_root = p_data_root p_image_file = os.path.join( p_data_root, os.path.basename(image_file).split('.')[0] + '.png') p_xml_file = os.path.join( p_data_root, os.path.basename(image_file).split('.')[0] + '.xml') tree = ET.parse(xml_file) pil_img = Image.open(image_file) pil_img, l, u, ratio = scale_image(pil_img, 512) for obj in tree.getiterator('object'): if (obj.find('name').text == 'optic_disk' or obj.find('name').text == 'optic_disc' or obj.find('name').text == 'optic-disc' or obj.find('name').text == 'macular'): ann = {} # ann['cls_id'] = obj.find('name').text ann['ordered_id'] = 1 if ( obj.find('name').text == 'optic_disk' or obj.find('name').text == 'optic_disc') else 2 # ann['bbox'] = [0] * 4 xmin = obj.find('bndbox').find('xmin') ymin = obj.find('bndbox').find('ymin') xmax = obj.find('bndbox').find('xmax') ymax = obj.find('bndbox').find('ymax') tmp = np.array([], dtype=int) tmp = np.append( tmp, np.array([ int(xmin.text), int(ymin.text), int(xmax.text), int(ymax.text) ])) tmp = pts_trans(tmp, l, u, ratio) tmp = [int(i) for i in tmp] obj.find('bndbox').find('xmin').text = str(tmp[0]) obj.find('bndbox').find('ymin').text = str(tmp[1]) obj.find('bndbox').find('xmax').text = str(tmp[2]) obj.find('bndbox').find('ymax').text = str(tmp[3]) pil_img.save(p_image_file) tree.write(p_xml_file)