def detect(self, img): img, scale = resize_img(img) print('shape', img.shape) print('scale', scale) img = img_normailize(img) h, w, c = img.shape img_input = np.reshape(img, [1, h, w, c]) img_info = [h, w, 1] s = time.time() scores, pp_boxes = self._get_net_output(img_input, img_info) print('net:', time.time() - s) print(scores) s = time.time() text_connector = TextConnector() # 得到是resize图像后的bbox # print(img_info) # print('boxes, scores[:, np.newaxis]',boxes.shape, scores[:, np.newaxis].shape,scores.shape) text_proposals, scores, boxes = text_connector.detect( pp_boxes, scores[:, np.newaxis], img_info[:2]) print('merge:', time.time() - s) # 原图像的绝对bbox位置 original_bbox, scores = self._resize_bbox(boxes, scale) return pp_boxes / scale, original_bbox, scores
def getbatch(self): img = cv2.imread(os.path.join(self.img_dir, self.img_path_list[self.current_index])) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img, _ = resize_img(img) img = img_normailize(img) h, w, c = img.shape # print(img.shape) img_data = np.reshape(img, [1, h, w, c]) label_name = self.img_path_list[self.current_index].split('.')[0] + '.txt' assert os.path.exists(os.path.join(self.label_dir, label_name)), \ "{} is not exist".format(os.path.join(self.label_dir, label_name)) labels_data = [] with open(os.path.join(self.label_dir, label_name)) as f: lines = f.readlines() for line in lines: encode_data = line.split('\t') label = [] label.append(float(encode_data[1])) label.append(float(encode_data[2])) label.append(float(encode_data[3])) label.append(float(encode_data[4])) label.append(int(self.class_dict[encode_data[0]])) labels_data.append(label) if self.current_index + 1 + 1 <= len(self.img_path_list): self.current_index += 1 elif self.current_index + 1 + 1 > len(self.img_path_list): self.epoch += 1 self.current_index = 0 random.shuffle(self.img_path_list) return img_data, np.array(labels_data, np.int), np.array(img.shape).reshape([1, 3])
if basename.lower().split('.')[-1] not in [ 'jpg', 'png', 'JPG', 'JPEG', 'jpeg', 'PNG' ]: print(basename.lower().split('.')[-1]) continue stem, ext = os.path.splitext(basename) txt_file = os.path.join(xml_dir, stem + '.txt') img_path = os.path.join(img_dir, file) # print(img_path) img = cv.imread(img_path) img_size = img.shape im_size_min = np.min(img_size[0:2]) im_size_max = np.max(img_size[0:2]) re_im, im_scale = resize_img(img) re_size = re_im.shape cv.imwrite(os.path.join(out_path, stem) + '.jpg', re_im) class_list, bbox_list = parse_txt(txt_file) assert len(class_list) == len(bbox_list), 'bbox和label不对应' assert len(class_list) > 0, 'xml文件有问题{}'.format(txt_file) for bbox_index in range(len(bbox_list)): # if class_list[bbox_index] == 2: # continue if len(bbox_list[bbox_index]) == 8: