def __init__(self, text_of_the_image): self.text_of_the_image = text_of_the_image self.image_width = cfg.char_width * len(text_of_the_image) self.char_height = cfg.char_height perceptual_weights_dir = os.path.join(cfg.working_dir, "projector/perceptual_weights") self.vgg_ckpt_fn = os.path.join(perceptual_weights_dir, "vgg", "exported") self.lin_ckpt_fn = os.path.join(perceptual_weights_dir, "lin", "exported") self.perceptual_loss = learned_perceptual_metric_model( self.char_height, self.image_width, self.vgg_ckpt_fn, self.lin_ckpt_fn) self.infere = Infere() self.aster_ocr = AsterInferer() self.generator = ModelLoader().load_generator(is_g_clone=True, ckpt_dir=cfg.ckpt_dir) self.n_mean_latent = 10000 # number of latents to take the mean from self.num_steps = 1000 self.save_and_log_frequency = 100 self.lr_rampup = 0.05 # duration of the learning rate warmup self.lr_rampdown = 0.25 # duration of the learning rate decay self.lr = 0.1 self.noise_strength_level = 0.05 self.noise_ramp = 0.75 # duration of the noise level decay self.optimizer = tf.keras.optimizers.Adam() self.ocr_loss_factor = 0.1
def log_images( self, input_words: tf.int32, generator: Generator, aster_ocr: AsterInferer, step: int, ) -> None: """ Generates text boxes and saves them. Parameters ---------- input_words: Integer sequences obtained from the input words (initially strings) using the MAIN_CHAR_VECTOR. generator: Generator used for inference (moving average of the trained Generator). aster_ocr: Pre-trained OCR. step: Current training step. """ test_z = tf.random.normal(shape=[self.num_images_per_log, self.z_dim], dtype=tf.dtypes.float32) if cfg.strategy.num_replicas_in_sync > 1: input_words = input_words.values # a list [x_from_dev_a, x_from_dev_b, ...] input_words = tf.concat(input_words, axis=0) input_words = tf.tile(input_words[0:1], [self.num_images_per_log, 1]) batch_concat_images, height_concat_images = self._gen_samples( test_z, input_words, generator) ( batch_concat_images, height_concat_images, input_words, ) = self._convert_per_replica_tensor( self.strategy, batch_concat_images, height_concat_images, input_words, ) ocr_images = aster_ocr.convert_inputs(batch_concat_images, tf.tile(input_words, [2, 1]), blank_label=0) text_log = self._get_text_log(input_words[0:1], ocr_images, aster_ocr) summary_images = generator_output_to_uint8(height_concat_images) with self.train_summary_writer.as_default(): tf.summary.image("images", summary_images, step=step) tf.summary.text("words", text_log, step=step)
def filter_out_bad_images() -> None: """ Filters out the images of the text box dataset for which the OCR loss is below the OCR_LOSS_THRESHOLD """ print("Filtering out bad images") aster_ocr = AsterInferer() with open( os.path.join(cfg.training_text_boxes_dir, "annotations.txt"), "r" ) as annotations: with open( os.path.join(cfg.training_text_boxes_dir, "annotations_filtered.txt"), "w" ) as annotations_filtered: lines = annotations.readlines() for i, data in tqdm(enumerate(lines)): image_name, word = data.split(",", 1) word = word.strip("\n") if len(word) > cfg.max_char_number or len(word) == 0: continue image = cv2.imread( os.path.join(cfg.training_text_boxes_dir, image_name) ) h, w, _ = image.shape image = cv2.resize( image, (cfg.aster_image_dims[1], cfg.aster_image_dims[0]) ) image = image.astype(np.float32) / 127.5 - 1.0 image = tf.expand_dims(tf.constant(image), 0) ocr_label_array = tf.constant(string_to_aster_int_sequence([word])) prediction = aster_ocr(image) loss = ( softmax_cross_entropy_loss(prediction, ocr_label_array) * cfg.batch_size ) if loss < OCR_LOSS_THRESHOLD: annotations_filtered.write(data)
def __init__(self): self.batch_size = cfg.batch_size self.strategy = cfg.strategy self.max_steps = cfg.max_steps self.summary_steps_frequency = cfg.summary_steps_frequency self.image_summary_step_frequency = cfg.image_summary_step_frequency self.save_step_frequency = cfg.save_step_frequency self.log_dir = cfg.log_dir self.validation_step_frequency = cfg.validation_step_frequency self.tensorboard_writer = TensorboardWriter(self.log_dir) # set optimizer params self.g_opt = self.update_optimizer_params(cfg.g_opt) self.d_opt = self.update_optimizer_params(cfg.d_opt) self.pl_mean = tf.Variable( initial_value=0.0, name="pl_mean", trainable=False, synchronization=tf.VariableSynchronization.ON_READ, aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) self.training_data_loader = TrainingDataLoader() self.validation_data_loader = ValidationDataLoader("validation_corpus.txt") self.model_loader = ModelLoader() # create model: model and optimizer must be created under `strategy.scope` ( self.discriminator, self.generator, self.g_clone, ) = self.model_loader.initiate_models() # set optimizers self.d_optimizer = tf.keras.optimizers.Adam( self.d_opt["learning_rate"], beta_1=self.d_opt["beta1"], beta_2=self.d_opt["beta2"], epsilon=self.d_opt["epsilon"], ) self.g_optimizer = tf.keras.optimizers.Adam( self.g_opt["learning_rate"], beta_1=self.g_opt["beta1"], beta_2=self.g_opt["beta2"], epsilon=self.g_opt["epsilon"], ) self.ocr_optimizer = tf.keras.optimizers.Adam( self.g_opt["learning_rate"], beta_1=self.g_opt["beta1"], beta_2=self.g_opt["beta2"], epsilon=self.g_opt["epsilon"], ) self.ocr_loss_weight = cfg.ocr_loss_weight self.aster_ocr = AsterInferer() self.training_step = TrainingStep( self.generator, self.discriminator, self.aster_ocr, self.g_optimizer, self.ocr_optimizer, self.d_optimizer, self.g_opt["reg_interval"], self.d_opt["reg_interval"], self.pl_mean, ) self.validation_step = ValidationStep(self.g_clone, self.aster_ocr) self.manager = self.model_loader.load_checkpoint( ckpt_kwargs={ "d_optimizer": self.d_optimizer, "g_optimizer": self.g_optimizer, "ocr_optimizer": self.ocr_optimizer, "discriminator": self.discriminator, "generator": self.generator, "g_clone": self.g_clone, "pl_mean": self.pl_mean, }, model_description="Full model", expect_partial=False, ckpt_dir=cfg.ckpt_dir, max_to_keep=cfg.num_ckpts_to_keep, )
def __init__(self): self.generator = ModelLoader().load_generator(is_g_clone=True, ckpt_dir=cfg.ckpt_dir) self.aster_ocr = AsterInferer() self.test_step = ValidationStep(self.generator, self.aster_ocr) self.strategy = cfg.strategy
---------- aster: pre-trained OCR. images_dir: Directory containing the images to infere """ for image_name in os.listdir(images_dir): image_path = os.path.join(image_name, image_name) image = cv2.imread(image_path) h, w, _ = image.shape ocr_image = cv2.resize( image, (cfg.aster_image_dims[1], cfg.aster_image_dims[0]) ) ocr_image = ocr_image.astype(np.float32) / 127.5 - 1.0 ocr_image = tf.expand_dims(tf.constant(ocr_image), 0) logits = aster(ocr_image) sequence_length = [logits.shape[1]] sequences_decoded = tf.nn.ctc_greedy_decoder( tf.transpose(logits, [1, 0, 2]), sequence_length, merge_repeated=False )[0][0] sequences_decoded = tf.sparse.to_dense(sequences_decoded).numpy() word = cfg.char_tokenizer.aster.sequences_to_texts(sequences_decoded)[0] print(image_path) print(word) if __name__ == "__main__": aster = AsterInferer() infer_images(aster, IMAGES_DIR)
class Projector: """ Projects a text box to find the latent vector responsible for its style. """ def __init__(self, text_of_the_image): self.text_of_the_image = text_of_the_image self.image_width = cfg.char_width * len(text_of_the_image) self.char_height = cfg.char_height perceptual_weights_dir = os.path.join(cfg.working_dir, "projector/perceptual_weights") self.vgg_ckpt_fn = os.path.join(perceptual_weights_dir, "vgg", "exported") self.lin_ckpt_fn = os.path.join(perceptual_weights_dir, "lin", "exported") self.perceptual_loss = learned_perceptual_metric_model( self.char_height, self.image_width, self.vgg_ckpt_fn, self.lin_ckpt_fn) self.infere = Infere() self.aster_ocr = AsterInferer() self.generator = ModelLoader().load_generator(is_g_clone=True, ckpt_dir=cfg.ckpt_dir) self.n_mean_latent = 10000 # number of latents to take the mean from self.num_steps = 1000 self.save_and_log_frequency = 100 self.lr_rampup = 0.05 # duration of the learning rate warmup self.lr_rampdown = 0.25 # duration of the learning rate decay self.lr = 0.1 self.noise_strength_level = 0.05 self.noise_ramp = 0.75 # duration of the noise level decay self.optimizer = tf.keras.optimizers.Adam() self.ocr_loss_factor = 0.1 def _get_lr(self, t: float) -> float: """ Computes a new learning rate. Parameters ---------- t: Ratio of the current step over the total number of steps. Returns ------- The new learning rate """ lr_ramp = min(1, (1 - t) / self.lr_rampdown) lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) lr_ramp = lr_ramp * min(1, t / self.lr_rampup) return self.lr * lr_ramp def _compute_w_latent(self) -> Tuple[tf.float32, tf.float32]: """ Computes the style vector variable to train. This variable is initialized as the mean of self.n_mean_latent random style vectors. Returns ------- w_latent_std: Standard deviation of the computed mean style vector. w_latent_var: Style vector variable to train. """ z_latent = tf.random.normal(shape=[self.n_mean_latent, cfg.z_dim]) w_latent = self.generator.latent_encoder(z_latent, training=False)[:, 1, :] w_latent_mean = tf.reduce_mean(w_latent, axis=0, keepdims=True) w_latent_std = (tf.reduce_sum( (w_latent - w_latent_mean)**2) / self.n_mean_latent)**0.5 w_latent_var = tf.Variable(w_latent_mean, name="w_latent_var", trainable=True) return w_latent_std, w_latent_var def _load_image(self, target_image_path: str, image_width: int) -> tf.float32: """ Load and preprocess the target image. Parameters ---------- target_image_path: Path of the image to project. image_width: Width of the preprocessed image Returns ------- """ image = cv2.imread(target_image_path) image = cv2.resize(image, (image_width, self.char_height)) return tf.expand_dims(tf.constant(image), 0) def main(self, target_image_path: str, output_dir: str) -> None: """ Entry point of the Projector. Parameters ---------- target_image_path: Path of the image to project. output_dir: Directory on which the output styles and images are saved. """ target_image = self._load_image(target_image_path, self.image_width) input_word_array = string_to_main_int_sequence( [self.text_of_the_image]) ocr_label = string_to_aster_int_sequence([self.text_of_the_image]) w_latent_std, w_latent_var = self._compute_w_latent() word_encoded = self.generator.word_encoder( input_word_array, batch_size=1, training=False, ) saved_latents = [] loss_tracker = LossTracker(["perceptual_loss"]) for step in tqdm(range(1, self.num_steps + 1)): t = step / self.num_steps lr = self._get_lr(t) self.optimizer.lr.assign(lr) noise_strength = (w_latent_std * self.noise_strength_level * max(0, 1 - t / self.noise_ramp)**2) w_latent_noise = tf.random.normal( shape=w_latent_var.shape) * noise_strength loss = self._projector_step( w_latent_noise, w_latent_var, ocr_label, word_encoded, input_word_array, target_image, ) loss_tracker.increment_losses({"perceptual_loss": loss}) if step % self.save_and_log_frequency == 0: saved_latents.append(w_latent_var.numpy()) loss_tracker.print_losses(step) self.infere.genererate_chosen_words( [ self.text_of_the_image, ], prefix="projected_image" + str(step), output_dir=output_dir, do_sentence=False, w_latents=saved_latents[-1], ) with open(os.path.join(output_dir, "latents.txt"), "w") as file: for latent in saved_latents: file.write(str(latent) + "\n") def _get_ocr_loss(self, ocr_label: tf.int32, generated_image: tf.float32, input_word: tf.int32) -> tf.float32: """ Computes the softmax crossentropy OCR loss. Parameters ---------- ocr_label: Integer sequence obtained from the input word (initially a string) using the ASTER_CHAR_VECTOR. generated_image: Text box generated with our model. input_word: Integer sequence obtained from the input word (initially a string) using the MAIN_CHAR_VECTOR. Returns ------- The OCR loss """ fake_images_ocr_format = self.aster_ocr.convert_inputs(generated_image, input_word, blank_label=0) logits = self.aster_ocr(fake_images_ocr_format) return softmax_cross_entropy_loss(logits, ocr_label) def get_perceptual_loss(self, generated_image: tf.float32, target_image: tf.float32) -> tf.float32: """ Computes the perceptual loss. Parameters ---------- generated_image: Text box generated with our model. target_image: The text box the projector is trying to extract the style from. Returns ------- The perceptual loss """ generated_image = generated_image[:, :, :, :self.image_width] generated_image = tf.transpose(generated_image, (0, 2, 3, 1)) generated_image = (tf.clip_by_value(generated_image, -1.0, 1.0) + 1.0) * 127.5 return self.perceptual_loss([target_image, generated_image]) @tf.function() def _projector_step( self, w_latent_noise: tf.float32, w_latent_var: tf.Variable, ocr_label: tf.int32, word_encoded: tf.float32, input_word: tf.int32, target_image: tf.float32, ) -> tf.float32: """ Training step for the projector. Parameters ---------- w_latent_noise: Noise applied on w_latent_var. w_latent_var: Style vector the projector is training. ocr_label: Integer sequence obtained from the input word (initially a string) using the ASTER_CHAR_VECTOR. word_encoded: Output of the Word Encoder when inferring input_word. input_word: Integer sequence obtained from the input word (initially a string) using the MAIN_CHAR_VECTOR. target_image: The text box the projector is trying to extract the style from. Returns ------- The resulting loss """ with tf.GradientTape() as tape: w_latent_final = tf.tile( tf.expand_dims( w_latent_var + w_latent_noise, 0, ), [1, self.generator.n_style, 1], ) generated_image = self.generator.synthesis( [word_encoded, w_latent_final], training=False) ocr_loss = self._get_ocr_loss(ocr_label, generated_image, input_word) p_loss = self.get_perceptual_loss(generated_image, target_image) loss = p_loss + self.ocr_loss_factor * ocr_loss gradients = tape.gradient(loss, [w_latent_var]) self.optimizer.apply_gradients(zip(gradients, [w_latent_var])) return loss