示例#1
0
def _transpile_function(func, attributes, mode, consts, in_types, ret_type):
    """Transpile the function
    Args:
        func (ast.FunctionDef): Target function.
        attributes (str): The attributes of target function.
        mode ('numpy' or 'cuda'): The rule for typecast.
        consts (dict): The dictionary with keys as variable names and
            values as concrete data object.
        in_types (list of _types.TypeBase): The types of arguments.
        ret_type (_types.TypeBase): The type of return value.

    Returns:
        code (str): The generated CUDA code.
        env (Environment): More details of analysis result of the function,
            which includes preambles, estimated return type and more.
    """
    consts = dict([(k, Constant(v)) for k, v, in consts.items()])

    if not isinstance(func, ast.FunctionDef):
        # TODO(asi1024): Support for `ast.ClassDef`.
        raise NotImplementedError('Not supported: {}'.format(type(func)))
    if len(func.decorator_list) > 0:
        if sys.version_info >= (3, 9):
            # Code path for Python versions that support `ast.unparse`.
            for deco in func.decorator_list:
                deco_code = ast.unparse(deco)
                if deco_code not in ['rawkernel', 'vectorize']:
                    warnings.warn(
                        f'Decorator {deco_code} may not supported in JIT.',
                        RuntimeWarning)
    arguments = func.args
    if arguments.vararg is not None:
        raise NotImplementedError('`*args` is not supported currently.')
    if len(arguments.kwonlyargs) > 0:  # same length with `kw_defaults`.
        raise NotImplementedError(
            'keyword only arguments are not supported currently .')
    if arguments.kwarg is not None:
        raise NotImplementedError('`**kwargs` is not supported currently.')
    if len(arguments.defaults) > 0:
        raise NotImplementedError(
            'Default values are not supported currently.')

    args = [arg.arg for arg in arguments.args]
    if len(args) != len(in_types):
        raise TypeError(
            f'{func.name}() takes {len(args)} positional arguments '
            f'but {len(in_types)} were given.')
    params = dict([(x, CudaObject(x, t)) for x, t in zip(args, in_types)])
    env = Environment(mode, consts, params, ret_type)
    body = _transpile_stmts(func.body, True, env)
    params = ', '.join([env[a].ctype.declvar(a) for a in args])
    local_vars = [v.ctype.declvar(n) + ';' for n, v in env.locals.items()]

    if env.ret_type is None:
        env.ret_type = _types.Void()

    head = f'{attributes} {env.ret_type} {func.name}({params})'
    code = CodeBlock(head, local_vars + body)
    return str(code), env
示例#2
0
def _transpile_function(
        func, attributes, mode, consts, in_types, ret_type):
    """Transpile the function
    Args:
        func (ast.FunctionDef): Target function.
        attributes (str): The attributes of target function.
        mode ('numpy' or 'cuda'): The rule for typecast.
        consts (dict): The dictionary with keys as variable names and
            values as concrete data object.
        in_types (list of _types.TypeBase): The types of arguments.
        ret_type (_types.TypeBase): The type of return value.

    Returns:
        code (str): The generated CUDA code.
        env (Environment): More details of analysis result of the function,
            which includes preambles, estimated return type and more.
    """
    if not isinstance(func, ast.FunctionDef):
        # TODO(asi1024): Support for `ast.ClassDef`.
        raise NotImplementedError('Not supported: {}'.format(type(func)))
    if len(func.decorator_list) > 0:
        raise NotImplementedError('Decorator is not supported')
    arguments = func.args
    if arguments.vararg is not None:
        raise NotImplementedError('`*args` is not supported currently.')
    if len(arguments.kwonlyargs) > 0:  # same length with `kw_defaults`.
        raise NotImplementedError(
            'keyword only arguments are not supported currently .')
    if arguments.kwarg is not None:
        raise NotImplementedError('`**kwargs` is not supported currently.')
    if len(arguments.defaults) > 0:
        raise NotImplementedError(
            'Default values are not supported currently.')

    args = [arg.arg for arg in arguments.args]
    if len(args) != len(in_types):
        raise TypeError(
            f'{func.name}() takes {len(args)} positional arguments '
            'but {len(in_types)} were given.')
    env = Environment(
        mode,
        dict([(k, Constant(v)) for k, v, in consts.items()]),
        dict([(x, CudaObject(x, t)) for x, t in zip(args, in_types)]),
        ret_type)
    body = _transpile_stmts(func.body, True, env)
    params = ', '.join([f'{env[a].ctype} {a}' for a in args])
    local_vars = [f'{v.ctype} {n};' for n, v in env.locals.items()]
    head = f'{attributes} {env.ret_type} {func.name}({params})'
    code = CodeBlock(head, local_vars + body)
    return str(code), env
示例#3
0
def _transpile_function_internal(func, attributes, mode, consts, in_types,
                                 ret_type):
    consts = dict([(k, Constant(v)) for k, v, in consts.items()])

    if not isinstance(func, ast.FunctionDef):
        # TODO(asi1024): Support for `ast.ClassDef`.
        raise NotImplementedError('Not supported: {}'.format(type(func)))
    if len(func.decorator_list) > 0:
        if sys.version_info >= (3, 9):
            # Code path for Python versions that support `ast.unparse`.
            for deco in func.decorator_list:
                deco_code = ast.unparse(deco)
                if not any(word in deco_code
                           for word in ['rawkernel', 'vectorize']):
                    warnings.warn(
                        f'Decorator {deco_code} may not supported in JIT.',
                        RuntimeWarning)
    arguments = func.args
    if arguments.vararg is not None:
        raise NotImplementedError('`*args` is not supported currently.')
    if len(arguments.kwonlyargs) > 0:  # same length with `kw_defaults`.
        raise NotImplementedError(
            'keyword only arguments are not supported currently .')
    if arguments.kwarg is not None:
        raise NotImplementedError('`**kwargs` is not supported currently.')
    if len(arguments.defaults) > 0:
        raise NotImplementedError(
            'Default values are not supported currently.')

    args = [arg.arg for arg in arguments.args]
    if len(args) != len(in_types):
        raise TypeError(
            f'{func.name}() takes {len(args)} positional arguments '
            f'but {len(in_types)} were given.')
    params = dict([(x, CudaObject(x, t)) for x, t in zip(args, in_types)])
    env = Environment(mode, consts, params, ret_type)
    body = _transpile_stmts(func.body, True, env)
    params = ', '.join([env[a].ctype.declvar(a) for a in args])
    local_vars = [v.ctype.declvar(n) + ';' for n, v in env.locals.items()]

    if env.ret_type is None:
        env.ret_type = _types.Void()

    head = f'{attributes} {env.ret_type} {func.name}({params})'
    code = CodeBlock(head, local_vars + body)
    return str(code), env
示例#4
0
def _transpile_stmt(stmt, is_toplevel, env):
    """Transpile the statement.

    Returns (list of [CodeBlock or str]): The generated CUDA code.
    """

    if isinstance(stmt, ast.ClassDef):
        raise NotImplementedError('class is not supported currently.')
    if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
        raise NotImplementedError(
            'Nested functions are not supported currently.')
    if isinstance(stmt, ast.Return):
        value = _transpile_expr(stmt.value, env)
        value = _to_cuda_object(value, env)
        t = value.ctype
        if env.ret_type is None:
            env.ret_type = t
        elif env.ret_type != t:
            raise ValueError(
                f'Failed to infer the return type: {env.ret_type} or {t}')
        return [f'return {value.code};']
    if isinstance(stmt, ast.Delete):
        raise NotImplementedError('`del` is not supported currently.')

    if isinstance(stmt, ast.Assign):
        if len(stmt.targets) != 1:
            raise NotImplementedError('Not implemented.')
        target = stmt.targets[0]
        if not isinstance(target, ast.Name):
            raise NotImplementedError('Tuple is not supported.')
        name = target.id
        value = _transpile_expr(stmt.value, env)

        if is_constants([value]):
            if not isinstance(value.obj, _typeclasses):
                if is_toplevel:
                    if env[name] is not None and not is_constants([env[name]]):
                        raise TypeError(f'Type mismatch of variable: `{name}`')
                    env.consts[name] = value
                    return []
                else:
                    raise TypeError(
                        'Cannot assign constant value not at top-level.')
            value = _to_cuda_object(value, env)

        if env[name] is None:
            env[name] = CudaObject(target.id, value.ctype)
        elif is_constants([env[name]]):
            raise TypeError('Type mismatch of variable: `{name}`')
        elif env[name].ctype.dtype != value.ctype.dtype:
            raise TypeError(
                f'Data type mismatch of variable: `{name}`: '
                f'{env[name].ctype.dtype} != {value.ctype.dtype}')
        return [f'{target.id} = {value.code};']

    if isinstance(stmt, ast.AugAssign):
        value = _transpile_expr(stmt.value, env)
        target = _transpile_expr(stmt.target, env)
        assert isinstance(target, CudaObject)
        value = _to_cuda_object(value, env)
        result = _eval_operand(stmt.op, (target, value), env)
        if not numpy.can_cast(
                result.ctype.dtype, target.ctype.dtype, 'same_kind'):
            raise TypeError('dtype mismatch')
        return [f'{target.code} = {result.code};']

    if isinstance(stmt, ast.For):
        if len(stmt.orelse) > 0:
            raise NotImplementedError('while-else is not supported.')
        name = stmt.target.id
        iters = _transpile_expr(stmt.iter, env)

        if env[name] is None:
            env[name] = CudaObject(stmt.target.id, iters.ctype)
        elif env[name].ctype.dtype != iters.ctype.dtype:
            raise TypeError(
                f'Data type mismatch of variable: `{name}`: '
                f'{env[name].ctype.dtype} != {iters.ctype.dtype}')

        body = _transpile_stmts(stmt.body, False, env)

        if not isinstance(iters, Range):
            raise NotImplementedError(
                'for-loop is supported only for range iterator.')

        init_code = (f'{iters.ctype} '
                     f'__it = {iters.start.code}, '
                     f'__stop = {iters.stop.code}, '
                     f'__step = {iters.step.code}')
        cond = f'__step >= 0 ? __it < __stop : __it > __stop'
        if iters.step_is_positive is True:
            cond = f'__it < __stop'
        elif iters.step_is_positive is False:
            cond = f'__it > __stop'

        head = f'for ({init_code}; {cond}; __it += __step)'
        return [CodeBlock(head, [f'{name} = __it;'] + body)]

    if isinstance(stmt, ast.AsyncFor):
        raise ValueError('`async for` is not allowed.')
    if isinstance(stmt, ast.While):
        if len(stmt.orelse) > 0:
            raise NotImplementedError('while-else is not supported.')
        condition = _transpile_expr(stmt.test, env)
        condition = _astype_scalar(condition, _types.bool_, 'unsafe', env)
        condition = _to_cuda_object(condition, env)
        body = _transpile_stmts(stmt.body, False, env)
        head = f'while ({condition.code})'
        return [CodeBlock(head, body)]
    if isinstance(stmt, ast.If):
        condition = _transpile_expr(stmt.test, env)
        if is_constants([condition]):
            stmts = stmt.body if condition.obj else stmt.orelse
            return _transpile_stmts(stmts, is_toplevel, env)
        head = f'if ({condition.code})'
        then_body = _transpile_stmts(stmt.body, False, env)
        else_body = _transpile_stmts(stmt.orelse, False, env)
        return [CodeBlock(head, then_body), CodeBlock('else', else_body)]
    if isinstance(stmt, (ast.With, ast.AsyncWith)):
        raise ValueError('Switching contexts are not allowed.')
    if isinstance(stmt, (ast.Raise, ast.Try)):
        raise ValueError('throw/catch are not allowed.')
    if isinstance(stmt, ast.Assert):
        value = _transpile_expr(stmt.test, env)
        if is_constants([value]):
            assert value.obj
            return [';']
        else:
            return ['assert(' + value + ');']
    if isinstance(stmt, (ast.Import, ast.ImportFrom)):
        raise ValueError('Cannot import modules from the target functions.')
    if isinstance(stmt, (ast.Global, ast.Nonlocal)):
        raise ValueError('Cannot use global/nonlocal in the target functions.')
    if isinstance(stmt, ast.Expr):
        value = _transpile_expr(stmt.value, env)
        return [';'] if is_constants([value]) else [value + ';']
    if isinstance(stmt, ast.Pass):
        return [';']
    if isinstance(stmt, ast.Break):
        raise NotImplementedError('Not implemented.')
    if isinstance(stmt, ast.Continue):
        raise NotImplementedError('Not implemented.')
    assert False