def test_countnonzero(self): x = np.random.randint(0, 10, size=(25, 12, 8)) for axis in (None, 0, 1, 2, (1, 2)): for keepdims in (True, False): for dtype in ('int32', 'float32'): y = [ bk.count_nonzero(bk.array(x, fw), axis=axis, keepdims=keepdims, dtype=dtype) for fw in FRAMEWORKS ] assert_equal(self, (axis, keepdims, dtype), *y)
def test_norm(self): for p in [1, 2, 'fro', np.inf]: for axis in [None, 0, 1, (0, 1)]: a = bk.norm(bk.flatten(x, 2), p=p, axis=axis, keepdims=True) b = bk.norm(bk.flatten(y, 2), p=p, axis=axis, keepdims=True) c = bk.norm(bk.flatten(z, 2), p=p, axis=axis, keepdims=True) assert_equal(self, (p, axis), a, b, c) a = bk.norm(bk.flatten(x, 2), p=p, axis=axis, keepdims=False) b = bk.norm(bk.flatten(y, 2), p=p, axis=axis, keepdims=False) c = bk.norm(bk.flatten(z, 2), p=p, axis=axis, keepdims=False) assert_equal(self, (p, axis), a, b, c)
def test_stats_and_reduce(self): for axis in (1, 2, None): for name, fn in ( ("min_keepdims", lambda _: bk.reduce_min(_, axis=axis, keepdims=True)), ("min", lambda _: bk.reduce_min(_, axis=axis, keepdims=False)), ("max_keepdims", lambda _: bk.reduce_max(_, axis=axis, keepdims=True)), ("max", lambda _: bk.reduce_max(_, axis=axis, keepdims=False)), ("mean_keepdims", lambda _: bk.reduce_mean(_, axis=axis, keepdims=True)), ("mean", lambda _: bk.reduce_mean(_, axis=axis, keepdims=False)), ("var_keepdims", lambda _: bk.reduce_var(_, axis=axis, keepdims=True)), ("var", lambda _: bk.reduce_var(_, axis=axis, keepdims=False)), ("std_keepdims", lambda _: bk.reduce_std(_, axis=axis, keepdims=True)), ("std", lambda _: bk.reduce_std(_, axis=axis, keepdims=False)), ("sum_keepdims", lambda _: bk.reduce_sum(_, axis=axis, keepdims=True)), ("sum", lambda _: bk.reduce_sum(_, axis=axis, keepdims=False)), ("prod_keepdims", lambda _: bk.reduce_prod(_, axis=axis, keepdims=True)), ("prod", lambda _: bk.reduce_prod(_, axis=axis, keepdims=False)), ("all_keepdims", lambda _: bk.reduce_all(_, axis=axis, keepdims=True)), ("all", lambda _: bk.reduce_all(_, axis=axis, keepdims=False)), ("any_keepdims", lambda _: bk.reduce_any(_, axis=axis, keepdims=True)), ("any", lambda _: bk.reduce_any(_, axis=axis, keepdims=False)), ("logsumexp_keepdims", lambda _: bk.reduce_logsumexp(_, axis=axis, keepdims=True)), ("logsumexp", lambda _: bk.reduce_logsumexp(_, axis=axis, keepdims=False)), ): # some functions are not supported by pytorch if any(_ in name for _ in ('min', 'max', 'prod', 'all', 'any')) and axis is None: continue a = fn(x) b = fn(y) c = fn(z) assert_equal(self, name, a, b, c) a1, a2 = bk.moments(x, axis=1) b1, b2 = bk.moments(y, axis=1) c1, c2 = bk.moments(z, axis=1) assert_equal(self, "moments_mean", a1, b1, c1) assert_equal(self, "moments_var", a2, b2, c2)
def test_clip_by_value(self): for minval, maxval in [(None, 1), (1, None), (1, 2)]: a = bk.clip(x, minval, maxval) b = bk.clip(y, minval, maxval) c = bk.clip(z, minval, maxval) assert_equal(self, (minval, maxval), a, b, c)
def tile_and_test(reps, axis): a = bk.tile(x, reps=reps, axis=axis) b = bk.tile(y, reps=reps, axis=axis) c = bk.tile(z, reps=reps, axis=axis) assert_equal(self, (reps, axis), a, b, c)
def swapaxes_and_test(a1, a2): a = bk.swapaxes(x, a1, a2) b = bk.swapaxes(y, a1, a2) c = bk.swapaxes(z, a1, a2) assert_equal(self, (a1, a2), a, b, c)
def flatten_and_test(n): a = bk.flatten(x, n) b = bk.flatten(y, n) c = bk.flatten(z, n) assert_equal(self, n, a, b, c)
def transpose_and_test(pattern): a = bk.transpose(x, pattern) b = bk.transpose(y, pattern) c = bk.transpose(z, pattern) assert_equal(self, pattern, a, b, c)
def reshape_and_test(newshape): a = bk.reshape(x, newshape) b = bk.reshape(y, newshape) c = bk.reshape(z, newshape) assert_equal(self, newshape, a, b, c)