def test(): """Test baseline speech DTW matching model for one-shot learning.""" # load Flickr Audio one-shot experiment one_shot_exp = flickr_speech.FlickrSpeech( features=FLAGS.features, keywords_split="one_shot_evaluation", preprocess_func=data_preprocess_func, speaker_mode=FLAGS.speaker_mode) # test model on L-way K-shot task task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, dtw=True, random=FLAGS.random) logging.log( logging.INFO, f"{FLAGS.L}-way {FLAGS.K}-shot accuracy after {FLAGS.episodes} " f"episodes: {task_accuracy:.3%} +- {conf_interval_95*100:.4f}")
def test(model_options, output_dir, model_file, model_step_file): """Load and test spoken word classification model for one-shot learning.""" # load Flickr Audio one-shot experiment one_shot_exp = flickr_speech.FlickrSpeech( features=model_options["features"], keywords_split="one_shot_evaluation", preprocess_func=get_data_preprocess_func(model_options), speaker_mode=FLAGS.speaker_mode) # load model speech_network, _ = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=get_training_objective(model_options)) embedding_model_func = lambda speech_network: create_embedding_model( model_options, speech_network) # create few-shot model from speech network for one-shot testing if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model(model_options, speech_network, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel(speech_network, None, mc_dropout=FLAGS.mc_dropout) classification = False if FLAGS.classification: assert FLAGS.embed_layer in ["logits", "softmax"] classification = True logging.log(logging.INFO, "Created few-shot model from speech network") test_few_shot_model.model.summary() # test model on L-way K-shot task task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"{FLAGS.L}-way {FLAGS.K}-shot accuracy after {FLAGS.episodes} " f"episodes: {task_accuracy:.3%} +- {conf_interval_95*100:.4f}")
def train(model_options, output_dir, model_file=None, model_step_file=None, tf_writer=None): """Create and train spoken word similarity model for one-shot learning.""" # load embeddings from dense layer of base model embed_dir = os.path.join(model_options["base_dir"], "embed", "dense") # load training data (embed dir determines if mfcc/fbank) train_exp, dev_exp = dataset.create_flickr_audio_train_data( "mfcc", embed_dir=embed_dir, speaker_mode=FLAGS.speaker_mode) train_labels = [] for keyword in train_exp.keywords_set[3]: label = train_exp.keyword_labels[keyword] train_labels.append(label) train_labels = np.asarray(train_labels) dev_labels = [] for keyword in dev_exp.keywords_set[3]: label = train_exp.keyword_labels[keyword] dev_labels.append(label) dev_labels = np.asarray(dev_labels) train_paths = train_exp.embed_paths dev_paths = dev_exp.embed_paths # define preprocessing for base model embeddings preprocess_data_func = lambda example: dataset.parse_embedding_protobuf( example)["embed"] # create balanced batch training dataset pipeline if model_options["balanced"]: assert model_options["p"] is not None assert model_options["k"] is not None shuffle_train = False prefetch_train = False num_repeat = model_options["num_batches"] model_options["batch_size"] = model_options["p"] * model_options["k"] # get unique path train indices per unique label train_labels_series = pd.Series(train_labels) train_label_idx = { label: idx.values[ np.unique(train_paths[idx.values], return_index=True)[1]] for label, idx in train_labels_series.groupby( train_labels_series).groups.items()} # cache paths to speed things up a little ... file_io.check_create_dir(os.path.join(output_dir, "cache")) # create a dataset for each unique keyword label (shuffled and cached) train_label_datasets = [ tf.data.Dataset.zip(( tf.data.Dataset.from_tensor_slices(train_paths[idx]), tf.data.Dataset.from_tensor_slices(train_labels[idx]))).cache( os.path.join( output_dir, "cache", str(label))).shuffle(20) # len(idx) for label, idx in train_label_idx.items()] # create a dataset that samples balanced batches from the label datasets background_train_ds = dataset.create_balanced_batch_dataset( model_options["p"], model_options["k"], train_label_datasets) # create standard training dataset pipeline (shuffle and load training set) else: shuffle_train = True prefetch_train = True num_repeat = model_options["num_augment"] background_train_ds = tf.data.Dataset.zip(( tf.data.Dataset.from_tensor_slices(train_paths), tf.data.Dataset.from_tensor_slices(train_labels))) # load embedding TFRecords (faster here than before balanced sampling) # batch to read files in parallel background_train_ds = background_train_ds.batch( model_options["batch_size"]) background_train_ds = background_train_ds.flat_map( lambda paths, labels: tf.data.Dataset.zip(( tf.data.TFRecordDataset( paths, compression_type="ZLIB", num_parallel_reads=8), tf.data.Dataset.from_tensor_slices(labels)))) # map data preprocessing function across training data background_train_ds = background_train_ds.map( lambda data, label: (preprocess_data_func(data), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) # repeat augmentation, shuffle and batch train data if num_repeat is not None: background_train_ds = background_train_ds.repeat(num_repeat) if shuffle_train: background_train_ds = background_train_ds.shuffle(1000) background_train_ds = background_train_ds.batch( model_options["batch_size"]) if prefetch_train: background_train_ds = background_train_ds.prefetch( tf.data.experimental.AUTOTUNE) # create dev set pipeline for siamese validation background_dev_ds = tf.data.Dataset.zip(( tf.data.TFRecordDataset( dev_paths, compression_type="ZLIB", num_parallel_reads=8).map(preprocess_data_func), tf.data.Dataset.from_tensor_slices(dev_labels))) background_dev_ds = background_dev_ds.batch( batch_size=model_options["batch_size"]) # get training objective triplet_loss = get_training_objective(model_options) # get model input shape for x_batch, _ in background_train_ds.take(1): model_options["base_embed_size"] = int( tf.shape(x_batch)[1].numpy()) model_options["input_shape"] = [model_options["base_embed_size"]] # load or create model if model_file is not None: speech_network, train_state = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=triplet_loss) # get previous tracking variables initial_model = False global_step, model_epochs, _, best_val_score = train_state else: speech_network = create_speech_network(model_options) # create tracking variables initial_model = True global_step = 0 model_epochs = 0 if model_options["one_shot_validation"]: best_val_score = -np.inf else: best_val_score = np.inf # load or create Adam optimizer with decayed learning rate lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( model_options["learning_rate"], decay_rate=model_options["decay_rate"], decay_steps=model_options["decay_steps"], staircase=True) if model_file is not None: logging.log(logging.INFO, "Restoring optimizer state") optimizer = speech_network.optimizer else: optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) # compile model to store optimizer with model when saving speech_network.compile(optimizer=optimizer, loss=triplet_loss) # create few-shot model from speech network for background training speech_few_shot_model = base.BaseModel(speech_network, triplet_loss) # test model on one-shot validation task prior to training if model_options["one_shot_validation"]: one_shot_dev_exp = flickr_speech.FlickrSpeech( features="mfcc", keywords_split="background_dev", preprocess_func=get_data_preprocess_func(), embed_dir=embed_dir, speaker_mode=FLAGS.speaker_mode) embedding_model_func = create_embedding_model classification = False if FLAGS.classification: assert FLAGS.fine_tune_steps is not None classification = True # create few-shot model from speech network for one-shot validation if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f}") # create training metrics loss_metric = tf.keras.metrics.Mean() best_model = False # store model options on first run if initial_model: file_io.write_json( os.path.join(output_dir, "model_options.json"), model_options) # train model for epoch in range(model_epochs, model_options["epochs"]): logging.log(logging.INFO, f"Epoch {epoch:03d}") loss_metric.reset_states() # train on epoch of training data step_pbar = tqdm(background_train_ds, bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]") for step, (x_batch, y_batch) in enumerate(step_pbar): loss_value, y_predict = speech_few_shot_model.train_step( x_batch, y_batch, optimizer, clip_norm=model_options["gradient_clip_norm"]) loss_metric.update_state(loss_value) step_loss = tf.reduce_mean(loss_value) train_loss = loss_metric.result().numpy() step_pbar.set_description_str( f"\tStep {step:03d}: " f"Step loss: {step_loss:.6f}, " f"Loss: {train_loss:.6f}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", step_loss, step=global_step) global_step += 1 # validate siamese model loss_metric.reset_states() for x_batch, y_batch in background_dev_ds: y_predict = speech_few_shot_model.predict(x_batch, training=False) loss_value = speech_few_shot_model.loss(y_batch, y_predict) loss_metric.update_state(loss_value) dev_loss = loss_metric.result().numpy() # validate model on one-shot dev task if specified if model_options["one_shot_validation"]: if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) val_score = val_task_accuracy val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy" if val_score >= best_val_score: best_val_score = val_score best_model = True # otherwise, validate on siamese task else: val_score = dev_loss val_metric = "loss" if val_score <= best_val_score: best_val_score = val_score best_model = True # log results logging.log(logging.INFO, f"Train: Loss: {train_loss:.6f}") logging.log( logging.INFO, f"Validation: Loss: {dev_loss:.6f} {'*' if best_model else ''}") if model_options["one_shot_validation"]: logging.log( logging.INFO, f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f} {'*' if best_model else ''}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", train_loss, step=global_step) tf.summary.scalar( f"Validation loss", dev_loss, step=global_step) if model_options["one_shot_validation"]: tf.summary.scalar( f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy", val_task_accuracy, step=global_step) # store model and results model_utils.save_model( speech_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="model") if best_model: best_model = False model_utils.save_model( speech_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="best_model")
def train(model_options, output_dir, model_file=None, model_step_file=None, tf_writer=None): """Create and train spoken word classification model for one-shot learning.""" # load training data train_exp, dev_exp = dataset.create_flickr_audio_train_data( model_options["features"], speaker_mode=FLAGS.speaker_mode) train_labels = [] for keyword in train_exp.keywords_set[3]: label = train_exp.keyword_labels[keyword] train_labels.append(label) train_labels = np.asarray(train_labels) dev_labels = [] for keyword in dev_exp.keywords_set[3]: label = train_exp.keyword_labels[keyword] dev_labels.append(label) dev_labels = np.asarray(dev_labels) train_paths = train_exp.audio_paths dev_paths = dev_exp.audio_paths lb = LabelBinarizer() train_labels_one_hot = lb.fit_transform(train_labels) dev_labels_one_hot = lb.transform(dev_labels) # define preprocessing for speech features preprocess_speech_func = functools.partial( dataset.load_and_preprocess_speech, features=model_options["features"], max_length=model_options["max_length"], scaling=model_options["scaling"]) preprocess_speech_ds_func = lambda path: tf.py_function( func=preprocess_speech_func, inp=[path], Tout=tf.float32) # create standard training dataset pipeline background_train_ds = tf.data.Dataset.zip( (tf.data.Dataset.from_tensor_slices(train_paths), tf.data.Dataset.from_tensor_slices(train_labels_one_hot))) # map data preprocessing function across training data background_train_ds = background_train_ds.map( lambda path, label: (preprocess_speech_ds_func(path), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) # shuffle and batch train data background_train_ds = background_train_ds.shuffle(1000) background_train_ds = background_train_ds.batch( model_options["batch_size"]) background_train_ds = background_train_ds.prefetch( tf.data.experimental.AUTOTUNE) # create dev set pipeline for classification validation background_dev_ds = tf.data.Dataset.zip( (tf.data.Dataset.from_tensor_slices(dev_paths).map( preprocess_speech_ds_func), tf.data.Dataset.from_tensor_slices(dev_labels_one_hot))) background_dev_ds = background_dev_ds.batch( batch_size=model_options["batch_size"]) # write example batch to TensorBoard if tf_writer is not None: logging.log(logging.INFO, "Writing example features to TensorBoard") with tf_writer.as_default(): for x_batch, y_batch in background_train_ds.take(1): speech_feats = [] for feats in x_batch[:30]: feats = np.transpose(feats) speech_feats.append( (feats - np.min(feats)) / np.max(feats)) tf.summary.image( f"Example train speech {model_options['features']}", np.expand_dims(speech_feats, axis=-1), max_outputs=30, step=0) labels = "" for i, label in enumerate(y_batch[:30]): labels += f"{i}: {np.asarray(train_exp.keywords)[label]} " tf.summary.text("Example train labels", labels, step=0) # get training objective loss = get_training_objective(model_options) # get model input shape if model_options["features"] == "mfcc": model_options["input_shape"] = [model_options["max_length"], 39] else: model_options["input_shape"] = [model_options["max_length"], 40] # load or create model if model_file is not None: assert model_options["n_classes"] == len(train_exp.keywords) speech_network, train_state = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=loss) # get previous tracking variables initial_model = False global_step, model_epochs, _, best_val_score = train_state else: model_options["n_classes"] = len(train_exp.keywords) speech_network = create_speech_network(model_options) # create tracking variables initial_model = True global_step = 0 model_epochs = 0 if model_options["one_shot_validation"]: best_val_score = -np.inf else: best_val_score = np.inf # load or create Adam optimizer with decayed learning rate lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( model_options["learning_rate"], decay_rate=model_options["decay_rate"], decay_steps=model_options["decay_steps"], staircase=True) if model_file is not None: logging.log(logging.INFO, "Restoring optimizer state") optimizer = speech_network.optimizer else: optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) # compile model to store optimizer with model when saving speech_network.compile(optimizer=optimizer, loss=loss) # create few-shot model from speech network for background training speech_few_shot_model = base.BaseModel(speech_network, loss) # test model on one-shot validation task prior to training if model_options["one_shot_validation"]: one_shot_dev_exp = flickr_speech.FlickrSpeech( features=model_options["features"], keywords_split="background_dev", preprocess_func=get_data_preprocess_func(model_options), speaker_mode=FLAGS.speaker_mode) embedding_model_func = lambda speech_network: create_embedding_model( model_options, speech_network) classification = False if FLAGS.classification: assert FLAGS.embed_layer in ["logits", "softmax"] classification = True # create few-shot model from speech network for one-shot validation if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel(speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f}") # create training metrics accuracy_metric = tf.keras.metrics.CategoricalAccuracy() loss_metric = tf.keras.metrics.Mean() best_model = False # store model options on first run if initial_model: file_io.write_json(os.path.join(output_dir, "model_options.json"), model_options) # train model for epoch in range(model_epochs, model_options["epochs"]): logging.log(logging.INFO, f"Epoch {epoch:03d}") accuracy_metric.reset_states() loss_metric.reset_states() # train on epoch of training data step_pbar = tqdm(background_train_ds, bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]") for step, (x_batch, y_batch) in enumerate(step_pbar): loss_value, y_predict = speech_few_shot_model.train_step( x_batch, y_batch, optimizer, clip_norm=model_options["gradient_clip_norm"]) accuracy_metric.update_state(y_batch, y_predict) loss_metric.update_state(loss_value) step_loss = tf.reduce_mean(loss_value) train_loss = loss_metric.result().numpy() train_accuracy = accuracy_metric.result().numpy() step_pbar.set_description_str( f"\tStep {step:03d}: " f"Step loss: {step_loss:.6f}, " f"Loss: {train_loss:.6f}, " f"Categorical accuracy: {train_accuracy:.3%}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar("Train step loss", step_loss, step=global_step) global_step += 1 # validate classification model accuracy_metric.reset_states() loss_metric.reset_states() for x_batch, y_batch in background_dev_ds: y_predict = speech_few_shot_model.predict(x_batch, training=False) loss_value = speech_few_shot_model.loss(y_batch, y_predict) accuracy_metric.update_state(y_batch, y_predict) loss_metric.update_state(loss_value) dev_loss = loss_metric.result().numpy() dev_accuracy = accuracy_metric.result().numpy() # validate model on one-shot dev task if specified if model_options["one_shot_validation"]: if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) val_score = val_task_accuracy val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy" # otherwise, validate on classification task else: val_score = dev_accuracy val_metric = "categorical accuracy" if val_score >= best_val_score: best_val_score = val_score best_model = True # log results logging.log( logging.INFO, f"Train: Loss: {train_loss:.6f}, Categorical accuracy: " f"{train_accuracy:.3%}") logging.log( logging.INFO, f"Validation: Loss: {dev_loss:.6f}, Categorical accuracy: " f"{dev_accuracy:.3%} {'*' if best_model else ''}") if model_options["one_shot_validation"]: logging.log( logging.INFO, f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f} {'*' if best_model else ''}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar("Train step loss", train_loss, step=global_step) tf.summary.scalar("Train categorical accuracy", train_accuracy, step=global_step) tf.summary.scalar("Validation loss", dev_loss, step=global_step) tf.summary.scalar("Validation categorical accuracy", dev_accuracy, step=global_step) if model_options["one_shot_validation"]: tf.summary.scalar( f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy", val_task_accuracy, step=global_step) # store model and results model_utils.save_model(speech_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="model") if best_model: best_model = False model_utils.save_model(speech_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="best_model")
def test(model_options, output_dir, model_file, model_step_file): """Load and test siamese image similarity model for one-shot learning.""" # get base embeddings directory if specified, otherwise embed images embed_dir = None if model_options["use_embeddings"]: # load embeddings from dense layer of base model embed_dir = os.path.join(model_options["base_dir"], "embed", "dense") # load Flickr 8k one-shot experiment one_shot_exp = flickr_vision.FlickrVision( keywords_split="one_shot_evaluation", flickr8k_image_dir=os.path.join("data", "external", "flickr8k_images"), preprocess_func=get_data_preprocess_func(model_options), embed_dir=embed_dir) # load model vision_network, _ = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=get_training_objective(model_options)) embedding_model_func = create_embedding_model # create few-shot model from vision network for one-shot testing if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model(model_options, vision_network, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel(vision_network, None, mc_dropout=FLAGS.mc_dropout) classification = False if FLAGS.classification: assert FLAGS.fine_tune_steps is not None classification = True logging.log(logging.INFO, "Created few-shot model from vision network") test_few_shot_model.model.summary() # test model on L-way K-shot task task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"{FLAGS.L}-way {FLAGS.K}-shot accuracy after {FLAGS.episodes} " f"episodes: {task_accuracy:.3%} +- {conf_interval_95*100:.4f}")
def train(model_options, output_dir, model_file=None, model_step_file=None, tf_writer=None): """Create and train image classification model for one-shot learning.""" # load training data train_exp, dev_exp = dataset.create_flickr_vision_train_data( model_options["data"]) train_labels = [] for image_keywords in train_exp.unique_image_keywords: labels = map( lambda keyword: train_exp.keyword_labels[keyword], image_keywords) train_labels.append(np.array(list(labels))) train_labels = np.asarray(train_labels) dev_labels = [] for image_keywords in dev_exp.unique_image_keywords: labels = map( lambda keyword: train_exp.keyword_labels[keyword], image_keywords) dev_labels.append(np.array(list(labels))) dev_labels = np.asarray(dev_labels) train_paths = train_exp.unique_image_paths dev_paths = dev_exp.unique_image_paths mlb = MultiLabelBinarizer() train_labels_multi_hot = mlb.fit_transform(train_labels) dev_labels_multi_hot = mlb.transform(dev_labels) # define preprocessing for images preprocess_images_func = functools.partial( dataset.load_and_preprocess_image, crop_size=model_options["crop_size"], augment_crop=model_options["augment_train"], random_scales=model_options["random_scales"], horizontal_flip=model_options["horizontal_flip"], colour=model_options["colour"]) # create standard training dataset pipeline background_train_ds = tf.data.Dataset.zip(( tf.data.Dataset.from_tensor_slices(train_paths), tf.data.Dataset.from_tensor_slices(train_labels_multi_hot))) # map data preprocessing function across training data background_train_ds = background_train_ds.map( lambda path, label: (preprocess_images_func(path), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) # repeat augmentation, shuffle and batch train data if model_options["num_augment"] is not None: background_train_ds = background_train_ds.repeat( model_options["num_augment"]) background_train_ds = background_train_ds.shuffle(1000) background_train_ds = background_train_ds.batch( model_options["batch_size"]) background_train_ds = background_train_ds.prefetch( tf.data.experimental.AUTOTUNE) # create dev set pipeline for classification validation background_dev_ds = tf.data.Dataset.zip(( tf.data.Dataset.from_tensor_slices( dev_paths).map(preprocess_images_func), tf.data.Dataset.from_tensor_slices(dev_labels_multi_hot))) background_dev_ds = background_dev_ds.batch( batch_size=model_options["batch_size"]) # write example batch to TensorBoard if tf_writer is not None: logging.log(logging.INFO, "Writing example images to TensorBoard") with tf_writer.as_default(): for x_batch, y_batch in background_train_ds.take(1): tf.summary.image("Example train images", (x_batch+1)/2, max_outputs=30, step=0) labels = "" for i, label in enumerate(y_batch[:30]): labels += f"{i}: {np.asarray(train_exp.keywords)[label]} " tf.summary.text("Example train labels", labels, step=0) # get training objective multi_label_loss = get_training_objective(model_options) # get model input shape model_options["input_shape"] = ( model_options["crop_size"], model_options["crop_size"], 3) # load or create model if model_file is not None: assert model_options["n_classes"] == len(train_exp.keywords) vision_network, train_state = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=multi_label_loss) # get previous tracking variables initial_model = False global_step, model_epochs, _, best_val_score = train_state else: model_options["n_classes"] = len(train_exp.keywords) vision_network = create_vision_network(model_options) # create tracking variables initial_model = True global_step = 0 model_epochs = 0 if model_options["one_shot_validation"]: best_val_score = -np.inf else: best_val_score = np.inf # load or create Adam optimizer with decayed learning rate lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( model_options["learning_rate"], decay_rate=model_options["decay_rate"], decay_steps=model_options["decay_steps"], staircase=True) if model_file is not None: logging.log(logging.INFO, "Restoring optimizer state") optimizer = vision_network.optimizer else: optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) # compile model to store optimizer with model when saving vision_network.compile(optimizer=optimizer, loss=multi_label_loss) # create few-shot model from vision network for background training vision_few_shot_model = base.BaseModel(vision_network, multi_label_loss) # test model on one-shot validation task prior to training if model_options["one_shot_validation"]: one_shot_dev_exp = flickr_vision.FlickrVision( keywords_split="background_dev", flickr8k_image_dir=os.path.join( "data", "external", "flickr8k_images"), flickr30k_image_dir=os.path.join( "data", "external", "flickr30k_images"), mscoco_image_dir=os.path.join( "data", "external", "mscoco", "val2017"), preprocess_func=get_data_preprocess_func(model_options)) embedding_model_func = lambda vision_network: create_embedding_model( model_options, vision_network) classification = False if FLAGS.classification: assert FLAGS.embed_layer in ["logits", "softmax"] classification = True # create few-shot model from vision network for one-shot validation if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, vision_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( vision_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f}") # create training metrics precision_metric = tf.keras.metrics.Precision() recall_metric = tf.keras.metrics.Recall() loss_metric = tf.keras.metrics.Mean() best_model = False # store model options on first run if initial_model: file_io.write_json( os.path.join(output_dir, "model_options.json"), model_options) # train model for epoch in range(model_epochs, model_options["epochs"]): logging.log(logging.INFO, f"Epoch {epoch:03d}") precision_metric.reset_states() recall_metric.reset_states() loss_metric.reset_states() # train on epoch of training data step_pbar = tqdm(background_train_ds, bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]") for step, (x_batch, y_batch) in enumerate(step_pbar): loss_value, y_predict = vision_few_shot_model.train_step( x_batch, y_batch, optimizer, clip_norm=model_options["gradient_clip_norm"]) y_one_hot_predict = tf.round(tf.nn.sigmoid(y_predict)) precision_metric.update_state(y_batch, y_one_hot_predict) recall_metric.update_state(y_batch, y_one_hot_predict) loss_metric.update_state(loss_value) step_loss = tf.reduce_mean(loss_value) train_loss = loss_metric.result().numpy() train_precision = precision_metric.result().numpy() train_recall = recall_metric.result().numpy() train_f1 = 2 / ((1/train_precision) + (1/train_recall)) step_pbar.set_description_str( f"\tStep {step:03d}: " f"Step loss: {step_loss:.6f}, " f"Loss: {train_loss:.6f}, " f"Precision: {train_precision:.3%}, " f"Recall: {train_recall:.3%}, " f"F-1: {train_f1:.3%}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", step_loss, step=global_step) global_step += 1 # validate classification model precision_metric.reset_states() recall_metric.reset_states() loss_metric.reset_states() for x_batch, y_batch in background_dev_ds: y_predict = vision_few_shot_model.predict(x_batch, training=False) loss_value = vision_few_shot_model.loss(y_batch, y_predict) y_one_hot_predict = tf.round(tf.nn.sigmoid(y_predict)) precision_metric.update_state(y_batch, y_one_hot_predict) recall_metric.update_state(y_batch, y_one_hot_predict) loss_metric.update_state(loss_value) dev_loss = loss_metric.result().numpy() dev_precision = precision_metric.result().numpy() dev_recall = recall_metric.result().numpy() dev_f1 = 2 / ((1/dev_precision) + (1/dev_recall)) # validate model on one-shot dev task if specified if model_options["one_shot_validation"]: if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, vision_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( vision_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) val_score = val_task_accuracy val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy" # otherwise, validate on classification task else: val_score = dev_f1 val_metric = "F-1" if val_score >= best_val_score: best_val_score = val_score best_model = True # log results logging.log( logging.INFO, f"Train: Loss: {train_loss:.6f}, Precision: {train_precision:.3%}, " f"Recall: {train_recall:.3%}, F-1: {train_f1:.3%}") logging.log( logging.INFO, f"Validation: Loss: {dev_loss:.6f}, Precision: " f"{dev_precision:.3%}, Recall: {dev_recall:.3%}, F-1: " f"{dev_f1:.3%} {'*' if best_model else ''}") if model_options["one_shot_validation"]: logging.log( logging.INFO, f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f} {'*' if best_model else ''}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", train_loss, step=global_step) tf.summary.scalar( "Train precision", train_precision, step=global_step) tf.summary.scalar( "Train recall", train_recall, step=global_step) tf.summary.scalar( "Train F-1", train_f1, step=global_step) tf.summary.scalar( "Validation loss", dev_loss, step=global_step) tf.summary.scalar( "Validation precision", dev_precision, step=global_step) tf.summary.scalar( "Validation recall", dev_recall, step=global_step) tf.summary.scalar( "Validation F-1", dev_f1, step=global_step) if model_options["one_shot_validation"]: tf.summary.scalar( f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy", val_task_accuracy, step=global_step) # store model and results model_utils.save_model( vision_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="model") if best_model: best_model = False model_utils.save_model( vision_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="best_model")