コード例 #1
0
ファイル: test_analysis.py プロジェクト: numba/numba
        def check(func, arg_tys, bit_val):
            func_ir = compile_to_ir(func)

            # check there is 1 branch
            before_branches = self.find_branches(func_ir)
            self.assertEqual(len(before_branches), 1)

            # check the condition in the branch is a binop
            condition_var = before_branches[0].cond
            condition_defn = ir_utils.get_definition(func_ir, condition_var)
            self.assertEqual(condition_defn.op, 'binop')

            # do the prune, this should kill the dead branch and rewrite the
            #'condition to a true/false const bit
            if self._DEBUG:
                print("=" * 80)
                print("before prune")
                func_ir.dump()
            dead_branch_prune(func_ir, arg_tys)
            if self._DEBUG:
                print("=" * 80)
                print("after prune")
                func_ir.dump()

            # after mutation, the condition should be a const value `bit_val`
            new_condition_defn = ir_utils.get_definition(func_ir, condition_var)
            self.assertTrue(isinstance(new_condition_defn, ir.Const))
            self.assertEqual(new_condition_defn.value, bit_val)
コード例 #2
0
ファイル: inline_closurecall.py プロジェクト: yuguen/numba
 def _fix_stencil_neighborhood(self, options):
     """
     Extract the two-level tuple representing the stencil neighborhood
     from the program IR to provide a tuple to StencilFunc.
     """
     # build_tuple node with neighborhood for each dimension
     dims_build_tuple = get_definition(self.func_ir, options['neighborhood'])
     require(hasattr(dims_build_tuple, 'items'))
     res = []
     for window_var in dims_build_tuple.items:
         win_build_tuple = get_definition(self.func_ir, window_var)
         require(hasattr(win_build_tuple, 'items'))
         res.append(tuple(win_build_tuple.items))
     options['neighborhood'] = tuple(res)
     return True
コード例 #3
0
    def _infer_h5_typ(self, rhs):
        # infer the type if it is of the from f['A']['B'][:] or f['A'][b,:]
        # with constant filename
        # TODO: static_getitem has index_var for sure?
        # make sure it's slice, TODO: support non-slice like integer
        require(rhs.op in ('getitem', 'static_getitem'))
        # XXX can't know the type of index here especially if it is bool arr
        # make sure it is not string (we're not in the middle a select chain)
        index_var = rhs.index if rhs.op == 'getitem' else rhs.index_var
        index_val = guard(find_const, self.func_ir, index_var)
        require(not isinstance(index_val, str))
        # index_def = get_definition(self.func_ir, index_var)
        # require(isinstance(index_def, ir.Expr) and index_def.op == 'call')
        # require(find_callname(self.func_ir, index_def) == ('slice', 'builtins'))
        # collect object names until the call
        val_def = rhs
        obj_name_list = []
        while True:
            val_def = get_definition(self.func_ir, val_def.value)
            require(isinstance(val_def, ir.Expr))
            if val_def.op == 'call':
                return self._get_h5_type_file(val_def, obj_name_list)

            # object_name should be constant str
            require(val_def.op in ('getitem', 'static_getitem'))
            val_index_var = val_def.index if val_def.op == 'getitem' else val_def.index_var
            obj_name = find_str_const(self.func_ir, val_index_var)
            obj_name_list.append(obj_name)
コード例 #4
0
 def _fix_stencil_neighborhood(self, options):
     """
     Extract the two-level tuple representing the stencil neighborhood
     from the program IR to provide a tuple to StencilFunc.
     """
     # build_tuple node with neighborhood for each dimension
     dims_build_tuple = get_definition(self.func_ir,
                                       options['neighborhood'])
     require(hasattr(dims_build_tuple, 'items'))
     res = []
     for window_var in dims_build_tuple.items:
         win_build_tuple = get_definition(self.func_ir, window_var)
         require(hasattr(win_build_tuple, 'items'))
         res.append(tuple(win_build_tuple.items))
     options['neighborhood'] = tuple(res)
     return True
コード例 #5
0
def _find_unsafe_empty_inferred(func_ir, expr):
    unsafe_empty_inferred
    require(isinstance(expr, ir.Expr) and expr.op == 'call')
    callee = expr.func
    callee_def = get_definition(func_ir, callee)
    require(isinstance(callee_def, ir.Global))
    _make_debug_print("_find_unsafe_empty_inferred")(callee_def.value)
    return callee_def.value == unsafe_empty_inferred
コード例 #6
0
 def _fix_stencil_index_offsets(self, options):
     """
     Extract the tuple representing the stencil index offsets
     from the program IR to provide to StencilFunc.
     """
     offset_tuple = get_definition(self.func_ir, options['index_offsets'])
     require(hasattr(offset_tuple, 'items'))
     options['index_offsets'] = tuple(offset_tuple.items)
     return True
コード例 #7
0
 def _get_str_contains_col(self, func_def):
     require(isinstance(func_def, ir.Expr) and func_def.op == 'getattr')
     require(func_def.attr == 'contains')
     str_def = get_definition(self.func_ir, func_def.value)
     require(isinstance(str_def, ir.Expr) and str_def.op == 'getattr')
     require(str_def.attr == 'str')
     col = str_def.value
     require(col.name in self.df_cols)
     return col
コード例 #8
0
def find_build_tuple(func_ir, var):
    """Check if a variable is constructed via build_tuple
    and return the sequence or raise GuardException otherwise.
    """
    # variable or variable name
    require(isinstance(var, (ir.Var, str)))
    var_def = get_definition(func_ir, var)
    require(isinstance(var_def, ir.Expr))
    require(var_def.op == 'build_tuple')
    return var_def.items
コード例 #9
0
def find_build_sequence(func_ir, var):
    """Reimplemented from numba.ir_utils.find_build_sequence
    Added 'build_map' to build_ops list.
    """
    from numba.ir_utils import (require, get_definition)

    require(isinstance(var, ir.Var))
    var_def = get_definition(func_ir, var)
    require(isinstance(var_def, ir.Expr))
    build_ops = ['build_tuple', 'build_list', 'build_set', 'build_map']
    require(var_def.op in build_ops)
    return var_def.items, var_def.op
コード例 #10
0
def _find_iter_range(func_ir, range_iter_var):
    """Find the iterator's actual range if it is either range(n), or range(m, n),
    otherwise return raise GuardException.
    """
    debug_print = _make_debug_print("find_iter_range")
    range_iter_def = get_definition(func_ir, range_iter_var)
    debug_print("range_iter_var = ", range_iter_var, " def = ", range_iter_def)
    require(isinstance(range_iter_def, ir.Expr) and range_iter_def.op == 'getiter')
    range_var = range_iter_def.value
    range_def = get_definition(func_ir, range_var)
    debug_print("range_var = ", range_var, " range_def = ", range_def)
    require(isinstance(range_def, ir.Expr) and range_def.op == 'call')
    func_var = range_def.func
    func_def = get_definition(func_ir, func_var)
    debug_print("func_var = ", func_var, " func_def = ", func_def)
    require(isinstance(func_def, ir.Global) and func_def.value == range)
    nargs = len(range_def.args)
    if nargs == 1:
        stop = get_definition(func_ir, range_def.args[0], lhs_only=True)
        return (0, range_def.args[0], func_def)
    elif nargs == 2:
        start = get_definition(func_ir, range_def.args[0], lhs_only=True)
        stop = get_definition(func_ir, range_def.args[1], lhs_only=True)
        return (start, stop, func_def)
    else:
        raise GuardException
コード例 #11
0
def _find_arraycall(func_ir, block):
    """Look for statement like "x = numpy.array(y)" or "x[..] = y"
    immediately after the closure call that creates list y (the i-th
    statement in block).  Return the statement index if found, or
    raise GuardException.
    """
    array_var = None
    array_call_index = None
    list_var_dead_after_array_call = False
    list_var = None

    i = 0
    while i < len(block.body):
        instr = block.body[i]
        if isinstance(instr, ir.Del):
            # Stop the process if list_var becomes dead
            if list_var and array_var and instr.value == list_var.name:
                list_var_dead_after_array_call = True
                break
            pass
        elif isinstance(instr, ir.Assign):
            # Found array_var = array(list_var)
            lhs  = instr.target
            expr = instr.value
            if (guard(find_callname, func_ir, expr) == ('array', 'numpy') and
                isinstance(expr.args[0], ir.Var)):
                list_var = expr.args[0]
                array_var = lhs
                array_stmt_index = i
                array_kws = dict(expr.kws)
        elif (isinstance(instr, ir.SetItem) and
              isinstance(instr.value, ir.Var) and
              not list_var):
            list_var = instr.value
            # Found array_var[..] = list_var, the case for nested array
            array_var = instr.target
            array_def = get_definition(func_ir, array_var)
            require(guard(_find_unsafe_empty_inferred, func_ir, array_def))
            array_stmt_index = i
            array_kws = {}
        else:
            # Bail out otherwise
            break
        i = i + 1
    # require array_var is found, and list_var is dead after array_call.
    require(array_var and list_var_dead_after_array_call)
    _make_debug_print("find_array_call")(block.body[array_stmt_index])
    return list_var, array_stmt_index, array_kws
コード例 #12
0
ファイル: utils.py プロジェクト: stuartarchibald/hpat
def find_str_const(func_ir, var):
    """Check if a variable can be inferred as a string constant, and return
    the constant value, or raise GuardException otherwise.
    """
    require(isinstance(var, ir.Var))
    var_def = get_definition(func_ir, var)
    if isinstance(var_def, ir.Const):
        val = var_def.value
        require(isinstance(val, str))
        return val

    # only add supported (s1+s2), TODO: extend to other expressions
    require(isinstance(var_def, ir.Expr) and var_def.op == 'binop'
            and var_def.fn == operator.add)
    arg1 = find_str_const(func_ir, var_def.lhs)
    arg2 = find_str_const(func_ir, var_def.rhs)
    return arg1 + arg2
コード例 #13
0
 def fix_dependencies(expr, varlist):
     """Double check if all variables in varlist are defined before
     expr is used. Try to move constant definition when the check fails.
     Bails out by raising GuardException if it can't be moved.
     """
     debug_print = _make_debug_print("fix_dependencies")
     for label, block in blocks.items():
         scope = block.scope
         body = block.body
         defined = set()
         for i in range(len(body)):
             inst = body[i]
             if isinstance(inst, ir.Assign):
                 defined.add(inst.target.name)
                 if inst.value == expr:
                     new_varlist = []
                     for var in varlist:
                         # var must be defined before this inst, or live
                         # and not later defined.
                         if (var.name in defined or
                             (var.name in livemap[label]
                              and not (var.name in usedefs.defmap[label]))):
                             debug_print(var.name, " already defined")
                             new_varlist.append(var)
                         else:
                             debug_print(var.name, " not yet defined")
                             var_def = get_definition(func_ir, var.name)
                             if isinstance(var_def, ir.Const):
                                 loc = var.loc
                                 new_var = ir.Var(scope,
                                                  mk_unique_var("new_var"),
                                                  loc)
                                 new_const = ir.Const(var_def.value, loc)
                                 new_vardef = _new_definition(
                                     func_ir, new_var, new_const, loc)
                                 new_body = []
                                 new_body.extend(body[:i])
                                 new_body.append(new_vardef)
                                 new_body.extend(body[i:])
                                 block.body = new_body
                                 new_varlist.append(new_var)
                             else:
                                 raise GuardException
                     return new_varlist
     # when expr is not found in block
     raise GuardException
コード例 #14
0
ファイル: inline_closurecall.py プロジェクト: yuguen/numba
 def fix_dependencies(expr, varlist):
     """Double check if all variables in varlist are defined before
     expr is used. Try to move constant definition when the check fails.
     Bails out by raising GuardException if it can't be moved.
     """
     debug_print = _make_debug_print("fix_dependencies")
     for label, block in blocks.items():
         scope = block.scope
         body = block.body
         defined = set()
         for i in range(len(body)):
             inst = body[i]
             if isinstance(inst, ir.Assign):
                 defined.add(inst.target.name)
                 if inst.value == expr:
                     new_varlist = []
                     for var in varlist:
                         # var must be defined before this inst, or live
                         # and not later defined.
                         if (var.name in defined or
                             (var.name in livemap[label] and
                              not (var.name in usedefs.defmap[label]))):
                             debug_print(var.name, " already defined")
                             new_varlist.append(var)
                         else:
                             debug_print(var.name, " not yet defined")
                             var_def = get_definition(func_ir, var.name)
                             if isinstance(var_def, ir.Const):
                                 loc = var.loc
                                 new_var = ir.Var(scope, mk_unique_var("new_var"), loc)
                                 new_const = ir.Const(var_def.value, loc)
                                 new_vardef = _new_definition(func_ir,
                                                 new_var, new_const, loc)
                                 new_body = []
                                 new_body.extend(body[:i])
                                 new_body.append(new_vardef)
                                 new_body.extend(body[i:])
                                 block.body = new_body
                                 new_varlist.append(new_var)
                             else:
                                 raise GuardException
                     return new_varlist
     # when expr is not found in block
     raise GuardException
コード例 #15
0
def _get_const_index_expr_inner(stencil_ir, func_ir, index_var):
    """inner constant inference function that calls constant, unary and binary
    cases.
    """
    require(isinstance(index_var, ir.Var))
    # case where the index is a const itself in outer function
    var_const = guard(_get_const_two_irs, stencil_ir, func_ir, index_var)
    if var_const is not None:
        return var_const
    # get index definition
    index_def = ir_utils.get_definition(stencil_ir, index_var)
    # match inner_var = unary(index_var)
    var_const = guard(_get_const_unary_expr, stencil_ir, func_ir, index_def)
    if var_const is not None:
        return var_const
    # match inner_var = arg1 + arg2
    var_const = guard(_get_const_binary_expr, stencil_ir, func_ir, index_def)
    if var_const is not None:
        return var_const
    raise GuardException
コード例 #16
0
 def fix_array_assign(stmt):
     """For assignment like lhs[idx] = rhs, where both lhs and rhs are arrays, do the
     following:
     1. find the definition of rhs, which has to be a call to numba.unsafe.ndarray.empty_inferred
     2. find the source array creation for lhs, insert an extra dimension of size of b.
     3. replace the definition of rhs = numba.unsafe.ndarray.empty_inferred(...) with rhs = lhs[idx]
     """
     require(isinstance(stmt, ir.SetItem))
     require(isinstance(stmt.value, ir.Var))
     debug_print = _make_debug_print("fix_array_assign")
     debug_print("found SetItem: ", stmt)
     lhs = stmt.target
     # Find the source array creation of lhs
     lhs_def = find_array_def(lhs)
     debug_print("found lhs_def: ", lhs_def)
     rhs_def = get_definition(func_ir, stmt.value)
     debug_print("found rhs_def: ", rhs_def)
     require(isinstance(rhs_def, ir.Expr))
     if rhs_def.op == 'cast':
         rhs_def = get_definition(func_ir, rhs_def.value)
         require(isinstance(rhs_def, ir.Expr))
     require(_find_unsafe_empty_inferred(func_ir, rhs_def))
     # Find the array dimension of rhs
     dim_def = get_definition(func_ir, rhs_def.args[0])
     require(isinstance(dim_def, ir.Expr) and dim_def.op == 'build_tuple')
     debug_print("dim_def = ", dim_def)
     extra_dims = [
         get_definition(func_ir, x, lhs_only=True) for x in dim_def.items
     ]
     debug_print("extra_dims = ", extra_dims)
     # Expand size tuple when creating lhs_def with extra_dims
     size_tuple_def = get_definition(func_ir, lhs_def.args[0])
     require(
         isinstance(size_tuple_def, ir.Expr)
         and size_tuple_def.op == 'build_tuple')
     debug_print("size_tuple_def = ", size_tuple_def)
     extra_dims = fix_dependencies(size_tuple_def, extra_dims)
     size_tuple_def.items += extra_dims
     # In-place modify rhs_def to be getitem
     rhs_def.op = 'getitem'
     rhs_def.value = get_definition(func_ir, lhs, lhs_only=True)
     rhs_def.index = stmt.index
     del rhs_def._kws['func']
     del rhs_def._kws['args']
     del rhs_def._kws['vararg']
     del rhs_def._kws['kws']
     # success
     return True
コード例 #17
0
ファイル: stencilparfor.py プロジェクト: numba/numba
def _get_const_index_expr_inner(stencil_ir, func_ir, index_var):
    """inner constant inference function that calls constant, unary and binary
    cases.
    """
    require(isinstance(index_var, ir.Var))
    # case where the index is a const itself in outer function
    var_const =  guard(_get_const_two_irs, stencil_ir, func_ir, index_var)
    if var_const is not None:
        return var_const
    # get index definition
    index_def = ir_utils.get_definition(stencil_ir, index_var)
    # match inner_var = unary(index_var)
    var_const = guard(
        _get_const_unary_expr, stencil_ir, func_ir, index_def)
    if var_const is not None:
        return var_const
    # match inner_var = arg1 + arg2
    var_const = guard(
        _get_const_binary_expr, stencil_ir, func_ir, index_def)
    if var_const is not None:
        return var_const
    raise GuardException
コード例 #18
0
ファイル: inline_closurecall.py プロジェクト: yuguen/numba
 def fix_array_assign(stmt):
     """For assignment like lhs[idx] = rhs, where both lhs and rhs are arrays, do the
     following:
     1. find the definition of rhs, which has to be a call to numba.unsafe.ndarray.empty_inferred
     2. find the source array creation for lhs, insert an extra dimension of size of b.
     3. replace the definition of rhs = numba.unsafe.ndarray.empty_inferred(...) with rhs = lhs[idx]
     """
     require(isinstance(stmt, ir.SetItem))
     require(isinstance(stmt.value, ir.Var))
     debug_print = _make_debug_print("fix_array_assign")
     debug_print("found SetItem: ", stmt)
     lhs = stmt.target
     # Find the source array creation of lhs
     lhs_def = find_array_def(lhs)
     debug_print("found lhs_def: ", lhs_def)
     rhs_def = get_definition(func_ir, stmt.value)
     debug_print("found rhs_def: ", rhs_def)
     require(isinstance(rhs_def, ir.Expr))
     if rhs_def.op == 'cast':
         rhs_def = get_definition(func_ir, rhs_def.value)
         require(isinstance(rhs_def, ir.Expr))
     require(_find_unsafe_empty_inferred(func_ir, rhs_def))
     # Find the array dimension of rhs
     dim_def = get_definition(func_ir, rhs_def.args[0])
     require(isinstance(dim_def, ir.Expr) and dim_def.op == 'build_tuple')
     debug_print("dim_def = ", dim_def)
     extra_dims = [ get_definition(func_ir, x, lhs_only=True) for x in dim_def.items ]
     debug_print("extra_dims = ", extra_dims)
     # Expand size tuple when creating lhs_def with extra_dims
     size_tuple_def = get_definition(func_ir, lhs_def.args[0])
     require(isinstance(size_tuple_def, ir.Expr) and size_tuple_def.op == 'build_tuple')
     debug_print("size_tuple_def = ", size_tuple_def)
     extra_dims = fix_dependencies(size_tuple_def, extra_dims)
     size_tuple_def.items += extra_dims
     # In-place modify rhs_def to be getitem
     rhs_def.op = 'getitem'
     rhs_def.value = get_definition(func_ir, lhs, lhs_only=True)
     rhs_def.index = stmt.index
     del rhs_def._kws['func']
     del rhs_def._kws['args']
     del rhs_def._kws['vararg']
     del rhs_def._kws['kws']
     # success
     return True
コード例 #19
0
def _inline_arraycall(func_ir, cfg, visited, loop, enable_prange=False):
    """Look for array(list) call in the exit block of a given loop, and turn list operations into
    array operations in the loop if the following conditions are met:
      1. The exit block contains an array call on the list;
      2. The list variable is no longer live after array call;
      3. The list is created in the loop entry block;
      4. The loop is created from an range iterator whose length is known prior to the loop;
      5. There is only one list_append operation on the list variable in the loop body;
      6. The block that contains list_append dominates the loop head, which ensures list
         length is the same as loop length;
    If any condition check fails, no modification will be made to the incoming IR.
    """
    debug_print = _make_debug_print("inline_arraycall")
    # There should only be one loop exit
    require(len(loop.exits) == 1)
    exit_block = next(iter(loop.exits))
    list_var, array_call_index, array_kws = _find_arraycall(func_ir, func_ir.blocks[exit_block])

    # check if dtype is present in array call
    dtype_def = None
    dtype_mod_def = None
    if 'dtype' in array_kws:
        require(isinstance(array_kws['dtype'], ir.Var))
        # We require that dtype argument to be a constant of getattr Expr, and we'll
        # remember its definition for later use.
        dtype_def = get_definition(func_ir, array_kws['dtype'])
        require(isinstance(dtype_def, ir.Expr) and dtype_def.op == 'getattr')
        dtype_mod_def = get_definition(func_ir, dtype_def.value)

    list_var_def = get_definition(func_ir, list_var)
    debug_print("list_var = ", list_var, " def = ", list_var_def)
    if isinstance(list_var_def, ir.Expr) and list_var_def.op == 'cast':
        list_var_def = get_definition(func_ir, list_var_def.value)
    # Check if the definition is a build_list
    require(isinstance(list_var_def, ir.Expr) and list_var_def.op ==  'build_list')

    # Look for list_append in "last" block in loop body, which should be a block that is
    # a post-dominator of the loop header.
    list_append_stmts = []
    for label in loop.body:
        # We have to consider blocks of this loop, but not sub-loops.
        # To achieve this, we require the set of "in_loops" of "label" to be visited loops.
        in_visited_loops = [l.header in visited for l in cfg.in_loops(label)]
        if not all(in_visited_loops):
            continue
        block = func_ir.blocks[label]
        debug_print("check loop body block ", label)
        for stmt in block.find_insts(ir.Assign):
            lhs = stmt.target
            expr = stmt.value
            if isinstance(expr, ir.Expr) and expr.op == 'call':
                func_def = get_definition(func_ir, expr.func)
                if isinstance(func_def, ir.Expr) and func_def.op == 'getattr' \
                  and func_def.attr == 'append':
                    list_def = get_definition(func_ir, func_def.value)
                    debug_print("list_def = ", list_def, list_def == list_var_def)
                    if list_def == list_var_def:
                        # found matching append call
                        list_append_stmts.append((label, block, stmt))

    # Require only one list_append, otherwise we won't know the indices
    require(len(list_append_stmts) == 1)
    append_block_label, append_block, append_stmt = list_append_stmts[0]

    # Check if append_block (besides loop entry) dominates loop header.
    # Since CFG doesn't give us this info without loop entry, we approximate
    # by checking if the predecessor set of the header block is the same
    # as loop_entries plus append_block, which is certainly more restrictive
    # than necessary, and can be relaxed if needed.
    preds = set(l for l, b in cfg.predecessors(loop.header))
    debug_print("preds = ", preds, (loop.entries | set([append_block_label])))
    require(preds == (loop.entries | set([append_block_label])))

    # Find iterator in loop header
    iter_vars = []
    iter_first_vars = []
    loop_header = func_ir.blocks[loop.header]
    for stmt in loop_header.find_insts(ir.Assign):
        expr = stmt.value
        if isinstance(expr, ir.Expr):
            if expr.op == 'iternext':
                iter_def = get_definition(func_ir, expr.value)
                debug_print("iter_def = ", iter_def)
                iter_vars.append(expr.value)
            elif expr.op == 'pair_first':
                iter_first_vars.append(stmt.target)

    # Require only one iterator in loop header
    require(len(iter_vars) == 1 and len(iter_first_vars) == 1)
    iter_var = iter_vars[0] # variable that holds the iterator object
    iter_first_var = iter_first_vars[0] # variable that holds the value out of iterator

    # Final requirement: only one loop entry, and we're going to modify it by:
    # 1. replacing the list definition with an array definition;
    # 2. adding a counter for the array iteration.
    require(len(loop.entries) == 1)
    loop_entry = func_ir.blocks[next(iter(loop.entries))]
    terminator = loop_entry.terminator
    scope = loop_entry.scope
    loc = loop_entry.loc
    stmts = []
    removed = []
    def is_removed(val, removed):
        if isinstance(val, ir.Var):
            for x in removed:
                if x.name == val.name:
                    return True
        return False
    # Skip list construction and skip terminator, add the rest to stmts
    for i in range(len(loop_entry.body) - 1):
        stmt = loop_entry.body[i]
        if isinstance(stmt, ir.Assign) and (stmt.value == list_def or is_removed(stmt.value, removed)):
            removed.append(stmt.target)
        else:
            stmts.append(stmt)
    debug_print("removed variables: ", removed)

    # Define an index_var to index the array.
    # If the range happens to be single step ranges like range(n), or range(m, n),
    # then the index_var correlates to iterator index; otherwise we'll have to
    # define a new counter.
    range_def = guard(_find_iter_range, func_ir, iter_var)
    index_var = ir.Var(scope, mk_unique_var("index"), loc)
    if range_def and range_def[0] == 0:
        # iterator starts with 0, index_var can just be iter_first_var
        index_var = iter_first_var
    else:
        # index_var = -1 # starting the index with -1 since it will incremented in loop header
        stmts.append(_new_definition(func_ir, index_var, ir.Const(value=-1, loc=loc), loc))

    # Insert statement to get the size of the loop iterator
    size_var = ir.Var(scope, mk_unique_var("size"), loc)
    if range_def:
        start, stop, range_func_def = range_def
        if start == 0:
            size_val = stop
        else:
            size_val = ir.Expr.binop(fn='-', lhs=stop, rhs=start, loc=loc)
        # we can parallelize this loop if enable_prange = True, by changing
        # range function from range, to prange.
        if enable_prange and isinstance(range_func_def, ir.Global):
            range_func_def.name = 'internal_prange'
            range_func_def.value = internal_prange

    else:
        len_func_var = ir.Var(scope, mk_unique_var("len_func"), loc)
        stmts.append(_new_definition(func_ir, len_func_var,
                     ir.Global('range_iter_len', range_iter_len, loc=loc), loc))
        size_val = ir.Expr.call(len_func_var, (iter_var,), (), loc=loc)

    stmts.append(_new_definition(func_ir, size_var, size_val, loc))

    size_tuple_var = ir.Var(scope, mk_unique_var("size_tuple"), loc)
    stmts.append(_new_definition(func_ir, size_tuple_var,
                 ir.Expr.build_tuple(items=[size_var], loc=loc), loc))

    # Insert array allocation
    array_var = ir.Var(scope, mk_unique_var("array"), loc)
    empty_func = ir.Var(scope, mk_unique_var("empty_func"), loc)
    if dtype_def and dtype_mod_def:
        # when dtype is present, we'll call emtpy with dtype
        dtype_mod_var = ir.Var(scope, mk_unique_var("dtype_mod"), loc)
        dtype_var = ir.Var(scope, mk_unique_var("dtype"), loc)
        stmts.append(_new_definition(func_ir, dtype_mod_var, dtype_mod_def, loc))
        stmts.append(_new_definition(func_ir, dtype_var,
                         ir.Expr.getattr(dtype_mod_var, dtype_def.attr, loc), loc))
        stmts.append(_new_definition(func_ir, empty_func,
                         ir.Global('empty', np.empty, loc=loc), loc))
        array_kws = [('dtype', dtype_var)]
    else:
        # otherwise we'll call unsafe_empty_inferred
        stmts.append(_new_definition(func_ir, empty_func,
                         ir.Global('unsafe_empty_inferred',
                             unsafe_empty_inferred, loc=loc), loc))
        array_kws = []
    # array_var = empty_func(size_tuple_var)
    stmts.append(_new_definition(func_ir, array_var,
                 ir.Expr.call(empty_func, (size_tuple_var,), list(array_kws), loc=loc), loc))

    # Add back removed just in case they are used by something else
    for var in removed:
        stmts.append(_new_definition(func_ir, var, array_var, loc))

    # Add back terminator
    stmts.append(terminator)
    # Modify loop_entry
    loop_entry.body = stmts

    if range_def:
        if range_def[0] != 0:
            # when range doesn't start from 0, index_var becomes loop index
            # (iter_first_var) minus an offset (range_def[0])
            terminator = loop_header.terminator
            assert(isinstance(terminator, ir.Branch))
            # find the block in the loop body that header jumps to
            block_id = terminator.truebr
            blk = func_ir.blocks[block_id]
            loc = blk.loc
            blk.body.insert(0, _new_definition(func_ir, index_var,
                ir.Expr.binop(fn='-', lhs=iter_first_var,
                                      rhs=range_def[0], loc=loc),
                loc))
    else:
        # Insert index_var increment to the end of loop header
        loc = loop_header.loc
        terminator = loop_header.terminator
        stmts = loop_header.body[0:-1]
        next_index_var = ir.Var(scope, mk_unique_var("next_index"), loc)
        one = ir.Var(scope, mk_unique_var("one"), loc)
        # one = 1
        stmts.append(_new_definition(func_ir, one,
                     ir.Const(value=1,loc=loc), loc))
        # next_index_var = index_var + 1
        stmts.append(_new_definition(func_ir, next_index_var,
                     ir.Expr.binop(fn='+', lhs=index_var, rhs=one, loc=loc), loc))
        # index_var = next_index_var
        stmts.append(_new_definition(func_ir, index_var, next_index_var, loc))
        stmts.append(terminator)
        loop_header.body = stmts

    # In append_block, change list_append into array assign
    for i in range(len(append_block.body)):
        if append_block.body[i] == append_stmt:
            debug_print("Replace append with SetItem")
            append_block.body[i] = ir.SetItem(target=array_var, index=index_var,
                                              value=append_stmt.value.args[0], loc=append_stmt.loc)

    # replace array call, by changing "a = array(b)" to "a = b"
    stmt = func_ir.blocks[exit_block].body[array_call_index]
    # stmt can be either array call or SetItem, we only replace array call
    if isinstance(stmt, ir.Assign) and isinstance(stmt.value, ir.Expr):
        stmt.value = array_var
        func_ir._definitions[stmt.target.name] = [stmt.value]

    return True
コード例 #20
0
def get_slice_step(typemap, func_ir, var):
    require(typemap[var.name] == types.slice3_type)
    call_expr = get_definition(func_ir, var)
    require(isinstance(call_expr, ir.Expr) and call_expr.op == 'call')
    assert len(call_expr.args) == 3
    return call_expr.args[2]