def check(func, arg_tys, bit_val): func_ir = compile_to_ir(func) # check there is 1 branch before_branches = self.find_branches(func_ir) self.assertEqual(len(before_branches), 1) # check the condition in the branch is a binop pred_var = before_branches[0].cond pred_defn = ir_utils.get_definition(func_ir, pred_var) self.assertEqual(pred_defn.op, 'call') condition_var = pred_defn.args[0] condition_op = ir_utils.get_definition(func_ir, condition_var) self.assertEqual(condition_op.op, 'binop') # do the prune, this should kill the dead branch and rewrite the #'condition to a true/false const bit if self._DEBUG: print("=" * 80) print("before prune") func_ir.dump() dead_branch_prune(func_ir, arg_tys) if self._DEBUG: print("=" * 80) print("after prune") func_ir.dump() # after mutation, the condition should be a const value `bit_val` new_condition_defn = ir_utils.get_definition( func_ir, condition_var) self.assertTrue(isinstance(new_condition_defn, ir.Const)) self.assertEqual(new_condition_defn.value, bit_val)
def assert_prune(self, func, args_tys, prune, *args, **kwargs): # This checks that the expected pruned branches have indeed been pruned. # func is a python function to assess # args_tys is the numba types arguments tuple # prune arg is a list, one entry per branch. The value in the entry is # encoded as follows: # True: using constant inference only, the True branch will be pruned # False: using constant inference only, the False branch will be pruned # None: under no circumstances should this branch be pruned # *args: the argument instances to pass to the function to check # execution is still valid post transform # **kwargs: # - flags: compiler.Flags instance to pass to `compile_isolated`, # permits use of e.g. object mode func_ir = compile_to_ir(func) before = func_ir.copy() if self._DEBUG: print("=" * 80) print("before inline") func_ir.dump() # run closure inlining to ensure that nonlocals in closures are visible inline_pass = InlineClosureCallPass( func_ir, cpu.ParallelOptions(False), ) inline_pass.run() # Remove all Dels, and re-run postproc post_proc = postproc.PostProcessor(func_ir) post_proc.run() rewrite_semantic_constants(func_ir, args_tys) if self._DEBUG: print("=" * 80) print("before prune") func_ir.dump() dead_branch_prune(func_ir, args_tys) after = func_ir if self._DEBUG: print("after prune") func_ir.dump() before_branches = self.find_branches(before) self.assertEqual(len(before_branches), len(prune)) # what is expected to be pruned expect_removed = [] for idx, prune in enumerate(prune): branch = before_branches[idx] if prune is True: expect_removed.append(branch.truebr) elif prune is False: expect_removed.append(branch.falsebr) elif prune is None: pass # nothing should be removed! elif prune == 'both': expect_removed.append(branch.falsebr) expect_removed.append(branch.truebr) else: assert 0, "unreachable" # compare labels original_labels = set([_ for _ in before.blocks.keys()]) new_labels = set([_ for _ in after.blocks.keys()]) # assert that the new labels are precisely the original less the # expected pruned labels try: self.assertEqual(new_labels, original_labels - set(expect_removed)) except AssertionError as e: print("new_labels", sorted(new_labels)) print("original_labels", sorted(original_labels)) print("expect_removed", sorted(expect_removed)) raise e supplied_flags = kwargs.pop('flags', False) compiler_kws = {'flags': supplied_flags} if supplied_flags else {} cres = compile_isolated(func, args_tys, **compiler_kws) if args is None: res = cres.entry_point() expected = func() else: res = cres.entry_point(*args) expected = func(*args) self.assertEqual(res, expected)