Ejemplo n.º 1
0
def create_flickr_multimodal_train_data(features,
                                        speech_embed_dir=None,
                                        image_embed_dir=None,
                                        speech_preprocess_func=None,
                                        image_preprocess_func=None,
                                        speaker_mode="baseline",
                                        unseen_match_set=False):
    """Load train and validation paired Flickr 8k and Flickr Audio data."""

    flickr_train_exp = flickr_multimodal.FlickrMultimodal(
        features=features,
        keywords_split="background_train",
        flickr8k_image_dir=os.path.join("data", "external", "flickr8k_images"),
        speech_embed_dir=speech_embed_dir,
        image_embed_dir=image_embed_dir,
        speech_preprocess_func=speech_preprocess_func,
        image_preprocess_func=image_preprocess_func,
        speaker_mode=speaker_mode,
        unseen_match_set=unseen_match_set)

    flickr_dev_exp = flickr_multimodal.FlickrMultimodal(
        features=features,
        keywords_split="background_dev",
        flickr8k_image_dir=os.path.join("data", "external", "flickr8k_images"),
        speech_embed_dir=speech_embed_dir,
        image_embed_dir=image_embed_dir,
        speech_preprocess_func=speech_preprocess_func,
        image_preprocess_func=image_preprocess_func,
        speaker_mode=speaker_mode,
        unseen_match_set=unseen_match_set)

    return flickr_train_exp, flickr_dev_exp
Ejemplo n.º 2
0
def test():
    """Test extracted image and speech model embeddings for one-shot learning."""

    # load embeddings from (linear) dense layer of base speech and vision models
    speech_embed_dir = os.path.join(FLAGS.audio_base_dir, "embed", "dense")
    image_embed_dir = os.path.join(FLAGS.vision_base_dir, "embed", "dense")

    # load Flickr Audio one-shot experiment
    one_shot_exp = flickr_multimodal.FlickrMultimodal(
        features="mfcc",
        keywords_split="one_shot_evaluation",
        flickr8k_image_dir=os.path.join("data", "external", "flickr8k_images"),
        speech_embed_dir=speech_embed_dir,
        image_embed_dir=image_embed_dir,
        speech_preprocess_func=data_preprocess_func,
        image_preprocess_func=data_preprocess_func,
        speaker_mode=FLAGS.speaker_mode,
        unseen_match_set=FLAGS.unseen_match_set)

    # test model on L-way K-shot task
    task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot(
        one_shot_exp,
        FLAGS.K,
        FLAGS.L,
        n=FLAGS.N,
        num_episodes=FLAGS.episodes,
        k_neighbours=FLAGS.k_neighbours,
        metric=FLAGS.metric)

    logging.log(
        logging.INFO,
        f"{FLAGS.L}-way {FLAGS.K}-shot accuracy after {FLAGS.episodes} "
        f"episodes: {task_accuracy:.3%} +- {conf_interval_95*100:.4f}")
Ejemplo n.º 3
0
def test(model_options, output_dir, model_file, model_step_file):
    """Load and test siamese audio-visual similarity model for one-shot learning."""

    # load embeddings from (linear) dense layer of base speech and vision models
    speech_embed_dir = os.path.join(model_options["audio_base_dir"], "embed",
                                    "dense")

    image_embed_dir = os.path.join(model_options["vision_base_dir"], "embed",
                                   "dense")

    # load Flickr Audio one-shot experiment
    one_shot_exp = flickr_multimodal.FlickrMultimodal(
        features="mfcc",
        keywords_split="one_shot_evaluation",
        flickr8k_image_dir=os.path.join("data", "external", "flickr8k_images"),
        speech_embed_dir=speech_embed_dir,
        image_embed_dir=image_embed_dir,
        speech_preprocess_func=data_preprocess_func,
        image_preprocess_func=data_preprocess_func,
        speaker_mode=FLAGS.speaker_mode,
        unseen_match_set=FLAGS.unseen_match_set)

    # load joint audio-visual model
    join_network_model, _ = model_utils.load_model(
        model_file=os.path.join(output_dir, model_file),
        model_step_file=os.path.join(output_dir, model_step_file),
        loss=get_training_objective(model_options))

    speech_network = tf.keras.Model(inputs=join_network_model.inputs[0],
                                    outputs=join_network_model.outputs[0])

    vision_network = tf.keras.Model(inputs=join_network_model.inputs[1],
                                    outputs=join_network_model.outputs[1])

    # create few-shot model from speech and vision networks for one-shot validation
    test_few_shot_model = create_fine_tune_model(model_options, speech_network,
                                                 vision_network)

    fine_tune_optimizer = None
    if not FLAGS.adam:
        fine_tune_optimizer = tf.keras.optimizers.SGD(FLAGS.fine_tune_lr)

    logging.log(logging.INFO, "Created few-shot model from speech network")
    test_few_shot_model.speech_model.model.summary()

    logging.log(logging.INFO, "Created few-shot model from vision network")
    test_few_shot_model.vision_model.model.summary()

    # test model on L-way K-shot multimodal task
    task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot(
        one_shot_exp,
        FLAGS.K,
        FLAGS.L,
        n=FLAGS.N,
        num_episodes=FLAGS.episodes,
        k_neighbours=FLAGS.k_neighbours,
        metric=FLAGS.metric,
        direct_match=FLAGS.direct_match,
        multimodal_model=test_few_shot_model,
        multimodal_embedding_func=None,  #create_embedding_model,
        optimizer=fine_tune_optimizer,
        fine_tune_steps=FLAGS.fine_tune_steps,
        fine_tune_lr=FLAGS.fine_tune_lr)

    logging.log(
        logging.INFO,
        f"{FLAGS.L}-way {FLAGS.K}-shot accuracy after {FLAGS.episodes} "
        f"episodes: {task_accuracy:.3%} +- {conf_interval_95*100:.4f}")
Ejemplo n.º 4
0
def train(model_options, output_dir, model_file=None, model_step_file=None,
          tf_writer=None):
    """Create and train audio-visual similarity model for one-shot learning."""

    # load embeddings from (linear) dense layer of base speech and vision models
    speech_embed_dir = os.path.join(
        model_options["audio_base_dir"], "embed", "dense")

    image_embed_dir = os.path.join(
        model_options["vision_base_dir"], "embed", "dense")

    # load training data (embed dir determines mfcc/fbank speech features)
    train_exp, dev_exp = dataset.create_flickr_multimodal_train_data(
        "mfcc", speech_embed_dir=speech_embed_dir,
        image_embed_dir=image_embed_dir, speaker_mode=FLAGS.speaker_mode,
        unseen_match_set=FLAGS.unseen_match_set)

    train_speech_paths = train_exp.speech_experiment.data
    train_image_paths = train_exp.vision_experiment.data

    dev_speech_paths = dev_exp.speech_experiment.data
    dev_image_paths = dev_exp.vision_experiment.data

    # define preprocessing for base model embeddings
    preprocess_data_func = lambda example: dataset.parse_embedding_protobuf(
        example)["embed"]

    # create standard training dataset pipeline
    background_train_ds = tf.data.Dataset.zip((
        tf.data.TFRecordDataset(
            train_speech_paths, compression_type="ZLIB"),
        tf.data.TFRecordDataset(
            train_image_paths, compression_type="ZLIB")))

    # map data preprocessing function across training data
    background_train_ds = background_train_ds.map(
        lambda speech_path, image_path: (
            preprocess_data_func(speech_path), preprocess_data_func(image_path)),
        num_parallel_calls=8)

    # shuffle and batch train data
    background_train_ds = background_train_ds.repeat(-1)

    background_train_ds = background_train_ds.shuffle(1000)

    background_train_ds = background_train_ds.batch(
        model_options["batch_size"])

    background_train_ds = background_train_ds.take(
        model_options["num_batches"])

    background_train_ds = background_train_ds.prefetch(
            tf.data.experimental.AUTOTUNE)

    # create dev set pipeline for validation
    background_dev_ds = tf.data.Dataset.zip((
        tf.data.TFRecordDataset(
            dev_speech_paths, compression_type="ZLIB"),
        tf.data.TFRecordDataset(
            dev_image_paths, compression_type="ZLIB")))

    background_dev_ds = background_dev_ds.map(
        lambda speech_path, image_path: (
            preprocess_data_func(speech_path), preprocess_data_func(image_path)),
        num_parallel_calls=8)

    background_dev_ds = background_dev_ds.batch(
        batch_size=model_options["batch_size"])

    # get training objective
    triplet_loss = get_training_objective(model_options)

    # get model input shape
    for speech_batch, image_batch in background_train_ds.take(1):
        model_options["audio_base_embed_size"] = int(
            tf.shape(speech_batch)[1].numpy())
        model_options["vision_base_embed_size"] = int(
            tf.shape(image_batch)[1].numpy())

        model_options["audio_input_shape"] = [
            model_options["audio_base_embed_size"]]
        model_options["vision_input_shape"] = [
            model_options["vision_base_embed_size"]]

    # load or create models
    if model_file is not None:
        join_network_model, train_state = model_utils.load_model(
            model_file=os.path.join(output_dir, model_file),
            model_step_file=os.path.join(output_dir, model_step_file),
            loss=get_training_objective(model_options))

        speech_network = tf.keras.Model(
            inputs=join_network_model.inputs[0],
            outputs=join_network_model.outputs[0])

        vision_network = tf.keras.Model(
            inputs=join_network_model.inputs[1],
            outputs=join_network_model.outputs[1])

        # get previous tracking variables
        initial_model = False
        global_step, model_epochs, _, best_val_score = train_state
    else:
        speech_network = create_speech_network(model_options)
        vision_network = create_vision_network(model_options)

        # create tracking variables
        initial_model = True
        global_step = 0
        model_epochs = 0

        if model_options["one_shot_validation"]:
            best_val_score = -np.inf
        else:
            best_val_score = np.inf

    # load or create Adam optimizer with decayed learning rate
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        model_options["learning_rate"], decay_rate=model_options["decay_rate"],
        decay_steps=model_options["decay_steps"], staircase=True)

    if model_file is not None:
        logging.log(logging.INFO, "Restoring optimizer state")
        optimizer = join_network_model.optimizer
    else:
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    # compile models to store optimizer with model when saving
    join_network_model = tf.keras.Model(
        inputs=[speech_network.input, vision_network.input],
        outputs=[speech_network.output, vision_network.output])

    # speech_network.compile(optimizer=optimizer, loss=triplet_loss)
    join_network_model.compile(optimizer=optimizer, loss=triplet_loss)

    # create few-shot model from speech network for background training
    multimodal_model = base.WeaklySupervisedModel(
        speech_network, vision_network, triplet_loss)

    # test model on one-shot validation task prior to training
    if model_options["one_shot_validation"]:

        one_shot_dev_exp = flickr_multimodal.FlickrMultimodal(
            features="mfcc", keywords_split="background_dev",
            flickr8k_image_dir=os.path.join("data", "external", "flickr8k_images"),
            speech_embed_dir=speech_embed_dir, image_embed_dir=image_embed_dir,
            speech_preprocess_func=data_preprocess_func,
            image_preprocess_func=data_preprocess_func,
            speaker_mode=FLAGS.speaker_mode,
            unseen_match_set=FLAGS.unseen_match_set)

        # create few-shot model from speech and vision networks for one-shot validation
        if FLAGS.fine_tune_steps is not None:
            test_few_shot_model = create_fine_tune_model(
                model_options, speech_network, vision_network)
        else:
            test_few_shot_model = base.WeaklySupervisedModel(
                speech_network, vision_network, None,
                mc_dropout=FLAGS.mc_dropout)

        val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot(
            one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N,
            num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours,
            metric=FLAGS.metric, direct_match=FLAGS.direct_match,
            multimodal_model=test_few_shot_model,
            multimodal_embedding_func=None, #create_embedding_model,
            fine_tune_steps=FLAGS.fine_tune_steps,
            fine_tune_lr=FLAGS.fine_tune_lr)

        logging.log(
            logging.INFO,
            f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
            f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
            f"{conf_interval_95*100:.4f}")


    # create training metrics
    loss_metric = tf.keras.metrics.Mean()
    best_model = False

    # store model options on first run
    if initial_model:
        file_io.write_json(
            os.path.join(output_dir, "model_options.json"), model_options)

    # train model
    for epoch in range(model_epochs, model_options["epochs"]):
        logging.log(logging.INFO, f"Epoch {epoch:03d}")

        loss_metric.reset_states()

        # train on epoch of training data
        step_pbar = tqdm(background_train_ds,
                         bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]")
        for step, (speech_batch, image_batch) in enumerate(step_pbar):

            loss_value, y_speech, y_image = multimodal_model.train_step(
                speech_batch, image_batch, optimizer,
                clip_norm=model_options["gradient_clip_norm"])

            loss_metric.update_state(loss_value)

            step_loss = tf.reduce_mean(loss_value)
            train_loss = loss_metric.result().numpy()

            step_pbar.set_description_str(
                f"\tStep {step:03d}: "
                f"Step loss: {step_loss:.6f}, "
                f"Loss: {train_loss:.6f}")

            if tf_writer is not None:
                with tf_writer.as_default():
                    tf.summary.scalar(
                        "Train step loss", step_loss, step=global_step)
            global_step += 1

        # validate multimodal model
        loss_metric.reset_states()

        for speech_batch, image_batch in background_dev_ds:
            y_speech = multimodal_model.speech_model.predict(
                speech_batch, training=False)
            y_image = multimodal_model.vision_model.predict(
                image_batch, training=False)
            loss_value = multimodal_model.loss(y_speech, y_image)

            loss_metric.update_state(loss_value)

        dev_loss = loss_metric.result().numpy()

        # validate model on one-shot dev task if specified
        if model_options["one_shot_validation"]:

            if FLAGS.fine_tune_steps is not None:
                test_few_shot_model = create_fine_tune_model(
                    model_options, speech_network, vision_network)
            else:
                test_few_shot_model = base.WeaklySupervisedModel(
                    speech_network, vision_network, None,
                    mc_dropout=FLAGS.mc_dropout)

            val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot(
                one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N,
                num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours,
                metric=FLAGS.metric, direct_match=FLAGS.direct_match,
                multimodal_model=test_few_shot_model,
                multimodal_embedding_func=None, #create_embedding_model,
                fine_tune_steps=FLAGS.fine_tune_steps,
                fine_tune_lr=FLAGS.fine_tune_lr)

            val_score = val_task_accuracy
            val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy"

            if val_score >= best_val_score:
                best_val_score = val_score
                best_model = True

        # otherwise, validate on siamese task
        else:
            val_score = dev_loss
            val_metric = "loss"

            if val_score <= best_val_score:
                best_val_score = val_score
                best_model = True

        # log results
        logging.log(logging.INFO, f"Train: Loss: {train_loss:.6f}")

        logging.log(
            logging.INFO,
            f"Validation: Loss: {dev_loss:.6f} {'*' if best_model else ''}")

        if model_options["one_shot_validation"]:
            logging.log(
                logging.INFO,
                f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
                f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
                f"{conf_interval_95*100:.4f} {'*' if best_model else ''}")

        if tf_writer is not None:
            with tf_writer.as_default():
                tf.summary.scalar(
                    "Train step loss", train_loss, step=global_step)
                tf.summary.scalar(
                    f"Validation loss", dev_loss, step=global_step)
                if model_options["one_shot_validation"]:
                    tf.summary.scalar(
                        f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy",
                        val_task_accuracy, step=global_step)

        # store model and results
        # model_utils.save_model(
        #     multimodal_model.model_a.model, output_dir, epoch + 1, global_step,
        #     val_metric, val_score, best_val_score, name="audio_model")
        # model_utils.save_model(
        #     multimodal_model.model_b.model, output_dir, epoch + 1, global_step,
        #     val_metric, val_score, best_val_score, name="vision_model")
        model_utils.save_model(
            join_network_model, output_dir, epoch + 1, global_step,
            val_metric, val_score, best_val_score, name="model")

        if best_model:
            best_model = False
            # model_utils.save_model(
            #     multimodal_model.model_a.model, output_dir, epoch + 1, global_step,
            #     val_metric, val_score, best_val_score, name="audio_best_model")
            # model_utils.save_model(
            #     multimodal_model.model_b.model, output_dir, epoch + 1, global_step,
            #     val_metric, val_score, best_val_score, name="vision_best_model")
            model_utils.save_model(
                join_network_model, output_dir, epoch + 1, global_step,
                val_metric, val_score, best_val_score, name="best_model")


    import pdb; pdb.set_trace()