Exemplo n.º 1
0
class TextRecognizer:
    def __init__(self,
                 checkpoint_path="./py/checkpoint_recognizer.hdf5",
                 input_size=(1024, 128, 1),
                 max_text_length=128,
                 charset_base=string.printable[:95],
                 architecture="flor"):
        self.tokenizer = None
        self.model = None
        self.checkpoint_path = checkpoint_path
        self.input_size = input_size
        self.max_text_length = max_text_length
        self.charset_base = charset_base
        self.architecture = architecture
        ml_utils.limit_gpu_memory()

        self.load_model()

    def load_model(self):
        self.tokenizer = Tokenizer(chars=self.charset_base,
                                   max_text_length=self.max_text_length)
        self.model = HTRModel(architecture=self.architecture,
                              input_size=self.input_size,
                              vocab_size=self.tokenizer.vocab_size,
                              top_paths=10)
        self.model.compile()
        self.model.load_checkpoint(target=self.checkpoint_path)

    def read_all_text_from_images(self, images):
        output = []
        for img in images:
            output.append(self.read_text_from_image(img))

        return output

    def read_text_from_image(self, img):
        img = pp.preprocess(img, input_size=self.input_size)
        x_test = pp.normalization([img])

        predicts, probabilities = self.model.predict(x_test, ctc_decode=True)
        predicts = [[self.tokenizer.decode(x) for x in y] for y in predicts]

        for i, (pred, prob) in enumerate(zip(predicts, probabilities)):
            return pred[0]

        return ""
Exemplo n.º 2
0
            cv2.imshow("img", pp.adjust_to_see(dt[x]))
            cv2.waitKey(0)

    elif args.image:
        tokenizer = Tokenizer(chars=charset_base,
                              max_text_length=max_text_length)

        img = pp.preproc(args.image, input_size=input_size)
        x_test = pp.normalization([img])

        model = HTRModel(architecture=args.arch,
                         input_size=input_size,
                         vocab_size=tokenizer.vocab_size,
                         top_paths=10)

        model.compile()
        model.load_checkpoint(target=target_path)

        predicts, probabilities = model.predict(x_test, ctc_decode=True)
        predicts = [[tokenizer.decode(x) for x in y] for y in predicts]

        print("\n####################################")
        for i, (pred, prob) in enumerate(zip(predicts, probabilities)):
            print("\nProb.  - Predict")

            for (pd, pb) in zip(pred, prob):
                print(f"{pb:.4f} - {pd}")

            cv2.imshow(f"Image {i + 1}", pp.adjust_to_see(img))
        print("\n####################################")
        cv2.waitKey(0)
    elif args.train or args.test:
        os.makedirs(output_path, exist_ok=True)

        dtgen = DataGenerator(hdf5_src=hdf5_src,
                              batch_size=args.batch_size,
                              charset=charset_base,
                              max_text_length=max_text_length)

        network_func = getattr(architecture, args.arch)

        ioo = network_func(input_size=input_size,
                           output_size=(dtgen.tokenizer.vocab_size + 1),
                           learning_rate=0.001)

        model = HTRModel(inputs=ioo[0], outputs=ioo[1])
        model.compile(optimizer=ioo[2])

        checkpoint = "checkpoint_weights.hdf5"
        model.load_checkpoint(target=os.path.join(output_path, checkpoint))

        if args.train:
            model.summary(output_path, "summary.txt")
            callbacks = model.get_callbacks(logdir=output_path,
                                            hdf5=checkpoint,
                                            verbose=1)

            start_time = time.time()
            h = model.fit_generator(generator=dtgen.next_train_batch(),
                                    epochs=args.epochs,
                                    steps_per_epoch=dtgen.train_steps,
                                    validation_data=dtgen.next_valid_batch(),