def check(func, arg_tys, bit_val): func_ir = compile_to_ir(func) # check there is 1 branch before_branches = self.find_branches(func_ir) self.assertEqual(len(before_branches), 1) # check the condition in the branch is a binop condition_var = before_branches[0].cond condition_defn = ir_utils.get_definition(func_ir, condition_var) self.assertEqual(condition_defn.op, 'binop') # do the prune, this should kill the dead branch and rewrite the #'condition to a true/false const bit if self._DEBUG: print("=" * 80) print("before prune") func_ir.dump() dead_branch_prune(func_ir, arg_tys) if self._DEBUG: print("=" * 80) print("after prune") func_ir.dump() # after mutation, the condition should be a const value `bit_val` new_condition_defn = ir_utils.get_definition(func_ir, condition_var) self.assertTrue(isinstance(new_condition_defn, ir.Const)) self.assertEqual(new_condition_defn.value, bit_val)
def assert_prune(self, func, args_tys, prune, *args): # This checks that the expected pruned branches have indeed been pruned. # func is a python function to assess # args_tys is the numba types arguments tuple # prune arg is a list, one entry per branch. The value in the entry is # encoded as follows: # True: using constant inference only, the True branch will be pruned # False: using constant inference only, the False branch will be pruned # None: under no circumstances should this branch be pruned # *args: the argument instances to pass to the function to check # execution is still valid post transform func_ir = compile_to_ir(func) before = func_ir.copy() if self._DEBUG: print("=" * 80) print("before prune") func_ir.dump() dead_branch_prune(func_ir, args_tys) after = func_ir if self._DEBUG: print("after prune") func_ir.dump() before_branches = self.find_branches(before) self.assertEqual(len(before_branches), len(prune)) # what is expected to be pruned expect_removed = [] for idx, prune in enumerate(prune): branch = before_branches[idx] if prune is True: expect_removed.append(branch.truebr) elif prune is False: expect_removed.append(branch.falsebr) elif prune is None: pass # nothing should be removed! else: assert 0, "unreachable" # compare labels original_labels = set([_ for _ in before.blocks.keys()]) new_labels = set([_ for _ in after.blocks.keys()]) # assert that the new labels are precisely the original less the # expected pruned labels try: self.assertEqual(new_labels, original_labels - set(expect_removed)) except AssertionError as e: print("new_labels", new_labels) print("original_labels", original_labels) print("expect_removed", expect_removed) raise e cres = compile_isolated(func, args_tys) res = cres.entry_point(*args) expected = func(*args) self.assertEqual(res, expected)
def stage_dead_branch_prune(self): """ This prunes dead branches, a dead branch is one which is derivable as not taken at compile time purely based on const/literal evaluation. """ assert self.func_ir msg = ('Internal error in pre-inference dead branch pruning ' 'pass encountered during compilation of ' 'function "%s"' % (self.func_id.func_name, )) with self.fallback_context(msg): dead_branch_prune(self.func_ir, self.args) if config.DEBUG or config.DUMP_IR: print('branch_pruned_ir'.center(80, '-')) print(self.func_ir.dump()) print('end branch_pruned_ir'.center(80, '-'))