Esempio n. 1
0
def _transpile_assign_stmt(target, env, value, is_toplevel, depth=0):
    if isinstance(target, ast.Name):
        name = target.id
        if env[name] is None:
            env.locals[name] = Data(name, value.ctype)
            if is_toplevel and depth == 0:
                return [value.ctype.declvar(name, value) + ';']
            env.decls[name] = Data(name, value.ctype)
        return _emit_assign_stmt(env[name], value, env)

    if isinstance(target, ast.Subscript):
        target = _transpile_expr(target, env)
        return _emit_assign_stmt(target, value, env)

    if isinstance(target, ast.Tuple):
        if not isinstance(value.ctype, _cuda_types.Tuple):
            raise ValueError(f'{value.ctype} cannot be unpack')
        size = len(target.elts)
        if len(value.ctype.types) > size:
            raise ValueError(f'too many values to unpack (expected {size})')
        if len(value.ctype.types) < size:
            raise ValueError(f'not enough values to unpack (expected {size})')
        codes = [value.ctype.declvar(f'_temp{depth}', value) + ';']
        for i in range(size):
            code = f'thrust::get<{i}>(_temp{depth})'
            ctype = value.ctype.types[i]
            stmt = _transpile_assign_stmt(target.elts[i], env,
                                          Data(code,
                                               ctype), is_toplevel, depth + 1)
            codes.extend(stmt)
        return [CodeBlock('', codes)]
Esempio n. 2
0
def _transpile_function_internal(func, attributes, mode, consts, in_types,
                                 ret_type):
    consts = dict([(k, Constant(v)) for k, v, in consts.items()])

    if not isinstance(func, ast.FunctionDef):
        # TODO(asi1024): Support for `ast.ClassDef`.
        raise NotImplementedError('Not supported: {}'.format(type(func)))
    if len(func.decorator_list) > 0:
        if sys.version_info >= (3, 9):
            # Code path for Python versions that support `ast.unparse`.
            for deco in func.decorator_list:
                deco_code = ast.unparse(deco)
                if not any(word in deco_code
                           for word in ['rawkernel', 'vectorize']):
                    warnings.warn(
                        f'Decorator {deco_code} may not supported in JIT.',
                        RuntimeWarning)
    arguments = func.args
    if arguments.vararg is not None:
        raise NotImplementedError('`*args` is not supported currently.')
    if len(arguments.kwonlyargs) > 0:  # same length with `kw_defaults`.
        raise NotImplementedError(
            'keyword only arguments are not supported currently .')
    if arguments.kwarg is not None:
        raise NotImplementedError('`**kwargs` is not supported currently.')
    if len(arguments.defaults) > 0:
        raise NotImplementedError(
            'Default values are not supported currently.')

    args = [arg.arg for arg in arguments.args]
    if len(args) != len(in_types):
        raise TypeError(
            f'{func.name}() takes {len(args)} positional arguments '
            f'but {len(in_types)} were given.')
    params = dict([(x, CudaObject(x, t)) for x, t in zip(args, in_types)])
    env = Environment(mode, consts, params, ret_type)
    body = _transpile_stmts(func.body, True, env)
    params = ', '.join([env[a].ctype.declvar(a) for a in args])
    local_vars = [v.ctype.declvar(n) + ';' for n, v in env.locals.items()]

    if env.ret_type is None:
        env.ret_type = _types.Void()

    head = f'{attributes} {env.ret_type} {func.name}({params})'
    code = CodeBlock(head, local_vars + body)
    return str(code), env
Esempio n. 3
0
def _transpile_stmt(stmt, is_toplevel, env):
    """Transpile the statement.

    Returns (list of [CodeBlock or str]): The generated CUDA code.
    """

    if isinstance(stmt, ast.ClassDef):
        raise NotImplementedError('class is not supported currently.')
    if isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef)):
        raise NotImplementedError(
            'Nested functions are not supported currently.')
    if isinstance(stmt, ast.Return):
        value = _transpile_expr(stmt.value, env)
        value = Data.init(value, env)
        t = value.ctype
        if env.ret_type is None:
            env.ret_type = t
        elif env.ret_type != t:
            raise ValueError(
                f'Failed to infer the return type: {env.ret_type} or {t}')
        return [f'return {value.code};']
    if isinstance(stmt, ast.Delete):
        raise NotImplementedError('`del` is not supported currently.')

    if isinstance(stmt, ast.Assign):
        if len(stmt.targets) != 1:
            raise NotImplementedError('Not implemented.')

        value = _transpile_expr(stmt.value, env)
        target = stmt.targets[0]

        if is_constants(value) and isinstance(target, ast.Name):
            name = target.id
            if not isinstance(value.obj, _typeclasses):
                if is_toplevel:
                    if env[name] is not None and not is_constants(env[name]):
                        raise TypeError(f'Type mismatch of variable: `{name}`')
                    env.consts[name] = value
                    return []
                else:
                    raise TypeError(
                        'Cannot assign constant value not at top-level.')

        value = Data.init(value, env)
        return _transpile_assign_stmt(target, env, value, is_toplevel)

    if isinstance(stmt, ast.AugAssign):
        value = _transpile_expr(stmt.value, env)
        target = _transpile_expr(stmt.target, env)
        assert isinstance(target, Data)
        value = Data.init(value, env)
        result = _eval_operand(stmt.op, (target, value), env)
        if not numpy.can_cast(result.ctype.dtype, target.ctype.dtype,
                              'same_kind'):
            raise TypeError('dtype mismatch')
        return [target.ctype.assign(target, result) + ';']

    if isinstance(stmt, ast.For):
        if len(stmt.orelse) > 0:
            raise NotImplementedError('while-else is not supported.')
        name = stmt.target.id
        iters = _transpile_expr(stmt.iter, env)

        if env[name] is None:
            var = Data(stmt.target.id, iters.ctype)
            env.locals[name] = var
            env.decls[name] = var
        elif env[name].ctype.dtype != iters.ctype.dtype:
            raise TypeError(f'Data type mismatch of variable: `{name}`: '
                            f'{env[name].ctype.dtype} != {iters.ctype.dtype}')

        body = _transpile_stmts(stmt.body, False, env)

        if not isinstance(iters, _internal_types.Range):
            raise NotImplementedError(
                'for-loop is supported only for range iterator.')

        init_code = (f'{iters.ctype} '
                     f'__it = {iters.start.code}, '
                     f'__stop = {iters.stop.code}, '
                     f'__step = {iters.step.code}')
        cond = '__step >= 0 ? __it < __stop : __it > __stop'
        if iters.step_is_positive is True:
            cond = '__it < __stop'
        elif iters.step_is_positive is False:
            cond = '__it > __stop'

        head = f'for ({init_code}; {cond}; __it += __step)'
        return [CodeBlock(head, [f'{name} = __it;'] + body)]

    if isinstance(stmt, ast.AsyncFor):
        raise ValueError('`async for` is not allowed.')
    if isinstance(stmt, ast.While):
        if len(stmt.orelse) > 0:
            raise NotImplementedError('while-else is not supported.')
        condition = _transpile_expr(stmt.test, env)
        condition = _astype_scalar(condition, _cuda_types.bool_, 'unsafe', env)
        condition = Data.init(condition, env)
        body = _transpile_stmts(stmt.body, False, env)
        head = f'while ({condition.code})'
        return [CodeBlock(head, body)]
    if isinstance(stmt, ast.If):
        condition = _transpile_expr(stmt.test, env)
        if is_constants(condition):
            stmts = stmt.body if condition.obj else stmt.orelse
            return _transpile_stmts(stmts, is_toplevel, env)
        head = f'if ({condition.code})'
        then_body = _transpile_stmts(stmt.body, False, env)
        else_body = _transpile_stmts(stmt.orelse, False, env)
        return [CodeBlock(head, then_body), CodeBlock('else', else_body)]
    if isinstance(stmt, (ast.With, ast.AsyncWith)):
        raise ValueError('Switching contexts are not allowed.')
    if isinstance(stmt, (ast.Raise, ast.Try)):
        raise ValueError('throw/catch are not allowed.')
    if isinstance(stmt, ast.Assert):
        value = _transpile_expr(stmt.test, env)
        if is_constants(value):
            assert value.obj
            return [';']
        else:
            return ['assert(' + value + ');']
    if isinstance(stmt, (ast.Import, ast.ImportFrom)):
        raise ValueError('Cannot import modules from the target functions.')
    if isinstance(stmt, (ast.Global, ast.Nonlocal)):
        raise ValueError('Cannot use global/nonlocal in the target functions.')
    if isinstance(stmt, ast.Expr):
        value = _transpile_expr(stmt.value, env)
        return [';'] if is_constants(value) else [value.code + ';']
    if isinstance(stmt, ast.Pass):
        return [';']
    if isinstance(stmt, ast.Break):
        raise NotImplementedError('Not implemented.')
    if isinstance(stmt, ast.Continue):
        raise NotImplementedError('Not implemented.')
    assert False