コード例 #1
0
ファイル: test_mlp.py プロジェクト: ddofer/breze
def test_awn_fit():
    X = np.random.standard_normal((10, 2))
    Z = np.random.standard_normal((10, 1))
    loss = lambda target, prediction: squared(target, prediction[:, :target.shape[1]])
    mlp = AwnNetwork(
        2, [10], 1, ['rectifier'], 'identity', loss, max_iter=10)
    mlp.fit(X, Z)
コード例 #2
0
ファイル: test_mlp.py プロジェクト: ddofer/breze
def test_awn_iter_fit():
    X = np.random.standard_normal((10, 2))
    Z = np.random.standard_normal((10, 1))
    loss = lambda target, prediction: squared(target, prediction[:, :target.shape[1]])
    mlp = AwnNetwork(
        2, [10], 1, ['rectifier'], 'identity', loss, max_iter=10)
    for i, info in enumerate(mlp.iter_fit(X, Z)):
        if i >= 10:
            break