コード例 #1
0
def run_all_folds(
    fm: FineModel,
    depth_index,
    lr_index,
    epochs,
    train_gens,
    val_gens,
    fold_index=None,
    generate_plots=True,
):
    """
    Train the model (frozen at some depth) for all five folds OR a specific
    fold. Weights, history and results are saved using instance keys in the
    following format:

    D01_L03_F01:
    1st freeze depth, 3rd learning rate, fold 1
    D01_L03_F01_E025:
    1st freeze depth, 3rd learning rate, fold 1, trained until the 25th epoch

    :param fm:
    FineModel to train, i.e., the base network to train on

    :param depth_index:
    The INDEX of the "freeze depth" for the given FineModel

    :param lr_index:
    The INDEX of the learning rate, i.e., lr = LEARNING_RATES[lr_index]

    :param epochs:
    Number of epochs to train. MUST BE MULTIPLE OF 5.

    :param train_gens:
    List of train ImageDataGenerators for each fold

    :param val_gens:
    List of validation ImageDataGenerators for each fold

    :param fold_index
    If specified, will only run the specific fold index
    """
    _depth_key = "D{:02}"
    _final_key = _depth_key + "_FINAL"
    _fold_key = _depth_key + "_L{:02}_F{:02}"
    _epoch_key = _fold_key + "_E{:03}"

    lr = LEARNING_RATES[lr_index]

    folds = range(K)
    if fold_index is not None:
        if fold_index < 0 or K <= fold_index:
            raise IndexError("Invalid fold_index: {}".format(fold_index))
        folds = [fold_index]
        print("Fold index {} specified".format(fold_index))

    # Train model K times, one for each fold
    for i in folds:
        fold_key = _fold_key.format(depth_index, lr_index, i)

        # Load model at previous state
        previous_depth_index = depth_index - 1
        if previous_depth_index < 0:
            fm.reload_model()
        else:
            fm.load_weights(_final_key.format(previous_depth_index))
        fm.set_depth(depth_index)
        fm.compile_model(lr=lr)
        model = fm.get_model()

        print("[DEBUG] Batch size: {}".format(BATCH_SIZE))
        print("[DEBUG] Number of images: {}".format(train_gens[i].n))
        print("[DEBUG] Steps: {}".format(len(train_gens[i])))

        # Train T epochs at a time
        start_epoch = 0
        save_interval = T
        # Reset training history
        ch.reset_history(fm.get_key(), fold_key)
        while start_epoch < epochs:
            print("[DEBUG] Starting epoch {}".format(start_epoch))
            target_epoch = start_epoch + save_interval
            if target_epoch > epochs:
                target_epoch = epochs
            result = model.fit_generator(
                train_gens[i],
                validation_data=val_gens[i],
                steps_per_epoch=len(train_gens[i]),
                validation_steps=len(val_gens[i]),
                workers=MULTIPROCESSING_WORKERS,
                use_multiprocessing=USE_MULTIPROCESSING,
                shuffle=True,
                epochs=target_epoch,
                initial_epoch=start_epoch,
            )
            start_epoch = target_epoch

            # Update training history every T epochs
            ch.append_history(result.history, fm.get_key(), fold_key)

            # Save intermediate weights every T epochs
            if SAVE_ALL_WEIGHTS:
                epoch_key = _epoch_key.format(depth_index, lr_index, i,
                                              target_epoch)
                fm.save_weights(epoch_key)

        # Save final weights
        fm.save_weights(fold_key)

    if fold_index is None and generate_plots:
        # Only generate analysis when running all K folds
        print("[debug] generating analysis of training process")
        for metric_key in analysis.metric_names:
            analysis.analyze_lr(fm, fm.get_key(), depth_index, lr_index, lr,
                                metric_key)
コード例 #2
0
def run(fm: FineModel,
        training_set: cri.CrCollection,
        epochs=EPOCHS,
        depth_index=DEPTH_INDEX,
        batch_size=BATCH_SIZE,
        augment_factor=BALANCE,
        learning_rate=LEARNING_RATE,
        save_interval=T,
        use_multiprocessing=USE_MULTIPROCESSING,
        workers=MULTIPROCESSING_WORKERS):
    """
    Train model and evalute results. Output files are saved to
    `output/<model_key>/D00_FINAL/`. These include:

    - Intemediate model weights
    - Final model weights
    - Test set result
    - Training history
    """
    _depth_key = 'D{:02d}_FINAL'
    instance_key = _depth_key.format(depth_index)
    _epoch_key = instance_key + "_E{:03}"

    if depth_index >= 1:
        fm.load_weights(_depth_key.format(depth_index - 1))
    fm.set_depth(depth_index)
    fm.compile_model(lr=learning_rate)
    model = fm.get_model()

    gen = fm.get_directory_iterator(training_set,
                                    'train',
                                    augment=True,
                                    augment_factor=augment_factor,
                                    shuffle=True,
                                    batch_size=batch_size,
                                    verbose=1,
                                    title='final training set')

    print("[DEBUG] Batch size: {}".format(batch_size))
    print("[DEBUG] Number of images: {}".format(gen.n))
    print("[DEBUG] Steps: {}".format(len(gen)))

    # Train T epochs at a time
    start_epoch = 0
    # Reset training history
    ch.reset_history(fm.get_key(), instance_key)
    while start_epoch < epochs:
        print("[DEBUG] Starting epoch {}".format(start_epoch))
        target_epoch = start_epoch + save_interval
        if target_epoch > epochs:
            target_epoch = epochs
        result = model.fit_generator(
            gen,
            steps_per_epoch=len(gen),
            shuffle=True,
            epochs=target_epoch,
            use_multiprocessing=use_multiprocessing,
            workers=workers,
            initial_epoch=start_epoch,
        )
        start_epoch = target_epoch

        # Update training history every T epochs
        ch.append_history(result.history, fm.get_key(), instance_key)

        # Save intermediate weights every T epochs
        if SAVE_ALL_WEIGHTS:
            epoch_key = _epoch_key.format(target_epoch)
            fm.save_weights(epoch_key)

    # Save final weights
    fm.save_weights(instance_key)

    # Generate test results
    print("[DEBUG] Generating test results...")
    results.generate_test_result(
        fm,
        instance_key,
        learning_rate,
        epochs,
        load_weights=False,
        workers=MULTIPROCESSING_WORKERS,
        use_multiprocessing=USE_MULTIPROCESSING,
    )