def test_valid_axes(shape, data, permit_none, pos_only): min_dim = data.draw(st.integers(0, len(shape)), label="min_dim") max_dim = data.draw( st.one_of(st.none(), st.integers(min_dim, len(shape))), label="max_dim" ) axis = data.draw( valid_axes( ndim=len(shape), permit_none=permit_none, pos_only=pos_only, min_dim=min_dim, max_dim=max_dim, ), label="axis", ) x = np.empty(shape) np.sum(x, axis=axis) if not permit_none: assert axis is not None if pos_only and axis is not None: assert all(i >= 0 for i in axis) if axis is not None: assert min_dim <= len(axis) if max_dim is not None: assert len(axis) <= max_dim
def test_valid_single_axis(shape, data, permit_none): axis = data.draw( valid_axes(ndim=len(shape), single_axis_only=True, permit_none=permit_none), label="axis", ) x = np.empty(shape) np.argmin(x, axis=axis) # raises if `axis` is invalid if not permit_none: assert axis is not None
def test_logsumexp(data: st.SearchStrategy, x: np.ndarray, keepdims: bool): axes = data.draw(valid_axes(ndim=x.ndim), label="axes") mygrad_result = logsumexp(x, axis=axes, keepdims=keepdims) scipy_result = special.logsumexp(x, axis=axes, keepdims=keepdims) assert_array_equal( mygrad_result, scipy_result, err_msg="mygrad's implementation of logsumexp does " "not match that of scipy's", )
def test_log_softmax_numerical_stability(x: np.ndarray, data: st.DataObject): axis = data.draw(valid_axes(x.ndim), label="axis") out = np.exp(logsoftmax(x, axis=axis).data) assert np.all(np.logical_and(0 <= out, out <= 1)), out assert_allclose(out.sum(axis=axis), 1.0)
def numpy_softmax(x, axis): x = np.asarray(x) x = np.exp(x - x.max(axis, keepdims=True)) return x / x.sum(axis, keepdims=True) def numpy_logsoftmax(x, axis): return np.log(numpy_softmax(x, axis)) @fwdprop_test_factory( mygrad_func=softmax, true_func=numpy_softmax, num_arrays=1, kwargs=dict(axis=lambda arrs: valid_axes(arrs.ndim)), ) def test_softmax_fwd(): pass @backprop_test_factory( mygrad_func=softmax, true_func=numpy_softmax, num_arrays=1, kwargs=dict(axis=lambda arrs: valid_axes(arrs.ndim)), vary_each_element=True, ) def test_softmax_bkwd(): pass