def finish_adaptive(self, cb, high_order_estimate, low_order_estimate): from pymbolic import var from pymbolic.primitives import Comparison, LogicalOr, Max, Min from dagrt.expression import IfThenElse norm_start_state = var("norm_start_state") norm_end_state = var("norm_end_state") rel_error_raw = var("rel_error_raw") rel_error = var("rel_error") def norm(expr): return var("<builtin>norm_2")(expr) cb(norm_start_state, norm(self.state)) cb(norm_end_state, norm(low_order_estimate)) cb( rel_error_raw, norm(high_order_estimate - low_order_estimate) / (var("<builtin>len")(self.state)**0.5 * (self.atol + self.rtol * Max( (norm_start_state, norm_end_state))))) cb( rel_error, IfThenElse(Comparison(rel_error_raw, "==", 0), 1.0e-14, rel_error_raw)) with cb.if_( LogicalOr((Comparison(rel_error, ">", 1), var("<builtin>isnan")(rel_error)))): with cb.if_(var("<builtin>isnan")(rel_error)): cb(self.dt, self.min_dt_shrinkage * self.dt) with cb.else_(): cb( self.dt, Max((0.9 * self.dt * rel_error**(-1 / self.low_order), self.min_dt_shrinkage * self.dt))) with cb.if_(self.t + self.dt, "==", self.t): cb.raise_(TimeStepUnderflow) with cb.else_(): cb.fail_step() with cb.else_(): # This updates <t>: <dt> should not be set before this is called. self.finish_nonadaptive(cb, high_order_estimate, low_order_estimate) cb( self.dt, Min((0.9 * self.dt * rel_error**(-1 / self.high_order), self.max_dt_growth * self.dt)))
def map_concatenate(self, expr: Concatenate) -> Array: from pymbolic.primitives import If, Comparison, Subscript def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: aggregate = var(f"_in{array_index}") index = [ var(f"_{i}") if i != expr.axis else (var(f"_{i}") - offset) for i in range(len(expr.shape)) ] return Subscript(aggregate, tuple(index)) lbounds: List[Any] = [0] ubounds: List[Any] = [expr.arrays[0].shape[expr.axis]] for i, array in enumerate(expr.arrays[1:], start=1): ubounds.append(ubounds[i - 1] + array.shape[expr.axis]) lbounds.append(ubounds[i - 1]) # I = axis index # # => If(0<=_I < arrays[0].shape[axis], # _in0[_0, _1, ..., _I, ...], # If(arrays[0].shape[axis]<= _I < (arrays[1].shape[axis] # +arrays[0].shape[axis]), # _in1[_0, _1, ..., _I-arrays[0].shape[axis], ...], # ... # _inNm1[_0, _1, ...] ...)) for i in range(len(expr.arrays) - 1, -1, -1): lbound, ubound = lbounds[i], ubounds[i] subarray_expr = get_subscript(i, lbound) if i == len(expr.arrays) - 1: stack_expr = subarray_expr else: stack_expr = If( Comparison(var(f"_{expr.axis}"), ">=", lbound) and Comparison(var(f"_{expr.axis}"), "<", ubound), subarray_expr, stack_expr) bindings = { f"_in{i}": self.rec(array) for i, array in enumerate(expr.arrays) } return IndexLambda(namespace=self.namespace, expr=stack_expr, shape=expr.shape, dtype=expr.dtype, bindings=bindings)
def constraint_to_expr(cns): # Looks like this is ok after all--get_aff() performs some magic. # Not entirely sure though... FIXME # #ls = cns.get_local_space() #if ls.dim(dim_type.div): #raise RuntimeError("constraint has an existentially quantified variable") expr = aff_to_expr(cns.get_aff()) from pymbolic.primitives import Comparison if cns.is_equality(): return Comparison(expr, "==", 0) else: return Comparison(expr, ">=", 0)
def wrap_in_typecast_lazy(self, actual_dtype, needed_dtype, s): if needed_dtype.dtype.kind == "b" and actual_dtype().dtype.kind == "f": # CL does not perform implicit conversion from float-type to a bool. from pymbolic.primitives import Comparison return Comparison(s, "!=", 0) return super().wrap_in_typecast_lazy(actual_dtype, needed_dtype, s)
def parse_postfix(self, pstate, min_precedence, left_exp): from pymbolic.parser import ( _PREC_CALL, _PREC_COMPARISON, _openpar, _PREC_LOGICAL_OR, _PREC_LOGICAL_AND) from pymbolic.primitives import ( Comparison, LogicalAnd, LogicalOr) next_tag = pstate.next_tag() if next_tag is _openpar and _PREC_CALL > min_precedence: raise TranslationError("parenthesis operator only works on names") elif next_tag in self.COMP_MAP and _PREC_COMPARISON > min_precedence: pstate.advance() left_exp = Comparison( left_exp, self.COMP_MAP[next_tag], self.parse_expression(pstate, _PREC_COMPARISON)) did_something = True elif next_tag is _and and _PREC_LOGICAL_AND > min_precedence: pstate.advance() left_exp = LogicalAnd((left_exp, self.parse_expression(pstate, _PREC_LOGICAL_AND))) did_something = True elif next_tag is _or and _PREC_LOGICAL_OR > min_precedence: pstate.advance() left_exp = LogicalOr((left_exp, self.parse_expression(pstate, _PREC_LOGICAL_OR))) did_something = True else: left_exp, did_something = ExpressionParserBase.parse_postfix( self, pstate, min_precedence, left_exp) return left_exp, did_something
def map_stack(self, expr: Stack) -> Array: def get_subscript(array_index: int) -> SymbolicIndex: result = [] for i in range(expr.ndim): if i != expr.axis: result.append(var(f"_{i}")) return tuple(result) # I = axis index # # => If(_I == 0, # _in0[_0, _1, ...], # If(_I == 1, # _in1[_0, _1, ...], # ... # _inNm1[_0, _1, ...] ...)) for i in range(len(expr.arrays) - 1, -1, -1): subarray_expr = var(f"_in{i}")[get_subscript(i)] if i == len(expr.arrays) - 1: stack_expr = subarray_expr else: from pymbolic.primitives import If, Comparison stack_expr = If(Comparison(var(f"_{expr.axis}"), "==", i), subarray_expr, stack_expr) bindings = { f"_in{i}": self.rec(array) for i, array in enumerate(expr.arrays) } return IndexLambda(namespace=self.namespace, expr=stack_expr, shape=expr.shape, dtype=expr.dtype, bindings=bindings)
def if_(self, *condition_arg): """Create a new block that is conditionally executed.""" from dagrt.expression import parse if len(condition_arg) == 1: condition = condition_arg[0] if isinstance(condition, str): condition = parse(condition) elif len(condition_arg) == 3: lhs, cond, rhs = condition_arg if isinstance(lhs, str): lhs = parse(lhs) if isinstance(rhs, str): rhs = parse(lhs) from pymbolic.primitives import Comparison condition = Comparison(lhs, cond, rhs) else: raise ValueError("Unrecognized condition expression") # Create an statement as a lead statement to assign a logical flag. cond_var = self.fresh_var("<cond>") cond_assignment = Assign(assignee=cond_var.name, assignee_subscript=(), expression=condition) self._add_statement(cond_assignment) self._conditional_expression_stack.append(cond_var) yield self._conditional_expression_stack.pop() self._last_if_block_conditional_expression = cond_var
def emit_sequential_loop(self, codegen_state, iname, iname_dtype, lbound, ubound, inner): ecm = codegen_state.expression_to_code_mapper from pymbolic import var from pymbolic.primitives import Comparison from pymbolic.mapper.stringifier import PREC_NONE from cgen import For, InlineInitializer return For( InlineInitializer(POD(self, iname_dtype, iname), ecm(lbound, PREC_NONE, "i")), ecm(Comparison(var(iname), "<=", ubound), PREC_NONE, "i"), "++%s" % iname, inner)
def make_spectra_knl(self, is_real, rank_shape): from pymbolic import var, parse indices = i, j, k = parse("i, j, k") momenta = [var("momenta_"+xx) for xx in ("x", "y", "z")] ksq = sum((dk_i * mom[ii])**2 for mom, dk_i, ii in zip(momenta, self.dk, indices)) kmag = var("sqrt")(ksq) bin_expr = var("round")(kmag / self.bin_width) if is_real: from pymbolic.primitives import If, Comparison, LogicalAnd nyq = self.grid_shape[-1] / 2 condition = LogicalAnd((Comparison(momenta[2][k], ">", 0), Comparison(momenta[2][k], "<", nyq))) count = If(condition, 2, 1) else: count = 1 fk = var("fk")[i, j, k] weight_expr = count * kmag**(var("k_power")) * var("abs")(fk)**2 histograms = {"spectrum": (bin_expr, weight_expr)} args = [ lp.GlobalArg("fk", self.cdtype, shape=("Nx", "Ny", "Nz"), offset=lp.auto), lp.GlobalArg("momenta_x", self.rdtype, shape=("Nx",)), lp.GlobalArg("momenta_y", self.rdtype, shape=("Ny",)), lp.GlobalArg("momenta_z", self.rdtype, shape=("Nz",)), lp.ValueArg("k_power", self.rdtype), ... ] from pystella.histogram import Histogrammer return Histogrammer(self.decomp, histograms, self.num_bins, self.rdtype, args=args, rank_shape=rank_shape)
def parse_postfix(self, pstate, min_precedence, left_exp): from pymbolic.parser import (_PREC_CALL, _PREC_COMPARISON, _openpar, _PREC_LOGICAL_OR, _PREC_LOGICAL_AND) from pymbolic.primitives import (Comparison, LogicalAnd, LogicalOr) next_tag = pstate.next_tag() if next_tag is _openpar and _PREC_CALL > min_precedence: raise TranslationError("parenthesis operator only works on names") elif next_tag in self.COMP_MAP and _PREC_COMPARISON > min_precedence: pstate.advance() left_exp = Comparison( left_exp, self.COMP_MAP[next_tag], self.parse_expression(pstate, _PREC_COMPARISON)) did_something = True elif next_tag is _and and _PREC_LOGICAL_AND > min_precedence: pstate.advance() left_exp = LogicalAnd( (left_exp, self.parse_expression(pstate, _PREC_LOGICAL_AND))) did_something = True elif next_tag is _or and _PREC_LOGICAL_OR > min_precedence: pstate.advance() left_exp = LogicalOr( (left_exp, self.parse_expression(pstate, _PREC_LOGICAL_OR))) did_something = True else: left_exp, did_something = ExpressionParserBase.parse_postfix( self, pstate, min_precedence, left_exp) if isinstance(left_exp, tuple) and min_precedence < self._PREC_FUNC_ARGS: # this must be a complex literal if len(left_exp) != 2: raise TranslationError("complex literals must have " "two entries") r, i = left_exp dtype = (r.dtype.type(0) + i.dtype.type(0)) if dtype == np.float32: dtype = np.complex64 else: dtype = np.complex128 left_exp = dtype(float(r) + float(i) * 1j) return left_exp, did_something
def __init__(self, fft, dk, dx, effective_k): self.fft = fft grid_size = fft.grid_shape[0] * fft.grid_shape[1] * fft.grid_shape[2] queue = self.fft.sub_k["momenta_x"].queue sub_k = list(x.get().astype("int") for x in self.fft.sub_k.values()) k_names = ("k_x", "k_y", "k_z") self.momenta = {} self.momenta = {} for mu, (name, kk) in enumerate(zip(k_names, sub_k)): kk_mu = effective_k(dk[mu] * kk.astype(fft.rdtype), dx[mu]) self.momenta[name] = cla.to_device(queue, kk_mu) args = [ lp.GlobalArg("fk", fft.cdtype, shape="(Nx, Ny, Nz)"), lp.GlobalArg("k_x", fft.rdtype, shape=("Nx", )), lp.GlobalArg("k_y", fft.rdtype, shape=("Ny", )), lp.GlobalArg("k_z", fft.rdtype, shape=("Nz", )), lp.ValueArg("m_squared", fft.rdtype), ] from pystella.field import Field from pymbolic.primitives import Variable, If, Comparison fk = Field("fk") indices = fk.indices rho_tmp = Variable("rho_tmp") tmp_insns = [(rho_tmp, Field("rhok") * (1 / grid_size))] mom_vars = tuple(Variable(name) for name in k_names) minus_k_squared = sum(kk_i[x_i] for kk_i, x_i in zip(mom_vars, indices)) sol = rho_tmp / (minus_k_squared - Variable("m_squared")) solution = { Field("fk"): If(Comparison(minus_k_squared, "<", 0), sol, 0) } from pystella.elementwise import ElementWiseMap options = lp.Options(return_dict=True) self.knl = ElementWiseMap(solution, args=args, halo_shape=0, options=options, tmp_instructions=tmp_insns, lsize=(16, 2, 1))
def parse_postfix(self, pstate, min_precedence, left_exp): import pymbolic.primitives as primitives did_something = False next_tag = pstate.next_tag() if next_tag is _openpar and _PREC_CALL > min_precedence: pstate.advance() args, kwargs = self.parse_arglist(pstate) if kwargs: left_exp = primitives.CallWithKwargs(left_exp, args, kwargs) else: left_exp = primitives.Call(left_exp, args) did_something = True elif next_tag is _openbracket and _PREC_CALL > min_precedence: pstate.advance() pstate.expect_not_end() left_exp = primitives.Subscript(left_exp, self.parse_expression(pstate)) pstate.expect(_closebracket) pstate.advance() did_something = True elif next_tag is _if and _PREC_IF > min_precedence: from pymbolic.primitives import If then_expr = left_exp pstate.advance() pstate.expect_not_end() condition = self.parse_expression(pstate, _PREC_LOGICAL_OR) pstate.expect(_else) pstate.advance() else_expr = self.parse_expression(pstate) left_exp = If(condition, then_expr, else_expr) did_something = True elif next_tag is _dot and _PREC_CALL > min_precedence: pstate.advance() pstate.expect(_identifier) left_exp = primitives.Lookup(left_exp, pstate.next_str()) pstate.advance() did_something = True elif next_tag is _plus and _PREC_PLUS > min_precedence: pstate.advance() right_exp = self.parse_expression(pstate, _PREC_PLUS) if isinstance(left_exp, primitives.Sum): left_exp = primitives.Sum(left_exp.children + (right_exp,)) else: left_exp = primitives.Sum((left_exp, right_exp)) did_something = True elif next_tag is _minus and _PREC_PLUS > min_precedence: pstate.advance() right_exp = self.parse_expression(pstate, _PREC_PLUS) if isinstance(left_exp, primitives.Sum): left_exp = primitives.Sum(left_exp.children + ((-right_exp),)) # noqa pylint:disable=invalid-unary-operand-type else: left_exp = primitives.Sum((left_exp, -right_exp)) # noqa pylint:disable=invalid-unary-operand-type did_something = True elif next_tag is _times and _PREC_TIMES > min_precedence: pstate.advance() right_exp = self.parse_expression(pstate, _PREC_PLUS) if isinstance(left_exp, primitives.Product): left_exp = primitives.Product(left_exp.children + (right_exp,)) else: left_exp = primitives.Product((left_exp, right_exp)) did_something = True elif next_tag is _floordiv and _PREC_TIMES > min_precedence: pstate.advance() left_exp = primitives.FloorDiv( left_exp, self.parse_expression(pstate, _PREC_TIMES)) did_something = True elif next_tag is _over and _PREC_TIMES > min_precedence: pstate.advance() left_exp = primitives.Quotient( left_exp, self.parse_expression(pstate, _PREC_TIMES)) did_something = True elif next_tag is _modulo and _PREC_TIMES > min_precedence: pstate.advance() left_exp = primitives.Remainder( left_exp, self.parse_expression(pstate, _PREC_TIMES)) did_something = True elif next_tag is _power and _PREC_POWER > min_precedence: pstate.advance() left_exp = primitives.Power( left_exp, self.parse_expression(pstate, _PREC_TIMES)) did_something = True elif next_tag is _and and _PREC_LOGICAL_AND > min_precedence: pstate.advance() from pymbolic.primitives import LogicalAnd left_exp = LogicalAnd(( left_exp, self.parse_expression(pstate, _PREC_LOGICAL_AND))) did_something = True elif next_tag is _or and _PREC_LOGICAL_OR > min_precedence: pstate.advance() from pymbolic.primitives import LogicalOr left_exp = LogicalOr(( left_exp, self.parse_expression(pstate, _PREC_LOGICAL_OR))) did_something = True elif next_tag is _bitwiseor and _PREC_BITWISE_OR > min_precedence: pstate.advance() from pymbolic.primitives import BitwiseOr left_exp = BitwiseOr(( left_exp, self.parse_expression(pstate, _PREC_BITWISE_OR))) did_something = True elif next_tag is _bitwisexor and _PREC_BITWISE_XOR > min_precedence: pstate.advance() from pymbolic.primitives import BitwiseXor left_exp = BitwiseXor(( left_exp, self.parse_expression(pstate, _PREC_BITWISE_XOR))) did_something = True elif next_tag is _bitwiseand and _PREC_BITWISE_AND > min_precedence: pstate.advance() from pymbolic.primitives import BitwiseAnd left_exp = BitwiseAnd(( left_exp, self.parse_expression(pstate, _PREC_BITWISE_AND))) did_something = True elif next_tag is _rightshift and _PREC_SHIFT > min_precedence: pstate.advance() from pymbolic.primitives import RightShift left_exp = RightShift( left_exp, self.parse_expression(pstate, _PREC_SHIFT)) did_something = True elif next_tag is _leftshift and _PREC_SHIFT > min_precedence: pstate.advance() from pymbolic.primitives import LeftShift left_exp = LeftShift( left_exp, self.parse_expression(pstate, _PREC_SHIFT)) did_something = True elif next_tag in self._COMP_TABLE and _PREC_COMPARISON > min_precedence: pstate.advance() from pymbolic.primitives import Comparison left_exp = Comparison( left_exp, self._COMP_TABLE[next_tag], self.parse_expression(pstate, _PREC_COMPARISON)) did_something = True elif next_tag is _colon and _PREC_SLICE >= min_precedence: pstate.advance() expr_pstate = pstate.copy() assert not isinstance(left_exp, primitives.Slice) from pytools.lex import ParseError try: next_expr = self.parse_expression(expr_pstate, _PREC_SLICE) except ParseError: # no expression follows, too bad. left_exp = primitives.Slice((left_exp, None,)) else: left_exp = _join_to_slice(left_exp, next_expr) pstate.assign(expr_pstate) did_something = True elif next_tag is _comma and _PREC_COMMA > min_precedence: # The precedence makes the comma left-associative. pstate.advance() if pstate.is_at_end() or pstate.next_tag() is _closepar: if isinstance(left_exp, (tuple, list)) \ and not isinstance(left_exp, FinalizedContainer): # left_expr is a container with trailing commas pass else: left_exp = (left_exp,) else: new_el = self.parse_expression(pstate, _PREC_COMMA) if isinstance(left_exp, (tuple, list)) \ and not isinstance(left_exp, FinalizedContainer): left_exp = left_exp + (new_el,) else: left_exp = (left_exp, new_el) did_something = True return left_exp, did_something
def __init__(self, fft, effective_k, dk, dx): self.fft = fft if not callable(effective_k): if effective_k != 0: from pystella.derivs import FirstCenteredDifference h = effective_k effective_k = FirstCenteredDifference(h).get_eigenvalues else: def effective_k(k, dx): # pylint: disable=function-redefined return k queue = self.fft.sub_k["momenta_x"].queue sub_k = list(x.get().astype("int") for x in self.fft.sub_k.values()) eff_mom_names = ("eff_mom_x", "eff_mom_y", "eff_mom_z") self.eff_mom = {} for mu, (name, kk) in enumerate(zip(eff_mom_names, sub_k)): eff_k = effective_k(dk[mu] * kk.astype(fft.rdtype), dx[mu]) eff_k[abs(sub_k[mu]) == fft.grid_shape[mu] // 2] = 0. eff_k[sub_k[mu] == 0] = 0. import pyopencl.array as cla self.eff_mom[name] = cla.to_device(queue, eff_k) from pymbolic import var, parse from pymbolic.primitives import If, Comparison, LogicalAnd from pystella import Field indices = parse("i, j, k") eff_k = tuple( var(array)[mu] for array, mu in zip(eff_mom_names, indices)) fabs, sqrt, conj = parse("fabs, sqrt, conj") kmag = sqrt(sum(kk**2 for kk in eff_k)) from pystella import ElementWiseMap vector = Field("vector", shape=(3, )) vector_T = Field("vector_T", shape=(3, )) kvec_zero = LogicalAnd( tuple(Comparison(fabs(eff_k[mu]), "<", 1e-14) for mu in range(3))) # note: write all output via private temporaries to allow for in-place div = var("div") div_insn = [(div, sum(eff_k[mu] * vector[mu] for mu in range(3)))] self.transversify_knl = ElementWiseMap( { vector_T[mu]: If(kvec_zero, 0, vector[mu] - eff_k[mu] / kmag**2 * div) for mu in range(3) }, tmp_instructions=div_insn, lsize=(32, 1, 1), rank_shape=fft.shape(True), ) import loopy as lp def assign(asignee, expr, **kwargs): default = dict(within_inames=frozenset(("i", "j", "k")), no_sync_with=[("*", "any")]) default.update(kwargs) return lp.Assignment(asignee, expr, **default) kmag, Kappa = parse("kmag, Kappa") eps_insns = [ assign(kmag, sqrt(sum(kk**2 for kk in eff_k))), assign(Kappa, sqrt(sum(kk**2 for kk in eff_k[:2]))) ] zero = fft.cdtype.type(0) kx_ky_zero = LogicalAnd( tuple(Comparison(fabs(eff_k[mu]), "<", 1e-10) for mu in range(2))) kz_nonzero = Comparison(fabs(eff_k[2]), ">", 1e-10) eps = var("eps") eps_insns.extend([ assign( eps[0], If(kx_ky_zero, If(kz_nonzero, fft.cdtype.type(1 / 2**.5), zero), (eff_k[0] * eff_k[2] / kmag - 1j * eff_k[1]) / Kappa / 2**.5)), assign( eps[1], If(kx_ky_zero, If(kz_nonzero, fft.cdtype.type(1j / 2**(1 / 2)), zero), (eff_k[1] * eff_k[2] / kmag + 1j * eff_k[0]) / Kappa / 2**.5)), assign(eps[2], If(kx_ky_zero, zero, -Kappa / kmag / 2**.5)) ]) plus, minus, lng = Field("plus"), Field("minus"), Field("lng") plus_tmp, minus_tmp = parse("plus_tmp, minus_tmp") pol_isns = [(plus_tmp, sum(vector[mu] * conj(eps[mu]) for mu in range(3))), (minus_tmp, sum(vector[mu] * eps[mu] for mu in range(3)))] args = [ lp.TemporaryVariable("kmag"), lp.TemporaryVariable("Kappa"), lp.TemporaryVariable("eps", shape=(3, )), ... ] self.vec_to_pol_knl = ElementWiseMap( { plus: plus_tmp, minus: minus_tmp }, tmp_instructions=eps_insns + pol_isns, args=args, lsize=(32, 1, 1), rank_shape=fft.shape(True), ) vector_tmp = var("vector_tmp") vec_insns = [(vector_tmp[mu], plus * eps[mu] + minus * conj(eps[mu])) for mu in range(3)] self.pol_to_vec_knl = ElementWiseMap( {vector[mu]: vector_tmp[mu] for mu in range(3)}, tmp_instructions=eps_insns + vec_insns, args=args, lsize=(32, 1, 1), rank_shape=fft.shape(True), ) ksq = sum(kk**2 for kk in eff_k) lng_rhs = If(kvec_zero, 0, -div / ksq * 1j) self.vec_decomp_knl = ElementWiseMap( { plus: plus_tmp, minus: minus_tmp, lng: lng_rhs }, tmp_instructions=eps_insns + pol_isns + div_insn, args=args, lsize=(32, 1, 1), rank_shape=fft.shape(True), ) lng_rhs = If(kvec_zero, 0, -div / ksq**.5 * 1j) self.vec_decomp_knl_times_abs_k = ElementWiseMap( { plus: plus_tmp, minus: minus_tmp, lng: lng_rhs }, tmp_instructions=eps_insns + pol_isns + div_insn, args=args, lsize=(32, 1, 1), rank_shape=fft.shape(True), ) from pystella.sectors import tensor_index as tid eff_k_hat = tuple(kk / sqrt(sum(kk**2 for kk in eff_k)) for kk in eff_k) hij = Field("hij", shape=(6, )) hij_TT = Field("hij_TT", shape=(6, )) Pab = var("P") Pab_insns = [(Pab[tid(a, b)], (If(Comparison(a, "==", b), 1, 0) - eff_k_hat[a - 1] * eff_k_hat[b - 1])) for a in range(1, 4) for b in range(a, 4)] hij_TT_tmp = var("hij_TT_tmp") TT_insns = [(hij_TT_tmp[tid(a, b)], sum((Pab[tid(a, c)] * Pab[tid(d, b)] - Pab[tid(a, b)] * Pab[tid(c, d)] / 2) * hij[tid(c, d)] for c in range(1, 4) for d in range(1, 4))) for a in range(1, 4) for b in range(a, 4)] # note: where conditionals (branch divergence) go can matter: # this kernel is twice as fast when putting the branching in the global # write, rather than when setting hij_TT_tmp write_insns = [(hij_TT[tid(a, b)], If(kvec_zero, 0, hij_TT_tmp[tid(a, b)])) for a in range(1, 4) for b in range(a, 4)] self.tt_knl = ElementWiseMap( write_insns, tmp_instructions=Pab_insns + TT_insns, lsize=(32, 1, 1), rank_shape=fft.shape(True), ) tensor_to_pol_insns = { plus: sum(hij[tid(c, d)] * conj(eps[c - 1]) * conj(eps[d - 1]) for c in range(1, 4) for d in range(1, 4)), minus: sum(hij[tid(c, d)] * eps[c - 1] * eps[d - 1] for c in range(1, 4) for d in range(1, 4)) } self.tensor_to_pol_knl = ElementWiseMap( tensor_to_pol_insns, tmp_instructions=eps_insns, args=args, lsize=(32, 1, 1), rank_shape=fft.shape(True), ) pol_to_tensor_insns = { hij[tid(a, b)]: (plus * eps[a - 1] * eps[b - 1] + minus * conj(eps[a - 1]) * conj(eps[b - 1])) for a in range(1, 4) for b in range(a, 4) } self.pol_to_tensor_knl = ElementWiseMap( pol_to_tensor_insns, tmp_instructions=eps_insns, args=args, lsize=(32, 1, 1), rank_shape=fft.shape(True), )
def is_odd(expr): from pymbolic.primitives import If, Comparison, Remainder return If(Comparison(Remainder(expr, 2), "==", 1), 1, 0)
def parse_postfix(self, pstate, min_precedence, left_exp): import pymbolic.primitives as primitives did_something = False next_tag = pstate.next_tag() if next_tag is _openpar and _PREC_CALL > min_precedence: pstate.advance() args, kwargs = self.parse_arglist(pstate) if kwargs: left_exp = primitives.CallWithKwargs(left_exp, args, kwargs) else: left_exp = primitives.Call(left_exp, args) did_something = True elif next_tag is _openbracket and _PREC_CALL > min_precedence: pstate.advance() pstate.expect_not_end() left_exp = primitives.Subscript(left_exp, self.parse_expression(pstate)) pstate.expect(_closebracket) pstate.advance() did_something = True elif next_tag is _dot and _PREC_CALL > min_precedence: pstate.advance() pstate.expect(_identifier) left_exp = primitives.Lookup(left_exp, pstate.next_str()) pstate.advance() did_something = True elif next_tag is _plus and _PREC_PLUS > min_precedence: pstate.advance() left_exp += self.parse_expression(pstate, _PREC_PLUS) did_something = True elif next_tag is _minus and _PREC_PLUS > min_precedence: pstate.advance() left_exp -= self.parse_expression(pstate, _PREC_PLUS) did_something = True elif next_tag is _times and _PREC_TIMES > min_precedence: pstate.advance() left_exp *= self.parse_expression(pstate, _PREC_TIMES) did_something = True elif next_tag is _floordiv and _PREC_TIMES > min_precedence: pstate.advance() left_exp //= self.parse_expression(pstate, _PREC_TIMES) did_something = True elif next_tag is _over and _PREC_TIMES > min_precedence: pstate.advance() left_exp /= self.parse_expression(pstate, _PREC_TIMES) did_something = True elif next_tag is _modulo and _PREC_TIMES > min_precedence: pstate.advance() left_exp %= self.parse_expression(pstate, _PREC_TIMES) did_something = True elif next_tag is _power and _PREC_POWER > min_precedence: pstate.advance() left_exp **= self.parse_expression(pstate, _PREC_POWER) did_something = True elif next_tag is _and and _PREC_LOGICAL_AND > min_precedence: pstate.advance() from pymbolic.primitives import LogicalAnd left_exp = LogicalAnd( (left_exp, self.parse_expression(pstate, _PREC_LOGICAL_AND))) did_something = True elif next_tag is _or and _PREC_LOGICAL_OR > min_precedence: pstate.advance() from pymbolic.primitives import LogicalOr left_exp = LogicalOr( (left_exp, self.parse_expression(pstate, _PREC_LOGICAL_OR))) did_something = True elif next_tag is _bitwiseor and _PREC_BITWISE_OR > min_precedence: pstate.advance() from pymbolic.primitives import BitwiseOr left_exp = BitwiseOr( (left_exp, self.parse_expression(pstate, _PREC_BITWISE_OR))) did_something = True elif next_tag is _bitwisexor and _PREC_BITWISE_XOR > min_precedence: pstate.advance() from pymbolic.primitives import BitwiseXor left_exp = BitwiseXor( (left_exp, self.parse_expression(pstate, _PREC_BITWISE_XOR))) did_something = True elif next_tag is _bitwiseand and _PREC_BITWISE_AND > min_precedence: pstate.advance() from pymbolic.primitives import BitwiseAnd left_exp = BitwiseAnd( (left_exp, self.parse_expression(pstate, _PREC_BITWISE_AND))) did_something = True elif next_tag is _rightshift and _PREC_SHIFT > min_precedence: pstate.advance() from pymbolic.primitives import RightShift left_exp = RightShift(left_exp, self.parse_expression(pstate, _PREC_SHIFT)) did_something = True elif next_tag is _leftshift and _PREC_SHIFT > min_precedence: pstate.advance() from pymbolic.primitives import LeftShift left_exp = LeftShift(left_exp, self.parse_expression(pstate, _PREC_SHIFT)) did_something = True elif next_tag in self._COMP_TABLE and _PREC_COMPARISON > min_precedence: pstate.advance() from pymbolic.primitives import Comparison left_exp = Comparison( left_exp, self._COMP_TABLE[next_tag], self.parse_expression(pstate, _PREC_COMPARISON)) did_something = True elif next_tag is _colon and _PREC_SLICE >= min_precedence: pstate.advance() expr_pstate = pstate.copy() assert not isinstance(left_exp, primitives.Slice) from pytools.lex import ParseError try: next_expr = self.parse_expression(expr_pstate, _PREC_SLICE) except ParseError: # no expression follows, too bad. left_exp = primitives.Slice(( left_exp, None, )) else: left_exp = _join_to_slice(left_exp, next_expr) pstate.assign(expr_pstate) elif next_tag is _comma and _PREC_COMMA > min_precedence: # The precedence makes the comma left-associative. pstate.advance() if pstate.is_at_end() or pstate.next_tag() is _closepar: left_exp = (left_exp, ) else: new_el = self.parse_expression(pstate, _PREC_COMMA) if isinstance(left_exp, tuple) \ and not isinstance(left_exp, FinalizedTuple): left_exp = left_exp + (new_el, ) else: left_exp = (left_exp, new_el) did_something = True return left_exp, did_something