def optimize_compare(self, equation, operands=None, verbose=False): for clean in [False, True]: with self.subTest(equation=equation): if operands is not None: inputs = operands else: eqs = equation.split("->")[0].split(",") inputs = [] for d, eq in enumerate(eqs): i = numpy.arange(2**len(eq)).reshape( (2, ) * len(eq)).astype(numpy.float32) inputs.append(i + numpy.array([3**d], dtype=numpy.float32)) exp = numpy.einsum(equation, *inputs) if verbose: print("###### equation", equation) path = numpy.einsum_path(equation, *inputs, optimize=False) print(path[1]) path = numpy.einsum_path(equation, *inputs) print(path[1]) shapes = [m.shape for m in inputs] vv = 12 if equation == ",a,ab,abc->abc" else verbose with self.subTest(strategy='numpy'): seq = decompose_einsum_equation(equation, *shapes, verbose=verbose, strategy='numpy', clean=clean) got = apply_einsum_sequence(seq, *inputs, verbose=vv) self.assertEqualArray(exp, got, decimal=6) if clean: with self.subTest(strategy='onnx'): inps = ['X%d' % (i + 1) for i in range(len(inputs))] try: onx = seq.to_onnx('Y', *inps, dtype=numpy.float32) except NotImplementedError as e: if "diagonal" in str(e): onx = None else: raise e if onx is not None: oinf = OnnxInference(onx) inps = { n: v.astype(numpy.float32) for n, v in zip(inps, inputs) } got = oinf.run(inps, verbose=vv, fLOG=print)['Y'] self.assertEqualArray(exp, got, decimal=5) with self.subTest(strategy='simple'): seq = decompose_einsum_equation(equation, *shapes, clean=clean, verbose=verbose) got = apply_einsum_sequence(seq, *inputs, verbose=verbose) self.assertEqualArray(exp, got, decimal=6)
def test_many_2(self): m1 = numpy.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10 m2 = numpy.arange(4).reshape((2, 2)) + 100 res = [] for p1 in itertools.permutations(list("abc")): for p2 in itertools.permutations(list("cd")): for i in [1, 2]: for j in [0, 1]: sp1 = "".join(p1) sp2 = "".join(p2) if len(set([sp1[0], sp1[i], sp2[j]])) != 3: continue equation = "%s,%s->%s%s%s" % (sp1, sp2, sp1[0], sp1[i], sp2[j]) try: r = numpy.einsum(equation, m1, m2) res.append((equation, r)) except ValueError: # Not viable equation. continue for i, (eq, exp) in enumerate(res): with self.subTest(equation=eq, index=i, total=len(res)): verbose = 12 if eq == ',abc,dc->acd' else 0 if verbose: print( '\n########################################clean=False' ) print("#########0", eq) seq = decompose_einsum_equation(eq, m1.shape, m2.shape, verbose=verbose) res = apply_einsum_sequence(seq, m1, m2, verbose=verbose) self.assertEqualArray(exp, res) if verbose: print( '\n########################################clean=True') print("#########1", eq) seq = decompose_einsum_equation(eq, m1.shape, m2.shape, strategy='numpy', clean=True, verbose=verbose) res = apply_einsum_sequence(seq, m1, m2, verbose=verbose) self.assertEqualArray(exp, res) onx = seq.to_onnx('Y', 'X1', 'X2', dtype=numpy.float32) oinf = OnnxInference(onx) res2 = oinf.run( { 'X1': m1.astype(numpy.float32), 'X2': m2.astype(numpy.float32) }, verbose=verbose, fLOG=print) self.assertEqualArray(exp, res2['Y'])
def test_decompose_einsum_equation_exc(self): self.assertRaise( lambda: decompose_einsum_equation("abc,ch->ah", (2, 2, 2), (2, 2), strategy="donotexist"), ValueError) self.assertRaise( lambda: decompose_einsum_equation("abc,ch->ah", (2, 2, 2), (2, 2), "donotexist"), TypeError) self.assertRaise( lambda: decompose_einsum_equation("abc,ch->ah", (2, 2, 2)), ValueError) self.assertRaise( lambda: decompose_einsum_equation("abc,ch->ah", (2, 2), (2, 2)), ValueError)
def test_bdn_in_bdi(self): equation = "bdn,in->bdi" seq = decompose_einsum_equation(equation, strategy='numpy', clean=True) inp1 = numpy.arange(2 * 3 * 5).reshape((2, 3, 5)).astype(numpy.float32) inp2 = numpy.arange(5 * 7).reshape((7, 5)).astype(numpy.float32) exp = numpy.einsum(equation, inp1, inp2) got = apply_einsum_sequence(seq, inp1, inp2) self.assertEqualArray(exp, got) onx = seq.to_onnx("Y", "X1", "X2") self.assertNotIn('Transpose', str(onx)) oinf = OnnxInference(onx) res = oinf.run({ 'X1': inp1.astype(numpy.float32), 'X2': inp2.astype(numpy.float32) }) oinf = OnnxInference(onx, runtime='onnxruntime1') res = oinf.run({ 'X1': inp1.astype(numpy.float32), 'X2': inp2.astype(numpy.float32) }) got = res['Y'] self.assertEqualArray(exp, got) for op in seq: if op.name == 'batch_dot': kind = op.get_dot_kind() self.assertEqual(kind, "11")
def test_decompose_einsum_equation_pyf(self): m1 = numpy.arange(0, 8).astype(numpy.float32).reshape((2, 2, 2)) m2 = numpy.arange(0, 4).astype(numpy.float32).reshape((2, 2)) seq = decompose_einsum_equation("bac,ch->ah", (2, 2, 2), (2, 2)) res1 = apply_einsum_sequence(seq, m1, m2) res2 = apply_einsum_sequence(seq, m1, m2, matmul_impl='pyf') self.assertEqualArray(res1, res2)
def test_case_1_iii_ii_i(self): verbose = False equation = 'ii->i' m1 = numpy.arange(2 * 2).reshape((2, 2)) + 10 exp = numpy.einsum(equation, m1) seq = decompose_einsum_equation(equation, m1.shape, verbose=verbose) res = apply_einsum_sequence(seq, m1, verbose=verbose) self.assertEqualArray(exp, res)
def test_case_1_iii_ii_i_j(self): verbose = False equation = 'iij->ij' m1 = numpy.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10 exp = numpy.einsum(equation, m1) seq = decompose_einsum_equation(equation, m1.shape, verbose=verbose) dot = seq.to_dot() self.assertIn("i=0,1", dot) res = apply_einsum_sequence(seq, m1, verbose=verbose) self.assertEqualArray(exp, res)
def common_test_case_2(self, equation, verbose=False, strategy='simple'): m1 = numpy.arange(2 * 2 * 2).reshape((2, 2, 2)) + 10 m2 = numpy.arange(4).reshape((2, 2)) + 100 exp = numpy.einsum(equation, m1, m2) seq = decompose_einsum_equation(equation, m1.shape, m2.shape, verbose=verbose, strategy=strategy) res = apply_einsum_sequence(seq, m1, m2, verbose=verbose) self.assertEqualArray(exp, res)
def fct(): print("########################## DECOMPOSE") seq = decompose_einsum_equation("bac,ch->ah", (2, 2, 2), (2, 2), verbose=True) print("########################## APPLY") dot = seq.to_dot() print(dot) red = dot.split('red') self.assertEqual(len(red), 5) res = apply_einsum_sequence(seq, m1, m2, verbose=True) print("########################## END") return res
def fct(): print("########################## DECOMPOSE") seq = decompose_einsum_equation("bac,chg->ah", (2, 2, 2), (2, 2, 2), verbose=True, clean=True, strategy='numpy') print("########################## APPLY") dot = seq.to_dot() print(dot) red = dot.split('red') self.assertEqual(len(red), 6) res = apply_einsum_sequence(seq, m1, m2, verbose=True) print("########################## END") onx = seq.to_onnx('Y', 'X1', 'X2', verbose=True) self.assertNotEmpty(onx) return res
def test_decompose_einsum_equation_onnx2(self): m1 = numpy.arange(0, 24).astype(numpy.float32).reshape((2, 3, 4)) m2 = numpy.arange(0, 20).astype(numpy.float32).reshape((4, 5)) m3 = numpy.arange(0, 77 * 5).astype(numpy.float32).reshape((5, 7, 11)) verbose = False for strat, opname in [('numpy', 'batch_dot')]: # pylint: disable=W0612 with self.subTest(strategy=strat): seq = decompose_einsum_equation("bac,cd,def->ebc", (2, 3, 4), (4, 5), (5, 7, 11), strategy=strat, verbose=verbose) res1 = apply_einsum_sequence(seq, m1, m2, m3, verbose=verbose) seq.simplify_mm_nodes() seq.clean_unused_nodes() onx = seq.to_onnx("Y", "X1", "X2", "X3", dtype=numpy.float32) oinf = OnnxInference(onx) oxres = oinf.run({ 'X1': m1.astype(numpy.float32), 'X2': m2.astype(numpy.float32), 'X3': m3.astype(numpy.float32) }) res2 = oxres['Y'] self.assertEqualArray(res1, res2) oinf = OnnxInference(onx, runtime="onnxruntime2") oxres = oinf.run({ 'X1': m1.astype(numpy.float32), 'X2': m2.astype(numpy.float32), 'X3': m3.astype(numpy.float32) }) res2 = oxres['Y'] self.assertEqualArray(res1, res2) so = SessionOptions() so.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL oinf = InferenceSession(onx.SerializeToString(), so) oxres = oinf.run( None, { 'X1': m1.astype(numpy.float32), 'X2': m2.astype(numpy.float32), 'X3': m3.astype(numpy.float32) }) res2 = oxres[0] self.assertEqualArray(res1, res2)
def local_test(inp1, inp2): exp = numpy.einsum('bid,nd->bin', inp1, inp2) seq = decompose_einsum_equation('bid,nd->bin', clean=True, strategy='numpy') got = apply_einsum_sequence(seq, inp1, inp2) self.assertEqualArray(exp, got, decimal=3) onx = seq.to_onnx('Y', 'X1', 'X2') oinf = OnnxInference(onx) got = oinf.run({'X1': inp1, 'X2': inp2})['Y'] self.assertEqualArray(exp, got, decimal=3) onx = seq.to_onnx( 'Y', 'X1', 'X2', initializer=[numpy_helper.from_array(inp2, name="X2")]) oinf = OnnxInference(onx) got = oinf.run({'X1': inp1})['Y'] self.assertEqualArray(exp, got, decimal=3)
def test_decompose_einsum_equation_py_noshape(self): m1 = numpy.arange(0, 24).astype(numpy.float32).reshape((2, 3, 4)) m2 = numpy.arange(0, 20).astype(numpy.float32).reshape((4, 5)) verbose = False for strat, opname in [('numpy', 'batch_dot'), ('simple', 'matmul')]: with self.subTest(strategy=strat): seq = decompose_einsum_equation("bac,ch->ah", strategy=strat, verbose=verbose) self.assertIn(opname, seq.to_dot()) res1 = apply_einsum_sequence(seq, m1, m2, verbose=verbose) res2 = apply_einsum_sequence(seq, m1, m2, matmul_impl='py', verbose=verbose) if strat == 'simple': self.assertRaise( lambda: apply_einsum_sequence( seq, m1, m2, matmul_impl='py2'), # pylint: disable=W0640 ValueError) self.assertEqualArray(res1, res2)
def test_decompose_einsum_equation_onnx(self): m1 = numpy.arange(0, 24).astype(numpy.float32).reshape((2, 3, 4)) m2 = numpy.arange(0, 20).astype(numpy.float32).reshape((4, 5)) verbose = False for strat, opname in [('numpy', 'batch_dot')]: # pylint: disable=W0612 with self.subTest(strategy=strat): seq = decompose_einsum_equation("bac,ch->ah", (2, 3, 4), (4, 5), strategy=strat, verbose=verbose) res1 = apply_einsum_sequence(seq, m1, m2, verbose=verbose) self.assertRaise( lambda: seq.to_onnx( # pylint: disable=W0640 "Y", "X1", "X2", dtype=numpy.float32), NotImplementedError) seq.simplify_mm_nodes() seq.clean_unused_nodes() onx = seq.to_onnx("Y", "X1", "X2", dtype=numpy.float32) oinf = OnnxInference(onx) oxres = oinf.run({ 'X1': m1.astype(numpy.float32), 'X2': m2.astype(numpy.float32) }) res2 = oxres['Y'] self.assertEqualArray(res1, res2) oinf = OnnxInference(onx, runtime="onnxruntime1") oxres = oinf.run({ 'X1': m1.astype(numpy.float32), 'X2': m2.astype(numpy.float32) }) res2 = oxres['Y'] self.assertEqualArray(res1, res2)