def plot_RBF_grid_search(data, weights=[0.6, 0.8, 1, 1.2], n_nodes=[10, 15, 20, 25], learning_mode=LearningMode.BATCH, centers_sampling=CentersSampling.LINEAR): results = {} for n in n_nodes: for w in weights: rbf_net = RBF(centers_sampling, n_nodes=n, n_inter=n, sigma=w) if learning_mode == LearningMode.BATCH: y_hat, error = rbf_net.batch_learning(data.x, data.y, data.x_test, data.y_test) else: y_hat, error = rbf_net.delta_learning(data.x, data.y, data.x_test, data.y_test, max_iters=20, lr=0.001) results[(rbf_net.n_nodes, w)] = error keys = np.array(list(results.keys())) plt.scatter(keys[:, 0], keys[:, 1], c=list(results.values()), cmap='tab20b', s=200) plt.xlabel('units') plt.ylabel('width') plt.title('Absolute residual error for different RBF configurations') plt.colorbar() plt.show() return results
def experiment(data, learning_mode, centers_sampling, n_nodes=None, error=None, n=20, n_iter=3, weight=1.0, drop=2**9 - 1, sigma=1.0, neigh=1, max_iter=20, lr=0.1): rbf_net = RBF(centers_sampling, n_nodes=n, n_inter=n_iter, weight=weight, drop=drop, x=data.x, sigma=sigma) if learning_mode == LearningMode.BATCH: y_hat, err = rbf_net.batch_learning(data.x, data.y, data.x_test, data.y_test) elif learning_mode == LearningMode.DELTA: y_hat, err = rbf_net.delta_learning(data.x, data.y, data.x_test, data.y_test, lr=lr, max_iters=max_iter) else: y_hat, err = rbf_net.hybrid_learning(data.x, data.y, data.x_test, data.y_test, lr=lr, neigh=neigh, max_iters=max_iter) if n_nodes != None and error != None: n_nodes.append(rbf_net.n_nodes) error.append(err) return y_hat, err, rbf_net
def plot_estimate(data, centers_sampling=CentersSampling.LINEAR, learning_type='batch', n_nodes=20, delta_max_iters=100, sigma=0.5, delta_lr=0.1, weight=1): rbf_net = RBF(centers_sampling, n_nodes=n_nodes, sigma=sigma, weight=weight) if learning_type == 'batch': y_hat, error = rbf_net.batch_learning(data.x, data.y, data.x_test, data.y_test) else: y_hat, error = rbf_net.delta_learning(data.x, data.y, data.x_test, data.y_test, max_iters=delta_max_iters, lr=delta_lr) centers, n_nodes = rbf_net.centers, rbf_net.n_nodes plt.plot(data.x_test, data.y_test, label="Target") plt.plot(data.x_test, y_hat, label="Estimate") plt.scatter(centers, [0] * n_nodes, c="r", label="RBF Centers") plt.xlabel("x") plt.ylabel("y") if sigma != 0.5: plt.title( f'{learning_type}, {n_nodes} units, {sigma} width, error= {round(error,5)}' ) else: plt.title( f'{learning_type} learning, {n_nodes} RBF units, error= {round(error,5)}' ) plt.legend() plt.grid(True) plt.show()
def error_estimate_batch(data, n_nodes=20, sigma=0.5): rbf_net = RBF(CentersSampling.LINEAR, n_nodes=n_nodes, sigma=0.5) y_hat, error = rbf_net.batch_learning(data.x, data.y, data.x_test, data.y_test) print(f'Error for batch learning: {error}')