def test_single_if_else_w_following_undetermined(self): def impl(x): x_is_none_work = False if x is None: x_is_none_work = True else: dead = 7 # noqa: F841 # no effect if x_is_none_work: y = 10 else: y = -3 return y self.assert_prune(impl, (types.NoneType('none'),), [False, None], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10) def impl(x): x_is_none_work = False if x is None: x_is_none_work = True else: pass # force the True branch exit to be on backbone if x_is_none_work: y = 10 else: y = -3 return y self.assert_prune(impl, (types.NoneType('none'),), [None, None], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10)
def test_single_if(self): def impl(x): if 1 == 0: return 3.14159 self.assert_prune(impl, (types.NoneType('none'),), [True], None) def impl(x): if 1 == 1: return 3.14159 self.assert_prune(impl, (types.NoneType('none'),), [False], None) def impl(x): if x is None: return 3.14159 self.assert_prune(impl, (types.NoneType('none'),), [False], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True], 10) def impl(x): if x == 10: return 3.14159 self.assert_prune(impl, (types.NoneType('none'),), [True], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10) def impl(x): if x == 10: z = 3.14159 # noqa: F841 # no effect self.assert_prune(impl, (types.NoneType('none'),), [True], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10)
def test_redefinition_analysis_different_block_cannot_exec(self): # checks that a redefinition in a block guarded by something that # has prune potential def impl(array, x=None, a=None): b = 0 if x is not None: a = 11 if a is None: b += 5 else: b += 7 return 30 + b self.assert_prune(impl, (types.Array(types.float64, 2, 'C'), types.NoneType('none'), types.NoneType('none')), [True, None], np.zeros((2, 3)), None, None) self.assert_prune(impl, (types.Array( types.float64, 2, 'C'), types.NoneType('none'), types.float64), [True, None], np.zeros((2, 3)), None, 1.2) self.assert_prune(impl, (types.Array( types.float64, 2, 'C'), types.float64, types.NoneType('none')), [None, None], np.zeros((2, 3)), 1.2, None)
def test_single_if_const_val(self): def impl(x): if x == 100: return 3.14159 self.assert_prune(impl, (types.NoneType('none'), ), [True], None) self.assert_prune(impl, (types.IntegerLiteral(100), ), [None], 100) def impl(x): # switch the condition order if 100 == x: return 3.14159 self.assert_prune(impl, (types.NoneType('none'), ), [True], None) self.assert_prune(impl, (types.IntegerLiteral(100), ), [None], 100)
def test_single_if_else_two_const_val(self): def impl(x, y): if x == y: return 3.14159 else: return 1.61803 self.assert_prune(impl, (types.IntegerLiteral(100),) * 2, [None], 100, 100) self.assert_prune(impl, (types.NoneType('none'),) * 2, [False], None, None) self.assert_prune(impl, (types.IntegerLiteral(100), types.NoneType('none'),), [True], 100, None) self.assert_prune(impl, (types.IntegerLiteral(100), types.IntegerLiteral(1000)), [None], 100, 1000)
def test_single_if_else(self): def impl(x): if x is None: return 3.14159 else: return 1.61803 self.assert_prune(impl, (types.NoneType('none'), ), [False], None) self.assert_prune(impl, (types.IntegerLiteral(10), ), [True], 10)
def test_redefined_variables_are_not_considered_in_prune(self): # see issue #4163, checks that if a variable that is an argument is # redefined in the user code it is not considered const def impl(array, a=None): if a is None: a = 0 if a < 0: return 10 return 30 self.assert_prune(impl, ( types.Array(types.float64, 2, 'C'), types.NoneType('none'), ), [None, None], np.zeros((2, 3)), None)
def test_single_two_branches_same_cond(self): def impl(x): if x is None: y = 10 else: y = 40 if x is not None: z = 100 else: z = 400 return z, y self.assert_prune(impl, (types.NoneType('none'),), [False, True], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10)
def test_double_if_else_rt_const(self): def impl(x): one_hundred = 100 x_is_none_work = 4 if x is None: x_is_none_work = 100 else: dead = 7 # noqa: F841 # no effect if x_is_none_work == one_hundred: y = 10 else: y = -3 return y, x_is_none_work self.assert_prune(impl, (types.NoneType('none'),), [False, None], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10)
def test_cond_is_kwarg_none(self): def impl(x=None): if x is None: y = 10 else: y = 40 if x is not None: z = 100 else: z = 400 return z, y self.assert_prune(impl, (types.Omitted(None),), [False, True], None) self.assert_prune(impl, (types.NoneType('none'),), [False, True], None) self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10)
def test_cond_is_kwarg_value(self): def impl(x=1000): if x == 1000: y = 10 else: y = 40 if x != 1000: z = 100 else: z = 400 return z, y self.assert_prune(impl, (types.Omitted(1000),), [None, None], 1000) self.assert_prune(impl, (types.IntegerLiteral(1000),), [None, None], 1000) self.assert_prune(impl, (types.IntegerLiteral(0),), [None, None], 0) self.assert_prune(impl, (types.NoneType('none'),), [True, False], None)
def test_redefinition_analysis_same_block(self): # checks that a redefinition in a block with prunable potential doesn't # break def impl(array, x, a=None): b = 0 if x < 4: b = 12 if a is None: a = 0 else: b = 12 if a < 0: return 10 return 30 + b + a self.assert_prune(impl, ( types.Array(types.float64, 2, 'C'), types.float64, types.NoneType('none'), ), [None, None, None], np.zeros((2, 3)), 1., None)
def test_cond_rewrite_is_correct(self): # this checks that when a condition is replaced, it is replace by a # true/false bit that correctly represents the evaluated condition def fn(x): if x is None: return 10 return 12 def check(func, arg_tys, bit_val): func_ir = compile_to_ir(func) # check there is 1 branch before_branches = self.find_branches(func_ir) self.assertEqual(len(before_branches), 1) # check the condition in the branch is a binop condition_var = before_branches[0].cond condition_defn = ir_utils.get_definition(func_ir, condition_var) self.assertEqual(condition_defn.op, 'binop') # do the prune, this should kill the dead branch and rewrite the #'condition to a true/false const bit if self._DEBUG: print("=" * 80) print("before prune") func_ir.dump() dead_branch_prune(func_ir, arg_tys) if self._DEBUG: print("=" * 80) print("after prune") func_ir.dump() # after mutation, the condition should be a const value `bit_val` new_condition_defn = ir_utils.get_definition( func_ir, condition_var) self.assertTrue(isinstance(new_condition_defn, ir.Const)) self.assertEqual(new_condition_defn.value, bit_val) check(fn, (types.NoneType('none'), ), 1) check(fn, (types.IntegerLiteral(10), ), 0)
def resolve_input_arg_const(input_arg_idx): """ Resolves an input arg to a constant (if possible) """ input_arg_ty = called_args[input_arg_idx] # comparing to None? if isinstance(input_arg_ty, types.NoneType): return input_arg_ty # is it a kwarg default if isinstance(input_arg_ty, types.Omitted): val = input_arg_ty.value if isinstance(val, types.NoneType): return val elif val is None: return types.NoneType('none') # literal type, return the type itself so comparisons like `x == None` # still work as e.g. x = types.int64 will never be None/NoneType so # the branch can still be pruned return getattr(input_arg_ty, 'literal_type', Unknown())
def test_redefinition_analysis_different_block_can_exec(self): # checks that a redefinition in a block that may be executed prevents # pruning def impl(array, x, a=None): b = 0 if x > 5: a = 11 # a redefined, cannot tell statically if this will exec if x < 4: b = 12 if a is None: # cannot prune, cannot determine if re-defn occurred b += 5 else: b += 7 if a < 0: return 10 return 30 + b self.assert_prune(impl, ( types.Array(types.float64, 2, 'C'), types.float64, types.NoneType('none'), ), [None, None, None, None], np.zeros((2, 3)), 1., None)
def test_comparison_operators(self): # see issue #4163, checks that a variable that is an argument and has # value None survives TypeError from invalid comparison which should be # dead def impl(array, a=None): x = 0 if a is None: return 10 # dynamic exec would return here # static analysis requires that this is executed with a=None, # hence TypeError if a < 0: return 20 return x self.assert_prune(impl, ( types.Array(types.float64, 2, 'C'), types.NoneType('none'), ), [False, 'both'], np.zeros((2, 3)), None) self.assert_prune(impl, ( types.Array(types.float64, 2, 'C'), types.float64, ), [None, None], np.zeros((2, 3)), 12.)
def dead_branch_prune(func_ir, called_args): """ Removes dead branches based on constant inference from function args. This directly mutates the IR. func_ir is the IR called_args are the actual arguments with which the function is called """ from .ir_utils import get_definition, guard, find_const, GuardException DEBUG = 0 def find_branches(func_ir): # find *all* branches branches = [] for blk in func_ir.blocks.values(): branch_or_jump = blk.body[-1] if isinstance(branch_or_jump, ir.Branch): branch = branch_or_jump condition = guard(get_definition, func_ir, branch.cond.name) if condition is not None: branches.append((branch, condition, blk)) return branches def do_prune(take_truebr, blk): keep = branch.truebr if take_truebr else branch.falsebr # replace the branch with a direct jump jmp = ir.Jump(keep, loc=branch.loc) blk.body[-1] = jmp return 1 if keep == branch.truebr else 0 def prune_by_type(branch, condition, blk, *conds): # this prunes a given branch and fixes up the IR # at least one needs to be a NoneType lhs_cond, rhs_cond = conds lhs_none = isinstance(lhs_cond, types.NoneType) rhs_none = isinstance(rhs_cond, types.NoneType) if lhs_none or rhs_none: take_truebr = condition.fn(lhs_cond, rhs_cond) if DEBUG > 0: kill = branch.falsebr if take_truebr else branch.truebr print("Pruning %s" % kill, branch, lhs_cond, rhs_cond, condition.fn) taken = do_prune(take_truebr, blk) return True, taken return False, None def prune_by_value(branch, condition, blk, *conds): lhs_cond, rhs_cond = conds take_truebr = condition.fn(lhs_cond, rhs_cond) if DEBUG > 0: kill = branch.falsebr if take_truebr else branch.truebr print("Pruning %s" % kill, branch, lhs_cond, rhs_cond, condition.fn) taken = do_prune(take_truebr, blk) return True, taken class Unknown(object): pass def resolve_input_arg_const(input_arg): """ Resolves an input arg to a constant (if possible) """ idx = func_ir.arg_names.index(input_arg) input_arg_ty = called_args[idx] # comparing to None? if isinstance(input_arg_ty, types.NoneType): return input_arg_ty # is it a kwarg default if isinstance(input_arg_ty, types.Omitted): val = input_arg_ty.value if isinstance(val, types.NoneType): return val elif val is None: return types.NoneType('none') # literal type, return the type itself so comparisons like `x == None` # still work as e.g. x = types.int64 will never be None/NoneType so # the branch can still be pruned return getattr(input_arg_ty, 'literal_type', Unknown()) if DEBUG > 1: print("before".center(80, '-')) print(func_ir.dump()) # This looks for branches where: # at least one arg of the condition is in input args and const # at least one an arg of the condition is a const # if the condition is met it will replace the branch with a jump branch_info = find_branches(func_ir) nullified_conditions = [ ] # stores conditions that have no impact post prune for branch, condition, blk in branch_info: const_conds = [] if isinstance(condition, ir.Expr) and condition.op == 'binop': prune = prune_by_value for arg in [condition.lhs, condition.rhs]: resolved_const = Unknown() if arg.name in func_ir.arg_names: # it's an e.g. literal argument to the function resolved_const = resolve_input_arg_const(arg.name) prune = prune_by_type else: # it's some const argument to the function, cannot use guard # here as the const itself may be None try: resolved_const = find_const(func_ir, arg) if resolved_const is None: resolved_const = types.NoneType('none') except GuardException: pass if not isinstance(resolved_const, Unknown): const_conds.append(resolved_const) # lhs/rhs are consts if len(const_conds) == 2: # prune the branch, switch the branch for an unconditional jump prune_stat, taken = prune(branch, condition, blk, *const_conds) if (prune_stat): # add the condition to the list of nullified conditions nullified_conditions.append((condition, taken)) # 'ERE BE DRAGONS... # It is the evaluation of the condition expression that often trips up type # inference, so ideally it would be removed as it is effectively rendered # dead by the unconditional jump if a branch was pruned. However, there may # be references to the condition that exist in multiple places (e.g. dels) # and we cannot run DCE here as typing has not taken place to give enough # information to run DCE safely. Upshot of all this is the condition gets # rewritten below into a benign const that typing will be happy with and DCE # can remove it and its reference post typing when it is safe to do so # (if desired). It is required that the const is assigned a value that # indicates the branch taken as its mutated value would be read in the case # of object mode fall back in place of the condition itself. For # completeness the func_ir._definitions and ._consts are also updated to # make the IR state self consistent. deadcond = [x[0] for x in nullified_conditions] for _, cond, blk in branch_info: if cond in deadcond: for x in blk.body: if isinstance(x, ir.Assign) and x.value is cond: # rewrite the condition as a true/false bit branch_bit = nullified_conditions[deadcond.index(cond)][1] x.value = ir.Const(branch_bit, loc=x.loc) # update the specific definition to the new const defns = func_ir._definitions[x.target.name] repl_idx = defns.index(cond) defns[repl_idx] = x.value # Remove dead blocks, this is safe as it relies on the CFG only. cfg = compute_cfg_from_blocks(func_ir.blocks) for dead in cfg.dead_nodes(): del func_ir.blocks[dead] # if conditions were nullified then consts were rewritten, update if nullified_conditions: func_ir._consts = consts.ConstantInference(func_ir) if DEBUG > 1: print("after".center(80, '-')) print(func_ir.dump())