Exemple #1
0
def train(model_options, output_dir, model_file=None, model_step_file=None,
          tf_writer=None):
    """Create and train spoken word similarity model for one-shot learning."""

    # load embeddings from dense layer of base model
    embed_dir = os.path.join(model_options["base_dir"], "embed", "dense")

    # load training data (embed dir determines if mfcc/fbank)
    train_exp, dev_exp = dataset.create_flickr_audio_train_data(
        "mfcc", embed_dir=embed_dir, speaker_mode=FLAGS.speaker_mode)

    train_labels = []
    for keyword in train_exp.keywords_set[3]:
        label = train_exp.keyword_labels[keyword]
        train_labels.append(label)
    train_labels = np.asarray(train_labels)

    dev_labels = []
    for keyword in dev_exp.keywords_set[3]:
        label = train_exp.keyword_labels[keyword]
        dev_labels.append(label)
    dev_labels = np.asarray(dev_labels)

    train_paths = train_exp.embed_paths
    dev_paths = dev_exp.embed_paths

    # define preprocessing for base model embeddings
    preprocess_data_func = lambda example: dataset.parse_embedding_protobuf(
        example)["embed"]

    # create balanced batch training dataset pipeline
    if model_options["balanced"]:
        assert model_options["p"] is not None
        assert model_options["k"] is not None

        shuffle_train = False
        prefetch_train = False
        num_repeat = model_options["num_batches"]
        model_options["batch_size"] = model_options["p"] * model_options["k"]

        # get unique path train indices per unique label
        train_labels_series = pd.Series(train_labels)
        train_label_idx = {
            label: idx.values[
                np.unique(train_paths[idx.values], return_index=True)[1]]
            for label, idx in train_labels_series.groupby(
                train_labels_series).groups.items()}

        # cache paths to speed things up a little ...
        file_io.check_create_dir(os.path.join(output_dir, "cache"))

        # create a dataset for each unique keyword label (shuffled and cached)
        train_label_datasets = [
            tf.data.Dataset.zip((
                tf.data.Dataset.from_tensor_slices(train_paths[idx]),
                tf.data.Dataset.from_tensor_slices(train_labels[idx]))).cache(
                    os.path.join(
                        output_dir, "cache", str(label))).shuffle(20)  # len(idx)
            for label, idx in train_label_idx.items()]

        # create a dataset that samples balanced batches from the label datasets
        background_train_ds = dataset.create_balanced_batch_dataset(
            model_options["p"], model_options["k"], train_label_datasets)

    # create standard training dataset pipeline (shuffle and load training set)
    else:
        shuffle_train = True
        prefetch_train = True
        num_repeat = model_options["num_augment"]

        background_train_ds = tf.data.Dataset.zip((
            tf.data.Dataset.from_tensor_slices(train_paths),
            tf.data.Dataset.from_tensor_slices(train_labels)))

    # load embedding TFRecords (faster here than before balanced sampling)
    # batch to read files in parallel
    background_train_ds = background_train_ds.batch(
        model_options["batch_size"])

    background_train_ds = background_train_ds.flat_map(
        lambda paths, labels: tf.data.Dataset.zip((
            tf.data.TFRecordDataset(
                paths, compression_type="ZLIB", num_parallel_reads=8),
            tf.data.Dataset.from_tensor_slices(labels))))

    # map data preprocessing function across training data
    background_train_ds = background_train_ds.map(
        lambda data, label: (preprocess_data_func(data), label),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    # repeat augmentation, shuffle and batch train data
    if num_repeat is not None:
        background_train_ds = background_train_ds.repeat(num_repeat)

    if shuffle_train:
        background_train_ds = background_train_ds.shuffle(1000)

    background_train_ds = background_train_ds.batch(
        model_options["batch_size"])

    if prefetch_train:
        background_train_ds = background_train_ds.prefetch(
            tf.data.experimental.AUTOTUNE)

    # create dev set pipeline for siamese validation
    background_dev_ds = tf.data.Dataset.zip((
        tf.data.TFRecordDataset(
            dev_paths, compression_type="ZLIB",
            num_parallel_reads=8).map(preprocess_data_func),
        tf.data.Dataset.from_tensor_slices(dev_labels)))

    background_dev_ds = background_dev_ds.batch(
        batch_size=model_options["batch_size"])

    # get training objective
    triplet_loss = get_training_objective(model_options)

    # get model input shape
    for x_batch, _ in background_train_ds.take(1):
        model_options["base_embed_size"] = int(
            tf.shape(x_batch)[1].numpy())

        model_options["input_shape"] = [model_options["base_embed_size"]]

    # load or create model
    if model_file is not None:
        speech_network, train_state = model_utils.load_model(
            model_file=os.path.join(output_dir, model_file),
            model_step_file=os.path.join(output_dir, model_step_file),
            loss=triplet_loss)

        # get previous tracking variables
        initial_model = False
        global_step, model_epochs, _, best_val_score = train_state
    else:
        speech_network = create_speech_network(model_options)

        # create tracking variables
        initial_model = True
        global_step = 0
        model_epochs = 0

        if model_options["one_shot_validation"]:
            best_val_score = -np.inf
        else:
            best_val_score = np.inf

    # load or create Adam optimizer with decayed learning rate
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        model_options["learning_rate"], decay_rate=model_options["decay_rate"],
        decay_steps=model_options["decay_steps"], staircase=True)

    if model_file is not None:
        logging.log(logging.INFO, "Restoring optimizer state")
        optimizer = speech_network.optimizer
    else:
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    # compile model to store optimizer with model when saving
    speech_network.compile(optimizer=optimizer, loss=triplet_loss)

    # create few-shot model from speech network for background training
    speech_few_shot_model = base.BaseModel(speech_network, triplet_loss)

    # test model on one-shot validation task prior to training
    if model_options["one_shot_validation"]:

        one_shot_dev_exp = flickr_speech.FlickrSpeech(
            features="mfcc",
            keywords_split="background_dev",
            preprocess_func=get_data_preprocess_func(),
            embed_dir=embed_dir,
            speaker_mode=FLAGS.speaker_mode)

        embedding_model_func = create_embedding_model

        classification = False
        if FLAGS.classification:
            assert FLAGS.fine_tune_steps is not None
            classification = True

        # create few-shot model from speech network for one-shot validation
        if FLAGS.fine_tune_steps is not None:
            test_few_shot_model = create_fine_tune_model(
                model_options, speech_few_shot_model.model, num_classes=FLAGS.L)
        else:
            test_few_shot_model = base.BaseModel(
                speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout)

        val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot(
            one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes,
            k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric,
            classification=classification, model=test_few_shot_model,
            embedding_model_func=embedding_model_func,
            fine_tune_steps=FLAGS.fine_tune_steps,
            fine_tune_lr=FLAGS.fine_tune_lr)

        logging.log(
            logging.INFO,
            f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
            f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
            f"{conf_interval_95*100:.4f}")

    # create training metrics
    loss_metric = tf.keras.metrics.Mean()
    best_model = False

    # store model options on first run
    if initial_model:
        file_io.write_json(
            os.path.join(output_dir, "model_options.json"), model_options)

    # train model
    for epoch in range(model_epochs, model_options["epochs"]):
        logging.log(logging.INFO, f"Epoch {epoch:03d}")

        loss_metric.reset_states()

        # train on epoch of training data
        step_pbar = tqdm(background_train_ds,
                         bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]")
        for step, (x_batch, y_batch) in enumerate(step_pbar):

            loss_value, y_predict = speech_few_shot_model.train_step(
                x_batch, y_batch, optimizer,
                clip_norm=model_options["gradient_clip_norm"])

            loss_metric.update_state(loss_value)

            step_loss = tf.reduce_mean(loss_value)
            train_loss = loss_metric.result().numpy()

            step_pbar.set_description_str(
                f"\tStep {step:03d}: "
                f"Step loss: {step_loss:.6f}, "
                f"Loss: {train_loss:.6f}")

            if tf_writer is not None:
                with tf_writer.as_default():
                    tf.summary.scalar(
                        "Train step loss", step_loss, step=global_step)
            global_step += 1

        # validate siamese model
        loss_metric.reset_states()

        for x_batch, y_batch in background_dev_ds:
            y_predict = speech_few_shot_model.predict(x_batch, training=False)
            loss_value = speech_few_shot_model.loss(y_batch, y_predict)

            loss_metric.update_state(loss_value)

        dev_loss = loss_metric.result().numpy()

        # validate model on one-shot dev task if specified
        if model_options["one_shot_validation"]:

            if FLAGS.fine_tune_steps is not None:
                test_few_shot_model = create_fine_tune_model(
                    model_options, speech_few_shot_model.model,
                    num_classes=FLAGS.L)
            else:
                test_few_shot_model = base.BaseModel(
                    speech_few_shot_model.model, None,
                    mc_dropout=FLAGS.mc_dropout)

            val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot(
                one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N,
                num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours,
                metric=FLAGS.metric, classification=classification,
                model=test_few_shot_model,
                embedding_model_func=embedding_model_func,
                fine_tune_steps=FLAGS.fine_tune_steps,
                fine_tune_lr=FLAGS.fine_tune_lr)

            val_score = val_task_accuracy
            val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy"

            if val_score >= best_val_score:
                best_val_score = val_score
                best_model = True

        # otherwise, validate on siamese task
        else:
            val_score = dev_loss
            val_metric = "loss"

            if val_score <= best_val_score:
                best_val_score = val_score
                best_model = True

        # log results
        logging.log(logging.INFO, f"Train: Loss: {train_loss:.6f}")

        logging.log(
            logging.INFO,
            f"Validation: Loss: {dev_loss:.6f} {'*' if best_model else ''}")

        if model_options["one_shot_validation"]:
            logging.log(
                logging.INFO,
                f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
                f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
                f"{conf_interval_95*100:.4f} {'*' if best_model else ''}")

        if tf_writer is not None:
            with tf_writer.as_default():
                tf.summary.scalar(
                    "Train step loss", train_loss, step=global_step)
                tf.summary.scalar(
                    f"Validation loss", dev_loss, step=global_step)
                if model_options["one_shot_validation"]:
                    tf.summary.scalar(
                        f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy",
                        val_task_accuracy, step=global_step)

        # store model and results
        model_utils.save_model(
            speech_few_shot_model.model, output_dir, epoch + 1, global_step,
            val_metric, val_score, best_val_score, name="model")

        if best_model:
            best_model = False
            model_utils.save_model(
                speech_few_shot_model.model, output_dir, epoch + 1, global_step,
                val_metric, val_score, best_val_score, name="best_model")
Exemple #2
0
def train(model_options,
          output_dir,
          model_file=None,
          model_step_file=None,
          tf_writer=None):
    """Create and train spoken word classification model for one-shot learning."""

    # load training data
    train_exp, dev_exp = dataset.create_flickr_audio_train_data(
        model_options["features"], speaker_mode=FLAGS.speaker_mode)

    train_labels = []
    for keyword in train_exp.keywords_set[3]:
        label = train_exp.keyword_labels[keyword]
        train_labels.append(label)
    train_labels = np.asarray(train_labels)

    dev_labels = []
    for keyword in dev_exp.keywords_set[3]:
        label = train_exp.keyword_labels[keyword]
        dev_labels.append(label)
    dev_labels = np.asarray(dev_labels)

    train_paths = train_exp.audio_paths
    dev_paths = dev_exp.audio_paths

    lb = LabelBinarizer()
    train_labels_one_hot = lb.fit_transform(train_labels)
    dev_labels_one_hot = lb.transform(dev_labels)

    # define preprocessing for speech features
    preprocess_speech_func = functools.partial(
        dataset.load_and_preprocess_speech,
        features=model_options["features"],
        max_length=model_options["max_length"],
        scaling=model_options["scaling"])

    preprocess_speech_ds_func = lambda path: tf.py_function(
        func=preprocess_speech_func, inp=[path], Tout=tf.float32)

    # create standard training dataset pipeline
    background_train_ds = tf.data.Dataset.zip(
        (tf.data.Dataset.from_tensor_slices(train_paths),
         tf.data.Dataset.from_tensor_slices(train_labels_one_hot)))

    # map data preprocessing function across training data
    background_train_ds = background_train_ds.map(
        lambda path, label: (preprocess_speech_ds_func(path), label),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    # shuffle and batch train data
    background_train_ds = background_train_ds.shuffle(1000)

    background_train_ds = background_train_ds.batch(
        model_options["batch_size"])

    background_train_ds = background_train_ds.prefetch(
        tf.data.experimental.AUTOTUNE)

    # create dev set pipeline for classification validation
    background_dev_ds = tf.data.Dataset.zip(
        (tf.data.Dataset.from_tensor_slices(dev_paths).map(
            preprocess_speech_ds_func),
         tf.data.Dataset.from_tensor_slices(dev_labels_one_hot)))

    background_dev_ds = background_dev_ds.batch(
        batch_size=model_options["batch_size"])

    # write example batch to TensorBoard
    if tf_writer is not None:
        logging.log(logging.INFO, "Writing example features to TensorBoard")
        with tf_writer.as_default():
            for x_batch, y_batch in background_train_ds.take(1):

                speech_feats = []
                for feats in x_batch[:30]:
                    feats = np.transpose(feats)
                    speech_feats.append(
                        (feats - np.min(feats)) / np.max(feats))

                tf.summary.image(
                    f"Example train speech {model_options['features']}",
                    np.expand_dims(speech_feats, axis=-1),
                    max_outputs=30,
                    step=0)

                labels = ""
                for i, label in enumerate(y_batch[:30]):
                    labels += f"{i}: {np.asarray(train_exp.keywords)[label]} "

                tf.summary.text("Example train labels", labels, step=0)

    # get training objective
    loss = get_training_objective(model_options)

    # get model input shape
    if model_options["features"] == "mfcc":
        model_options["input_shape"] = [model_options["max_length"], 39]
    else:
        model_options["input_shape"] = [model_options["max_length"], 40]

    # load or create model
    if model_file is not None:
        assert model_options["n_classes"] == len(train_exp.keywords)

        speech_network, train_state = model_utils.load_model(
            model_file=os.path.join(output_dir, model_file),
            model_step_file=os.path.join(output_dir, model_step_file),
            loss=loss)

        # get previous tracking variables
        initial_model = False
        global_step, model_epochs, _, best_val_score = train_state
    else:
        model_options["n_classes"] = len(train_exp.keywords)

        speech_network = create_speech_network(model_options)

        # create tracking variables
        initial_model = True
        global_step = 0
        model_epochs = 0

        if model_options["one_shot_validation"]:
            best_val_score = -np.inf
        else:
            best_val_score = np.inf

    # load or create Adam optimizer with decayed learning rate
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        model_options["learning_rate"],
        decay_rate=model_options["decay_rate"],
        decay_steps=model_options["decay_steps"],
        staircase=True)

    if model_file is not None:
        logging.log(logging.INFO, "Restoring optimizer state")
        optimizer = speech_network.optimizer
    else:
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    # compile model to store optimizer with model when saving
    speech_network.compile(optimizer=optimizer, loss=loss)

    # create few-shot model from speech network for background training
    speech_few_shot_model = base.BaseModel(speech_network, loss)

    # test model on one-shot validation task prior to training
    if model_options["one_shot_validation"]:

        one_shot_dev_exp = flickr_speech.FlickrSpeech(
            features=model_options["features"],
            keywords_split="background_dev",
            preprocess_func=get_data_preprocess_func(model_options),
            speaker_mode=FLAGS.speaker_mode)

        embedding_model_func = lambda speech_network: create_embedding_model(
            model_options, speech_network)

        classification = False
        if FLAGS.classification:
            assert FLAGS.embed_layer in ["logits", "softmax"]
            classification = True

        # create few-shot model from speech network for one-shot validation
        if FLAGS.fine_tune_steps is not None:
            test_few_shot_model = create_fine_tune_model(
                model_options,
                speech_few_shot_model.model,
                num_classes=FLAGS.L)
        else:
            test_few_shot_model = base.BaseModel(speech_few_shot_model.model,
                                                 None,
                                                 mc_dropout=FLAGS.mc_dropout)

        val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot(
            one_shot_dev_exp,
            FLAGS.K,
            FLAGS.L,
            n=FLAGS.N,
            num_episodes=FLAGS.episodes,
            k_neighbours=FLAGS.k_neighbours,
            metric=FLAGS.metric,
            classification=classification,
            model=test_few_shot_model,
            embedding_model_func=embedding_model_func,
            fine_tune_steps=FLAGS.fine_tune_steps,
            fine_tune_lr=FLAGS.fine_tune_lr)

        logging.log(
            logging.INFO,
            f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
            f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
            f"{conf_interval_95*100:.4f}")

    # create training metrics
    accuracy_metric = tf.keras.metrics.CategoricalAccuracy()
    loss_metric = tf.keras.metrics.Mean()
    best_model = False

    # store model options on first run
    if initial_model:
        file_io.write_json(os.path.join(output_dir, "model_options.json"),
                           model_options)

    # train model
    for epoch in range(model_epochs, model_options["epochs"]):
        logging.log(logging.INFO, f"Epoch {epoch:03d}")

        accuracy_metric.reset_states()
        loss_metric.reset_states()

        # train on epoch of training data
        step_pbar = tqdm(background_train_ds,
                         bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]")
        for step, (x_batch, y_batch) in enumerate(step_pbar):

            loss_value, y_predict = speech_few_shot_model.train_step(
                x_batch,
                y_batch,
                optimizer,
                clip_norm=model_options["gradient_clip_norm"])

            accuracy_metric.update_state(y_batch, y_predict)
            loss_metric.update_state(loss_value)

            step_loss = tf.reduce_mean(loss_value)
            train_loss = loss_metric.result().numpy()
            train_accuracy = accuracy_metric.result().numpy()

            step_pbar.set_description_str(
                f"\tStep {step:03d}: "
                f"Step loss: {step_loss:.6f}, "
                f"Loss: {train_loss:.6f}, "
                f"Categorical accuracy: {train_accuracy:.3%}")

            if tf_writer is not None:
                with tf_writer.as_default():
                    tf.summary.scalar("Train step loss",
                                      step_loss,
                                      step=global_step)
            global_step += 1

        # validate classification model
        accuracy_metric.reset_states()
        loss_metric.reset_states()

        for x_batch, y_batch in background_dev_ds:
            y_predict = speech_few_shot_model.predict(x_batch, training=False)
            loss_value = speech_few_shot_model.loss(y_batch, y_predict)

            accuracy_metric.update_state(y_batch, y_predict)
            loss_metric.update_state(loss_value)

        dev_loss = loss_metric.result().numpy()
        dev_accuracy = accuracy_metric.result().numpy()

        # validate model on one-shot dev task if specified
        if model_options["one_shot_validation"]:

            if FLAGS.fine_tune_steps is not None:
                test_few_shot_model = create_fine_tune_model(
                    model_options,
                    speech_few_shot_model.model,
                    num_classes=FLAGS.L)
            else:
                test_few_shot_model = base.BaseModel(
                    speech_few_shot_model.model,
                    None,
                    mc_dropout=FLAGS.mc_dropout)

            val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot(
                one_shot_dev_exp,
                FLAGS.K,
                FLAGS.L,
                n=FLAGS.N,
                num_episodes=FLAGS.episodes,
                k_neighbours=FLAGS.k_neighbours,
                metric=FLAGS.metric,
                classification=classification,
                model=test_few_shot_model,
                embedding_model_func=embedding_model_func,
                fine_tune_steps=FLAGS.fine_tune_steps,
                fine_tune_lr=FLAGS.fine_tune_lr)

            val_score = val_task_accuracy
            val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy"

        # otherwise, validate on classification task
        else:
            val_score = dev_accuracy
            val_metric = "categorical accuracy"

        if val_score >= best_val_score:
            best_val_score = val_score
            best_model = True

        # log results
        logging.log(
            logging.INFO,
            f"Train: Loss: {train_loss:.6f}, Categorical accuracy: "
            f"{train_accuracy:.3%}")

        logging.log(
            logging.INFO,
            f"Validation: Loss: {dev_loss:.6f}, Categorical accuracy: "
            f"{dev_accuracy:.3%} {'*' if best_model else ''}")

        if model_options["one_shot_validation"]:
            logging.log(
                logging.INFO,
                f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
                f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
                f"{conf_interval_95*100:.4f} {'*' if best_model else ''}")

        if tf_writer is not None:
            with tf_writer.as_default():
                tf.summary.scalar("Train step loss",
                                  train_loss,
                                  step=global_step)
                tf.summary.scalar("Train categorical accuracy",
                                  train_accuracy,
                                  step=global_step)
                tf.summary.scalar("Validation loss",
                                  dev_loss,
                                  step=global_step)
                tf.summary.scalar("Validation categorical accuracy",
                                  dev_accuracy,
                                  step=global_step)
                if model_options["one_shot_validation"]:
                    tf.summary.scalar(
                        f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy",
                        val_task_accuracy,
                        step=global_step)

        # store model and results
        model_utils.save_model(speech_few_shot_model.model,
                               output_dir,
                               epoch + 1,
                               global_step,
                               val_metric,
                               val_score,
                               best_val_score,
                               name="model")

        if best_model:
            best_model = False

            model_utils.save_model(speech_few_shot_model.model,
                                   output_dir,
                                   epoch + 1,
                                   global_step,
                                   val_metric,
                                   val_score,
                                   best_val_score,
                                   name="best_model")
Exemple #3
0
def train(model_options,
          output_dir,
          model_file=None,
          model_step_file=None,
          tf_writer=None):
    """Create and train audio-visual similarity model for one-shot learning."""

    # load embeddings from (linear) dense layer of base speech and vision models
    speech_embed_dir = os.path.join(model_options["audio_base_dir"], "embed",
                                    "dense")

    image_embed_dir = os.path.join(model_options["vision_base_dir"], "embed",
                                   "dense")

    # load training data (embed dir determines mfcc/fbank speech features)
    train_exp, dev_exp = dataset.create_flickr_multimodal_train_data(
        "mfcc",
        speech_embed_dir=speech_embed_dir,
        image_embed_dir=image_embed_dir,
        speaker_mode=FLAGS.speaker_mode,
        speech_preprocess_func=data_preprocess_func,
        image_preprocess_func=data_preprocess_func,
        unseen_match_set=FLAGS.unseen_match_set)

    # create training dataset generator
    def sample_task_function(task_exp, l, k, n):
        def sample():
            curr_episode_train, curr_episode_test = task_exp.sample_episode(
                l, k, n)

            x_train_s = task_exp.speech_experiment.data[curr_episode_train[0]]
            x_train_i = task_exp.vision_experiment.data[curr_episode_train[0]]

            x_query_s = task_exp.speech_experiment.data[curr_episode_test[0]]
            x_query_i = task_exp.vision_experiment.data[curr_episode_test[0]]

            return x_train_s, x_train_i, x_query_s, x_query_i

        return sample

    train_generator = sample_task_function(train_exp, FLAGS.L, FLAGS.K,
                                           FLAGS.N)

    # create dummy dataset with infinite zero elements
    background_train_ds = tf.data.Dataset.from_tensors(tf.constant(0.))
    background_train_ds = background_train_ds.repeat(-1)

    # parallel map dummy elements to task data
    background_train_ds = background_train_ds.map(
        lambda _: tf.py_function(
            train_generator, inp=[], Tout=[tf.float32] * 4),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    background_train_ds = background_train_ds.batch(
        batch_size=model_options["meta_batch_size"])

    background_train_ds = background_train_ds.prefetch(
        tf.data.experimental.AUTOTUNE)

    # get training objective
    triplet_loss = get_training_objective(model_options)

    # get model input shape
    for speech_batch, image_batch, _, _ in background_train_ds.take(1):
        model_options["audio_base_embed_size"] = int(
            tf.shape(speech_batch)[-1].numpy())
        model_options["vision_base_embed_size"] = int(
            tf.shape(image_batch)[-1].numpy())

        model_options["audio_input_shape"] = [
            model_options["audio_base_embed_size"]
        ]
        model_options["vision_input_shape"] = [
            model_options["vision_base_embed_size"]
        ]

    # load or create models
    if model_file is not None:
        join_network_model, train_state = model_utils.load_model(
            model_file=os.path.join(output_dir, model_file),
            model_step_file=os.path.join(output_dir, model_step_file),
            loss=get_training_objective(model_options))

        speech_network = tf.keras.Model(inputs=join_network_model.inputs[0],
                                        outputs=join_network_model.outputs[0])

        vision_network = tf.keras.Model(inputs=join_network_model.inputs[1],
                                        outputs=join_network_model.outputs[1])

        # get previous tracking variables
        initial_model = False
        global_step, _, _, best_val_score = train_state
    else:
        speech_network = create_speech_network(model_options)
        vision_network = create_vision_network(model_options)

        # create tracking variables
        initial_model = True
        global_step = 0

        if model_options["one_shot_validation"]:
            best_val_score = -np.inf
        else:
            best_val_score = np.inf

    # load or create Adam optimizer with decayed learning rate
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        model_options["meta_learning_rate"],
        decay_rate=model_options["decay_rate"],
        decay_steps=model_options["decay_steps"],
        staircase=True)

    if model_file is not None:
        logging.log(logging.INFO, "Restoring optimizer state")
        optimizer = join_network_model.optimizer
    else:
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    # compile models to store optimizer with model when saving
    join_network_model = tf.keras.Model(
        inputs=[speech_network.input, vision_network.input],
        outputs=[speech_network.output, vision_network.output])

    # speech_network.compile(optimizer=optimizer, loss=triplet_loss)
    join_network_model.compile(optimizer=optimizer, loss=triplet_loss)

    # create few-shot model from speech network for background training
    multimodal_model = base.WeaklySupervisedMAML(
        speech_network,
        vision_network,
        triplet_loss,
        inner_optimizer_lr=FLAGS.fine_tune_lr)

    # create few-shot model from speech and vision networks for one-shot validation
    test_few_shot_model = create_fine_tune_model(model_options, speech_network,
                                                 vision_network)

    fine_tune_optimizer = None
    if not FLAGS.adam:
        fine_tune_optimizer = tf.keras.optimizers.SGD(FLAGS.fine_tune_lr)

    # test model on one-shot validation task prior to training
    val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot(
        dev_exp,
        FLAGS.K,
        FLAGS.L,
        n=FLAGS.N,
        num_episodes=FLAGS.episodes,
        k_neighbours=FLAGS.k_neighbours,
        metric=FLAGS.metric,
        direct_match=FLAGS.direct_match,
        multimodal_model=test_few_shot_model,
        multimodal_embedding_func=None,  #create_embedding_model,
        optimizer=fine_tune_optimizer,
        fine_tune_steps=FLAGS.fine_tune_steps * 2,
        fine_tune_lr=FLAGS.fine_tune_lr)

    logging.log(
        logging.INFO,
        f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
        f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
        f"{conf_interval_95*100:.4f}")

    # create training metrics
    step_loss_metric = tf.keras.metrics.Mean()
    avg_loss_metric = tf.keras.metrics.Mean()
    best_model = False

    # store model options on first run
    if initial_model:
        file_io.write_json(os.path.join(output_dir, "model_options.json"),
                           model_options)

        # also store initial model for probing and things
        model_utils.save_model(join_network_model,
                               output_dir,
                               0,
                               0,
                               "not tested",
                               0.,
                               0.,
                               name="initial_model")

    # train model
    step_pbar = tqdm(background_train_ds,
                     bar_format="{desc} {r_bar}",
                     initial=global_step,
                     total=model_options["meta_steps"])
    for batch in step_pbar:

        # train on batch of training task data
        train_s_batch, train_i_batch, test_s_batch, test_i_batch = batch

        meta_loss, inner_losses, meta_losses = multimodal_model.maml_train_step(
            train_s_batch,
            train_i_batch,
            test_s_batch,
            test_i_batch,
            FLAGS.fine_tune_steps,
            optimizer,
            stop_gradients=FLAGS.first_order,
            clip_norm=model_options["gradient_clip_norm"])

        step_loss_metric.reset_states()
        step_loss_metric.update_state(meta_loss)
        avg_loss_metric.update_state(meta_loss)

        avg_loss = avg_loss_metric.result().numpy()

        step_pbar.set_description_str(f"Step loss: {meta_loss:.6f}, "
                                      f"Average loss: {avg_loss:.6f}")

        if tf_writer is not None:
            with tf_writer.as_default():
                tf.summary.scalar("Train step loss",
                                  meta_loss,
                                  step=global_step)

        if (global_step % model_options["validation_interval"] == 0
                or global_step
                == model_options["meta_steps"]) and global_step > 0:

            # validate model on one-shot dev task
            test_few_shot_model = create_fine_tune_model(
                model_options, speech_network, vision_network)

            val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot(
                dev_exp,
                FLAGS.K,
                FLAGS.L,
                n=FLAGS.N,
                num_episodes=FLAGS.episodes,
                k_neighbours=FLAGS.k_neighbours,
                metric=FLAGS.metric,
                direct_match=FLAGS.direct_match,
                multimodal_model=test_few_shot_model,
                multimodal_embedding_func=None,  #create_embedding_model,
                optimizer=fine_tune_optimizer,
                fine_tune_steps=FLAGS.fine_tune_steps,
                fine_tune_lr=FLAGS.fine_tune_lr)

            val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy"

            if val_task_accuracy >= best_val_score:
                best_val_score = val_task_accuracy
                best_model = True

            # log results
            avg_loss_metric.reset_states()

            logging.log(logging.INFO, f"Step {global_step:03d}")

            logging.log(
                logging.INFO, f"Train: Step loss: {meta_loss:.6f}, "
                f"Average loss: {avg_loss:.6f}")

            logging.log(
                logging.INFO,
                f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
                f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
                f"{conf_interval_95*100:.4f} {'*' if best_model else ''}")

            if tf_writer is not None:
                with tf_writer.as_default():
                    tf.summary.scalar(
                        f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy",
                        val_task_accuracy,
                        step=global_step)

            model_utils.save_model(join_network_model,
                                   output_dir,
                                   global_step,
                                   global_step,
                                   val_metric,
                                   val_task_accuracy,
                                   best_val_score,
                                   name="model")

            if best_model:
                best_model = False
                model_utils.save_model(join_network_model,
                                       output_dir,
                                       global_step,
                                       global_step,
                                       val_metric,
                                       val_task_accuracy,
                                       best_val_score,
                                       name="best_model")

        global_step += 1

        if global_step > model_options["meta_steps"]:
            logging.log(logging.INFO,
                        f"Training complete after {global_step-1:03d} steps")
            break
Exemple #4
0
def train(model_options, output_dir, model_file=None, model_step_file=None,
          tf_writer=None):
    """Create and train image classification model for one-shot learning."""

    # load training data
    train_exp, dev_exp = dataset.create_flickr_vision_train_data(
        model_options["data"])

    train_labels = []
    for image_keywords in train_exp.unique_image_keywords:
        labels = map(
            lambda keyword: train_exp.keyword_labels[keyword], image_keywords)
        train_labels.append(np.array(list(labels)))
    train_labels = np.asarray(train_labels)

    dev_labels = []
    for image_keywords in dev_exp.unique_image_keywords:
        labels = map(
            lambda keyword: train_exp.keyword_labels[keyword], image_keywords)
        dev_labels.append(np.array(list(labels)))
    dev_labels = np.asarray(dev_labels)

    train_paths = train_exp.unique_image_paths
    dev_paths = dev_exp.unique_image_paths

    mlb = MultiLabelBinarizer()
    train_labels_multi_hot = mlb.fit_transform(train_labels)
    dev_labels_multi_hot = mlb.transform(dev_labels)

    # define preprocessing for images
    preprocess_images_func = functools.partial(
        dataset.load_and_preprocess_image,
        crop_size=model_options["crop_size"],
        augment_crop=model_options["augment_train"],
        random_scales=model_options["random_scales"],
        horizontal_flip=model_options["horizontal_flip"],
        colour=model_options["colour"])

    # create standard training dataset pipeline
    background_train_ds = tf.data.Dataset.zip((
        tf.data.Dataset.from_tensor_slices(train_paths),
        tf.data.Dataset.from_tensor_slices(train_labels_multi_hot)))

    # map data preprocessing function across training data
    background_train_ds = background_train_ds.map(
        lambda path, label: (preprocess_images_func(path), label),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    # repeat augmentation, shuffle and batch train data
    if model_options["num_augment"] is not None:
        background_train_ds = background_train_ds.repeat(
            model_options["num_augment"])

    background_train_ds = background_train_ds.shuffle(1000)

    background_train_ds = background_train_ds.batch(
        model_options["batch_size"])

    background_train_ds = background_train_ds.prefetch(
        tf.data.experimental.AUTOTUNE)

    # create dev set pipeline for classification validation
    background_dev_ds = tf.data.Dataset.zip((
        tf.data.Dataset.from_tensor_slices(
            dev_paths).map(preprocess_images_func),
        tf.data.Dataset.from_tensor_slices(dev_labels_multi_hot)))

    background_dev_ds = background_dev_ds.batch(
        batch_size=model_options["batch_size"])

    # write example batch to TensorBoard
    if tf_writer is not None:
        logging.log(logging.INFO, "Writing example images to TensorBoard")
        with tf_writer.as_default():
            for x_batch, y_batch in background_train_ds.take(1):
                tf.summary.image("Example train images", (x_batch+1)/2,
                                 max_outputs=30, step=0)
                labels = ""
                for i, label in enumerate(y_batch[:30]):
                    labels += f"{i}: {np.asarray(train_exp.keywords)[label]} "
                tf.summary.text("Example train labels", labels, step=0)

    # get training objective
    multi_label_loss = get_training_objective(model_options)

    # get model input shape
    model_options["input_shape"] = (
        model_options["crop_size"], model_options["crop_size"], 3)

    # load or create model
    if model_file is not None:
        assert model_options["n_classes"] == len(train_exp.keywords)

        vision_network, train_state = model_utils.load_model(
            model_file=os.path.join(output_dir, model_file),
            model_step_file=os.path.join(output_dir, model_step_file),
            loss=multi_label_loss)

        # get previous tracking variables
        initial_model = False
        global_step, model_epochs, _, best_val_score = train_state
    else:
        model_options["n_classes"] = len(train_exp.keywords)

        vision_network = create_vision_network(model_options)

        # create tracking variables
        initial_model = True
        global_step = 0
        model_epochs = 0

        if model_options["one_shot_validation"]:
            best_val_score = -np.inf
        else:
            best_val_score = np.inf

    # load or create Adam optimizer with decayed learning rate
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        model_options["learning_rate"], decay_rate=model_options["decay_rate"],
        decay_steps=model_options["decay_steps"], staircase=True)

    if model_file is not None:
        logging.log(logging.INFO, "Restoring optimizer state")
        optimizer = vision_network.optimizer
    else:
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    # compile model to store optimizer with model when saving
    vision_network.compile(optimizer=optimizer, loss=multi_label_loss)

    # create few-shot model from vision network for background training
    vision_few_shot_model = base.BaseModel(vision_network, multi_label_loss)

    # test model on one-shot validation task prior to training
    if model_options["one_shot_validation"]:

        one_shot_dev_exp = flickr_vision.FlickrVision(
            keywords_split="background_dev",
            flickr8k_image_dir=os.path.join(
                "data", "external", "flickr8k_images"),
            flickr30k_image_dir=os.path.join(
                "data", "external", "flickr30k_images"),
            mscoco_image_dir=os.path.join(
                "data", "external", "mscoco", "val2017"),
            preprocess_func=get_data_preprocess_func(model_options))

        embedding_model_func = lambda vision_network: create_embedding_model(
            model_options, vision_network)

        classification = False
        if FLAGS.classification:
            assert FLAGS.embed_layer in ["logits", "softmax"]
            classification = True

        # create few-shot model from vision network for one-shot validation
        if FLAGS.fine_tune_steps is not None:
            test_few_shot_model = create_fine_tune_model(
                model_options, vision_few_shot_model.model, num_classes=FLAGS.L)
        else:
            test_few_shot_model = base.BaseModel(
                vision_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout)

        val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot(
            one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N,
            num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours,
            metric=FLAGS.metric, classification=classification,
            model=test_few_shot_model,
            embedding_model_func=embedding_model_func,
            fine_tune_steps=FLAGS.fine_tune_steps,
            fine_tune_lr=FLAGS.fine_tune_lr)

        logging.log(
            logging.INFO,
            f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
            f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
            f"{conf_interval_95*100:.4f}")

    # create training metrics
    precision_metric = tf.keras.metrics.Precision()
    recall_metric = tf.keras.metrics.Recall()
    loss_metric = tf.keras.metrics.Mean()
    best_model = False

    # store model options on first run
    if initial_model:
        file_io.write_json(
            os.path.join(output_dir, "model_options.json"), model_options)

    # train model
    for epoch in range(model_epochs, model_options["epochs"]):
        logging.log(logging.INFO, f"Epoch {epoch:03d}")

        precision_metric.reset_states()
        recall_metric.reset_states()
        loss_metric.reset_states()

        # train on epoch of training data
        step_pbar = tqdm(background_train_ds,
                         bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]")
        for step, (x_batch, y_batch) in enumerate(step_pbar):

            loss_value, y_predict = vision_few_shot_model.train_step(
                x_batch, y_batch, optimizer,
                clip_norm=model_options["gradient_clip_norm"])

            y_one_hot_predict = tf.round(tf.nn.sigmoid(y_predict))

            precision_metric.update_state(y_batch, y_one_hot_predict)
            recall_metric.update_state(y_batch, y_one_hot_predict)
            loss_metric.update_state(loss_value)

            step_loss = tf.reduce_mean(loss_value)
            train_loss = loss_metric.result().numpy()
            train_precision = precision_metric.result().numpy()
            train_recall = recall_metric.result().numpy()
            train_f1 = 2 / ((1/train_precision) + (1/train_recall))

            step_pbar.set_description_str(
                f"\tStep {step:03d}: "
                f"Step loss: {step_loss:.6f}, "
                f"Loss: {train_loss:.6f}, "
                f"Precision: {train_precision:.3%}, "
                f"Recall: {train_recall:.3%}, "
                f"F-1: {train_f1:.3%}")

            if tf_writer is not None:
                with tf_writer.as_default():
                    tf.summary.scalar(
                        "Train step loss", step_loss, step=global_step)
            global_step += 1

        # validate classification model
        precision_metric.reset_states()
        recall_metric.reset_states()
        loss_metric.reset_states()

        for x_batch, y_batch in background_dev_ds:
            y_predict = vision_few_shot_model.predict(x_batch, training=False)
            loss_value = vision_few_shot_model.loss(y_batch, y_predict)

            y_one_hot_predict = tf.round(tf.nn.sigmoid(y_predict))

            precision_metric.update_state(y_batch, y_one_hot_predict)
            recall_metric.update_state(y_batch, y_one_hot_predict)
            loss_metric.update_state(loss_value)

        dev_loss = loss_metric.result().numpy()
        dev_precision = precision_metric.result().numpy()
        dev_recall = recall_metric.result().numpy()
        dev_f1 = 2 / ((1/dev_precision) + (1/dev_recall))

        # validate model on one-shot dev task if specified
        if model_options["one_shot_validation"]:

            if FLAGS.fine_tune_steps is not None:
                test_few_shot_model = create_fine_tune_model(
                    model_options, vision_few_shot_model.model,
                    num_classes=FLAGS.L)
            else:
                test_few_shot_model = base.BaseModel(
                    vision_few_shot_model.model, None,
                    mc_dropout=FLAGS.mc_dropout)

            val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot(
                one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N,
                num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours,
                metric=FLAGS.metric, classification=classification,
                model=test_few_shot_model,
                embedding_model_func=embedding_model_func,
                fine_tune_steps=FLAGS.fine_tune_steps,
                fine_tune_lr=FLAGS.fine_tune_lr)

            val_score = val_task_accuracy
            val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy"

        # otherwise, validate on classification task
        else:
            val_score = dev_f1
            val_metric = "F-1"

        if val_score >= best_val_score:
            best_val_score = val_score
            best_model = True

        # log results
        logging.log(
            logging.INFO,
            f"Train: Loss: {train_loss:.6f}, Precision: {train_precision:.3%}, "
            f"Recall: {train_recall:.3%}, F-1: {train_f1:.3%}")

        logging.log(
            logging.INFO,
            f"Validation: Loss: {dev_loss:.6f}, Precision: "
            f"{dev_precision:.3%}, Recall: {dev_recall:.3%}, F-1: "
            f"{dev_f1:.3%} {'*' if best_model else ''}")

        if model_options["one_shot_validation"]:
            logging.log(
                logging.INFO,
                f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
                f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
                f"{conf_interval_95*100:.4f} {'*' if best_model else ''}")

        if tf_writer is not None:
            with tf_writer.as_default():
                tf.summary.scalar(
                    "Train step loss", train_loss, step=global_step)
                tf.summary.scalar(
                    "Train precision", train_precision, step=global_step)
                tf.summary.scalar(
                    "Train recall", train_recall, step=global_step)
                tf.summary.scalar(
                    "Train F-1", train_f1, step=global_step)
                tf.summary.scalar(
                    "Validation loss", dev_loss, step=global_step)
                tf.summary.scalar(
                    "Validation precision", dev_precision, step=global_step)
                tf.summary.scalar(
                    "Validation recall", dev_recall, step=global_step)
                tf.summary.scalar(
                    "Validation F-1", dev_f1, step=global_step)
                if model_options["one_shot_validation"]:
                    tf.summary.scalar(
                        f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy",
                        val_task_accuracy, step=global_step)

        # store model and results
        model_utils.save_model(
            vision_few_shot_model.model, output_dir, epoch + 1, global_step,
            val_metric, val_score, best_val_score, name="model")

        if best_model:
            best_model = False

            model_utils.save_model(
                vision_few_shot_model.model, output_dir, epoch + 1, global_step,
                val_metric, val_score, best_val_score, name="best_model")
Exemple #5
0
def train(model_options, output_dir, model_file=None, model_step_file=None,
          tf_writer=None):
    """Create and train audio-visual similarity model for one-shot learning."""

    # load embeddings from (linear) dense layer of base speech and vision models
    speech_embed_dir = os.path.join(
        model_options["audio_base_dir"], "embed", "dense")

    image_embed_dir = os.path.join(
        model_options["vision_base_dir"], "embed", "dense")

    # load training data (embed dir determines mfcc/fbank speech features)
    train_exp, dev_exp = dataset.create_flickr_multimodal_train_data(
        "mfcc", speech_embed_dir=speech_embed_dir,
        image_embed_dir=image_embed_dir, speaker_mode=FLAGS.speaker_mode,
        unseen_match_set=FLAGS.unseen_match_set)

    train_speech_paths = train_exp.speech_experiment.data
    train_image_paths = train_exp.vision_experiment.data

    dev_speech_paths = dev_exp.speech_experiment.data
    dev_image_paths = dev_exp.vision_experiment.data

    # define preprocessing for base model embeddings
    preprocess_data_func = lambda example: dataset.parse_embedding_protobuf(
        example)["embed"]

    # create standard training dataset pipeline
    background_train_ds = tf.data.Dataset.zip((
        tf.data.TFRecordDataset(
            train_speech_paths, compression_type="ZLIB"),
        tf.data.TFRecordDataset(
            train_image_paths, compression_type="ZLIB")))

    # map data preprocessing function across training data
    background_train_ds = background_train_ds.map(
        lambda speech_path, image_path: (
            preprocess_data_func(speech_path), preprocess_data_func(image_path)),
        num_parallel_calls=8)

    # shuffle and batch train data
    background_train_ds = background_train_ds.repeat(-1)

    background_train_ds = background_train_ds.shuffle(1000)

    background_train_ds = background_train_ds.batch(
        model_options["batch_size"])

    background_train_ds = background_train_ds.take(
        model_options["num_batches"])

    background_train_ds = background_train_ds.prefetch(
            tf.data.experimental.AUTOTUNE)

    # create dev set pipeline for validation
    background_dev_ds = tf.data.Dataset.zip((
        tf.data.TFRecordDataset(
            dev_speech_paths, compression_type="ZLIB"),
        tf.data.TFRecordDataset(
            dev_image_paths, compression_type="ZLIB")))

    background_dev_ds = background_dev_ds.map(
        lambda speech_path, image_path: (
            preprocess_data_func(speech_path), preprocess_data_func(image_path)),
        num_parallel_calls=8)

    background_dev_ds = background_dev_ds.batch(
        batch_size=model_options["batch_size"])

    # get training objective
    triplet_loss = get_training_objective(model_options)

    # get model input shape
    for speech_batch, image_batch in background_train_ds.take(1):
        model_options["audio_base_embed_size"] = int(
            tf.shape(speech_batch)[1].numpy())
        model_options["vision_base_embed_size"] = int(
            tf.shape(image_batch)[1].numpy())

        model_options["audio_input_shape"] = [
            model_options["audio_base_embed_size"]]
        model_options["vision_input_shape"] = [
            model_options["vision_base_embed_size"]]

    # load or create models
    if model_file is not None:
        join_network_model, train_state = model_utils.load_model(
            model_file=os.path.join(output_dir, model_file),
            model_step_file=os.path.join(output_dir, model_step_file),
            loss=get_training_objective(model_options))

        speech_network = tf.keras.Model(
            inputs=join_network_model.inputs[0],
            outputs=join_network_model.outputs[0])

        vision_network = tf.keras.Model(
            inputs=join_network_model.inputs[1],
            outputs=join_network_model.outputs[1])

        # get previous tracking variables
        initial_model = False
        global_step, model_epochs, _, best_val_score = train_state
    else:
        speech_network = create_speech_network(model_options)
        vision_network = create_vision_network(model_options)

        # create tracking variables
        initial_model = True
        global_step = 0
        model_epochs = 0

        if model_options["one_shot_validation"]:
            best_val_score = -np.inf
        else:
            best_val_score = np.inf

    # load or create Adam optimizer with decayed learning rate
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        model_options["learning_rate"], decay_rate=model_options["decay_rate"],
        decay_steps=model_options["decay_steps"], staircase=True)

    if model_file is not None:
        logging.log(logging.INFO, "Restoring optimizer state")
        optimizer = join_network_model.optimizer
    else:
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    # compile models to store optimizer with model when saving
    join_network_model = tf.keras.Model(
        inputs=[speech_network.input, vision_network.input],
        outputs=[speech_network.output, vision_network.output])

    # speech_network.compile(optimizer=optimizer, loss=triplet_loss)
    join_network_model.compile(optimizer=optimizer, loss=triplet_loss)

    # create few-shot model from speech network for background training
    multimodal_model = base.WeaklySupervisedModel(
        speech_network, vision_network, triplet_loss)

    # test model on one-shot validation task prior to training
    if model_options["one_shot_validation"]:

        one_shot_dev_exp = flickr_multimodal.FlickrMultimodal(
            features="mfcc", keywords_split="background_dev",
            flickr8k_image_dir=os.path.join("data", "external", "flickr8k_images"),
            speech_embed_dir=speech_embed_dir, image_embed_dir=image_embed_dir,
            speech_preprocess_func=data_preprocess_func,
            image_preprocess_func=data_preprocess_func,
            speaker_mode=FLAGS.speaker_mode,
            unseen_match_set=FLAGS.unseen_match_set)

        # create few-shot model from speech and vision networks for one-shot validation
        if FLAGS.fine_tune_steps is not None:
            test_few_shot_model = create_fine_tune_model(
                model_options, speech_network, vision_network)
        else:
            test_few_shot_model = base.WeaklySupervisedModel(
                speech_network, vision_network, None,
                mc_dropout=FLAGS.mc_dropout)

        val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot(
            one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N,
            num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours,
            metric=FLAGS.metric, direct_match=FLAGS.direct_match,
            multimodal_model=test_few_shot_model,
            multimodal_embedding_func=None, #create_embedding_model,
            fine_tune_steps=FLAGS.fine_tune_steps,
            fine_tune_lr=FLAGS.fine_tune_lr)

        logging.log(
            logging.INFO,
            f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
            f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
            f"{conf_interval_95*100:.4f}")


    # create training metrics
    loss_metric = tf.keras.metrics.Mean()
    best_model = False

    # store model options on first run
    if initial_model:
        file_io.write_json(
            os.path.join(output_dir, "model_options.json"), model_options)

    # train model
    for epoch in range(model_epochs, model_options["epochs"]):
        logging.log(logging.INFO, f"Epoch {epoch:03d}")

        loss_metric.reset_states()

        # train on epoch of training data
        step_pbar = tqdm(background_train_ds,
                         bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]")
        for step, (speech_batch, image_batch) in enumerate(step_pbar):

            loss_value, y_speech, y_image = multimodal_model.train_step(
                speech_batch, image_batch, optimizer,
                clip_norm=model_options["gradient_clip_norm"])

            loss_metric.update_state(loss_value)

            step_loss = tf.reduce_mean(loss_value)
            train_loss = loss_metric.result().numpy()

            step_pbar.set_description_str(
                f"\tStep {step:03d}: "
                f"Step loss: {step_loss:.6f}, "
                f"Loss: {train_loss:.6f}")

            if tf_writer is not None:
                with tf_writer.as_default():
                    tf.summary.scalar(
                        "Train step loss", step_loss, step=global_step)
            global_step += 1

        # validate multimodal model
        loss_metric.reset_states()

        for speech_batch, image_batch in background_dev_ds:
            y_speech = multimodal_model.speech_model.predict(
                speech_batch, training=False)
            y_image = multimodal_model.vision_model.predict(
                image_batch, training=False)
            loss_value = multimodal_model.loss(y_speech, y_image)

            loss_metric.update_state(loss_value)

        dev_loss = loss_metric.result().numpy()

        # validate model on one-shot dev task if specified
        if model_options["one_shot_validation"]:

            if FLAGS.fine_tune_steps is not None:
                test_few_shot_model = create_fine_tune_model(
                    model_options, speech_network, vision_network)
            else:
                test_few_shot_model = base.WeaklySupervisedModel(
                    speech_network, vision_network, None,
                    mc_dropout=FLAGS.mc_dropout)

            val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot(
                one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N,
                num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours,
                metric=FLAGS.metric, direct_match=FLAGS.direct_match,
                multimodal_model=test_few_shot_model,
                multimodal_embedding_func=None, #create_embedding_model,
                fine_tune_steps=FLAGS.fine_tune_steps,
                fine_tune_lr=FLAGS.fine_tune_lr)

            val_score = val_task_accuracy
            val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy"

            if val_score >= best_val_score:
                best_val_score = val_score
                best_model = True

        # otherwise, validate on siamese task
        else:
            val_score = dev_loss
            val_metric = "loss"

            if val_score <= best_val_score:
                best_val_score = val_score
                best_model = True

        # log results
        logging.log(logging.INFO, f"Train: Loss: {train_loss:.6f}")

        logging.log(
            logging.INFO,
            f"Validation: Loss: {dev_loss:.6f} {'*' if best_model else ''}")

        if model_options["one_shot_validation"]:
            logging.log(
                logging.INFO,
                f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
                f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
                f"{conf_interval_95*100:.4f} {'*' if best_model else ''}")

        if tf_writer is not None:
            with tf_writer.as_default():
                tf.summary.scalar(
                    "Train step loss", train_loss, step=global_step)
                tf.summary.scalar(
                    f"Validation loss", dev_loss, step=global_step)
                if model_options["one_shot_validation"]:
                    tf.summary.scalar(
                        f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy",
                        val_task_accuracy, step=global_step)

        # store model and results
        # model_utils.save_model(
        #     multimodal_model.model_a.model, output_dir, epoch + 1, global_step,
        #     val_metric, val_score, best_val_score, name="audio_model")
        # model_utils.save_model(
        #     multimodal_model.model_b.model, output_dir, epoch + 1, global_step,
        #     val_metric, val_score, best_val_score, name="vision_model")
        model_utils.save_model(
            join_network_model, output_dir, epoch + 1, global_step,
            val_metric, val_score, best_val_score, name="model")

        if best_model:
            best_model = False
            # model_utils.save_model(
            #     multimodal_model.model_a.model, output_dir, epoch + 1, global_step,
            #     val_metric, val_score, best_val_score, name="audio_best_model")
            # model_utils.save_model(
            #     multimodal_model.model_b.model, output_dir, epoch + 1, global_step,
            #     val_metric, val_score, best_val_score, name="vision_best_model")
            model_utils.save_model(
                join_network_model, output_dir, epoch + 1, global_step,
                val_metric, val_score, best_val_score, name="best_model")


    import pdb; pdb.set_trace()