def special_shape(draw, static_shape, shape=tuple(), min_dim=0, max_dim=5): """ search strategy that permits broadcastable dimensions to be prepended to a static shape - for the purposes of drawing diverse shaped-arrays for matmul Returns ------- hypothesis.searchstrategy.SearchStrategy -> Tuple[int, ...]""" return draw(broadcastable_shapes(shape, min_dim, max_dim)) + static_shape
def test_reduce_broadcast_nokeepdim(var_shape, data): """ example broadcasting: (2, 3) -> (5, 2, 3)""" grad_shape = data.draw( broadcastable_shapes( shape=var_shape, min_dims=len(var_shape) + 1, max_dims=len(var_shape) + 3, min_side=2, ), label="grad_shape", ) grad = np.ones(grad_shape, dtype=float) reduced_grad = reduce_broadcast(grad=grad, var_shape=var_shape) reduced_grad *= (np.prod(var_shape) / grad.size ) # scale reduced-grad so all elements are 1 assert_allclose(actual=reduced_grad, desired=np.ones(var_shape))
def test_reduce_broadcast_keepdim(var_shape, data): """ example broadcasting: (2, 1, 4) -> (2, 5, 4)""" grad = data.draw( hnp.arrays( dtype=float, shape=broadcastable_shapes(shape=var_shape, min_dims=len(var_shape), max_dims=len(var_shape)), elements=st.just(1.0), ), label="grad", ) reduced_grad = reduce_broadcast(grad=grad, var_shape=var_shape) assert reduced_grad.shape == tuple(i if i < j else j for i, j in zip(var_shape, grad.shape)) assert (i == 1 for i, j in zip(var_shape, grad.shape) if i < j) sum_axes = tuple(n for n, (i, j) in enumerate(zip(var_shape, grad.shape)) if i != j) assert_allclose(actual=reduced_grad, desired=grad.sum(axis=sum_axes, keepdims=True))