def visit_FieldRef( self, path: tuple, node_name: str, node: gt_ir.FieldRef ) -> Tuple[bool, Union[gt_ir.FieldRef, gt_ir.VarRef]]: if node.name in self.demotables: if node.name not in self.local_symbols: field_decl = self.fields[node.name] self.local_symbols[node.name] = gt_ir.VarDecl( name=node.name, data_type=field_decl.data_type, length=1, is_api=False, loc=field_decl.loc, ) return True, gt_ir.VarRef(name=node.name, index=0, loc=node.loc) else: return True, node
def _make_axis_interval(self, interval: IntervalInfo): axis_bounds = [] for bound in (interval.start, interval.end): if bound[0] == 0: axis_bounds.append(gt_ir.AxisBound(level=gt_ir.LevelMarker.START, offset=bound[1])) elif bound[0] == self.data.nk_intervals: axis_bounds.append(gt_ir.AxisBound(level=gt_ir.LevelMarker.END, offset=bound[1])) else: axis_bounds.append( gt_ir.AxisBound( level=gt_ir.VarRef(name=self.data.splitters_var, index=bound[0] - 1), offset=bound[1], ) ) result = gt_ir.AxisInterval(start=axis_bounds[0], end=axis_bounds[1]) return result
def visit_Name(self, node: ast.Name) -> gt_ir.Ref: symbol = node.id if self._is_field(symbol): result = gt_ir.FieldRef( name=symbol, offset={ axis: value for axis, value in zip(self.domain.axes_names, (0, 0, 0)) }, ) elif self._is_parameter(symbol): result = gt_ir.VarRef(name=symbol) elif self._is_local_symbol(symbol): assert False # result = gt_ir.VarRef(name=symbol) else: assert False, "Missing '{}' symbol definition".format(symbol) return result