Exemplo n.º 1
0
    def map_stack(self, expr: Stack) -> Array:
        def get_subscript(array_index: int) -> SymbolicIndex:
            result = []
            for i in range(expr.ndim):
                if i != expr.axis:
                    result.append(var(f"_{i}"))
            return tuple(result)

        # I = axis index
        #
        # => If(_I == 0,
        #        _in0[_0, _1, ...],
        #        If(_I == 1,
        #            _in1[_0, _1, ...],
        #            ...
        #                _inNm1[_0, _1, ...] ...))
        for i in range(len(expr.arrays) - 1, -1, -1):
            subarray_expr = var(f"_in{i}")[get_subscript(i)]
            if i == len(expr.arrays) - 1:
                stack_expr = subarray_expr
            else:
                from pymbolic.primitives import If, Comparison
                stack_expr = If(Comparison(var(f"_{expr.axis}"), "==", i),
                                subarray_expr, stack_expr)

        bindings = {
            f"_in{i}": self.rec(array)
            for i, array in enumerate(expr.arrays)
        }

        return IndexLambda(namespace=self.namespace,
                           expr=stack_expr,
                           shape=expr.shape,
                           dtype=expr.dtype,
                           bindings=bindings)
Exemplo n.º 2
0
    def map_call(self, expr):
        from loopy.library.reduction import parse_reduction_op

        if not isinstance(expr.function, p.Variable):
            return IdentityMapper.map_call(self, expr)

        name = expr.function.name
        if name == "cse":
            if len(expr.parameters) in [1, 2]:
                if len(expr.parameters) == 2:
                    if not isinstance(expr.parameters[1], p.Variable):
                        raise TypeError("second argument to cse() must be a symbol")
                    tag = expr.parameters[1].name
                else:
                    tag = None

                return p.CommonSubexpression(
                        self.rec(expr.parameters[0]), tag)
            else:
                raise TypeError("cse takes two arguments")

        elif name in ["reduce", "simul_reduce"]:

            if len(expr.parameters) >= 3:
                operation, inames = expr.parameters[:2]
                red_exprs = expr.parameters[2:]

                operation = parse_reduction_op(str(operation))
                return self._parse_reduction(operation, inames,
                        tuple(self.rec(red_expr) for red_expr in red_exprs),
                        allow_simultaneous=(name == "simul_reduce"))
            else:
                raise TypeError("invalid 'reduce' calling sequence")

        elif name == "if":
            if len(expr.parameters) == 3:
                from pymbolic.primitives import If
                return If(*tuple(self.rec(p) for p in expr.parameters))
            else:
                raise TypeError("if takes three arguments")

        else:
            # see if 'name' is an existing reduction op

            operation = parse_reduction_op(name)
            if operation:
                # arg_count counts arguments but not inames
                if len(expr.parameters) != 1 + operation.arg_count:
                    raise RuntimeError("invalid invocation of "
                            "reduction operation '%s': expected %d arguments, "
                            "got %d instead" % (expr.function.name,
                                                1 + operation.arg_count,
                                                len(expr.parameters)))

                inames = expr.parameters[0]
                red_exprs = tuple(self.rec(param) for param in expr.parameters[1:])
                return self._parse_reduction(operation, inames, red_exprs)

            else:
                return IdentityMapper.map_call(self, expr)
Exemplo n.º 3
0
    def map_concatenate(self, expr: Concatenate) -> Array:
        from pymbolic.primitives import If, Comparison, Subscript

        def get_subscript(array_index: int,
                          offset: ScalarExpression) -> Subscript:
            aggregate = var(f"_in{array_index}")
            index = [
                var(f"_{i}") if i != expr.axis else (var(f"_{i}") - offset)
                for i in range(len(expr.shape))
            ]
            return Subscript(aggregate, tuple(index))

        lbounds: List[Any] = [0]
        ubounds: List[Any] = [expr.arrays[0].shape[expr.axis]]

        for i, array in enumerate(expr.arrays[1:], start=1):
            ubounds.append(ubounds[i - 1] + array.shape[expr.axis])
            lbounds.append(ubounds[i - 1])

        # I = axis index
        #
        # => If(0<=_I < arrays[0].shape[axis],
        #        _in0[_0, _1, ..., _I, ...],
        #        If(arrays[0].shape[axis]<= _I < (arrays[1].shape[axis]
        #                                         +arrays[0].shape[axis]),
        #            _in1[_0, _1, ..., _I-arrays[0].shape[axis], ...],
        #            ...
        #                _inNm1[_0, _1, ...] ...))
        for i in range(len(expr.arrays) - 1, -1, -1):
            lbound, ubound = lbounds[i], ubounds[i]
            subarray_expr = get_subscript(i, lbound)
            if i == len(expr.arrays) - 1:
                stack_expr = subarray_expr
            else:
                stack_expr = If(
                    Comparison(var(f"_{expr.axis}"), ">=", lbound)
                    and Comparison(var(f"_{expr.axis}"), "<", ubound),
                    subarray_expr, stack_expr)

        bindings = {
            f"_in{i}": self.rec(array)
            for i, array in enumerate(expr.arrays)
        }

        return IndexLambda(namespace=self.namespace,
                           expr=stack_expr,
                           shape=expr.shape,
                           dtype=expr.dtype,
                           bindings=bindings)
Exemplo n.º 4
0
def pw_aff_to_expr(pw_aff, int_ok=False):
    if isinstance(pw_aff, int):
        if not int_ok:
            from warnings import warn
            warn("expected PwAff, got int", stacklevel=2)

        return pw_aff

    pieces = pw_aff.get_pieces()
    last_expr = aff_to_expr(pieces[-1][1])

    pairs = [(set_to_cond_expr(constr_set), aff_to_expr(aff))
             for constr_set, aff in pieces[:-1]]

    from pymbolic.primitives import If
    expr = last_expr
    for condition, then_expr in reversed(pairs):
        expr = If(condition, then_expr, expr)

    return expr
Exemplo n.º 5
0
 def is_odd(expr):
     from pymbolic.primitives import If, Comparison, Remainder
     return If(Comparison(Remainder(expr, 2), "==", 1), 1, 0)