def reduction( self, expr: expr, neighborhood: expr, op: str, kwargs_keys: t.List[str], kwargs_values: t.List[expr], ): kwargs = dict(zip(kwargs_keys, kwargs_values)) assert len(kwargs) == len(kwargs_keys) == len(kwargs_values) # TODO: what about these hard coded strings? wrong_kwargs = kwargs.keys() - {"init", "weights"} if 0 < len(wrong_kwargs): raise DuskSyntaxError( f"Unsupported kwargs '{wrong_kwargs}' in reduction!") neighborhood = self.location_chain(neighborhood) with self.ctx.location.reduction(neighborhood): expr = self.expression(expr) op_map = {"sum": "+", "mul": "*", "min": "min", "max": "max"} if not op in op_map: raise DuskSyntaxError(f"Invalid operator '{op}' for reduction!") if "init" in kwargs: init = self.expression(kwargs["init"]) else: # TODO: "min" and "max" are still kinda stupid # we should use something like this: # https://en.cppreference.com/w/cpp/types/numeric_limits/max # the current solution is: # - appropriate for doubles # - okish for floats (may trip floating point exceptions but correct outcome) # - breaks for int (undefined behavior!) init_map = { "sum": "0", "mul": "1", "min": "1.79769313486231571e+308", "max": "-1.79769313486231571e+308", } init = make_literal_access_expr( init_map[op], sir.BuiltinType.TypeID.Value("Double")) op = op_map[op] weights = None if "weights" in kwargs: # TODO: check for `kwargs["weight"].ctx == Load`? weights = [ self.expression(weight) for weight in kwargs["weights"].elts ] return make_reduction_over_neighbor_expr( op, expr, init, neighborhood, weights, )
def vertical_region(self): if self.in_vertical_region: raise DuskSyntaxError("Vertical regions can't be nested!") if self.in_loop_stmt or self.in_reduction: raise DuskSyntaxError( "Encountered vertical region inside reduction or loop statement!" ) self.in_vertical_region = True yield self.in_vertical_region = False
def loop_stmt(self, location_chain: LocationChain): if self.in_loop_stmt: raise DuskSyntaxError("Nested loop statements aren't allowed!") if self.in_reduction: raise DuskSyntaxError("Loop statements can't occur inside reductions!") self.in_loop_stmt = True with self._neighbor_iteration(location_chain): yield self.in_loop_stmt = False
def match(self, nodes, **kwargs): if not isinstance(nodes, list): raise DuskSyntaxError(f"Expected a list, but got '{type(nodes)}'!", nodes) if len(nodes) != len(self.matchers): raise DuskSyntaxError( f"Expected a list of length {len(self.matchers)}'!", nodes) for matcher, node in zip(self.matchers, nodes): match(matcher, node, **kwargs)
def reduction( self, expr: expr, neighborhood: expr, op: str, kwargs_keys: t.List[str], kwargs_values: t.List[expr], ): kwargs = dict(zip(kwargs_keys, kwargs_values)) assert len(kwargs) == len(kwargs_keys) == len(kwargs_values) # TODO: what about these hard coded strings? wrong_kwargs = kwargs.keys() - {"init", "weights"} if 0 < len(wrong_kwargs): raise DuskSyntaxError(f"Unsupported kwargs '{wrong_kwargs}' in reduction!") neighborhood = self.location_chain(neighborhood) with self.ctx.location.reduction(neighborhood): expr = self.expression(expr) op_map = {"sum": "+", "mul": "*", "min": "min", "max": "max"} if not op in op_map: raise DuskSyntaxError(f"Invalid operator '{op}' for reduction!") if "init" in kwargs: init = self.expression(kwargs["init"]) else: # TODO: "min" and "max" are kinda stupid # we should use something like this: # https://en.cppreference.com/w/cpp/types/numeric_limits/max # but for double it should be 1.79769e+308 # FIXME: probably breaks for int init_map = { "sum": "0", "mul": "1", "min": "9" * 400, "max": "-" + ("9" * 400), } init = make_literal_access_expr( init_map[op], sir.BuiltinType.TypeID.Value("Double") ) op = op_map[op] weights = None if "weights" in kwargs: # TODO: check for `kwargs["weight"].ctx == Load`? weights = [self.expression(weight) for weight in kwargs["weights"].elts] return make_reduction_over_neighbor_expr(op, expr, init, neighborhood, weights,)
def field_index(self, field: DuskField, vindex=None, hindex=None): voffset, vbase = (self.relative_vertical_offset(vindex) if vindex is not None else (0, None)) hindex = self.location_chain(hindex) if hindex is not None else None if not self.ctx.location.in_neighbor_iteration: if hindex is not None: raise DuskSyntaxError( f"Invalid horizontal index for field '{field.sir.name}' " "outside of neighbor iteration!") return make_unstructured_offset(False), voffset, vbase neighbor_iteration = self.ctx.location.current_neighbor_iteration field_dimension = self.ctx.location.get_field_dimension(field.sir) # TODO: `vindex` is _non-sensical_ if the field is 2d # TODO: we should check that `field_dimension` is valid for # the current neighbor iteration(s?) if hindex is None: if self.ctx.location.is_dense(field_dimension): if self.ctx.location.is_ambiguous(neighbor_iteration): raise DuskSyntaxError( f"Field '{field.sir.name}' requires a horizontal index " "inside of ambiguous neighbor iteration!") return ( make_unstructured_offset( field_dimension[0] == neighbor_iteration[-1]), voffset, vbase, ) return make_unstructured_offset(True), voffset, vbase # TODO: check if `hindex` is valid for this field's location type if len(hindex) == 1: if neighbor_iteration[0] != hindex[0]: raise DuskSyntaxError( f"Invalid horizontal offset for field '{field.sir.name}'!") return make_unstructured_offset(False), voffset, vbase if hindex != neighbor_iteration: raise DuskSyntaxError( f"Invalid horizontal offset for field '{field.sir.name}'!") return make_unstructured_offset(True), voffset, vbase
def _neighbor_iteration(self, location_chain: LocationChain): if not self.in_vertical_region: raise DuskSyntaxError( "Reductions or loop statements can only occur inside vertical regions!" ) if len(location_chain) <= 1: raise DuskSyntaxError( "Reductions and loop statements must have a location chain of" "length longer than 1!") self.neighbor_iterations.append(location_chain) yield self.neighbor_iterations.pop()
def match(self, nodes, **kwargs) -> None: if not isinstance(nodes, list): raise DuskSyntaxError(f"Expected a list, but got '{type(nodes)}'!", nodes) elif isinstance(self.n, int): if len(nodes) != self.n: raise DuskSyntaxError( f"Expected a list of length {self.n}, but got list of length {len(nodes)}!", nodes, ) elif self.n == 0: return for node in nodes: match(self.matcher, node, **kwargs)
def assign(self, lhs: expr, rhs: expr, op: t.Optional[operator] = None): py_assign_op_to_sir_assign_op = { Add: "+=", Sub: "-=", Mult: "*=", Div: "/=", Mod: "%=", LShift: "<<=", RShift: ">>=", BitOr: "|=", BitXor: "^=", BitAnd: "&=", } if op is None: op = "=" elif isinstance(op, Pow): op = "=" rhs = make_fun_call_expr( "gridtools::dawn::math::pow", [self.expression(lhs), self.expression(rhs)], ) return make_assignment_stmt(self.expression(lhs), rhs, op) elif type(op) in py_assign_op_to_sir_assign_op.keys(): op = py_assign_op_to_sir_assign_op[type(op)] else: raise DuskSyntaxError(f"Unsupported assignment operator '{op}'!", op) return make_assignment_stmt(self.expression(lhs), self.expression(rhs), op)
def binop(self, left: expr, op: t.Any, right: expr): py_binops_to_sir_binops = { Add: "+", Sub: "-", Mult: "*", Div: "/", LShift: "<<", RShift: ">>", BitOr: "|", BitXor: "^", BitAnd: "&", } if type(op) in py_binops_to_sir_binops.keys(): op = py_binops_to_sir_binops[type(op)] return make_binary_operator(self.expression(left), op, self.expression(right)) elif isinstance(op, Pow): return make_fun_call_expr( "gridtools::dawn::math::pow", [self.expression(left), self.expression(right)], ) else: raise DuskSyntaxError(f"Unsupported binary operator '{op}'!", op)
def relative_vertical_offset(self, name: str, vindex: int = 0, vop=Add()): if not isinstance( self.ctx.scope.current_scope.fetch(name), VerticalIterationVariable ): raise DuskSyntaxError( f"'{name}' isn't a vertical iteration variable!", name ) return -vindex if isinstance(vop, Sub) else vindex
def field_access_expr(self, field: DuskField, index: expr = None): if not self.ctx.location.in_vertical_region: raise DuskSyntaxError( f"Invalid field access {name} outside of a vertical region!" ) return make_field_access_expr( field.sir.name, self.field_index(index, field=field) )
def field_index(self, field: DuskField, vindex=None, hindex=None): vindex = self.relative_vertical_offset(vindex) if vindex is not None else 0 hindex = self.location_chain(hindex) if hindex is not None else None if not self.ctx.location.in_neighbor_iteration: if hindex is not None: raise DuskSyntaxError( f"Invalid horizontal index for field '{field.sir.name}' " "outside of neighbor iteration!" ) return [False, vindex] neighbor_iteration = self.ctx.location.current_neighbor_iteration field_dimension = self.ctx.location.get_field_dimension(field.sir) # TODO: we should check that `field_dimension` is valid for # the current neighbor iteration(s?) if hindex is None: if self.ctx.location.is_dense(field_dimension): if self.ctx.location.is_ambiguous(neighbor_iteration): raise DuskSyntaxError( f"Field '{field.sir.name}' requires a horizontal index " "inside of ambiguous neighbor iteration!" ) return [field_dimension[0] == neighbor_iteration[-1], vindex] return [True, vindex] # TODO: check if `hindex` is valid for this field's location type if len(hindex) == 1: if neighbor_iteration[0] != hindex[0]: raise DuskSyntaxError( f"Invalid horizontal offset for field '{field.sir.name}'!" ) return [False, vindex] if hindex != neighbor_iteration: raise DuskSyntaxError( f"Invalid horizontal offset for field '{field.sir.name}'!" ) return [True, vindex]
def math_function(self, name: str, args: t.List): if name in self.unary_math_functions: if len(args) != 1: raise DuskSyntaxError(f"Function '{name}' takes exactly one argument!") return make_fun_call_expr( f"gridtools::dawn::math::{name}", [self.expression(args[0])] ) if name in self.binary_math_functions: if len(args) != 2: raise DuskSyntaxError(f"Function '{name}' takes exactly two arguments!") return make_fun_call_expr( f"gridtools::dawn::math::{name}", [self.expression(arg) for arg in args], ) raise DuskSyntaxError(f"Unrecognized function call '{name}'!")
def vertical_interval_bound(self, bound): if does_match(Constant(value=int, kind=None), bound): return sir.Interval.Start, bound.value elif does_match( UnaryOp(op=USub, operand=Constant(value=int, kind=None)), bound): return sir.Interval.End, -bound.operand.value else: raise DuskSyntaxError( f"Unrecognized vertical intervals bound '{bound}'!", bound)
def var(self, name: str, index: expr = None): if not self.ctx.scope.current_scope.contains(name): raise DuskSyntaxError(f"Undeclared variable '{name}'!", name) symbol = self.ctx.scope.current_scope.fetch(name) if isinstance(symbol, DuskField): return self.field_access_expr(symbol, index) else: raise DuskInternalError( f"Encountered unknown symbol type '{symbol}' ('{name}')!")
def match(self, node, **kwargs): matched = False for matcher in self.matchers: try: match(matcher, node, **kwargs) matched = True break except DuskSyntaxError: pass if not matched: raise DuskSyntaxError(f"Encountered unrecognized node '{node}'!", node)
def funcall(self, name: str, node: Call): # TODO: bad hardcoded string if name == "reduce_over": return self.reduce_over(node) if name in {"sum_over", "min_over", "max_over"}: return self.short_reduce_over(node) if name in self.unary_math_functions or name in self.binary_math_functions: return self.math_function(node) raise DuskSyntaxError(f"Unrecognized function call '{name}'!", node)
def location_chain(self, locations: t.List, include_center: t.Optional[t.Literal["Origin"]] = None): does_include_center = include_center is not None locations = [self.location_type(location) for location in locations] if does_include_center and not self.ctx.location.is_ambiguous( locations): raise DuskSyntaxError( f"including the center is only allowed if start equals end location of the neighbor chain!" ) return does_include_center, locations
def match_ast(matcher: AST, node, **kwargs): if not isinstance(node, type(matcher)): raise DuskSyntaxError( f"Expected node type '{type(matcher)}', but got '{type(node)}'!", node) for field in matcher._fields: try: match(getattr(matcher, field), getattr(node, field), **kwargs) except DuskSyntaxError as e: if e.loc is None and isinstance(node, (stmt, expr)): # add location info if possible e.loc_from_node(node) raise e
def compare(self, left: expr, op, right: expr): # FIXME: we should probably have a better answer when we need such mappings py_compare_to_sir_compare = { Eq: "==", NotEq: "!=", Lt: "<", LtE: "<=", Gt: ">", GtE: ">=", } if type(op) not in py_compare_to_sir_compare.keys(): raise DuskSyntaxError(f"Unsupported comparison operator '{op}'!", op) op = py_compare_to_sir_compare[type(op)] return make_binary_operator(self.expression(left), op, self.expression(right))
def relative_vertical_offset(self, base: str, shift: int = 0, vop=Add()): base = self.ctx.scope.current_scope.fetch(base) if not isinstance(base, (VerticalIterationVariable, DuskIndexField)): raise DuskSyntaxError( f"'{base}' isn't a vertical iteration variable or index field!", base) if isinstance(base, DuskIndexField): # TODO: check that `base` is valid in this context # * compatible neighbor iteration # * _correct_ usage (not clear what this entails) base = base.sir.name if (self.ctx.location.in_neighbor_iteration and self.ctx.location.is_ambiguous( self.ctx.location.current_neighbor_iteration)): raise DuskSyntaxError( f"Index field {base} used in ambiguous neighbor iteration!" ) elif isinstance(base, VerticalIterationVariable): base = None else: assert False return (-shift if isinstance(vop, Sub) else shift), base
def constant(self, value): # TODO: properly distinguish between float and double built_in_type_map = {bool: "Boolean", int: "Integer", float: "Double"} if type(value) not in built_in_type_map.keys(): raise DuskSyntaxError( f"Unsupported constant '{value}' of type '{type(value)}'!", value ) _type = sir.BuiltinType.TypeID.Value(built_in_type_map[type(value)]) if isinstance(value, bool): value = "true" if value else "false" else: # TODO: does `str` really work here? (what about NaNs, precision, 1e11 notation, etc) value = str(value) return make_literal_access_expr(value, _type,)
def dispatch(rules: t.Dict[t.Any, t.Callable], node): for recognizer, rule in rules.items(): if does_match(recognizer, node): return rule(node) raise DuskSyntaxError(f"Unrecognized node: '{node}'!", node)
def match_type(matcher: type, node, **kwargs): if not isinstance(node, matcher): raise DuskSyntaxError( f"Expected type '{matcher}', but got '{type(node)}'", node)
def location_type(self, name: str): location_names = {l.__name__ for l in LOCATION_TYPES} if name not in location_names: raise DuskSyntaxError(f"Invalid location type '{name}'!", name) return sir.LocationType.Value(name)
def match_primitives(matcher, node, **kwargs): if matcher != node: raise DuskSyntaxError(f"Expected '{matcher}', but got '{node}'!", node)