def transform_is_same(self, args: Tuple[ir.Expr, ...]): assert len(args) == 2 lhs, rhs = args list_template_names = {'List', 'BoolList', 'Int64List'} if (isinstance(lhs, ir.TemplateInstantiation) and isinstance(lhs.template_expr, ir.AtomicTypeLiteral) and lhs.template_expr.cpp_type in list_template_names and isinstance(rhs, ir.TemplateInstantiation) and isinstance(rhs.template_expr, ir.AtomicTypeLiteral) and rhs.template_expr.cpp_type in list_template_names and lhs.template_expr.cpp_type == rhs.template_expr.cpp_type and not any( isinstance(arg, ir.VariadicTypeExpansion) for arg in lhs.args) and not any( isinstance(arg, ir.VariadicTypeExpansion) for arg in rhs.args) and len(lhs.args) == len(rhs.args) and lhs.args): # std::is_same<List<X1, X2, X3>, List<Y1, Y2, Y3>>::value # -> std::is_same<X1, Y1>::value && std::is_same<X2, Y2>::value && std::is_same<X3, Y3>::value if lhs.template_expr.cpp_type == 'List': result = None for lhs_arg, rhs_arg in zip(lhs.args, rhs.args): if result: result = ir.BoolBinaryOpExpr( lhs=result, rhs=self._create_is_same_expr(lhs_arg, rhs_arg), op='&&') else: result = self._create_is_same_expr(lhs_arg, rhs_arg) return self.transform_expr(result) # std::is_same<IntList<n1, n2, n3>, IntList<m1, m2, m3>>::value # -> (n1 == m1) && (n2 == m2) && (n3 == m3) # (and same for BoolList) result = None for lhs_arg, rhs_arg in zip(lhs.args, rhs.args): if result: result = ir.BoolBinaryOpExpr( lhs=result, rhs=ir.ComparisonExpr(lhs_arg, rhs_arg, op='=='), op='&&') else: result = ir.ComparisonExpr(lhs_arg, rhs_arg, op='==') return self.transform_expr(result) return self._create_is_same_expr(lhs, rhs)
def transform_comparison_expr(self, comparison: ir.ComparisonExpr) -> ir.Expr: lhs = comparison.lhs rhs = comparison.rhs op = comparison.op lhs = self.transform_expr(lhs) rhs = self.transform_expr(rhs) if isinstance(lhs, ir.Literal) and isinstance(rhs, ir.Literal): if op == '==': return ir.Literal(lhs.value == rhs.value) if op == '!=': return ir.Literal(lhs.value != rhs.value) if op == '<': return ir.Literal(lhs.value < rhs.value) if op == '<=': return ir.Literal(lhs.value <= rhs.value) if op == '>': return ir.Literal(lhs.value > rhs.value) if op == '>=': return ir.Literal(lhs.value >= rhs.value) if op in ('==', '!=') and self._is_syntactically_equal( lhs, rhs) and not expr_can_trigger_static_asserts(lhs): if self._can_remove_subexpression( lhs) and self._can_remove_subexpression(rhs): return { '==': ir.Literal(True), '!=': ir.Literal(False), }[op] if op in ('==', '!=') and isinstance( rhs, ir.Literal) and rhs.expr_type == ir.BoolType(): rhs, lhs = lhs, rhs if op in ('==', '!=') and isinstance( lhs, ir.Literal) and lhs.expr_type == ir.BoolType(): return { ('==', True): lambda: rhs, ('==', False): lambda: self.transform_expr(ir.NotExpr(rhs)), ('!=', True): lambda: self.transform_expr(ir.NotExpr(rhs)), ('!=', False): lambda: rhs, }[(op, lhs.value)]() return ir.ComparisonExpr(lhs, rhs, op)
def transform_not_expr(self, not_expr: ir.NotExpr) -> ir.Expr: expr = self.transform_expr(not_expr.inner_expr) # !true => false # !false => true if isinstance(expr, ir.Literal): assert isinstance(expr.value, bool) return ir.Literal(not expr.value) # !!x => x if isinstance(expr, ir.NotExpr): return expr.inner_expr # !(x && y) => (!x || !y) # !(x || y) => (!x && !y) if isinstance(expr, ir.BoolBinaryOpExpr): op = { '&&': '||', '||': '&&', }[expr.op] return self.transform_expr( ir.BoolBinaryOpExpr(lhs=ir.NotExpr(expr.lhs), rhs=ir.NotExpr(expr.rhs), op=op)) # !(x == y) => x != y # !(x != y) => x == y # !(x < y) => x >= y # !(x <= y) => x > y # !(x > y) => x <= y # !(x >= y) => x < y if isinstance(expr, ir.ComparisonExpr) and expr.op in ('==', '!='): op = { '==': '!=', '!=': '==', '<': '>=', '<=': '>', '>': '<=', '>=': '<', }[expr.op] return ir.ComparisonExpr(expr.lhs, expr.rhs, op) return ir.NotExpr(expr)
def transform_comparison_expr(self, comparison: ir.ComparisonExpr) -> ir.Expr: lhs, rhs = self.transform_exprs([comparison.lhs, comparison.rhs], comparison) return ir.ComparisonExpr(lhs=lhs, rhs=rhs, op=comparison.op)