예제 #1
0
def plot_RBF_grid_search(data,
                         weights=[0.6, 0.8, 1, 1.2],
                         n_nodes=[10, 15, 20, 25],
                         learning_mode=LearningMode.BATCH,
                         centers_sampling=CentersSampling.LINEAR):
    results = {}

    for n in n_nodes:
        for w in weights:
            rbf_net = RBF(centers_sampling, n_nodes=n, n_inter=n, sigma=w)
            if learning_mode == LearningMode.BATCH:
                y_hat, error = rbf_net.batch_learning(data.x, data.y,
                                                      data.x_test, data.y_test)
            else:
                y_hat, error = rbf_net.delta_learning(data.x,
                                                      data.y,
                                                      data.x_test,
                                                      data.y_test,
                                                      max_iters=20,
                                                      lr=0.001)
            results[(rbf_net.n_nodes, w)] = error

    keys = np.array(list(results.keys()))
    plt.scatter(keys[:, 0],
                keys[:, 1],
                c=list(results.values()),
                cmap='tab20b',
                s=200)
    plt.xlabel('units')
    plt.ylabel('width')
    plt.title('Absolute residual error for different RBF configurations')
    plt.colorbar()
    plt.show()
    return results
예제 #2
0
def experiment(data,
               learning_mode,
               centers_sampling,
               n_nodes=None,
               error=None,
               n=20,
               n_iter=3,
               weight=1.0,
               drop=2**9 - 1,
               sigma=1.0,
               neigh=1,
               max_iter=20,
               lr=0.1):

    rbf_net = RBF(centers_sampling,
                  n_nodes=n,
                  n_inter=n_iter,
                  weight=weight,
                  drop=drop,
                  x=data.x,
                  sigma=sigma)

    if learning_mode == LearningMode.BATCH:
        y_hat, err = rbf_net.batch_learning(data.x, data.y, data.x_test,
                                            data.y_test)
    elif learning_mode == LearningMode.DELTA:
        y_hat, err = rbf_net.delta_learning(data.x,
                                            data.y,
                                            data.x_test,
                                            data.y_test,
                                            lr=lr,
                                            max_iters=max_iter)
    else:
        y_hat, err = rbf_net.hybrid_learning(data.x,
                                             data.y,
                                             data.x_test,
                                             data.y_test,
                                             lr=lr,
                                             neigh=neigh,
                                             max_iters=max_iter)

    if n_nodes != None and error != None:
        n_nodes.append(rbf_net.n_nodes)
        error.append(err)

    return y_hat, err, rbf_net
예제 #3
0
def plot_estimate(data,
                  centers_sampling=CentersSampling.LINEAR,
                  learning_type='batch',
                  n_nodes=20,
                  delta_max_iters=100,
                  sigma=0.5,
                  delta_lr=0.1,
                  weight=1):
    rbf_net = RBF(centers_sampling,
                  n_nodes=n_nodes,
                  sigma=sigma,
                  weight=weight)
    if learning_type == 'batch':
        y_hat, error = rbf_net.batch_learning(data.x, data.y, data.x_test,
                                              data.y_test)
    else:
        y_hat, error = rbf_net.delta_learning(data.x,
                                              data.y,
                                              data.x_test,
                                              data.y_test,
                                              max_iters=delta_max_iters,
                                              lr=delta_lr)
    centers, n_nodes = rbf_net.centers, rbf_net.n_nodes
    plt.plot(data.x_test, data.y_test, label="Target")
    plt.plot(data.x_test, y_hat, label="Estimate")
    plt.scatter(centers, [0] * n_nodes, c="r", label="RBF Centers")
    plt.xlabel("x")
    plt.ylabel("y")
    if sigma != 0.5:
        plt.title(
            f'{learning_type}, {n_nodes} units, {sigma} width, error= {round(error,5)}'
        )
    else:
        plt.title(
            f'{learning_type} learning, {n_nodes} RBF units, error= {round(error,5)}'
        )
    plt.legend()
    plt.grid(True)
    plt.show()
예제 #4
0
def error_estimate_batch(data, n_nodes=20, sigma=0.5):
    rbf_net = RBF(CentersSampling.LINEAR, n_nodes=n_nodes, sigma=0.5)
    y_hat, error = rbf_net.batch_learning(data.x, data.y, data.x_test,
                                          data.y_test)
    print(f'Error for batch learning: {error}')