def main(): DIR = "../data/craft_temp/" print(torch.cuda.is_available()) craft = Craft(output_dir=DIR, crop_type="box", cuda=True, text_threshold=.4, link_threshold=.4, low_text=.4) #just extract network base_model = copy.deepcopy(craft.craft_net) test_handler = DatasetImgCraftDefault(base_model, "../data/craft_temp/temp_train") model_train = copy.deepcopy(craft.craft_net) craft_coord_ob = craft_with_coord.CraftCoordSimple(model_train) loader_gen = torch.utils.data.DataLoader(test_handler, batch_size=1, shuffle=False, num_workers=0) lossType = nn.L1Loss(reduction="sum") train_obj = TrainModel(craft_coord_ob, lossType) optimizer = torch.optim.Adam(craft_coord_ob.parameters(), lr=.00001, weight_decay=.0005) for i in range(20): print(train_obj.train(5, loader_gen, optimizer, True))
def test_load_refinenet_model(self): # init craft craft = Craft( output_dir=None, rectify=True, export_extra=False, text_threshold=0.7, link_threshold=0.4, low_text=0.4, cuda=False, long_size=720, refiner=False, crop_type="poly", ) # remove refinenet model craft.refine_net = None # load refinenet model craft.load_refinenet_model() self.assertTrue(craft.refine_net)
def test_init(self): craft = Craft( output_dir=None, rectify=True, export_extra=False, text_threshold=0.7, link_threshold=0.4, low_text=0.4, cuda=False, long_size=720, refiner=False, crop_type="poly", ) self.assertTrue(craft)
def test_detect_text(self): # init craft craft = Craft( output_dir=None, rectify=True, export_extra=False, text_threshold=0.7, link_threshold=0.4, low_text=0.4, cuda=False, long_size=720, refiner=False, crop_type="poly", ) # detect text prediction_result = craft.detect_text(image=self.image_path) self.assertEqual(len(prediction_result["boxes"]), 52) self.assertEqual(len(prediction_result["boxes"][0]), 4) self.assertEqual(len(prediction_result["boxes"][0][0]), 2) self.assertEqual(int(prediction_result["boxes"][0][0][0]), 115) # init craft craft = Craft( output_dir=None, rectify=True, export_extra=False, text_threshold=0.7, link_threshold=0.4, low_text=0.4, cuda=False, long_size=720, refiner=True, crop_type="poly", ) # detect text prediction_result = craft.detect_text(image=self.image_path) self.assertEqual(len(prediction_result["boxes"]), 19) self.assertEqual(len(prediction_result["boxes"][0]), 4) self.assertEqual(len(prediction_result["boxes"][0][0]), 2) self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661) # init craft craft = Craft( output_dir=None, rectify=False, export_extra=False, text_threshold=0.7, link_threshold=0.4, low_text=0.4, cuda=False, long_size=720, refiner=False, crop_type="box", ) # detect text prediction_result = craft.detect_text(image=self.image_path) self.assertEqual(len(prediction_result["boxes"]), 52) self.assertEqual(len(prediction_result["boxes"][0]), 4) self.assertEqual(len(prediction_result["boxes"][0][0]), 2) self.assertEqual(int(prediction_result["boxes"][0][2][0]), 244) # init craft craft = Craft( output_dir=None, rectify=False, export_extra=False, text_threshold=0.7, link_threshold=0.4, low_text=0.4, cuda=False, long_size=720, refiner=True, crop_type="box", ) # detect text prediction_result = craft.detect_text(image=self.image_path) self.assertEqual(len(prediction_result["boxes"]), 19) self.assertEqual(len(prediction_result["boxes"][0]), 4) self.assertEqual(len(prediction_result["boxes"][0][0]), 2) self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661)
from craft_text_detector import Craft from glob import glob import os import pickle from tqdm import tqdm from PIL import Image output_dir = "outputs/" if not os.path.exists(output_dir): os.mkdir(output_dir) craft = Craft(output_dir=output_dir, crop_type="poly", cuda=False) from transformers import TrOCRProcessor, VisionEncoderDecoderModel from PIL import Image trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") trocr_model = VisionEncoderDecoderModel.from_pretrained( "microsoft/trocr-base-handwritten" ) def predictor(x, batch_size=1): print(f"{len(x)} input Images received.") results = [] for _ in x: print(f"{len(x)} input Images received.") try: craft.detect_text(_) crops = sorted(
def test_detect_text(self): # init craft craft = Craft( output_dir=None, rectify=True, export_extra=False, text_threshold=0.7, link_threshold=0.4, low_text=0.4, cuda=False, long_size=720, refiner=False, crop_type="poly", ) # detect text #prediction_result = craft.detect_text(image_path=self.image_path) ''' self.assertEqual(len(prediction_result["boxes"]), 52) self.assertEqual(len(prediction_result["boxes"][0]), 4) self.assertEqual(len(prediction_result["boxes"][0][0]), 2) self.assertEqual(int(prediction_result["boxes"][0][0][0]), 115) ''' # init craft craft = Craft( output_dir=None, rectify=True, export_extra=False, text_threshold=0.7, link_threshold=0.4, low_text=0.4, cuda=False, long_size=720, refiner=True, crop_type="poly", ) # detect text #prediction_result = craft.detect_text(image_path=self.image_path) ''' self.assertEqual(len(prediction_result["boxes"]), 19) self.assertEqual(len(prediction_result["boxes"][0]), 4) self.assertEqual(len(prediction_result["boxes"][0][0]), 2) self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661) ''' # init craft craft = Craft( output_dir=None, rectify=False, export_extra=False, text_threshold=0.7, link_threshold=0.4, low_text=0.4, cuda=False, long_size=720, refiner=False, crop_type="box", ) # detect text #prediction_result = craft.detect_text(image_path=self.image_path) ''' self.assertEqual(len(prediction_result["boxes"]), 52) self.assertEqual(len(prediction_result["boxes"][0]), 4) self.assertEqual(len(prediction_result["boxes"][0][0]), 2) self.assertEqual(int(prediction_result["boxes"][0][2][0]), 244) ''' # init craft craft = Craft( output_dir=None, rectify=False, export_extra=False, text_threshold=0.4, link_threshold=0.2, low_text=0.4, cuda=False, long_size=720, refiner=True, crop_type="box", ) # detect text print("initiating") prediction_result = craft.detect_text(image_path=self.image_path) im = cv2.imread(self.image_path) #print(im.shape) ''' self.assertEqual(len(prediction_result["boxes"]), 19) self.assertEqual(len(prediction_result["boxes"][0]), 4) self.assertEqual(len(prediction_result["boxes"][0][0]), 2) self.assertEqual(int(prediction_result["boxes"][0][2][0]), 661) ''' image = cv2.imread(self.image_path) for i, img in enumerate(prediction_result["text_crops"]): cv2.imshow("image", img) cv2.waitKey(0)