def test_keepdims(): x = Symbol('x', '5 * 3 * float32') assert x.sum(axis=0, keepdims=True).dshape == dshape('1 * 3 * float32') assert x.sum(axis=1, keepdims=True).dshape == dshape('5 * 1 * float32') assert x.sum(axis=(0, 1), keepdims=True).dshape == dshape('1 * 1 * float32') assert x.std(axis=0, keepdims=True).shape == (1, 3)
def test_axis_kwarg_is_normalized_to_tuple(): x = Symbol('x', '5 * 3 * float32') exprs = [x.sum(), x.sum(axis=1), x.sum(axis=[1]), x.std(), x.mean(axis=1)] for expr in exprs: assert isinstance(expr.axis, tuple)