def train(model_options, output_dir, model_file=None, model_step_file=None, tf_writer=None): """Create and train spoken word similarity model for one-shot learning.""" # load embeddings from dense layer of base model embed_dir = os.path.join(model_options["base_dir"], "embed", "dense") # load training data (embed dir determines if mfcc/fbank) train_exp, dev_exp = dataset.create_flickr_audio_train_data( "mfcc", embed_dir=embed_dir, speaker_mode=FLAGS.speaker_mode) train_labels = [] for keyword in train_exp.keywords_set[3]: label = train_exp.keyword_labels[keyword] train_labels.append(label) train_labels = np.asarray(train_labels) dev_labels = [] for keyword in dev_exp.keywords_set[3]: label = train_exp.keyword_labels[keyword] dev_labels.append(label) dev_labels = np.asarray(dev_labels) train_paths = train_exp.embed_paths dev_paths = dev_exp.embed_paths # define preprocessing for base model embeddings preprocess_data_func = lambda example: dataset.parse_embedding_protobuf( example)["embed"] # create balanced batch training dataset pipeline if model_options["balanced"]: assert model_options["p"] is not None assert model_options["k"] is not None shuffle_train = False prefetch_train = False num_repeat = model_options["num_batches"] model_options["batch_size"] = model_options["p"] * model_options["k"] # get unique path train indices per unique label train_labels_series = pd.Series(train_labels) train_label_idx = { label: idx.values[ np.unique(train_paths[idx.values], return_index=True)[1]] for label, idx in train_labels_series.groupby( train_labels_series).groups.items()} # cache paths to speed things up a little ... file_io.check_create_dir(os.path.join(output_dir, "cache")) # create a dataset for each unique keyword label (shuffled and cached) train_label_datasets = [ tf.data.Dataset.zip(( tf.data.Dataset.from_tensor_slices(train_paths[idx]), tf.data.Dataset.from_tensor_slices(train_labels[idx]))).cache( os.path.join( output_dir, "cache", str(label))).shuffle(20) # len(idx) for label, idx in train_label_idx.items()] # create a dataset that samples balanced batches from the label datasets background_train_ds = dataset.create_balanced_batch_dataset( model_options["p"], model_options["k"], train_label_datasets) # create standard training dataset pipeline (shuffle and load training set) else: shuffle_train = True prefetch_train = True num_repeat = model_options["num_augment"] background_train_ds = tf.data.Dataset.zip(( tf.data.Dataset.from_tensor_slices(train_paths), tf.data.Dataset.from_tensor_slices(train_labels))) # load embedding TFRecords (faster here than before balanced sampling) # batch to read files in parallel background_train_ds = background_train_ds.batch( model_options["batch_size"]) background_train_ds = background_train_ds.flat_map( lambda paths, labels: tf.data.Dataset.zip(( tf.data.TFRecordDataset( paths, compression_type="ZLIB", num_parallel_reads=8), tf.data.Dataset.from_tensor_slices(labels)))) # map data preprocessing function across training data background_train_ds = background_train_ds.map( lambda data, label: (preprocess_data_func(data), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) # repeat augmentation, shuffle and batch train data if num_repeat is not None: background_train_ds = background_train_ds.repeat(num_repeat) if shuffle_train: background_train_ds = background_train_ds.shuffle(1000) background_train_ds = background_train_ds.batch( model_options["batch_size"]) if prefetch_train: background_train_ds = background_train_ds.prefetch( tf.data.experimental.AUTOTUNE) # create dev set pipeline for siamese validation background_dev_ds = tf.data.Dataset.zip(( tf.data.TFRecordDataset( dev_paths, compression_type="ZLIB", num_parallel_reads=8).map(preprocess_data_func), tf.data.Dataset.from_tensor_slices(dev_labels))) background_dev_ds = background_dev_ds.batch( batch_size=model_options["batch_size"]) # get training objective triplet_loss = get_training_objective(model_options) # get model input shape for x_batch, _ in background_train_ds.take(1): model_options["base_embed_size"] = int( tf.shape(x_batch)[1].numpy()) model_options["input_shape"] = [model_options["base_embed_size"]] # load or create model if model_file is not None: speech_network, train_state = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=triplet_loss) # get previous tracking variables initial_model = False global_step, model_epochs, _, best_val_score = train_state else: speech_network = create_speech_network(model_options) # create tracking variables initial_model = True global_step = 0 model_epochs = 0 if model_options["one_shot_validation"]: best_val_score = -np.inf else: best_val_score = np.inf # load or create Adam optimizer with decayed learning rate lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( model_options["learning_rate"], decay_rate=model_options["decay_rate"], decay_steps=model_options["decay_steps"], staircase=True) if model_file is not None: logging.log(logging.INFO, "Restoring optimizer state") optimizer = speech_network.optimizer else: optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) # compile model to store optimizer with model when saving speech_network.compile(optimizer=optimizer, loss=triplet_loss) # create few-shot model from speech network for background training speech_few_shot_model = base.BaseModel(speech_network, triplet_loss) # test model on one-shot validation task prior to training if model_options["one_shot_validation"]: one_shot_dev_exp = flickr_speech.FlickrSpeech( features="mfcc", keywords_split="background_dev", preprocess_func=get_data_preprocess_func(), embed_dir=embed_dir, speaker_mode=FLAGS.speaker_mode) embedding_model_func = create_embedding_model classification = False if FLAGS.classification: assert FLAGS.fine_tune_steps is not None classification = True # create few-shot model from speech network for one-shot validation if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f}") # create training metrics loss_metric = tf.keras.metrics.Mean() best_model = False # store model options on first run if initial_model: file_io.write_json( os.path.join(output_dir, "model_options.json"), model_options) # train model for epoch in range(model_epochs, model_options["epochs"]): logging.log(logging.INFO, f"Epoch {epoch:03d}") loss_metric.reset_states() # train on epoch of training data step_pbar = tqdm(background_train_ds, bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]") for step, (x_batch, y_batch) in enumerate(step_pbar): loss_value, y_predict = speech_few_shot_model.train_step( x_batch, y_batch, optimizer, clip_norm=model_options["gradient_clip_norm"]) loss_metric.update_state(loss_value) step_loss = tf.reduce_mean(loss_value) train_loss = loss_metric.result().numpy() step_pbar.set_description_str( f"\tStep {step:03d}: " f"Step loss: {step_loss:.6f}, " f"Loss: {train_loss:.6f}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", step_loss, step=global_step) global_step += 1 # validate siamese model loss_metric.reset_states() for x_batch, y_batch in background_dev_ds: y_predict = speech_few_shot_model.predict(x_batch, training=False) loss_value = speech_few_shot_model.loss(y_batch, y_predict) loss_metric.update_state(loss_value) dev_loss = loss_metric.result().numpy() # validate model on one-shot dev task if specified if model_options["one_shot_validation"]: if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) val_score = val_task_accuracy val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy" if val_score >= best_val_score: best_val_score = val_score best_model = True # otherwise, validate on siamese task else: val_score = dev_loss val_metric = "loss" if val_score <= best_val_score: best_val_score = val_score best_model = True # log results logging.log(logging.INFO, f"Train: Loss: {train_loss:.6f}") logging.log( logging.INFO, f"Validation: Loss: {dev_loss:.6f} {'*' if best_model else ''}") if model_options["one_shot_validation"]: logging.log( logging.INFO, f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f} {'*' if best_model else ''}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", train_loss, step=global_step) tf.summary.scalar( f"Validation loss", dev_loss, step=global_step) if model_options["one_shot_validation"]: tf.summary.scalar( f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy", val_task_accuracy, step=global_step) # store model and results model_utils.save_model( speech_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="model") if best_model: best_model = False model_utils.save_model( speech_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="best_model")
def train(model_options, output_dir, model_file=None, model_step_file=None, tf_writer=None): """Create and train spoken word classification model for one-shot learning.""" # load training data train_exp, dev_exp = dataset.create_flickr_audio_train_data( model_options["features"], speaker_mode=FLAGS.speaker_mode) train_labels = [] for keyword in train_exp.keywords_set[3]: label = train_exp.keyword_labels[keyword] train_labels.append(label) train_labels = np.asarray(train_labels) dev_labels = [] for keyword in dev_exp.keywords_set[3]: label = train_exp.keyword_labels[keyword] dev_labels.append(label) dev_labels = np.asarray(dev_labels) train_paths = train_exp.audio_paths dev_paths = dev_exp.audio_paths lb = LabelBinarizer() train_labels_one_hot = lb.fit_transform(train_labels) dev_labels_one_hot = lb.transform(dev_labels) # define preprocessing for speech features preprocess_speech_func = functools.partial( dataset.load_and_preprocess_speech, features=model_options["features"], max_length=model_options["max_length"], scaling=model_options["scaling"]) preprocess_speech_ds_func = lambda path: tf.py_function( func=preprocess_speech_func, inp=[path], Tout=tf.float32) # create standard training dataset pipeline background_train_ds = tf.data.Dataset.zip( (tf.data.Dataset.from_tensor_slices(train_paths), tf.data.Dataset.from_tensor_slices(train_labels_one_hot))) # map data preprocessing function across training data background_train_ds = background_train_ds.map( lambda path, label: (preprocess_speech_ds_func(path), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) # shuffle and batch train data background_train_ds = background_train_ds.shuffle(1000) background_train_ds = background_train_ds.batch( model_options["batch_size"]) background_train_ds = background_train_ds.prefetch( tf.data.experimental.AUTOTUNE) # create dev set pipeline for classification validation background_dev_ds = tf.data.Dataset.zip( (tf.data.Dataset.from_tensor_slices(dev_paths).map( preprocess_speech_ds_func), tf.data.Dataset.from_tensor_slices(dev_labels_one_hot))) background_dev_ds = background_dev_ds.batch( batch_size=model_options["batch_size"]) # write example batch to TensorBoard if tf_writer is not None: logging.log(logging.INFO, "Writing example features to TensorBoard") with tf_writer.as_default(): for x_batch, y_batch in background_train_ds.take(1): speech_feats = [] for feats in x_batch[:30]: feats = np.transpose(feats) speech_feats.append( (feats - np.min(feats)) / np.max(feats)) tf.summary.image( f"Example train speech {model_options['features']}", np.expand_dims(speech_feats, axis=-1), max_outputs=30, step=0) labels = "" for i, label in enumerate(y_batch[:30]): labels += f"{i}: {np.asarray(train_exp.keywords)[label]} " tf.summary.text("Example train labels", labels, step=0) # get training objective loss = get_training_objective(model_options) # get model input shape if model_options["features"] == "mfcc": model_options["input_shape"] = [model_options["max_length"], 39] else: model_options["input_shape"] = [model_options["max_length"], 40] # load or create model if model_file is not None: assert model_options["n_classes"] == len(train_exp.keywords) speech_network, train_state = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=loss) # get previous tracking variables initial_model = False global_step, model_epochs, _, best_val_score = train_state else: model_options["n_classes"] = len(train_exp.keywords) speech_network = create_speech_network(model_options) # create tracking variables initial_model = True global_step = 0 model_epochs = 0 if model_options["one_shot_validation"]: best_val_score = -np.inf else: best_val_score = np.inf # load or create Adam optimizer with decayed learning rate lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( model_options["learning_rate"], decay_rate=model_options["decay_rate"], decay_steps=model_options["decay_steps"], staircase=True) if model_file is not None: logging.log(logging.INFO, "Restoring optimizer state") optimizer = speech_network.optimizer else: optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) # compile model to store optimizer with model when saving speech_network.compile(optimizer=optimizer, loss=loss) # create few-shot model from speech network for background training speech_few_shot_model = base.BaseModel(speech_network, loss) # test model on one-shot validation task prior to training if model_options["one_shot_validation"]: one_shot_dev_exp = flickr_speech.FlickrSpeech( features=model_options["features"], keywords_split="background_dev", preprocess_func=get_data_preprocess_func(model_options), speaker_mode=FLAGS.speaker_mode) embedding_model_func = lambda speech_network: create_embedding_model( model_options, speech_network) classification = False if FLAGS.classification: assert FLAGS.embed_layer in ["logits", "softmax"] classification = True # create few-shot model from speech network for one-shot validation if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel(speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f}") # create training metrics accuracy_metric = tf.keras.metrics.CategoricalAccuracy() loss_metric = tf.keras.metrics.Mean() best_model = False # store model options on first run if initial_model: file_io.write_json(os.path.join(output_dir, "model_options.json"), model_options) # train model for epoch in range(model_epochs, model_options["epochs"]): logging.log(logging.INFO, f"Epoch {epoch:03d}") accuracy_metric.reset_states() loss_metric.reset_states() # train on epoch of training data step_pbar = tqdm(background_train_ds, bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]") for step, (x_batch, y_batch) in enumerate(step_pbar): loss_value, y_predict = speech_few_shot_model.train_step( x_batch, y_batch, optimizer, clip_norm=model_options["gradient_clip_norm"]) accuracy_metric.update_state(y_batch, y_predict) loss_metric.update_state(loss_value) step_loss = tf.reduce_mean(loss_value) train_loss = loss_metric.result().numpy() train_accuracy = accuracy_metric.result().numpy() step_pbar.set_description_str( f"\tStep {step:03d}: " f"Step loss: {step_loss:.6f}, " f"Loss: {train_loss:.6f}, " f"Categorical accuracy: {train_accuracy:.3%}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar("Train step loss", step_loss, step=global_step) global_step += 1 # validate classification model accuracy_metric.reset_states() loss_metric.reset_states() for x_batch, y_batch in background_dev_ds: y_predict = speech_few_shot_model.predict(x_batch, training=False) loss_value = speech_few_shot_model.loss(y_batch, y_predict) accuracy_metric.update_state(y_batch, y_predict) loss_metric.update_state(loss_value) dev_loss = loss_metric.result().numpy() dev_accuracy = accuracy_metric.result().numpy() # validate model on one-shot dev task if specified if model_options["one_shot_validation"]: if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( speech_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) val_score = val_task_accuracy val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy" # otherwise, validate on classification task else: val_score = dev_accuracy val_metric = "categorical accuracy" if val_score >= best_val_score: best_val_score = val_score best_model = True # log results logging.log( logging.INFO, f"Train: Loss: {train_loss:.6f}, Categorical accuracy: " f"{train_accuracy:.3%}") logging.log( logging.INFO, f"Validation: Loss: {dev_loss:.6f}, Categorical accuracy: " f"{dev_accuracy:.3%} {'*' if best_model else ''}") if model_options["one_shot_validation"]: logging.log( logging.INFO, f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f} {'*' if best_model else ''}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar("Train step loss", train_loss, step=global_step) tf.summary.scalar("Train categorical accuracy", train_accuracy, step=global_step) tf.summary.scalar("Validation loss", dev_loss, step=global_step) tf.summary.scalar("Validation categorical accuracy", dev_accuracy, step=global_step) if model_options["one_shot_validation"]: tf.summary.scalar( f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy", val_task_accuracy, step=global_step) # store model and results model_utils.save_model(speech_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="model") if best_model: best_model = False model_utils.save_model(speech_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="best_model")
def train(model_options, output_dir, model_file=None, model_step_file=None, tf_writer=None): """Create and train audio-visual similarity model for one-shot learning.""" # load embeddings from (linear) dense layer of base speech and vision models speech_embed_dir = os.path.join(model_options["audio_base_dir"], "embed", "dense") image_embed_dir = os.path.join(model_options["vision_base_dir"], "embed", "dense") # load training data (embed dir determines mfcc/fbank speech features) train_exp, dev_exp = dataset.create_flickr_multimodal_train_data( "mfcc", speech_embed_dir=speech_embed_dir, image_embed_dir=image_embed_dir, speaker_mode=FLAGS.speaker_mode, speech_preprocess_func=data_preprocess_func, image_preprocess_func=data_preprocess_func, unseen_match_set=FLAGS.unseen_match_set) # create training dataset generator def sample_task_function(task_exp, l, k, n): def sample(): curr_episode_train, curr_episode_test = task_exp.sample_episode( l, k, n) x_train_s = task_exp.speech_experiment.data[curr_episode_train[0]] x_train_i = task_exp.vision_experiment.data[curr_episode_train[0]] x_query_s = task_exp.speech_experiment.data[curr_episode_test[0]] x_query_i = task_exp.vision_experiment.data[curr_episode_test[0]] return x_train_s, x_train_i, x_query_s, x_query_i return sample train_generator = sample_task_function(train_exp, FLAGS.L, FLAGS.K, FLAGS.N) # create dummy dataset with infinite zero elements background_train_ds = tf.data.Dataset.from_tensors(tf.constant(0.)) background_train_ds = background_train_ds.repeat(-1) # parallel map dummy elements to task data background_train_ds = background_train_ds.map( lambda _: tf.py_function( train_generator, inp=[], Tout=[tf.float32] * 4), num_parallel_calls=tf.data.experimental.AUTOTUNE) background_train_ds = background_train_ds.batch( batch_size=model_options["meta_batch_size"]) background_train_ds = background_train_ds.prefetch( tf.data.experimental.AUTOTUNE) # get training objective triplet_loss = get_training_objective(model_options) # get model input shape for speech_batch, image_batch, _, _ in background_train_ds.take(1): model_options["audio_base_embed_size"] = int( tf.shape(speech_batch)[-1].numpy()) model_options["vision_base_embed_size"] = int( tf.shape(image_batch)[-1].numpy()) model_options["audio_input_shape"] = [ model_options["audio_base_embed_size"] ] model_options["vision_input_shape"] = [ model_options["vision_base_embed_size"] ] # load or create models if model_file is not None: join_network_model, train_state = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=get_training_objective(model_options)) speech_network = tf.keras.Model(inputs=join_network_model.inputs[0], outputs=join_network_model.outputs[0]) vision_network = tf.keras.Model(inputs=join_network_model.inputs[1], outputs=join_network_model.outputs[1]) # get previous tracking variables initial_model = False global_step, _, _, best_val_score = train_state else: speech_network = create_speech_network(model_options) vision_network = create_vision_network(model_options) # create tracking variables initial_model = True global_step = 0 if model_options["one_shot_validation"]: best_val_score = -np.inf else: best_val_score = np.inf # load or create Adam optimizer with decayed learning rate lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( model_options["meta_learning_rate"], decay_rate=model_options["decay_rate"], decay_steps=model_options["decay_steps"], staircase=True) if model_file is not None: logging.log(logging.INFO, "Restoring optimizer state") optimizer = join_network_model.optimizer else: optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) # compile models to store optimizer with model when saving join_network_model = tf.keras.Model( inputs=[speech_network.input, vision_network.input], outputs=[speech_network.output, vision_network.output]) # speech_network.compile(optimizer=optimizer, loss=triplet_loss) join_network_model.compile(optimizer=optimizer, loss=triplet_loss) # create few-shot model from speech network for background training multimodal_model = base.WeaklySupervisedMAML( speech_network, vision_network, triplet_loss, inner_optimizer_lr=FLAGS.fine_tune_lr) # create few-shot model from speech and vision networks for one-shot validation test_few_shot_model = create_fine_tune_model(model_options, speech_network, vision_network) fine_tune_optimizer = None if not FLAGS.adam: fine_tune_optimizer = tf.keras.optimizers.SGD(FLAGS.fine_tune_lr) # test model on one-shot validation task prior to training val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot( dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, direct_match=FLAGS.direct_match, multimodal_model=test_few_shot_model, multimodal_embedding_func=None, #create_embedding_model, optimizer=fine_tune_optimizer, fine_tune_steps=FLAGS.fine_tune_steps * 2, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f}") # create training metrics step_loss_metric = tf.keras.metrics.Mean() avg_loss_metric = tf.keras.metrics.Mean() best_model = False # store model options on first run if initial_model: file_io.write_json(os.path.join(output_dir, "model_options.json"), model_options) # also store initial model for probing and things model_utils.save_model(join_network_model, output_dir, 0, 0, "not tested", 0., 0., name="initial_model") # train model step_pbar = tqdm(background_train_ds, bar_format="{desc} {r_bar}", initial=global_step, total=model_options["meta_steps"]) for batch in step_pbar: # train on batch of training task data train_s_batch, train_i_batch, test_s_batch, test_i_batch = batch meta_loss, inner_losses, meta_losses = multimodal_model.maml_train_step( train_s_batch, train_i_batch, test_s_batch, test_i_batch, FLAGS.fine_tune_steps, optimizer, stop_gradients=FLAGS.first_order, clip_norm=model_options["gradient_clip_norm"]) step_loss_metric.reset_states() step_loss_metric.update_state(meta_loss) avg_loss_metric.update_state(meta_loss) avg_loss = avg_loss_metric.result().numpy() step_pbar.set_description_str(f"Step loss: {meta_loss:.6f}, " f"Average loss: {avg_loss:.6f}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar("Train step loss", meta_loss, step=global_step) if (global_step % model_options["validation_interval"] == 0 or global_step == model_options["meta_steps"]) and global_step > 0: # validate model on one-shot dev task test_few_shot_model = create_fine_tune_model( model_options, speech_network, vision_network) val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot( dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, direct_match=FLAGS.direct_match, multimodal_model=test_few_shot_model, multimodal_embedding_func=None, #create_embedding_model, optimizer=fine_tune_optimizer, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy" if val_task_accuracy >= best_val_score: best_val_score = val_task_accuracy best_model = True # log results avg_loss_metric.reset_states() logging.log(logging.INFO, f"Step {global_step:03d}") logging.log( logging.INFO, f"Train: Step loss: {meta_loss:.6f}, " f"Average loss: {avg_loss:.6f}") logging.log( logging.INFO, f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f} {'*' if best_model else ''}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy", val_task_accuracy, step=global_step) model_utils.save_model(join_network_model, output_dir, global_step, global_step, val_metric, val_task_accuracy, best_val_score, name="model") if best_model: best_model = False model_utils.save_model(join_network_model, output_dir, global_step, global_step, val_metric, val_task_accuracy, best_val_score, name="best_model") global_step += 1 if global_step > model_options["meta_steps"]: logging.log(logging.INFO, f"Training complete after {global_step-1:03d} steps") break
def train(model_options, output_dir, model_file=None, model_step_file=None, tf_writer=None): """Create and train image classification model for one-shot learning.""" # load training data train_exp, dev_exp = dataset.create_flickr_vision_train_data( model_options["data"]) train_labels = [] for image_keywords in train_exp.unique_image_keywords: labels = map( lambda keyword: train_exp.keyword_labels[keyword], image_keywords) train_labels.append(np.array(list(labels))) train_labels = np.asarray(train_labels) dev_labels = [] for image_keywords in dev_exp.unique_image_keywords: labels = map( lambda keyword: train_exp.keyword_labels[keyword], image_keywords) dev_labels.append(np.array(list(labels))) dev_labels = np.asarray(dev_labels) train_paths = train_exp.unique_image_paths dev_paths = dev_exp.unique_image_paths mlb = MultiLabelBinarizer() train_labels_multi_hot = mlb.fit_transform(train_labels) dev_labels_multi_hot = mlb.transform(dev_labels) # define preprocessing for images preprocess_images_func = functools.partial( dataset.load_and_preprocess_image, crop_size=model_options["crop_size"], augment_crop=model_options["augment_train"], random_scales=model_options["random_scales"], horizontal_flip=model_options["horizontal_flip"], colour=model_options["colour"]) # create standard training dataset pipeline background_train_ds = tf.data.Dataset.zip(( tf.data.Dataset.from_tensor_slices(train_paths), tf.data.Dataset.from_tensor_slices(train_labels_multi_hot))) # map data preprocessing function across training data background_train_ds = background_train_ds.map( lambda path, label: (preprocess_images_func(path), label), num_parallel_calls=tf.data.experimental.AUTOTUNE) # repeat augmentation, shuffle and batch train data if model_options["num_augment"] is not None: background_train_ds = background_train_ds.repeat( model_options["num_augment"]) background_train_ds = background_train_ds.shuffle(1000) background_train_ds = background_train_ds.batch( model_options["batch_size"]) background_train_ds = background_train_ds.prefetch( tf.data.experimental.AUTOTUNE) # create dev set pipeline for classification validation background_dev_ds = tf.data.Dataset.zip(( tf.data.Dataset.from_tensor_slices( dev_paths).map(preprocess_images_func), tf.data.Dataset.from_tensor_slices(dev_labels_multi_hot))) background_dev_ds = background_dev_ds.batch( batch_size=model_options["batch_size"]) # write example batch to TensorBoard if tf_writer is not None: logging.log(logging.INFO, "Writing example images to TensorBoard") with tf_writer.as_default(): for x_batch, y_batch in background_train_ds.take(1): tf.summary.image("Example train images", (x_batch+1)/2, max_outputs=30, step=0) labels = "" for i, label in enumerate(y_batch[:30]): labels += f"{i}: {np.asarray(train_exp.keywords)[label]} " tf.summary.text("Example train labels", labels, step=0) # get training objective multi_label_loss = get_training_objective(model_options) # get model input shape model_options["input_shape"] = ( model_options["crop_size"], model_options["crop_size"], 3) # load or create model if model_file is not None: assert model_options["n_classes"] == len(train_exp.keywords) vision_network, train_state = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=multi_label_loss) # get previous tracking variables initial_model = False global_step, model_epochs, _, best_val_score = train_state else: model_options["n_classes"] = len(train_exp.keywords) vision_network = create_vision_network(model_options) # create tracking variables initial_model = True global_step = 0 model_epochs = 0 if model_options["one_shot_validation"]: best_val_score = -np.inf else: best_val_score = np.inf # load or create Adam optimizer with decayed learning rate lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( model_options["learning_rate"], decay_rate=model_options["decay_rate"], decay_steps=model_options["decay_steps"], staircase=True) if model_file is not None: logging.log(logging.INFO, "Restoring optimizer state") optimizer = vision_network.optimizer else: optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) # compile model to store optimizer with model when saving vision_network.compile(optimizer=optimizer, loss=multi_label_loss) # create few-shot model from vision network for background training vision_few_shot_model = base.BaseModel(vision_network, multi_label_loss) # test model on one-shot validation task prior to training if model_options["one_shot_validation"]: one_shot_dev_exp = flickr_vision.FlickrVision( keywords_split="background_dev", flickr8k_image_dir=os.path.join( "data", "external", "flickr8k_images"), flickr30k_image_dir=os.path.join( "data", "external", "flickr30k_images"), mscoco_image_dir=os.path.join( "data", "external", "mscoco", "val2017"), preprocess_func=get_data_preprocess_func(model_options)) embedding_model_func = lambda vision_network: create_embedding_model( model_options, vision_network) classification = False if FLAGS.classification: assert FLAGS.embed_layer in ["logits", "softmax"] classification = True # create few-shot model from vision network for one-shot validation if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, vision_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( vision_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f}") # create training metrics precision_metric = tf.keras.metrics.Precision() recall_metric = tf.keras.metrics.Recall() loss_metric = tf.keras.metrics.Mean() best_model = False # store model options on first run if initial_model: file_io.write_json( os.path.join(output_dir, "model_options.json"), model_options) # train model for epoch in range(model_epochs, model_options["epochs"]): logging.log(logging.INFO, f"Epoch {epoch:03d}") precision_metric.reset_states() recall_metric.reset_states() loss_metric.reset_states() # train on epoch of training data step_pbar = tqdm(background_train_ds, bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]") for step, (x_batch, y_batch) in enumerate(step_pbar): loss_value, y_predict = vision_few_shot_model.train_step( x_batch, y_batch, optimizer, clip_norm=model_options["gradient_clip_norm"]) y_one_hot_predict = tf.round(tf.nn.sigmoid(y_predict)) precision_metric.update_state(y_batch, y_one_hot_predict) recall_metric.update_state(y_batch, y_one_hot_predict) loss_metric.update_state(loss_value) step_loss = tf.reduce_mean(loss_value) train_loss = loss_metric.result().numpy() train_precision = precision_metric.result().numpy() train_recall = recall_metric.result().numpy() train_f1 = 2 / ((1/train_precision) + (1/train_recall)) step_pbar.set_description_str( f"\tStep {step:03d}: " f"Step loss: {step_loss:.6f}, " f"Loss: {train_loss:.6f}, " f"Precision: {train_precision:.3%}, " f"Recall: {train_recall:.3%}, " f"F-1: {train_f1:.3%}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", step_loss, step=global_step) global_step += 1 # validate classification model precision_metric.reset_states() recall_metric.reset_states() loss_metric.reset_states() for x_batch, y_batch in background_dev_ds: y_predict = vision_few_shot_model.predict(x_batch, training=False) loss_value = vision_few_shot_model.loss(y_batch, y_predict) y_one_hot_predict = tf.round(tf.nn.sigmoid(y_predict)) precision_metric.update_state(y_batch, y_one_hot_predict) recall_metric.update_state(y_batch, y_one_hot_predict) loss_metric.update_state(loss_value) dev_loss = loss_metric.result().numpy() dev_precision = precision_metric.result().numpy() dev_recall = recall_metric.result().numpy() dev_f1 = 2 / ((1/dev_precision) + (1/dev_recall)) # validate model on one-shot dev task if specified if model_options["one_shot_validation"]: if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, vision_few_shot_model.model, num_classes=FLAGS.L) else: test_few_shot_model = base.BaseModel( vision_few_shot_model.model, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, classification=classification, model=test_few_shot_model, embedding_model_func=embedding_model_func, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) val_score = val_task_accuracy val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy" # otherwise, validate on classification task else: val_score = dev_f1 val_metric = "F-1" if val_score >= best_val_score: best_val_score = val_score best_model = True # log results logging.log( logging.INFO, f"Train: Loss: {train_loss:.6f}, Precision: {train_precision:.3%}, " f"Recall: {train_recall:.3%}, F-1: {train_f1:.3%}") logging.log( logging.INFO, f"Validation: Loss: {dev_loss:.6f}, Precision: " f"{dev_precision:.3%}, Recall: {dev_recall:.3%}, F-1: " f"{dev_f1:.3%} {'*' if best_model else ''}") if model_options["one_shot_validation"]: logging.log( logging.INFO, f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f} {'*' if best_model else ''}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", train_loss, step=global_step) tf.summary.scalar( "Train precision", train_precision, step=global_step) tf.summary.scalar( "Train recall", train_recall, step=global_step) tf.summary.scalar( "Train F-1", train_f1, step=global_step) tf.summary.scalar( "Validation loss", dev_loss, step=global_step) tf.summary.scalar( "Validation precision", dev_precision, step=global_step) tf.summary.scalar( "Validation recall", dev_recall, step=global_step) tf.summary.scalar( "Validation F-1", dev_f1, step=global_step) if model_options["one_shot_validation"]: tf.summary.scalar( f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy", val_task_accuracy, step=global_step) # store model and results model_utils.save_model( vision_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="model") if best_model: best_model = False model_utils.save_model( vision_few_shot_model.model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="best_model")
def train(model_options, output_dir, model_file=None, model_step_file=None, tf_writer=None): """Create and train audio-visual similarity model for one-shot learning.""" # load embeddings from (linear) dense layer of base speech and vision models speech_embed_dir = os.path.join( model_options["audio_base_dir"], "embed", "dense") image_embed_dir = os.path.join( model_options["vision_base_dir"], "embed", "dense") # load training data (embed dir determines mfcc/fbank speech features) train_exp, dev_exp = dataset.create_flickr_multimodal_train_data( "mfcc", speech_embed_dir=speech_embed_dir, image_embed_dir=image_embed_dir, speaker_mode=FLAGS.speaker_mode, unseen_match_set=FLAGS.unseen_match_set) train_speech_paths = train_exp.speech_experiment.data train_image_paths = train_exp.vision_experiment.data dev_speech_paths = dev_exp.speech_experiment.data dev_image_paths = dev_exp.vision_experiment.data # define preprocessing for base model embeddings preprocess_data_func = lambda example: dataset.parse_embedding_protobuf( example)["embed"] # create standard training dataset pipeline background_train_ds = tf.data.Dataset.zip(( tf.data.TFRecordDataset( train_speech_paths, compression_type="ZLIB"), tf.data.TFRecordDataset( train_image_paths, compression_type="ZLIB"))) # map data preprocessing function across training data background_train_ds = background_train_ds.map( lambda speech_path, image_path: ( preprocess_data_func(speech_path), preprocess_data_func(image_path)), num_parallel_calls=8) # shuffle and batch train data background_train_ds = background_train_ds.repeat(-1) background_train_ds = background_train_ds.shuffle(1000) background_train_ds = background_train_ds.batch( model_options["batch_size"]) background_train_ds = background_train_ds.take( model_options["num_batches"]) background_train_ds = background_train_ds.prefetch( tf.data.experimental.AUTOTUNE) # create dev set pipeline for validation background_dev_ds = tf.data.Dataset.zip(( tf.data.TFRecordDataset( dev_speech_paths, compression_type="ZLIB"), tf.data.TFRecordDataset( dev_image_paths, compression_type="ZLIB"))) background_dev_ds = background_dev_ds.map( lambda speech_path, image_path: ( preprocess_data_func(speech_path), preprocess_data_func(image_path)), num_parallel_calls=8) background_dev_ds = background_dev_ds.batch( batch_size=model_options["batch_size"]) # get training objective triplet_loss = get_training_objective(model_options) # get model input shape for speech_batch, image_batch in background_train_ds.take(1): model_options["audio_base_embed_size"] = int( tf.shape(speech_batch)[1].numpy()) model_options["vision_base_embed_size"] = int( tf.shape(image_batch)[1].numpy()) model_options["audio_input_shape"] = [ model_options["audio_base_embed_size"]] model_options["vision_input_shape"] = [ model_options["vision_base_embed_size"]] # load or create models if model_file is not None: join_network_model, train_state = model_utils.load_model( model_file=os.path.join(output_dir, model_file), model_step_file=os.path.join(output_dir, model_step_file), loss=get_training_objective(model_options)) speech_network = tf.keras.Model( inputs=join_network_model.inputs[0], outputs=join_network_model.outputs[0]) vision_network = tf.keras.Model( inputs=join_network_model.inputs[1], outputs=join_network_model.outputs[1]) # get previous tracking variables initial_model = False global_step, model_epochs, _, best_val_score = train_state else: speech_network = create_speech_network(model_options) vision_network = create_vision_network(model_options) # create tracking variables initial_model = True global_step = 0 model_epochs = 0 if model_options["one_shot_validation"]: best_val_score = -np.inf else: best_val_score = np.inf # load or create Adam optimizer with decayed learning rate lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay( model_options["learning_rate"], decay_rate=model_options["decay_rate"], decay_steps=model_options["decay_steps"], staircase=True) if model_file is not None: logging.log(logging.INFO, "Restoring optimizer state") optimizer = join_network_model.optimizer else: optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule) # compile models to store optimizer with model when saving join_network_model = tf.keras.Model( inputs=[speech_network.input, vision_network.input], outputs=[speech_network.output, vision_network.output]) # speech_network.compile(optimizer=optimizer, loss=triplet_loss) join_network_model.compile(optimizer=optimizer, loss=triplet_loss) # create few-shot model from speech network for background training multimodal_model = base.WeaklySupervisedModel( speech_network, vision_network, triplet_loss) # test model on one-shot validation task prior to training if model_options["one_shot_validation"]: one_shot_dev_exp = flickr_multimodal.FlickrMultimodal( features="mfcc", keywords_split="background_dev", flickr8k_image_dir=os.path.join("data", "external", "flickr8k_images"), speech_embed_dir=speech_embed_dir, image_embed_dir=image_embed_dir, speech_preprocess_func=data_preprocess_func, image_preprocess_func=data_preprocess_func, speaker_mode=FLAGS.speaker_mode, unseen_match_set=FLAGS.unseen_match_set) # create few-shot model from speech and vision networks for one-shot validation if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_network, vision_network) else: test_few_shot_model = base.WeaklySupervisedModel( speech_network, vision_network, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, direct_match=FLAGS.direct_match, multimodal_model=test_few_shot_model, multimodal_embedding_func=None, #create_embedding_model, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) logging.log( logging.INFO, f"Base model: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f}") # create training metrics loss_metric = tf.keras.metrics.Mean() best_model = False # store model options on first run if initial_model: file_io.write_json( os.path.join(output_dir, "model_options.json"), model_options) # train model for epoch in range(model_epochs, model_options["epochs"]): logging.log(logging.INFO, f"Epoch {epoch:03d}") loss_metric.reset_states() # train on epoch of training data step_pbar = tqdm(background_train_ds, bar_format="{desc} [{elapsed},{rate_fmt}{postfix}]") for step, (speech_batch, image_batch) in enumerate(step_pbar): loss_value, y_speech, y_image = multimodal_model.train_step( speech_batch, image_batch, optimizer, clip_norm=model_options["gradient_clip_norm"]) loss_metric.update_state(loss_value) step_loss = tf.reduce_mean(loss_value) train_loss = loss_metric.result().numpy() step_pbar.set_description_str( f"\tStep {step:03d}: " f"Step loss: {step_loss:.6f}, " f"Loss: {train_loss:.6f}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", step_loss, step=global_step) global_step += 1 # validate multimodal model loss_metric.reset_states() for speech_batch, image_batch in background_dev_ds: y_speech = multimodal_model.speech_model.predict( speech_batch, training=False) y_image = multimodal_model.vision_model.predict( image_batch, training=False) loss_value = multimodal_model.loss(y_speech, y_image) loss_metric.update_state(loss_value) dev_loss = loss_metric.result().numpy() # validate model on one-shot dev task if specified if model_options["one_shot_validation"]: if FLAGS.fine_tune_steps is not None: test_few_shot_model = create_fine_tune_model( model_options, speech_network, vision_network) else: test_few_shot_model = base.WeaklySupervisedModel( speech_network, vision_network, None, mc_dropout=FLAGS.mc_dropout) val_task_accuracy, _, conf_interval_95 = experiment.test_multimodal_l_way_k_shot( one_shot_dev_exp, FLAGS.K, FLAGS.L, n=FLAGS.N, num_episodes=FLAGS.episodes, k_neighbours=FLAGS.k_neighbours, metric=FLAGS.metric, direct_match=FLAGS.direct_match, multimodal_model=test_few_shot_model, multimodal_embedding_func=None, #create_embedding_model, fine_tune_steps=FLAGS.fine_tune_steps, fine_tune_lr=FLAGS.fine_tune_lr) val_score = val_task_accuracy val_metric = f"{FLAGS.L}-way {FLAGS.K}-shot accuracy" if val_score >= best_val_score: best_val_score = val_score best_model = True # otherwise, validate on siamese task else: val_score = dev_loss val_metric = "loss" if val_score <= best_val_score: best_val_score = val_score best_model = True # log results logging.log(logging.INFO, f"Train: Loss: {train_loss:.6f}") logging.log( logging.INFO, f"Validation: Loss: {dev_loss:.6f} {'*' if best_model else ''}") if model_options["one_shot_validation"]: logging.log( logging.INFO, f"Validation: {FLAGS.L}-way {FLAGS.K}-shot accuracy after " f"{FLAGS.episodes} episodes: {val_task_accuracy:.3%} +- " f"{conf_interval_95*100:.4f} {'*' if best_model else ''}") if tf_writer is not None: with tf_writer.as_default(): tf.summary.scalar( "Train step loss", train_loss, step=global_step) tf.summary.scalar( f"Validation loss", dev_loss, step=global_step) if model_options["one_shot_validation"]: tf.summary.scalar( f"Validation {FLAGS.L}-way {FLAGS.K}-shot accuracy", val_task_accuracy, step=global_step) # store model and results # model_utils.save_model( # multimodal_model.model_a.model, output_dir, epoch + 1, global_step, # val_metric, val_score, best_val_score, name="audio_model") # model_utils.save_model( # multimodal_model.model_b.model, output_dir, epoch + 1, global_step, # val_metric, val_score, best_val_score, name="vision_model") model_utils.save_model( join_network_model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="model") if best_model: best_model = False # model_utils.save_model( # multimodal_model.model_a.model, output_dir, epoch + 1, global_step, # val_metric, val_score, best_val_score, name="audio_best_model") # model_utils.save_model( # multimodal_model.model_b.model, output_dir, epoch + 1, global_step, # val_metric, val_score, best_val_score, name="vision_best_model") model_utils.save_model( join_network_model, output_dir, epoch + 1, global_step, val_metric, val_score, best_val_score, name="best_model") import pdb; pdb.set_trace()