Exemplo n.º 1
0
class LLVMVector(types.GenericVector):
    elemtype = Descriptor(constant=True,
                          constrains=instanceof(types.BuiltinType))
    elemcount = Descriptor(constant=True, constrains=lambda N: N > 1)

    def type(self):
        return llvm.TypeFactory.make_vector(self.elemtype.type(),
                                            self.elemcount)

    def cast(self, old, builder):
        from values import LLVMTempValue, LLVMConstant
        if not isinstance(old.type, LLVMVector):  # from scalar to vector
            elem_casted = self.elemtype.cast(old, builder)
            vector = llvm.ConstantFactory.make_undef(self.type())
            for i in range(self.elemcount):
                idx = LLVMConstant(LLVMType(types.Int), i).value(builder)
                vector = builder.insert_element(vector, elem_casted, idx)
            return vector
        if old.type.elemtype != self.elemtype:
            raise TypeError('Invalid casting')

        return old.value(builder)

    def coerce(self, other):
        if not isinstance(other, LLVMVector):  # other is of scalar type
            # then, try to promote it to a vector type
            return self

        if other.elemtype != self.elemtype:
            raise TypeError('Invalid casting')

        return self

    def argument_adaptor(self, val):
        raise NotImplementedError('Cannot use vector as argument.')
Exemplo n.º 2
0
class LLVMConstant(LLVMValue):
    constant = Descriptor(constant=True)

    def __init__(self, ty, val):
        self.type = ty
        self.constant = self.type.constant(val)

    def value(self, builder):
        return self.constant
Exemplo n.º 3
0
class LLVMVariable(LLVMValue):
    pointer = Descriptor(constant=True)

    def __init__(self, name, ty, builder):
        self.type = ty
        self.pointer = builder.alloc(ty.type(), name)

    def value(self, builder):
        return builder.load(self.pointer)
Exemplo n.º 4
0
class LLVMTempValue(LLVMValue):
    temp_value = Descriptor(constant=True)

    def __init__(self, val, ty):
        self.type = ty
        self.temp_value = val

    def value(self, builder):
        return self.temp_value
Exemplo n.º 5
0
class LLVMUnboundedArray(types.GenericUnboundedArray):

    elemtype = Descriptor(constant=True,
                          constrains=instanceof(types.BuiltinType))

    def cast(self, old, builder):
        if isinstance(old.type, LLVMUnboundedArray):
            return old.value(builder)
        else:
            raise TypeError('Casting unbounded-array to something else.')

    def ctype(self):
        return ctypes.POINTER(self.elemtype.ctype())

    def type(self):
        return llvm.TypeFactory.make_pointer(self.elemtype.type())

    def argument_adaptor(self, val):
        try:  # try to use numpy.ndarray
            from numpy import ndarray
            if isinstance(val, ndarray):
                if val.dtype != self.elemtype.ctype():
                    raise TypeError('dtype of the numpy.ndarray '
                                    'does not match argument type.')
                return val.ctypes.data_as(self.ctype())
        except ImportError:
            pass

        # No numpy or val is not ndarray.
        # Try to use array.array
        from array import array
        if isinstance(val, array):
            elemctype = self.elemtype.ctype()

            if _array_type_code_to_ctype[val.typecode] != elemctype:
                raise TypeError('array.array contains a different datatype.')

            address, length = val.buffer_info()
            ptr = ctypes.cast(address, self.ctype())
            return ptr

        # Build a ctype array from iterable. This can be very slow.
        ctype = self.elemtype.ctype()
        argtype = ctype * len(val)
        return argtype(*val)
Exemplo n.º 6
0
class CodeGenerationBase(ast.NodeVisitor):

    symbols = Descriptor(constant=True, constrains=instanceof(dict))
    __nodes = Descriptor(constant=True)

    def __init__(self, globalsymbols):
        '''
        globalsymbols -- A dict containing global symbols for the function.
        '''
        self.symbols = globalsymbols.copy()
        self.__nodes = []

    @property
    def current_node(self):
        return self.__nodes[-1]

    def visit(self, node):
        try:
            fn = getattr(self, 'visit_%s' % type(node).__name__)
        except AttributeError as e:
            logger.exception(e)
            logger.error('Unhandled visit to %s', ast.dump(node))
            raise InternalError(node, 'Not yet implemented.')
        else:
            try:
                self.__nodes.append(node)  # push current node
                return fn(node)
            except TypeError as e:
                logger.exception(e)
                raise InternalError(node, str(e))
            except (NotImplementedError, AssertionError) as e:
                logger.exception(e)
                raise InternalError(node, str(e))
            finally:
                self.__nodes.pop()  # pop current node

    def visit_FunctionDef(self, node):
        # function def
        with self.generate_function(node.name) as fndef:
            # arguments
            self.visit(node.args)
            # function body
            if (isinstance(node.body[0], ast.Expr)
                    and isinstance(node.body[0].value, ast.Str)):
                # Python doc string
                logger.info('Ignoring python doc string.')
                statements = node.body[1:]
            else:
                statements = node.body

            for stmt in statements:
                self.visit(stmt)
        # close function

    def visit_arguments(self, node):
        if node.vararg or node.kwarg or node.defaults:
            raise FunctionDeclarationError(
                'Does not support variable/keyword/default arguments.')

        arguments = []
        for arg in node.args:
            if not isinstance(arg.ctx, ast.Param):
                raise InternalError('Argument is not ast.Param?')
            name = arg.id
            arguments.append(name)

        if len(set(arguments)) != len(arguments):
            raise InternalError(
                node, '''Argument redeclared.
This error should have been caught by the Python parser.''')

        self.generate_function_arguments(arguments)

    def generate_function(self, name):
        raise NotImplementedError

    def generate_function_arguments(self, arguments):
        raise NotImplementedError

    def visit_Pass(self, node):
        pass

    def visit_Call(self, node):
        fn = self.visit(node.func)

        if type(fn) is type and issubclass(fn, dialect.Construct):
            # Special construct for our dialect
            try:
                handler = {
                    dialect.var: self.construct_var,
                }[fn]
            except KeyError:
                raise NotImplementedError(
                    'This construct has yet to be implemented.')
            else:
                return handler(fn, node)
        else:  # is a real function call
            if node.keywords or node.starargs or node.kwargs:
                raise InvalidCall(node,
                                  'Cannot use keyword or star arguments.')

            args = map(self.visit, node.args)
            return self.generate_call(fn, args)  # return value

        raise InternalError(self.current_node, 'Unreachable')

    def generate_declare(self, name, ty):
        raise NotImplementedError

    def construct_var(self, fn, node):
        if fn is not dialect.var:
            raise AssertionError('Implementation error.')

        if node.args:
            raise InvalidUseOfConstruct(
                node,
                ('Construct "var" must contain at least one'
                 ' keyword arguments in the form of "var ( name=type )".'),
            )

        # for each defined variable
        for kw in node.keywords:
            #ty = self.visit(kw.value) # type
            ty = self.extract_type(kw.value)
            name = kw.arg  # name
            if name in self.symbols:
                raise VariableRedeclarationError(kw.value)
            variable = self.generate_declare(name, ty)
            # store new variable to symbol table
            self.symbols[name] = variable
        return  # return None

    def extract_type(self, node):
        if isinstance(node, ast.Name):  # simple symbols
            if not isinstance(node.ctx, ast.Load):
                raise InternalError(node, 'Only load context is possible.')
            return self.symbols[node.id]
        elif isinstance(node, ast.Call):  # complex type
            fn = self.visit(node.func)
            if not issubclass(fn, types.DummyType):
                raise AssertionError('Not a dummy type.')
            if fn is types.Slice:
                # Defining a slice type
                if len(node.args) != 1:
                    raise InvalidUseOfConstruct(
                        node, ('Slice constructor takes 1 arguments.\n'
                               'Hint: Slice(ElemType)'))

                if node.keywords or node.starargs or node.kwargs:
                    raise InvalidUseOfConstruct(
                        node, 'Cannot use keyword or star arguments.')

                elemty = self.visit(node.args[0])

                if type(elemty) is not type or not issubclass(
                        elemty, types.Type):
                    raise InvalidUseOfConstruct(
                        node, 'Expecting a type for element type of array.')

                newclsname = '__CustomSlice__%s' % (elemty.__name__)
                newcls = type(
                    newclsname,
                    (types.GenericUnboundedArray, ),
                    {
                        # List all class members of the new array type.
                        'elemtype': elemty,
                    })

                return newcls  # return the new array type class object
            elif fn is types.Vector:
                # Defining a vector type
                if len(node.args) != 2:
                    raise InvalidUseOfConstruct(
                        node, ('Vector constructor takes two arguments.\n'
                               'Hint: Vector(ElemType, ElemCount)'))

                if node.keywords or node.starargs or node.kwargs:
                    raise InvalidUseOfConstruct(
                        node, 'Cannot use keyword or star arguments.')

                elemty = self.visit(node.args[0])
                elemct = self.constant_number(node.args[1])

                if type(elemty) is not type or not issubclass(
                        elemty, types.Type):
                    raise InvalidUseOfConstruct(
                        node, 'Expecting a type for element type of vector.')

                if elemct <= 0:
                    raise InvalidUseOfConstruct(
                        node, 'Vector type must have at least one element.')

                newclsname = '__CustomVector__%s%d' % (elemty.__name__, elemct)
                newcls = type(
                    newclsname,
                    (types.GenericVector, ),
                    {
                        # List all class members of the new vector type.
                        'elemtype': elemty,
                        'elemcount': elemct,
                    })

                return newcls  # return the new vector type class object
            elif fn is types.Array:
                # Defining an Array type
                if len(node.args) != 2:
                    raise InvalidUseOfConstruct(
                        node, ('Array constructor takes two arguments.\n'
                               'Hint: Array(ElemType, ElemCount)'))

                if node.keywords or node.starargs or node.kwargs:
                    raise InvalidUseOfConstruct(
                        node, 'Cannot use keyword or star arguments.')

                elemty = self.visit(node.args[0])
                elemct = self.visit(
                    node.args[1])  # accept constants & variables

                if type(elemty) is not type or not issubclass(
                        elemty, types.Type):
                    raise InvalidUseOfConstruct(
                        node, 'Expecting a type for element type of array.')

                if elemct <= 0:
                    raise InvalidUseOfConstruct(
                        node, 'array type must have at least one element.')

                newclsname = '__CustomArray__%s%s' % (elemty.__name__, elemct)
                newcls = type(
                    newclsname,
                    (types.GenericBoundedArray, ),
                    {
                        # List all class members of the new array type.
                        'elemtype': elemty,
                        'elemcount': elemct,
                    })

                return newcls  # return the new array type class object

        raise InternalError(node, 'Cannot resolve type.')

    def visit_Expr(self, node):
        self.generic_visit(node)

    def visit_Attribute(self, node):
        if isinstance(node.ctx, ast.Load):
            value = self.visit(node.value)
            return getattr(value, node.attr)
        else:
            raise NotImplementedError(
                'Storing into attribute is not supported.')

    def visit_Compare(self, node):
        if len(node.ops) != 1:
            raise NotImplementedError('Multiple operators in ast.Compare')

        if len(node.comparators) != 1:
            raise NotImplementedError('Multiple comparators in ast.Compare')

        lhs = self.visit(node.left)
        rhs = self.visit(node.comparators[0])
        op = type(node.ops[0])
        return self.generate_compare(op, lhs, rhs)

    def visit_Return(self, node):
        if node.value is not None:
            value = self.visit(node.value)
            self.generate_return(value)
        else:
            self.generate_return()

    def generate_return(self, value):
        raise NotImplementedError

    def generate_compare(self, op_class, lhs, rhs):
        raise NotImplementedError

    def visit_BinOp(self, node):
        lhs = self.visit(node.left)
        rhs = self.visit(node.right)
        op = type(node.op)
        return self.generate_binop(op, lhs, rhs)

    def generate_binop(self, op_class, lhs, rhs):
        raise NotImplementedError

    def visit_Assign(self, node):
        if len(node.targets) != 1:
            raise NotImplementedError('Mutliple targets in assignment.')
        target = self.visit(node.targets[0])
        value = self.visit(node.value)

        return self.generate_assign(value, target)

    def generate_assign(self, from_value, to_target):
        raise NotImplementedError

    def constant_number(self, node):
        if isinstance(node, ast.Num):
            retval = node.n
        else:
            if not isinstance(node, ast.Name):
                raise NotImplementedError
            if not isinstance(node.ctx, ast.Load):
                raise NotImplementedError

            retval = self.symbols[node.id]
        if (not isinstance(retval, int) and not isinstance(retval, long)
                and not isinstance(retval, float)):
            raise TypeError('Not a numeric constant.')
        return retval

    def visit_Num(self, node):
        if type(node.n) is int:
            return self.generate_constant_int(node.n)
        elif type(node.n) is float:
            return self.generate_constant_real(node.n)

    def generate_constant_int(self, val):
        raise NotImplementedError

    def generate_constant_real(self, val):
        raise NotImplementedError

    def visit_Subscript(self, node):
        if isinstance(node.slice, ast.Slice):
            ptr = self.visit(node.value)
            idx = self.visit(node.slice.lower)
            if node.slice.upper or node.slice.step:
                raise NotImplementedError

            if not isinstance(ptr.type,
                              types.GenericUnboundedArray):  # only array
                raise NotImplementedError
            if not isinstance(node.ctx, ast.Load):  # only at load context
                raise NotImplementedError

            return self.generate_array_slice(ptr, idx, None, None)

        else:
            if not isinstance(node.slice, ast.Index):
                raise AssertionError(ast.dump(node.slice))
            ptr = self.visit(node.value)
            idx = self.visit(node.slice.value)
            if isinstance(ptr.type, types.GenericVector):
                # Access vector element
                if isinstance(node.ctx, ast.Load):  # load
                    return self.generate_vector_load_elem(ptr, idx)
                elif isinstance(node.ctx, ast.Store):  # store
                    return self.generate_vector_store_elem(ptr, idx)
            elif isinstance(ptr.type, types.GenericUnboundedArray):
                # Access array element
                if isinstance(node.ctx, ast.Load):  # load
                    return self.generate_array_load_elem(ptr, idx)
                elif isinstance(node.ctx, ast.Store):  # store
                    return self.generate_array_store_elem(ptr, idx)
            else:  # Unsupported types
                raise InvalidSubscriptError(node)

    def generate_array_slice(ptr, lower, upper=None, step=None):
        raise NotImplementedError

    def generate_vector_load_elem(self, ptr, idx):
        raise NotImplementedError

    def generate_vector_store_elem(self, ptr, idx):
        raise NotImplementedError

    def generate_array_load_elem(self, ptr, idx):
        raise NotImplementedError

    def generate_array_store_elem(self, ptr, idx):
        raise NotImplementedError

    def visit_Name(self, node):
        if isinstance(node.ctx, ast.Load):  # load
            try:  # lookup in the symbol table
                val = self.symbols[node.id]
            except KeyError:  # does not exist
                raise UndefinedSymbolError(node)
            else:  # load from stack
                if isinstance(val, int) or isinstance(val, long):
                    return self.generate_constant_int(val)
                elif isinstance(val, float):
                    return self.generate_constant_real(val)
                else:
                    return val
        elif isinstance(node.ctx, ast.Store):  # store
            try:
                return self.symbols[node.id]
            except KeyError:
                raise UndefinedSymbolError(node)
        # unreachable
        raise AssertionError('unreachable')

    def visit_If(self, node):
        test = self.visit(node.test)
        iftrue_body = node.body
        orelse_body = node.orelse
        if len(orelse_body) not in [0, 1]: raise AssertionError
        self.generate_if(test, iftrue_body, orelse_body)

    def visit_For(self, node):
        if node.orelse:
            raise NotImplementedError('Else in for-loop is not implemented.')
        iternode = node.iter

        str_only_support_forrange = 'Only for-range|for-xrange are supported.'
        if not isinstance(iternode, ast.Call):
            raise InvalidUseOfConstruct(str_only_support_forrange)

        looptype = iternode.func.id
        if looptype not in ['range', 'xrange']:
            raise InvalidUseOfConstruct(str_only_support_forrange)

        # counter variable
        counter_name = node.target.id
        if counter_name in self.symbols:
            raise VariableRedeclarationError(node.target)

        counter_ptr = self.generate_declare(node.target.id, types.Int)
        self.symbols[counter_name] = counter_ptr

        # range information
        iternode_arg_N = len(iternode.args)
        if iternode_arg_N == 1:  # only END is given
            zero = self.generate_constant_int(0)
            initcount = zero  # init count is implicitly zero
            endcountpos = 0
            step = self.generate_constant_int(1)
        elif iternode_arg_N == 2:  # both BEGIN and END are given
            initcount = self.visit(iternode.args[0])  # init count is given
            endcountpos = 1
            step = self.generate_constant_int(1)
        else:  # with BEGIN, END and STEP
            initcount = self.visit(iternode.args[0])  # init count is given
            endcountpos = 1
            step = self.visit(iternode.args[2])  # step is given

        endcount = self.visit(iternode.args[endcountpos])  # end count

        loopbody = node.body
        self.generate_for_range(counter_ptr, initcount, endcount, step,
                                loopbody)

    def generate_for_range(self, counter, init, end, step, body):
        raise NotImplementedError

    def visit_BoolOp(self, node):
        if len(node.values) != 2: raise AssertionError
        return self.generate_boolop(node.op, node.values[0], node.values[1])

    def generate_boolop(self, op_class, lhs, rhs):
        raise NotImplementedError

    def visit_UnaryOp(self, node):
        operand = self.visit(node.operand)
        if isinstance(node.op, ast.Not):
            return self.generate_not(operand)
        raise NotImplementedError(ast.dump(node))

    def generate_not(self, operand):
        raise NotImplementedError

    def visit_AugAssign(self, node):
        target = self.visit(node.target)
        node.target.ctx = ast.Load()  # change context to load
        target_val = self.visit(node.target)
        value = self.visit(node.value)

        result = self.generate_binop(type(node.op), target_val, value)
        return self.generate_assign(result, target)

    def visit_While(self, node):
        if node.orelse:
            raise NotImplementedError('Else in for-loop is not implemented.')
        self.generate_while(node.test, node.body)

    def generate_while(self, test, body):
        raise NotImplementedError
Exemplo n.º 7
0
class LLVMCodeGenerator(CodeGenerationBase):
    retty         = Descriptor(constant=True)
    argtys        = Descriptor(constant=True)
    function      = Descriptor(constant=True)
    entry_block   = Descriptor(constant=True)

    def __init__(self, fnobj, retty, argtys, symbols):
        super(LLVMCodeGenerator, self).__init__(symbols)
        self.function = fnobj
        self.retty = retty
        self.argtys = argtys

    @contextmanager
    def generate_function(self, name):
        if not self.function.valid():
            raise FunctionDeclarationError(
                    self.current_node,
                    self.jit_engine.last_error()
                )

        self.symbols[name] = self.function

        # make basic block
        self.entry_block = self.function.append_basic_block("entry")
        self.__blockcounter = 0

        # make instruction builder
        self.builder = llvm.Builder()
        bb_body = self.function.append_basic_block("body")
        self.builder.insert_at(bb_body)

        yield # wait until args & body are generated

        # link entry to body
        bb_last = self.builder.get_basic_block() # remember last block
        self.builder.insert_at(self.entry_block)         # goto entry block
        self.builder.branch(bb_body)             # branch to body
        self.builder.insert_at(bb_last)                  # return to last block

        # close function
        if not self.builder.is_block_closed():
            if isinstance(self.retty, types.Void):
                # no return
                self.builder.ret_void()
            else:
                raise MissingReturnError(self.current_node)

    def generate_function_arguments(self, arguments):
        with self.relocate_to_entry():
            fn_args = self.function.arguments()
            for i, name in enumerate(arguments):
                try:
                    var = LLVMVariable(name, self.argtys[i], self.builder)
                except IndexError:
                    raise FunctionDeclarationError(
                            self.current_node,
                            'Actual number of argument mismatch declaration.')
                self.builder.store(fn_args[i], var.pointer)
                self.symbols[name] = var

    def generate_call(self, fn, args):
        from function import LLVMFunction
        if isinstance(fn, LLVMFunction): # another function
            retty = fn.retty
            argtys = fn.argtys
            fn = fn.code_llvm
        elif fn is self.function: # recursion
            retty = self.retty
            argtys = self.argtys
        else:
            raise InvalidCall(self.current_node)

        return self._call_function(fn, args, retty, argtys)

    def generate_assign(self, from_value, to_target):
        casted = to_target.type.cast(from_value, self.builder)
        self.builder.store(casted, to_target.pointer)
        return casted

    def generate_compare(self, op_class, lhs, rhs):
        ty = lhs.type.coerce(rhs.type)
        lval = ty.cast(lhs, self.builder)
        rval = ty.cast(rhs, self.builder)
        fn = getattr(ty, 'op_%s'%op_class.__name__.lower())
        pred = fn(lval, rval, self.builder)
        return LLVMTempValue(pred, LLVMType(types.Bool))

    def generate_return(self, value=None):
        if value is None: # no return value
            self.builder.ret_void()
            return
        if isinstance(self.retty, LLVMVoid):
            raise InvalidReturnError(
                    self.current_node,
                    'This function does not return any value.'
                  )
        casted = self.retty.cast(value, self.builder)
        self.builder.ret(casted)

    def generate_binop(self, op_class, lhs, rhs):
        ty = lhs.type.coerce(rhs.type)
        lval = ty.cast(lhs, self.builder)
        rval = ty.cast(rhs, self.builder)

        fn = getattr(ty, 'op_%s'%op_class.__name__.lower())
        return LLVMTempValue(fn(lval, rval, self.builder), ty)

    def generate_constant_int(self, value):
        return LLVMConstant(LLVMType(types.Int), value)

    def generate_constant_real(self, value):
        return LLVMConstant(LLVMType(types.Double), value)

    def generate_declare(self, name, ty):
        with self.relocate_to_entry():
            if issubclass(ty, types.GenericBoundedArray): # array
                return LLVMArrayVariable(name, LLVMType(ty), ty.elemcount.value(self.builder), self.builder)
            else: # other types
                realty = LLVMType(ty)
                return LLVMVariable(name, realty, self.builder)


    def _call_function(self, fn, args, retty, argtys):
        arg_values = map(lambda X: LLVMTempValue(X.value(self.builder), X.type), args)
        # cast types
        try:
            for i, argty in enumerate(argtys):
                arg_values[i] = argty.cast(arg_values[i], self.builder)
        except IndexError:
            raise InvalidCall(self.current_node, 'Number of argument mismatch')
        out = self.builder.call(fn, arg_values)
        return LLVMTempValue(out, retty)

    def new_basic_block(self, name='uname'):
        self.__blockcounter += 1
        return self.function.append_basic_block('%s_%d'%(name, self.__blockcounter))

    def generate_vector_load_elem(self, ptr, idx):
        elemval = self.builder.extract_element(
                    ptr.value(self.builder),
                    idx.value(self.builder),
                  )
        return LLVMTempValue(elemval, ptr.type.elemtype)

    def generate_vector_store_elem(self, ptr, idx):
        zero = self.generate_constant_int(0)
        indices = map(lambda X: X.value(self.builder), [zero, idx])
        addr = self.builder.gep2(ptr.pointer, indices)
        return LLVMTempPointer(addr, ptr.type.elemtype)

    def generate_array_load_elem(self, ptr, idx):
        ptr_val = ptr.value(self.builder)
        idx_val = idx.value(self.builder)
        ptr_offset = self.builder.gep(ptr_val, idx_val)
        return LLVMTempValue(self.builder.load(ptr_offset), ptr.type.elemtype)

    def generate_array_store_elem(self, ptr, idx):
        ptr_val = ptr.value(self.builder)
        idx_val = idx.value(self.builder)
        ptr_offset = self.builder.gep(ptr_val, idx_val)
        return LLVMTempPointer(ptr_offset, ptr.type.elemtype)

    def generate_if(self, test, iftrue, orelse):
        bb_if = self.new_basic_block('if')
        bb_else = self.new_basic_block('else')
        bb_endif = self.new_basic_block('endif')
        is_endif_reachable = False

        boolean = test.value(self.builder)
        self.builder.cond_branch(boolean, bb_if, bb_else)

        # true branch
        self.builder.insert_at(bb_if)
        for stmt in iftrue:
            self.visit(stmt)
        else:
            if not self.builder.is_block_closed():
                self.builder.branch(bb_endif)
                is_endif_reachable=True

        # false branch
        self.builder.insert_at(bb_else)
        for stmt in orelse:
            self.visit(stmt)
        else:
            if not self.builder.is_block_closed():
                self.builder.branch(bb_endif)
                is_endif_reachable=True

        # endif
        self.builder.insert_at(bb_endif)
        if not is_endif_reachable:
            self.builder.unreachable()

    def generate_while(self, test, body):
        bb_cond = self.new_basic_block('loopcond')
        bb_body = self.new_basic_block('loopbody')
        bb_exit = self.new_basic_block('loopexit')

        self.builder.branch(bb_cond)

        # condition
        self.builder.insert_at(bb_cond)
        cond = self.visit(test)
        self.builder.cond_branch(cond.value(self.builder), bb_body, bb_exit)

        # body

        self.builder.insert_at(bb_body)

        for stmt in body:
            self.visit(stmt)
        else:
            self.builder.branch(bb_cond)
            # Not sure if it is necessary
            #            if not self.builder.is_block_closed():
            #                self.builder.branch(bb_cond)

        # end loop
        self.builder.insert_at(bb_exit)

    def generate_for_range(self, counter_ptr, initcount, endcount, step, loopbody):

        self.builder.store(initcount.value(self.builder), counter_ptr.pointer)

        bb_cond = self.new_basic_block('loopcond')
        bb_body = self.new_basic_block('loopbody')
        bb_incr = self.new_basic_block('loopincr')
        bb_exit = self.new_basic_block('loopexit')

        self.builder.branch(bb_cond)

        # condition
        self.builder.insert_at(bb_cond)
        test = self.builder.icmp(llvm.ICMP_SLT, counter_ptr.value(self.builder), endcount.value(self.builder))
        self.builder.cond_branch(test, bb_body, bb_exit)

        # body
        self.builder.insert_at(bb_body)

        for stmt in loopbody:
            self.visit(stmt)
        else:
            self.builder.branch(bb_incr)
            # Not sure if it is necessary
            #            if not self.builder.is_block_closed():
            #                self.builder.branch(bb_incr)

        # incr
        self.builder.insert_at(bb_incr)

#        counter_next = self.builder.add(counter_ptr.value(self.builder),
#                                        step.value(self.builder))

        counter_next = counter_ptr.type.op_add(counter_ptr.value(self.builder),
                                               step.value(self.builder),
                                               self.builder)

        self.builder.store(counter_next, counter_ptr.pointer)
        self.builder.branch(bb_cond)

        # exit
        self.builder.insert_at(bb_exit)

    def generate_boolop(self, op_class, lhs, rhs):
        bb_left = self.builder.get_basic_block()
        boolty = LLVMType(types.Bool)

        left = boolty.cast(self.visit(lhs), self.builder)

        bb_right = self.new_basic_block('bool_right')
        bb_result = self.new_basic_block('bool_result')

        if isinstance(op_class, ast.And):
            self.builder.cond_branch(left, bb_right, bb_result)
        elif isinstance(op_class, ast.Or):
            self.builder.cond_branch(left, bb_result, bb_right)
        else:
            raise AssertionError('Unknown Boolean operator')

        self.builder.insert_at(bb_right)
        right = boolty.cast(self.visit(rhs), self.builder)
        self.builder.branch(bb_result)

        self.builder.insert_at(bb_result)
        pred = self.builder.phi(boolty.type(), [bb_left, bb_right], [left, right]);
        return LLVMTempValue(pred, boolty)

    def generate_not(self, operand):
        boolty = LLVMType(types.Bool)
        boolval = boolty.cast(operand, self.builder)
        negated = boolty.op_not(boolval, self.builder)
        return LLVMTempValue(negated, boolty)

    def generate_array_slice(self, ptr, lower, upper=None, step=None):
        assert upper is None
        assert step is None
        ptr_val = ptr.value(self.builder)
        lower_val = lower.value(self.builder)
        offsetted = self.builder.gep(ptr_val, lower_val)
        return LLVMTempValue(offsetted, ptr.type)

    @contextmanager
    def relocate_to_entry(self):
        # goto entry block
        bb_last = self.builder.get_basic_block()
        self.builder.insert_at(self.entry_block)
        yield # relocated
        # pickup at last block
        self.builder.insert_at(bb_last)
Exemplo n.º 8
0
class LLVMModule(object):
    jit_engine = Descriptor(constant=True)

    def __init__(self, name, optlevel=3, vectorize=True):
        self.jit_engine = llvm.JITEngine(name, optlevel, vectorize)

    def optimize(self):
        self.jit_engine.optimize()

    def verify(self):
        self.jit_engine.verify()

    def dump_asm(self, fn):
        return self.jit_engine.dump_asm(fn)

    def dump(self):
        return self.jit_engine.dump()

    def _new_func_def_or_decl(self, ret, args, name_or_func):
        from function import LLVMFuncDef, LLVMFuncDecl, LLVMFuncDef_BoolRet
        is_func_def = not isinstance(name_or_func, basestring)
        if is_func_def:
            func = name_or_func
            namespace = func.func_globals['__name__']
            realname = '.'.join([namespace, func.__name__])
        else:
            name = name_or_func
            realname = name

        # workaround for boolean return type
        is_ret_bool = False
        if ret is types.Bool:
            # Change return type to 8-bit int
            retty = LLVMType(types.Int8)
            is_ret_bool = True
            logger.warning(
                'Using workaround (change to Int8) for boolean return type.')
        else:
            retty = LLVMType(ret)

        # workaround for boolean argument type
        argtys = []
        count_converted_boolean = 0
        for arg in args:
            if arg is types.Bool:
                argtys.append(LLVMType(types.Int8))
                count_converted_boolean += 1
            else:
                argtys.append(LLVMType(arg))
        else:
            if count_converted_boolean:
                logger.warning(
                    'Using workaround (changed to Int8) for boolean argument type.'
                )

        fn_decl = self.jit_engine.make_function(
            realname,
            retty.type(),
            map(lambda X: X.type(), argtys),
        )

        if fn_decl.name() != realname:
            raise NameError('Generated function has a different name: %s' %
                            (fn_decl.name()))

        if is_func_def:
            if is_ret_bool:
                return LLVMFuncDef_BoolRet(func, retty, argtys, self, fn_decl)
            else:
                return LLVMFuncDef(func, retty, argtys, self, fn_decl)
        else:
            return LLVMFuncDecl(retty, argtys, self, fn_decl)

    def new_function(self, func, ret, args):
        return self._new_func_def_or_decl(ret, args, func)

    def new_declaration(self, realname, ret, args):
        return self._new_func_def_or_decl(ret, args, realname)
Exemplo n.º 9
0
class LLVMValue(object):
    type = Descriptor(constant=True)

    __init__ = NotImplemented
Exemplo n.º 10
0
class LLVMTempPointer(LLVMValue):
    pointer = Descriptor(constant=True)

    def __init__(self, ptr, ty):
        self.type = ty
        self.pointer = ptr
Exemplo n.º 11
0
class LLVMFuncDef(LLVMFunction):
    code_python = Descriptor(constant=True)

    c_funcptr_type = Descriptor(constant=True)
    c_funcptr = Descriptor(constant=True)

    manager = Descriptor(constant=True)

    def __init__(self, fnobj, retty, argtys, module, fn_decl):
        self.code_python = fnobj
        self.retty = retty
        self.argtys = argtys
        self.manager = module
        self.code_llvm = fn_decl

    def compile(self):
        from pymothoa.compiler_errors import CompilerError, wrap_by_function

        func = self.code_python
        source = inspect.getsource(func)

        logger.debug('Compiling function: %s', func.__name__)

        tree = ast.parse(source)

        assert type(tree).__name__=='Module'
        assert len(tree.body)==1

        # Code generation for LLVM
        try:
            codegen = LLVMCodeGenerator(
                            self.code_llvm,
                            self.retty,
                            self.argtys,
                            symbols=func.func_globals
                        )
            codegen.visit(tree.body[0])
        except CompilerError as e:
            raise wrap_by_function(e, func)

        self.code_llvm.verify()     # verify generated code
        self.manager.jit_engine.optimize_function(self.code_llvm) # optimize generated code to reduce space

        logger.debug('Dump LLVM IR\n%s', self.code_llvm.dump())

    def assembly(self):
        return self.manager.dump_asm(self.code_llvm)

    def prepare_pointer_to_function(self):
        '''Obtain pointer to function from the JIT engine'''
        addr = self.manager.jit_engine.get_pointer_to_function(self.code_llvm)
        # Create binding with ctypes library
        from ctypes import CFUNCTYPE, cast
        c_argtys = map(lambda T: T.ctype(), self.argtys)
        c_retty = self.retty.ctype()
        self.c_funcptr_type = CFUNCTYPE(c_retty, *c_argtys)
        self.c_funcptr = cast( int(addr), self.c_funcptr_type )

    def run_py(self, *args):
        return self.code_python(*args)

    def run_jit(self, *args):
        from itertools import izip
        # Cast the arguments to corresponding types
        argvals = []
        for aty, aval in izip(self.argtys, args):
            argvals.append(aty.argument_adaptor(aval))

        try:
            retval = self.c_funcptr(*argvals)
        except AttributeError: # Has not create binding to the function.
            self.prepare_pointer_to_function()
            # Call the function
            retval = self.c_funcptr(*argvals)

        return retval

    __call__ = run_jit
Exemplo n.º 12
0
class LLVMFunction(object):
    retty = Descriptor(constant=True)
    argtys = Descriptor(constant=True)

    code_llvm = Descriptor(constant=True)