Example #1
0
    def __init__(self, text_of_the_image):
        self.text_of_the_image = text_of_the_image
        self.image_width = cfg.char_width * len(text_of_the_image)
        self.char_height = cfg.char_height
        perceptual_weights_dir = os.path.join(cfg.working_dir,
                                              "projector/perceptual_weights")
        self.vgg_ckpt_fn = os.path.join(perceptual_weights_dir, "vgg",
                                        "exported")
        self.lin_ckpt_fn = os.path.join(perceptual_weights_dir, "lin",
                                        "exported")
        self.perceptual_loss = learned_perceptual_metric_model(
            self.char_height, self.image_width, self.vgg_ckpt_fn,
            self.lin_ckpt_fn)
        self.infere = Infere()

        self.aster_ocr = AsterInferer()
        self.generator = ModelLoader().load_generator(is_g_clone=True,
                                                      ckpt_dir=cfg.ckpt_dir)

        self.n_mean_latent = 10000  # number of latents to take the mean from
        self.num_steps = 1000
        self.save_and_log_frequency = 100
        self.lr_rampup = 0.05  # duration of the learning rate warmup
        self.lr_rampdown = 0.25  # duration of the learning rate decay
        self.lr = 0.1
        self.noise_strength_level = 0.05
        self.noise_ramp = 0.75  # duration of the noise level decay
        self.optimizer = tf.keras.optimizers.Adam()
        self.ocr_loss_factor = 0.1
Example #2
0
    def log_images(
        self,
        input_words: tf.int32,
        generator: Generator,
        aster_ocr: AsterInferer,
        step: int,
    ) -> None:
        """
        Generates text boxes and saves them.

        Parameters
        ----------
        input_words: Integer sequences obtained from the input words (initially strings) using the MAIN_CHAR_VECTOR.
        generator: Generator used for inference (moving average of the trained Generator).
        aster_ocr: Pre-trained OCR.
        step: Current training step.

        """
        test_z = tf.random.normal(shape=[self.num_images_per_log, self.z_dim],
                                  dtype=tf.dtypes.float32)

        if cfg.strategy.num_replicas_in_sync > 1:
            input_words = input_words.values  # a list [x_from_dev_a, x_from_dev_b, ...]
            input_words = tf.concat(input_words, axis=0)

        input_words = tf.tile(input_words[0:1], [self.num_images_per_log, 1])

        batch_concat_images, height_concat_images = self._gen_samples(
            test_z, input_words, generator)

        (
            batch_concat_images,
            height_concat_images,
            input_words,
        ) = self._convert_per_replica_tensor(
            self.strategy,
            batch_concat_images,
            height_concat_images,
            input_words,
        )

        ocr_images = aster_ocr.convert_inputs(batch_concat_images,
                                              tf.tile(input_words, [2, 1]),
                                              blank_label=0)

        text_log = self._get_text_log(input_words[0:1], ocr_images, aster_ocr)
        summary_images = generator_output_to_uint8(height_concat_images)

        with self.train_summary_writer.as_default():
            tf.summary.image("images", summary_images, step=step)
            tf.summary.text("words", text_log, step=step)
def filter_out_bad_images() -> None:
    """ Filters out the images of the text box dataset for which the OCR loss is below the OCR_LOSS_THRESHOLD """

    print("Filtering out bad images")
    aster_ocr = AsterInferer()

    with open(
        os.path.join(cfg.training_text_boxes_dir, "annotations.txt"), "r"
    ) as annotations:
        with open(
            os.path.join(cfg.training_text_boxes_dir, "annotations_filtered.txt"), "w"
        ) as annotations_filtered:
            lines = annotations.readlines()

            for i, data in tqdm(enumerate(lines)):
                image_name, word = data.split(",", 1)
                word = word.strip("\n")

                if len(word) > cfg.max_char_number or len(word) == 0:
                    continue

                image = cv2.imread(
                    os.path.join(cfg.training_text_boxes_dir, image_name)
                )
                h, w, _ = image.shape

                image = cv2.resize(
                    image, (cfg.aster_image_dims[1], cfg.aster_image_dims[0])
                )
                image = image.astype(np.float32) / 127.5 - 1.0
                image = tf.expand_dims(tf.constant(image), 0)

                ocr_label_array = tf.constant(string_to_aster_int_sequence([word]))

                prediction = aster_ocr(image)

                loss = (
                    softmax_cross_entropy_loss(prediction, ocr_label_array)
                    * cfg.batch_size
                )
                if loss < OCR_LOSS_THRESHOLD:
                    annotations_filtered.write(data)
Example #4
0
    def __init__(self):

        self.batch_size = cfg.batch_size
        self.strategy = cfg.strategy
        self.max_steps = cfg.max_steps
        self.summary_steps_frequency = cfg.summary_steps_frequency
        self.image_summary_step_frequency = cfg.image_summary_step_frequency
        self.save_step_frequency = cfg.save_step_frequency
        self.log_dir = cfg.log_dir

        self.validation_step_frequency = cfg.validation_step_frequency
        self.tensorboard_writer = TensorboardWriter(self.log_dir)
        # set optimizer params
        self.g_opt = self.update_optimizer_params(cfg.g_opt)
        self.d_opt = self.update_optimizer_params(cfg.d_opt)
        self.pl_mean = tf.Variable(
            initial_value=0.0,
            name="pl_mean",
            trainable=False,
            synchronization=tf.VariableSynchronization.ON_READ,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
        )
        self.training_data_loader = TrainingDataLoader()
        self.validation_data_loader = ValidationDataLoader("validation_corpus.txt")
        self.model_loader = ModelLoader()
        # create model: model and optimizer must be created under `strategy.scope`
        (
            self.discriminator,
            self.generator,
            self.g_clone,
        ) = self.model_loader.initiate_models()

        # set optimizers
        self.d_optimizer = tf.keras.optimizers.Adam(
            self.d_opt["learning_rate"],
            beta_1=self.d_opt["beta1"],
            beta_2=self.d_opt["beta2"],
            epsilon=self.d_opt["epsilon"],
        )
        self.g_optimizer = tf.keras.optimizers.Adam(
            self.g_opt["learning_rate"],
            beta_1=self.g_opt["beta1"],
            beta_2=self.g_opt["beta2"],
            epsilon=self.g_opt["epsilon"],
        )
        self.ocr_optimizer = tf.keras.optimizers.Adam(
            self.g_opt["learning_rate"],
            beta_1=self.g_opt["beta1"],
            beta_2=self.g_opt["beta2"],
            epsilon=self.g_opt["epsilon"],
        )
        self.ocr_loss_weight = cfg.ocr_loss_weight

        self.aster_ocr = AsterInferer()

        self.training_step = TrainingStep(
            self.generator,
            self.discriminator,
            self.aster_ocr,
            self.g_optimizer,
            self.ocr_optimizer,
            self.d_optimizer,
            self.g_opt["reg_interval"],
            self.d_opt["reg_interval"],
            self.pl_mean,
        )

        self.validation_step = ValidationStep(self.g_clone, self.aster_ocr)

        self.manager = self.model_loader.load_checkpoint(
            ckpt_kwargs={
                "d_optimizer": self.d_optimizer,
                "g_optimizer": self.g_optimizer,
                "ocr_optimizer": self.ocr_optimizer,
                "discriminator": self.discriminator,
                "generator": self.generator,
                "g_clone": self.g_clone,
                "pl_mean": self.pl_mean,
            },
            model_description="Full model",
            expect_partial=False,
            ckpt_dir=cfg.ckpt_dir,
            max_to_keep=cfg.num_ckpts_to_keep,
        )
Example #5
0
 def __init__(self):
     self.generator = ModelLoader().load_generator(is_g_clone=True,
                                                   ckpt_dir=cfg.ckpt_dir)
     self.aster_ocr = AsterInferer()
     self.test_step = ValidationStep(self.generator, self.aster_ocr)
     self.strategy = cfg.strategy
Example #6
0
    ----------
    aster: pre-trained OCR.
    images_dir: Directory containing the images to infere

    """
    for image_name in os.listdir(images_dir):
        image_path = os.path.join(image_name, image_name)
        image = cv2.imread(image_path)
        h, w, _ = image.shape

        ocr_image = cv2.resize(
            image, (cfg.aster_image_dims[1], cfg.aster_image_dims[0])
        )
        ocr_image = ocr_image.astype(np.float32) / 127.5 - 1.0
        ocr_image = tf.expand_dims(tf.constant(ocr_image), 0)

        logits = aster(ocr_image)
        sequence_length = [logits.shape[1]]
        sequences_decoded = tf.nn.ctc_greedy_decoder(
            tf.transpose(logits, [1, 0, 2]), sequence_length, merge_repeated=False
        )[0][0]
        sequences_decoded = tf.sparse.to_dense(sequences_decoded).numpy()
        word = cfg.char_tokenizer.aster.sequences_to_texts(sequences_decoded)[0]
        print(image_path)
        print(word)


if __name__ == "__main__":
    aster = AsterInferer()
    infer_images(aster, IMAGES_DIR)
Example #7
0
class Projector:
    """ Projects a text box to find the latent vector responsible for its style. """
    def __init__(self, text_of_the_image):
        self.text_of_the_image = text_of_the_image
        self.image_width = cfg.char_width * len(text_of_the_image)
        self.char_height = cfg.char_height
        perceptual_weights_dir = os.path.join(cfg.working_dir,
                                              "projector/perceptual_weights")
        self.vgg_ckpt_fn = os.path.join(perceptual_weights_dir, "vgg",
                                        "exported")
        self.lin_ckpt_fn = os.path.join(perceptual_weights_dir, "lin",
                                        "exported")
        self.perceptual_loss = learned_perceptual_metric_model(
            self.char_height, self.image_width, self.vgg_ckpt_fn,
            self.lin_ckpt_fn)
        self.infere = Infere()

        self.aster_ocr = AsterInferer()
        self.generator = ModelLoader().load_generator(is_g_clone=True,
                                                      ckpt_dir=cfg.ckpt_dir)

        self.n_mean_latent = 10000  # number of latents to take the mean from
        self.num_steps = 1000
        self.save_and_log_frequency = 100
        self.lr_rampup = 0.05  # duration of the learning rate warmup
        self.lr_rampdown = 0.25  # duration of the learning rate decay
        self.lr = 0.1
        self.noise_strength_level = 0.05
        self.noise_ramp = 0.75  # duration of the noise level decay
        self.optimizer = tf.keras.optimizers.Adam()
        self.ocr_loss_factor = 0.1

    def _get_lr(self, t: float) -> float:
        """
        Computes a new learning rate.

        Parameters
        ----------
        t: Ratio of the current step over the total number of steps.

        Returns
        -------
        The new learning rate

        """
        lr_ramp = min(1, (1 - t) / self.lr_rampdown)
        lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
        lr_ramp = lr_ramp * min(1, t / self.lr_rampup)

        return self.lr * lr_ramp

    def _compute_w_latent(self) -> Tuple[tf.float32, tf.float32]:
        """
        Computes the style vector variable to train. This variable is initialized as the mean of self.n_mean_latent
        random style vectors.

        Returns
        -------
        w_latent_std: Standard deviation of the computed mean style vector.
        w_latent_var: Style vector variable to train.

        """
        z_latent = tf.random.normal(shape=[self.n_mean_latent, cfg.z_dim])
        w_latent = self.generator.latent_encoder(z_latent,
                                                 training=False)[:, 1, :]
        w_latent_mean = tf.reduce_mean(w_latent, axis=0, keepdims=True)
        w_latent_std = (tf.reduce_sum(
            (w_latent - w_latent_mean)**2) / self.n_mean_latent)**0.5

        w_latent_var = tf.Variable(w_latent_mean,
                                   name="w_latent_var",
                                   trainable=True)
        return w_latent_std, w_latent_var

    def _load_image(self, target_image_path: str,
                    image_width: int) -> tf.float32:
        """
        Load and preprocess the target image.

        Parameters
        ----------
        target_image_path: Path of the image to project.
        image_width: Width of the preprocessed image

        Returns
        -------

        """
        image = cv2.imread(target_image_path)
        image = cv2.resize(image, (image_width, self.char_height))
        return tf.expand_dims(tf.constant(image), 0)

    def main(self, target_image_path: str, output_dir: str) -> None:
        """
        Entry point of the Projector.

        Parameters
        ----------
        target_image_path: Path of the image to project.
        output_dir: Directory on which the output styles and images are saved.

        """
        target_image = self._load_image(target_image_path, self.image_width)
        input_word_array = string_to_main_int_sequence(
            [self.text_of_the_image])
        ocr_label = string_to_aster_int_sequence([self.text_of_the_image])

        w_latent_std, w_latent_var = self._compute_w_latent()

        word_encoded = self.generator.word_encoder(
            input_word_array,
            batch_size=1,
            training=False,
        )

        saved_latents = []
        loss_tracker = LossTracker(["perceptual_loss"])

        for step in tqdm(range(1, self.num_steps + 1)):
            t = step / self.num_steps
            lr = self._get_lr(t)
            self.optimizer.lr.assign(lr)

            noise_strength = (w_latent_std * self.noise_strength_level *
                              max(0, 1 - t / self.noise_ramp)**2)
            w_latent_noise = tf.random.normal(
                shape=w_latent_var.shape) * noise_strength

            loss = self._projector_step(
                w_latent_noise,
                w_latent_var,
                ocr_label,
                word_encoded,
                input_word_array,
                target_image,
            )
            loss_tracker.increment_losses({"perceptual_loss": loss})

            if step % self.save_and_log_frequency == 0:
                saved_latents.append(w_latent_var.numpy())
                loss_tracker.print_losses(step)
                self.infere.genererate_chosen_words(
                    [
                        self.text_of_the_image,
                    ],
                    prefix="projected_image" + str(step),
                    output_dir=output_dir,
                    do_sentence=False,
                    w_latents=saved_latents[-1],
                )
                with open(os.path.join(output_dir, "latents.txt"),
                          "w") as file:
                    for latent in saved_latents:
                        file.write(str(latent) + "\n")

    def _get_ocr_loss(self, ocr_label: tf.int32, generated_image: tf.float32,
                      input_word: tf.int32) -> tf.float32:
        """
        Computes the softmax crossentropy OCR loss.

        Parameters
        ----------
        ocr_label: Integer sequence obtained from the input word (initially a string) using the ASTER_CHAR_VECTOR.
        generated_image: Text box generated with our model.
        input_word: Integer sequence obtained from the input word (initially a string) using the MAIN_CHAR_VECTOR.

        Returns
        -------
        The OCR loss

        """

        fake_images_ocr_format = self.aster_ocr.convert_inputs(generated_image,
                                                               input_word,
                                                               blank_label=0)

        logits = self.aster_ocr(fake_images_ocr_format)
        return softmax_cross_entropy_loss(logits, ocr_label)

    def get_perceptual_loss(self, generated_image: tf.float32,
                            target_image: tf.float32) -> tf.float32:
        """
        Computes the perceptual loss.

        Parameters
        ----------
        generated_image: Text box generated with our model.
        target_image: The text box the projector is trying to extract the style from.

        Returns
        -------
        The perceptual loss

        """
        generated_image = generated_image[:, :, :, :self.image_width]
        generated_image = tf.transpose(generated_image, (0, 2, 3, 1))
        generated_image = (tf.clip_by_value(generated_image, -1.0, 1.0) +
                           1.0) * 127.5
        return self.perceptual_loss([target_image, generated_image])

    @tf.function()
    def _projector_step(
        self,
        w_latent_noise: tf.float32,
        w_latent_var: tf.Variable,
        ocr_label: tf.int32,
        word_encoded: tf.float32,
        input_word: tf.int32,
        target_image: tf.float32,
    ) -> tf.float32:
        """
        Training step for the projector.

        Parameters
        ----------
        w_latent_noise: Noise applied on w_latent_var.
        w_latent_var: Style vector the projector is training.
        ocr_label: Integer sequence obtained from the input word (initially a string) using the ASTER_CHAR_VECTOR.
        word_encoded: Output of the Word Encoder when inferring input_word.
        input_word: Integer sequence obtained from the input word (initially a string) using the MAIN_CHAR_VECTOR.
        target_image: The text box the projector is trying to extract the style from.

        Returns
        -------
        The resulting loss

        """
        with tf.GradientTape() as tape:
            w_latent_final = tf.tile(
                tf.expand_dims(
                    w_latent_var + w_latent_noise,
                    0,
                ),
                [1, self.generator.n_style, 1],
            )

            generated_image = self.generator.synthesis(
                [word_encoded, w_latent_final], training=False)

            ocr_loss = self._get_ocr_loss(ocr_label, generated_image,
                                          input_word)
            p_loss = self.get_perceptual_loss(generated_image, target_image)
            loss = p_loss + self.ocr_loss_factor * ocr_loss
        gradients = tape.gradient(loss, [w_latent_var])
        self.optimizer.apply_gradients(zip(gradients, [w_latent_var]))
        return loss