def test_decompose_einsum_equation_deep_case(self): m1 = np.arange(0, 16).astype(np.float32).reshape((2, 2, 2, 2)) m2 = np.arange(0, 16).astype(np.float32).reshape((2, 2, 2, 2)) exp = np.einsum("bsnh,btnh->bnts", m1, m2) seq = decompose_einsum_equation("bsnh,btnh->bnts") res = self.apply_einsum_sequence(seq, m1, m2) assert_almost_equal(exp, res)
def test_decompose_einsum_equation_noshape(self): m1 = np.arange(0, 24).astype(np.float32).reshape((2, 3, 4)) m2 = np.arange(0, 20).astype(np.float32).reshape((4, 5)) seq = decompose_einsum_equation("bac,ch->ah") exp = np.einsum("bac,ch->ah", m1, m2) res = self.apply_einsum_sequence(seq, m1, m2) assert_almost_equal(exp, res)
def test_many_3(self): "test many equation with 3 inputs" m1 = np.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10 m2 = np.arange(4).reshape((2, 2)) + 100 m3 = np.arange(8).reshape((2, 2, 2)) + 1000 res = [] for p1 in itertools.permutations(list("abc")): # pylint: disable=R1702 for p2 in itertools.permutations(list("cd")): for p3 in itertools.permutations(list("def")): for i in [1, 2]: for j in [0, 1]: sp1 = "".join(p1) sp2 = "".join(p2) sp3 = "".join(p3) equation = "%s,%s,%s->%s%s%s" % ( sp1, sp2, sp3, sp1[0], sp1[i], sp3[j]) try: r = np.einsum(equation, m1, m2, m3) res.append((equation, r)) except ValueError: # Not viable equation. continue for i, (eq, exp) in enumerate(res): with self.subTest(equation=eq, index=i, total=len(res)): seq = decompose_einsum_equation( eq, m1.shape, m2.shape, m3.shape) res = self.apply_einsum_sequence(seq, m1, m2, m3) exp = np.einsum(eq, m1, m2, m3) assert_almost_equal(exp, res)
def common_test_case_2(self, equation): m1 = np.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10 m2 = np.arange(4).reshape((2, 2)) + 100 exp = np.einsum(equation, m1, m2) seq = decompose_einsum_equation(equation, m1.shape, m2.shape) res = self.apply_einsum_sequence(seq, m1, m2) assert_almost_equal(exp, res)
def test_decompose_einsum_equation_onnx2(self): "test bac,cd,def->ebc" m1 = np.arange(0, 24).astype(np.float32).reshape((2, 3, 4)) m2 = np.arange(0, 20).astype(np.float32).reshape((4, 5)) m3 = np.arange(0, 77 * 5).astype(np.float32).reshape((5, 7, 11)) seq = decompose_einsum_equation( "bac,cd,def->ebc", (2, 3, 4), (4, 5), (5, 7, 11)) exp = np.einsum("bac,cd,def->ebc", m1, m2, m3) res = self.apply_einsum_sequence(seq, m1, m2, m3) assert_almost_equal(exp, res)
def test_decompose_einsum_equation(self): "test decompose einsum" m1 = np.arange(0, 8).astype(np.float32).reshape((2, 2, 2)) m2 = np.arange(0, 4).astype(np.float32).reshape((2, 2)) exp = np.einsum("bac,ch->ah", m1, m2) seq = decompose_einsum_equation("bac,ch->ah", (2, 2, 2), (2, 2)) dot = seq.to_dot() red = dot.split('red') self.assertEqual(len(red), 5) res = self.apply_einsum_sequence(seq, m1, m2) assert_almost_equal(exp, res)
def optimize_compare(self, equation, operands=None): "Compares numpy einsum and ONNX." with self.subTest(equation=equation): if operands is not None: inputs = operands else: eqs = equation.split("->")[0].split(",") inputs = [] for d, eq in enumerate(eqs): i = np.arange(2 ** len(eq)).reshape( (2,) * len(eq)).astype(np.float32) inputs.append( i + np.array([3 ** d], dtype=np.float32)) exp = np.einsum(equation, *inputs) shapes = [m.shape for m in inputs] seq = decompose_einsum_equation(equation, *shapes) got = self.apply_einsum_sequence(seq, *inputs) assert_almost_equal(exp, got, decimal=5)
def test_abbba(self): decompose_einsum_equation("ab,b->ba")