예제 #1
0
        def visit_FieldRef(
            self, path: tuple, node_name: str, node: gt_ir.FieldRef
        ) -> Tuple[bool, Union[gt_ir.FieldRef, gt_ir.VarRef]]:
            if node.name in self.demotables:
                if node.name not in self.local_symbols:
                    field_decl = self.fields[node.name]
                    self.local_symbols[node.name] = gt_ir.VarDecl(
                        name=node.name,
                        data_type=field_decl.data_type,
                        length=1,
                        is_api=False,
                        loc=field_decl.loc,
                    )
                return True, gt_ir.VarRef(name=node.name, index=0, loc=node.loc)

            else:
                return True, node
예제 #2
0
    def _make_axis_interval(self, interval: IntervalInfo):
        axis_bounds = []
        for bound in (interval.start, interval.end):
            if bound[0] == 0:
                axis_bounds.append(gt_ir.AxisBound(level=gt_ir.LevelMarker.START, offset=bound[1]))
            elif bound[0] == self.data.nk_intervals:
                axis_bounds.append(gt_ir.AxisBound(level=gt_ir.LevelMarker.END, offset=bound[1]))
            else:
                axis_bounds.append(
                    gt_ir.AxisBound(
                        level=gt_ir.VarRef(name=self.data.splitters_var, index=bound[0] - 1),
                        offset=bound[1],
                    )
                )

        result = gt_ir.AxisInterval(start=axis_bounds[0], end=axis_bounds[1])

        return result
예제 #3
0
    def visit_Name(self, node: ast.Name) -> gt_ir.Ref:
        symbol = node.id
        if self._is_field(symbol):
            result = gt_ir.FieldRef(
                name=symbol,
                offset={
                    axis: value
                    for axis, value in zip(self.domain.axes_names, (0, 0, 0))
                },
            )
        elif self._is_parameter(symbol):
            result = gt_ir.VarRef(name=symbol)
        elif self._is_local_symbol(symbol):
            assert False  # result = gt_ir.VarRef(name=symbol)
        else:
            assert False, "Missing '{}' symbol definition".format(symbol)

        return result