Ejemplo n.º 1
0
def test_matmul():
    a = sym.Variable("a", shape=(10, 20))
    b = sym.Variable("b", shape=(20, 30))
    c = sym.matmul(a, b, name="matmul")
    g, ldict = correct_layout(c, {"a": "HW", "b": "WC"})
    assert (ldict["a"][0] == "HW")
    assert (ldict["b"][0] == "WC")
    assert (ldict["matmul"][0] == "HC")
    # second pass will insert layout transform
    _, ldict = correct_layout(g, {"a": "HW16w", "b": "WC16c"})
    assert (ldict["a"][0] == "HW16w")
    assert (ldict["a_HW"][0] == "HW")
    assert (ldict["b"][0] == "WC16c")
    assert (ldict["b_WC"][0] == "WC")
    assert (ldict["matmul"][0] == "HC")
    a = sym.Variable("a", shape=(20, 10))
    c = sym.matmul(a, b, name="matmul", transpose_a=True)
    g, ldict = correct_layout(c, {"a": "HW", "b": "HC"})
    assert (ldict["a"][0] == "HW")
    assert (ldict["b"][0] == "HC")
    assert (ldict["matmul"][0] == "WC")
    b = sym.Variable("b", shape=(30, 20))
    c = sym.matmul(a, b, name="matmul", transpose_b=True)
    g, ldict = correct_layout(c, {"a": "HW", "b": "CW"})
    assert (ldict["a"][0] == "HW")
    assert (ldict["b"][0] == "CW")
    assert (ldict["matmul"][0] == "HC")
    a = sym.Variable("a", shape=(20, 10))
    b = sym.Variable("b", shape=(30, 20))
    c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True)
    g, ldict = correct_layout(c, {"a": "HW", "b": "CH"})
    assert (ldict["a"][0] == "HW")
    assert (ldict["b"][0] == "CH")
    assert (ldict["matmul"][0] == "WC")
Ejemplo n.º 2
0
def test_matmul():
    a = sym.Variable("a", shape=(10, 20))
    b = sym.Variable("b", shape=(20, 30))
    c = sym.matmul(a, b, name="matmul")
    g, ldict = correct_layout(c, {"a" : "HW", "b" : "WC"})
    assert(ldict["a"][0] == "HW")
    assert(ldict["b"][0] == "WC")
    assert(ldict["matmul"][0] == "HC")
    # second pass will insert layout transform
    _, ldict = correct_layout(g, {"a" : "HW16w", "b" : "WC16c"})
    assert(ldict["a"][0] == "HW16w")
    assert(ldict["a_HW"][0] == "HW")
    assert(ldict["b"][0] == "WC16c")
    assert(ldict["b_WC"][0] == "WC")
    assert(ldict["matmul"][0] == "HC")
    a = sym.Variable("a", shape=(20, 10))
    c = sym.matmul(a, b, name="matmul", transpose_a=True)
    g, ldict = correct_layout(c, {"a" : "HW", "b" : "HC"})
    assert(ldict["a"][0] == "HW")
    assert(ldict["b"][0] == "HC")
    assert(ldict["matmul"][0] == "WC")
    b = sym.Variable("b", shape=(30, 20))
    c = sym.matmul(a, b, name="matmul", transpose_b=True)
    g, ldict = correct_layout(c, {"a" : "HW", "b" : "CW"})
    assert(ldict["a"][0] == "HW")
    assert(ldict["b"][0] == "CW")
    assert(ldict["matmul"][0] == "HC")
    a = sym.Variable("a", shape=(20, 10))
    b = sym.Variable("b", shape=(30, 20))
    c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True)
    g, ldict = correct_layout(c, {"a" : "HW", "b" : "CH"})
    assert(ldict["a"][0] == "HW")
    assert(ldict["b"][0] == "CH")
    assert(ldict["matmul"][0] == "WC")
Ejemplo n.º 3
0
def test_matmul():
    a = sym.Variable('a', shape=(10, 20))
    b = sym.Variable('b', shape=(20, 30))
    c = sym.matmul(a, b, name="matmul")
    sdict = infer_shape(c)
    assert(sdict["matmul"][0] == [10, 30])
    a = sym.Variable('a', shape=(20, 10))
    c = sym.matmul(a, b, name="matmul", transpose_a=True)
    sdict = infer_shape(c)
    assert(sdict["matmul"][0] == [10, 30])
    b = sym.Variable('b', shape=(30, 20))
    c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True)
    sdict = infer_shape(c)
    assert(sdict["matmul"][0] == [10, 30])
    a = sym.Variable('a', shape=(10, 20))
    c = sym.matmul(a, b, name="matmul", transpose_b=True)
    sdict = infer_shape(c)
    assert(sdict["matmul"][0] == [10, 30])
    a = sym.Variable('a', shape=(10, 20, 30))
    b = sym.Variable('b', shape=(30, 40, 50))
    c = sym.matmul(a, b, name="matmul")
    sdict = infer_shape(c)
    assert(sdict["matmul"][0] == [10, 20, 40, 50])
    a = sym.Variable('a', shape=(30, 20, 10))
    b = sym.Variable('b', shape=(50, 40, 30))
    c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True)
    sdict = infer_shape(c)
    assert(sdict["matmul"][0] == [10, 20, 40, 50])
Ejemplo n.º 4
0
def test_matmul():
    a = sym.Variable('a', shape=(10, 20))
    b = sym.Variable('b', shape=(20, 30))
    c = sym.matmul(a, b, name="matmul")
    sdict = infer_shape(c)
    assert(sdict["matmul"][0] == [10, 30])
    a = sym.Variable('a', shape=(20, 10))
    c = sym.matmul(a, b, name="matmul", transpose_a=True)
    sdict = infer_shape(c)
    assert(sdict["matmul"][0] == [10, 30])
    b = sym.Variable('b', shape=(30, 20))
    c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True)
    sdict = infer_shape(c)
    assert(sdict["matmul"][0] == [10, 30])
    a = sym.Variable('a', shape=(10, 20))
    c = sym.matmul(a, b, name="matmul", transpose_b=True)
    sdict = infer_shape(c)
    assert(sdict["matmul"][0] == [10, 30])
    a = sym.Variable('a', shape=(10, 20, 30))
    b = sym.Variable('b', shape=(30, 40, 50))
    c = sym.matmul(a, b, name="matmul")
    sdict = infer_shape(c)
    assert(sdict["matmul"][0] == [10, 20, 40, 50])
    a = sym.Variable('a', shape=(30, 20, 10))
    b = sym.Variable('b', shape=(50, 40, 30))
    c = sym.matmul(a, b, name="matmul", transpose_a=True, transpose_b=True)
    sdict = infer_shape(c)
    assert(sdict["matmul"][0] == [10, 20, 40, 50])