コード例 #1
0
    def detect(self, img):
        img, scale = resize_img(img)
        print('shape', img.shape)
        print('scale', scale)
        img = img_normailize(img)
        h, w, c = img.shape
        img_input = np.reshape(img, [1, h, w, c])
        img_info = [h, w, 1]
        s = time.time()
        scores, pp_boxes = self._get_net_output(img_input, img_info)
        print('net:', time.time() - s)
        print(scores)
        s = time.time()
        text_connector = TextConnector()
        # 得到是resize图像后的bbox
        # print(img_info)
        # print('boxes, scores[:, np.newaxis]',boxes.shape, scores[:, np.newaxis].shape,scores.shape)
        text_proposals, scores, boxes = text_connector.detect(
            pp_boxes, scores[:, np.newaxis], img_info[:2])

        print('merge:', time.time() - s)
        # 原图像的绝对bbox位置
        original_bbox, scores = self._resize_bbox(boxes, scale)

        return pp_boxes / scale, original_bbox, scores
コード例 #2
0
    def getbatch(self):
        img = cv2.imread(os.path.join(self.img_dir, self.img_path_list[self.current_index]))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img, _ = resize_img(img)
        img = img_normailize(img)
        h, w, c = img.shape
        # print(img.shape)
        img_data = np.reshape(img, [1, h, w, c])

        label_name = self.img_path_list[self.current_index].split('.')[0] + '.txt'
        assert os.path.exists(os.path.join(self.label_dir, label_name)), \
            "{} is not exist".format(os.path.join(self.label_dir, label_name))

        labels_data = []
        with open(os.path.join(self.label_dir, label_name)) as f:
            lines = f.readlines()
            for line in lines:
                encode_data = line.split('\t')
                label = []
                label.append(float(encode_data[1]))
                label.append(float(encode_data[2]))
                label.append(float(encode_data[3]))
                label.append(float(encode_data[4]))
                label.append(int(self.class_dict[encode_data[0]]))
                labels_data.append(label)

        if self.current_index + 1 + 1 <= len(self.img_path_list):
            self.current_index += 1
        elif self.current_index + 1 + 1 > len(self.img_path_list):
            self.epoch += 1
            self.current_index = 0
            random.shuffle(self.img_path_list)

        return img_data, np.array(labels_data, np.int), np.array(img.shape).reshape([1, 3])
コード例 #3
0
    if basename.lower().split('.')[-1] not in [
            'jpg', 'png', 'JPG', 'JPEG', 'jpeg', 'PNG'
    ]:
        print(basename.lower().split('.')[-1])
        continue
    stem, ext = os.path.splitext(basename)
    txt_file = os.path.join(xml_dir, stem + '.txt')
    img_path = os.path.join(img_dir, file)
    # print(img_path)

    img = cv.imread(img_path)
    img_size = img.shape
    im_size_min = np.min(img_size[0:2])
    im_size_max = np.max(img_size[0:2])

    re_im, im_scale = resize_img(img)

    re_size = re_im.shape
    cv.imwrite(os.path.join(out_path, stem) + '.jpg', re_im)

    class_list, bbox_list = parse_txt(txt_file)

    assert len(class_list) == len(bbox_list), 'bbox和label不对应'

    assert len(class_list) > 0, 'xml文件有问题{}'.format(txt_file)

    for bbox_index in range(len(bbox_list)):
        # if class_list[bbox_index] == 2:
        #     continue

        if len(bbox_list[bbox_index]) == 8: