Esempio n. 1
0
    def finish_adaptive(self, cb, high_order_estimate, low_order_estimate):
        from pymbolic import var
        from pymbolic.primitives import Comparison, LogicalOr, Max, Min
        from dagrt.expression import IfThenElse

        norm_start_state = var("norm_start_state")
        norm_end_state = var("norm_end_state")
        rel_error_raw = var("rel_error_raw")
        rel_error = var("rel_error")

        def norm(expr):
            return var("<builtin>norm_2")(expr)

        cb(norm_start_state, norm(self.state))
        cb(norm_end_state, norm(low_order_estimate))
        cb(
            rel_error_raw,
            norm(high_order_estimate - low_order_estimate) /
            (var("<builtin>len")(self.state)**0.5 *
             (self.atol + self.rtol * Max(
                 (norm_start_state, norm_end_state)))))

        cb(
            rel_error,
            IfThenElse(Comparison(rel_error_raw, "==", 0), 1.0e-14,
                       rel_error_raw))

        with cb.if_(
                LogicalOr((Comparison(rel_error, ">",
                                      1), var("<builtin>isnan")(rel_error)))):

            with cb.if_(var("<builtin>isnan")(rel_error)):
                cb(self.dt, self.min_dt_shrinkage * self.dt)
            with cb.else_():
                cb(
                    self.dt,
                    Max((0.9 * self.dt * rel_error**(-1 / self.low_order),
                         self.min_dt_shrinkage * self.dt)))

            with cb.if_(self.t + self.dt, "==", self.t):
                cb.raise_(TimeStepUnderflow)
            with cb.else_():
                cb.fail_step()

        with cb.else_():
            # This updates <t>: <dt> should not be set before this is called.
            self.finish_nonadaptive(cb, high_order_estimate,
                                    low_order_estimate)

            cb(
                self.dt,
                Min((0.9 * self.dt * rel_error**(-1 / self.high_order),
                     self.max_dt_growth * self.dt)))
Esempio n. 2
0
    def map_concatenate(self, expr: Concatenate) -> Array:
        from pymbolic.primitives import If, Comparison, Subscript

        def get_subscript(array_index: int,
                          offset: ScalarExpression) -> Subscript:
            aggregate = var(f"_in{array_index}")
            index = [
                var(f"_{i}") if i != expr.axis else (var(f"_{i}") - offset)
                for i in range(len(expr.shape))
            ]
            return Subscript(aggregate, tuple(index))

        lbounds: List[Any] = [0]
        ubounds: List[Any] = [expr.arrays[0].shape[expr.axis]]

        for i, array in enumerate(expr.arrays[1:], start=1):
            ubounds.append(ubounds[i - 1] + array.shape[expr.axis])
            lbounds.append(ubounds[i - 1])

        # I = axis index
        #
        # => If(0<=_I < arrays[0].shape[axis],
        #        _in0[_0, _1, ..., _I, ...],
        #        If(arrays[0].shape[axis]<= _I < (arrays[1].shape[axis]
        #                                         +arrays[0].shape[axis]),
        #            _in1[_0, _1, ..., _I-arrays[0].shape[axis], ...],
        #            ...
        #                _inNm1[_0, _1, ...] ...))
        for i in range(len(expr.arrays) - 1, -1, -1):
            lbound, ubound = lbounds[i], ubounds[i]
            subarray_expr = get_subscript(i, lbound)
            if i == len(expr.arrays) - 1:
                stack_expr = subarray_expr
            else:
                stack_expr = If(
                    Comparison(var(f"_{expr.axis}"), ">=", lbound)
                    and Comparison(var(f"_{expr.axis}"), "<", ubound),
                    subarray_expr, stack_expr)

        bindings = {
            f"_in{i}": self.rec(array)
            for i, array in enumerate(expr.arrays)
        }

        return IndexLambda(namespace=self.namespace,
                           expr=stack_expr,
                           shape=expr.shape,
                           dtype=expr.dtype,
                           bindings=bindings)
Esempio n. 3
0
def constraint_to_expr(cns):
    # Looks like this is ok after all--get_aff() performs some magic.
    # Not entirely sure though... FIXME
    #
    #ls = cns.get_local_space()
    #if ls.dim(dim_type.div):
        #raise RuntimeError("constraint has an existentially quantified variable")

    expr = aff_to_expr(cns.get_aff())

    from pymbolic.primitives import Comparison
    if cns.is_equality():
        return Comparison(expr, "==", 0)
    else:
        return Comparison(expr, ">=", 0)
Esempio n. 4
0
    def wrap_in_typecast_lazy(self, actual_dtype, needed_dtype, s):
        if needed_dtype.dtype.kind == "b" and actual_dtype().dtype.kind == "f":
            # CL does not perform implicit conversion from float-type to a bool.
            from pymbolic.primitives import Comparison
            return Comparison(s, "!=", 0)

        return super().wrap_in_typecast_lazy(actual_dtype, needed_dtype, s)
Esempio n. 5
0
    def parse_postfix(self, pstate, min_precedence, left_exp):
        from pymbolic.parser import (
                _PREC_CALL, _PREC_COMPARISON, _openpar,
                _PREC_LOGICAL_OR, _PREC_LOGICAL_AND)
        from pymbolic.primitives import (
                Comparison, LogicalAnd, LogicalOr)

        next_tag = pstate.next_tag()
        if next_tag is _openpar and _PREC_CALL > min_precedence:
            raise TranslationError("parenthesis operator only works on names")

        elif next_tag in self.COMP_MAP and _PREC_COMPARISON > min_precedence:
            pstate.advance()
            left_exp = Comparison(
                    left_exp,
                    self.COMP_MAP[next_tag],
                    self.parse_expression(pstate, _PREC_COMPARISON))
            did_something = True
        elif next_tag is _and and _PREC_LOGICAL_AND > min_precedence:
            pstate.advance()
            left_exp = LogicalAnd((left_exp,
                    self.parse_expression(pstate, _PREC_LOGICAL_AND)))
            did_something = True
        elif next_tag is _or and _PREC_LOGICAL_OR > min_precedence:
            pstate.advance()
            left_exp = LogicalOr((left_exp,
                    self.parse_expression(pstate, _PREC_LOGICAL_OR)))
            did_something = True
        else:
            left_exp, did_something = ExpressionParserBase.parse_postfix(
                    self, pstate, min_precedence, left_exp)

        return left_exp, did_something
Esempio n. 6
0
    def map_stack(self, expr: Stack) -> Array:
        def get_subscript(array_index: int) -> SymbolicIndex:
            result = []
            for i in range(expr.ndim):
                if i != expr.axis:
                    result.append(var(f"_{i}"))
            return tuple(result)

        # I = axis index
        #
        # => If(_I == 0,
        #        _in0[_0, _1, ...],
        #        If(_I == 1,
        #            _in1[_0, _1, ...],
        #            ...
        #                _inNm1[_0, _1, ...] ...))
        for i in range(len(expr.arrays) - 1, -1, -1):
            subarray_expr = var(f"_in{i}")[get_subscript(i)]
            if i == len(expr.arrays) - 1:
                stack_expr = subarray_expr
            else:
                from pymbolic.primitives import If, Comparison
                stack_expr = If(Comparison(var(f"_{expr.axis}"), "==", i),
                                subarray_expr, stack_expr)

        bindings = {
            f"_in{i}": self.rec(array)
            for i, array in enumerate(expr.arrays)
        }

        return IndexLambda(namespace=self.namespace,
                           expr=stack_expr,
                           shape=expr.shape,
                           dtype=expr.dtype,
                           bindings=bindings)
Esempio n. 7
0
    def if_(self, *condition_arg):
        """Create a new block that is conditionally executed."""
        from dagrt.expression import parse

        if len(condition_arg) == 1:
            condition = condition_arg[0]

            if isinstance(condition, str):
                condition = parse(condition)

        elif len(condition_arg) == 3:
            lhs, cond, rhs = condition_arg

            if isinstance(lhs, str):
                lhs = parse(lhs)
            if isinstance(rhs, str):
                rhs = parse(lhs)

            from pymbolic.primitives import Comparison
            condition = Comparison(lhs, cond, rhs)
        else:
            raise ValueError("Unrecognized condition expression")

        # Create an statement as a lead statement to assign a logical flag.
        cond_var = self.fresh_var("<cond>")
        cond_assignment = Assign(assignee=cond_var.name,
                                 assignee_subscript=(),
                                 expression=condition)

        self._add_statement(cond_assignment)

        self._conditional_expression_stack.append(cond_var)
        yield
        self._conditional_expression_stack.pop()
        self._last_if_block_conditional_expression = cond_var
Esempio n. 8
0
    def emit_sequential_loop(self, codegen_state, iname, iname_dtype, lbound,
                             ubound, inner):
        ecm = codegen_state.expression_to_code_mapper

        from pymbolic import var
        from pymbolic.primitives import Comparison
        from pymbolic.mapper.stringifier import PREC_NONE
        from cgen import For, InlineInitializer

        return For(
            InlineInitializer(POD(self, iname_dtype, iname),
                              ecm(lbound, PREC_NONE, "i")),
            ecm(Comparison(var(iname), "<=", ubound), PREC_NONE, "i"),
            "++%s" % iname, inner)
Esempio n. 9
0
    def make_spectra_knl(self, is_real, rank_shape):
        from pymbolic import var, parse
        indices = i, j, k = parse("i, j, k")
        momenta = [var("momenta_"+xx) for xx in ("x", "y", "z")]
        ksq = sum((dk_i * mom[ii])**2
                  for mom, dk_i, ii in zip(momenta, self.dk, indices))
        kmag = var("sqrt")(ksq)
        bin_expr = var("round")(kmag / self.bin_width)

        if is_real:
            from pymbolic.primitives import If, Comparison, LogicalAnd
            nyq = self.grid_shape[-1] / 2
            condition = LogicalAnd((Comparison(momenta[2][k], ">", 0),
                                    Comparison(momenta[2][k], "<", nyq)))
            count = If(condition, 2, 1)
        else:
            count = 1

        fk = var("fk")[i, j, k]
        weight_expr = count * kmag**(var("k_power")) * var("abs")(fk)**2

        histograms = {"spectrum": (bin_expr, weight_expr)}

        args = [
            lp.GlobalArg("fk", self.cdtype, shape=("Nx", "Ny", "Nz"),
                         offset=lp.auto),
            lp.GlobalArg("momenta_x", self.rdtype, shape=("Nx",)),
            lp.GlobalArg("momenta_y", self.rdtype, shape=("Ny",)),
            lp.GlobalArg("momenta_z", self.rdtype, shape=("Nz",)),
            lp.ValueArg("k_power", self.rdtype),
            ...
        ]

        from pystella.histogram import Histogrammer
        return Histogrammer(self.decomp, histograms, self.num_bins,
                            self.rdtype, args=args, rank_shape=rank_shape)
Esempio n. 10
0
    def parse_postfix(self, pstate, min_precedence, left_exp):
        from pymbolic.parser import (_PREC_CALL, _PREC_COMPARISON, _openpar,
                                     _PREC_LOGICAL_OR, _PREC_LOGICAL_AND)
        from pymbolic.primitives import (Comparison, LogicalAnd, LogicalOr)

        next_tag = pstate.next_tag()
        if next_tag is _openpar and _PREC_CALL > min_precedence:
            raise TranslationError("parenthesis operator only works on names")

        elif next_tag in self.COMP_MAP and _PREC_COMPARISON > min_precedence:
            pstate.advance()
            left_exp = Comparison(
                left_exp, self.COMP_MAP[next_tag],
                self.parse_expression(pstate, _PREC_COMPARISON))
            did_something = True
        elif next_tag is _and and _PREC_LOGICAL_AND > min_precedence:
            pstate.advance()
            left_exp = LogicalAnd(
                (left_exp, self.parse_expression(pstate, _PREC_LOGICAL_AND)))
            did_something = True
        elif next_tag is _or and _PREC_LOGICAL_OR > min_precedence:
            pstate.advance()
            left_exp = LogicalOr(
                (left_exp, self.parse_expression(pstate, _PREC_LOGICAL_OR)))
            did_something = True
        else:
            left_exp, did_something = ExpressionParserBase.parse_postfix(
                self, pstate, min_precedence, left_exp)

            if isinstance(left_exp,
                          tuple) and min_precedence < self._PREC_FUNC_ARGS:
                # this must be a complex literal
                if len(left_exp) != 2:
                    raise TranslationError("complex literals must have "
                                           "two entries")

                r, i = left_exp

                dtype = (r.dtype.type(0) + i.dtype.type(0))
                if dtype == np.float32:
                    dtype = np.complex64
                else:
                    dtype = np.complex128

                left_exp = dtype(float(r) + float(i) * 1j)

        return left_exp, did_something
Esempio n. 11
0
    def __init__(self, fft, dk, dx, effective_k):
        self.fft = fft
        grid_size = fft.grid_shape[0] * fft.grid_shape[1] * fft.grid_shape[2]

        queue = self.fft.sub_k["momenta_x"].queue
        sub_k = list(x.get().astype("int") for x in self.fft.sub_k.values())
        k_names = ("k_x", "k_y", "k_z")
        self.momenta = {}
        self.momenta = {}
        for mu, (name, kk) in enumerate(zip(k_names, sub_k)):
            kk_mu = effective_k(dk[mu] * kk.astype(fft.rdtype), dx[mu])
            self.momenta[name] = cla.to_device(queue, kk_mu)

        args = [
            lp.GlobalArg("fk", fft.cdtype, shape="(Nx, Ny, Nz)"),
            lp.GlobalArg("k_x", fft.rdtype, shape=("Nx", )),
            lp.GlobalArg("k_y", fft.rdtype, shape=("Ny", )),
            lp.GlobalArg("k_z", fft.rdtype, shape=("Nz", )),
            lp.ValueArg("m_squared", fft.rdtype),
        ]

        from pystella.field import Field
        from pymbolic.primitives import Variable, If, Comparison

        fk = Field("fk")
        indices = fk.indices
        rho_tmp = Variable("rho_tmp")
        tmp_insns = [(rho_tmp, Field("rhok") * (1 / grid_size))]

        mom_vars = tuple(Variable(name) for name in k_names)
        minus_k_squared = sum(kk_i[x_i]
                              for kk_i, x_i in zip(mom_vars, indices))
        sol = rho_tmp / (minus_k_squared - Variable("m_squared"))

        solution = {
            Field("fk"): If(Comparison(minus_k_squared, "<", 0), sol, 0)
        }

        from pystella.elementwise import ElementWiseMap
        options = lp.Options(return_dict=True)
        self.knl = ElementWiseMap(solution,
                                  args=args,
                                  halo_shape=0,
                                  options=options,
                                  tmp_instructions=tmp_insns,
                                  lsize=(16, 2, 1))
Esempio n. 12
0
    def parse_postfix(self, pstate, min_precedence, left_exp):
        import pymbolic.primitives as primitives

        did_something = False

        next_tag = pstate.next_tag()

        if next_tag is _openpar and _PREC_CALL > min_precedence:
            pstate.advance()
            args, kwargs = self.parse_arglist(pstate)

            if kwargs:
                left_exp = primitives.CallWithKwargs(left_exp, args, kwargs)
            else:
                left_exp = primitives.Call(left_exp, args)

            did_something = True
        elif next_tag is _openbracket and _PREC_CALL > min_precedence:
            pstate.advance()
            pstate.expect_not_end()
            left_exp = primitives.Subscript(left_exp, self.parse_expression(pstate))
            pstate.expect(_closebracket)
            pstate.advance()
            did_something = True
        elif next_tag is _if and _PREC_IF > min_precedence:
            from pymbolic.primitives import If
            then_expr = left_exp
            pstate.advance()
            pstate.expect_not_end()
            condition = self.parse_expression(pstate, _PREC_LOGICAL_OR)
            pstate.expect(_else)
            pstate.advance()
            else_expr = self.parse_expression(pstate)
            left_exp = If(condition, then_expr, else_expr)
            did_something = True
        elif next_tag is _dot and _PREC_CALL > min_precedence:
            pstate.advance()
            pstate.expect(_identifier)
            left_exp = primitives.Lookup(left_exp, pstate.next_str())
            pstate.advance()
            did_something = True
        elif next_tag is _plus and _PREC_PLUS > min_precedence:
            pstate.advance()
            right_exp = self.parse_expression(pstate, _PREC_PLUS)
            if isinstance(left_exp, primitives.Sum):
                left_exp = primitives.Sum(left_exp.children + (right_exp,))
            else:
                left_exp = primitives.Sum((left_exp, right_exp))

            did_something = True
        elif next_tag is _minus and _PREC_PLUS > min_precedence:
            pstate.advance()
            right_exp = self.parse_expression(pstate, _PREC_PLUS)
            if isinstance(left_exp, primitives.Sum):
                left_exp = primitives.Sum(left_exp.children + ((-right_exp),))  # noqa pylint:disable=invalid-unary-operand-type
            else:
                left_exp = primitives.Sum((left_exp, -right_exp))  # noqa pylint:disable=invalid-unary-operand-type
            did_something = True
        elif next_tag is _times and _PREC_TIMES > min_precedence:
            pstate.advance()
            right_exp = self.parse_expression(pstate, _PREC_PLUS)
            if isinstance(left_exp, primitives.Product):
                left_exp = primitives.Product(left_exp.children + (right_exp,))
            else:
                left_exp = primitives.Product((left_exp, right_exp))
            did_something = True
        elif next_tag is _floordiv and _PREC_TIMES > min_precedence:
            pstate.advance()
            left_exp = primitives.FloorDiv(
                    left_exp, self.parse_expression(pstate, _PREC_TIMES))
            did_something = True
        elif next_tag is _over and _PREC_TIMES > min_precedence:
            pstate.advance()
            left_exp = primitives.Quotient(
                    left_exp, self.parse_expression(pstate, _PREC_TIMES))
            did_something = True
        elif next_tag is _modulo and _PREC_TIMES > min_precedence:
            pstate.advance()
            left_exp = primitives.Remainder(
                    left_exp, self.parse_expression(pstate, _PREC_TIMES))
            did_something = True
        elif next_tag is _power and _PREC_POWER > min_precedence:
            pstate.advance()
            left_exp = primitives.Power(
                    left_exp, self.parse_expression(pstate, _PREC_TIMES))
            did_something = True
        elif next_tag is _and and _PREC_LOGICAL_AND > min_precedence:
            pstate.advance()
            from pymbolic.primitives import LogicalAnd
            left_exp = LogicalAnd((
                    left_exp,
                    self.parse_expression(pstate, _PREC_LOGICAL_AND)))
            did_something = True
        elif next_tag is _or and _PREC_LOGICAL_OR > min_precedence:
            pstate.advance()
            from pymbolic.primitives import LogicalOr
            left_exp = LogicalOr((
                    left_exp,
                    self.parse_expression(pstate, _PREC_LOGICAL_OR)))
            did_something = True
        elif next_tag is _bitwiseor and _PREC_BITWISE_OR > min_precedence:
            pstate.advance()
            from pymbolic.primitives import BitwiseOr
            left_exp = BitwiseOr((
                    left_exp,
                    self.parse_expression(pstate, _PREC_BITWISE_OR)))
            did_something = True
        elif next_tag is _bitwisexor and _PREC_BITWISE_XOR > min_precedence:
            pstate.advance()
            from pymbolic.primitives import BitwiseXor
            left_exp = BitwiseXor((
                    left_exp,
                    self.parse_expression(pstate, _PREC_BITWISE_XOR)))
            did_something = True
        elif next_tag is _bitwiseand and _PREC_BITWISE_AND > min_precedence:
            pstate.advance()
            from pymbolic.primitives import BitwiseAnd
            left_exp = BitwiseAnd((
                    left_exp,
                    self.parse_expression(pstate, _PREC_BITWISE_AND)))
            did_something = True
        elif next_tag is _rightshift and _PREC_SHIFT > min_precedence:
            pstate.advance()
            from pymbolic.primitives import RightShift
            left_exp = RightShift(
                    left_exp,
                    self.parse_expression(pstate, _PREC_SHIFT))
            did_something = True
        elif next_tag is _leftshift and _PREC_SHIFT > min_precedence:
            pstate.advance()
            from pymbolic.primitives import LeftShift
            left_exp = LeftShift(
                    left_exp,
                    self.parse_expression(pstate, _PREC_SHIFT))
            did_something = True
        elif next_tag in self._COMP_TABLE and _PREC_COMPARISON > min_precedence:
            pstate.advance()
            from pymbolic.primitives import Comparison
            left_exp = Comparison(
                    left_exp,
                    self._COMP_TABLE[next_tag],
                    self.parse_expression(pstate, _PREC_COMPARISON))
            did_something = True
        elif next_tag is _colon and _PREC_SLICE >= min_precedence:
            pstate.advance()
            expr_pstate = pstate.copy()

            assert not isinstance(left_exp, primitives.Slice)

            from pytools.lex import ParseError
            try:
                next_expr = self.parse_expression(expr_pstate, _PREC_SLICE)
            except ParseError:
                # no expression follows, too bad.
                left_exp = primitives.Slice((left_exp, None,))
            else:
                left_exp = _join_to_slice(left_exp, next_expr)
                pstate.assign(expr_pstate)

            did_something = True

        elif next_tag is _comma and _PREC_COMMA > min_precedence:
            # The precedence makes the comma left-associative.

            pstate.advance()
            if pstate.is_at_end() or pstate.next_tag() is _closepar:
                if isinstance(left_exp, (tuple, list)) \
                        and not isinstance(left_exp, FinalizedContainer):
                    # left_expr is a container with trailing commas
                    pass
                else:
                    left_exp = (left_exp,)
            else:
                new_el = self.parse_expression(pstate, _PREC_COMMA)
                if isinstance(left_exp, (tuple, list)) \
                        and not isinstance(left_exp, FinalizedContainer):
                    left_exp = left_exp + (new_el,)
                else:
                    left_exp = (left_exp, new_el)

            did_something = True

        return left_exp, did_something
Esempio n. 13
0
    def __init__(self, fft, effective_k, dk, dx):
        self.fft = fft

        if not callable(effective_k):
            if effective_k != 0:
                from pystella.derivs import FirstCenteredDifference
                h = effective_k
                effective_k = FirstCenteredDifference(h).get_eigenvalues
            else:

                def effective_k(k, dx):  # pylint: disable=function-redefined
                    return k

        queue = self.fft.sub_k["momenta_x"].queue
        sub_k = list(x.get().astype("int") for x in self.fft.sub_k.values())
        eff_mom_names = ("eff_mom_x", "eff_mom_y", "eff_mom_z")
        self.eff_mom = {}
        for mu, (name, kk) in enumerate(zip(eff_mom_names, sub_k)):
            eff_k = effective_k(dk[mu] * kk.astype(fft.rdtype), dx[mu])
            eff_k[abs(sub_k[mu]) == fft.grid_shape[mu] // 2] = 0.
            eff_k[sub_k[mu] == 0] = 0.

            import pyopencl.array as cla
            self.eff_mom[name] = cla.to_device(queue, eff_k)

        from pymbolic import var, parse
        from pymbolic.primitives import If, Comparison, LogicalAnd
        from pystella import Field
        indices = parse("i, j, k")
        eff_k = tuple(
            var(array)[mu] for array, mu in zip(eff_mom_names, indices))
        fabs, sqrt, conj = parse("fabs, sqrt, conj")
        kmag = sqrt(sum(kk**2 for kk in eff_k))

        from pystella import ElementWiseMap
        vector = Field("vector", shape=(3, ))
        vector_T = Field("vector_T", shape=(3, ))

        kvec_zero = LogicalAnd(
            tuple(Comparison(fabs(eff_k[mu]), "<", 1e-14) for mu in range(3)))

        # note: write all output via private temporaries to allow for in-place

        div = var("div")
        div_insn = [(div, sum(eff_k[mu] * vector[mu] for mu in range(3)))]
        self.transversify_knl = ElementWiseMap(
            {
                vector_T[mu]: If(kvec_zero, 0,
                                 vector[mu] - eff_k[mu] / kmag**2 * div)
                for mu in range(3)
            },
            tmp_instructions=div_insn,
            lsize=(32, 1, 1),
            rank_shape=fft.shape(True),
        )

        import loopy as lp

        def assign(asignee, expr, **kwargs):
            default = dict(within_inames=frozenset(("i", "j", "k")),
                           no_sync_with=[("*", "any")])
            default.update(kwargs)
            return lp.Assignment(asignee, expr, **default)

        kmag, Kappa = parse("kmag, Kappa")
        eps_insns = [
            assign(kmag, sqrt(sum(kk**2 for kk in eff_k))),
            assign(Kappa, sqrt(sum(kk**2 for kk in eff_k[:2])))
        ]

        zero = fft.cdtype.type(0)
        kx_ky_zero = LogicalAnd(
            tuple(Comparison(fabs(eff_k[mu]), "<", 1e-10) for mu in range(2)))
        kz_nonzero = Comparison(fabs(eff_k[2]), ">", 1e-10)

        eps = var("eps")
        eps_insns.extend([
            assign(
                eps[0],
                If(kx_ky_zero, If(kz_nonzero, fft.cdtype.type(1 / 2**.5),
                                  zero),
                   (eff_k[0] * eff_k[2] / kmag - 1j * eff_k[1]) / Kappa /
                   2**.5)),
            assign(
                eps[1],
                If(kx_ky_zero,
                   If(kz_nonzero, fft.cdtype.type(1j / 2**(1 / 2)),
                      zero), (eff_k[1] * eff_k[2] / kmag + 1j * eff_k[0]) /
                   Kappa / 2**.5)),
            assign(eps[2], If(kx_ky_zero, zero, -Kappa / kmag / 2**.5))
        ])

        plus, minus, lng = Field("plus"), Field("minus"), Field("lng")

        plus_tmp, minus_tmp = parse("plus_tmp, minus_tmp")
        pol_isns = [(plus_tmp,
                     sum(vector[mu] * conj(eps[mu]) for mu in range(3))),
                    (minus_tmp, sum(vector[mu] * eps[mu] for mu in range(3)))]

        args = [
            lp.TemporaryVariable("kmag"),
            lp.TemporaryVariable("Kappa"),
            lp.TemporaryVariable("eps", shape=(3, )), ...
        ]

        self.vec_to_pol_knl = ElementWiseMap(
            {
                plus: plus_tmp,
                minus: minus_tmp
            },
            tmp_instructions=eps_insns + pol_isns,
            args=args,
            lsize=(32, 1, 1),
            rank_shape=fft.shape(True),
        )

        vector_tmp = var("vector_tmp")
        vec_insns = [(vector_tmp[mu], plus * eps[mu] + minus * conj(eps[mu]))
                     for mu in range(3)]

        self.pol_to_vec_knl = ElementWiseMap(
            {vector[mu]: vector_tmp[mu]
             for mu in range(3)},
            tmp_instructions=eps_insns + vec_insns,
            args=args,
            lsize=(32, 1, 1),
            rank_shape=fft.shape(True),
        )

        ksq = sum(kk**2 for kk in eff_k)
        lng_rhs = If(kvec_zero, 0, -div / ksq * 1j)
        self.vec_decomp_knl = ElementWiseMap(
            {
                plus: plus_tmp,
                minus: minus_tmp,
                lng: lng_rhs
            },
            tmp_instructions=eps_insns + pol_isns + div_insn,
            args=args,
            lsize=(32, 1, 1),
            rank_shape=fft.shape(True),
        )
        lng_rhs = If(kvec_zero, 0, -div / ksq**.5 * 1j)
        self.vec_decomp_knl_times_abs_k = ElementWiseMap(
            {
                plus: plus_tmp,
                minus: minus_tmp,
                lng: lng_rhs
            },
            tmp_instructions=eps_insns + pol_isns + div_insn,
            args=args,
            lsize=(32, 1, 1),
            rank_shape=fft.shape(True),
        )

        from pystella.sectors import tensor_index as tid

        eff_k_hat = tuple(kk / sqrt(sum(kk**2 for kk in eff_k))
                          for kk in eff_k)
        hij = Field("hij", shape=(6, ))
        hij_TT = Field("hij_TT", shape=(6, ))

        Pab = var("P")
        Pab_insns = [(Pab[tid(a, b)], (If(Comparison(a, "==", b), 1, 0) -
                                       eff_k_hat[a - 1] * eff_k_hat[b - 1]))
                     for a in range(1, 4) for b in range(a, 4)]

        hij_TT_tmp = var("hij_TT_tmp")
        TT_insns = [(hij_TT_tmp[tid(a, b)],
                     sum((Pab[tid(a, c)] * Pab[tid(d, b)] -
                          Pab[tid(a, b)] * Pab[tid(c, d)] / 2) * hij[tid(c, d)]
                         for c in range(1, 4) for d in range(1, 4)))
                    for a in range(1, 4) for b in range(a, 4)]
        # note: where conditionals (branch divergence) go can matter:
        # this kernel is twice as fast when putting the branching in the global
        # write, rather than when setting hij_TT_tmp
        write_insns = [(hij_TT[tid(a,
                                   b)], If(kvec_zero, 0, hij_TT_tmp[tid(a,
                                                                        b)]))
                       for a in range(1, 4) for b in range(a, 4)]
        self.tt_knl = ElementWiseMap(
            write_insns,
            tmp_instructions=Pab_insns + TT_insns,
            lsize=(32, 1, 1),
            rank_shape=fft.shape(True),
        )

        tensor_to_pol_insns = {
            plus:
            sum(hij[tid(c, d)] * conj(eps[c - 1]) * conj(eps[d - 1])
                for c in range(1, 4) for d in range(1, 4)),
            minus:
            sum(hij[tid(c, d)] * eps[c - 1] * eps[d - 1] for c in range(1, 4)
                for d in range(1, 4))
        }
        self.tensor_to_pol_knl = ElementWiseMap(
            tensor_to_pol_insns,
            tmp_instructions=eps_insns,
            args=args,
            lsize=(32, 1, 1),
            rank_shape=fft.shape(True),
        )

        pol_to_tensor_insns = {
            hij[tid(a, b)]: (plus * eps[a - 1] * eps[b - 1] +
                             minus * conj(eps[a - 1]) * conj(eps[b - 1]))
            for a in range(1, 4) for b in range(a, 4)
        }
        self.pol_to_tensor_knl = ElementWiseMap(
            pol_to_tensor_insns,
            tmp_instructions=eps_insns,
            args=args,
            lsize=(32, 1, 1),
            rank_shape=fft.shape(True),
        )
Esempio n. 14
0
 def is_odd(expr):
     from pymbolic.primitives import If, Comparison, Remainder
     return If(Comparison(Remainder(expr, 2), "==", 1), 1, 0)
Esempio n. 15
0
    def parse_postfix(self, pstate, min_precedence, left_exp):
        import pymbolic.primitives as primitives

        did_something = False

        next_tag = pstate.next_tag()

        if next_tag is _openpar and _PREC_CALL > min_precedence:
            pstate.advance()
            args, kwargs = self.parse_arglist(pstate)

            if kwargs:
                left_exp = primitives.CallWithKwargs(left_exp, args, kwargs)
            else:
                left_exp = primitives.Call(left_exp, args)

            did_something = True
        elif next_tag is _openbracket and _PREC_CALL > min_precedence:
            pstate.advance()
            pstate.expect_not_end()
            left_exp = primitives.Subscript(left_exp,
                                            self.parse_expression(pstate))
            pstate.expect(_closebracket)
            pstate.advance()
            did_something = True
        elif next_tag is _dot and _PREC_CALL > min_precedence:
            pstate.advance()
            pstate.expect(_identifier)
            left_exp = primitives.Lookup(left_exp, pstate.next_str())
            pstate.advance()
            did_something = True
        elif next_tag is _plus and _PREC_PLUS > min_precedence:
            pstate.advance()
            left_exp += self.parse_expression(pstate, _PREC_PLUS)
            did_something = True
        elif next_tag is _minus and _PREC_PLUS > min_precedence:
            pstate.advance()
            left_exp -= self.parse_expression(pstate, _PREC_PLUS)
            did_something = True
        elif next_tag is _times and _PREC_TIMES > min_precedence:
            pstate.advance()
            left_exp *= self.parse_expression(pstate, _PREC_TIMES)
            did_something = True
        elif next_tag is _floordiv and _PREC_TIMES > min_precedence:
            pstate.advance()
            left_exp //= self.parse_expression(pstate, _PREC_TIMES)
            did_something = True
        elif next_tag is _over and _PREC_TIMES > min_precedence:
            pstate.advance()
            left_exp /= self.parse_expression(pstate, _PREC_TIMES)
            did_something = True
        elif next_tag is _modulo and _PREC_TIMES > min_precedence:
            pstate.advance()
            left_exp %= self.parse_expression(pstate, _PREC_TIMES)
            did_something = True
        elif next_tag is _power and _PREC_POWER > min_precedence:
            pstate.advance()
            left_exp **= self.parse_expression(pstate, _PREC_POWER)
            did_something = True
        elif next_tag is _and and _PREC_LOGICAL_AND > min_precedence:
            pstate.advance()
            from pymbolic.primitives import LogicalAnd
            left_exp = LogicalAnd(
                (left_exp, self.parse_expression(pstate, _PREC_LOGICAL_AND)))
            did_something = True
        elif next_tag is _or and _PREC_LOGICAL_OR > min_precedence:
            pstate.advance()
            from pymbolic.primitives import LogicalOr
            left_exp = LogicalOr(
                (left_exp, self.parse_expression(pstate, _PREC_LOGICAL_OR)))
            did_something = True
        elif next_tag is _bitwiseor and _PREC_BITWISE_OR > min_precedence:
            pstate.advance()
            from pymbolic.primitives import BitwiseOr
            left_exp = BitwiseOr(
                (left_exp, self.parse_expression(pstate, _PREC_BITWISE_OR)))
            did_something = True
        elif next_tag is _bitwisexor and _PREC_BITWISE_XOR > min_precedence:
            pstate.advance()
            from pymbolic.primitives import BitwiseXor
            left_exp = BitwiseXor(
                (left_exp, self.parse_expression(pstate, _PREC_BITWISE_XOR)))
            did_something = True
        elif next_tag is _bitwiseand and _PREC_BITWISE_AND > min_precedence:
            pstate.advance()
            from pymbolic.primitives import BitwiseAnd
            left_exp = BitwiseAnd(
                (left_exp, self.parse_expression(pstate, _PREC_BITWISE_AND)))
            did_something = True
        elif next_tag is _rightshift and _PREC_SHIFT > min_precedence:
            pstate.advance()
            from pymbolic.primitives import RightShift
            left_exp = RightShift(left_exp,
                                  self.parse_expression(pstate, _PREC_SHIFT))
            did_something = True
        elif next_tag is _leftshift and _PREC_SHIFT > min_precedence:
            pstate.advance()
            from pymbolic.primitives import LeftShift
            left_exp = LeftShift(left_exp,
                                 self.parse_expression(pstate, _PREC_SHIFT))
            did_something = True
        elif next_tag in self._COMP_TABLE and _PREC_COMPARISON > min_precedence:
            pstate.advance()
            from pymbolic.primitives import Comparison
            left_exp = Comparison(
                left_exp, self._COMP_TABLE[next_tag],
                self.parse_expression(pstate, _PREC_COMPARISON))
            did_something = True
        elif next_tag is _colon and _PREC_SLICE >= min_precedence:
            pstate.advance()
            expr_pstate = pstate.copy()

            assert not isinstance(left_exp, primitives.Slice)

            from pytools.lex import ParseError
            try:
                next_expr = self.parse_expression(expr_pstate, _PREC_SLICE)
            except ParseError:
                # no expression follows, too bad.
                left_exp = primitives.Slice((
                    left_exp,
                    None,
                ))
            else:
                left_exp = _join_to_slice(left_exp, next_expr)
                pstate.assign(expr_pstate)

        elif next_tag is _comma and _PREC_COMMA > min_precedence:
            # The precedence makes the comma left-associative.

            pstate.advance()
            if pstate.is_at_end() or pstate.next_tag() is _closepar:
                left_exp = (left_exp, )
            else:
                new_el = self.parse_expression(pstate, _PREC_COMMA)
                if isinstance(left_exp, tuple) \
                        and not isinstance(left_exp, FinalizedTuple):
                    left_exp = left_exp + (new_el, )
                else:
                    left_exp = (left_exp, new_el)

            did_something = True

        return left_exp, did_something