def inference( images_path: str, texts_path: str, test_imgs_file_path, batch_size: int, prefetch_size: int, checkpoint_path: str, joint_space: int, num_layers: int, ) -> None: """Performs inference on the Flickr8k/30k test set. Args: images_path: A path where all the images are located. texts_path: Path where the text doc with the descriptions is. test_imgs_file_path: Path to a file with the test image names. batch_size: The batch size to be used. prefetch_size: How many batches to prefetch. checkpoint_path: Path to a valid model checkpoint. joint_space: The size of the joint latent space. num_layers: The number of rnn layers. Returns: None """ dataset = FlickrDataset(images_path, texts_path) # Getting the vocabulary size of the train dataset test_image_paths, test_captions = dataset.get_data(test_imgs_file_path) logger.info("Test dataset created...") evaluator_test = Evaluator(len(test_image_paths), joint_space) logger.info("Test evaluator created...") # Resetting the default graph and setting the random seed tf.reset_default_graph() loader = InferenceLoader(test_image_paths, test_captions, batch_size, prefetch_size) images, captions, captions_lengths = loader.get_next() logger.info("Loader created...") model = VsePpModel(images, captions, captions_lengths, joint_space, num_layers) logger.info("Model created...") logger.info("Inference is starting...") with tf.Session() as sess: # Initializers model.init(sess, checkpoint_path) try: with tqdm(total=len(test_image_paths)) as pbar: while True: loss, lengths, embedded_images, embedded_captions = sess.run( [ model.loss, model.captions_len, model.image_encoded, model.text_encoded, ]) evaluator_test.update_metrics(loss) evaluator_test.update_embeddings(embedded_images, embedded_captions) pbar.update(len(lengths)) except tf.errors.OutOfRangeError: pass logger.info(f"The image2text recall at (1, 5, 10) is: " f"{evaluator_test.image2text_recall_at_k()}") logger.info(f"The text2image recall at (1, 5, 10) is: " f"{evaluator_test.text2image_recall_at_k()}")
def train( images_path: str, texts_path: str, train_imgs_file_path: str, val_imgs_file_path: str, joint_space: int, num_layers: int, learning_rate: float, margin: float, clip_val: float, decay_rate: int, weight_decay: float, batch_size: int, prefetch_size: int, epochs: int, save_model_path: str, ) -> None: """Starts a training session with the Flickr8k dataset. Args: images_path: A path where all the images are located. texts_path: Path where the text doc with the descriptions is. train_imgs_file_path: Path to a file with the train image names. val_imgs_file_path: Path to a file with the val image names. joint_space: The space where the encoded images and text will be projected. num_layers: Number of layers of the rnn. epochs: The number of epochs to train the model. batch_size: The batch size to be used. prefetch_size: How many batches to keep on GPU ready for processing. save_model_path: Where to save the model. learning_rate: The learning rate. weight_decay: The L2 loss constant. margin: The contrastive margin. clip_val: The max grad norm. decay_rate: When to decay the learning rate. Returns: None """ dataset = FlickrDataset(images_path, texts_path) train_image_paths, train_captions = dataset.get_data(train_imgs_file_path) val_image_paths, val_captions = dataset.get_data(val_imgs_file_path) logger.info("Dataset created...") evaluator_val = Evaluator(len(val_image_paths), joint_space) logger.info("Evaluators created...") # Resetting the default graph tf.reset_default_graph() loader = TrainValLoader( train_image_paths, train_captions, val_image_paths, val_captions, batch_size, prefetch_size, ) images, captions, captions_lengths = loader.get_next() logger.info("Loader created...") decay_steps = decay_rate * len(train_image_paths) / batch_size model = VsePpModel(images, captions, captions_lengths, joint_space, num_layers) logger.info("Model created...") logger.info("Training is starting...") with tf.Session() as sess: # Initializers model.init(sess) for e in range(epochs): # Reset evaluators evaluator_val.reset_all_vars() # Initialize iterator with train data sess.run(loader.train_init) try: with tqdm(total=len(train_image_paths)) as pbar: while True: _, loss, lengths = sess.run( [model.optimize, model.loss, model.captions], feed_dict={ model.weight_decay: weight_decay, model.learning_rate: learning_rate, model.margin: margin, model.decay_steps: decay_steps, model.clip_value: clip_val, }, ) pbar.update(len(lengths)) pbar.set_postfix({"Batch loss": loss}) except tf.errors.OutOfRangeError: pass # Initialize iterator with validation data sess.run(loader.val_init) try: with tqdm(total=len(val_image_paths)) as pbar: while True: loss, lengths, embedded_images, embedded_captions = sess.run( [ model.loss, model.captions, model.image_encoded, model.text_encoded, ] ) evaluator_val.update_metrics(loss) evaluator_val.update_embeddings( embedded_images, embedded_captions ) pbar.update(len(lengths)) except tf.errors.OutOfRangeError: pass if evaluator_val.is_best_recall_at_k(): evaluator_val.update_best_recall_at_k() logger.info("=============================") logger.info( f"Found new best on epoch {e+1}!! Saving model!\n" f"Current image-text recall at 1, 5, 10: " f"{evaluator_val.best_image2text_recall_at_k} \n" f"Current text-image recall at 1, 5, 10: " f"{evaluator_val.best_text2image_recall_at_k}" ) logger.info("=============================") model.save_model(sess, save_model_path)