Esempio n. 1
0
        def check(func, arg_tys, bit_val):
            func_ir = compile_to_ir(func)

            # check there is 1 branch
            before_branches = self.find_branches(func_ir)
            self.assertEqual(len(before_branches), 1)

            # check the condition in the branch is a binop
            pred_var = before_branches[0].cond
            pred_defn = ir_utils.get_definition(func_ir, pred_var)
            self.assertEqual(pred_defn.op, 'call')
            condition_var = pred_defn.args[0]
            condition_op = ir_utils.get_definition(func_ir, condition_var)
            self.assertEqual(condition_op.op, 'binop')

            # do the prune, this should kill the dead branch and rewrite the
            #'condition to a true/false const bit
            if self._DEBUG:
                print("=" * 80)
                print("before prune")
                func_ir.dump()
            dead_branch_prune(func_ir, arg_tys)
            if self._DEBUG:
                print("=" * 80)
                print("after prune")
                func_ir.dump()

            # after mutation, the condition should be a const value `bit_val`
            new_condition_defn = ir_utils.get_definition(
                func_ir, condition_var)
            self.assertTrue(isinstance(new_condition_defn, ir.Const))
            self.assertEqual(new_condition_defn.value, bit_val)
Esempio n. 2
0
    def assert_prune(self, func, args_tys, prune, *args, **kwargs):
        # This checks that the expected pruned branches have indeed been pruned.
        # func is a python function to assess
        # args_tys is the numba types arguments tuple
        # prune arg is a list, one entry per branch. The value in the entry is
        # encoded as follows:
        # True: using constant inference only, the True branch will be pruned
        # False: using constant inference only, the False branch will be pruned
        # None: under no circumstances should this branch be pruned
        # *args: the argument instances to pass to the function to check
        #        execution is still valid post transform
        # **kwargs:
        #        - flags: compiler.Flags instance to pass to `compile_isolated`,
        #          permits use of e.g. object mode

        func_ir = compile_to_ir(func)
        before = func_ir.copy()
        if self._DEBUG:
            print("=" * 80)
            print("before inline")
            func_ir.dump()

        # run closure inlining to ensure that nonlocals in closures are visible
        inline_pass = InlineClosureCallPass(
            func_ir,
            cpu.ParallelOptions(False),
        )
        inline_pass.run()

        # Remove all Dels, and re-run postproc
        post_proc = postproc.PostProcessor(func_ir)
        post_proc.run()

        rewrite_semantic_constants(func_ir, args_tys)
        if self._DEBUG:
            print("=" * 80)
            print("before prune")
            func_ir.dump()

        dead_branch_prune(func_ir, args_tys)

        after = func_ir
        if self._DEBUG:
            print("after prune")
            func_ir.dump()

        before_branches = self.find_branches(before)
        self.assertEqual(len(before_branches), len(prune))

        # what is expected to be pruned
        expect_removed = []
        for idx, prune in enumerate(prune):
            branch = before_branches[idx]
            if prune is True:
                expect_removed.append(branch.truebr)
            elif prune is False:
                expect_removed.append(branch.falsebr)
            elif prune is None:
                pass  # nothing should be removed!
            elif prune == 'both':
                expect_removed.append(branch.falsebr)
                expect_removed.append(branch.truebr)
            else:
                assert 0, "unreachable"

        # compare labels
        original_labels = set([_ for _ in before.blocks.keys()])
        new_labels = set([_ for _ in after.blocks.keys()])
        # assert that the new labels are precisely the original less the
        # expected pruned labels
        try:
            self.assertEqual(new_labels, original_labels - set(expect_removed))
        except AssertionError as e:
            print("new_labels", sorted(new_labels))
            print("original_labels", sorted(original_labels))
            print("expect_removed", sorted(expect_removed))
            raise e

        supplied_flags = kwargs.pop('flags', False)
        compiler_kws = {'flags': supplied_flags} if supplied_flags else {}
        cres = compile_isolated(func, args_tys, **compiler_kws)
        if args is None:
            res = cres.entry_point()
            expected = func()
        else:
            res = cres.entry_point(*args)
            expected = func(*args)
        self.assertEqual(res, expected)