def test(model_options, output_dir, model_file, model_step_file): """Load and test spoken word classification model for one-shot learning.""" # load Flickr Audio one-shot experiment one_shot_exp = flickr_speech.FlickrSpeech( features=model_options["features"], keywords_split="one_shot_evaluation", preprocess_func=get_data_preprocess_func(model_options), speaker_mode=FLAGS.speaker_mode) # load model speech_network, _ = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=get_training_objective(model_options)) embedding_model_func = lambda speech_network: create_embedding_model( model_options, speech_network) # create few-shot model from speech network for one-shot testing if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model(model_options, speech_network, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel(speech_network, None, mc_dropout=FLAGS.mc_dropout) classification = False if FLAGS.classification: assert FLAGS.embed_layer in ["logits", "softmax"] classification = True logging.log(logging.INFO, "Created few-shot model from speech network") test_few_shot_model.model.summary() # test model on L-way K-shot task task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"{FLAGS.L}-way {FLAGS.K}-shot accuracy after {FLAGS.episodes} " f"episodes: {task_accuracy:.3%} +- {conf_interval_95*100:.4f}")
def embed(model_options, output_dir, model_file, model_step_file): """Load siamese spoken word similarity model and extract embeddings.""" # load embeddings from dense layer of base model embed_dir = os.path.join(model_options["base_dir"], "embed", "dense") # load model speech_network, _ = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=get_training_objective(model_options)) # get model embedding model and data preprocessing embedding_model = create_embedding_model(speech_network) data_preprocess_func = get_data_preprocess_func() # load Flickr Audio dataset and compute embeddings one_shot_exp = flickr_speech.FlickrSpeech( features="mfcc", keywords_split="one_shot_evaluation", embed_dir=embed_dir) background_train_exp = flickr_speech.FlickrSpeech( features="mfcc", keywords_split="background_train", embed_dir=embed_dir) background_dev_exp = flickr_speech.FlickrSpeech( features="mfcc", keywords_split="background_dev", embed_dir=embed_dir) subset_exp = { "one_shot_evaluation": one_shot_exp, "background_train": background_train_exp, "background_dev": background_dev_exp, } for subset, exp in subset_exp.items(): embed_dir = os.path.join( output_dir, "embed", "dense", "flickr_audio", subset) file_io.check_create_dir(embed_dir) unique_paths = np.unique(exp.embed_paths) # batch base embeddings for faster embedding inference path_ds = tf.data.Dataset.from_tensor_slices(unique_paths) path_ds = path_ds.batch(model_options["batch_size"]) path_ds = path_ds.prefetch(tf.data.experimental.AUTOTUNE) num_samples = int( np.ceil(len(unique_paths) / model_options["batch_size"])) start_time = time.time() paths, embeddings = [], [] for path_batch in tqdm(path_ds, total=num_samples): path_embeddings = embedding_model.predict( data_preprocess_func(path_batch)) paths.extend(path_batch.numpy()) embeddings.extend(path_embeddings.numpy()) end_time = time.time() logging.log( logging.INFO, f"Computed embeddings for Flickr Audio {subset} in " f"{end_time - start_time:.4f} seconds") # serialize and write embeddings to TFRecord files for path, embedding in zip(paths, embeddings): example_proto = dataset.embedding_to_example_protobuf(embedding) path = path.decode("utf-8") path = path.split(".tfrecord")[0] # remove any ".tfrecord" ext path = os.path.join( embed_dir, f"{os.path.split(path)[1]}.tfrecord") with tf.io.TFRecordWriter(path, options="ZLIB") as writer: writer.write(example_proto.SerializeToString()) logging.log(logging.INFO, f"Embeddings stored at: {embed_dir}")
def train(model_options, output_dir, model_file=None, model_step_file=None, tf_writer=None): """Create and train spoken word similarity model for one-shot learning.""" # load embeddings from dense layer of base model embed_dir = os.path.join(model_options["base_dir"], "embed", "dense") # load training data (embed dir determines if mfcc/fbank) train_exp, dev_exp = dataset.create_flickr_audio_train_data( "mfcc", embed_dir=embed_dir, speaker_mode=FLAGS.speaker_mode) train_labels = [] for keyword in train_exp.keywords_set[3]: label = train_exp.keyword_labels[keyword] train_labels.append(label) train_labels = np.asarray(train_labels) dev_labels = [] for keyword in dev_exp.keywords_set[3]: label = train_exp.keyword_labels[keyword] dev_labels.append(label) dev_labels = np.asarray(dev_labels) train_paths = train_exp.embed_paths dev_paths = dev_exp.embed_paths # define preprocessing for base model embeddings preprocess_data_func = lambda example: dataset.parse_embedding_protobuf( example)["embed"] # create balanced batch training dataset pipeline if model_options["balanced"]: assert model_options["p"] is not None assert model_options["k"] is not None shuffle_train = False prefetch_train = False num_repeat = model_options["num_batches"] model_options["batch_size"] = model_options["p"] * model_options["k"] # get unique path train indices per unique label train_labels_series = pd.Series(train_labels) train_label_idx = { label: idx.values[ np.unique(train_paths[idx.values], return_index=True)[1]] for label, idx in train_labels_series.groupby( train_labels_series).groups.items()} # cache paths to speed things up a little ... file_io.check_create_dir(os.path.join(output_dir, "cache")) # create a dataset for each unique keyword label (shuffled and cached) train_label_datasets = [ tf.data.Dataset.zip(( tf.data.Dataset.from_tensor_slices(train_paths[idx]), tf.data.Dataset.from_tensor_slices(train_labels[idx]))).cache( os.path.join( output_dir, "cache", str(label))).shuffle(20) # len(idx) for label, idx in train_label_idx.items()] # create a dataset that samples balanced batches from the label datasets background_train_ds = dataset.create_balanced_batch_dataset( model_options["p"], model_options["k"], train_label_datasets) # create standard training dataset pipeline (shuffle and load training set) else: shuffle_train = True prefetch_train = True num_repeat = model_options["num_augment"] background_train_ds = tf.data.Dataset.zip(( tf.data.Dataset.from_tensor_slices(train_paths), tf.data.Dataset.from_tensor_slices(train_labels))) # load embedding TFRecords (faster here than before balanced sampling) # batch to read files in parallel background_train_ds = background_train_ds.batch( model_options["batch_size"]) background_train_ds = background_train_ds.flat_map( lambda paths, labels: tf.data.Dataset.zip(( tf.data.TFRecordDataset( paths, compression_type="ZLIB", num_parallel_reads=8), tf.data.Dataset.from_tensor_slices(labels)))) # map data preprocessing function across training data background_train_ds = background_train_ds.map( lambda data, label: (preprocess_data_func(data), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) # repeat augmentation, shuffle and batch train data if num_repeat is not None: background_train_ds = background_train_ds.repeat(num_repeat) if shuffle_train: background_train_ds = background_train_ds.shuffle(1000) background_train_ds = background_train_ds.batch( model_options["batch_size"]) if prefetch_train: background_train_ds = background_train_ds.prefetch( tf.data.experimental.AUTOTUNE) # create dev set pipeline for siamese validation background_dev_ds = tf.data.Dataset.zip(( tf.data.TFRecordDataset( dev_paths, compression_type="ZLIB", num_parallel_reads=8).map(preprocess_data_func), tf.data.Dataset.from_tensor_slices(dev_labels))) background_dev_ds = background_dev_ds.batch( batch_size=model_options["batch_size"]) # get training objective triplet_loss = get_training_objective(model_options) # get model input shape for x_batch, _ in background_train_ds.take(1): model_options["base_embed_size"] = int( tf.shape(x_batch)[1].numpy()) model_options["input_shape"] = [model_options["base_embed_size"]] # load or create model if model_file is not None: speech_network, train_state = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=triplet_loss) # get previous tracking variables initial_model = False global_step, model_epochs, _, best_val_score = train_state else: speech_network = create_speech_network(model_options) # create tracking variables initial_model = True global_step = 0 model_epochs = 0 if model_options["one_shot_validation"]: best_val_score = -np.inf else: best_val_score = np.inf # load or create Adam optimizer with decayed learning rate lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( model_options["learning_rate"], decay_rate=model_options["decay_rate"], decay_steps=model_options["decay_steps"], staircase=True) if model_file is not None: logging.log(logging.INFO, "Restoring optimizer state") optimizer = speech_network.optimizer else: optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) # compile model to store optimizer with model when saving speech_network.compile(optimizer=optimizer, loss=triplet_loss) # create few-shot model from speech network for background training speech_few_shot_model = base.BaseModel(speech_network, triplet_loss) # test model on one-shot validation task prior to training if model_options["one_shot_validation"]: one_shot_dev_exp = flickr_speech.FlickrSpeech( features="mfcc", keywords_split="background_dev", preprocess_func=get_data_preprocess_func(), embed_dir=embed_dir, speaker_mode=FLAGS.speaker_mode) embedding_model_func = create_embedding_model classification = False if FLAGS.classification: assert FLAGS.fine_tune_steps is not None classification = True # create few-shot model from speech network for one-shot validation if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f}") # create training metrics loss_metric = tf.keras.metrics.Mean() best_model = False # store model options on first run if initial_model: file_io.write_json( os.path.join(output_dir, "model_options.json"), model_options) # train model for epoch in range(model_epochs, model_options["epochs"]): logging.log(logging.INFO, f"Epoch {epoch:03d}") loss_metric.reset_states() # train on epoch of training data step_pbar = tqdm(background_train_ds, bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]") for step, (x_batch, y_batch) in enumerate(step_pbar): loss_value, y_predict = speech_few_shot_model.train_step( x_batch, y_batch, optimizer, clip_norm=model_options["gradient_clip_norm"]) loss_metric.update_state(loss_value) step_loss = tf.reduce_mean(loss_value) train_loss = loss_metric.result().numpy() step_pbar.set_description_str( f"\tStep {step:03d}: " f"Step loss: {step_loss:.6f}, " f"Loss: {train_loss:.6f}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", step_loss, step=global_step) global_step += 1 # validate siamese model loss_metric.reset_states() for x_batch, y_batch in background_dev_ds: y_predict = speech_few_shot_model.predict(x_batch, training=False) loss_value = speech_few_shot_model.loss(y_batch, y_predict) loss_metric.update_state(loss_value) dev_loss = loss_metric.result().numpy() # validate model on one-shot dev task if specified if model_options["one_shot_validation"]: if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) val_score = val_task_accuracy val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy" if val_score >= best_val_score: best_val_score = val_score best_model = True # otherwise, validate on siamese task else: val_score = dev_loss val_metric = "loss" if val_score <= best_val_score: best_val_score = val_score best_model = True # log results logging.log(logging.INFO, f"Train: Loss: {train_loss:.6f}") logging.log( logging.INFO, f"Validation: Loss: {dev_loss:.6f} {'*' if best_model else ''}") if model_options["one_shot_validation"]: logging.log( logging.INFO, f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f} {'*' if best_model else ''}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", train_loss, step=global_step) tf.summary.scalar( f"Validation loss", dev_loss, step=global_step) if model_options["one_shot_validation"]: tf.summary.scalar( f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy", val_task_accuracy, step=global_step) # store model and results model_utils.save_model( speech_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="model") if best_model: best_model = False model_utils.save_model( speech_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="best_model")
def train(model_options, output_dir, model_file=None, model_step_file=None, tf_writer=None): """Create and train spoken word classification model for one-shot learning.""" # load training data train_exp, dev_exp = dataset.create_flickr_audio_train_data( model_options["features"], speaker_mode=FLAGS.speaker_mode) train_labels = [] for keyword in train_exp.keywords_set[3]: label = train_exp.keyword_labels[keyword] train_labels.append(label) train_labels = np.asarray(train_labels) dev_labels = [] for keyword in dev_exp.keywords_set[3]: label = train_exp.keyword_labels[keyword] dev_labels.append(label) dev_labels = np.asarray(dev_labels) train_paths = train_exp.audio_paths dev_paths = dev_exp.audio_paths lb = LabelBinarizer() train_labels_one_hot = lb.fit_transform(train_labels) dev_labels_one_hot = lb.transform(dev_labels) # define preprocessing for speech features preprocess_speech_func = functools.partial( dataset.load_and_preprocess_speech, features=model_options["features"], max_length=model_options["max_length"], scaling=model_options["scaling"]) preprocess_speech_ds_func = lambda path: tf.py_function( func=preprocess_speech_func, inp=[path], Tout=tf.float32) # create standard training dataset pipeline background_train_ds = tf.data.Dataset.zip( (tf.data.Dataset.from_tensor_slices(train_paths), tf.data.Dataset.from_tensor_slices(train_labels_one_hot))) # map data preprocessing function across training data background_train_ds = background_train_ds.map( lambda path, label: (preprocess_speech_ds_func(path), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) # shuffle and batch train data background_train_ds = background_train_ds.shuffle(1000) background_train_ds = background_train_ds.batch( model_options["batch_size"]) background_train_ds = background_train_ds.prefetch( tf.data.experimental.AUTOTUNE) # create dev set pipeline for classification validation background_dev_ds = tf.data.Dataset.zip( (tf.data.Dataset.from_tensor_slices(dev_paths).map( preprocess_speech_ds_func), tf.data.Dataset.from_tensor_slices(dev_labels_one_hot))) background_dev_ds = background_dev_ds.batch( batch_size=model_options["batch_size"]) # write example batch to TensorBoard if tf_writer is not None: logging.log(logging.INFO, "Writing example features to TensorBoard") with tf_writer.as_default(): for x_batch, y_batch in background_train_ds.take(1): speech_feats = [] for feats in x_batch[:30]: feats = np.transpose(feats) speech_feats.append( (feats - np.min(feats)) / np.max(feats)) tf.summary.image( f"Example train speech {model_options['features']}", np.expand_dims(speech_feats, axis=-1), max_outputs=30, step=0) labels = "" for i, label in enumerate(y_batch[:30]): labels += f"{i}: {np.asarray(train_exp.keywords)[label]} " tf.summary.text("Example train labels", labels, step=0) # get training objective loss = get_training_objective(model_options) # get model input shape if model_options["features"] == "mfcc": model_options["input_shape"] = [model_options["max_length"], 39] else: model_options["input_shape"] = [model_options["max_length"], 40] # load or create model if model_file is not None: assert model_options["n_classes"] == len(train_exp.keywords) speech_network, train_state = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=loss) # get previous tracking variables initial_model = False global_step, model_epochs, _, best_val_score = train_state else: model_options["n_classes"] = len(train_exp.keywords) speech_network = create_speech_network(model_options) # create tracking variables initial_model = True global_step = 0 model_epochs = 0 if model_options["one_shot_validation"]: best_val_score = -np.inf else: best_val_score = np.inf # load or create Adam optimizer with decayed learning rate lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( model_options["learning_rate"], decay_rate=model_options["decay_rate"], decay_steps=model_options["decay_steps"], staircase=True) if model_file is not None: logging.log(logging.INFO, "Restoring optimizer state") optimizer = speech_network.optimizer else: optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) # compile model to store optimizer with model when saving speech_network.compile(optimizer=optimizer, loss=loss) # create few-shot model from speech network for background training speech_few_shot_model = base.BaseModel(speech_network, loss) # test model on one-shot validation task prior to training if model_options["one_shot_validation"]: one_shot_dev_exp = flickr_speech.FlickrSpeech( features=model_options["features"], keywords_split="background_dev", preprocess_func=get_data_preprocess_func(model_options), speaker_mode=FLAGS.speaker_mode) embedding_model_func = lambda speech_network: create_embedding_model( model_options, speech_network) classification = False if FLAGS.classification: assert FLAGS.embed_layer in ["logits", "softmax"] classification = True # create few-shot model from speech network for one-shot validation if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel(speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f}") # create training metrics accuracy_metric = tf.keras.metrics.CategoricalAccuracy() loss_metric = tf.keras.metrics.Mean() best_model = False # store model options on first run if initial_model: file_io.write_json(os.path.join(output_dir, "model_options.json"), model_options) # train model for epoch in range(model_epochs, model_options["epochs"]): logging.log(logging.INFO, f"Epoch {epoch:03d}") accuracy_metric.reset_states() loss_metric.reset_states() # train on epoch of training data step_pbar = tqdm(background_train_ds, bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]") for step, (x_batch, y_batch) in enumerate(step_pbar): loss_value, y_predict = speech_few_shot_model.train_step( x_batch, y_batch, optimizer, clip_norm=model_options["gradient_clip_norm"]) accuracy_metric.update_state(y_batch, y_predict) loss_metric.update_state(loss_value) step_loss = tf.reduce_mean(loss_value) train_loss = loss_metric.result().numpy() train_accuracy = accuracy_metric.result().numpy() step_pbar.set_description_str( f"\tStep {step:03d}: " f"Step loss: {step_loss:.6f}, " f"Loss: {train_loss:.6f}, " f"Categorical accuracy: {train_accuracy:.3%}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar("Train step loss", step_loss, step=global_step) global_step += 1 # validate classification model accuracy_metric.reset_states() loss_metric.reset_states() for x_batch, y_batch in background_dev_ds: y_predict = speech_few_shot_model.predict(x_batch, training=False) loss_value = speech_few_shot_model.loss(y_batch, y_predict) accuracy_metric.update_state(y_batch, y_predict) loss_metric.update_state(loss_value) dev_loss = loss_metric.result().numpy() dev_accuracy = accuracy_metric.result().numpy() # validate model on one-shot dev task if specified if model_options["one_shot_validation"]: if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) val_score = val_task_accuracy val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy" # otherwise, validate on classification task else: val_score = dev_accuracy val_metric = "categorical accuracy" if val_score >= best_val_score: best_val_score = val_score best_model = True # log results logging.log( logging.INFO, f"Train: Loss: {train_loss:.6f}, Categorical accuracy: " f"{train_accuracy:.3%}") logging.log( logging.INFO, f"Validation: Loss: {dev_loss:.6f}, Categorical accuracy: " f"{dev_accuracy:.3%} {'*' if best_model else ''}") if model_options["one_shot_validation"]: logging.log( logging.INFO, f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f} {'*' if best_model else ''}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar("Train step loss", train_loss, step=global_step) tf.summary.scalar("Train categorical accuracy", train_accuracy, step=global_step) tf.summary.scalar("Validation loss", dev_loss, step=global_step) tf.summary.scalar("Validation categorical accuracy", dev_accuracy, step=global_step) if model_options["one_shot_validation"]: tf.summary.scalar( f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy", val_task_accuracy, step=global_step) # store model and results model_utils.save_model(speech_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="model") if best_model: best_model = False model_utils.save_model(speech_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="best_model")
def embed(model_options, output_dir, model_file, model_step_file): """Load siamese image similarity model and extract embeddings.""" # get base embeddings directory if specified, otherwise embed images embed_dir = None if model_options["use_embeddings"]: # load embeddings from dense layer of base model embed_dir = os.path.join(model_options["base_dir"], "embed", "dense") # load model vision_network, _ = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=get_training_objective(model_options)) # get model embedding model and data preprocessing embedding_model = create_embedding_model(vision_network) data_preprocess_func = get_data_preprocess_func(model_options) # load image datasets and compute embeddings for data in ["flickr8k", "flickr30k", "mscoco"]: train_image_dir_dict = {} dev_image_dir_dict = {} if data == "flickr8k": train_image_dir_dict["flickr8k_image_dir"] = os.path.join( "data", "external", "flickr8k_images") dev_image_dir_dict = train_image_dir_dict if data == "flickr30k": train_image_dir_dict["flickr30k_image_dir"] = os.path.join( "data", "external", "flickr30k_images") dev_image_dir_dict = train_image_dir_dict if data == "mscoco": train_image_dir_dict["mscoco_image_dir"] = os.path.join( "data", "external", "mscoco", "train2017") dev_image_dir_dict["mscoco_image_dir"] = os.path.join( "data", "external", "mscoco", "val2017") one_shot_exp = flickr_vision.FlickrVision( keywords_split="one_shot_evaluation", **train_image_dir_dict, embed_dir=embed_dir) background_train_exp = flickr_vision.FlickrVision( keywords_split="background_train", **train_image_dir_dict, embed_dir=embed_dir) background_dev_exp = flickr_vision.FlickrVision( keywords_split="background_dev", **dev_image_dir_dict, embed_dir=embed_dir) subset_exp = { "one_shot_evaluation": one_shot_exp, "background_train": background_train_exp, "background_dev": background_dev_exp, } for subset, exp in subset_exp.items(): output_embed_dir = os.path.join(output_dir, "embed", "dense", data, subset) file_io.check_create_dir(output_embed_dir) if model_options["use_embeddings"]: subset_paths = exp.embed_paths else: subset_paths = exp.image_paths unique_paths = np.unique(subset_paths) # batch images/base embeddings for faster embedding inference path_ds = tf.data.Dataset.from_tensor_slices(unique_paths) path_ds = path_ds.batch(model_options["batch_size"]) path_ds = path_ds.prefetch(tf.data.experimental.AUTOTUNE) num_samples = int( np.ceil(len(unique_paths) / model_options["batch_size"])) start_time = time.time() paths, embeddings = [], [] for path_batch in tqdm(path_ds, total=num_samples): path_embeddings = embedding_model.predict( data_preprocess_func(path_batch)) paths.extend(path_batch.numpy()) embeddings.extend(path_embeddings.numpy()) end_time = time.time() logging.log( logging.INFO, f"Computed embeddings for {data} {subset} in " f"{end_time - start_time:.4f} seconds") # serialize and write embeddings to TFRecord files for path, embedding in zip(paths, embeddings): example_proto = dataset.embedding_to_example_protobuf( embedding) path = path.decode("utf-8") path = path.split(".tfrecord")[0] # remove any ".tfrecord" ext path = os.path.join(output_embed_dir, f"{os.path.split(path)[1]}.tfrecord") with tf.io.TFRecordWriter(path, options="ZLIB") as writer: writer.write(example_proto.SerializeToString()) logging.log(logging.INFO, f"Embeddings stored at: {output_embed_dir}")
def test(model_options, output_dir, model_file, model_step_file): """Load and test siamese image similarity model for one-shot learning.""" # get base embeddings directory if specified, otherwise embed images embed_dir = None if model_options["use_embeddings"]: # load embeddings from dense layer of base model embed_dir = os.path.join(model_options["base_dir"], "embed", "dense") # load Flickr 8k one-shot experiment one_shot_exp = flickr_vision.FlickrVision( keywords_split="one_shot_evaluation", flickr8k_image_dir=os.path.join("data", "external", "flickr8k_images"), preprocess_func=get_data_preprocess_func(model_options), embed_dir=embed_dir) # load model vision_network, _ = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=get_training_objective(model_options)) embedding_model_func = create_embedding_model # create few-shot model from vision network for one-shot testing if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model(model_options, vision_network, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel(vision_network, None, mc_dropout=FLAGS.mc_dropout) classification = False if FLAGS.classification: assert FLAGS.fine_tune_steps is not None classification = True logging.log(logging.INFO, "Created few-shot model from vision network") test_few_shot_model.model.summary() # test model on L-way K-shot task task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"{FLAGS.L}-way {FLAGS.K}-shot accuracy after {FLAGS.episodes} " f"episodes: {task_accuracy:.3%} +- {conf_interval_95*100:.4f}")
def create_vision_network(model_options, build_model=True): """Create vision similarity network from model options.""" # get input shape input_shape = None if build_model: input_shape = model_options["input_shape"] # train entire similarity network from scratch or with imagenet weights if model_options["base_dir"] is None: # oracle model with imagenet weights and class logits as embedding layer if model_options["oracle"]: if build_model: logging.log(logging.INFO, "Fine-tuning oracle with imagenet weights") inception_network = inceptionv3.create_inceptionv3_network( input_shape=input_shape, pretrained=True, include_top=True) # train only the logits embeddings layer inceptionv3.freeze_weights(inception_network, trainable="logits") # set output layer to be linear (instead of softmax) inception_network.layers[-1].activation = None model_layers = [inception_network] # no additional dense layers if training oracle model model_options["dense_units"] = None # train network from scratch (or with imagenet weights for debugging) else: inception_network = inceptionv3.create_inceptionv3_network( input_shape=input_shape, pretrained=model_options["pretrained"], include_top=False) # inception model with imagenet weights and our own top dense layers # ... alternative "oracle"? if model_options["pretrained"]: if build_model: logging.log( logging.INFO, "Debugging with imagenet weights and custom top layer") # train final inception module and the top dense layers inceptionv3.freeze_weights(inception_network, trainable="final_inception") # inception model with random weights and our own top dense layers elif build_model: logging.log(logging.INFO, "Training entire model from scratch") model_layers = [ inception_network, tf.keras.layers.GlobalAveragePooling2D() ] if model_options["dropout_rate"] is not None: model_layers.append( tf.keras.layers.Dropout(model_options["dropout_rate"])) model_layers.append( tf.keras.layers.Dense(model_options["dense_units"][0])) # train similarity ranking layers on base model or its extracted embeddings else: # train dense layers on extracted base embeddings if model_options["use_embeddings"]: if build_model: logging.log(logging.INFO, "Training model on base embeddings") if model_options["dropout_rate"] is not None: model_layers = [ tf.keras.layers.Dropout(model_options["dropout_rate"], input_shape=input_shape), tf.keras.layers.Dense(model_options["dense_units"][0]) ] else: model_layers = [ tf.keras.layers.Dense(model_options["dense_units"][0], input_shape=input_shape) ] # load base model and fine-tune final layer else: if build_model: logging.log(logging.INFO, "Warm start training from pretrained network") base_model_file = os.path.join( model_options["base_dir"], f"{model_options['base_model']}.h5") base_step_file = os.path.join( model_options["base_dir"], f"{model_options['base_model']}.step") base_network, _ = model_utils.load_model( model_file=base_model_file, model_step_file=base_step_file, loss=lambda y_t, y_p: y_p) # dummy loss to load model .. # set output layer to be linear (in case of softmax, sigmoid, etc.) base_network.layers[-1].activation = None # fine-tune only the final base embedding layer for layer in base_network.layers[:-1]: layer.trainable = False base_network.layers[-1].trainable = True model_layers = [base_network] # no additional dense layers model_options["dense_units"] = None # add top layer hidden units (final layer linear) if model_options["dense_units"] is not None: for dense_units in model_options["dense_units"][1:]: model_layers.append(tf.keras.layers.ReLU()) if model_options["dropout_rate"] is not None: model_layers.append( tf.keras.layers.Dropout(model_options["dropout_rate"])) model_layers.append(tf.keras.layers.Dense(dense_units)) vision_network = tf.keras.Sequential(model_layers) if build_model: vision_network.summary() return vision_network
def train(model_options, output_dir, model_file=None, model_step_file=None, tf_writer=None): """Create and train image classification model for one-shot learning.""" # load training data train_exp, dev_exp = dataset.create_flickr_vision_train_data( model_options["data"]) train_labels = [] for image_keywords in train_exp.unique_image_keywords: labels = map( lambda keyword: train_exp.keyword_labels[keyword], image_keywords) train_labels.append(np.array(list(labels))) train_labels = np.asarray(train_labels) dev_labels = [] for image_keywords in dev_exp.unique_image_keywords: labels = map( lambda keyword: train_exp.keyword_labels[keyword], image_keywords) dev_labels.append(np.array(list(labels))) dev_labels = np.asarray(dev_labels) train_paths = train_exp.unique_image_paths dev_paths = dev_exp.unique_image_paths mlb = MultiLabelBinarizer() train_labels_multi_hot = mlb.fit_transform(train_labels) dev_labels_multi_hot = mlb.transform(dev_labels) # define preprocessing for images preprocess_images_func = functools.partial( dataset.load_and_preprocess_image, crop_size=model_options["crop_size"], augment_crop=model_options["augment_train"], random_scales=model_options["random_scales"], horizontal_flip=model_options["horizontal_flip"], colour=model_options["colour"]) # create standard training dataset pipeline background_train_ds = tf.data.Dataset.zip(( tf.data.Dataset.from_tensor_slices(train_paths), tf.data.Dataset.from_tensor_slices(train_labels_multi_hot))) # map data preprocessing function across training data background_train_ds = background_train_ds.map( lambda path, label: (preprocess_images_func(path), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) # repeat augmentation, shuffle and batch train data if model_options["num_augment"] is not None: background_train_ds = background_train_ds.repeat( model_options["num_augment"]) background_train_ds = background_train_ds.shuffle(1000) background_train_ds = background_train_ds.batch( model_options["batch_size"]) background_train_ds = background_train_ds.prefetch( tf.data.experimental.AUTOTUNE) # create dev set pipeline for classification validation background_dev_ds = tf.data.Dataset.zip(( tf.data.Dataset.from_tensor_slices( dev_paths).map(preprocess_images_func), tf.data.Dataset.from_tensor_slices(dev_labels_multi_hot))) background_dev_ds = background_dev_ds.batch( batch_size=model_options["batch_size"]) # write example batch to TensorBoard if tf_writer is not None: logging.log(logging.INFO, "Writing example images to TensorBoard") with tf_writer.as_default(): for x_batch, y_batch in background_train_ds.take(1): tf.summary.image("Example train images", (x_batch+1)/2, max_outputs=30, step=0) labels = "" for i, label in enumerate(y_batch[:30]): labels += f"{i}: {np.asarray(train_exp.keywords)[label]} " tf.summary.text("Example train labels", labels, step=0) # get training objective multi_label_loss = get_training_objective(model_options) # get model input shape model_options["input_shape"] = ( model_options["crop_size"], model_options["crop_size"], 3) # load or create model if model_file is not None: assert model_options["n_classes"] == len(train_exp.keywords) vision_network, train_state = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=multi_label_loss) # get previous tracking variables initial_model = False global_step, model_epochs, _, best_val_score = train_state else: model_options["n_classes"] = len(train_exp.keywords) vision_network = create_vision_network(model_options) # create tracking variables initial_model = True global_step = 0 model_epochs = 0 if model_options["one_shot_validation"]: best_val_score = -np.inf else: best_val_score = np.inf # load or create Adam optimizer with decayed learning rate lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( model_options["learning_rate"], decay_rate=model_options["decay_rate"], decay_steps=model_options["decay_steps"], staircase=True) if model_file is not None: logging.log(logging.INFO, "Restoring optimizer state") optimizer = vision_network.optimizer else: optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) # compile model to store optimizer with model when saving vision_network.compile(optimizer=optimizer, loss=multi_label_loss) # create few-shot model from vision network for background training vision_few_shot_model = base.BaseModel(vision_network, multi_label_loss) # test model on one-shot validation task prior to training if model_options["one_shot_validation"]: one_shot_dev_exp = flickr_vision.FlickrVision( keywords_split="background_dev", flickr8k_image_dir=os.path.join( "data", "external", "flickr8k_images"), flickr30k_image_dir=os.path.join( "data", "external", "flickr30k_images"), mscoco_image_dir=os.path.join( "data", "external", "mscoco", "val2017"), preprocess_func=get_data_preprocess_func(model_options)) embedding_model_func = lambda vision_network: create_embedding_model( model_options, vision_network) classification = False if FLAGS.classification: assert FLAGS.embed_layer in ["logits", "softmax"] classification = True # create few-shot model from vision network for one-shot validation if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, vision_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( vision_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f}") # create training metrics precision_metric = tf.keras.metrics.Precision() recall_metric = tf.keras.metrics.Recall() loss_metric = tf.keras.metrics.Mean() best_model = False # store model options on first run if initial_model: file_io.write_json( os.path.join(output_dir, "model_options.json"), model_options) # train model for epoch in range(model_epochs, model_options["epochs"]): logging.log(logging.INFO, f"Epoch {epoch:03d}") precision_metric.reset_states() recall_metric.reset_states() loss_metric.reset_states() # train on epoch of training data step_pbar = tqdm(background_train_ds, bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]") for step, (x_batch, y_batch) in enumerate(step_pbar): loss_value, y_predict = vision_few_shot_model.train_step( x_batch, y_batch, optimizer, clip_norm=model_options["gradient_clip_norm"]) y_one_hot_predict = tf.round(tf.nn.sigmoid(y_predict)) precision_metric.update_state(y_batch, y_one_hot_predict) recall_metric.update_state(y_batch, y_one_hot_predict) loss_metric.update_state(loss_value) step_loss = tf.reduce_mean(loss_value) train_loss = loss_metric.result().numpy() train_precision = precision_metric.result().numpy() train_recall = recall_metric.result().numpy() train_f1 = 2 / ((1/train_precision) + (1/train_recall)) step_pbar.set_description_str( f"\tStep {step:03d}: " f"Step loss: {step_loss:.6f}, " f"Loss: {train_loss:.6f}, " f"Precision: {train_precision:.3%}, " f"Recall: {train_recall:.3%}, " f"F-1: {train_f1:.3%}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", step_loss, step=global_step) global_step += 1 # validate classification model precision_metric.reset_states() recall_metric.reset_states() loss_metric.reset_states() for x_batch, y_batch in background_dev_ds: y_predict = vision_few_shot_model.predict(x_batch, training=False) loss_value = vision_few_shot_model.loss(y_batch, y_predict) y_one_hot_predict = tf.round(tf.nn.sigmoid(y_predict)) precision_metric.update_state(y_batch, y_one_hot_predict) recall_metric.update_state(y_batch, y_one_hot_predict) loss_metric.update_state(loss_value) dev_loss = loss_metric.result().numpy() dev_precision = precision_metric.result().numpy() dev_recall = recall_metric.result().numpy() dev_f1 = 2 / ((1/dev_precision) + (1/dev_recall)) # validate model on one-shot dev task if specified if model_options["one_shot_validation"]: if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, vision_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( vision_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) val_score = val_task_accuracy val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy" # otherwise, validate on classification task else: val_score = dev_f1 val_metric = "F-1" if val_score >= best_val_score: best_val_score = val_score best_model = True # log results logging.log( logging.INFO, f"Train: Loss: {train_loss:.6f}, Precision: {train_precision:.3%}, " f"Recall: {train_recall:.3%}, F-1: {train_f1:.3%}") logging.log( logging.INFO, f"Validation: Loss: {dev_loss:.6f}, Precision: " f"{dev_precision:.3%}, Recall: {dev_recall:.3%}, F-1: " f"{dev_f1:.3%} {'*' if best_model else ''}") if model_options["one_shot_validation"]: logging.log( logging.INFO, f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f} {'*' if best_model else ''}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", train_loss, step=global_step) tf.summary.scalar( "Train precision", train_precision, step=global_step) tf.summary.scalar( "Train recall", train_recall, step=global_step) tf.summary.scalar( "Train F-1", train_f1, step=global_step) tf.summary.scalar( "Validation loss", dev_loss, step=global_step) tf.summary.scalar( "Validation precision", dev_precision, step=global_step) tf.summary.scalar( "Validation recall", dev_recall, step=global_step) tf.summary.scalar( "Validation F-1", dev_f1, step=global_step) if model_options["one_shot_validation"]: tf.summary.scalar( f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy", val_task_accuracy, step=global_step) # store model and results model_utils.save_model( vision_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="model") if best_model: best_model = False model_utils.save_model( vision_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="best_model")
def test(model_options, output_dir, model_file, model_step_file): """Load and test siamese audio-visual similarity model for one-shot learning.""" # load embeddings from (linear) dense layer of base speech and vision models speech_embed_dir = os.path.join(model_options["audio_base_dir"], "embed", "dense") image_embed_dir = os.path.join(model_options["vision_base_dir"], "embed", "dense") # load Flickr Audio one-shot experiment one_shot_exp = flickr_multimodal.FlickrMultimodal( features="mfcc", keywords_split="one_shot_evaluation", flickr8k_image_dir=os.path.join("data", "external", "flickr8k_images"), speech_embed_dir=speech_embed_dir, image_embed_dir=image_embed_dir, speech_preprocess_func=data_preprocess_func, image_preprocess_func=data_preprocess_func, speaker_mode=FLAGS.speaker_mode, unseen_match_set=FLAGS.unseen_match_set) # load joint audio-visual model join_network_model, _ = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=get_training_objective(model_options)) speech_network = tf.keras.Model(inputs=join_network_model.inputs[0], outputs=join_network_model.outputs[0]) vision_network = tf.keras.Model(inputs=join_network_model.inputs[1], outputs=join_network_model.outputs[1]) # create few-shot model from speech and vision networks for one-shot validation test_few_shot_model = create_fine_tune_model(model_options, speech_network, vision_network) fine_tune_optimizer = None if not FLAGS.adam: fine_tune_optimizer = tf.keras.optimizers.SGD(FLAGS.fine_tune_lr) logging.log(logging.INFO, "Created few-shot model from speech network") test_few_shot_model.speech_model.model.summary() logging.log(logging.INFO, "Created few-shot model from vision network") test_few_shot_model.vision_model.model.summary() # test model on L-way K-shot multimodal task task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot( one_shot_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, direct_match=FLAGS.direct_match, multimodal_model=test_few_shot_model, multimodal_embedding_func=None, #create_embedding_model, optimizer=fine_tune_optimizer, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"{FLAGS.L}-way {FLAGS.K}-shot accuracy after {FLAGS.episodes} " f"episodes: {task_accuracy:.3%} +- {conf_interval_95*100:.4f}")
def train(model_options, output_dir, model_file=None, model_step_file=None, tf_writer=None): """Create and train audio-visual similarity model for one-shot learning.""" # load embeddings from (linear) dense layer of base speech and vision models speech_embed_dir = os.path.join(model_options["audio_base_dir"], "embed", "dense") image_embed_dir = os.path.join(model_options["vision_base_dir"], "embed", "dense") # load training data (embed dir determines mfcc/fbank speech features) train_exp, dev_exp = dataset.create_flickr_multimodal_train_data( "mfcc", speech_embed_dir=speech_embed_dir, image_embed_dir=image_embed_dir, speaker_mode=FLAGS.speaker_mode, speech_preprocess_func=data_preprocess_func, image_preprocess_func=data_preprocess_func, unseen_match_set=FLAGS.unseen_match_set) # create training dataset generator def sample_task_function(task_exp, l, k, n): def sample(): curr_episode_train, curr_episode_test = task_exp.sample_episode( l, k, n) x_train_s = task_exp.speech_experiment.data[curr_episode_train[0]] x_train_i = task_exp.vision_experiment.data[curr_episode_train[0]] x_query_s = task_exp.speech_experiment.data[curr_episode_test[0]] x_query_i = task_exp.vision_experiment.data[curr_episode_test[0]] return x_train_s, x_train_i, x_query_s, x_query_i return sample train_generator = sample_task_function(train_exp, FLAGS.L, FLAGS.K, FLAGS.N) # create dummy dataset with infinite zero elements background_train_ds = tf.data.Dataset.from_tensors(tf.constant(0.)) background_train_ds = background_train_ds.repeat(-1) # parallel map dummy elements to task data background_train_ds = background_train_ds.map( lambda _: tf.py_function( train_generator, inp=[], Tout=[tf.float32] * 4), num_parallel_calls=tf.data.experimental.AUTOTUNE) background_train_ds = background_train_ds.batch( batch_size=model_options["meta_batch_size"]) background_train_ds = background_train_ds.prefetch( tf.data.experimental.AUTOTUNE) # get training objective triplet_loss = get_training_objective(model_options) # get model input shape for speech_batch, image_batch, _, _ in background_train_ds.take(1): model_options["audio_base_embed_size"] = int( tf.shape(speech_batch)[-1].numpy()) model_options["vision_base_embed_size"] = int( tf.shape(image_batch)[-1].numpy()) model_options["audio_input_shape"] = [ model_options["audio_base_embed_size"] ] model_options["vision_input_shape"] = [ model_options["vision_base_embed_size"] ] # load or create models if model_file is not None: join_network_model, train_state = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=get_training_objective(model_options)) speech_network = tf.keras.Model(inputs=join_network_model.inputs[0], outputs=join_network_model.outputs[0]) vision_network = tf.keras.Model(inputs=join_network_model.inputs[1], outputs=join_network_model.outputs[1]) # get previous tracking variables initial_model = False global_step, _, _, best_val_score = train_state else: speech_network = create_speech_network(model_options) vision_network = create_vision_network(model_options) # create tracking variables initial_model = True global_step = 0 if model_options["one_shot_validation"]: best_val_score = -np.inf else: best_val_score = np.inf # load or create Adam optimizer with decayed learning rate lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( model_options["meta_learning_rate"], decay_rate=model_options["decay_rate"], decay_steps=model_options["decay_steps"], staircase=True) if model_file is not None: logging.log(logging.INFO, "Restoring optimizer state") optimizer = join_network_model.optimizer else: optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) # compile models to store optimizer with model when saving join_network_model = tf.keras.Model( inputs=[speech_network.input, vision_network.input], outputs=[speech_network.output, vision_network.output]) # speech_network.compile(optimizer=optimizer, loss=triplet_loss) join_network_model.compile(optimizer=optimizer, loss=triplet_loss) # create few-shot model from speech network for background training multimodal_model = base.WeaklySupervisedMAML( speech_network, vision_network, triplet_loss, inner_optimizer_lr=FLAGS.fine_tune_lr) # create few-shot model from speech and vision networks for one-shot validation test_few_shot_model = create_fine_tune_model(model_options, speech_network, vision_network) fine_tune_optimizer = None if not FLAGS.adam: fine_tune_optimizer = tf.keras.optimizers.SGD(FLAGS.fine_tune_lr) # test model on one-shot validation task prior to training val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot( dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, direct_match=FLAGS.direct_match, multimodal_model=test_few_shot_model, multimodal_embedding_func=None, #create_embedding_model, optimizer=fine_tune_optimizer, fine_tune_steps=FLAGS.fine_tune_steps * 2, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f}") # create training metrics step_loss_metric = tf.keras.metrics.Mean() avg_loss_metric = tf.keras.metrics.Mean() best_model = False # store model options on first run if initial_model: file_io.write_json(os.path.join(output_dir, "model_options.json"), model_options) # also store initial model for probing and things model_utils.save_model(join_network_model, output_dir, 0, 0, "not tested", 0., 0., name="initial_model") # train model step_pbar = tqdm(background_train_ds, bar_format="{desc} {r_bar}", initial=global_step, total=model_options["meta_steps"]) for batch in step_pbar: # train on batch of training task data train_s_batch, train_i_batch, test_s_batch, test_i_batch = batch meta_loss, inner_losses, meta_losses = multimodal_model.maml_train_step( train_s_batch, train_i_batch, test_s_batch, test_i_batch, FLAGS.fine_tune_steps, optimizer, stop_gradients=FLAGS.first_order, clip_norm=model_options["gradient_clip_norm"]) step_loss_metric.reset_states() step_loss_metric.update_state(meta_loss) avg_loss_metric.update_state(meta_loss) avg_loss = avg_loss_metric.result().numpy() step_pbar.set_description_str(f"Step loss: {meta_loss:.6f}, " f"Average loss: {avg_loss:.6f}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar("Train step loss", meta_loss, step=global_step) if (global_step % model_options["validation_interval"] == 0 or global_step == model_options["meta_steps"]) and global_step > 0: # validate model on one-shot dev task test_few_shot_model = create_fine_tune_model( model_options, speech_network, vision_network) val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot( dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, direct_match=FLAGS.direct_match, multimodal_model=test_few_shot_model, multimodal_embedding_func=None, #create_embedding_model, optimizer=fine_tune_optimizer, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy" if val_task_accuracy >= best_val_score: best_val_score = val_task_accuracy best_model = True # log results avg_loss_metric.reset_states() logging.log(logging.INFO, f"Step {global_step:03d}") logging.log( logging.INFO, f"Train: Step loss: {meta_loss:.6f}, " f"Average loss: {avg_loss:.6f}") logging.log( logging.INFO, f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f} {'*' if best_model else ''}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy", val_task_accuracy, step=global_step) model_utils.save_model(join_network_model, output_dir, global_step, global_step, val_metric, val_task_accuracy, best_val_score, name="model") if best_model: best_model = False model_utils.save_model(join_network_model, output_dir, global_step, global_step, val_metric, val_task_accuracy, best_val_score, name="best_model") global_step += 1 if global_step > model_options["meta_steps"]: logging.log(logging.INFO, f"Training complete after {global_step-1:03d} steps") break
def train(model_options, output_dir, model_file=None, model_step_file=None, tf_writer=None): """Create and train audio-visual similarity model for one-shot learning.""" # load embeddings from (linear) dense layer of base speech and vision models speech_embed_dir = os.path.join( model_options["audio_base_dir"], "embed", "dense") image_embed_dir = os.path.join( model_options["vision_base_dir"], "embed", "dense") # load training data (embed dir determines mfcc/fbank speech features) train_exp, dev_exp = dataset.create_flickr_multimodal_train_data( "mfcc", speech_embed_dir=speech_embed_dir, image_embed_dir=image_embed_dir, speaker_mode=FLAGS.speaker_mode, unseen_match_set=FLAGS.unseen_match_set) train_speech_paths = train_exp.speech_experiment.data train_image_paths = train_exp.vision_experiment.data dev_speech_paths = dev_exp.speech_experiment.data dev_image_paths = dev_exp.vision_experiment.data # define preprocessing for base model embeddings preprocess_data_func = lambda example: dataset.parse_embedding_protobuf( example)["embed"] # create standard training dataset pipeline background_train_ds = tf.data.Dataset.zip(( tf.data.TFRecordDataset( train_speech_paths, compression_type="ZLIB"), tf.data.TFRecordDataset( train_image_paths, compression_type="ZLIB"))) # map data preprocessing function across training data background_train_ds = background_train_ds.map( lambda speech_path, image_path: ( preprocess_data_func(speech_path), preprocess_data_func(image_path)), num_parallel_calls=8) # shuffle and batch train data background_train_ds = background_train_ds.repeat(-1) background_train_ds = background_train_ds.shuffle(1000) background_train_ds = background_train_ds.batch( model_options["batch_size"]) background_train_ds = background_train_ds.take( model_options["num_batches"]) background_train_ds = background_train_ds.prefetch( tf.data.experimental.AUTOTUNE) # create dev set pipeline for validation background_dev_ds = tf.data.Dataset.zip(( tf.data.TFRecordDataset( dev_speech_paths, compression_type="ZLIB"), tf.data.TFRecordDataset( dev_image_paths, compression_type="ZLIB"))) background_dev_ds = background_dev_ds.map( lambda speech_path, image_path: ( preprocess_data_func(speech_path), preprocess_data_func(image_path)), num_parallel_calls=8) background_dev_ds = background_dev_ds.batch( batch_size=model_options["batch_size"]) # get training objective triplet_loss = get_training_objective(model_options) # get model input shape for speech_batch, image_batch in background_train_ds.take(1): model_options["audio_base_embed_size"] = int( tf.shape(speech_batch)[1].numpy()) model_options["vision_base_embed_size"] = int( tf.shape(image_batch)[1].numpy()) model_options["audio_input_shape"] = [ model_options["audio_base_embed_size"]] model_options["vision_input_shape"] = [ model_options["vision_base_embed_size"]] # load or create models if model_file is not None: join_network_model, train_state = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=get_training_objective(model_options)) speech_network = tf.keras.Model( inputs=join_network_model.inputs[0], outputs=join_network_model.outputs[0]) vision_network = tf.keras.Model( inputs=join_network_model.inputs[1], outputs=join_network_model.outputs[1]) # get previous tracking variables initial_model = False global_step, model_epochs, _, best_val_score = train_state else: speech_network = create_speech_network(model_options) vision_network = create_vision_network(model_options) # create tracking variables initial_model = True global_step = 0 model_epochs = 0 if model_options["one_shot_validation"]: best_val_score = -np.inf else: best_val_score = np.inf # load or create Adam optimizer with decayed learning rate lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( model_options["learning_rate"], decay_rate=model_options["decay_rate"], decay_steps=model_options["decay_steps"], staircase=True) if model_file is not None: logging.log(logging.INFO, "Restoring optimizer state") optimizer = join_network_model.optimizer else: optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) # compile models to store optimizer with model when saving join_network_model = tf.keras.Model( inputs=[speech_network.input, vision_network.input], outputs=[speech_network.output, vision_network.output]) # speech_network.compile(optimizer=optimizer, loss=triplet_loss) join_network_model.compile(optimizer=optimizer, loss=triplet_loss) # create few-shot model from speech network for background training multimodal_model = base.WeaklySupervisedModel( speech_network, vision_network, triplet_loss) # test model on one-shot validation task prior to training if model_options["one_shot_validation"]: one_shot_dev_exp = flickr_multimodal.FlickrMultimodal( features="mfcc", keywords_split="background_dev", flickr8k_image_dir=os.path.join("data", "external", "flickr8k_images"), speech_embed_dir=speech_embed_dir, image_embed_dir=image_embed_dir, speech_preprocess_func=data_preprocess_func, image_preprocess_func=data_preprocess_func, speaker_mode=FLAGS.speaker_mode, unseen_match_set=FLAGS.unseen_match_set) # create few-shot model from speech and vision networks for one-shot validation if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_network, vision_network) else: test_few_shot_model = base.WeaklySupervisedModel( speech_network, vision_network, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, direct_match=FLAGS.direct_match, multimodal_model=test_few_shot_model, multimodal_embedding_func=None, #create_embedding_model, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f}") # create training metrics loss_metric = tf.keras.metrics.Mean() best_model = False # store model options on first run if initial_model: file_io.write_json( os.path.join(output_dir, "model_options.json"), model_options) # train model for epoch in range(model_epochs, model_options["epochs"]): logging.log(logging.INFO, f"Epoch {epoch:03d}") loss_metric.reset_states() # train on epoch of training data step_pbar = tqdm(background_train_ds, bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]") for step, (speech_batch, image_batch) in enumerate(step_pbar): loss_value, y_speech, y_image = multimodal_model.train_step( speech_batch, image_batch, optimizer, clip_norm=model_options["gradient_clip_norm"]) loss_metric.update_state(loss_value) step_loss = tf.reduce_mean(loss_value) train_loss = loss_metric.result().numpy() step_pbar.set_description_str( f"\tStep {step:03d}: " f"Step loss: {step_loss:.6f}, " f"Loss: {train_loss:.6f}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", step_loss, step=global_step) global_step += 1 # validate multimodal model loss_metric.reset_states() for speech_batch, image_batch in background_dev_ds: y_speech = multimodal_model.speech_model.predict( speech_batch, training=False) y_image = multimodal_model.vision_model.predict( image_batch, training=False) loss_value = multimodal_model.loss(y_speech, y_image) loss_metric.update_state(loss_value) dev_loss = loss_metric.result().numpy() # validate model on one-shot dev task if specified if model_options["one_shot_validation"]: if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_network, vision_network) else: test_few_shot_model = base.WeaklySupervisedModel( speech_network, vision_network, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, direct_match=FLAGS.direct_match, multimodal_model=test_few_shot_model, multimodal_embedding_func=None, #create_embedding_model, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) val_score = val_task_accuracy val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy" if val_score >= best_val_score: best_val_score = val_score best_model = True # otherwise, validate on siamese task else: val_score = dev_loss val_metric = "loss" if val_score <= best_val_score: best_val_score = val_score best_model = True # log results logging.log(logging.INFO, f"Train: Loss: {train_loss:.6f}") logging.log( logging.INFO, f"Validation: Loss: {dev_loss:.6f} {'*' if best_model else ''}") if model_options["one_shot_validation"]: logging.log( logging.INFO, f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f} {'*' if best_model else ''}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", train_loss, step=global_step) tf.summary.scalar( f"Validation loss", dev_loss, step=global_step) if model_options["one_shot_validation"]: tf.summary.scalar( f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy", val_task_accuracy, step=global_step) # store model and results # model_utils.save_model( # multimodal_model.model_a.model, output_dir, epoch + 1, global_step, # val_metric, val_score, best_val_score, name="audio_model") # model_utils.save_model( # multimodal_model.model_b.model, output_dir, epoch + 1, global_step, # val_metric, val_score, best_val_score, name="vision_model") model_utils.save_model( join_network_model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="model") if best_model: best_model = False # model_utils.save_model( # multimodal_model.model_a.model, output_dir, epoch + 1, global_step, # val_metric, val_score, best_val_score, name="audio_best_model") # model_utils.save_model( # multimodal_model.model_b.model, output_dir, epoch + 1, global_step, # val_metric, val_score, best_val_score, name="vision_best_model") model_utils.save_model( join_network_model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="best_model") import pdb; pdb.set_trace()