Пример #1
0
    batch_size = args.batch_size
    test_batch_size = args.test_batch_size
    epochs = args.epochs  # 遍历数据集次数
    lr = args.lr  # 学习率
    test_interval = args.test_interval #测试间隔
    pretrained_model = args.pretrained_model #预训练模型
    net = CRAFT(pretrained=True)  # craft模型

    if args.cuda:
        net.load_state_dict(copyStateDict(torch.load(pretrained_model)))
    else:
        net.load_state_dict(copyStateDict(torch.load(pretrained_model, map_location='cpu')))

    if args.cuda:
        net = net.cuda()
        net = torch.nn.DataParallel(net)
        cudnn.benchmark = False

    net = net.to(device)
    net.train()
    model_save_prefix = 'finetune/craft_finetune_'
    try:
        train(net=net,
              epochs=epochs,
              batch_size=batch_size,
              test_batch_size=test_batch_size,
              lr=lr,test_interval=test_interval,
              test_model_path=pretrained_model,
              model_save_prefix =  model_save_prefix)
    except KeyboardInterrupt:
class Icdar2013Dataset(torch.utils.data.Dataset):
    def __init__(self,
                 cuda=False,
                 image_transform=None,
                 label_transform=None,
                 target_transform=None,
                 model_path=None,
                 images_dir=None,
                 labels_dir=None):
        super(Icdar2013Dataset, self).__init__()  #继承父类构造方法
        self.model_path = model_path
        #图片名和标签数据(不是标签名)
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.image_names, self.label_names = load_icdar2013(
            self.images_dir, self.labels_dir)
        self.craft = CRAFT()
        self.cuda = cuda
        if self.cuda:
            self.craft.load_state_dict(
                copyStateDict(torch.load(self.model_path)))
            self.craft = self.craft.cuda()
            self.net = torch.nn.DataParallel(self.craft)
            cudnn.benchmark = False
        else:
            self.craft.load_state_dict(
                copyStateDict(torch.load(self.model_path, map_location='cpu')))

        self.craft.eval()

        self.image_transform = image_transform
        self.label_transform = label_transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.image_names)

    # label应为高斯热力图
    def __getitem__(self, idx):
        fn = self.image_names[idx]
        image = load_image(os.path.join(self.images_dir, fn))
        label_name = self.label_names[idx]
        word_boxes, words = get_wordsList(
            os.path.join(self.labels_dir, label_name))
        char_boxes_list, affinity_boxes_list, confidence_list = self.get_affinity_boxes_list(
            image, word_boxes, words)
        height, width = image.shape[:2]  #opencv方式
        heat_map_size = (height, width)
        #get pixel-wise confidence map
        sc_map = self.get_sc_map(heat_map_size, word_boxes,
                                 confidence_list) * 255
        region_scores = self.get_region_scores(heat_map_size,
                                               char_boxes_list) * 255
        affinity_scores = self.get_region_scores(heat_map_size,
                                                 affinity_boxes_list) * 255

        #opencv转为PIL.Image
        image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        #numpy.ndarray转为PIL.Image
        region_scores = Image.fromarray(np.uint8(region_scores))
        affinity_scores = Image.fromarray(np.uint8(affinity_scores))
        sc_map = Image.fromarray(np.uint8(sc_map))
        if self.image_transform is not None:
            image = self.image_transform(image)

        if self.label_transform is not None:
            region_scores = self.label_transform(region_scores)
            affinity_scores = self.label_transform(affinity_scores)
            sc_map = self.label_transform(sc_map)
        return image, region_scores, affinity_scores, sc_map

    def fake_char_boxes(self, src, word_box, word_length):
        img, src_points, crop_points = crop_image(src,
                                                  word_box,
                                                  dst_height=64.)
        h, w = img.shape[:2]
        if min(h, w) == 0:
            confidence = 0.5
            region_boxes = divide_region(word_box, word_length)
            region_boxes = [
                reorder_points(region_box) for region_box in region_boxes
            ]
            return region_boxes, confidence

        img = img_normalize(img)
        region_score, affinity_score = self.test_net(self.craft, img,
                                                     self.cuda)
        heat_map = region_score * 255.
        heat_map = heat_map.astype(np.uint8)
        marker_map = watershed(heat_map)
        region_boxes = find_box(marker_map)
        confidence = cal_confidence(region_boxes, word_length)
        if confidence <= 0.5:
            confidence = 0.5
            region_boxes = divide_region(word_box, word_length)
            region_boxes = [
                reorder_points(region_box) for region_box in region_boxes
            ]
        else:
            region_boxes = divide_region(word_box, word_length)
            region_boxes = [
                reorder_points(region_box) for region_box in region_boxes
            ]

        return region_boxes, confidence

    def get_affinity_boxes_list(self, image, word_boxes, words):
        char_boxes_list = list()
        affinity_boxes_list = list()
        confidence_list = list()

        for word_box, word in zip(word_boxes, words):
            char_boxes, confidence = self.fake_char_boxes(
                image, word_box, len(word))
            affinity_boxes = cal_affinity_boxes(char_boxes)
            affinity_boxes_list.append((affinity_boxes))
            char_boxes_list.append(char_boxes)
            confidence_list.append(confidence)

        char_boxes_list = list(chain.from_iterable(char_boxes_list))
        affinity_boxes_list = list(chain.from_iterable(affinity_boxes_list))
        return char_boxes_list, affinity_boxes_list, confidence_list

    def get_region_scores(self, heat_map_size, char_boxes_list):
        # 高斯热力图
        gaussian_generator = GaussianGenerator()
        char_boxes_list = np.array(char_boxes_list, dtype=np.float32)
        region_scores = gaussian_generator.gen(heat_map_size, char_boxes_list)
        return region_scores

    def get_sc_map(self, heat_map_size, word_boxes, confidence_list):
        """
        :param heat_map_size:
        :param word_boxes:
        :param confidence_list:
        :return: pixel-wise confidence map Sc
        """
        word_boxes = np.array(word_boxes, dtype=int)
        sc_map = np.ones(heat_map_size, dtype=np.float32)
        for (word_box, confidence) in zip(word_boxes, confidence_list):
            x_left = word_box[0, 0]
            y_top = word_box[0, 1]
            x_right = word_box[2, 0]
            y_down = word_box[2, 1]
            sc_map[y_top:y_down, x_left:x_right] = confidence
        return sc_map

    def test_net(self, net, image, cuda):

        # resize
        img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(
            image, 1280, interpolation=cv2.INTER_LINEAR, mag_ratio=1.5)
        # preprocessing
        x = imgproc.normalizeMeanVariance(img_resized)
        x = torch.from_numpy(x).permute(2, 0, 1)  # [h, w, c] to [c, h, w]
        x = Variable(x.unsqueeze(0))  # [c, h, w] to [b, c, h, w]
        if cuda:
            x = x.cuda()
        # forward pass
        with torch.no_grad():
            y, _ = net(x)
        #    make score and link map
        score_text = y[0, :, :, 0].cpu().data.numpy()
        score_link = y[0, :, :, 1].cpu().data.numpy()
        return score_text, score_link