Example #1
0
def main():

    testset = DeployDataset(image_root=cfg.img_root,
                            transform=BaseTransform(size=cfg.input_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
    print(cfg)
    test_loader = data.DataLoader(testset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=cfg.num_workers)

    # Model
    model = TextNet(is_training=False, backbone=cfg.net)
    model_path = os.path.join(cfg.save_dir, cfg.exp_name, \
              'textsnake_{}_{}.pth'.format(model.backbone_name, cfg.checkepoch))
    model.load_model(model_path)

    # copy to cuda
    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    detector = TextDetector(model,
                            tr_thresh=cfg.tr_thresh,
                            tcl_thresh=cfg.tcl_thresh)

    print('Start testing TextSnake.')
    output_dir = os.path.join(cfg.output_dir, cfg.exp_name)
    inference(detector, test_loader, output_dir)
Example #2
0
    def Model_Params(self, model_type="vgg", model_path=None, use_gpu=True):
        self.system_dict["local"]["net"] = model_type
        self.system_dict["local"]["model_path"] = model_path
        self.system_dict["local"]["cuda"] = use_gpu

        self.system_dict["local"]["cfg"] = cfg

        self.system_dict["local"]["cfg"].net = self.system_dict["local"]["net"]
        self.system_dict["local"]["cfg"].cuda = self.system_dict["local"][
            "cuda"]
        self.system_dict["local"]["cfg"].means = self.system_dict["local"][
            "means"]
        self.system_dict["local"]["cfg"].stds = self.system_dict["local"][
            "stds"]
        self.system_dict["local"]["cfg"].input_size = self.system_dict[
            "local"]["input_size"]

        model = TextNet(is_training=False,
                        backbone=self.system_dict["local"]["cfg"].net)
        model.load_model(model_path)

        # copy to cuda
        if (self.system_dict["local"]["cfg"].cuda):
            cudnn.benchmark = True
            self.system_dict["local"]["cfg"].device = torch.device("cuda")
        else:
            self.system_dict["local"]["cfg"].device = torch.device("cpu")

        self.system_dict["local"]["model"] = model.to(
            self.system_dict["local"]["cfg"].device)
Example #3
0
def main():

    testset = TotalText(
        data_root='data/total-text',
        ignore_list=None,
        is_training=False,
        transform=BaseTransform(size=cfg.input_size, mean=cfg.means, std=cfg.stds)
    )
    test_loader = data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=cfg.num_workers)

    # Model
    # 载入模型
    model = TextNet(is_training=False, backbone=cfg.net)
    # 载入参数
    model_path = os.path.join(cfg.save_dir, cfg.exp_name, \
              'textsnake_{}_{}.pth'.format(model.backbone_name, cfg.checkepoch))
    model.load_model(model_path)

    # copy to cuda
    model = model.to(cfg.device)
    if cfg.cuda:
        # cudn加速运算
        cudnn.benchmark = True
    detector = TextDetector(model, tr_thresh=cfg.tr_thresh, tcl_thresh=cfg.tcl_thresh)

    print('Start testing TextSnake.')
    output_dir = os.path.join(cfg.output_dir, cfg.exp_name)
    inference(detector, test_loader, output_dir)

    # compute DetEval
    print('Computing DetEval in {}/{}'.format(cfg.output_dir, cfg.exp_name))
    subprocess.call(['python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', args.exp_name, '--tr', '0.7', '--tp', '0.6'])
    subprocess.call(['python', 'dataset/total_text/Evaluation_Protocol/Python_scripts/Deteval.py', args.exp_name, '--tr', '0.8', '--tp', '0.4'])
    print('End.')
Example #4
0
def main(vis_dir_path):

    osmkdir(vis_dir_path)
    if cfg.exp_name == "Totaltext":
        testset = TotalText(data_root='data/total-text-mat',
                            ignore_list=None,
                            is_training=False,
                            transform=BaseTransform(size=cfg.test_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))

    elif cfg.exp_name == "Ctw1500":
        testset = Ctw1500Text(data_root='data/ctw1500',
                              is_training=False,
                              transform=BaseTransform(size=cfg.test_size,
                                                      mean=cfg.means,
                                                      std=cfg.stds))
    elif cfg.exp_name == "TD500":
        testset = TD500Text(data_root='data/TD500',
                            is_training=False,
                            transform=BaseTransform(size=cfg.test_size,
                                                    mean=cfg.means,
                                                    std=cfg.stds))
    else:
        print("{} is not justify".format(cfg.exp_name))

    test_loader = data.DataLoader(testset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=cfg.num_workers)

    # Model
    model = TextNet(is_training=False, backbone=cfg.net)
    model_path = os.path.join(cfg.save_dir, cfg.exp_name,
                              'TextGraph_{}.pth'.format(model.backbone_name))
    model.load_model(model_path)

    # copy to cuda
    model = model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True

    if cfg.graph_link:
        detector = TextDetector_graph(model)

    print('Start testing TextGraph.')
    output_dir = os.path.join(cfg.output_dir, cfg.exp_name)
    inference(detector, test_loader, output_dir)
Example #5
0
class text_detection(object):
    def __init__(self):
        self.switch = False
        r = rospkg.RosPack()
        self.path = r.get_path('textsnake')
        self.commodity_list = []
        self.read_commodity(
            r.get_path('text_msgs') + "/config/commodity_list.txt")
        self.prob_threshold = 0.90
        self.cv_bridge = CvBridge()
        self.means = (0.485, 0.456, 0.406)
        self.stds = (0.229, 0.224, 0.225)

        self.saver = False

        self.color_map = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
                          (255, 255, 255)]  # 0 90 180 270 noise

        self.objects = []
        self.network = TextNet(is_training=False, backbone='vgg')
        self.is_compressed = False

        self.cuda_use = torch.cuda.is_available()

        if self.cuda_use:
            self.network = self.network.cuda()

        model_name = "textsnake.pth"
        self.network.load_model(os.path.join(self.path, "weights/",
                                             model_name))

        self.detector = TextDetector(self.network,
                                     tr_thresh=0.6,
                                     tcl_thresh=0.4)
        self.network.eval()
        #### Publisher
        self.image_pub = rospy.Publisher("~predict_img", Image, queue_size=1)
        self.img_bbox_pub = rospy.Publisher("~predict_bbox",
                                            Image,
                                            queue_size=1)
        self.predict_img_pub = rospy.Publisher("/prediction_img",
                                               Image,
                                               queue_size=1)
        self.predict_mask_pub = rospy.Publisher("/prediction_mask",
                                                Image,
                                                queue_size=1)
        self.text_detection_pub = rospy.Publisher("/text_detection_array",
                                                  text_detection_array,
                                                  queue_size=1)
        ### service
        self.predict_switch_ser = rospy.Service("~predict_switch_server",
                                                predict_switch,
                                                self.switch_callback)
        self.predict_ser = rospy.Service("~text_detection", text_detection_srv,
                                         self.srv_callback)
        ### msg filter
        image_sub = message_filters.Subscriber('/camera/color/image_raw',
                                               Image)
        depth_sub = message_filters.Subscriber(
            '/camera/aligned_depth_to_color/image_raw', Image)
        ts = message_filters.TimeSynchronizer([image_sub, depth_sub], 10)
        ts.registerCallback(self.callback)
        self.saver_count = 0
        if self.saver:
            self.p_img = os.path.join(self.path, "saver", "img")
            if not os.path.exists(self.p_img):
                os.makedirs(self.p_img)
            self.p_depth = os.path.join(self.path, "saver", "depth")
            if not os.path.exists(self.p_depth):
                os.makedirs(self.p_depth)
            self.p_mask = os.path.join(self.path, "saver", "mask")
            if not os.path.exists(self.p_mask):
                os.makedirs(self.p_mask)
            self.p_result = os.path.join(self.path, "saver", "result")
            if not os.path.exists(self.p_result):
                os.makedirs(self.p_result)

        print "============ Ready ============"
        print "TextSnake Model Parameters number: " + str(
            self.count_parameters(self.network))

    def read_commodity(self, path):

        for line in open(path, "r"):
            line = line.rstrip('\n')
            self.commodity_list.append(line)
        print "Node (text_detection): Finish reading list"

    def count_parameters(self, model):
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    def srv_callback(self, req):
        text_array = text_detection_array()

        resp = text_detection_srvResponse()
        img_msg = rospy.wait_for_message('/camera/color/image_raw',
                                         Image,
                                         timeout=None)
        resp.depth = rospy.wait_for_message(
            '/camera/aligned_depth_to_color/image_raw', Image, timeout=None)
        resp.image = img_msg
        try:
            if self.is_compressed:
                np_arr = np.fromstring(img_msg.data, np.uint8)
                cv_image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
            else:
                cv_image = self.cv_bridge.imgmsg_to_cv2(img_msg, "bgr8")
        except CvBridgeError as e:
            resp.status = e
            print(e)
        (rows, cols, channels) = cv_image.shape
        rows = int(np.ceil(rows / 32.) * 32)
        cols = int(np.ceil(cols / 32.) * 32)
        cv_image1 = np.zeros((rows, cols, channels), dtype=np.uint8)
        cv_image1[:cv_image.shape[0], :cv_image.shape[1], :cv_image.
                  shape[2]] = cv_image[:, :, :]
        cv_image = cv_image1.copy()

        mask = np.zeros([cv_image.shape[0], cv_image.shape[1]], dtype=np.uint8)
        img_list_0_90_180_270 = rotate_cv(cv_image)

        for i in range(4):

            predict_img, contours = self.predict(img_list_0_90_180_270[i])
            img_bbox = img_list_0_90_180_270[i].copy()

            text_array = text_detection_array()
            text_array.image = self.cv_bridge.cv2_to_imgmsg(
                img_list_0_90_180_270[i], "bgr8")
            text_array.depth = resp.depth
            for _cont in contours:
                text_bb = text_detection_msg()
                for p in _cont:
                    int_array = int_arr()
                    int_array.point.append(p[0])
                    int_array.point.append(p[1])
                    text_bb.contour.append(int_array)
                cv2.drawContours(predict_img, [_cont], -1, self.color_map[i],
                                 3)
                text_bb.box.xmin = min(_cont[:, 0])
                text_bb.box.xmax = max(_cont[:, 0])
                text_bb.box.ymin = min(_cont[:, 1])
                text_bb.box.ymax = max(_cont[:, 1])
                text_array.text_array.append(text_bb)
                cv2.rectangle(img_bbox, (text_bb.box.xmin, text_bb.box.ymin),
                              (text_bb.box.xmax, text_bb.box.ymax),
                              self.color_map[i], 3)
            text_array.bb_count = len(text_array.text_array)
            # self.text_detection_pub.publish(text_array)

            recog_req = text_recognize_srvRequest()
            recog_resp = text_recognize_srvResponse()
            try:
                rospy.wait_for_service(RECOG_SRV, timeout=10)
                recog_req.data = text_array
                recog_req.direct = i
                recognition_srv = rospy.ServiceProxy(RECOG_SRV,
                                                     text_recognize_srv)
                recog_resp = recognition_srv(recog_req)
            except (rospy.ServiceException, rospy.ROSException), e:
                resp.state = e

            recog_mask = self.cv_bridge.imgmsg_to_cv2(recog_resp.mask, "8UC1")

            if i == 0:
                pass
            elif i == 1:
                recog_mask = rotate_back_change_h_w(recog_mask, angle=-90)
                predict_img = rotate_back_change_h_w(predict_img, angle=-90)
                img_bbox = rotate_back_change_h_w(img_bbox, angle=-90)
            elif i == 2:
                recog_mask = rotate_back(recog_mask, angle=-180)
                predict_img = rotate_back(predict_img, angle=-180)
                img_bbox = rotate_back(img_bbox, angle=-180)
            else:
                recog_mask = rotate_back_change_h_w(recog_mask, angle=-270)
                predict_img = rotate_back_change_h_w(predict_img, angle=-270)
                img_bbox = rotate_back_change_h_w(img_bbox, angle=-270)

            mask[recog_mask != 0] = recog_mask[recog_mask != 0]

            try:
                self.image_pub.publish(
                    self.cv_bridge.cv2_to_imgmsg(predict_img, "bgr8"))
                self.img_bbox_pub.publish(
                    self.cv_bridge.cv2_to_imgmsg(img_bbox, "bgr8"))
            except CvBridgeError as e:
                resp.state = e
                print(e)

        ## publish visualization
        self.img_show(mask, cv_image)
        resp.mask = self.cv_bridge.cv2_to_imgmsg(mask, "8UC1")
        vis_mask = np.zeros([cv_image.shape[0], cv_image.shape[1]],
                            dtype=np.uint8)
        vis_mask[mask != 0] = 255 - mask[mask != 0]
        if self.saver:
            self.save_func(cv_image1, vis_mask,
                           self.cv_bridge.imgmsg_to_cv2(resp.depth, "16UC1"),
                           cv_image)
        ## srv end
        self.predict_img_pub.publish(
            self.cv_bridge.cv2_to_imgmsg(cv_image, "bgr8"))
        self.predict_mask_pub.publish(
            self.cv_bridge.cv2_to_imgmsg(vis_mask, "8UC1"))
        return resp