def test_multikron_single(): a = np.random.random((2, 3, 5)) b = np.random.random((4, 8)) r1 = multikron(a, b) r2 = np.array([np.kron(ai, b) for ai in a]) assert r1.shape == r2.shape np.testing.assert_allclose(r1, r2) a = np.random.random((6, 3)) b = np.random.random((5, 2, 9)) r1 = multikron(a, b) r2 = np.array([np.kron(a, bi) for bi in b]) assert r1.shape == r2.shape np.testing.assert_allclose(r1, r2)
def test_multikron_broadcast(): a = np.random.random((10, 11, 2, 4)) b = np.random.random((11, 4, 3)) r1 = multikron(a, b[np.newaxis]) r2 = np.array([ [np.kron(ai, bi) for ai, bi in zip(aj, b)] for aj in a ]) np.testing.assert_allclose(r1, r2) a = np.random.random((7, 6, 3)) b = np.random.random((5, 2, 9)) r1 = multikron(a[:, np.newaxis], b[np.newaxis, :]) r2 = np.array([ [np.kron(ai, bi) for bi in b] for ai in a ]) assert r1.shape == r2.shape np.testing.assert_allclose(r1, r2)
def test_multikron_noncontiguous(): # a non-contiguous a = np.random.random((10, 3, 3)).swapaxes(1, 2) b = np.random.random((10, 3, 3)) assert not a.flags.contiguous r1 = multikron(a, b) r2 = np.array(list(map(np.kron, a, b))) np.testing.assert_allclose(r1, r2) # b non-contiguous a = np.random.random((10, 3, 4)) b = np.random.random((10, 3, 4)).swapaxes(1, 2) assert not b.flags.contiguous r1 = multikron(a, b) r2 = np.array(list(map(np.kron, a, b))) np.testing.assert_allclose(r1, r2) # both non-contiguous a = np.random.random((10, 3, 5)).swapaxes(1, 2) b = np.random.random((10, 6, 3)).swapaxes(1, 2) assert not a.flags.contiguous assert not b.flags.contiguous r1 = multikron(a, b) r2 = np.array(list(map(np.kron, a, b))) np.testing.assert_allclose(r1, r2)
def test_multikron_ndim(): a = np.random.random((10, 11, 2, 4)) b = np.random.random((10, 11, 4, 3)) r1 = multikron(a, b) r2 = np.array(list(map(np.kron, a.reshape(-1, 2, 4), b.reshape(-1, 4, 3)))).reshape(10, 11, 2*4, 4*3) np.testing.assert_allclose(r1, r2)
def test_multikron_eqshape(): a = np.random.random((31, 4, 4)) b = np.random.random((31, 4, 4)) r1 = multikron(a, b) r2 = np.array(list(map(np.kron, a, b))) np.testing.assert_allclose(r1, r2)