def main():

    DIR = "../data/craft_temp/"

    print(torch.cuda.is_available())
    craft = Craft(output_dir=DIR,
                  crop_type="box",
                  cuda=True,
                  text_threshold=.4,
                  link_threshold=.4,
                  low_text=.4)  #just extract network
    base_model = copy.deepcopy(craft.craft_net)

    test_handler = DatasetImgCraftDefault(base_model,
                                          "../data/craft_temp/temp_train")

    model_train = copy.deepcopy(craft.craft_net)
    craft_coord_ob = craft_with_coord.CraftCoordSimple(model_train)

    loader_gen = torch.utils.data.DataLoader(test_handler,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=0)
    lossType = nn.L1Loss(reduction="sum")
    train_obj = TrainModel(craft_coord_ob, lossType)
    optimizer = torch.optim.Adam(craft_coord_ob.parameters(),
                                 lr=.00001,
                                 weight_decay=.0005)

    for i in range(20):
        print(train_obj.train(5, loader_gen, optimizer, True))
Example #2
0
 def test_load_refinenet_model(self):
     # init craft
     craft = Craft(
         output_dir=None,
         rectify=True,
         export_extra=False,
         text_threshold=0.7,
         link_threshold=0.4,
         low_text=0.4,
         cuda=False,
         long_size=720,
         refiner=False,
         crop_type="poly",
     )
     # remove refinenet model
     craft.refine_net = None
     # load refinenet model
     craft.load_refinenet_model()
     self.assertTrue(craft.refine_net)
Example #3
0
 def test_init(self):
     craft = Craft(
         output_dir=None,
         rectify=True,
         export_extra=False,
         text_threshold=0.7,
         link_threshold=0.4,
         low_text=0.4,
         cuda=False,
         long_size=720,
         refiner=False,
         crop_type="poly",
     )
     self.assertTrue(craft)
Example #4
0
    def test_detect_text(self):
        # init craft
        craft = Craft(
            output_dir=None,
            rectify=True,
            export_extra=False,
            text_threshold=0.7,
            link_threshold=0.4,
            low_text=0.4,
            cuda=False,
            long_size=720,
            refiner=False,
            crop_type="poly",
        )
        # detect text
        prediction_result = craft.detect_text(image=self.image_path)

        self.assertEqual(len(prediction_result["boxes"]), 52)
        self.assertEqual(len(prediction_result["boxes"][0]), 4)
        self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
        self.assertEqual(int(prediction_result["boxes"][0][0][0]), 115)

        # init craft
        craft = Craft(
            output_dir=None,
            rectify=True,
            export_extra=False,
            text_threshold=0.7,
            link_threshold=0.4,
            low_text=0.4,
            cuda=False,
            long_size=720,
            refiner=True,
            crop_type="poly",
        )
        # detect text
        prediction_result = craft.detect_text(image=self.image_path)

        self.assertEqual(len(prediction_result["boxes"]), 19)
        self.assertEqual(len(prediction_result["boxes"][0]), 4)
        self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
        self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661)

        # init craft
        craft = Craft(
            output_dir=None,
            rectify=False,
            export_extra=False,
            text_threshold=0.7,
            link_threshold=0.4,
            low_text=0.4,
            cuda=False,
            long_size=720,
            refiner=False,
            crop_type="box",
        )
        # detect text
        prediction_result = craft.detect_text(image=self.image_path)

        self.assertEqual(len(prediction_result["boxes"]), 52)
        self.assertEqual(len(prediction_result["boxes"][0]), 4)
        self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
        self.assertEqual(int(prediction_result["boxes"][0][2][0]), 244)

        # init craft
        craft = Craft(
            output_dir=None,
            rectify=False,
            export_extra=False,
            text_threshold=0.7,
            link_threshold=0.4,
            low_text=0.4,
            cuda=False,
            long_size=720,
            refiner=True,
            crop_type="box",
        )
        # detect text
        prediction_result = craft.detect_text(image=self.image_path)

        self.assertEqual(len(prediction_result["boxes"]), 19)
        self.assertEqual(len(prediction_result["boxes"][0]), 4)
        self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
        self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661)
Example #5
0
from craft_text_detector import Craft

from glob import glob
import os
import pickle
from tqdm import tqdm
from PIL import Image

output_dir = "outputs/"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

craft = Craft(output_dir=output_dir, crop_type="poly", cuda=False)

from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image

trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
trocr_model = VisionEncoderDecoderModel.from_pretrained(
    "microsoft/trocr-base-handwritten"
)


def predictor(x, batch_size=1):
    print(f"{len(x)} input Images received.")
    results = []
    for _ in x:
        print(f"{len(x)} input Images received.")
        try:
            craft.detect_text(_)
            crops = sorted(
Example #6
0
    def test_detect_text(self):
        # init craft
        craft = Craft(
            output_dir=None,
            rectify=True,
            export_extra=False,
            text_threshold=0.7,
            link_threshold=0.4,
            low_text=0.4,
            cuda=False,
            long_size=720,
            refiner=False,
            crop_type="poly",
        )
        # detect text
        #prediction_result = craft.detect_text(image_path=self.image_path)
        '''
        self.assertEqual(len(prediction_result["boxes"]), 52)
        self.assertEqual(len(prediction_result["boxes"][0]), 4)
        self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
        self.assertEqual(int(prediction_result["boxes"][0][0][0]), 115)
        '''
        # init craft
        craft = Craft(
            output_dir=None,
            rectify=True,
            export_extra=False,
            text_threshold=0.7,
            link_threshold=0.4,
            low_text=0.4,
            cuda=False,
            long_size=720,
            refiner=True,
            crop_type="poly",
        )
        # detect text
        #prediction_result = craft.detect_text(image_path=self.image_path)
        '''
        self.assertEqual(len(prediction_result["boxes"]), 19)
        self.assertEqual(len(prediction_result["boxes"][0]), 4)
        self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
        self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661)
        '''
        # init craft
        craft = Craft(
            output_dir=None,
            rectify=False,
            export_extra=False,
            text_threshold=0.7,
            link_threshold=0.4,
            low_text=0.4,
            cuda=False,
            long_size=720,
            refiner=False,
            crop_type="box",
        )
        # detect text
        #prediction_result = craft.detect_text(image_path=self.image_path)
        '''
        self.assertEqual(len(prediction_result["boxes"]), 52)
        self.assertEqual(len(prediction_result["boxes"][0]), 4)
        self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
        self.assertEqual(int(prediction_result["boxes"][0][2][0]), 244)
        '''
        # init craft
        craft = Craft(
            output_dir=None,
            rectify=False,
            export_extra=False,
            text_threshold=0.4,
            link_threshold=0.2,
            low_text=0.4,
            cuda=False,
            long_size=720,
            refiner=True,
            crop_type="box",
        )
        # detect text
        print("initiating")
        prediction_result = craft.detect_text(image_path=self.image_path)
        im = cv2.imread(self.image_path)
        #print(im.shape)
        '''
        self.assertEqual(len(prediction_result["boxes"]), 19)
        self.assertEqual(len(prediction_result["boxes"][0]), 4)
        self.assertEqual(len(prediction_result["boxes"][0][0]), 2)
        self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661)
        '''

        image = cv2.imread(self.image_path)
        for i, img in enumerate(prediction_result["text_crops"]):
            cv2.imshow("image", img)
            cv2.waitKey(0)