コード例 #1
0
def save_keyword_images(keywords_set,
                        images_dir,
                        keyword_list,
                        output_dir,
                        max_per_row=5,
                        max_images=20):
    """TODO(rpeloff) document and move to plotting (?)"""
    import os
    import matplotlib.pyplot as plt
    from moonshot.utils import image_utils
    from moonshot.utils import plotting

    file_io.check_create_dir(output_dir)

    for keyword in keyword_list:
        if keyword not in keywords_set[3]:
            logging.log(logging.INFO,
                        "Keyword not found in set: {}".format(keyword))
            continue  # skip to next keyword

        image_uids = np.unique(
            keywords_set[0][np.where(keywords_set[3] == keyword)[0]])
        n_cols = min(len(image_uids), max_per_row)
        n_rows = min(int(np.ceil(len(image_uids) / max_per_row)),
                     int(max_images / max_per_row))

        plt.figure(figsize=(n_cols * 3, n_rows * 3))
        plt.suptitle(keyword, fontsize=14)

        for image_index, uid in enumerate(image_uids):
            if image_index + 1 > max_images:
                break
            plt.subplot(n_rows, n_cols, image_index + 1)
            plt.imshow(image_utils.load_image_array(
                os.path.join(images_dir, "{}.jpg".format(uid))),
                       interpolation="lanczos")
            plt.title(uid)
            plt.axis("off")

        plotting.save_figure("{}_filtered_images.png".format(keyword),
                             path=output_dir,
                             tight_layout=True)
コード例 #2
0
ファイル: plotting.py プロジェクト: rpeloff/moonshot
def save_figure(filename,
                path="figures",
                figure=None,
                tight_layout=False,
                fig_extension="png",
                resolution=300):
    """Save current plot or specified figure to disk."""
    file_io.check_create_dir(path)

    logging.log(logging.INFO,
                "Saving figure '{}' to directory: {}".format(filename, path))

    if figure is None:  # fetch default plot (only works before plt.show)
        figure = plt

    if tight_layout:
        figure.tight_layout()

    fig_path = os.path.join(path, filename)
    figure.savefig(fig_path, format=fig_extension, dpi=resolution)
コード例 #3
0
def main(argv):
    del argv  # unused

    logging.log(logging.INFO, "Logging application {}".format(__file__))
    if FLAGS.debug:
        logging.set_verbosity(logging.DEBUG)
        logging.log(logging.DEBUG, "Running in debug mode")

    physical_devices = tf.config.experimental.list_physical_devices("GPU")
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

    # create output directory if none specified
    if FLAGS.output_dir is None:
        output_dir = os.path.join(
            "logs", __file__,
            datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
        file_io.check_create_dir(output_dir)

    # output directory specified, load model options if found
    else:
        output_dir = FLAGS.output_dir

    # print flag options
    flag_options = {}
    for flag in FLAGS.get_key_flags_for_module(__file__):
        flag_options[flag.name] = flag.value

    # logging
    logging_utils.absl_file_logger(output_dir, f"log.test")

    logging.log(logging.INFO, f"Model directory: {output_dir}")
    logging.log(logging.INFO, f"Flag options: {flag_options}")

    # set seeds for reproducibility
    np.random.seed(FLAGS.seed)
    tf.random.set_seed(FLAGS.seed)

    # test baseline matching model (no background training step)
    test()
コード例 #4
0
def main(argv):
    """Main program logic."""
    del argv  # unused

    logging.log(logging.INFO, "Logging application {}".format(__file__))
    if FLAGS.debug:
        logging.set_verbosity(logging.DEBUG)
        logging.log(logging.DEBUG, "Running in debug mode")

    physical_devices = tf.config.experimental.list_physical_devices("GPU")
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

    model_found = False
    # no prior run specified, train model
    if FLAGS.output_dir is None:
        if FLAGS.target != "train":
            raise ValueError(
                f"Target `{FLAGS.target}` requires --output_dir to be specified.")

        output_dir = os.path.join(
            "logs", __file__, datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
        file_io.check_create_dir(output_dir)

        model_options = DEFAULT_OPTIONS

        # add flag options to model options
        model_options["base_dir"] = FLAGS.base_dir

        if FLAGS.base_dir is None:
            raise ValueError(
                f"Target `{FLAGS.target}` requires --base_dir to be specified.")

        model_options["base_model"] = (
            "best_model" if FLAGS.load_best else "model")

    # prior run specified, resume training or test model
    else:
        output_dir = FLAGS.output_dir

        # load current or best model
        model_file = "best_model.h5" if FLAGS.load_best else "model.h5"
        model_step_file = "best_model.step" if FLAGS.load_best else "model.step"

        if FLAGS.base_dir is not None:
            raise ValueError(
                f"Flag --base_dir should not be set for target `{FLAGS.target}`.")

        if os.path.exists(os.path.join(output_dir, model_file)):

            model_found = True
            model_options = file_io.read_json(
                os.path.join(output_dir, "model_options.json"))

        elif FLAGS.target != "train":
            raise ValueError(
                f"Target `{FLAGS.target}` specified but `{model_file}` not "
                f"found in {output_dir}.")

    # gather flag options
    flag_options = {}
    for flag in FLAGS.get_key_flags_for_module(__file__):
        flag_options[flag.name] = flag.value

    # logging
    logging_utils.absl_file_logger(output_dir, f"log.{FLAGS.target}")

    logging.log(logging.INFO, f"Model directory: {output_dir}")
    logging.log(logging.INFO, f"Model options: {model_options}")
    logging.log(logging.INFO, f"Flag options: {flag_options}")

    tf_writer = None
    if FLAGS.tensorboard and FLAGS.target == "train":
        tf_writer = tf.summary.create_file_writer(output_dir)

    # set seeds for reproducibility
    np.random.seed(model_options["seed"])
    tf.random.set_seed(model_options["seed"])

    # run target
    if FLAGS.target == "train":
        if model_found and FLAGS.resume:
            train(model_options, output_dir, model_file, model_step_file,
                  tf_writer=tf_writer)
        else:
            train(model_options, output_dir, tf_writer=tf_writer)
    elif FLAGS.target == "validate":  # TODO
        raise NotImplementedError
    elif FLAGS.target == "embed":
        embed(model_options, output_dir, model_file, model_step_file)
    else:
        test(model_options, output_dir, model_file, model_step_file)
コード例 #5
0
def embed(model_options, output_dir, model_file, model_step_file):
    """Load siamese spoken word similarity model and extract embeddings."""

    # load embeddings from dense layer of base model
    embed_dir = os.path.join(model_options["base_dir"], "embed", "dense")

    # load model
    speech_network, _ = model_utils.load_model(
        model_file=os.path.join(output_dir, model_file),
        model_step_file=os.path.join(output_dir, model_step_file),
        loss=get_training_objective(model_options))

    # get model embedding model and data preprocessing
    embedding_model = create_embedding_model(speech_network)
    data_preprocess_func = get_data_preprocess_func()

    # load Flickr Audio dataset and compute embeddings
    one_shot_exp = flickr_speech.FlickrSpeech(
        features="mfcc", keywords_split="one_shot_evaluation",
        embed_dir=embed_dir)

    background_train_exp = flickr_speech.FlickrSpeech(
        features="mfcc", keywords_split="background_train", embed_dir=embed_dir)

    background_dev_exp = flickr_speech.FlickrSpeech(
        features="mfcc", keywords_split="background_dev", embed_dir=embed_dir)

    subset_exp = {
        "one_shot_evaluation": one_shot_exp,
        "background_train": background_train_exp,
        "background_dev": background_dev_exp,
    }

    for subset, exp in subset_exp.items():
        embed_dir = os.path.join(
            output_dir, "embed", "dense", "flickr_audio", subset)
        file_io.check_create_dir(embed_dir)

        unique_paths = np.unique(exp.embed_paths)

        # batch base embeddings for faster embedding inference
        path_ds = tf.data.Dataset.from_tensor_slices(unique_paths)
        path_ds = path_ds.batch(model_options["batch_size"])
        path_ds = path_ds.prefetch(tf.data.experimental.AUTOTUNE)

        num_samples = int(
            np.ceil(len(unique_paths) / model_options["batch_size"]))

        start_time = time.time()
        paths, embeddings = [], []
        for path_batch in tqdm(path_ds, total=num_samples):
            path_embeddings = embedding_model.predict(
                data_preprocess_func(path_batch))

            paths.extend(path_batch.numpy())
            embeddings.extend(path_embeddings.numpy())
        end_time = time.time()

        logging.log(
            logging.INFO,
            f"Computed embeddings for Flickr Audio {subset} in "
            f"{end_time - start_time:.4f} seconds")

        # serialize and write embeddings to TFRecord files
        for path, embedding in zip(paths, embeddings):
            example_proto = dataset.embedding_to_example_protobuf(embedding)

            path = path.decode("utf-8")
            path = path.split(".tfrecord")[0]  # remove any ".tfrecord" ext
            path = os.path.join(
                embed_dir, f"{os.path.split(path)[1]}.tfrecord")

            with tf.io.TFRecordWriter(path, options="ZLIB") as writer:
                writer.write(example_proto.SerializeToString())

        logging.log(logging.INFO, f"Embeddings stored at: {embed_dir}")
コード例 #6
0
def train(model_options, output_dir, model_file=None, model_step_file=None,
          tf_writer=None):
    """Create and train spoken word similarity model for one-shot learning."""

    # load embeddings from dense layer of base model
    embed_dir = os.path.join(model_options["base_dir"], "embed", "dense")

    # load training data (embed dir determines if mfcc/fbank)
    train_exp, dev_exp = dataset.create_flickr_audio_train_data(
        "mfcc", embed_dir=embed_dir, speaker_mode=FLAGS.speaker_mode)

    train_labels = []
    for keyword in train_exp.keywords_set[3]:
        label = train_exp.keyword_labels[keyword]
        train_labels.append(label)
    train_labels = np.asarray(train_labels)

    dev_labels = []
    for keyword in dev_exp.keywords_set[3]:
        label = train_exp.keyword_labels[keyword]
        dev_labels.append(label)
    dev_labels = np.asarray(dev_labels)

    train_paths = train_exp.embed_paths
    dev_paths = dev_exp.embed_paths

    # define preprocessing for base model embeddings
    preprocess_data_func = lambda example: dataset.parse_embedding_protobuf(
        example)["embed"]

    # create balanced batch training dataset pipeline
    if model_options["balanced"]:
        assert model_options["p"] is not None
        assert model_options["k"] is not None

        shuffle_train = False
        prefetch_train = False
        num_repeat = model_options["num_batches"]
        model_options["batch_size"] = model_options["p"] * model_options["k"]

        # get unique path train indices per unique label
        train_labels_series = pd.Series(train_labels)
        train_label_idx = {
            label: idx.values[
                np.unique(train_paths[idx.values], return_index=True)[1]]
            for label, idx in train_labels_series.groupby(
                train_labels_series).groups.items()}

        # cache paths to speed things up a little ...
        file_io.check_create_dir(os.path.join(output_dir, "cache"))

        # create a dataset for each unique keyword label (shuffled and cached)
        train_label_datasets = [
            tf.data.Dataset.zip((
                tf.data.Dataset.from_tensor_slices(train_paths[idx]),
                tf.data.Dataset.from_tensor_slices(train_labels[idx]))).cache(
                    os.path.join(
                        output_dir, "cache", str(label))).shuffle(20)  # len(idx)
            for label, idx in train_label_idx.items()]

        # create a dataset that samples balanced batches from the label datasets
        background_train_ds = dataset.create_balanced_batch_dataset(
            model_options["p"], model_options["k"], train_label_datasets)

    # create standard training dataset pipeline (shuffle and load training set)
    else:
        shuffle_train = True
        prefetch_train = True
        num_repeat = model_options["num_augment"]

        background_train_ds = tf.data.Dataset.zip((
            tf.data.Dataset.from_tensor_slices(train_paths),
            tf.data.Dataset.from_tensor_slices(train_labels)))

    # load embedding TFRecords (faster here than before balanced sampling)
    # batch to read files in parallel
    background_train_ds = background_train_ds.batch(
        model_options["batch_size"])

    background_train_ds = background_train_ds.flat_map(
        lambda paths, labels: tf.data.Dataset.zip((
            tf.data.TFRecordDataset(
                paths, compression_type="ZLIB", num_parallel_reads=8),
            tf.data.Dataset.from_tensor_slices(labels))))

    # map data preprocessing function across training data
    background_train_ds = background_train_ds.map(
        lambda data, label: (preprocess_data_func(data), label),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    # repeat augmentation, shuffle and batch train data
    if num_repeat is not None:
        background_train_ds = background_train_ds.repeat(num_repeat)

    if shuffle_train:
        background_train_ds = background_train_ds.shuffle(1000)

    background_train_ds = background_train_ds.batch(
        model_options["batch_size"])

    if prefetch_train:
        background_train_ds = background_train_ds.prefetch(
            tf.data.experimental.AUTOTUNE)

    # create dev set pipeline for siamese validation
    background_dev_ds = tf.data.Dataset.zip((
        tf.data.TFRecordDataset(
            dev_paths, compression_type="ZLIB",
            num_parallel_reads=8).map(preprocess_data_func),
        tf.data.Dataset.from_tensor_slices(dev_labels)))

    background_dev_ds = background_dev_ds.batch(
        batch_size=model_options["batch_size"])

    # get training objective
    triplet_loss = get_training_objective(model_options)

    # get model input shape
    for x_batch, _ in background_train_ds.take(1):
        model_options["base_embed_size"] = int(
            tf.shape(x_batch)[1].numpy())

        model_options["input_shape"] = [model_options["base_embed_size"]]

    # load or create model
    if model_file is not None:
        speech_network, train_state = model_utils.load_model(
            model_file=os.path.join(output_dir, model_file),
            model_step_file=os.path.join(output_dir, model_step_file),
            loss=triplet_loss)

        # get previous tracking variables
        initial_model = False
        global_step, model_epochs, _, best_val_score = train_state
    else:
        speech_network = create_speech_network(model_options)

        # create tracking variables
        initial_model = True
        global_step = 0
        model_epochs = 0

        if model_options["one_shot_validation"]:
            best_val_score = -np.inf
        else:
            best_val_score = np.inf

    # load or create Adam optimizer with decayed learning rate
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        model_options["learning_rate"], decay_rate=model_options["decay_rate"],
        decay_steps=model_options["decay_steps"], staircase=True)

    if model_file is not None:
        logging.log(logging.INFO, "Restoring optimizer state")
        optimizer = speech_network.optimizer
    else:
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    # compile model to store optimizer with model when saving
    speech_network.compile(optimizer=optimizer, loss=triplet_loss)

    # create few-shot model from speech network for background training
    speech_few_shot_model = base.BaseModel(speech_network, triplet_loss)

    # test model on one-shot validation task prior to training
    if model_options["one_shot_validation"]:

        one_shot_dev_exp = flickr_speech.FlickrSpeech(
            features="mfcc",
            keywords_split="background_dev",
            preprocess_func=get_data_preprocess_func(),
            embed_dir=embed_dir,
            speaker_mode=FLAGS.speaker_mode)

        embedding_model_func = create_embedding_model

        classification = False
        if FLAGS.classification:
            assert FLAGS.fine_tune_steps is not None
            classification = True

        # create few-shot model from speech network for one-shot validation
        if FLAGS.fine_tune_steps is not None:
            test_few_shot_model = create_fine_tune_model(
                model_options, speech_few_shot_model.model, num_classes=FLAGS.L)
        else:
            test_few_shot_model = base.BaseModel(
                speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout)

        val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot(
            one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes,
            k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric,
            classification=classification, model=test_few_shot_model,
            embedding_model_func=embedding_model_func,
            fine_tune_steps=FLAGS.fine_tune_steps,
            fine_tune_lr=FLAGS.fine_tune_lr)

        logging.log(
            logging.INFO,
            f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
            f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
            f"{conf_interval_95*100:.4f}")

    # create training metrics
    loss_metric = tf.keras.metrics.Mean()
    best_model = False

    # store model options on first run
    if initial_model:
        file_io.write_json(
            os.path.join(output_dir, "model_options.json"), model_options)

    # train model
    for epoch in range(model_epochs, model_options["epochs"]):
        logging.log(logging.INFO, f"Epoch {epoch:03d}")

        loss_metric.reset_states()

        # train on epoch of training data
        step_pbar = tqdm(background_train_ds,
                         bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]")
        for step, (x_batch, y_batch) in enumerate(step_pbar):

            loss_value, y_predict = speech_few_shot_model.train_step(
                x_batch, y_batch, optimizer,
                clip_norm=model_options["gradient_clip_norm"])

            loss_metric.update_state(loss_value)

            step_loss = tf.reduce_mean(loss_value)
            train_loss = loss_metric.result().numpy()

            step_pbar.set_description_str(
                f"\tStep {step:03d}: "
                f"Step loss: {step_loss:.6f}, "
                f"Loss: {train_loss:.6f}")

            if tf_writer is not None:
                with tf_writer.as_default():
                    tf.summary.scalar(
                        "Train step loss", step_loss, step=global_step)
            global_step += 1

        # validate siamese model
        loss_metric.reset_states()

        for x_batch, y_batch in background_dev_ds:
            y_predict = speech_few_shot_model.predict(x_batch, training=False)
            loss_value = speech_few_shot_model.loss(y_batch, y_predict)

            loss_metric.update_state(loss_value)

        dev_loss = loss_metric.result().numpy()

        # validate model on one-shot dev task if specified
        if model_options["one_shot_validation"]:

            if FLAGS.fine_tune_steps is not None:
                test_few_shot_model = create_fine_tune_model(
                    model_options, speech_few_shot_model.model,
                    num_classes=FLAGS.L)
            else:
                test_few_shot_model = base.BaseModel(
                    speech_few_shot_model.model, None,
                    mc_dropout=FLAGS.mc_dropout)

            val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot(
                one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N,
                num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours,
                metric=FLAGS.metric, classification=classification,
                model=test_few_shot_model,
                embedding_model_func=embedding_model_func,
                fine_tune_steps=FLAGS.fine_tune_steps,
                fine_tune_lr=FLAGS.fine_tune_lr)

            val_score = val_task_accuracy
            val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy"

            if val_score >= best_val_score:
                best_val_score = val_score
                best_model = True

        # otherwise, validate on siamese task
        else:
            val_score = dev_loss
            val_metric = "loss"

            if val_score <= best_val_score:
                best_val_score = val_score
                best_model = True

        # log results
        logging.log(logging.INFO, f"Train: Loss: {train_loss:.6f}")

        logging.log(
            logging.INFO,
            f"Validation: Loss: {dev_loss:.6f} {'*' if best_model else ''}")

        if model_options["one_shot_validation"]:
            logging.log(
                logging.INFO,
                f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
                f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
                f"{conf_interval_95*100:.4f} {'*' if best_model else ''}")

        if tf_writer is not None:
            with tf_writer.as_default():
                tf.summary.scalar(
                    "Train step loss", train_loss, step=global_step)
                tf.summary.scalar(
                    f"Validation loss", dev_loss, step=global_step)
                if model_options["one_shot_validation"]:
                    tf.summary.scalar(
                        f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy",
                        val_task_accuracy, step=global_step)

        # store model and results
        model_utils.save_model(
            speech_few_shot_model.model, output_dir, epoch + 1, global_step,
            val_metric, val_score, best_val_score, name="model")

        if best_model:
            best_model = False
            model_utils.save_model(
                speech_few_shot_model.model, output_dir, epoch + 1, global_step,
                val_metric, val_score, best_val_score, name="best_model")
コード例 #7
0
ファイル: run.py プロジェクト: rpeloff/moonshot
def embed(model_options, output_dir, model_file, model_step_file):
    """Load siamese image similarity model and extract embeddings."""

    # get base embeddings directory if specified, otherwise embed images
    embed_dir = None
    if model_options["use_embeddings"]:
        # load embeddings from dense layer of base model
        embed_dir = os.path.join(model_options["base_dir"], "embed", "dense")

    # load model
    vision_network, _ = model_utils.load_model(
        model_file=os.path.join(output_dir, model_file),
        model_step_file=os.path.join(output_dir, model_step_file),
        loss=get_training_objective(model_options))

    # get model embedding model and data preprocessing
    embedding_model = create_embedding_model(vision_network)
    data_preprocess_func = get_data_preprocess_func(model_options)

    # load image datasets and compute embeddings
    for data in ["flickr8k", "flickr30k", "mscoco"]:

        train_image_dir_dict = {}
        dev_image_dir_dict = {}

        if data == "flickr8k":
            train_image_dir_dict["flickr8k_image_dir"] = os.path.join(
                "data", "external", "flickr8k_images")
            dev_image_dir_dict = train_image_dir_dict

        if data == "flickr30k":
            train_image_dir_dict["flickr30k_image_dir"] = os.path.join(
                "data", "external", "flickr30k_images")
            dev_image_dir_dict = train_image_dir_dict

        if data == "mscoco":
            train_image_dir_dict["mscoco_image_dir"] = os.path.join(
                "data", "external", "mscoco", "train2017")
            dev_image_dir_dict["mscoco_image_dir"] = os.path.join(
                "data", "external", "mscoco", "val2017")

        one_shot_exp = flickr_vision.FlickrVision(
            keywords_split="one_shot_evaluation",
            **train_image_dir_dict,
            embed_dir=embed_dir)

        background_train_exp = flickr_vision.FlickrVision(
            keywords_split="background_train",
            **train_image_dir_dict,
            embed_dir=embed_dir)

        background_dev_exp = flickr_vision.FlickrVision(
            keywords_split="background_dev",
            **dev_image_dir_dict,
            embed_dir=embed_dir)

        subset_exp = {
            "one_shot_evaluation": one_shot_exp,
            "background_train": background_train_exp,
            "background_dev": background_dev_exp,
        }

        for subset, exp in subset_exp.items():
            output_embed_dir = os.path.join(output_dir, "embed", "dense", data,
                                            subset)
            file_io.check_create_dir(output_embed_dir)

            if model_options["use_embeddings"]:
                subset_paths = exp.embed_paths
            else:
                subset_paths = exp.image_paths

            unique_paths = np.unique(subset_paths)

            # batch images/base embeddings for faster embedding inference
            path_ds = tf.data.Dataset.from_tensor_slices(unique_paths)
            path_ds = path_ds.batch(model_options["batch_size"])
            path_ds = path_ds.prefetch(tf.data.experimental.AUTOTUNE)

            num_samples = int(
                np.ceil(len(unique_paths) / model_options["batch_size"]))

            start_time = time.time()
            paths, embeddings = [], []
            for path_batch in tqdm(path_ds, total=num_samples):
                path_embeddings = embedding_model.predict(
                    data_preprocess_func(path_batch))

                paths.extend(path_batch.numpy())
                embeddings.extend(path_embeddings.numpy())
            end_time = time.time()

            logging.log(
                logging.INFO, f"Computed embeddings for {data} {subset} in "
                f"{end_time - start_time:.4f} seconds")

            # serialize and write embeddings to TFRecord files
            for path, embedding in zip(paths, embeddings):
                example_proto = dataset.embedding_to_example_protobuf(
                    embedding)

                path = path.decode("utf-8")
                path = path.split(".tfrecord")[0]  # remove any ".tfrecord" ext
                path = os.path.join(output_embed_dir,
                                    f"{os.path.split(path)[1]}.tfrecord")

                with tf.io.TFRecordWriter(path, options="ZLIB") as writer:
                    writer.write(example_proto.SerializeToString())

            logging.log(logging.INFO,
                        f"Embeddings stored at: {output_embed_dir}")