def _make_axis_offset_expr(bound: common.AxisBound, axis_index: int) -> cuir.Expr: if bound.level == common.LevelMarker.END: base = cuir.ScalarAccess(name="{}_size".format(["i", "j"][axis_index]), dtype=common.DataType.INT32) return cuir.BinaryOp( op=common.ArithmeticOperator.ADD, left=base, right=cuir.Literal(value=str(bound.offset), dtype=common.DataType.INT32), ) else: return cuir.Literal(value=str(bound.offset), dtype=common.DataType.INT32)
def _mask_to_expr(self, mask: common.HorizontalMask, ctx: "Context") -> cuir.Expr: mask_expr: List[cuir.Expr] = [] for axis_index, interval in enumerate(mask.intervals): if interval.is_single_index(): mask_expr.append( cuir.BinaryOp( op=common.ComparisonOperator.EQ, left=ctx.make_positional(axis_index), right=_make_axis_offset_expr(interval.start, axis_index), )) else: for op, endpt in zip( (common.ComparisonOperator.GE, common.ComparisonOperator.LT), (interval.start, interval.end), ): if endpt is None: continue mask_expr.append( cuir.BinaryOp( op=op, left=ctx.make_positional(axis_index), right=_make_axis_offset_expr(endpt, axis_index), )) return (functools.reduce( lambda a, b: cuir.BinaryOp( op=common.LogicalOperator.AND, left=a, right=b), mask_expr, ) if mask_expr else cuir.Literal(value=common.BuiltInLiteral.TRUE, dtype=common.DataType.BOOL))
def visit_Literal(self, node: oir.Literal, **kwargs: Any) -> cuir.Literal: return cuir.Literal(value=node.value, dtype=node.dtype)