def test_static_logsoftmax2d(): # Verified against theano.tensor.softmax skew = np.array([ [0.87566484, 0.53596079, 0.85693981, 0.09526036], [0.32024455, 0.81532148, 0.2480434, 0.85119342], [0.57943085, 0.33958252, 0.95864464, 0.22881712], ]) x = np.array([[0.0, 1.0, 2.0, 3.0], [4.0, 5.0, 6.0, 7.0], [8.0, 9.0, 10.0, 11.0]]) x = Tensor(x) f = (logsoftmax(x, constant=False) * skew).sum() out = np.array(-13.722895761739732) assert_allclose(actual=f.data, desired=out) f.backward() dx = np.array([ [0.79988389, 0.3299668, 0.29699009, -1.42684078], [0.24859989, 0.62057111, -0.281343, -0.587828], [0.5119002, 0.15601518, 0.45965687, -1.12757225], ]) assert_allclose(x.grad, dx, atol=1e-5, rtol=1e-5)
def test_static_logsoftmax1d(): # Verified against theano.tensor.softmax skew = np.array([0.87566484, 0.53596079, 0.85693981, 0.09526036]) x = np.array([0., 1., 2., 3.]) x = Tensor(x) f = (logsoftmax(x, constant=False) * skew).sum() out = np.array(-5.596387676353177) assert_allclose(actual=f.data, desired=out) f.backward() dx = np.array([0.79988389, 0.3299668, 0.29699009, -1.42684078]) assert_allclose(x.grad, dx, atol=1e-5, rtol=1e-5)
def test_static_logsoftmax(): # Verified against theano.tensor.softmax skew = np.array([[ 0.87566484, 0.53596079, 0.85693981, 0.09526036], [ 0.32024455, 0.81532148, 0.2480434 , 0.85119342], [ 0.57943085, 0.33958252, 0.95864464, 0.22881712]]) x = np.array([[ 0., 1., 2., 3.], [ 4., 5., 6., 7.], [ 8., 9., 10., 11.]]) x = Tensor(x) f = (logsoftmax(x) * skew).sum() out = np.array(-13.722895761739732) assert np.allclose(f.data, out) f.backward() dx = np.array([[ 0.79988389, 0.3299668 , 0.29699009, -1.42684078], [ 0.24859989, 0.62057111, -0.281343 , -0.587828 ], [ 0.5119002 , 0.15601518, 0.45965687, -1.12757225]]) assert np.allclose(x.grad, dx)
def test_log_softmax_numerical_stability(x: np.ndarray, data: st.DataObject): axis = data.draw(valid_axes(x.ndim), label="axis") out = np.exp(logsoftmax(x, axis=axis).data) assert np.all(np.logical_and(0 <= out, out <= 1)), out assert_allclose(out.sum(axis=axis), 1.0)