Exemplo n.º 1
0
    # Run the features through the neural net (to compute a and z)
    y_pred = nn.feed_forward(X)

    # Compute the gradient
    grad_w = nn._gradient_fun(X, T)

    # Update the neural network weight matrices
    for w, gw in zip(nn.weights, grad_w):
        w -= eta * gw

    # Print some statistics every thousandth iteration
    if i % 1000 == 0:
        misclassified = sum(np.argmax(y_pred, axis=1) != y.ravel())
        print(
            "Iteration: {0}, Objective Function Value: {1:3f}, Misclassified: {2}"
            .format(i, nn._J_fun(X, T), misclassified))

# Predict the training data and classify
y_pred = np.argmax(nn.feed_forward(X_test), axis=1)
print("Test data accuracy: {0:3f}".format(1 - sum(y_pred != y_test.ravel()) /
                                          float(len(y_test))))

# Print a matrix of plots showing points and misfit point
fig, axs = plt.subplots(nrows=4, ncols=4)
for i in range(4):
    for j in range(4):
        if i > j:
            axs[i, j].scatter(X_test[:, i], X_test[:, j], c=y_pred)
            axs[i, j].plot(X_test[y_test != y_pred, i],
                           X_test[y_test != y_pred, j],
                           'ro',
Exemplo n.º 2
0
# Perform gradient descent
for i in range(N_iterations):

    # For stochastic gradient descent, take random samples of X and T
    batch = np.random.randint(0, m, size=batch_size)

    # Run the features through the neural net (to compute a and z)
    y_pred = nn.feed_forward(X)

    # Compute the gradient
    grad_w = nn._gradient_fun(X, y)

    # Update the neural network weight matrices
    for w, gw in zip(nn.weights, grad_w):
        w -= eta * gw

    # Print some statistics every thousandth iteration
    if i % 10000 == 0:
        print('Iteration: {0}, Objective Function Value: {1:3f}'
              .format(i, nn._J_fun(X, y_pred)))

print('Total error: {}'.format(
    np.abs(np.sum(y_pred - y) / np.sum(y))))
plt.plot(X, y, '*')
plt.plot(X, y_pred, '*')
plt.show()


# ========================= EOF ================================================================