Пример #1
0
    def __init__(self,
                 features="mfcc",
                 keywords_split="one_shot_evaluation",
                 flickr8k_image_dir=None,
                 speech_embed_dir=None,
                 image_embed_dir=None,
                 image_preprocess_func=None,
                 speech_preprocess_func=None,
                 speaker_mode="baseline",
                 unseen_match_set=False,
                 **kwargs):
        """TODO"""
        logging.log(logging.INFO, f"Creating Flickr multimodal experiment")

        super().__init__(features=features,
                         keywords_split=keywords_split,
                         embed_dir=speech_embed_dir,
                         preprocess_func=speech_preprocess_func,
                         speaker_mode=speaker_mode,
                         **kwargs)

        # TODO: add flickr30k/mscoco and sample paired audio with same keywords?
        self.flickr_vision_exp = flickr_vision.FlickrVision(
            keywords_split=keywords_split,
            flickr8k_image_dir=flickr8k_image_dir,
            flickr30k_image_dir=None,
            mscoco_image_dir=None,
            embed_dir=image_embed_dir,
            preprocess_func=image_preprocess_func,
            **kwargs)

        self.unseen_match_set = unseen_match_set
Пример #2
0
def create_flickr_vision_train_data(data_sets, embed_dir=None):
    """Load train and validation Flickr vision data."""

    flickr8k_image_dir = None
    if "flickr8k" in data_sets:
        flickr8k_image_dir = os.path.join("data", "external",
                                          "flickr8k_images")

    flickr30k_image_dir = None
    if "flickr30k" in data_sets:
        flickr30k_image_dir = os.path.join("data", "external",
                                           "flickr30k_images")

    mscoco_train_image_dir = None
    mscoco_dev_image_dir = None
    if "mscoco" in data_sets:
        mscoco_train_image_dir = os.path.join("data", "external", "mscoco",
                                              "train2017")
        mscoco_dev_image_dir = os.path.join("data", "external", "mscoco",
                                            "val2017")

    flickr_train_exp = flickr_vision.FlickrVision(
        keywords_split="background_train",
        flickr8k_image_dir=flickr8k_image_dir,
        flickr30k_image_dir=flickr30k_image_dir,
        mscoco_image_dir=mscoco_train_image_dir,
        embed_dir=embed_dir)

    flickr_dev_exp = flickr_vision.FlickrVision(
        keywords_split="background_dev",
        flickr8k_image_dir=flickr8k_image_dir,
        flickr30k_image_dir=flickr30k_image_dir,
        mscoco_image_dir=mscoco_dev_image_dir,
        embed_dir=embed_dir)

    return flickr_train_exp, flickr_dev_exp
Пример #3
0
def test(model_options, output_dir, model_file, model_step_file):
    """Load and test image classification model for one-shot learning."""

    # load Flickr 8k one-shot experiment
    one_shot_exp = flickr_vision.FlickrVision(
        keywords_split="one_shot_evaluation",
        flickr8k_image_dir=os.path.join("data", "external", "flickr8k_images"),
        preprocess_func=get_data_preprocess_func(model_options))

    # load model
    vision_network, _ = model_utils.load_model(
        model_file=os.path.join(output_dir, model_file),
        model_step_file=os.path.join(output_dir, model_step_file),
        loss=get_training_objective(model_options))

    embedding_model_func = lambda vision_network: create_embedding_model(
        model_options, vision_network)

    # create few-shot model from vision network for one-shot testing
    if FLAGS.fine_tune_steps is not None:
        test_few_shot_model = create_fine_tune_model(
            model_options, vision_network, num_classes=FLAGS.L)
    else:
        test_few_shot_model = base.BaseModel(
            vision_network, None, mc_dropout=FLAGS.mc_dropout)

    classification = False
    if FLAGS.classification:
        assert FLAGS.embed_layer in ["logits", "softmax"]
        classification = True

    logging.log(logging.INFO, "Created few-shot model from vision network")
    test_few_shot_model.model.summary()

    # test model on L-way K-shot task
    task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot(
        one_shot_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes,
        k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric,
        classification=classification, model=test_few_shot_model,
        embedding_model_func=embedding_model_func,
        fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr)

    logging.log(
        logging.INFO,
        f"{FLAGS.L}-way {FLAGS.K}-shot accuracy after {FLAGS.episodes} "
        f"episodes: {task_accuracy:.3%} +- {conf_interval_95*100:.4f}")
Пример #4
0
def embed(model_options, output_dir, model_file, model_step_file):
    """Load siamese image similarity model and extract embeddings."""

    # get base embeddings directory if specified, otherwise embed images
    embed_dir = None
    if model_options["use_embeddings"]:
        # load embeddings from dense layer of base model
        embed_dir = os.path.join(model_options["base_dir"], "embed", "dense")

    # load model
    vision_network, _ = model_utils.load_model(
        model_file=os.path.join(output_dir, model_file),
        model_step_file=os.path.join(output_dir, model_step_file),
        loss=get_training_objective(model_options))

    # get model embedding model and data preprocessing
    embedding_model = create_embedding_model(vision_network)
    data_preprocess_func = get_data_preprocess_func(model_options)

    # load image datasets and compute embeddings
    for data in ["flickr8k", "flickr30k", "mscoco"]:

        train_image_dir_dict = {}
        dev_image_dir_dict = {}

        if data == "flickr8k":
            train_image_dir_dict["flickr8k_image_dir"] = os.path.join(
                "data", "external", "flickr8k_images")
            dev_image_dir_dict = train_image_dir_dict

        if data == "flickr30k":
            train_image_dir_dict["flickr30k_image_dir"] = os.path.join(
                "data", "external", "flickr30k_images")
            dev_image_dir_dict = train_image_dir_dict

        if data == "mscoco":
            train_image_dir_dict["mscoco_image_dir"] = os.path.join(
                "data", "external", "mscoco", "train2017")
            dev_image_dir_dict["mscoco_image_dir"] = os.path.join(
                "data", "external", "mscoco", "val2017")

        one_shot_exp = flickr_vision.FlickrVision(
            keywords_split="one_shot_evaluation",
            **train_image_dir_dict,
            embed_dir=embed_dir)

        background_train_exp = flickr_vision.FlickrVision(
            keywords_split="background_train",
            **train_image_dir_dict,
            embed_dir=embed_dir)

        background_dev_exp = flickr_vision.FlickrVision(
            keywords_split="background_dev",
            **dev_image_dir_dict,
            embed_dir=embed_dir)

        subset_exp = {
            "one_shot_evaluation": one_shot_exp,
            "background_train": background_train_exp,
            "background_dev": background_dev_exp,
        }

        for subset, exp in subset_exp.items():
            output_embed_dir = os.path.join(output_dir, "embed", "dense", data,
                                            subset)
            file_io.check_create_dir(output_embed_dir)

            if model_options["use_embeddings"]:
                subset_paths = exp.embed_paths
            else:
                subset_paths = exp.image_paths

            unique_paths = np.unique(subset_paths)

            # batch images/base embeddings for faster embedding inference
            path_ds = tf.data.Dataset.from_tensor_slices(unique_paths)
            path_ds = path_ds.batch(model_options["batch_size"])
            path_ds = path_ds.prefetch(tf.data.experimental.AUTOTUNE)

            num_samples = int(
                np.ceil(len(unique_paths) / model_options["batch_size"]))

            start_time = time.time()
            paths, embeddings = [], []
            for path_batch in tqdm(path_ds, total=num_samples):
                path_embeddings = embedding_model.predict(
                    data_preprocess_func(path_batch))

                paths.extend(path_batch.numpy())
                embeddings.extend(path_embeddings.numpy())
            end_time = time.time()

            logging.log(
                logging.INFO, f"Computed embeddings for {data} {subset} in "
                f"{end_time - start_time:.4f} seconds")

            # serialize and write embeddings to TFRecord files
            for path, embedding in zip(paths, embeddings):
                example_proto = dataset.embedding_to_example_protobuf(
                    embedding)

                path = path.decode("utf-8")
                path = path.split(".tfrecord")[0]  # remove any ".tfrecord" ext
                path = os.path.join(output_embed_dir,
                                    f"{os.path.split(path)[1]}.tfrecord")

                with tf.io.TFRecordWriter(path, options="ZLIB") as writer:
                    writer.write(example_proto.SerializeToString())

            logging.log(logging.INFO,
                        f"Embeddings stored at: {output_embed_dir}")
Пример #5
0
def train(model_options,
          output_dir,
          model_file=None,
          model_step_file=None,
          tf_writer=None):
    """Create and train image similarity model for one-shot learning."""

    # load embeddings from dense layer of base model
    embed_dir = None
    if model_options["use_embeddings"]:
        embed_dir = os.path.join(model_options["base_dir"], "embed", "dense")

    # load training data
    train_exp, dev_exp = dataset.create_flickr_vision_train_data(
        model_options["data"], embed_dir=embed_dir)

    train_labels = []
    for keyword in train_exp.keywords_set[3]:
        label = train_exp.keyword_labels[keyword]
        train_labels.append(label)
    train_labels = np.asarray(train_labels)

    dev_labels = []
    for keyword in dev_exp.keywords_set[3]:
        label = train_exp.keyword_labels[keyword]
        dev_labels.append(label)
    dev_labels = np.asarray(dev_labels)

    # define preprocessing for images if no base model embeddings specified
    if model_options["base_dir"] is None or not model_options["use_embeddings"]:
        train_paths = train_exp.image_paths

        preprocess_data_func = functools.partial(
            dataset.load_and_preprocess_image,
            crop_size=model_options["crop_size"],
            augment_crop=model_options["augment_train"],
            random_scales=model_options["random_scales"],
            horizontal_flip=model_options["horizontal_flip"],
            colour=model_options["colour"])

    # otherwise define preprocessing for base model embeddings
    else:
        train_paths = train_exp.embed_paths

        preprocess_data_func = lambda example: dataset.parse_embedding_protobuf(
            example)["embed"]

    # create balanced batch training dataset pipeline
    if model_options["balanced"]:
        assert model_options["p"] is not None
        assert model_options["k"] is not None

        shuffle_train = False
        prefetch_train = False
        num_repeat = model_options["num_batches"]
        model_options["batch_size"] = model_options["p"] * model_options["k"]

        # get unique path train indices per unique label
        train_labels_series = pd.Series(train_labels)
        train_label_idx = {
            label: idx.values[np.unique(train_paths[idx.values],
                                        return_index=True)[1]]
            for label, idx in train_labels_series.groupby(
                train_labels_series).groups.items()
        }

        # cache paths to speed things up a little ...
        file_io.check_create_dir(os.path.join(output_dir, "cache"))

        # create a dataset for each unique keyword label (shuffled and cached)
        train_label_datasets = [
            tf.data.Dataset.zip(
                (tf.data.Dataset.from_tensor_slices(train_paths[idx]),
                 tf.data.Dataset.from_tensor_slices(train_labels[idx]))).cache(
                     os.path.join(output_dir, "cache",
                                  str(label))).shuffle(20)  # len(idx)
            for label, idx in train_label_idx.items()
        ]

        # create a dataset that samples balanced batches from the label datasets
        background_train_ds = dataset.create_balanced_batch_dataset(
            model_options["p"], model_options["k"], train_label_datasets)

    # create standard training dataset pipeline (shuffle and load training set)
    else:
        shuffle_train = True
        prefetch_train = True
        num_repeat = model_options["num_augment"]

        background_train_ds = tf.data.Dataset.zip(
            (tf.data.Dataset.from_tensor_slices(train_paths),
             tf.data.Dataset.from_tensor_slices(train_labels)))

    # load embedding TFRecords (faster here than before balanced sampling)
    if model_options["use_embeddings"]:
        # batch to read files in parallel
        background_train_ds = background_train_ds.batch(
            model_options["batch_size"])

        background_train_ds = background_train_ds.flat_map(
            lambda paths, labels: tf.data.Dataset.zip((tf.data.TFRecordDataset(
                paths, compression_type="ZLIB", num_parallel_reads=8
            ), tf.data.Dataset.from_tensor_slices(labels))))

    # map data preprocessing function across training data
    background_train_ds = background_train_ds.map(
        lambda data, label: (preprocess_data_func(data), label),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    # repeat augmentation, shuffle and batch train data
    if num_repeat is not None:
        background_train_ds = background_train_ds.repeat(num_repeat)

    if shuffle_train:
        background_train_ds = background_train_ds.shuffle(1000)

    background_train_ds = background_train_ds.batch(
        model_options["batch_size"])

    if prefetch_train:
        background_train_ds = background_train_ds.prefetch(
            tf.data.experimental.AUTOTUNE)

    # create dev set pipeline for siamese validation
    if model_options["use_embeddings"]:
        background_dev_ds = tf.data.Dataset.zip(
            (tf.data.TFRecordDataset(
                dev_exp.embed_paths,
                compression_type="ZLIB",
                num_parallel_reads=8).map(preprocess_data_func),
             tf.data.Dataset.from_tensor_slices(dev_labels)))
    else:
        background_dev_ds = tf.data.Dataset.zip(
            (tf.data.Dataset.from_tensor_slices(
                dev_exp.image_paths).map(preprocess_data_func),
             tf.data.Dataset.from_tensor_slices(dev_labels)))

    background_dev_ds = background_dev_ds.batch(
        batch_size=model_options["batch_size"])

    # write example batch to TensorBoard
    if tf_writer is not None and not model_options["use_embeddings"]:
        logging.log(logging.INFO, "Writing example images to TensorBoard")
        with tf_writer.as_default():
            for x_batch, y_batch in background_train_ds.take(1):
                tf.summary.image("Example train images", (x_batch + 1) / 2,
                                 max_outputs=30,
                                 step=0)
                labels = ""
                for i, label in enumerate(y_batch[:30]):
                    labels += f"{i}: {np.asarray(train_exp.keywords)[label]} "
                tf.summary.text("Example train labels", labels, step=0)

    # get training objective
    triplet_loss = get_training_objective(model_options)

    # get model input shape
    if model_options["use_embeddings"]:
        for x_batch, _ in background_train_ds.take(1):
            model_options["base_embed_size"] = int(
                tf.shape(x_batch)[1].numpy())

        model_options["input_shape"] = [model_options["base_embed_size"]]
    else:
        model_options["input_shape"] = [
            model_options["crop_size"], model_options["crop_size"], 3
        ]

    # load or create model
    if model_file is not None:
        vision_network, train_state = model_utils.load_model(
            model_file=os.path.join(output_dir, model_file),
            model_step_file=os.path.join(output_dir, model_step_file),
            loss=triplet_loss)

        # get previous tracking variables
        initial_model = False
        global_step, model_epochs, _, best_val_score = train_state
    else:
        vision_network = create_vision_network(model_options)

        # create tracking variables
        initial_model = True
        global_step = 0
        model_epochs = 0

        if model_options["one_shot_validation"]:
            best_val_score = -np.inf
        else:
            best_val_score = np.inf

    # load or create Adam optimizer with decayed learning rate
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        model_options["learning_rate"],
        decay_rate=model_options["decay_rate"],
        decay_steps=model_options["decay_steps"],
        staircase=True)

    if model_file is not None:
        logging.log(logging.INFO, "Restoring optimizer state")
        optimizer = vision_network.optimizer
    else:
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    # compile model to store optimizer with model when saving
    vision_network.compile(optimizer=optimizer, loss=triplet_loss)

    # create few-shot model from vision network for background training
    vision_few_shot_model = base.BaseModel(vision_network, triplet_loss)

    # test model on one-shot validation task prior to training
    if model_options["one_shot_validation"]:

        one_shot_dev_exp = flickr_vision.FlickrVision(
            keywords_split="background_dev",
            flickr8k_image_dir=os.path.join("data", "external",
                                            "flickr8k_images"),
            flickr30k_image_dir=os.path.join("data", "external",
                                             "flickr30k_images"),
            mscoco_image_dir=os.path.join("data", "external", "mscoco",
                                          "val2017"),
            preprocess_func=get_data_preprocess_func(model_options),
            embed_dir=embed_dir)

        embedding_model_func = create_embedding_model

        classification = False
        if FLAGS.classification:
            assert FLAGS.fine_tune_steps is not None
            classification = True

        # create few-shot model from vision network for one-shot validation
        if FLAGS.fine_tune_steps is not None:
            test_few_shot_model = create_fine_tune_model(
                model_options,
                vision_few_shot_model.model,
                num_classes=FLAGS.L)
        else:
            test_few_shot_model = base.BaseModel(vision_few_shot_model.model,
                                                 None,
                                                 mc_dropout=FLAGS.mc_dropout)

        val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot(
            one_shot_dev_exp,
            FLAGS.K,
            FLAGS.L,
            n=FLAGS.N,
            num_episodes=FLAGS.episodes,
            k_neighbours=FLAGS.k_neighbours,
            metric=FLAGS.metric,
            classification=classification,
            model=test_few_shot_model,
            embedding_model_func=embedding_model_func,
            fine_tune_steps=FLAGS.fine_tune_steps,
            fine_tune_lr=FLAGS.fine_tune_lr)

        logging.log(
            logging.INFO,
            f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
            f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
            f"{conf_interval_95*100:.4f}")

    # create training metrics
    loss_metric = tf.keras.metrics.Mean()
    best_model = False

    # store model options on first run
    if initial_model:
        file_io.write_json(os.path.join(output_dir, "model_options.json"),
                           model_options)

        # also store initial model for probing and things
        model_utils.save_model(vision_few_shot_model.model,
                               output_dir,
                               0,
                               0,
                               "not tested",
                               0.,
                               0.,
                               name="initial_model")

    # train model
    for epoch in range(model_epochs, model_options["epochs"]):
        logging.log(logging.INFO, f"Epoch {epoch:03d}")

        loss_metric.reset_states()

        # train on epoch of training data
        step_pbar = tqdm(background_train_ds,
                         bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]")
        for step, (x_batch, y_batch) in enumerate(step_pbar):

            loss_value, y_predict = vision_few_shot_model.train_step(
                x_batch,
                y_batch,
                optimizer,
                clip_norm=model_options["gradient_clip_norm"])

            loss_metric.update_state(loss_value)

            step_loss = tf.reduce_mean(loss_value)
            train_loss = loss_metric.result().numpy()

            step_pbar.set_description_str(f"\tStep {step:03d}: "
                                          f"Step loss: {step_loss:.6f}, "
                                          f"Loss: {train_loss:.6f}")

            if tf_writer is not None:
                with tf_writer.as_default():
                    tf.summary.scalar("Train step loss",
                                      step_loss,
                                      step=global_step)
            global_step += 1

        # validate siamese model
        loss_metric.reset_states()

        for x_batch, y_batch in background_dev_ds:
            y_predict = vision_few_shot_model.predict(x_batch, training=False)
            loss_value = vision_few_shot_model.loss(y_batch, y_predict)

            loss_metric.update_state(loss_value)

        dev_loss = loss_metric.result().numpy()

        # validate model on one-shot dev task if specified
        if model_options["one_shot_validation"]:

            if FLAGS.fine_tune_steps is not None:
                test_few_shot_model = create_fine_tune_model(
                    model_options,
                    vision_few_shot_model.model,
                    num_classes=FLAGS.L)
            else:
                test_few_shot_model = base.BaseModel(
                    vision_few_shot_model.model,
                    None,
                    mc_dropout=FLAGS.mc_dropout)

            val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot(
                one_shot_dev_exp,
                FLAGS.K,
                FLAGS.L,
                n=FLAGS.N,
                num_episodes=FLAGS.episodes,
                k_neighbours=FLAGS.k_neighbours,
                metric=FLAGS.metric,
                classification=classification,
                model=test_few_shot_model,
                embedding_model_func=embedding_model_func,
                fine_tune_steps=FLAGS.fine_tune_steps,
                fine_tune_lr=FLAGS.fine_tune_lr)

            val_score = val_task_accuracy
            val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy"

            if val_score >= best_val_score:
                best_val_score = val_score
                best_model = True

        # otherwise, validate on siamese task
        else:
            val_score = dev_loss
            val_metric = "loss"

            if val_score <= best_val_score:
                best_val_score = val_score
                best_model = True

        # log results
        logging.log(logging.INFO, f"Train: Loss: {train_loss:.6f}")

        logging.log(
            logging.INFO,
            f"Validation: Loss: {dev_loss:.6f} {'*' if best_model else ''}")

        if model_options["one_shot_validation"]:
            logging.log(
                logging.INFO,
                f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
                f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
                f"{conf_interval_95*100:.4f} {'*' if best_model else ''}")

        if tf_writer is not None:
            with tf_writer.as_default():
                tf.summary.scalar("Train step loss",
                                  train_loss,
                                  step=global_step)
                tf.summary.scalar(f"Validation loss",
                                  dev_loss,
                                  step=global_step)
                if model_options["one_shot_validation"]:
                    tf.summary.scalar(
                        f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy",
                        val_task_accuracy,
                        step=global_step)

        # store model and results
        model_utils.save_model(vision_few_shot_model.model,
                               output_dir,
                               epoch + 1,
                               global_step,
                               val_metric,
                               val_score,
                               best_val_score,
                               name="model")

        if best_model:
            best_model = False
            model_utils.save_model(vision_few_shot_model.model,
                                   output_dir,
                                   epoch + 1,
                                   global_step,
                                   val_metric,
                                   val_score,
                                   best_val_score,
                                   name="best_model")
Пример #6
0
def train(model_options, output_dir, model_file=None, model_step_file=None,
          tf_writer=None):
    """Create and train image classification model for one-shot learning."""

    # load training data
    train_exp, dev_exp = dataset.create_flickr_vision_train_data(
        model_options["data"])

    train_labels = []
    for image_keywords in train_exp.unique_image_keywords:
        labels = map(
            lambda keyword: train_exp.keyword_labels[keyword], image_keywords)
        train_labels.append(np.array(list(labels)))
    train_labels = np.asarray(train_labels)

    dev_labels = []
    for image_keywords in dev_exp.unique_image_keywords:
        labels = map(
            lambda keyword: train_exp.keyword_labels[keyword], image_keywords)
        dev_labels.append(np.array(list(labels)))
    dev_labels = np.asarray(dev_labels)

    train_paths = train_exp.unique_image_paths
    dev_paths = dev_exp.unique_image_paths

    mlb = MultiLabelBinarizer()
    train_labels_multi_hot = mlb.fit_transform(train_labels)
    dev_labels_multi_hot = mlb.transform(dev_labels)

    # define preprocessing for images
    preprocess_images_func = functools.partial(
        dataset.load_and_preprocess_image,
        crop_size=model_options["crop_size"],
        augment_crop=model_options["augment_train"],
        random_scales=model_options["random_scales"],
        horizontal_flip=model_options["horizontal_flip"],
        colour=model_options["colour"])

    # create standard training dataset pipeline
    background_train_ds = tf.data.Dataset.zip((
        tf.data.Dataset.from_tensor_slices(train_paths),
        tf.data.Dataset.from_tensor_slices(train_labels_multi_hot)))

    # map data preprocessing function across training data
    background_train_ds = background_train_ds.map(
        lambda path, label: (preprocess_images_func(path), label),
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    # repeat augmentation, shuffle and batch train data
    if model_options["num_augment"] is not None:
        background_train_ds = background_train_ds.repeat(
            model_options["num_augment"])

    background_train_ds = background_train_ds.shuffle(1000)

    background_train_ds = background_train_ds.batch(
        model_options["batch_size"])

    background_train_ds = background_train_ds.prefetch(
        tf.data.experimental.AUTOTUNE)

    # create dev set pipeline for classification validation
    background_dev_ds = tf.data.Dataset.zip((
        tf.data.Dataset.from_tensor_slices(
            dev_paths).map(preprocess_images_func),
        tf.data.Dataset.from_tensor_slices(dev_labels_multi_hot)))

    background_dev_ds = background_dev_ds.batch(
        batch_size=model_options["batch_size"])

    # write example batch to TensorBoard
    if tf_writer is not None:
        logging.log(logging.INFO, "Writing example images to TensorBoard")
        with tf_writer.as_default():
            for x_batch, y_batch in background_train_ds.take(1):
                tf.summary.image("Example train images", (x_batch+1)/2,
                                 max_outputs=30, step=0)
                labels = ""
                for i, label in enumerate(y_batch[:30]):
                    labels += f"{i}: {np.asarray(train_exp.keywords)[label]} "
                tf.summary.text("Example train labels", labels, step=0)

    # get training objective
    multi_label_loss = get_training_objective(model_options)

    # get model input shape
    model_options["input_shape"] = (
        model_options["crop_size"], model_options["crop_size"], 3)

    # load or create model
    if model_file is not None:
        assert model_options["n_classes"] == len(train_exp.keywords)

        vision_network, train_state = model_utils.load_model(
            model_file=os.path.join(output_dir, model_file),
            model_step_file=os.path.join(output_dir, model_step_file),
            loss=multi_label_loss)

        # get previous tracking variables
        initial_model = False
        global_step, model_epochs, _, best_val_score = train_state
    else:
        model_options["n_classes"] = len(train_exp.keywords)

        vision_network = create_vision_network(model_options)

        # create tracking variables
        initial_model = True
        global_step = 0
        model_epochs = 0

        if model_options["one_shot_validation"]:
            best_val_score = -np.inf
        else:
            best_val_score = np.inf

    # load or create Adam optimizer with decayed learning rate
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        model_options["learning_rate"], decay_rate=model_options["decay_rate"],
        decay_steps=model_options["decay_steps"], staircase=True)

    if model_file is not None:
        logging.log(logging.INFO, "Restoring optimizer state")
        optimizer = vision_network.optimizer
    else:
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    # compile model to store optimizer with model when saving
    vision_network.compile(optimizer=optimizer, loss=multi_label_loss)

    # create few-shot model from vision network for background training
    vision_few_shot_model = base.BaseModel(vision_network, multi_label_loss)

    # test model on one-shot validation task prior to training
    if model_options["one_shot_validation"]:

        one_shot_dev_exp = flickr_vision.FlickrVision(
            keywords_split="background_dev",
            flickr8k_image_dir=os.path.join(
                "data", "external", "flickr8k_images"),
            flickr30k_image_dir=os.path.join(
                "data", "external", "flickr30k_images"),
            mscoco_image_dir=os.path.join(
                "data", "external", "mscoco", "val2017"),
            preprocess_func=get_data_preprocess_func(model_options))

        embedding_model_func = lambda vision_network: create_embedding_model(
            model_options, vision_network)

        classification = False
        if FLAGS.classification:
            assert FLAGS.embed_layer in ["logits", "softmax"]
            classification = True

        # create few-shot model from vision network for one-shot validation
        if FLAGS.fine_tune_steps is not None:
            test_few_shot_model = create_fine_tune_model(
                model_options, vision_few_shot_model.model, num_classes=FLAGS.L)
        else:
            test_few_shot_model = base.BaseModel(
                vision_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout)

        val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot(
            one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N,
            num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours,
            metric=FLAGS.metric, classification=classification,
            model=test_few_shot_model,
            embedding_model_func=embedding_model_func,
            fine_tune_steps=FLAGS.fine_tune_steps,
            fine_tune_lr=FLAGS.fine_tune_lr)

        logging.log(
            logging.INFO,
            f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
            f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
            f"{conf_interval_95*100:.4f}")

    # create training metrics
    precision_metric = tf.keras.metrics.Precision()
    recall_metric = tf.keras.metrics.Recall()
    loss_metric = tf.keras.metrics.Mean()
    best_model = False

    # store model options on first run
    if initial_model:
        file_io.write_json(
            os.path.join(output_dir, "model_options.json"), model_options)

    # train model
    for epoch in range(model_epochs, model_options["epochs"]):
        logging.log(logging.INFO, f"Epoch {epoch:03d}")

        precision_metric.reset_states()
        recall_metric.reset_states()
        loss_metric.reset_states()

        # train on epoch of training data
        step_pbar = tqdm(background_train_ds,
                         bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]")
        for step, (x_batch, y_batch) in enumerate(step_pbar):

            loss_value, y_predict = vision_few_shot_model.train_step(
                x_batch, y_batch, optimizer,
                clip_norm=model_options["gradient_clip_norm"])

            y_one_hot_predict = tf.round(tf.nn.sigmoid(y_predict))

            precision_metric.update_state(y_batch, y_one_hot_predict)
            recall_metric.update_state(y_batch, y_one_hot_predict)
            loss_metric.update_state(loss_value)

            step_loss = tf.reduce_mean(loss_value)
            train_loss = loss_metric.result().numpy()
            train_precision = precision_metric.result().numpy()
            train_recall = recall_metric.result().numpy()
            train_f1 = 2 / ((1/train_precision) + (1/train_recall))

            step_pbar.set_description_str(
                f"\tStep {step:03d}: "
                f"Step loss: {step_loss:.6f}, "
                f"Loss: {train_loss:.6f}, "
                f"Precision: {train_precision:.3%}, "
                f"Recall: {train_recall:.3%}, "
                f"F-1: {train_f1:.3%}")

            if tf_writer is not None:
                with tf_writer.as_default():
                    tf.summary.scalar(
                        "Train step loss", step_loss, step=global_step)
            global_step += 1

        # validate classification model
        precision_metric.reset_states()
        recall_metric.reset_states()
        loss_metric.reset_states()

        for x_batch, y_batch in background_dev_ds:
            y_predict = vision_few_shot_model.predict(x_batch, training=False)
            loss_value = vision_few_shot_model.loss(y_batch, y_predict)

            y_one_hot_predict = tf.round(tf.nn.sigmoid(y_predict))

            precision_metric.update_state(y_batch, y_one_hot_predict)
            recall_metric.update_state(y_batch, y_one_hot_predict)
            loss_metric.update_state(loss_value)

        dev_loss = loss_metric.result().numpy()
        dev_precision = precision_metric.result().numpy()
        dev_recall = recall_metric.result().numpy()
        dev_f1 = 2 / ((1/dev_precision) + (1/dev_recall))

        # validate model on one-shot dev task if specified
        if model_options["one_shot_validation"]:

            if FLAGS.fine_tune_steps is not None:
                test_few_shot_model = create_fine_tune_model(
                    model_options, vision_few_shot_model.model,
                    num_classes=FLAGS.L)
            else:
                test_few_shot_model = base.BaseModel(
                    vision_few_shot_model.model, None,
                    mc_dropout=FLAGS.mc_dropout)

            val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot(
                one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N,
                num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours,
                metric=FLAGS.metric, classification=classification,
                model=test_few_shot_model,
                embedding_model_func=embedding_model_func,
                fine_tune_steps=FLAGS.fine_tune_steps,
                fine_tune_lr=FLAGS.fine_tune_lr)

            val_score = val_task_accuracy
            val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy"

        # otherwise, validate on classification task
        else:
            val_score = dev_f1
            val_metric = "F-1"

        if val_score >= best_val_score:
            best_val_score = val_score
            best_model = True

        # log results
        logging.log(
            logging.INFO,
            f"Train: Loss: {train_loss:.6f}, Precision: {train_precision:.3%}, "
            f"Recall: {train_recall:.3%}, F-1: {train_f1:.3%}")

        logging.log(
            logging.INFO,
            f"Validation: Loss: {dev_loss:.6f}, Precision: "
            f"{dev_precision:.3%}, Recall: {dev_recall:.3%}, F-1: "
            f"{dev_f1:.3%} {'*' if best_model else ''}")

        if model_options["one_shot_validation"]:
            logging.log(
                logging.INFO,
                f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after "
                f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- "
                f"{conf_interval_95*100:.4f} {'*' if best_model else ''}")

        if tf_writer is not None:
            with tf_writer.as_default():
                tf.summary.scalar(
                    "Train step loss", train_loss, step=global_step)
                tf.summary.scalar(
                    "Train precision", train_precision, step=global_step)
                tf.summary.scalar(
                    "Train recall", train_recall, step=global_step)
                tf.summary.scalar(
                    "Train F-1", train_f1, step=global_step)
                tf.summary.scalar(
                    "Validation loss", dev_loss, step=global_step)
                tf.summary.scalar(
                    "Validation precision", dev_precision, step=global_step)
                tf.summary.scalar(
                    "Validation recall", dev_recall, step=global_step)
                tf.summary.scalar(
                    "Validation F-1", dev_f1, step=global_step)
                if model_options["one_shot_validation"]:
                    tf.summary.scalar(
                        f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy",
                        val_task_accuracy, step=global_step)

        # store model and results
        model_utils.save_model(
            vision_few_shot_model.model, output_dir, epoch + 1, global_step,
            val_metric, val_score, best_val_score, name="model")

        if best_model:
            best_model = False

            model_utils.save_model(
                vision_few_shot_model.model, output_dir, epoch + 1, global_step,
                val_metric, val_score, best_val_score, name="best_model")