def main(plot=False, train=False, epochs=5):
    """ Main function """
    # Get mnist train and test dataset
    (x_train, y_train), (x_test, y_test) = get_real_mnist()

    # Get gan test dataset
    (x_gan_test, y_gan_test) = get_gan_mnist()

    # Preprocess raw data
    print('preprocess raw data')
    x_train = preprocess_raw_mnist_data(x_train)
    x_test = preprocess_raw_mnist_data(x_test)
    x_gan_test = preprocess_raw_mnist_data(x_gan_test)

    # Build classifier
    fcnn_clf = fcnn_classifier()

    if train:
        # Train classifier
        print('\ntrain the classifier')

        history = fcnn_clf.fit(x_train,
                               y_train,
                               epochs=epochs,
                               validation_split=0.1)

        # Save weights
        fcnn_clf.save_weights('weights/fcnn_clf_%s.h5' % epochs)

        #Get data from history
        print(history.history.keys())
        plt.plot(history.history['acc'])
        plt.plot(history.history['val_acc'])
        plt.title("model accuracy")
        plt.ylabel('accuracy')
        plt.xlabel('epoch')
        plt.legend(['train', 'val'], loc='upper left')
        plt.savefig("output/fully_connected_model_accuracy.png")
        plt.show()
        #Save the plot

        #Plot the loss
        plt.plot(history.history['loss'])
        plt.plot(history.history['val_loss'])
        plt.title('model loss')
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.legend(['train', 'val'], loc='upper left')
        plt.savefig("output/fully_connected_model_loss.png")
        plt.show()
    else:
        # Load the model weights
        import os
        weights_file_path = os.path.abspath(
            os.path.join(os.curdir, 'weights/fcnn_clf_%s.h5' % epochs))
        if not print(os.path.exists(weights_file_path)):
            print("The weights file path specified does not exists: %s" %
                  os.path.exists(weights_file_path))
        fcnn_clf.load_weights(weights_file_path)

    print('\ntest the classifier')
    test_loss, test_acc = fcnn_clf.evaluate(x_test[:1000], y_test[:1000])

    print('\n#######################################')
    print('Test loss:', test_loss)
    print('Test accuracy:', test_acc)

    print('\ntest the classifier on gan mnist')
    test_loss, test_acc = fcnn_clf.evaluate(x_gan_test[:1000],
                                            y_gan_test[:1000])

    print('\n#######################################')
    print('Test loss gan:', test_loss)
    print('Test accuracy gan:', test_acc)

    if plot:
        plot_mist(x_train, y_train, 9, save_file_path='plots/test.png')
def main(plot=False, train=False):
    """ Main function """
    # Get mnist train and test dataset
    (x_train, y_train), (x_test, y_test) = get_real_mnist()

    # Get gan test dataset
    (x_gan_test, y_gan_test) = get_gan_mnist()

    # Preprocess raw data
    print('preprocess raw data')
    x_train = preprocess_raw_mnist_data(x_train, conv=True)
    x_test = preprocess_raw_mnist_data(x_test, conv=True)
    x_gan_test = preprocess_raw_mnist_data(x_gan_test, conv=True)

    # Build classifier
    cnn_clf = cnn_classifier()

    epochs = 5

    if train:
        # Train classifier
        print('\ntrain the classifier')

        #history = cnn_clf.fit(x_train, y_train, epochs=epochs, validation_data=(x_test, y_test))
        history = cnn_clf.fit(x_train,
                              y_train,
                              epochs=epochs,
                              validation_split=0.1)

        # Save weights
        cnn_clf.save_weights('weights/cnn_clf_%s.h5' % epochs)

        #Plots train and validation datasets
        #Get data from history
        print(history.history.keys())
        plt.plot(history.history['acc'])
        plt.plot(history.history['val_acc'])
        plt.title("model accuracy")
        plt.ylabel('accuracy')
        plt.xlabel('epoch')
        plt.legend(['train', 'test'], loc='upper left')
        plt.savefig("output/fully_connected_model_accuracy.png")
        plt.show()
        #Save the plot

        #Plot the loss
        plt.plot(history.history['loss'])
        plt.plot(history.history['val_loss'])
        plt.title('model loss')
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.legend(['train', 'test'], loc='upper left')
        plt.savefig("output/fully_connected_model_loss.png")
        plt.show()

    else:
        # Load the model weights
        import os
        weights_file_path = os.path.abspath(
            os.path.join(os.curdir, 'weights/cnn_clf_%s.h5' % epochs))
        if not print(os.path.exists(weights_file_path)):
            print("The weights file path specified does not exists: %s" %
                  os.path.exists(weights_file_path))
        cnn_clf.load_weights(weights_file_path)

    print('\ntest the classifier')
    test_loss, test_acc = cnn_clf.evaluate(x_test, y_test)

    print('\n#######################################')
    print('Test loss:', test_loss)
    print('Test accuracy:', test_acc)

    print('\ntest the classifier on gan mnist')
    test_loss, test_acc = cnn_clf.evaluate(x_gan_test[:100], y_gan_test[:100])

    print('\n#######################################')
    print('Test loss gan:', test_loss)
    print('Test accuracy gan:', test_acc)

    class_idx = 0
    indices = np.where(y_test[:, class_idx] == 1.)[0]

    # pick some random input from here.
    idx = indices[0]

    # Lets sanity check the picked image.
    plt.rcParams['figure.figsize'] = (18, 6)

    plt.imshow(x_test[idx][..., 0])

    if plot:
        plot_mist(x_train, y_train, 9, save_file_path='plots/test.png')
def main(plot=False, train=False, epochs=5, attention=False):
    """ Main function """
    # Get mnist train and test dataset
    (x_train, y_train), (x_test, y_test) = get_real_mnist()

    # Get gan test dataset
    (x_gan_test, y_gan_test) = get_gan_mnist()

    # Preprocess raw data
    print('preprocess raw data')
    x_train = preprocess_raw_mnist_data(x_train, conv=True)
    x_test = preprocess_raw_mnist_data(x_test, conv=True)
    x_gan_test = preprocess_raw_mnist_data(x_gan_test, conv=True)

    # Build classifier
    cnn_clf = cnn_classifier()

    if train:
        # Train classifier
        print('\ntrain the classifier')
        # cnn_clf.compile(optimizer=keras.optimizers.Adam(),
        #                 loss=keras.losses.sparse_categorical_crossentropy,
        #                 metrics=['accuracy'])

        #history = cnn_clf.fit(x_train, y_train, epochs=epochs, validation_data=(x_test, y_test))
        history = cnn_clf.fit(x_train,
                              y_train,
                              epochs=epochs,
                              validation_split=0.1)

        # Save weights
        cnn_clf.save_weights('weights/cnn_clf_4layer_%s.h5' % epochs)

        #Plots train and validation datasets
        #Get data from history
        print(history.history.keys())
        plt.plot(history.history['acc'])
        plt.plot(history.history['val_acc'])
        plt.title("model accuracy")
        plt.ylabel('accuracy')
        plt.xlabel('epoch')
        plt.legend(['train', 'val'], loc='upper left')
        plt.savefig("output/conv_4layer_model_accuracy.png")
        plt.show()
        #Save the plot

        #Plot the loss
        plt.plot(history.history['loss'])
        plt.plot(history.history['val_loss'])
        plt.title('model loss')
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.legend(['train', 'val'], loc='upper left')
        plt.savefig("output/conv_4layer_model_loss.png")
        plt.show()

    else:
        # Load the model weights
        import os
        weights_file_path = os.path.abspath(
            os.path.join(os.curdir, 'weights/cnn_clf_4layer_%s.h5' % epochs))
        if not print(os.path.exists(weights_file_path)):
            print("The weights file path specified does not exists: %s" %
                  os.path.exists(weights_file_path))
        cnn_clf.load_weights(weights_file_path)

    print('\ntest the classifier')
    test_loss, test_acc = cnn_clf.evaluate(x_test[:1000], y_test[:1000])

    print('\n#######################################')
    print('Test loss:', test_loss)
    print('Test accuracy:', test_acc)

    print('\ntest the classifier on gan mnist')
    test_loss, test_acc = cnn_clf.evaluate(x_gan_test[:1000],
                                           y_gan_test[:1000])

    print('\n#######################################')
    print('Test loss gan:', test_loss)
    print('Test accuracy gan:', test_acc)

    if attention:
        attention_visualization(cnn_clf, x_test, y_test)

    if plot:
        plot_mist(x_train, y_train, 9, save_file_path='plots/test.png')
Ejemplo n.º 4
0
def main(plot=False, train=False, epochs=5, unlabelled_size=5400):
    """ Main function """
    # Get mnist train and test dataset
    (x_train, y_train), (x_test, y_test) = get_real_mnist()

    # Get gan test dataset
    (x_gan_train, y_gan_train) = get_gan_mnist()

    # Preprocess raw data
    print('preprocess raw data')
    x_train = preprocess_raw_mnist_data(x_train, conv=True)
    x_test = preprocess_raw_mnist_data(x_test, conv=True)
    x_gan_train = preprocess_raw_mnist_data(x_gan_train, conv=True)

    # Build modified discriminator
    discrim = discriminator()

    # Combine real and synthetic data
    label_size = 5400
    ratio = label_size / (label_size + unlabelled_size)

    if train:
        # Train classifier
        print('\ntrain the classifier')

        x_comb_train = np.append(x_gan_train[:unlabelled_size],
                                 x_train[:label_size],
                                 axis=0)
        y_comb_train = [
            np.append(np.zeros(unlabelled_size), np.ones(label_size), axis=0),
            # We assign the label 10 to all unlabelled examples as the digit that represents a fake image.
            np.append(np.zeros(unlabelled_size) + 10,
                      y_train[:label_size],
                      axis=0)
        ]

        discrim.fit(x_comb_train, y_comb_train, epochs=epochs)

        # Save weights
        discrim.save_weights('weights/semi_sup_clf_%s_r%s.h5' %
                             (epochs, ratio))

    else:
        # Load the model weights
        import os
        weights_file_path = os.path.abspath(
            os.path.join(os.curdir,
                         'weights/semi_sup_clf_%s_r%s.h5' % (epochs, ratio)))
        if not print(os.path.exists(weights_file_path)):
            print("The weights file path specified does not exists: %s" %
                  os.path.exists(weights_file_path))
        discrim.load_weights(weights_file_path)

    print('\ntest the classifier')
    comb_test_loss, discr_test_loss, label_clf_test_loss, discr_test_acc, label_clf_test_acc = discrim.evaluate(
        x_test[:1000], [np.ones(len(y_test[:1000])), y_test[:1000]])

    print('\n#######################################')
    print('Combined output layers Test loss:', comb_test_loss)
    print('Discriminator output layer Test loss:', discr_test_loss)
    print('Label classifier output layer Test loss:', label_clf_test_loss)
    print('Discriminator output layer Test accuracy:', discr_test_acc)
    print('Label classifier output layer Test accuracy:', label_clf_test_acc)

    if plot:
        plot_mist(x_train, y_train, 9, save_file_path='plots/test.png')
def main(plot=False,
         train=False,
         epochs=5,
         attention=False,
         conf_matrix=False):
    """ Main function """
    # Get mnist train and test dataset
    (x_train, y_train), (x_test, y_test) = get_real_mnist()

    # Get gan test dataset
    (x_gan_test, y_gan_test) = get_gan_mnist()

    # Preprocess raw data
    print('preprocess raw data')
    x_train = preprocess_raw_mnist_data(x_train, conv=True)
    x_test = preprocess_raw_mnist_data(x_test, conv=True)
    x_gan_test = preprocess_raw_mnist_data(x_gan_test, conv=True)

    # Build classifier
    cnn_clf = cnn_classifier()

    if train:
        # Train classifier
        print('\ntrain the classifier')

        #history = cnn_clf.fit(x_train, y_train, epochs=epochs, validation_data=(x_test, y_test))
        history = cnn_clf.fit(x_train,
                              y_train,
                              epochs=epochs,
                              validation_split=0.1)

        # Save weights
        cnn_clf.save_weights('weights/cnn_clf_5layer_%s.h5' % epochs)

        #Plots train and validation datasets
        #Get data from history
        print(history.history.keys())
        plt.plot(history.history['acc'])
        plt.plot(history.history['val_acc'])
        plt.title("model accuracy")
        plt.ylabel('accuracy')
        plt.xlabel('epoch')
        plt.legend(['train', 'val'], loc='upper left')
        plt.savefig("output/conv_5layer_epoch_" + str(epochs) +
                    "_model_accuracy.png")
        plt.show()
        #Save the plot

        #Plot the loss
        plt.plot(history.history['loss'])
        plt.plot(history.history['val_loss'])
        plt.title('model loss')
        plt.ylabel('loss')
        plt.xlabel('epoch')
        plt.legend(['train', 'val'], loc='upper left')
        plt.savefig("output/conv_5layer_epoch_" + str(epochs) +
                    "model_loss.png")
        plt.show()

    else:
        # Load the model weights
        import os
        weights_file_path = os.path.abspath(
            os.path.join(os.curdir, 'weights/cnn_clf_5layer_%s.h5' % epochs))
        if not print(os.path.exists(weights_file_path)):
            print("The weights file path specified does not exists: %s" %
                  os.path.exists(weights_file_path))
        cnn_clf.load_weights(weights_file_path)

    print('\ntest the classifier')
    test_loss, test_acc = cnn_clf.evaluate(x_test[:1000], y_test[:1000])

    print('\n#######################################')
    print('Test loss:', test_loss)
    print('Test accuracy:', test_acc)

    print('\ntest the classifier on gan mnist')
    test_loss, test_acc = cnn_clf.evaluate(x_gan_test[:1000],
                                           y_gan_test[:1000])

    print('\n#######################################')
    print('Test loss gan:', test_loss)
    print('Test accuracy gan:', test_acc)

    if attention:
        attention_visualization(cnn_clf, x_test[:1000], y_test[:1000], epochs,
                                "real")
        attention_visualization(cnn_clf, x_gan_test[:1000], y_gan_test[:1000],
                                epochs, "synthetic")

    if conf_matrix:
        #Original predict returns 1 hot encoding, so use argmax instead
        y_pred = np.argmax(cnn_clf.predict(x_test[:1000]), axis=1)
        cm = confusion_matrix(y_test[:1000], y_pred)
        class_names = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

        y_gan_pred = np.argmax(cnn_clf.predict(x_gan_test[:1000]), axis=1)
        gan_cm = confusion_matrix(y_gan_test[:1000], y_gan_pred)

        # Plot non-normalized confusion matrix
        plt.figure()
        plot_confusion_matrix(
            cm,
            classes=class_names,
            title='Confusion matrix real images, without normalization')
        plt.savefig("output/confusion_matrix_real_cnn_5layer_epoch_" +
                    str(epochs) + ".png")

        plt.figure()
        plot_confusion_matrix(
            gan_cm,
            classes=class_names,
            title='Confusion matrix synthetic images, without normalization')
        plt.savefig("output/confusion_matrix_synthetic_cnn_5layer_epoch_" +
                    str(epochs) + ".png")

        # Plot normalized confusion matrix
        plt.figure()
        plot_confusion_matrix(cm,
                              classes=class_names,
                              normalize=True,
                              title='Normalized confusion matrix real images')
        plt.savefig(
            "output/confusion_matrix_real_normalized_cnn_5layer_epoch_" +
            str(epochs) + ".png")

        plt.figure()
        plot_confusion_matrix(
            gan_cm,
            classes=class_names,
            normalize=True,
            title='Normalized confusion matrix synthetic images')
        plt.savefig(
            "output/confusion_matrix_synthetic_normalized_cnn_5layer_epoch_" +
            str(epochs) + ".png")

        plt.show()

    if plot:
        plot_mist(x_train, y_train, 9, save_file_path='plots/test.png')