def test_inv2_multiple(): Ts = np.random.random((154, 7, 2, 2)) Tinv_np = np.array(list(map(np.linalg.inv, Ts.reshape((-1, 2, 2))))).reshape(Ts.shape) Tinv_blitz = inv2(Ts) np.set_printoptions(suppress=True) np.testing.assert_allclose(Tinv_np, Tinv_blitz)
def test_inv2_float32(): np.random.seed(42) Ts = np.random.random((1000, 2, 2)).astype(np.float32) Tinv_np = np.array(list(map(np.linalg.inv, Ts))).reshape(Ts.shape) Tinv_blitz = inv2(Ts) np.testing.assert_allclose(Tinv_np, Tinv_blitz, rtol=1.e-3)
def test_inv2(): T = np.random.random((2, 2)) np.testing.assert_allclose(np.linalg.inv(T), inv2(T))