コード例 #1
0
 def test_decompose_einsum_equation_deep_case(self):
     m1 = np.arange(0, 16).astype(np.float32).reshape((2, 2, 2, 2))
     m2 = np.arange(0, 16).astype(np.float32).reshape((2, 2, 2, 2))
     exp = np.einsum("bsnh,btnh->bnts", m1, m2)
     seq = decompose_einsum_equation("bsnh,btnh->bnts")
     res = self.apply_einsum_sequence(seq, m1, m2)
     assert_almost_equal(exp, res)
コード例 #2
0
 def test_decompose_einsum_equation_noshape(self):
     m1 = np.arange(0, 24).astype(np.float32).reshape((2, 3, 4))
     m2 = np.arange(0, 20).astype(np.float32).reshape((4, 5))
     seq = decompose_einsum_equation("bac,ch->ah")
     exp = np.einsum("bac,ch->ah", m1, m2)
     res = self.apply_einsum_sequence(seq, m1, m2)
     assert_almost_equal(exp, res)
コード例 #3
0
    def test_many_3(self):
        "test many equation with 3 inputs"
        m1 = np.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10
        m2 = np.arange(4).reshape((2, 2)) + 100
        m3 = np.arange(8).reshape((2, 2, 2)) + 1000

        res = []
        for p1 in itertools.permutations(list("abc")):  # pylint: disable=R1702
            for p2 in itertools.permutations(list("cd")):
                for p3 in itertools.permutations(list("def")):
                    for i in [1, 2]:
                        for j in [0, 1]:
                            sp1 = "".join(p1)
                            sp2 = "".join(p2)
                            sp3 = "".join(p3)
                            equation = "%s,%s,%s->%s%s%s" % (
                                sp1, sp2, sp3, sp1[0], sp1[i], sp3[j])
                            try:
                                r = np.einsum(equation, m1, m2, m3)
                                res.append((equation, r))
                            except ValueError:
                                # Not viable equation.
                                continue

        for i, (eq, exp) in enumerate(res):
            with self.subTest(equation=eq, index=i, total=len(res)):
                seq = decompose_einsum_equation(
                    eq, m1.shape, m2.shape, m3.shape)
                res = self.apply_einsum_sequence(seq, m1, m2, m3)
                exp = np.einsum(eq, m1, m2, m3)
                assert_almost_equal(exp, res)
コード例 #4
0
    def common_test_case_2(self, equation):
        m1 = np.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10
        m2 = np.arange(4).reshape((2, 2)) + 100
        exp = np.einsum(equation, m1, m2)

        seq = decompose_einsum_equation(equation, m1.shape, m2.shape)
        res = self.apply_einsum_sequence(seq, m1, m2)
        assert_almost_equal(exp, res)
コード例 #5
0
    def test_decompose_einsum_equation_onnx2(self):
        "test bac,cd,def->ebc"
        m1 = np.arange(0, 24).astype(np.float32).reshape((2, 3, 4))
        m2 = np.arange(0, 20).astype(np.float32).reshape((4, 5))
        m3 = np.arange(0, 77 * 5).astype(np.float32).reshape((5, 7, 11))

        seq = decompose_einsum_equation(
            "bac,cd,def->ebc", (2, 3, 4), (4, 5), (5, 7, 11))
        exp = np.einsum("bac,cd,def->ebc", m1, m2, m3)
        res = self.apply_einsum_sequence(seq, m1, m2, m3)
        assert_almost_equal(exp, res)
コード例 #6
0
 def test_decompose_einsum_equation(self):
     "test decompose einsum"
     m1 = np.arange(0, 8).astype(np.float32).reshape((2, 2, 2))
     m2 = np.arange(0, 4).astype(np.float32).reshape((2, 2))
     exp = np.einsum("bac,ch->ah", m1, m2)
     seq = decompose_einsum_equation("bac,ch->ah", (2, 2, 2), (2, 2))
     dot = seq.to_dot()
     red = dot.split('red')
     self.assertEqual(len(red), 5)
     res = self.apply_einsum_sequence(seq, m1, m2)
     assert_almost_equal(exp, res)
コード例 #7
0
    def optimize_compare(self, equation, operands=None):
        "Compares numpy einsum and ONNX."
        with self.subTest(equation=equation):
            if operands is not None:
                inputs = operands
            else:
                eqs = equation.split("->")[0].split(",")
                inputs = []
                for d, eq in enumerate(eqs):
                    i = np.arange(2 ** len(eq)).reshape(
                        (2,) * len(eq)).astype(np.float32)
                    inputs.append(
                        i + np.array([3 ** d], dtype=np.float32))

            exp = np.einsum(equation, *inputs)
            shapes = [m.shape for m in inputs]

            seq = decompose_einsum_equation(equation, *shapes)
            got = self.apply_einsum_sequence(seq, *inputs)
            assert_almost_equal(exp, got, decimal=5)
コード例 #8
0
 def test_abbba(self):
     decompose_einsum_equation("ab,b->ba")