예제 #1
0
def _make_axis_offset_expr(
    bound: common.AxisBound,
    axis_index: int,
    axis_length_accessor: Callable[[int], gtcpp.AccessorRef],
) -> gtcpp.Expr:
    if bound.level == common.LevelMarker.END:
        base = axis_length_accessor(axis_index)
        return gtcpp.BinaryOp(
            op=common.ArithmeticOperator.ADD,
            left=base,
            right=gtcpp.Literal(value=str(bound.offset),
                                dtype=common.DataType.INT32),
        )
    else:
        return gtcpp.Literal(value=str(bound.offset),
                             dtype=common.DataType.INT32)
예제 #2
0
 def _mask_to_expr(self, mask: common.HorizontalMask,
                   comp_ctx: "GTComputationContext") -> gtcpp.Expr:
     mask_expr: List[gtcpp.Expr] = []
     for axis_index, interval in enumerate(mask.intervals):
         if interval.is_single_index():
             mask_expr.append(
                 gtcpp.BinaryOp(
                     op=common.ComparisonOperator.EQ,
                     left=comp_ctx.make_positional(axis_index),
                     right=_make_axis_offset_expr(interval.start,
                                                  axis_index,
                                                  comp_ctx.make_length),
                 ))
         else:
             for op, endpt in zip(
                 (common.ComparisonOperator.GE,
                  common.ComparisonOperator.LT),
                 (interval.start, interval.end),
             ):
                 if endpt is None:
                     continue
                 mask_expr.append(
                     gtcpp.BinaryOp(
                         op=op,
                         left=comp_ctx.make_positional(axis_index),
                         right=_make_axis_offset_expr(
                             endpt, axis_index, comp_ctx.make_length),
                     ))
     return (functools.reduce(
         lambda a, b: gtcpp.BinaryOp(
             op=common.LogicalOperator.AND, left=a, right=b),
         mask_expr,
     ) if mask_expr else gtcpp.Literal(value=common.BuiltInLiteral.TRUE,
                                       dtype=common.DataType.BOOL))
예제 #3
0
 def visit_Literal(self, node: oir.Literal, **kwargs: Any) -> gtcpp.Literal:
     return gtcpp.Literal(value=node.value, dtype=node.dtype)