コード例 #1
0
def inference(
    images_path: str,
    texts_path: str,
    test_imgs_file_path,
    batch_size: int,
    prefetch_size: int,
    checkpoint_path: str,
    joint_space: int,
    num_layers: int,
) -> None:
    """Performs inference on the Flickr8k/30k test set.

    Args:
        images_path: A path where all the images are located.
        texts_path: Path where the text doc with the descriptions is.
        test_imgs_file_path: Path to a file with the test image names.
        batch_size: The batch size to be used.
        prefetch_size: How many batches to prefetch.
        checkpoint_path: Path to a valid model checkpoint.
        joint_space: The size of the joint latent space.
        num_layers: The number of rnn layers.

    Returns:
        None

    """
    dataset = FlickrDataset(images_path, texts_path)
    # Getting the vocabulary size of the train dataset
    test_image_paths, test_captions = dataset.get_data(test_imgs_file_path)
    logger.info("Test dataset created...")
    evaluator_test = Evaluator(len(test_image_paths), joint_space)

    logger.info("Test evaluator created...")

    # Resetting the default graph and setting the random seed
    tf.reset_default_graph()

    loader = InferenceLoader(test_image_paths, test_captions, batch_size,
                             prefetch_size)
    images, captions, captions_lengths = loader.get_next()
    logger.info("Loader created...")

    model = VsePpModel(images, captions, captions_lengths, joint_space,
                       num_layers)
    logger.info("Model created...")
    logger.info("Inference is starting...")

    with tf.Session() as sess:

        # Initializers
        model.init(sess, checkpoint_path)
        try:
            with tqdm(total=len(test_image_paths)) as pbar:
                while True:
                    loss, lengths, embedded_images, embedded_captions = sess.run(
                        [
                            model.loss,
                            model.captions_len,
                            model.image_encoded,
                            model.text_encoded,
                        ])
                    evaluator_test.update_metrics(loss)
                    evaluator_test.update_embeddings(embedded_images,
                                                     embedded_captions)
                    pbar.update(len(lengths))
        except tf.errors.OutOfRangeError:
            pass

            logger.info(f"The image2text recall at (1, 5, 10) is: "
                        f"{evaluator_test.image2text_recall_at_k()}")

            logger.info(f"The text2image recall at (1, 5, 10) is: "
                        f"{evaluator_test.text2image_recall_at_k()}")
コード例 #2
0
def train(
    images_path: str,
    texts_path: str,
    train_imgs_file_path: str,
    val_imgs_file_path: str,
    joint_space: int,
    num_layers: int,
    learning_rate: float,
    margin: float,
    clip_val: float,
    decay_rate: int,
    weight_decay: float,
    batch_size: int,
    prefetch_size: int,
    epochs: int,
    save_model_path: str,
) -> None:
    """Starts a training session with the Flickr8k dataset.

    Args:
        images_path: A path where all the images are located.
        texts_path: Path where the text doc with the descriptions is.
        train_imgs_file_path: Path to a file with the train image names.
        val_imgs_file_path: Path to a file with the val image names.
        joint_space: The space where the encoded images and text will be projected.
        num_layers: Number of layers of the rnn.
        epochs: The number of epochs to train the model.
        batch_size: The batch size to be used.
        prefetch_size: How many batches to keep on GPU ready for processing.
        save_model_path: Where to save the model.
        learning_rate: The learning rate.
        weight_decay: The L2 loss constant.
        margin: The contrastive margin.
        clip_val: The max grad norm.
        decay_rate: When to decay the learning rate.

    Returns:
        None

    """
    dataset = FlickrDataset(images_path, texts_path)
    train_image_paths, train_captions = dataset.get_data(train_imgs_file_path)
    val_image_paths, val_captions = dataset.get_data(val_imgs_file_path)
    logger.info("Dataset created...")
    evaluator_val = Evaluator(len(val_image_paths), joint_space)
    logger.info("Evaluators created...")

    # Resetting the default graph
    tf.reset_default_graph()
    loader = TrainValLoader(
        train_image_paths,
        train_captions,
        val_image_paths,
        val_captions,
        batch_size,
        prefetch_size,
    )
    images, captions, captions_lengths = loader.get_next()
    logger.info("Loader created...")

    decay_steps = decay_rate * len(train_image_paths) / batch_size
    model = VsePpModel(images, captions, captions_lengths, joint_space, num_layers)
    logger.info("Model created...")
    logger.info("Training is starting...")

    with tf.Session() as sess:
        # Initializers
        model.init(sess)
        for e in range(epochs):
            # Reset evaluators
            evaluator_val.reset_all_vars()

            # Initialize iterator with train data
            sess.run(loader.train_init)
            try:
                with tqdm(total=len(train_image_paths)) as pbar:
                    while True:
                        _, loss, lengths = sess.run(
                            [model.optimize, model.loss, model.captions],
                            feed_dict={
                                model.weight_decay: weight_decay,
                                model.learning_rate: learning_rate,
                                model.margin: margin,
                                model.decay_steps: decay_steps,
                                model.clip_value: clip_val,
                            },
                        )
                        pbar.update(len(lengths))
                        pbar.set_postfix({"Batch loss": loss})
            except tf.errors.OutOfRangeError:
                pass

            # Initialize iterator with validation data
            sess.run(loader.val_init)
            try:
                with tqdm(total=len(val_image_paths)) as pbar:
                    while True:
                        loss, lengths, embedded_images, embedded_captions = sess.run(
                            [
                                model.loss,
                                model.captions,
                                model.image_encoded,
                                model.text_encoded,
                            ]
                        )
                        evaluator_val.update_metrics(loss)
                        evaluator_val.update_embeddings(
                            embedded_images, embedded_captions
                        )
                        pbar.update(len(lengths))
            except tf.errors.OutOfRangeError:
                pass

            if evaluator_val.is_best_recall_at_k():
                evaluator_val.update_best_recall_at_k()
                logger.info("=============================")
                logger.info(
                    f"Found new best on epoch {e+1}!! Saving model!\n"
                    f"Current image-text recall at 1, 5, 10: "
                    f"{evaluator_val.best_image2text_recall_at_k} \n"
                    f"Current text-image recall at 1, 5, 10: "
                    f"{evaluator_val.best_text2image_recall_at_k}"
                )
                logger.info("=============================")
                model.save_model(sess, save_model_path)