def _make_axis_offset_expr( bound: common.AxisBound, axis_index: int, axis_length_accessor: Callable[[int], gtcpp.AccessorRef], ) -> gtcpp.Expr: if bound.level == common.LevelMarker.END: base = axis_length_accessor(axis_index) return gtcpp.BinaryOp( op=common.ArithmeticOperator.ADD, left=base, right=gtcpp.Literal(value=str(bound.offset), dtype=common.DataType.INT32), ) else: return gtcpp.Literal(value=str(bound.offset), dtype=common.DataType.INT32)
def _mask_to_expr(self, mask: common.HorizontalMask, comp_ctx: "GTComputationContext") -> gtcpp.Expr: mask_expr: List[gtcpp.Expr] = [] for axis_index, interval in enumerate(mask.intervals): if interval.is_single_index(): mask_expr.append( gtcpp.BinaryOp( op=common.ComparisonOperator.EQ, left=comp_ctx.make_positional(axis_index), right=_make_axis_offset_expr(interval.start, axis_index, comp_ctx.make_length), )) else: for op, endpt in zip( (common.ComparisonOperator.GE, common.ComparisonOperator.LT), (interval.start, interval.end), ): if endpt is None: continue mask_expr.append( gtcpp.BinaryOp( op=op, left=comp_ctx.make_positional(axis_index), right=_make_axis_offset_expr( endpt, axis_index, comp_ctx.make_length), )) return (functools.reduce( lambda a, b: gtcpp.BinaryOp( op=common.LogicalOperator.AND, left=a, right=b), mask_expr, ) if mask_expr else gtcpp.Literal(value=common.BuiltInLiteral.TRUE, dtype=common.DataType.BOOL))
def visit_Literal(self, node: oir.Literal, **kwargs: Any) -> gtcpp.Literal: return gtcpp.Literal(value=node.value, dtype=node.dtype)