def test_nnet(): x = vector("x") x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX) out = aet_nnet.sigmoid(x) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_nnet.ultra_fast_sigmoid(x) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_nnet.softplus(x) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_nnet.softmax(x) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) out = aet_nnet.logsoftmax(x) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_logsoftmax(axis): x = matrix("x") x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3) out = aet_nnet.logsoftmax(x, axis=axis) fgraph = FunctionGraph([x], [out]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])