def train_with_augmentation(): datagen = ImageDataGenerator(rescale=1. / 255, rotation_range=30., horizontal_flip=True) model = vgg13() (X_train, Y_train), (X_test, Y_test), (X_validation, Y_validation) = fer2013() if load_weights: model.load_weights('model_vgg_13_aug.h5') history = model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size), samples_per_epoch=50000, nb_epoch=nb_epoch, validation_data=(X_validation, Y_validation)) predictions = model.predict(X_test, batch_size=batch_size, verbose=1) historic(history) confusion_matrix(predictions, Y_test) if save_weights: model.save_weights('model_vgg_13_aug.h5')
def test(model=None): if model == None: model = basic() model.load_weights('basic_fer2013.h5') (X_train, Y_train), (X_test, Y_test), (X_validation, Y_validation) = ck() scores = model.evaluate(X_train, Y_train) print(scores) predictions = model.predict(X_train, batch_size=batch_size, verbose=1) confusion_matrix(predictions, Y_train) scores = model.evaluate(X_validation, Y_validation) print(scores) predictions = model.predict(X_validation, batch_size=batch_size, verbose=1) confusion_matrix(predictions, Y_validation) scores = model.evaluate(X_test, Y_test) print(scores) predictions = model.predict(X_test, batch_size=batch_size, verbose=1) confusion_matrix(predictions, Y_test) (X_train, Y_train), (X_test, Y_test), (X_validation, Y_validation) = fer2013() scores = model.evaluate(X_train, Y_train) print(scores) predictions = model.predict(X_train, batch_size=batch_size, verbose=1) confusion_matrix(predictions, Y_train) scores = model.evaluate(X_validation, Y_validation) print(scores) predictions = model.predict(X_validation, batch_size=batch_size, verbose=1) confusion_matrix(predictions, Y_validation) scores = model.evaluate(X_test, Y_test) print(scores) predictions = model.predict(X_test, batch_size=batch_size, verbose=1) confusion_matrix(predictions, Y_test)
def train_without_augmentation(): model = vgg16(lr=0.0001, dropout_in=0.25, dropout_out=0.5) (X_train, Y_train), (X_test, Y_test), (X_validation, Y_validation) = fer2013() if load_weights: model.load_weights('model_vgg_16_eq.h5') history = model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, validation_data=(X_validation, Y_validation), shuffle=True) if save_weights: model.save_weights('vgg16_fer2013_np.h5') predictions = model.predict(X_test, batch_size=batch_size, verbose=1) evaluates = model.evaluate(X_test, Y_test) historic(history) confusion_matrix(predictions, Y_test)
def train_without_augmentation_basic(): model = basic() (X_train, Y_train), (X_test, Y_test), (X_validation, Y_validation) = fer2013() if load_weights: model.load_weights('model_vgg_16_eq.h5') history = model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, validation_data=(X_validation, Y_validation), shuffle=True) if save_weights: model.save_weights('basic_fer2013.h5') predictions = model.predict(X_test, batch_size=batch_size, verbose=1) evaluates = model.evaluate(X_test, Y_test) print(history) historic(history) confusion_matrix(predictions, Y_test)
from preprocess import fer2013 from models.vgg import basic from train import train_without_augmentation if __name__ == '__main__': '''Load & preprocess data data''' (X_train, Y_train), (X_test, Y_test), (X_validation, Y_validation) = fer2013() '''Define model''' model = basic() '''Train model''' train_without_augmentation()