def test_matmul(): a = sym.Variable("a", shape=(10, 20)) b = sym.Variable("b", shape=(20, 30)) c = sym.matmul(a, b, name="matmul") g, ldict = correct_layout(c, {"a": "HW", "b": "WC"}) assert (ldict["a"][0] == "HW") assert (ldict["b"][0] == "WC") assert (ldict["matmul"][0] == "HC") # second pass will insert layout transform _, ldict = correct_layout(g, {"a": "HW16w", "b": "WC16c"}) assert (ldict["a"][0] == "HW16w") assert (ldict["a_HW"][0] == "HW") assert (ldict["b"][0] == "WC16c") assert (ldict["b_WC"][0] == "WC") assert (ldict["matmul"][0] == "HC") a = sym.Variable("a", shape=(20, 10)) c = sym.matmul(a, b, name="matmul", transpose_a=True) g, ldict = correct_layout(c, {"a": "HW", "b": "HC"}) assert (ldict["a"][0] == "HW") assert (ldict["b"][0] == "HC") assert (ldict["matmul"][0] == "WC") b = sym.Variable("b", shape=(30, 20)) c = sym.matmul(a, b, name="matmul", transpose_b=True) g, ldict = correct_layout(c, {"a": "HW", "b": "CW"}) assert (ldict["a"][0] == "HW") assert (ldict["b"][0] == "CW") assert (ldict["matmul"][0] == "HC") a = sym.Variable("a", shape=(20, 10)) b = sym.Variable("b", shape=(30, 20)) c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True) g, ldict = correct_layout(c, {"a": "HW", "b": "CH"}) assert (ldict["a"][0] == "HW") assert (ldict["b"][0] == "CH") assert (ldict["matmul"][0] == "WC")
def test_matmul(): a = sym.Variable("a", shape=(10, 20)) b = sym.Variable("b", shape=(20, 30)) c = sym.matmul(a, b, name="matmul") g, ldict = correct_layout(c, {"a" : "HW", "b" : "WC"}) assert(ldict["a"][0] == "HW") assert(ldict["b"][0] == "WC") assert(ldict["matmul"][0] == "HC") # second pass will insert layout transform _, ldict = correct_layout(g, {"a" : "HW16w", "b" : "WC16c"}) assert(ldict["a"][0] == "HW16w") assert(ldict["a_HW"][0] == "HW") assert(ldict["b"][0] == "WC16c") assert(ldict["b_WC"][0] == "WC") assert(ldict["matmul"][0] == "HC") a = sym.Variable("a", shape=(20, 10)) c = sym.matmul(a, b, name="matmul", transpose_a=True) g, ldict = correct_layout(c, {"a" : "HW", "b" : "HC"}) assert(ldict["a"][0] == "HW") assert(ldict["b"][0] == "HC") assert(ldict["matmul"][0] == "WC") b = sym.Variable("b", shape=(30, 20)) c = sym.matmul(a, b, name="matmul", transpose_b=True) g, ldict = correct_layout(c, {"a" : "HW", "b" : "CW"}) assert(ldict["a"][0] == "HW") assert(ldict["b"][0] == "CW") assert(ldict["matmul"][0] == "HC") a = sym.Variable("a", shape=(20, 10)) b = sym.Variable("b", shape=(30, 20)) c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True) g, ldict = correct_layout(c, {"a" : "HW", "b" : "CH"}) assert(ldict["a"][0] == "HW") assert(ldict["b"][0] == "CH") assert(ldict["matmul"][0] == "WC")
def test_matmul(): a = sym.Variable('a', shape=(10, 20)) b = sym.Variable('b', shape=(20, 30)) c = sym.matmul(a, b, name="matmul") sdict = infer_shape(c) assert(sdict["matmul"][0] == [10, 30]) a = sym.Variable('a', shape=(20, 10)) c = sym.matmul(a, b, name="matmul", transpose_a=True) sdict = infer_shape(c) assert(sdict["matmul"][0] == [10, 30]) b = sym.Variable('b', shape=(30, 20)) c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True) sdict = infer_shape(c) assert(sdict["matmul"][0] == [10, 30]) a = sym.Variable('a', shape=(10, 20)) c = sym.matmul(a, b, name="matmul", transpose_b=True) sdict = infer_shape(c) assert(sdict["matmul"][0] == [10, 30]) a = sym.Variable('a', shape=(10, 20, 30)) b = sym.Variable('b', shape=(30, 40, 50)) c = sym.matmul(a, b, name="matmul") sdict = infer_shape(c) assert(sdict["matmul"][0] == [10, 20, 40, 50]) a = sym.Variable('a', shape=(30, 20, 10)) b = sym.Variable('b', shape=(50, 40, 30)) c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True) sdict = infer_shape(c) assert(sdict["matmul"][0] == [10, 20, 40, 50])