コード例 #1
0
def test_train_and_error_calculation():
    input_dimensions = 6
    number_of_nodes = 2
    model = Perceptron(input_dimensions=input_dimensions,
                       number_of_nodes=number_of_nodes)
    X_train = np.random.randn(input_dimensions, 5)
    X_train = np.array(
        [[0.60302512, -0.94856749, -0.88904878, -0.02272178, 0.90226341],
         [0.23682103, -0.29573089, -0.2328635, 0.85537468, 0.19151025],
         [-0.35294026, -0.99263065, 0.41645806, -0.22561292, 2.72811021],
         [0.06180665, 1.0643302, 0.49739215, -1.81960612, 0.50104263],
         [-0.4875518, -0.98996947, 2.38729703, 0.95753127, -0.20929545],
         [-1.20261055, -1.84727613, -1.13450254, -1.57499346, 0.4382195]])
    Y_train = np.array([[0, 1, 1, 1, 0], [0, 0, 1, 0, 1]])
    model.initialize_weights(seed=2)
    error = []
    for k in range(20):
        error.append(model.calculate_percent_error(X_train, Y_train))
        model.train(X_train, Y_train, num_epochs=1, alpha=0.025)
    error.append(model.calculate_percent_error(X_train, Y_train))
    np.testing.assert_allclose(error, [
        80.0, 80.0, 80.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 60.0, 40.0, 20.0,
        20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0, 0.0, 0.0
    ],
                               rtol=1e-3,
                               atol=1e-3)
コード例 #2
0
def test_error_calculation():
    input_dimensions = 2
    number_of_classes = 2
    model = Perceptron(input_dimensions=input_dimensions, number_of_classes=number_of_classes, seed=1)
    X_train = np.array([[-1.43815556, 0.10089809, -1.25432937, 1.48410426],
                        [-1.81784194, 0.42935033, -1.2806198, 0.06527391]])
    Y_train = np.array([[1, 0, 0, 1], [0, 1, 1, 0]])
    model.initialize_all_weights_to_zeros()
    error = []
    for k in range(20):
        model.train(X_train, Y_train, num_epochs=1, alpha=0.0001)
        print(model.calculate_percent_error(X_train, Y_train))
        error.append(model.calculate_percent_error(X_train, Y_train))
    np.testing.assert_allclose(error,
                               [0.25, 0.5, 0.5, 0.25, 0.25, 0.25, 0.5, 0.25, 0.25, 0.25, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0,
                                0.0, 0.0, 0.0, 0.0], rtol=1e-3, atol=1e-3)