示例#1
0
def continue_training():
    """Continues training the chesspiece model based on SqueezeNet-v1.1.
    """
    model = load_model("./models/SqueezeNet1p1.h5")

    train_generator, validation_generator = data_generators(
        preprocess_input, (227, 227), 64)

    # Train all layers
    for layer in model.layers:
        layer.trainable = True

    model.compile(optimizer='Adam', loss='categorical_crossentropy',
                  metrics=['accuracy'])

    callbacks = model_callbacks(20, "./models/SqueezeNet1p1_all.h5", 0.2, 8)

    history = train_model(model, 100, train_generator, validation_generator,
                          callbacks, use_weights=False, workers=5)

    plot_model_history(history, "./models/SqueezeNet1p1_all_acc.png",
                       "./models/SqueezeNet1p1_all_loss.png")
    evaluate_model(model, validation_generator)

    model.save("./models/SqueezeNet1p1_all_last.h5")
示例#2
0
def continue_training():
    """Continues training the chesspiece model based on AlexNet."""
    model = load_model("./models/AlexNet.h5")

    train_generator, validation_generator = data_generators(
        preprocess_input, (224, 224), 64)

    model.compile(optimizer=Adam(lr=1e-4),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    callbacks = model_callbacks(20, "./models/AlexNet_2.h5", 0.2, 8)

    history = train_model(model,
                          100,
                          train_generator,
                          validation_generator,
                          callbacks,
                          use_weights=False,
                          workers=5)

    plot_model_history(history, "./models/AlexNet_2_acc.png",
                       "./models/AlexNet_2_loss.png")
    evaluate_model(model, validation_generator)

    model.save("./models/AlexNet_2_last.h5")
示例#3
0
def train_chesspiece_model():
    """Trains the chesspiece model based on MobileNetV2."""
    base_model = MobileNetV2(input_shape=(224, 224, 3),
                             include_top=False,
                             weights='imagenet',
                             alpha=0.5)

    # First train only the top layers
    for layer in base_model.layers:
        layer.trainable = False

    model = build_model(base_model)

    train_generator, validation_generator = data_generators(
        preprocess_input, (224, 224), 64)

    callbacks = model_callbacks(5, "./models/MobileNetV2_0p5_pre.h5", 0.1, 10)

    history = train_model(model,
                          20,
                          train_generator,
                          validation_generator,
                          callbacks,
                          use_weights=False,
                          workers=10)

    plot_model_history(history, "./models/MobileNetV2_0p5_pre_acc.png",
                       "./models/MobileNetV2_0p5_pre_loss.png")
    evaluate_model(model, validation_generator)

    # Also train blocks 14-16
    for layer in model.layers[:126]:
        layer.trainable = False
    for layer in model.layers[126:]:
        layer.trainable = True

    model.compile(optimizer='Adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    callbacks = model_callbacks(20, "./models/MobileNetV2_0p5.h5", 0.2, 8)

    history = train_model(model,
                          100,
                          train_generator,
                          validation_generator,
                          callbacks,
                          use_weights=False,
                          workers=10)

    plot_model_history(history, "./models/MobileNetV2_0p5_acc.png",
                       "./models/MobileNetV2_0p5_loss.png")
    evaluate_model(model, validation_generator)

    model.save("./models/MobileNetV2_0p5_last.h5")
示例#4
0
def train_chesspiece_model():
    """Trains the chesspiece model based on SqueezeNet-v1.1."""
    base_model = SqueezeNet(input_shape=(227, 227, 3), include_top=False,
                            weights='imagenet')

    # First train only the top layers
    for layer in base_model.layers:
        layer.trainable = False

    model = build_model(base_model)

    train_generator, validation_generator = data_generators(
        preprocess_input, (227, 227), 64)

    callbacks = model_callbacks(5, "./models/SqueezeNet1p1_pre.h5", 0.1, 10)

    history = train_model(model, 20, train_generator, validation_generator,
                          callbacks, use_weights=False, workers=5)

    plot_model_history(history, "./models/SqueezeNet1p1_pre_acc.png",
                       "./models/SqueezeNet1p1_pre_loss.png")
    evaluate_model(model, validation_generator)

    # Also train fire 7-9
    for layer in model.layers[:41]:
        layer.trainable = False
    for layer in model.layers[41:]:
        layer.trainable = True

    model.compile(optimizer='Adam', loss='categorical_crossentropy',
                  metrics=['accuracy'])

    callbacks = model_callbacks(20, "./models/SqueezeNet1p1.h5", 0.2, 8)

    history = train_model(model, 100, train_generator, validation_generator,
                          callbacks, use_weights=False, workers=5)

    plot_model_history(history, "./models/SqueezeNet1p1_acc.png",
                       "./models/SqueezeNet1p1_loss.png")
    evaluate_model(model, validation_generator)

    model.save("./models/SqueezeNet1p1_last.h5")
示例#5
0
def train_chesspiece_model():
    """Trains the chesspiece model based on AlexNet."""
    model = alexnet(input_shape=(224, 224, 3))

    train_generator, validation_generator = data_generators(
        preprocess_input, (224, 224), 64)

    callbacks = model_callbacks(20, "./models/AlexNet.h5", 0.2, 8)

    history = train_model(model,
                          100,
                          train_generator,
                          validation_generator,
                          callbacks,
                          use_weights=False,
                          workers=5)

    plot_model_history(history, "./models/AlexNet_acc.png",
                       "./models/AlexNet_loss.png")
    evaluate_model(model, validation_generator)

    model.save("./models/AlexNet_last.h5")