def class_avg_chainthaw(model, nb_classes, train, val, test, batch_size, loss, epoch_size, nb_epochs, checkpoint_weight_path, f1_init_weight_path, patience=5, initial_lr=0.001, next_lr=0.0001, seed=None, verbose=True): """ Finetunes given model using chain-thaw and evaluates using F1. For a dataset with multiple classes, the model is trained once for each class, relabeling those classes into a binary classification task. The result is an average of all F1 scores for each class. # Arguments: model: Model to be finetuned. nb_classes: Number of classes in the given dataset. train: Training data, given as a tuple of (inputs, outputs) val: Validation data, given as a tuple of (inputs, outputs) test: Testing data, given as a tuple of (inputs, outputs) batch_size: Batch size. loss: Loss function to be used during training. epoch_size: Number of samples in an epoch. nb_epochs: Number of epochs. checkpoint_weight_path: Filepath where weights will be checkpointed to during training. This file will be rewritten by the function. f1_init_weight_path: Filepath where weights will be saved to and reloaded from before training each class. This ensures that each class is trained independently. This file will be rewritten. initial_lr: Initial learning rate. Will only be used for the first training step (i.e. the softmax layer) next_lr: Learning rate for every subsequent step. seed: Random number generator seed. verbose: Verbosity flag. # Returns: Averaged F1 score. """ # Unpack args X_train, y_train = train X_val, y_val = val X_test, y_test = test total_f1 = 0 nb_iter = nb_classes if nb_classes > 2 else 1 model.save_weights(f1_init_weight_path) for i in range(nb_iter): if verbose: print('Iteration number {}/{}'.format(i + 1, nb_iter)) model.load_weights(f1_init_weight_path, by_name=False) y_train_new, y_val_new, y_test_new = prepare_labels( y_train, y_val, y_test, i, nb_classes) train_gen, X_val_resamp, y_val_resamp = \ prepare_generators(X_train, y_train_new, X_val, y_val_new, batch_size, epoch_size) if verbose: print("Training..") callbacks = finetuning_callbacks(checkpoint_weight_path, patience=patience) # Train using chain-thaw train_by_chain_thaw(model=model, train_gen=train_gen, val_data=(X_val_resamp, y_val_resamp), loss=loss, callbacks=callbacks, epoch_size=epoch_size, nb_epochs=nb_epochs, checkpoint_weight_path=checkpoint_weight_path, initial_lr=initial_lr, next_lr=next_lr, batch_size=batch_size, verbose=verbose) # Evaluate y_pred_val = np.array(model.predict(X_val, batch_size=batch_size)) y_pred_test = np.array(model.predict(X_test, batch_size=batch_size)) f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val, y_test_new, y_pred_test) if verbose: print('f1_test: {}'.format(f1_test)) print('best_t: {}'.format(best_t)) total_f1 += f1_test return total_f1 / nb_iter
def class_avg_chainthaw(model, nb_classes, train, val, test, batch_size, loss, epoch_size, nb_epochs, checkpoint_weight_path, f1_init_weight_path, patience=5, initial_lr=0.001, next_lr=0.0001, seed=None, verbose=True): """ Finetunes given model using chain-thaw and evaluates using F1. For a dataset with multiple classes, the model is trained once for each class, relabeling those classes into a binary classification task. The result is an average of all F1 scores for each class. # Arguments: model: Model to be finetuned. nb_classes: Number of classes in the given dataset. train: Training data, given as a tuple of (inputs, outputs) val: Validation data, given as a tuple of (inputs, outputs) test: Testing data, given as a tuple of (inputs, outputs) batch_size: Batch size. loss: Loss function to be used during training. epoch_size: Number of samples in an epoch. nb_epochs: Number of epochs. checkpoint_weight_path: Filepath where weights will be checkpointed to during training. This file will be rewritten by the function. f1_init_weight_path: Filepath where weights will be saved to and reloaded from before training each class. This ensures that each class is trained independently. This file will be rewritten. initial_lr: Initial learning rate. Will only be used for the first training step (i.e. the softmax layer) next_lr: Learning rate for every subsequent step. seed: Random number generator seed. verbose: Verbosity flag. # Returns: Averaged F1 score. """ # Unpack args X_train, y_train = train X_val, y_val = val X_test, y_test = test total_f1 = 0 nb_iter = nb_classes if nb_classes > 2 else 1 model.save_weights(f1_init_weight_path) for i in range(nb_iter): if verbose: print('Iteration number {}/{}'.format(i + 1, nb_iter)) model.load_weights(f1_init_weight_path, by_name=False) y_train_new, y_val_new, y_test_new = prepare_labels(y_train, y_val, y_test, i, nb_classes) train_gen, X_val_resamp, y_val_resamp = \ prepare_generators(X_train, y_train_new, X_val, y_val_new, batch_size, epoch_size) if verbose: print("Training..") callbacks = finetuning_callbacks(checkpoint_weight_path, patience=patience) # Train using chain-thaw train_by_chain_thaw(model=model, train_gen=train_gen, val_data=(X_val_resamp, y_val_resamp), loss=loss, callbacks=callbacks, epoch_size=epoch_size, nb_epochs=nb_epochs, checkpoint_weight_path=checkpoint_weight_path, initial_lr=initial_lr, next_lr=next_lr, batch_size=batch_size, verbose=verbose) # Evaluate y_pred_val = np.array(model.predict(X_val, batch_size=batch_size)) y_pred_test = np.array(model.predict(X_test, batch_size=batch_size)) f1_test, best_t = find_f1_threshold(y_val_new, y_pred_val, y_test_new, y_pred_test) if verbose: print('f1_test: {}'.format(f1_test)) print('best_t: {}'.format(best_t)) total_f1 += f1_test return total_f1 / nb_iter