def transform_is_same(self, args: Tuple[ir.Expr, ...]): assert len(args) == 2 lhs, rhs = args list_template_names = {'List', 'BoolList', 'Int64List'} if (isinstance(lhs, ir.TemplateInstantiation) and isinstance(lhs.template_expr, ir.AtomicTypeLiteral) and lhs.template_expr.cpp_type in list_template_names and isinstance(rhs, ir.TemplateInstantiation) and isinstance(rhs.template_expr, ir.AtomicTypeLiteral) and rhs.template_expr.cpp_type in list_template_names and lhs.template_expr.cpp_type == rhs.template_expr.cpp_type and not any( isinstance(arg, ir.VariadicTypeExpansion) for arg in lhs.args) and not any( isinstance(arg, ir.VariadicTypeExpansion) for arg in rhs.args) and len(lhs.args) == len(rhs.args) and lhs.args): # std::is_same<List<X1, X2, X3>, List<Y1, Y2, Y3>>::value # -> std::is_same<X1, Y1>::value && std::is_same<X2, Y2>::value && std::is_same<X3, Y3>::value if lhs.template_expr.cpp_type == 'List': result = None for lhs_arg, rhs_arg in zip(lhs.args, rhs.args): if result: result = ir.BoolBinaryOpExpr( lhs=result, rhs=self._create_is_same_expr(lhs_arg, rhs_arg), op='&&') else: result = self._create_is_same_expr(lhs_arg, rhs_arg) return self.transform_expr(result) # std::is_same<IntList<n1, n2, n3>, IntList<m1, m2, m3>>::value # -> (n1 == m1) && (n2 == m2) && (n3 == m3) # (and same for BoolList) result = None for lhs_arg, rhs_arg in zip(lhs.args, rhs.args): if result: result = ir.BoolBinaryOpExpr( lhs=result, rhs=ir.ComparisonExpr(lhs_arg, rhs_arg, op='=='), op='&&') else: result = ir.ComparisonExpr(lhs_arg, rhs_arg, op='==') return self.transform_expr(result) return self._create_is_same_expr(lhs, rhs)
def transform_bool_binary_op_expr( self, binary_op: ir.BoolBinaryOpExpr) -> ir.Expr: lhs = binary_op.lhs rhs = binary_op.rhs op = binary_op.op lhs = self.transform_expr(lhs) rhs = self.transform_expr(rhs) if op == '&&': # true && false => false if isinstance(lhs, ir.Literal) and isinstance(rhs, ir.Literal): return ir.Literal(lhs.value and rhs.value) # true && x => x if isinstance(lhs, ir.Literal) and lhs.value is True: return rhs # x && true => x if isinstance(rhs, ir.Literal) and rhs.value is True: return lhs # false && x => false if isinstance(lhs, ir.Literal) and lhs.value is False: if self._can_remove_subexpression(rhs): return ir.Literal(False) # x && false => false if isinstance(rhs, ir.Literal) and rhs.value is False: if self._can_remove_subexpression(lhs): return ir.Literal(False) if op == '||': # true || false => true if isinstance(lhs, ir.Literal) and isinstance(rhs, ir.Literal): return ir.Literal(lhs.value or rhs.value) # false || x => x if isinstance(lhs, ir.Literal) and lhs.value is False: return rhs # x || false => x if isinstance(rhs, ir.Literal) and rhs.value is False: return lhs # true || x => true if isinstance(lhs, ir.Literal) and lhs.value is True: if self._can_remove_subexpression(rhs): return ir.Literal(True) # x || true => true if isinstance(rhs, ir.Literal) and rhs.value is True: if self._can_remove_subexpression(lhs): return ir.Literal(True) return ir.BoolBinaryOpExpr(lhs, rhs, op)
def transform_not_expr(self, not_expr: ir.NotExpr) -> ir.Expr: expr = self.transform_expr(not_expr.inner_expr) # !true => false # !false => true if isinstance(expr, ir.Literal): assert isinstance(expr.value, bool) return ir.Literal(not expr.value) # !!x => x if isinstance(expr, ir.NotExpr): return expr.inner_expr # !(x && y) => (!x || !y) # !(x || y) => (!x && !y) if isinstance(expr, ir.BoolBinaryOpExpr): op = { '&&': '||', '||': '&&', }[expr.op] return self.transform_expr( ir.BoolBinaryOpExpr(lhs=ir.NotExpr(expr.lhs), rhs=ir.NotExpr(expr.rhs), op=op)) # !(x == y) => x != y # !(x != y) => x == y # !(x < y) => x >= y # !(x <= y) => x > y # !(x > y) => x <= y # !(x >= y) => x < y if isinstance(expr, ir.ComparisonExpr) and expr.op in ('==', '!='): op = { '==': '!=', '!=': '==', '<': '>=', '<=': '>', '>': '<=', '>=': '<', }[expr.op] return ir.ComparisonExpr(expr.lhs, expr.rhs, op) return ir.NotExpr(expr)
def transform_bool_binary_op_expr( self, binary_op: ir.BoolBinaryOpExpr) -> ir.Expr: lhs, rhs = self.transform_exprs([binary_op.lhs, binary_op.rhs], binary_op) return ir.BoolBinaryOpExpr(lhs=lhs, rhs=rhs, op=binary_op.op)