def test_dot_constant_shape_2D(shape): N, P = shape sa = StackedArray([da.random.random(size=(N, P)) for _ in range(7)]) y = da.random.random(P) np.testing.assert_array_almost_equal(sa.dot(y), sa.array.dot(y), decimal=12) y = da.random.random((P, 2)) np.testing.assert_array_almost_equal(sa.dot(y), sa.array.dot(y), decimal=12)
def test_dot_non_consistent_shape(shapes): assume(all(max(shape) > 1 for shape in shapes)) sa = StackedArray([da.random.random(shape) for shape in shapes]) N, P = sa.shape y = da.random.random(P) note(f"shapes: {shapes}, y shape: {y.shape}") np.testing.assert_array_almost_equal(sa.dot(y), sa.array.dot(y), decimal=12)
def test_dot_2D_1D(shape): N, P = shape assume(N > 1) assume(P > 1) sa = StackedArray([ da.random.random(size=(N, P)), da.random.random(size=(N, 1)), da.random.random(size=(P, )), da.random.random(size=(1, P)) ]) for size in [(P, 2), (P, )]: y = da.random.random(size=size) np.testing.assert_array_almost_equal(sa.dot(y), sa.array.dot(y), decimal=12)