Ejemplo n.º 1
0
def get_strides(index, stride_vars):
    """Get the striding of given vars in an indexing expression.

    Parameters
    ----------
    index : tvm.tir.PrimExpr
        The index expression where the stride vars are present.
    stride_vars : list of tvm.tir.Var
        The vars to determine the striding of.

    Returns
    -------
    strides : list of int
        The striding of each stride var in the index expression
        in the same order as the stride vars were given.

    """
    strides = [1] * len(stride_vars)
    dmap = {}

    def _visit(stmt):
        if isinstance(stmt, tvm.tir.Var):
            dmap[stmt] = arith.IntervalSet(0, 0)

    tvm.tir.stmt_functor.post_order_visit(index, _visit)
    min_value = int(arith.Analyzer().int_set(index, dmap).min_value)
    for var in dmap:
        if var in stride_vars:
            # NOTE: Doing this using a [0, 1] interval doesn't work reliably
            # Seems to be a bug
            dmap[var] = arith.IntervalSet(1, 1)
            max_value = int(arith.Analyzer().int_set(index, dmap).max_value)
            stride = int(max_value - min_value)
            i = stride_vars.index(var)
            strides[i] = stride
            dmap[var] = arith.IntervalSet(0, 0)

    return strides
Ejemplo n.º 2
0
 def _visit(stmt):
     if isinstance(stmt, tvm.tir.Var):
         dmap[stmt] = arith.IntervalSet(0, 0)