Пример #1
0
def _make_axis_offset_expr(bound: common.AxisBound,
                           axis_index: int) -> cuir.Expr:
    if bound.level == common.LevelMarker.END:
        base = cuir.ScalarAccess(name="{}_size".format(["i", "j"][axis_index]),
                                 dtype=common.DataType.INT32)
        return cuir.BinaryOp(
            op=common.ArithmeticOperator.ADD,
            left=base,
            right=cuir.Literal(value=str(bound.offset),
                               dtype=common.DataType.INT32),
        )
    else:
        return cuir.Literal(value=str(bound.offset),
                            dtype=common.DataType.INT32)
Пример #2
0
 def _mask_to_expr(self, mask: common.HorizontalMask,
                   ctx: "Context") -> cuir.Expr:
     mask_expr: List[cuir.Expr] = []
     for axis_index, interval in enumerate(mask.intervals):
         if interval.is_single_index():
             mask_expr.append(
                 cuir.BinaryOp(
                     op=common.ComparisonOperator.EQ,
                     left=ctx.make_positional(axis_index),
                     right=_make_axis_offset_expr(interval.start,
                                                  axis_index),
                 ))
         else:
             for op, endpt in zip(
                 (common.ComparisonOperator.GE,
                  common.ComparisonOperator.LT),
                 (interval.start, interval.end),
             ):
                 if endpt is None:
                     continue
                 mask_expr.append(
                     cuir.BinaryOp(
                         op=op,
                         left=ctx.make_positional(axis_index),
                         right=_make_axis_offset_expr(endpt, axis_index),
                     ))
     return (functools.reduce(
         lambda a, b: cuir.BinaryOp(
             op=common.LogicalOperator.AND, left=a, right=b),
         mask_expr,
     ) if mask_expr else cuir.Literal(value=common.BuiltInLiteral.TRUE,
                                      dtype=common.DataType.BOOL))
Пример #3
0
 def visit_Literal(self, node: oir.Literal, **kwargs: Any) -> cuir.Literal:
     return cuir.Literal(value=node.value, dtype=node.dtype)