def transform_unary_minus_expr(self, unary_minus: ir.UnaryMinusExpr) -> ir.Expr: expr = self.transform_expr(unary_minus.inner_expr) # -(3) => -3 if isinstance(expr, ir.Literal): assert isinstance(expr.value, int) return ir.Literal(-expr.value) # -(x - y) => y - x if isinstance(expr, ir.Int64BinaryOpExpr) and expr.op == '-': return ir.Int64BinaryOpExpr(lhs=expr.rhs, rhs=expr.lhs, op='-') return ir.UnaryMinusExpr(expr)
def transform_comparison_expr(self, comparison: ir.ComparisonExpr) -> ir.Expr: lhs = comparison.lhs rhs = comparison.rhs op = comparison.op lhs = self.transform_expr(lhs) rhs = self.transform_expr(rhs) if isinstance(lhs, ir.Literal) and isinstance(rhs, ir.Literal): if op == '==': return ir.Literal(lhs.value == rhs.value) if op == '!=': return ir.Literal(lhs.value != rhs.value) if op == '<': return ir.Literal(lhs.value < rhs.value) if op == '<=': return ir.Literal(lhs.value <= rhs.value) if op == '>': return ir.Literal(lhs.value > rhs.value) if op == '>=': return ir.Literal(lhs.value >= rhs.value) if op in ('==', '!=') and self._is_syntactically_equal( lhs, rhs) and not expr_can_trigger_static_asserts(lhs): if self._can_remove_subexpression( lhs) and self._can_remove_subexpression(rhs): return { '==': ir.Literal(True), '!=': ir.Literal(False), }[op] if op in ('==', '!=') and isinstance( rhs, ir.Literal) and rhs.expr_type == ir.BoolType(): rhs, lhs = lhs, rhs if op in ('==', '!=') and isinstance( lhs, ir.Literal) and lhs.expr_type == ir.BoolType(): return { ('==', True): lambda: rhs, ('==', False): lambda: self.transform_expr(ir.NotExpr(rhs)), ('!=', True): lambda: self.transform_expr(ir.NotExpr(rhs)), ('!=', False): lambda: rhs, }[(op, lhs.value)]() return ir.ComparisonExpr(lhs, rhs, op)
def transform_bool_binary_op_expr( self, binary_op: ir.BoolBinaryOpExpr) -> ir.Expr: lhs = binary_op.lhs rhs = binary_op.rhs op = binary_op.op lhs = self.transform_expr(lhs) rhs = self.transform_expr(rhs) if op == '&&': # true && false => false if isinstance(lhs, ir.Literal) and isinstance(rhs, ir.Literal): return ir.Literal(lhs.value and rhs.value) # true && x => x if isinstance(lhs, ir.Literal) and lhs.value is True: return rhs # x && true => x if isinstance(rhs, ir.Literal) and rhs.value is True: return lhs # false && x => false if isinstance(lhs, ir.Literal) and lhs.value is False: if self._can_remove_subexpression(rhs): return ir.Literal(False) # x && false => false if isinstance(rhs, ir.Literal) and rhs.value is False: if self._can_remove_subexpression(lhs): return ir.Literal(False) if op == '||': # true || false => true if isinstance(lhs, ir.Literal) and isinstance(rhs, ir.Literal): return ir.Literal(lhs.value or rhs.value) # false || x => x if isinstance(lhs, ir.Literal) and lhs.value is False: return rhs # x || false => x if isinstance(rhs, ir.Literal) and rhs.value is False: return lhs # true || x => true if isinstance(lhs, ir.Literal) and lhs.value is True: if self._can_remove_subexpression(rhs): return ir.Literal(True) # x || true => true if isinstance(rhs, ir.Literal) and rhs.value is True: if self._can_remove_subexpression(lhs): return ir.Literal(True) return ir.BoolBinaryOpExpr(lhs, rhs, op)
def transform_not_expr(self, not_expr: ir.NotExpr) -> ir.Expr: expr = self.transform_expr(not_expr.inner_expr) # !true => false # !false => true if isinstance(expr, ir.Literal): assert isinstance(expr.value, bool) return ir.Literal(not expr.value) # !!x => x if isinstance(expr, ir.NotExpr): return expr.inner_expr # !(x && y) => (!x || !y) # !(x || y) => (!x && !y) if isinstance(expr, ir.BoolBinaryOpExpr): op = { '&&': '||', '||': '&&', }[expr.op] return self.transform_expr( ir.BoolBinaryOpExpr(lhs=ir.NotExpr(expr.lhs), rhs=ir.NotExpr(expr.rhs), op=op)) # !(x == y) => x != y # !(x != y) => x == y # !(x < y) => x >= y # !(x <= y) => x > y # !(x > y) => x <= y # !(x >= y) => x < y if isinstance(expr, ir.ComparisonExpr) and expr.op in ('==', '!='): op = { '==': '!=', '!=': '==', '<': '>=', '<=': '>', '>': '<=', '>=': '<', }[expr.op] return ir.ComparisonExpr(expr.lhs, expr.rhs, op) return ir.NotExpr(expr)
def transform_int64_binary_op_expr( self, binary_op: ir.Int64BinaryOpExpr) -> ir.Expr: lhs = binary_op.lhs rhs = binary_op.rhs op = binary_op.op # (x - y) => (x + -y) # This pushes down the minus, so that e.g. (x - (-y)) => (x + y). if op == '-': rhs = ir.UnaryMinusExpr(rhs) op = '+' lhs = self.transform_expr(lhs) rhs = self.transform_expr(rhs) if op == '+' and isinstance(rhs, ir.UnaryMinusExpr): # We could not push down the minus, so switch back to a subtraction. op = '-' rhs = rhs.inner_expr if op == '+': # 3 + 5 => 8 if isinstance(lhs, ir.Literal) and isinstance(rhs, ir.Literal): return ir.Literal(lhs.value + rhs.value) # 0 + x => x if isinstance(lhs, ir.Literal) and lhs.value == 0: return rhs # x + 0 => x if isinstance(rhs, ir.Literal) and rhs.value == 0: return lhs if op == '-': # 8 - 5 => 3 if isinstance(lhs, ir.Literal) and isinstance(rhs, ir.Literal): return ir.Literal(lhs.value - rhs.value) # 0 - x => -x if isinstance(lhs, ir.Literal) and lhs.value == 0: return ir.UnaryMinusExpr(rhs) # x - 0 => x if isinstance(rhs, ir.Literal) and rhs.value == 0: return lhs if op == '*': # 3 * 5 => 15 if isinstance(lhs, ir.Literal) and isinstance(rhs, ir.Literal): return ir.Literal(lhs.value * rhs.value) # 0 * x => 0 if isinstance(lhs, ir.Literal) and lhs.value == 0: if self._can_remove_subexpression(rhs): return ir.Literal(0) # x * 0 => 0 if isinstance(rhs, ir.Literal) and rhs.value == 0: if self._can_remove_subexpression(lhs): return ir.Literal(0) # 1 * x => x if isinstance(lhs, ir.Literal) and lhs.value == 1: return rhs # x * 1 => x if isinstance(rhs, ir.Literal) and rhs.value == 1: return lhs if op == '/': # 16 / 3 => 5 if isinstance(lhs, ir.Literal) and isinstance(rhs, ir.Literal): return ir.Literal(lhs.value // rhs.value) # x / 1 => x if isinstance(rhs, ir.Literal) and rhs.value == 1: return lhs if op == '%': # 16 % 3 => 1 if isinstance(lhs, ir.Literal) and isinstance(rhs, ir.Literal): return ir.Literal(lhs.value % rhs.value) # x % 1 => 0 if isinstance(rhs, ir.Literal) and rhs.value == 1: return ir.Literal(0) return ir.Int64BinaryOpExpr(lhs, rhs, op)
result_element_names=['value'], args=[ ir.TemplateArgDecl( name='T', expr_type=ir.TypeType(), is_variadic=False), ir.TemplateArgDecl( name='U', expr_type=ir.TypeType(), is_variadic=False) ], main_definition=ir.TemplateSpecialization( args=[ ir.TemplateArgDecl( name='T', expr_type=ir.TypeType(), is_variadic=False), ir.TemplateArgDecl( name='U', expr_type=ir.TypeType(), is_variadic=False) ], patterns=None, body=[ir.ConstantDef(name='value', expr=ir.Literal(False))], is_metafunction=True), specializations=[ ir.TemplateSpecialization( args=[ ir.TemplateArgDecl( name='T', expr_type=ir.TypeType(), is_variadic=False) ], patterns=[ ir.AtomicTypeLiteral.for_local(cpp_type='T', expr_type=ir.TypeType(), is_variadic=False), ir.AtomicTypeLiteral.for_local(cpp_type='T', expr_type=ir.TypeType(), is_variadic=False) ],