예제 #1
0
    def test_genn(self):
        import numpy as np
        import matplotlib.pyplot as plt
        from smt.surrogate_models.genn import GENN, load_smt_data

        # Training data
        lower_bound = -np.pi
        upper_bound = np.pi
        number_of_training_points = 4
        xt = np.linspace(lower_bound, upper_bound, number_of_training_points)
        yt = xt * np.sin(xt)
        dyt_dxt = np.sin(xt) + xt * np.cos(xt)

        # Validation data
        number_of_validation_points = 30
        xv = np.linspace(lower_bound, upper_bound, number_of_validation_points)
        yv = xv * np.sin(xv)
        dyv_dxv = np.sin(xv) + xv * np.cos(xv)

        # Truth model
        x = np.arange(lower_bound, upper_bound, 0.01)
        y = x * np.sin(x)

        # GENN
        genn = GENN()
        genn.options[
            "alpha"] = 0.1  # learning rate that controls optimizer step size
        genn.options[
            "beta1"] = 0.9  # tuning parameter to control ADAM optimization
        genn.options[
            "beta2"] = 0.99  # tuning parameter to control ADAM optimization
        genn.options[
            "lambd"] = 0.1  # lambd = 0. = no regularization, lambd > 0 = regularization
        genn.options[
            "gamma"] = 1.0  # gamma = 0. = no grad-enhancement, gamma > 0 = grad-enhancement
        genn.options["deep"] = 2  # number of hidden layers
        genn.options["wide"] = 6  # number of nodes per hidden layer
        genn.options[
            "mini_batch_size"] = 64  # used to divide data into training batches (use for large data sets)
        genn.options["num_epochs"] = 20  # number of passes through data
        genn.options[
            "num_iterations"] = 100  # number of optimizer iterations per mini-batch
        genn.options["is_print"] = True  # print output (or not)
        load_smt_data(
            genn, xt, yt, dyt_dxt
        )  # convenience function to read in data that is in SMT format
        genn.train()  # API function to train model
        genn.plot_training_history(
        )  # non-API function to plot training history (to check convergence)
        genn.goodness_of_fit(
            xv, yv,
            dyv_dxv)  # non-API function to check accuracy of regression
        y_pred = genn.predict_values(
            x)  # API function to predict values at new (unseen) points

        # Plot
        fig, ax = plt.subplots()
        ax.plot(x, y_pred)
        ax.plot(x, y, "k--")
        ax.plot(xv, yv, "ro")
        ax.plot(xt, yt, "k+", mew=3, ms=10)
        ax.set(xlabel="x", ylabel="y", title="GENN")
        ax.legend(["Predicted", "True", "Test", "Train"])
        plt.show()
예제 #2
0
def run_demo_1D(is_gradient_enhancement=True):  # pragma: no cover
    """Test and demonstrate GENN using a 1D example"""

    # Test function
    f = lambda x: x * np.sin(x)
    df_dx = lambda x: np.sin(x) + x * np.cos(x)

    # Domain
    lb = -np.pi
    ub = np.pi

    # Training data
    m = 4
    xt = np.linspace(lb, ub, m)
    yt = f(xt)
    dyt_dxt = df_dx(xt)

    # Validation data
    xv = lb + np.random.rand(30, 1) * (ub - lb)
    yv = f(xv)
    dyv_dxv = df_dx(xv)

    # Initialize GENN object
    genn = GENN()
    genn.options["alpha"] = 0.05
    genn.options["beta1"] = 0.9
    genn.options["beta2"] = 0.99
    genn.options["lambd"] = 0.05
    genn.options["gamma"] = int(is_gradient_enhancement)
    genn.options["deep"] = 2
    genn.options["wide"] = 6
    genn.options["mini_batch_size"] = 64
    genn.options["num_epochs"] = 25
    genn.options["num_iterations"] = 100
    genn.options["seed"] = SEED
    genn.options["is_print"] = True

    # Load data
    load_smt_data(genn, xt, yt, dyt_dxt)

    # Train
    genn.train()
    genn.plot_training_history()
    genn.goodness_of_fit(xv, yv, dyv_dxv)

    # Plot comparison
    if genn.options["gamma"] == 1.0:
        title = 'with gradient enhancement'
    else:
        title = 'without gradient enhancement'
    x = np.arange(lb, ub, 0.01)
    y = f(x)
    y_pred = genn.predict_values(x)
    fig, ax = plt.subplots()
    ax.plot(x, y_pred)
    ax.plot(x, y, 'k--')
    ax.plot(xv, yv, 'ro')
    ax.plot(xt, yt, 'k+', mew=3, ms=10)
    ax.set(xlabel='x', ylabel='y', title=title)
    ax.legend(['Predicted', 'True', 'Test', 'Train'])
    plt.show()