예제 #1
0
    def optimize_compare(self, equation, operands=None, verbose=False):
        for clean in [False, True]:
            with self.subTest(equation=equation):
                if operands is not None:
                    inputs = operands
                else:
                    eqs = equation.split("->")[0].split(",")
                    inputs = []
                    for d, eq in enumerate(eqs):
                        i = numpy.arange(2**len(eq)).reshape(
                            (2, ) * len(eq)).astype(numpy.float32)
                        inputs.append(i +
                                      numpy.array([3**d], dtype=numpy.float32))

                exp = numpy.einsum(equation, *inputs)
                if verbose:
                    print("###### equation", equation)
                    path = numpy.einsum_path(equation, *inputs, optimize=False)
                    print(path[1])
                    path = numpy.einsum_path(equation, *inputs)
                    print(path[1])

                shapes = [m.shape for m in inputs]
                vv = 12 if equation == ",a,ab,abc->abc" else verbose

                with self.subTest(strategy='numpy'):
                    seq = decompose_einsum_equation(equation,
                                                    *shapes,
                                                    verbose=verbose,
                                                    strategy='numpy',
                                                    clean=clean)
                    got = apply_einsum_sequence(seq, *inputs, verbose=vv)
                    self.assertEqualArray(exp, got, decimal=6)

                if clean:
                    with self.subTest(strategy='onnx'):
                        inps = ['X%d' % (i + 1) for i in range(len(inputs))]
                        try:
                            onx = seq.to_onnx('Y', *inps, dtype=numpy.float32)
                        except NotImplementedError as e:
                            if "diagonal" in str(e):
                                onx = None
                            else:
                                raise e
                        if onx is not None:
                            oinf = OnnxInference(onx)
                            inps = {
                                n: v.astype(numpy.float32)
                                for n, v in zip(inps, inputs)
                            }
                            got = oinf.run(inps, verbose=vv, fLOG=print)['Y']
                            self.assertEqualArray(exp, got, decimal=5)

                with self.subTest(strategy='simple'):
                    seq = decompose_einsum_equation(equation,
                                                    *shapes,
                                                    clean=clean,
                                                    verbose=verbose)
                    got = apply_einsum_sequence(seq, *inputs, verbose=verbose)
                    self.assertEqualArray(exp, got, decimal=6)
예제 #2
0
    def test_many_2(self):
        m1 = numpy.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10
        m2 = numpy.arange(4).reshape((2, 2)) + 100

        res = []
        for p1 in itertools.permutations(list("abc")):
            for p2 in itertools.permutations(list("cd")):
                for i in [1, 2]:
                    for j in [0, 1]:
                        sp1 = "".join(p1)
                        sp2 = "".join(p2)
                        if len(set([sp1[0], sp1[i], sp2[j]])) != 3:
                            continue
                        equation = "%s,%s->%s%s%s" % (sp1, sp2, sp1[0], sp1[i],
                                                      sp2[j])
                        try:
                            r = numpy.einsum(equation, m1, m2)
                            res.append((equation, r))
                        except ValueError:
                            # Not viable equation.
                            continue

        for i, (eq, exp) in enumerate(res):
            with self.subTest(equation=eq, index=i, total=len(res)):
                verbose = 12 if eq == ',abc,dc->acd' else 0
                if verbose:
                    print(
                        '\n########################################clean=False'
                    )
                    print("#########0", eq)
                seq = decompose_einsum_equation(eq,
                                                m1.shape,
                                                m2.shape,
                                                verbose=verbose)
                res = apply_einsum_sequence(seq, m1, m2, verbose=verbose)
                self.assertEqualArray(exp, res)

                if verbose:
                    print(
                        '\n########################################clean=True')
                    print("#########1", eq)
                seq = decompose_einsum_equation(eq,
                                                m1.shape,
                                                m2.shape,
                                                strategy='numpy',
                                                clean=True,
                                                verbose=verbose)
                res = apply_einsum_sequence(seq, m1, m2, verbose=verbose)
                self.assertEqualArray(exp, res)
                onx = seq.to_onnx('Y', 'X1', 'X2', dtype=numpy.float32)
                oinf = OnnxInference(onx)
                res2 = oinf.run(
                    {
                        'X1': m1.astype(numpy.float32),
                        'X2': m2.astype(numpy.float32)
                    },
                    verbose=verbose,
                    fLOG=print)
                self.assertEqualArray(exp, res2['Y'])
예제 #3
0
 def test_decompose_einsum_equation_exc(self):
     self.assertRaise(
         lambda: decompose_einsum_equation("abc,ch->ah", (2, 2, 2), (2, 2),
                                           strategy="donotexist"),
         ValueError)
     self.assertRaise(
         lambda: decompose_einsum_equation("abc,ch->ah", (2, 2, 2),
                                           (2, 2), "donotexist"), TypeError)
     self.assertRaise(
         lambda: decompose_einsum_equation("abc,ch->ah", (2, 2, 2)),
         ValueError)
     self.assertRaise(
         lambda: decompose_einsum_equation("abc,ch->ah", (2, 2), (2, 2)),
         ValueError)
예제 #4
0
    def test_bdn_in_bdi(self):
        equation = "bdn,in->bdi"
        seq = decompose_einsum_equation(equation, strategy='numpy', clean=True)

        inp1 = numpy.arange(2 * 3 * 5).reshape((2, 3, 5)).astype(numpy.float32)
        inp2 = numpy.arange(5 * 7).reshape((7, 5)).astype(numpy.float32)
        exp = numpy.einsum(equation, inp1, inp2)
        got = apply_einsum_sequence(seq, inp1, inp2)
        self.assertEqualArray(exp, got)

        onx = seq.to_onnx("Y", "X1", "X2")
        self.assertNotIn('Transpose', str(onx))
        oinf = OnnxInference(onx)
        res = oinf.run({
            'X1': inp1.astype(numpy.float32),
            'X2': inp2.astype(numpy.float32)
        })
        oinf = OnnxInference(onx, runtime='onnxruntime1')
        res = oinf.run({
            'X1': inp1.astype(numpy.float32),
            'X2': inp2.astype(numpy.float32)
        })
        got = res['Y']
        self.assertEqualArray(exp, got)
        for op in seq:
            if op.name == 'batch_dot':
                kind = op.get_dot_kind()
                self.assertEqual(kind, "11")
예제 #5
0
 def test_decompose_einsum_equation_pyf(self):
     m1 = numpy.arange(0, 8).astype(numpy.float32).reshape((2, 2, 2))
     m2 = numpy.arange(0, 4).astype(numpy.float32).reshape((2, 2))
     seq = decompose_einsum_equation("bac,ch->ah", (2, 2, 2), (2, 2))
     res1 = apply_einsum_sequence(seq, m1, m2)
     res2 = apply_einsum_sequence(seq, m1, m2, matmul_impl='pyf')
     self.assertEqualArray(res1, res2)
예제 #6
0
 def test_case_1_iii_ii_i(self):
     verbose = False
     equation = 'ii->i'
     m1 = numpy.arange(2 * 2).reshape((2, 2)) + 10
     exp = numpy.einsum(equation, m1)
     seq = decompose_einsum_equation(equation, m1.shape, verbose=verbose)
     res = apply_einsum_sequence(seq, m1, verbose=verbose)
     self.assertEqualArray(exp, res)
예제 #7
0
 def test_case_1_iii_ii_i_j(self):
     verbose = False
     equation = 'iij->ij'
     m1 = numpy.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10
     exp = numpy.einsum(equation, m1)
     seq = decompose_einsum_equation(equation, m1.shape, verbose=verbose)
     dot = seq.to_dot()
     self.assertIn("i=0,1", dot)
     res = apply_einsum_sequence(seq, m1, verbose=verbose)
     self.assertEqualArray(exp, res)
예제 #8
0
    def common_test_case_2(self, equation, verbose=False, strategy='simple'):
        m1 = numpy.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10
        m2 = numpy.arange(4).reshape((2, 2)) + 100
        exp = numpy.einsum(equation, m1, m2)

        seq = decompose_einsum_equation(equation,
                                        m1.shape,
                                        m2.shape,
                                        verbose=verbose,
                                        strategy=strategy)
        res = apply_einsum_sequence(seq, m1, m2, verbose=verbose)
        self.assertEqualArray(exp, res)
예제 #9
0
 def fct():
     print("########################## DECOMPOSE")
     seq = decompose_einsum_equation("bac,ch->ah", (2, 2, 2), (2, 2),
                                     verbose=True)
     print("########################## APPLY")
     dot = seq.to_dot()
     print(dot)
     red = dot.split('red')
     self.assertEqual(len(red), 5)
     res = apply_einsum_sequence(seq, m1, m2, verbose=True)
     print("########################## END")
     return res
예제 #10
0
 def fct():
     print("########################## DECOMPOSE")
     seq = decompose_einsum_equation("bac,chg->ah", (2, 2, 2),
                                     (2, 2, 2),
                                     verbose=True,
                                     clean=True,
                                     strategy='numpy')
     print("########################## APPLY")
     dot = seq.to_dot()
     print(dot)
     red = dot.split('red')
     self.assertEqual(len(red), 6)
     res = apply_einsum_sequence(seq, m1, m2, verbose=True)
     print("########################## END")
     onx = seq.to_onnx('Y', 'X1', 'X2', verbose=True)
     self.assertNotEmpty(onx)
     return res
예제 #11
0
    def test_decompose_einsum_equation_onnx2(self):
        m1 = numpy.arange(0, 24).astype(numpy.float32).reshape((2, 3, 4))
        m2 = numpy.arange(0, 20).astype(numpy.float32).reshape((4, 5))
        m3 = numpy.arange(0, 77 * 5).astype(numpy.float32).reshape((5, 7, 11))
        verbose = False
        for strat, opname in [('numpy', 'batch_dot')]:  # pylint: disable=W0612
            with self.subTest(strategy=strat):
                seq = decompose_einsum_equation("bac,cd,def->ebc", (2, 3, 4),
                                                (4, 5), (5, 7, 11),
                                                strategy=strat,
                                                verbose=verbose)
                res1 = apply_einsum_sequence(seq, m1, m2, m3, verbose=verbose)
                seq.simplify_mm_nodes()
                seq.clean_unused_nodes()
                onx = seq.to_onnx("Y", "X1", "X2", "X3", dtype=numpy.float32)

                oinf = OnnxInference(onx)
                oxres = oinf.run({
                    'X1': m1.astype(numpy.float32),
                    'X2': m2.astype(numpy.float32),
                    'X3': m3.astype(numpy.float32)
                })
                res2 = oxres['Y']
                self.assertEqualArray(res1, res2)

                oinf = OnnxInference(onx, runtime="onnxruntime2")
                oxres = oinf.run({
                    'X1': m1.astype(numpy.float32),
                    'X2': m2.astype(numpy.float32),
                    'X3': m3.astype(numpy.float32)
                })
                res2 = oxres['Y']
                self.assertEqualArray(res1, res2)

                so = SessionOptions()
                so.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
                oinf = InferenceSession(onx.SerializeToString(), so)
                oxres = oinf.run(
                    None, {
                        'X1': m1.astype(numpy.float32),
                        'X2': m2.astype(numpy.float32),
                        'X3': m3.astype(numpy.float32)
                    })
                res2 = oxres[0]
                self.assertEqualArray(res1, res2)
예제 #12
0
        def local_test(inp1, inp2):
            exp = numpy.einsum('bid,nd->bin', inp1, inp2)
            seq = decompose_einsum_equation('bid,nd->bin',
                                            clean=True,
                                            strategy='numpy')
            got = apply_einsum_sequence(seq, inp1, inp2)
            self.assertEqualArray(exp, got, decimal=3)

            onx = seq.to_onnx('Y', 'X1', 'X2')
            oinf = OnnxInference(onx)
            got = oinf.run({'X1': inp1, 'X2': inp2})['Y']
            self.assertEqualArray(exp, got, decimal=3)

            onx = seq.to_onnx(
                'Y',
                'X1',
                'X2',
                initializer=[numpy_helper.from_array(inp2, name="X2")])
            oinf = OnnxInference(onx)
            got = oinf.run({'X1': inp1})['Y']
            self.assertEqualArray(exp, got, decimal=3)
예제 #13
0
 def test_decompose_einsum_equation_py_noshape(self):
     m1 = numpy.arange(0, 24).astype(numpy.float32).reshape((2, 3, 4))
     m2 = numpy.arange(0, 20).astype(numpy.float32).reshape((4, 5))
     verbose = False
     for strat, opname in [('numpy', 'batch_dot'), ('simple', 'matmul')]:
         with self.subTest(strategy=strat):
             seq = decompose_einsum_equation("bac,ch->ah",
                                             strategy=strat,
                                             verbose=verbose)
             self.assertIn(opname, seq.to_dot())
             res1 = apply_einsum_sequence(seq, m1, m2, verbose=verbose)
             res2 = apply_einsum_sequence(seq,
                                          m1,
                                          m2,
                                          matmul_impl='py',
                                          verbose=verbose)
             if strat == 'simple':
                 self.assertRaise(
                     lambda: apply_einsum_sequence(
                         seq, m1, m2, matmul_impl='py2'),  # pylint: disable=W0640
                     ValueError)
             self.assertEqualArray(res1, res2)
예제 #14
0
    def test_decompose_einsum_equation_onnx(self):
        m1 = numpy.arange(0, 24).astype(numpy.float32).reshape((2, 3, 4))
        m2 = numpy.arange(0, 20).astype(numpy.float32).reshape((4, 5))
        verbose = False
        for strat, opname in [('numpy', 'batch_dot')]:  # pylint: disable=W0612
            with self.subTest(strategy=strat):
                seq = decompose_einsum_equation("bac,ch->ah", (2, 3, 4),
                                                (4, 5),
                                                strategy=strat,
                                                verbose=verbose)
                res1 = apply_einsum_sequence(seq, m1, m2, verbose=verbose)
                self.assertRaise(
                    lambda: seq.to_onnx(  # pylint: disable=W0640
                        "Y",
                        "X1",
                        "X2",
                        dtype=numpy.float32),
                    NotImplementedError)
                seq.simplify_mm_nodes()
                seq.clean_unused_nodes()
                onx = seq.to_onnx("Y", "X1", "X2", dtype=numpy.float32)

                oinf = OnnxInference(onx)
                oxres = oinf.run({
                    'X1': m1.astype(numpy.float32),
                    'X2': m2.astype(numpy.float32)
                })
                res2 = oxres['Y']
                self.assertEqualArray(res1, res2)

                oinf = OnnxInference(onx, runtime="onnxruntime1")
                oxres = oinf.run({
                    'X1': m1.astype(numpy.float32),
                    'X2': m2.astype(numpy.float32)
                })
                res2 = oxres['Y']
                self.assertEqualArray(res1, res2)