Ejemplo n.º 1
0
 def test_generator_from_random(self):
     generator = GeneratorFromRandom()
     i = 0
     while i < 100:
         img, lbl = next(generator)
         self.assertTrue(img.size[1] == 32, "Shape is not right")
         i += 1
Ejemplo n.º 2
0
def create_random_image_generator(count):
    """Creates a data generator that can produce images with
    corresponding labels. The text that is generated is pretty
    much gibberish; randomly chosen characters (from a predefined
    set of characters).

    Args:
        count (int):

    Returns:
        GeneratorFromRandom
    """
    font_path = os.path.join(PROJECT_PATH, 'resources/fonts/Vudotronic.otf')
    background_images_dir = os.path.join(PROJECT_PATH,
                                         'resources/background_images')
    gen = GeneratorFromRandom(
        count=count,
        length=2,
        allow_variable=True,
        use_letters=True,
        use_numbers=True,
        use_symbols=True,
        character_spacing=
        5,  # how much space to put between characters from one word
        fonts=[font_path],
        background_type=-1,  # means use images as background
        image_dir=background_images_dir)
    return gen
Ejemplo n.º 3
0
 def test_generator_from_random_stops(self):
     generator = GeneratorFromRandom(count=1)
     next(generator)
     self.assertRaises(StopIteration, generator.next)
    def __init__(self, args, training=True, semi_amount=0.0):
        """
        Initialize the constructor.
        """
        self.args = args
        self.is_training = training
        self.batch_size = args.batch_size
        self.current_iter = 0

        self.fake_generators = [
            GeneratorFromRandom(
                count=-1,
                length=1,
                language="fr",
                size=64,
                background_type=1,
                skewing_angle=2,
                margins=(2, 1, 1, 1),
                use_letters=False,
                use_symbols=False,
                use_numbers=True,
                random_skew=True,
                text_color='#000000,#888888',
            ),
            GeneratorFromRandom(
                count=-1,
                length=1,
                language="fr",
                size=48,
                background_type=1,
                skewing_angle=2,
                margins=(2, 1, 1, 1),
                use_letters=True,
                use_symbols=True,  # false
                use_numbers=False,
                random_skew=True,
                text_color='#000000,#888888',
            ),
            GeneratorFromRandom(
                count=-1,
                length=3,
                language="fr",
                size=24,
                background_type=3,
                skewing_angle=2,
                fit=True,
                random_skew=True,
                text_color='#000000,#888888',
            ),
            GeneratorFromRandom(
                count=-1,
                length=3,
                language="fr",
                size=32,
                background_type=3,
                skewing_angle=2,
                space_width=2,
                use_symbols=False,
                margins=(8, 8, 8, 8),
                random_skew=False,
            ),
            GeneratorFromRandom(
                count=-1,
                length=2,
                language="fr",
                size=55,
                background_type=1,
                skewing_angle=3,
                use_symbols=True,  # false
                fit=True,
                random_skew=True,
                text_color='#0171ff',
            ),
            GeneratorFromRandom(
                count=-1,
                length=3,
                language="fr",
                size=43,
                background_type=1,
                skewing_angle=2,
                margins=(4, 2, 10, 6),
                random_skew=False,
                text_color='#000000,#888888',
            ),
            GeneratorFromRandom(
                count=-1,
                length=5,
                language="fr",
                size=37,
                space_width=3,
                background_type=1,
                use_symbols=False,
                fit=True,
                text_color='#000000,#888888',
            ),
            GeneratorFromRandom(
                count=-1,
                length=5,
                language="fr",
                size=28,
                background_type=1,
                use_symbols=False,
                fit=True,
                text_color='#000000,#888888',
            ),
        ]
        self.dict_generators = [
            GeneratorFromDict(
                length=5,
                allow_variable=True,
                language="fr",
                size=32,
                background_type=1,
                fit=True,
                text_color='#000000,#888888',
            ),
            GeneratorFromDict(
                length=5,
                allow_variable=True,
                language="fr",
                size=32,
                background_type=1,
                margins=(7, 4, 6, 4),
                text_color='#000000,#888888',
            ),
            GeneratorFromDict(
                length=5,
                allow_variable=True,
                language="fr",
                size=32,
                background_type=1,
                fit=True,
                text_color='#000000,#888888',
            ),
        ]
        self.classic_gen = [
            GeneratorFromDict(
                length=3,
                allow_variable=True,
                language="fr",
                size=32,
                background_type=0,
                fit=True,
            ),
            GeneratorFromRandom(
                count=-1,
                length=5,
                language="fr",
                size=28,
                background_type=0,
                use_symbols=False,
                fit=True,
            ),
        ]
        self.height = 32
        self.width = None

        self.alphabet = ALPHABET
        # self.alphabet = string.printable
        self.alphabet_size = len(self.alphabet)
Ejemplo n.º 5
0
def train():
    """main training function"""

    # *********** MAGIC LINES ****************
    # you might need this if training crashes due GPU memory overload
    # or you get CuDNN load failure

    #check for gpu
    print(tf.config.list_physical_devices('GPU'))

    #for tf2 magic lines to prevent razer from crashing
    physical_devices = tf.config.experimental.list_physical_devices('GPU')
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

    config = ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.7
    config.gpu_options.allow_growth = True
    session = InteractiveSession(config=config)

    #*************  PARAMETERS *******************
    batch_size = 12
    img_h = 32
    num_epochs = 10

    # list of all characters
    # map each color to an integer, a "label" and reverse mapping
    all_chars = "0123456789"
    num_chars = len(all_chars)
    char_to_lbl_dict = dict((char, ind) for ind, char in enumerate(all_chars))
    lbl_to_char_dict = dict((ind, char) for ind, char in enumerate(all_chars))

    # ************** DATA GENERATORS *********************
    #use the trdg for the base generator of text
    base_generator = GeneratorFromRandom(use_symbols=False,
                                         use_letters=False,
                                         background_type=1)

    #add some more augmentation with keras ImageDataGenerator
    keras_augm = ImageDataGenerator(rotation_range=2.0,
                                    width_shift_range=5.0,
                                    height_shift_range=5.0,
                                    shear_range=4.0,
                                    zoom_range=0.1)

    #the actual datagenerator for training and visualizations (and validation)
    dg_params = {
        "batch_size": batch_size,
        "img_h": img_h,
        "keras_augmentor": keras_augm,
        "char_to_lbl_dict": char_to_lbl_dict
    }

    datagen = batch_functions.OCR_generator(base_generator, **dg_params)
    val_datagen = batch_functions.OCR_generator(base_generator,
                                                **dg_params,
                                                validation=True)

    #*******MODEL******
    model = models.make_standard_CRNN(img_h, num_chars)

    #********CALLBACKS AND LOSSES****************
    # get the cool outputs
    predvis = custom_callbacks.PredVisualize(model,
                                             val_datagen,
                                             lbl_to_char_dict,
                                             printing=True)
    model_saver = custom_callbacks.make_save_model_cb()
    custom_loss = losses.custom_ctc()

    #********COMPILE, SAVE MODEL**************
    model.compile(loss=custom_loss, optimizer="Adam")
    tf.keras.models.save_model(model,
                               "saved_models",
                               overwrite=True,
                               include_optimizer=False)

    H = model.fit(datagen,
                  epochs=num_epochs,
                  verbose=1,
                  callbacks=[predvis, model_saver])