def test_embedding_partial_shape():
    # testing embedding with batch size unknown
    x = mx.sym.Variable("x")
    w = mx.sym.Variable("w")
    y = mx.sym.Embedding(data=x, weight=w, input_dim=100, output_dim=10)
    _, result_shape, _ = y.infer_shape_partial(x=(0, 5), w=(100, 10))
    assert result_shape == [(0, 5, 10)]
    with mx.np_shape(True):
        _, result_shape, _ = y.infer_shape_partial(x=(-1, 5), w=(100, 10))
        assert result_shape == [(-1, 5, 10)]
def test_dot_partial_shape():
    x = mx.sym.Variable("x")
    y = mx.sym.Variable("y")
    z = mx.sym.dot(x, y)
    # batch size(first dim) of lhs unknown
    _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(4, 5))
    assert result_shape == [(0, 3, 5)]
    with mx.np_shape(True):
        _, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(4, 5))
        assert result_shape == [(-1, 3, 5)]
def test_transpose_partial_shape():
    # test converting tensor shape
    # from channels first to channels last
    # with batch size unknown
    axes = [0, 3, 2, 1]
    x = mx.sym.Variable("x")
    y = mx.sym.transpose(x, axes=axes)
    _, result, _ = y.infer_shape_partial(x=(0, 3, 224, 224))
    assert result == [(0, 224, 224, 3)]

    with mx.np_shape(True):
        _, result, _ = y.infer_shape_partial(x=(-1, 3, 224, 224))
        assert result == [(-1, 224, 224, 3)]
def test_shape_completely_unknown():
    data = mx.sym.var("data")
    ret = mx.sym.sin(data)
    arg_shapes, out_shapes, _ = ret.infer_shape_partial()
    assert arg_shapes[0] == ()
    assert out_shapes[0] == ()

    with mx.np_shape():
        data = mx.sym.var("data")
        ret = mx.sym.sin(data)
        arg_shapes, out_shapes, _ = ret.infer_shape_partial()
        assert arg_shapes[0] is None
        assert out_shapes[0] is None
def test_batch_dot_partial_shape():
    x = mx.sym.Variable("x")
    y = mx.sym.Variable("y")
    z = mx.sym.batch_dot(x, y)
    # lhs and rhs batch size unknown
    _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 4, 5))
    assert result_shape == [(0, 3, 5)]
    # rhs second dim unknown
    _, result_shape, _ = z.infer_shape_partial(x=(0, 3, 4), y=(0, 0, 5))
    assert result_shape == [()]
    with mx.np_shape(True):
        _, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, 4, 5))
        assert result_shape == [(-1, 3, 5)]
        _, result_shape, _ = z.infer_shape_partial(x=(-1, 3, 4), y=(-1, -1, 5))
        assert result_shape == [None]
def test_pick_partial_shape():
    x = mx.sym.Variable("x")
    index = mx.sym.Variable("index")
    y = mx.sym.pick(x, index, axis=1)
    # batch size unknown
    _, result, _ = y.infer_shape_partial(x=(0, 3, 3), index=(
        0,
        3,
    ))
    assert result == [(0, 3)]
    with mx.np_shape(True):
        _, result, _ = y.infer_shape_partial(x=(-1, 3, 3), index=(
            -1,
            3,
        ))
        assert result == [(-1, 3)]
def test_where_partial_shape():
    x = mx.sym.Variable("x")
    y = mx.sym.Variable("y")
    cond = mx.sym.Variable("cond")
    where_op = mx.sym.where(cond, x, y)
    # condition must be fully known to infer shape
    _, result, _ = where_op.infer_shape_partial(cond=(0, 2),
                                                x=(0, 2),
                                                y=(0, 2))
    assert result == [()]
    _, result, _ = where_op.infer_shape_partial(cond=(0, ), x=(2, 2), y=(2, 2))
    assert result == [()]
    with mx.np_shape(True):
        _, result, _ = where_op.infer_shape_partial(cond=(-1, 2),
                                                    x=(-1, 2),
                                                    y=(-1, 2))
        assert result == [None]
        _, result, _ = where_op.infer_shape_partial(cond=(-1, ),
                                                    x=(2, 2),
                                                    y=(2, 2))
        assert result == [None]