コード例 #1
0
    def reduction(
        self,
        expr: expr,
        neighborhood: expr,
        op: str,
        kwargs_keys: t.List[str],
        kwargs_values: t.List[expr],
    ):
        kwargs = dict(zip(kwargs_keys, kwargs_values))
        assert len(kwargs) == len(kwargs_keys) == len(kwargs_values)

        # TODO: what about these hard coded strings?
        wrong_kwargs = kwargs.keys() - {"init", "weights"}
        if 0 < len(wrong_kwargs):
            raise DuskSyntaxError(
                f"Unsupported kwargs '{wrong_kwargs}' in reduction!")

        neighborhood = self.location_chain(neighborhood)
        with self.ctx.location.reduction(neighborhood):
            expr = self.expression(expr)

        op_map = {"sum": "+", "mul": "*", "min": "min", "max": "max"}
        if not op in op_map:
            raise DuskSyntaxError(f"Invalid operator '{op}' for reduction!")

        if "init" in kwargs:
            init = self.expression(kwargs["init"])
        else:
            # TODO: "min" and "max" are still kinda stupid
            # we should use something like this:
            # https://en.cppreference.com/w/cpp/types/numeric_limits/max
            # the current solution is:
            #   - appropriate for doubles
            #   - okish for floats (may trip floating point exceptions but correct outcome)
            #   - breaks for int (undefined behavior!)
            init_map = {
                "sum": "0",
                "mul": "1",
                "min": "1.79769313486231571e+308",
                "max": "-1.79769313486231571e+308",
            }
            init = make_literal_access_expr(
                init_map[op], sir.BuiltinType.TypeID.Value("Double"))

        op = op_map[op]

        weights = None
        if "weights" in kwargs:
            # TODO: check for `kwargs["weight"].ctx == Load`?
            weights = [
                self.expression(weight) for weight in kwargs["weights"].elts
            ]

        return make_reduction_over_neighbor_expr(
            op,
            expr,
            init,
            neighborhood,
            weights,
        )
コード例 #2
0
ファイル: semantics.py プロジェクト: BenWeber42/dusk
 def vertical_region(self):
     if self.in_vertical_region:
         raise DuskSyntaxError("Vertical regions can't be nested!")
     if self.in_loop_stmt or self.in_reduction:
         raise DuskSyntaxError(
             "Encountered vertical region inside reduction or loop statement!"
         )
     self.in_vertical_region = True
     yield
     self.in_vertical_region = False
コード例 #3
0
    def loop_stmt(self, location_chain: LocationChain):

        if self.in_loop_stmt:
            raise DuskSyntaxError("Nested loop statements aren't allowed!")
        if self.in_reduction:
            raise DuskSyntaxError("Loop statements can't occur inside reductions!")

        self.in_loop_stmt = True
        with self._neighbor_iteration(location_chain):
            yield
        self.in_loop_stmt = False
コード例 #4
0
ファイル: match.py プロジェクト: BenWeber42/dusk
    def match(self, nodes, **kwargs):

        if not isinstance(nodes, list):
            raise DuskSyntaxError(f"Expected a list, but got '{type(nodes)}'!",
                                  nodes)

        if len(nodes) != len(self.matchers):
            raise DuskSyntaxError(
                f"Expected a list of length {len(self.matchers)}'!", nodes)

        for matcher, node in zip(self.matchers, nodes):
            match(matcher, node, **kwargs)
コード例 #5
0
ファイル: grammar.py プロジェクト: BenWeber42/dusk
    def reduction(
        self,
        expr: expr,
        neighborhood: expr,
        op: str,
        kwargs_keys: t.List[str],
        kwargs_values: t.List[expr],
    ):
        kwargs = dict(zip(kwargs_keys, kwargs_values))
        assert len(kwargs) == len(kwargs_keys) == len(kwargs_values)

        # TODO: what about these hard coded strings?
        wrong_kwargs = kwargs.keys() - {"init", "weights"}
        if 0 < len(wrong_kwargs):
            raise DuskSyntaxError(f"Unsupported kwargs '{wrong_kwargs}' in reduction!")

        neighborhood = self.location_chain(neighborhood)
        with self.ctx.location.reduction(neighborhood):
            expr = self.expression(expr)

        op_map = {"sum": "+", "mul": "*", "min": "min", "max": "max"}
        if not op in op_map:
            raise DuskSyntaxError(f"Invalid operator '{op}' for reduction!")

        if "init" in kwargs:
            init = self.expression(kwargs["init"])
        else:
            # TODO: "min" and "max" are kinda stupid
            # we should use something like this:
            # https://en.cppreference.com/w/cpp/types/numeric_limits/max
            # but for double it should be 1.79769e+308
            # FIXME: probably breaks for int
            init_map = {
                "sum": "0",
                "mul": "1",
                "min": "9" * 400,
                "max": "-" + ("9" * 400),
            }
            init = make_literal_access_expr(
                init_map[op], sir.BuiltinType.TypeID.Value("Double")
            )

        op = op_map[op]

        weights = None
        if "weights" in kwargs:
            # TODO: check for `kwargs["weight"].ctx == Load`?
            weights = [self.expression(weight) for weight in kwargs["weights"].elts]

        return make_reduction_over_neighbor_expr(op, expr, init, neighborhood, weights,)
コード例 #6
0
    def field_index(self, field: DuskField, vindex=None, hindex=None):

        voffset, vbase = (self.relative_vertical_offset(vindex)
                          if vindex is not None else (0, None))
        hindex = self.location_chain(hindex) if hindex is not None else None

        if not self.ctx.location.in_neighbor_iteration:
            if hindex is not None:
                raise DuskSyntaxError(
                    f"Invalid horizontal index for field '{field.sir.name}' "
                    "outside of neighbor iteration!")
            return make_unstructured_offset(False), voffset, vbase

        neighbor_iteration = self.ctx.location.current_neighbor_iteration
        field_dimension = self.ctx.location.get_field_dimension(field.sir)

        # TODO: `vindex` is _non-sensical_ if the field is 2d

        # TODO: we should check that `field_dimension` is valid for
        #       the current neighbor iteration(s?)

        if hindex is None:
            if self.ctx.location.is_dense(field_dimension):
                if self.ctx.location.is_ambiguous(neighbor_iteration):
                    raise DuskSyntaxError(
                        f"Field '{field.sir.name}' requires a horizontal index "
                        "inside of ambiguous neighbor iteration!")

                return (
                    make_unstructured_offset(
                        field_dimension[0] == neighbor_iteration[-1]),
                    voffset,
                    vbase,
                )
            return make_unstructured_offset(True), voffset, vbase

        # TODO: check if `hindex` is valid for this field's location type

        if len(hindex) == 1:
            if neighbor_iteration[0] != hindex[0]:
                raise DuskSyntaxError(
                    f"Invalid horizontal offset for field '{field.sir.name}'!")
            return make_unstructured_offset(False), voffset, vbase

        if hindex != neighbor_iteration:
            raise DuskSyntaxError(
                f"Invalid horizontal offset for field '{field.sir.name}'!")

        return make_unstructured_offset(True), voffset, vbase
コード例 #7
0
ファイル: semantics.py プロジェクト: BenWeber42/dusk
    def _neighbor_iteration(self, location_chain: LocationChain):

        if not self.in_vertical_region:
            raise DuskSyntaxError(
                "Reductions or loop statements can only occur inside vertical regions!"
            )

        if len(location_chain) <= 1:
            raise DuskSyntaxError(
                "Reductions and loop statements must have a location chain of"
                "length longer than 1!")

        self.neighbor_iterations.append(location_chain)
        yield
        self.neighbor_iterations.pop()
コード例 #8
0
    def match(self, nodes, **kwargs) -> None:

        if not isinstance(nodes, list):
            raise DuskSyntaxError(f"Expected a list, but got '{type(nodes)}'!", nodes)

        elif isinstance(self.n, int):
            if len(nodes) != self.n:
                raise DuskSyntaxError(
                    f"Expected a list of length {self.n}, but got list of length {len(nodes)}!",
                    nodes,
                )
            elif self.n == 0:
                return

        for node in nodes:
            match(self.matcher, node, **kwargs)
コード例 #9
0
ファイル: grammar.py プロジェクト: BenWeber42/dusk
    def assign(self, lhs: expr, rhs: expr, op: t.Optional[operator] = None):

        py_assign_op_to_sir_assign_op = {
            Add: "+=",
            Sub: "-=",
            Mult: "*=",
            Div: "/=",
            Mod: "%=",
            LShift: "<<=",
            RShift: ">>=",
            BitOr: "|=",
            BitXor: "^=",
            BitAnd: "&=",
        }

        if op is None:
            op = "="

        elif isinstance(op, Pow):
            op = "="
            rhs = make_fun_call_expr(
                "gridtools::dawn::math::pow",
                [self.expression(lhs), self.expression(rhs)],
            )
            return make_assignment_stmt(self.expression(lhs), rhs, op)

        elif type(op) in py_assign_op_to_sir_assign_op.keys():
            op = py_assign_op_to_sir_assign_op[type(op)]

        else:
            raise DuskSyntaxError(f"Unsupported assignment operator '{op}'!", op)

        return make_assignment_stmt(self.expression(lhs), self.expression(rhs), op)
コード例 #10
0
    def binop(self, left: expr, op: t.Any, right: expr):
        py_binops_to_sir_binops = {
            Add: "+",
            Sub: "-",
            Mult: "*",
            Div: "/",
            LShift: "<<",
            RShift: ">>",
            BitOr: "|",
            BitXor: "^",
            BitAnd: "&",
        }

        if type(op) in py_binops_to_sir_binops.keys():
            op = py_binops_to_sir_binops[type(op)]
            return make_binary_operator(self.expression(left), op,
                                        self.expression(right))

        elif isinstance(op, Pow):
            return make_fun_call_expr(
                "gridtools::dawn::math::pow",
                [self.expression(left),
                 self.expression(right)],
            )

        else:
            raise DuskSyntaxError(f"Unsupported binary operator '{op}'!", op)
コード例 #11
0
ファイル: grammar.py プロジェクト: BenWeber42/dusk
 def relative_vertical_offset(self, name: str, vindex: int = 0, vop=Add()):
     if not isinstance(
         self.ctx.scope.current_scope.fetch(name), VerticalIterationVariable
     ):
         raise DuskSyntaxError(
             f"'{name}' isn't a vertical iteration variable!", name
         )
     return -vindex if isinstance(vop, Sub) else vindex
コード例 #12
0
ファイル: grammar.py プロジェクト: BenWeber42/dusk
 def field_access_expr(self, field: DuskField, index: expr = None):
     if not self.ctx.location.in_vertical_region:
         raise DuskSyntaxError(
             f"Invalid field access {name} outside of a vertical region!"
         )
     return make_field_access_expr(
         field.sir.name, self.field_index(index, field=field)
     )
コード例 #13
0
ファイル: grammar.py プロジェクト: BenWeber42/dusk
    def field_index(self, field: DuskField, vindex=None, hindex=None):

        vindex = self.relative_vertical_offset(vindex) if vindex is not None else 0
        hindex = self.location_chain(hindex) if hindex is not None else None

        if not self.ctx.location.in_neighbor_iteration:
            if hindex is not None:
                raise DuskSyntaxError(
                    f"Invalid horizontal index for field '{field.sir.name}' "
                    "outside of neighbor iteration!"
                )
            return [False, vindex]

        neighbor_iteration = self.ctx.location.current_neighbor_iteration
        field_dimension = self.ctx.location.get_field_dimension(field.sir)

        # TODO: we should check that `field_dimension` is valid for
        #       the current neighbor iteration(s?)

        if hindex is None:
            if self.ctx.location.is_dense(field_dimension):
                if self.ctx.location.is_ambiguous(neighbor_iteration):
                    raise DuskSyntaxError(
                        f"Field '{field.sir.name}' requires a horizontal index "
                        "inside of ambiguous neighbor iteration!"
                    )

                return [field_dimension[0] == neighbor_iteration[-1], vindex]
            return [True, vindex]

        # TODO: check if `hindex` is valid for this field's location type

        if len(hindex) == 1:
            if neighbor_iteration[0] != hindex[0]:
                raise DuskSyntaxError(
                    f"Invalid horizontal offset for field '{field.sir.name}'!"
                )
            return [False, vindex]

        if hindex != neighbor_iteration:
            raise DuskSyntaxError(
                f"Invalid horizontal offset for field '{field.sir.name}'!"
            )

        return [True, vindex]
コード例 #14
0
ファイル: grammar.py プロジェクト: BenWeber42/dusk
    def math_function(self, name: str, args: t.List):

        if name in self.unary_math_functions:
            if len(args) != 1:
                raise DuskSyntaxError(f"Function '{name}' takes exactly one argument!")
            return make_fun_call_expr(
                f"gridtools::dawn::math::{name}", [self.expression(args[0])]
            )

        if name in self.binary_math_functions:
            if len(args) != 2:
                raise DuskSyntaxError(f"Function '{name}' takes exactly two arguments!")
            return make_fun_call_expr(
                f"gridtools::dawn::math::{name}",
                [self.expression(arg) for arg in args],
            )

        raise DuskSyntaxError(f"Unrecognized function call '{name}'!")
コード例 #15
0
 def vertical_interval_bound(self, bound):
     if does_match(Constant(value=int, kind=None), bound):
         return sir.Interval.Start, bound.value
     elif does_match(
             UnaryOp(op=USub, operand=Constant(value=int, kind=None)),
             bound):
         return sir.Interval.End, -bound.operand.value
     else:
         raise DuskSyntaxError(
             f"Unrecognized vertical intervals bound '{bound}'!", bound)
コード例 #16
0
    def var(self, name: str, index: expr = None):

        if not self.ctx.scope.current_scope.contains(name):
            raise DuskSyntaxError(f"Undeclared variable '{name}'!", name)

        symbol = self.ctx.scope.current_scope.fetch(name)
        if isinstance(symbol, DuskField):
            return self.field_access_expr(symbol, index)
        else:
            raise DuskInternalError(
                f"Encountered unknown symbol type '{symbol}' ('{name}')!")
コード例 #17
0
    def match(self, node, **kwargs):
        matched = False
        for matcher in self.matchers:
            try:
                match(matcher, node, **kwargs)
                matched = True
                break
            except DuskSyntaxError:
                pass

        if not matched:
            raise DuskSyntaxError(f"Encountered unrecognized node '{node}'!", node)
コード例 #18
0
    def funcall(self, name: str, node: Call):
        # TODO: bad hardcoded string
        if name == "reduce_over":
            return self.reduce_over(node)

        if name in {"sum_over", "min_over", "max_over"}:
            return self.short_reduce_over(node)

        if name in self.unary_math_functions or name in self.binary_math_functions:
            return self.math_function(node)

        raise DuskSyntaxError(f"Unrecognized function call '{name}'!", node)
コード例 #19
0
    def location_chain(self,
                       locations: t.List,
                       include_center: t.Optional[t.Literal["Origin"]] = None):
        does_include_center = include_center is not None
        locations = [self.location_type(location) for location in locations]

        if does_include_center and not self.ctx.location.is_ambiguous(
                locations):
            raise DuskSyntaxError(
                f"including the center is only allowed if start equals end location of the neighbor chain!"
            )
        return does_include_center, locations
コード例 #20
0
ファイル: match.py プロジェクト: BenWeber42/dusk
def match_ast(matcher: AST, node, **kwargs):
    if not isinstance(node, type(matcher)):
        raise DuskSyntaxError(
            f"Expected node type '{type(matcher)}', but got '{type(node)}'!",
            node)

    for field in matcher._fields:
        try:
            match(getattr(matcher, field), getattr(node, field), **kwargs)
        except DuskSyntaxError as e:
            if e.loc is None and isinstance(node, (stmt, expr)):
                # add location info if possible
                e.loc_from_node(node)
            raise e
コード例 #21
0
ファイル: grammar.py プロジェクト: BenWeber42/dusk
 def compare(self, left: expr, op, right: expr):
     # FIXME: we should probably have a better answer when we need such mappings
     py_compare_to_sir_compare = {
         Eq: "==",
         NotEq: "!=",
         Lt: "<",
         LtE: "<=",
         Gt: ">",
         GtE: ">=",
     }
     if type(op) not in py_compare_to_sir_compare.keys():
         raise DuskSyntaxError(f"Unsupported comparison operator '{op}'!", op)
     op = py_compare_to_sir_compare[type(op)]
     return make_binary_operator(self.expression(left), op, self.expression(right))
コード例 #22
0
    def relative_vertical_offset(self, base: str, shift: int = 0, vop=Add()):
        base = self.ctx.scope.current_scope.fetch(base)

        if not isinstance(base, (VerticalIterationVariable, DuskIndexField)):
            raise DuskSyntaxError(
                f"'{base}' isn't a vertical iteration variable or index field!",
                base)
        if isinstance(base, DuskIndexField):
            # TODO: check that `base` is valid in this context
            #       * compatible neighbor iteration
            #       * _correct_ usage (not clear what this entails)
            base = base.sir.name
            if (self.ctx.location.in_neighbor_iteration
                    and self.ctx.location.is_ambiguous(
                        self.ctx.location.current_neighbor_iteration)):
                raise DuskSyntaxError(
                    f"Index field {base} used in ambiguous neighbor iteration!"
                )
        elif isinstance(base, VerticalIterationVariable):
            base = None
        else:
            assert False

        return (-shift if isinstance(vop, Sub) else shift), base
コード例 #23
0
ファイル: grammar.py プロジェクト: BenWeber42/dusk
    def constant(self, value):
        # TODO: properly distinguish between float and double
        built_in_type_map = {bool: "Boolean", int: "Integer", float: "Double"}

        if type(value) not in built_in_type_map.keys():
            raise DuskSyntaxError(
                f"Unsupported constant '{value}' of type '{type(value)}'!", value
            )

        _type = sir.BuiltinType.TypeID.Value(built_in_type_map[type(value)])

        if isinstance(value, bool):
            value = "true" if value else "false"
        else:
            # TODO: does `str` really work here? (what about NaNs, precision, 1e11 notation, etc)
            value = str(value)

        return make_literal_access_expr(value, _type,)
コード例 #24
0
def dispatch(rules: t.Dict[t.Any, t.Callable], node):
    for recognizer, rule in rules.items():
        if does_match(recognizer, node):
            return rule(node)
    raise DuskSyntaxError(f"Unrecognized node: '{node}'!", node)
コード例 #25
0
ファイル: match.py プロジェクト: BenWeber42/dusk
def match_type(matcher: type, node, **kwargs):
    if not isinstance(node, matcher):
        raise DuskSyntaxError(
            f"Expected type '{matcher}', but got '{type(node)}'", node)
コード例 #26
0
    def location_type(self, name: str):
        location_names = {l.__name__ for l in LOCATION_TYPES}

        if name not in location_names:
            raise DuskSyntaxError(f"Invalid location type '{name}'!", name)
        return sir.LocationType.Value(name)
コード例 #27
0
ファイル: match.py プロジェクト: BenWeber42/dusk
def match_primitives(matcher, node, **kwargs):
    if matcher != node:
        raise DuskSyntaxError(f"Expected '{matcher}', but got '{node}'!", node)