def test_single_if_else_w_following_undetermined(self):

        def impl(x):
            x_is_none_work = False
            if x is None:
                x_is_none_work = True
            else:
                dead = 7  # noqa: F841 # no effect

            if x_is_none_work:
                y = 10
            else:
                y = -3
            return y

        self.assert_prune(impl, (types.NoneType('none'),), [False, None], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10)

        def impl(x):
            x_is_none_work = False
            if x is None:
                x_is_none_work = True
            else:
                pass  # force the True branch exit to be on backbone

            if x_is_none_work:
                y = 10
            else:
                y = -3
            return y

        self.assert_prune(impl, (types.NoneType('none'),), [None, None], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10)
    def test_single_if(self):

        def impl(x):
            if 1 == 0:
                return 3.14159

        self.assert_prune(impl, (types.NoneType('none'),), [True], None)

        def impl(x):
            if 1 == 1:
                return 3.14159

        self.assert_prune(impl, (types.NoneType('none'),), [False], None)

        def impl(x):
            if x is None:
                return 3.14159

        self.assert_prune(impl, (types.NoneType('none'),), [False], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True], 10)

        def impl(x):
            if x == 10:
                return 3.14159

        self.assert_prune(impl, (types.NoneType('none'),), [True], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10)

        def impl(x):
            if x == 10:
                z = 3.14159  # noqa: F841 # no effect

        self.assert_prune(impl, (types.NoneType('none'),), [True], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10)
    def test_redefinition_analysis_different_block_cannot_exec(self):
        # checks that a redefinition in a block guarded by something that
        # has prune potential

        def impl(array, x=None, a=None):
            b = 0
            if x is not None:
                a = 11
            if a is None:
                b += 5
            else:
                b += 7
            return 30 + b

        self.assert_prune(impl,
                          (types.Array(types.float64, 2, 'C'),
                           types.NoneType('none'), types.NoneType('none')),
                          [True, None], np.zeros((2, 3)), None, None)

        self.assert_prune(impl, (types.Array(
            types.float64, 2, 'C'), types.NoneType('none'), types.float64),
                          [True, None], np.zeros((2, 3)), None, 1.2)

        self.assert_prune(impl, (types.Array(
            types.float64, 2, 'C'), types.float64, types.NoneType('none')),
                          [None, None], np.zeros((2, 3)), 1.2, None)
    def test_single_if_const_val(self):
        def impl(x):
            if x == 100:
                return 3.14159

        self.assert_prune(impl, (types.NoneType('none'), ), [True], None)
        self.assert_prune(impl, (types.IntegerLiteral(100), ), [None], 100)

        def impl(x):
            # switch the condition order
            if 100 == x:
                return 3.14159

        self.assert_prune(impl, (types.NoneType('none'), ), [True], None)
        self.assert_prune(impl, (types.IntegerLiteral(100), ), [None], 100)
    def test_single_if_else_two_const_val(self):

        def impl(x, y):
            if x == y:
                return 3.14159
            else:
                return 1.61803

        self.assert_prune(impl, (types.IntegerLiteral(100),) * 2, [None], 100,
                          100)
        self.assert_prune(impl, (types.NoneType('none'),) * 2, [False], None,
                          None)
        self.assert_prune(impl, (types.IntegerLiteral(100),
                                 types.NoneType('none'),), [True], 100, None)
        self.assert_prune(impl, (types.IntegerLiteral(100),
                                 types.IntegerLiteral(1000)), [None], 100, 1000)
    def test_single_if_else(self):
        def impl(x):
            if x is None:
                return 3.14159
            else:
                return 1.61803

        self.assert_prune(impl, (types.NoneType('none'), ), [False], None)
        self.assert_prune(impl, (types.IntegerLiteral(10), ), [True], 10)
    def test_redefined_variables_are_not_considered_in_prune(self):
        # see issue #4163, checks that if a variable that is an argument is
        # redefined in the user code it is not considered const

        def impl(array, a=None):
            if a is None:
                a = 0
            if a < 0:
                return 10
            return 30

        self.assert_prune(impl, (
            types.Array(types.float64, 2, 'C'),
            types.NoneType('none'),
        ), [None, None], np.zeros((2, 3)), None)
    def test_single_two_branches_same_cond(self):

        def impl(x):
            if x is None:
                y = 10
            else:
                y = 40

            if x is not None:
                z = 100
            else:
                z = 400

            return z, y

        self.assert_prune(impl, (types.NoneType('none'),), [False, True], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10)
    def test_double_if_else_rt_const(self):

        def impl(x):
            one_hundred = 100
            x_is_none_work = 4
            if x is None:
                x_is_none_work = 100
            else:
                dead = 7  # noqa: F841 # no effect

            if x_is_none_work == one_hundred:
                y = 10
            else:
                y = -3

            return y, x_is_none_work

        self.assert_prune(impl, (types.NoneType('none'),), [False, None], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10)
    def test_cond_is_kwarg_none(self):

        def impl(x=None):
            if x is None:
                y = 10
            else:
                y = 40

            if x is not None:
                z = 100
            else:
                z = 400

            return z, y

        self.assert_prune(impl, (types.Omitted(None),),
                          [False, True], None)
        self.assert_prune(impl, (types.NoneType('none'),), [False, True], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10)
    def test_cond_is_kwarg_value(self):

        def impl(x=1000):
            if x == 1000:
                y = 10
            else:
                y = 40

            if x != 1000:
                z = 100
            else:
                z = 400

            return z, y

        self.assert_prune(impl, (types.Omitted(1000),), [None, None], 1000)
        self.assert_prune(impl, (types.IntegerLiteral(1000),), [None, None],
                          1000)
        self.assert_prune(impl, (types.IntegerLiteral(0),), [None, None], 0)
        self.assert_prune(impl, (types.NoneType('none'),), [True, False], None)
示例#12
0
    def test_redefinition_analysis_same_block(self):
        # checks that a redefinition in a block with prunable potential doesn't
        # break

        def impl(array, x, a=None):
            b = 0
            if x < 4:
                b = 12
            if a is None:
                a = 0
            else:
                b = 12
            if a < 0:
                return 10
            return 30 + b + a

        self.assert_prune(impl, (
            types.Array(types.float64, 2, 'C'),
            types.float64,
            types.NoneType('none'),
        ), [None, None, None], np.zeros((2, 3)), 1., None)
示例#13
0
    def test_cond_rewrite_is_correct(self):
        # this checks that when a condition is replaced, it is replace by a
        # true/false bit that correctly represents the evaluated condition
        def fn(x):
            if x is None:
                return 10
            return 12

        def check(func, arg_tys, bit_val):
            func_ir = compile_to_ir(func)

            # check there is 1 branch
            before_branches = self.find_branches(func_ir)
            self.assertEqual(len(before_branches), 1)

            # check the condition in the branch is a binop
            condition_var = before_branches[0].cond
            condition_defn = ir_utils.get_definition(func_ir, condition_var)
            self.assertEqual(condition_defn.op, 'binop')

            # do the prune, this should kill the dead branch and rewrite the
            #'condition to a true/false const bit
            if self._DEBUG:
                print("=" * 80)
                print("before prune")
                func_ir.dump()
            dead_branch_prune(func_ir, arg_tys)
            if self._DEBUG:
                print("=" * 80)
                print("after prune")
                func_ir.dump()

            # after mutation, the condition should be a const value `bit_val`
            new_condition_defn = ir_utils.get_definition(
                func_ir, condition_var)
            self.assertTrue(isinstance(new_condition_defn, ir.Const))
            self.assertEqual(new_condition_defn.value, bit_val)

        check(fn, (types.NoneType('none'), ), 1)
        check(fn, (types.IntegerLiteral(10), ), 0)
示例#14
0
    def resolve_input_arg_const(input_arg_idx):
        """
        Resolves an input arg to a constant (if possible)
        """
        input_arg_ty = called_args[input_arg_idx]

        # comparing to None?
        if isinstance(input_arg_ty, types.NoneType):
            return input_arg_ty

        # is it a kwarg default
        if isinstance(input_arg_ty, types.Omitted):
            val = input_arg_ty.value
            if isinstance(val, types.NoneType):
                return val
            elif val is None:
                return types.NoneType('none')

        # literal type, return the type itself so comparisons like `x == None`
        # still work as e.g. x = types.int64 will never be None/NoneType so
        # the branch can still be pruned
        return getattr(input_arg_ty, 'literal_type', Unknown())
示例#15
0
    def test_redefinition_analysis_different_block_can_exec(self):
        # checks that a redefinition in a block that may be executed prevents
        # pruning

        def impl(array, x, a=None):
            b = 0
            if x > 5:
                a = 11  # a redefined, cannot tell statically if this will exec
            if x < 4:
                b = 12
            if a is None:  # cannot prune, cannot determine if re-defn occurred
                b += 5
            else:
                b += 7
                if a < 0:
                    return 10
            return 30 + b

        self.assert_prune(impl, (
            types.Array(types.float64, 2, 'C'),
            types.float64,
            types.NoneType('none'),
        ), [None, None, None, None], np.zeros((2, 3)), 1., None)
示例#16
0
    def test_comparison_operators(self):
        # see issue #4163, checks that a variable that is an argument and has
        # value None survives TypeError from invalid comparison which should be
        # dead

        def impl(array, a=None):
            x = 0
            if a is None:
                return 10  # dynamic exec would return here
            # static analysis requires that this is executed with a=None,
            # hence TypeError
            if a < 0:
                return 20
            return x

        self.assert_prune(impl, (
            types.Array(types.float64, 2, 'C'),
            types.NoneType('none'),
        ), [False, 'both'], np.zeros((2, 3)), None)

        self.assert_prune(impl, (
            types.Array(types.float64, 2, 'C'),
            types.float64,
        ), [None, None], np.zeros((2, 3)), 12.)
def dead_branch_prune(func_ir, called_args):
    """
    Removes dead branches based on constant inference from function args.
    This directly mutates the IR.

    func_ir is the IR
    called_args are the actual arguments with which the function is called
    """
    from .ir_utils import get_definition, guard, find_const, GuardException

    DEBUG = 0

    def find_branches(func_ir):
        # find *all* branches
        branches = []
        for blk in func_ir.blocks.values():
            branch_or_jump = blk.body[-1]
            if isinstance(branch_or_jump, ir.Branch):
                branch = branch_or_jump
                condition = guard(get_definition, func_ir, branch.cond.name)
                if condition is not None:
                    branches.append((branch, condition, blk))
        return branches

    def do_prune(take_truebr, blk):
        keep = branch.truebr if take_truebr else branch.falsebr
        # replace the branch with a direct jump
        jmp = ir.Jump(keep, loc=branch.loc)
        blk.body[-1] = jmp
        return 1 if keep == branch.truebr else 0

    def prune_by_type(branch, condition, blk, *conds):
        # this prunes a given branch and fixes up the IR
        # at least one needs to be a NoneType
        lhs_cond, rhs_cond = conds
        lhs_none = isinstance(lhs_cond, types.NoneType)
        rhs_none = isinstance(rhs_cond, types.NoneType)
        if lhs_none or rhs_none:
            take_truebr = condition.fn(lhs_cond, rhs_cond)
            if DEBUG > 0:
                kill = branch.falsebr if take_truebr else branch.truebr
                print("Pruning %s" % kill, branch, lhs_cond, rhs_cond,
                      condition.fn)
            taken = do_prune(take_truebr, blk)
            return True, taken
        return False, None

    def prune_by_value(branch, condition, blk, *conds):
        lhs_cond, rhs_cond = conds
        take_truebr = condition.fn(lhs_cond, rhs_cond)
        if DEBUG > 0:
            kill = branch.falsebr if take_truebr else branch.truebr
            print("Pruning %s" % kill, branch, lhs_cond, rhs_cond,
                  condition.fn)
        taken = do_prune(take_truebr, blk)
        return True, taken

    class Unknown(object):
        pass

    def resolve_input_arg_const(input_arg):
        """
        Resolves an input arg to a constant (if possible)
        """
        idx = func_ir.arg_names.index(input_arg)
        input_arg_ty = called_args[idx]

        # comparing to None?
        if isinstance(input_arg_ty, types.NoneType):
            return input_arg_ty

        # is it a kwarg default
        if isinstance(input_arg_ty, types.Omitted):
            val = input_arg_ty.value
            if isinstance(val, types.NoneType):
                return val
            elif val is None:
                return types.NoneType('none')

        # literal type, return the type itself so comparisons like `x == None`
        # still work as e.g. x = types.int64 will never be None/NoneType so
        # the branch can still be pruned
        return getattr(input_arg_ty, 'literal_type', Unknown())

    if DEBUG > 1:
        print("before".center(80, '-'))
        print(func_ir.dump())

    # This looks for branches where:
    # at least one arg of the condition is in input args and const
    # at least one an arg of the condition is a const
    # if the condition is met it will replace the branch with a jump
    branch_info = find_branches(func_ir)
    nullified_conditions = [
    ]  # stores conditions that have no impact post prune
    for branch, condition, blk in branch_info:
        const_conds = []
        if isinstance(condition, ir.Expr) and condition.op == 'binop':
            prune = prune_by_value
            for arg in [condition.lhs, condition.rhs]:
                resolved_const = Unknown()
                if arg.name in func_ir.arg_names:
                    # it's an e.g. literal argument to the function
                    resolved_const = resolve_input_arg_const(arg.name)
                    prune = prune_by_type
                else:
                    # it's some const argument to the function, cannot use guard
                    # here as the const itself may be None
                    try:
                        resolved_const = find_const(func_ir, arg)
                        if resolved_const is None:
                            resolved_const = types.NoneType('none')
                    except GuardException:
                        pass

                if not isinstance(resolved_const, Unknown):
                    const_conds.append(resolved_const)

            # lhs/rhs are consts
            if len(const_conds) == 2:
                # prune the branch, switch the branch for an unconditional jump
                prune_stat, taken = prune(branch, condition, blk, *const_conds)
                if (prune_stat):
                    # add the condition to the list of nullified conditions
                    nullified_conditions.append((condition, taken))

    # 'ERE BE DRAGONS...
    # It is the evaluation of the condition expression that often trips up type
    # inference, so ideally it would be removed as it is effectively rendered
    # dead by the unconditional jump if a branch was pruned. However, there may
    # be references to the condition that exist in multiple places (e.g. dels)
    # and we cannot run DCE here as typing has not taken place to give enough
    # information to run DCE safely. Upshot of all this is the condition gets
    # rewritten below into a benign const that typing will be happy with and DCE
    # can remove it and its reference post typing when it is safe to do so
    # (if desired). It is required that the const is assigned a value that
    # indicates the branch taken as its mutated value would be read in the case
    # of object mode fall back in place of the condition itself. For
    # completeness the func_ir._definitions and ._consts are also updated to
    # make the IR state self consistent.

    deadcond = [x[0] for x in nullified_conditions]
    for _, cond, blk in branch_info:
        if cond in deadcond:
            for x in blk.body:
                if isinstance(x, ir.Assign) and x.value is cond:
                    # rewrite the condition as a true/false bit
                    branch_bit = nullified_conditions[deadcond.index(cond)][1]
                    x.value = ir.Const(branch_bit, loc=x.loc)
                    # update the specific definition to the new const
                    defns = func_ir._definitions[x.target.name]
                    repl_idx = defns.index(cond)
                    defns[repl_idx] = x.value

    # Remove dead blocks, this is safe as it relies on the CFG only.
    cfg = compute_cfg_from_blocks(func_ir.blocks)
    for dead in cfg.dead_nodes():
        del func_ir.blocks[dead]

    # if conditions were nullified then consts were rewritten, update
    if nullified_conditions:
        func_ir._consts = consts.ConstantInference(func_ir)

    if DEBUG > 1:
        print("after".center(80, '-'))
        print(func_ir.dump())