def run_all_folds( fm: FineModel, depth_index, lr_index, epochs, train_gens, val_gens, fold_index=None, generate_plots=True, ): """ Train the model (frozen at some depth) for all five folds OR a specific fold. Weights, history and results are saved using instance keys in the following format: D01_L03_F01: 1st freeze depth, 3rd learning rate, fold 1 D01_L03_F01_E025: 1st freeze depth, 3rd learning rate, fold 1, trained until the 25th epoch :param fm: FineModel to train, i.e., the base network to train on :param depth_index: The INDEX of the "freeze depth" for the given FineModel :param lr_index: The INDEX of the learning rate, i.e., lr = LEARNING_RATES[lr_index] :param epochs: Number of epochs to train. MUST BE MULTIPLE OF 5. :param train_gens: List of train ImageDataGenerators for each fold :param val_gens: List of validation ImageDataGenerators for each fold :param fold_index If specified, will only run the specific fold index """ _depth_key = "D{:02}" _final_key = _depth_key + "_FINAL" _fold_key = _depth_key + "_L{:02}_F{:02}" _epoch_key = _fold_key + "_E{:03}" lr = LEARNING_RATES[lr_index] folds = range(K) if fold_index is not None: if fold_index < 0 or K <= fold_index: raise IndexError("Invalid fold_index: {}".format(fold_index)) folds = [fold_index] print("Fold index {} specified".format(fold_index)) # Train model K times, one for each fold for i in folds: fold_key = _fold_key.format(depth_index, lr_index, i) # Load model at previous state previous_depth_index = depth_index - 1 if previous_depth_index < 0: fm.reload_model() else: fm.load_weights(_final_key.format(previous_depth_index)) fm.set_depth(depth_index) fm.compile_model(lr=lr) model = fm.get_model() print("[DEBUG] Batch size: {}".format(BATCH_SIZE)) print("[DEBUG] Number of images: {}".format(train_gens[i].n)) print("[DEBUG] Steps: {}".format(len(train_gens[i]))) # Train T epochs at a time start_epoch = 0 save_interval = T # Reset training history ch.reset_history(fm.get_key(), fold_key) while start_epoch < epochs: print("[DEBUG] Starting epoch {}".format(start_epoch)) target_epoch = start_epoch + save_interval if target_epoch > epochs: target_epoch = epochs result = model.fit_generator( train_gens[i], validation_data=val_gens[i], steps_per_epoch=len(train_gens[i]), validation_steps=len(val_gens[i]), workers=MULTIPROCESSING_WORKERS, use_multiprocessing=USE_MULTIPROCESSING, shuffle=True, epochs=target_epoch, initial_epoch=start_epoch, ) start_epoch = target_epoch # Update training history every T epochs ch.append_history(result.history, fm.get_key(), fold_key) # Save intermediate weights every T epochs if SAVE_ALL_WEIGHTS: epoch_key = _epoch_key.format(depth_index, lr_index, i, target_epoch) fm.save_weights(epoch_key) # Save final weights fm.save_weights(fold_key) if fold_index is None and generate_plots: # Only generate analysis when running all K folds print("[debug] generating analysis of training process") for metric_key in analysis.metric_names: analysis.analyze_lr(fm, fm.get_key(), depth_index, lr_index, lr, metric_key)
def run(fm: FineModel, training_set: cri.CrCollection, epochs=EPOCHS, depth_index=DEPTH_INDEX, batch_size=BATCH_SIZE, augment_factor=BALANCE, learning_rate=LEARNING_RATE, save_interval=T, use_multiprocessing=USE_MULTIPROCESSING, workers=MULTIPROCESSING_WORKERS): """ Train model and evalute results. Output files are saved to `output/<model_key>/D00_FINAL/`. These include: - Intemediate model weights - Final model weights - Test set result - Training history """ _depth_key = 'D{:02d}_FINAL' instance_key = _depth_key.format(depth_index) _epoch_key = instance_key + "_E{:03}" if depth_index >= 1: fm.load_weights(_depth_key.format(depth_index - 1)) fm.set_depth(depth_index) fm.compile_model(lr=learning_rate) model = fm.get_model() gen = fm.get_directory_iterator(training_set, 'train', augment=True, augment_factor=augment_factor, shuffle=True, batch_size=batch_size, verbose=1, title='final training set') print("[DEBUG] Batch size: {}".format(batch_size)) print("[DEBUG] Number of images: {}".format(gen.n)) print("[DEBUG] Steps: {}".format(len(gen))) # Train T epochs at a time start_epoch = 0 # Reset training history ch.reset_history(fm.get_key(), instance_key) while start_epoch < epochs: print("[DEBUG] Starting epoch {}".format(start_epoch)) target_epoch = start_epoch + save_interval if target_epoch > epochs: target_epoch = epochs result = model.fit_generator( gen, steps_per_epoch=len(gen), shuffle=True, epochs=target_epoch, use_multiprocessing=use_multiprocessing, workers=workers, initial_epoch=start_epoch, ) start_epoch = target_epoch # Update training history every T epochs ch.append_history(result.history, fm.get_key(), instance_key) # Save intermediate weights every T epochs if SAVE_ALL_WEIGHTS: epoch_key = _epoch_key.format(target_epoch) fm.save_weights(epoch_key) # Save final weights fm.save_weights(instance_key) # Generate test results print("[DEBUG] Generating test results...") results.generate_test_result( fm, instance_key, learning_rate, epochs, load_weights=False, workers=MULTIPROCESSING_WORKERS, use_multiprocessing=USE_MULTIPROCESSING, )