Example #1
0
def test_xor(samples=10000, low=0.1, high=0.9, test_samples=1000):
    _nn_like.nn_like([2, 2, 1])
    _nn_like.print_connections()

    X_all = np.random.choice([low, high],
                             size=samples*2).reshape((samples, 2))
    y_all = np.array([[high] if a != b else [low]
                      for a, b in X_all])
    print(X_all[:20], y_all[:20])

    # test data
    X_test, y_test = (X_all[:test_samples], y_all[:test_samples])
    error = sum((y[0]-_nn_like.forward_deterministic(X)[0])**2
                for X, y in zip(X_test, y_test)) / test_samples
    print(error)

    # train
    for X, y in zip(X_all, y_all):
        o = _nn_like.forward_deterministic(X)
        _nn_like.backprop_deterministic(o, y, 0.1)
    _nn_like.forward_deterministic([low, low])
    _nn_like.print_connections()
    _nn_like.print_states()

    # test
    error = sum((y[0]-_nn_like.forward_deterministic(X)[0])**2
                for X, y in zip(X_test, y_test)) / test_samples
    print(error)

    # print the 4 cases
    print('00', _nn_like.forward_deterministic([low, low]))
    print('01', _nn_like.forward_deterministic([low, high]))
    print('10', _nn_like.forward_deterministic([high, low]))
    print('11', _nn_like.forward_deterministic([high, high]))
Example #2
0
def test_backprop_121():
    print('\n=== test_backprop_121 ===')
    _nn_like.bias(0)
    _nn_like.nn_like([1, 2, 1])
    _nn_like.fixed_weights(1.2, 1.0)
    o_before = _nn_like.forward_deterministic([1])[0]
    _nn_like.backprop_deterministic([o_before], [0.6], 1.0)
    o_after = _nn_like.forward_deterministic([1])[0]
    print(o_before, o_after)
    assert 0.55 < o_after < 0.65
Example #3
0
def test_backprop_111_neg_input():
    print('\n=== test_backprop_111_neg_input ===')
    _nn_like.bias(0)
    _nn_like.nn_like([1, 1, 1])
    _nn_like.fixed_weights(1.2, 1.0)
    o_before = _nn_like.forward_deterministic([-2])[0]
    _nn_like.backprop_deterministic([o_before], [-0.6], 1.0)
    o_after = _nn_like.forward_deterministic([-2])[0]
    print(o_before, o_after)
    assert -0.65 < o_after < 0.55