Esempio n. 1
0
def test_pick_partial_shape():
    x = mx.sym.Variable("x")
    index = mx.sym.Variable("index")
    y = mx.sym.pick(x, index, axis=1)
    # batch size unknown
    _, result, _ =  y.infer_shape_partial(x=(0, 3, 3), index=(0, 3,))
    assert result == [(0, 3)]
    with mx.np_compat(True):
        _, result, _ =  y.infer_shape_partial(x=(-1, 3, 3), index=(-1, 3,))
        assert result == [(-1, 3)]
Esempio n. 2
0
def test_embedding_partial_shape():
    # testing embedding with batch size unknown
    x = mx.sym.Variable("x")
    w = mx.sym.Variable("w")
    y = mx.sym.Embedding(data=x, weight=w, input_dim=100, output_dim=10)
    _, result_shape, _ = y.infer_shape_partial(x=(0, 5), w=(100, 10))
    assert result_shape  == [(0, 5, 10)]
    with mx.np_compat(True):
        _, result_shape, _ = y.infer_shape_partial(x=(-1, 5), w=(100, 10))
        assert result_shape == [(-1, 5, 10)]
Esempio n. 3
0
def test_dot_partial_shape():
    x = mx.sym.Variable("x")
    y = mx.sym.Variable("y")
    z = mx.sym.dot(x, y)
    # batch size(first dim) of lhs unknown
    _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(4, 5))
    assert result_shape == [(0, 3, 5)]
    with mx.np_compat(True):
        _, result_shape, _ =  z.infer_shape_partial(x=(-1, 3, 4), y=(4, 5))
        assert result_shape == [(-1, 3, 5)]
Esempio n. 4
0
def test_shape_completely_unknown():
    data = mx.sym.var("data")
    ret = mx.sym.sin(data)
    arg_shapes, out_shapes, _ = ret.infer_shape_partial()
    assert arg_shapes[0] == ()
    assert out_shapes[0] == ()

    with mx.np_compat():
        data = mx.sym.var("data")
        ret = mx.sym.sin(data)
        arg_shapes, out_shapes, _ = ret.infer_shape_partial()
        assert arg_shapes[0] is None
        assert out_shapes[0] is None
Esempio n. 5
0
def test_shape_completely_unknown():
    data = mx.sym.var("data")
    ret = mx.sym.sin(data)
    arg_shapes, out_shapes, _ = ret.infer_shape_partial()
    assert arg_shapes[0] == ()
    assert out_shapes[0] == ()

    with mx.np_compat():
        data = mx.sym.var("data")
        ret = mx.sym.sin(data)
        arg_shapes, out_shapes, _ = ret.infer_shape_partial()
        assert arg_shapes[0] is None
        assert out_shapes[0] is None
Esempio n. 6
0
def test_transpose_partial_shape():
    # test converting tensor shape
    # from channels first to channels last
    # with batch size unknown
    axes = [0, 3, 2, 1]
    x = mx.sym.Variable("x")
    y = mx.sym.transpose(x, axes=axes)
    _, result, _ = y.infer_shape_partial(x=(0, 3, 224, 224))
    assert result == [(0, 224, 224, 3)]

    with mx.np_compat(True):
        _, result, _ = y.infer_shape_partial(x=(-1, 3, 224, 224))
        assert result == [(-1, 224, 224, 3)]
Esempio n. 7
0
def test_where_partial_shape():
    x = mx.sym.Variable("x")
    y = mx.sym.Variable("y")
    cond = mx.sym.Variable("cond")
    where_op = mx.sym.where(cond, x, y)
    # condition must be fully known to infer shape
    _, result, _ = where_op.infer_shape_partial(cond=(0, 2), x=(0, 2), y =(0, 2))
    assert result == [()]
    _, result, _ = where_op.infer_shape_partial(cond=(0,), x=(2, 2), y =(2, 2))
    assert result == [()]
    with mx.np_compat(True):
        _, result, _ =  where_op.infer_shape_partial(cond=(-1, 2), x=(-1, 2), y =(-1, 2))
        assert result == [None]
        _, result, _ = where_op.infer_shape_partial(cond=(-1,), x=(2, 2), y=(2, 2))
        assert result == [None]
Esempio n. 8
0
def test_batch_dot_partial_shape():
    x = mx.sym.Variable("x")
    y = mx.sym.Variable("y")
    z = mx.sym.batch_dot(x, y)
    # lhs and rhs batch size unknown
    _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 4, 5))
    assert result_shape == [(0, 3, 5)]
    # rhs second dim unknown
    _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 0, 5))
    assert result_shape == [()]
    with mx.np_compat(True):
        _, result_shape, _ =  z.infer_shape_partial(x=(-1, 3, 4), y=(-1, 4, 5))
        assert result_shape == [(-1, 3, 5)]
        _, result_shape, _ =  z.infer_shape_partial(x=(-1, 3, 4), y=(-1, -1, 5))
        assert result_shape == [None]