def test_awn_fit(): X = np.random.standard_normal((10, 2)) Z = np.random.standard_normal((10, 1)) loss = lambda target, prediction: squared(target, prediction[:, :target.shape[1]]) mlp = AwnNetwork( 2, [10], 1, ['rectifier'], 'identity', loss, max_iter=10) mlp.fit(X, Z)
def test_awn_iter_fit(): X = np.random.standard_normal((10, 2)) Z = np.random.standard_normal((10, 1)) loss = lambda target, prediction: squared(target, prediction[:, :target.shape[1]]) mlp = AwnNetwork( 2, [10], 1, ['rectifier'], 'identity', loss, max_iter=10) for i, info in enumerate(mlp.iter_fit(X, Z)): if i >= 10: break