Example #1
0
def _parse_function_object(func):
    # Parses function object into ast.FunctionDef object.
    if not callable(func):
        raise ValueError('`func` must be a callable object.')

    if func.__name__ != '<lambda>':
        if jit._getsource_func is None:
            lines = inspect.getsource(func).split('\n')
            num_indent = len(lines[0]) - len(lines[0].lstrip())
            source = '\n'.join(
                [line.replace(' ' * num_indent, '', 1) for line in lines])
        else:
            source = jit._getsource_func(func)
        tree = ast.parse(source)
        assert isinstance(tree, ast.Module)
        assert len(tree.body) == 1
        return tree.body[0], source

    if jit._getsource_func is not None:
        full_source = jit._getsource_func(func)
        start_line, end_line = 0, math.inf
        source = full_source
    else:
        try:
            filename = inspect.getsourcefile(func)
        except TypeError:
            filename = None
        if filename is None:
            raise ValueError(f'JIT needs access to Python source for {func}'
                             'but could not be located')
        with open(filename) as f:
            full_source = f.read()
        source, start_line = inspect.getsourcelines(func)
        end_line = start_line + len(source)
        source = ''.join(source)

    tree = ast.parse(full_source)

    nodes = [
        node for node in ast.walk(tree) if isinstance(node, ast.Lambda)
        and start_line <= node.lineno < end_line
    ]
    if len(nodes) > 1:
        raise ValueError('Multiple callables are found near the'
                         f' definition of {func}, and JIT could not'
                         ' identify the source code for it.')
    node = nodes[0]
    return ast.FunctionDef(
        name='_lambda_kernel',
        args=node.args,
        body=[ast.Return(node.body)],
        decorator_list=[],
        returns=None,
        type_comment=None,
    ), source
Example #2
0
def transpile(func, attributes, mode, in_types, ret_type):
    """Transpile the target function
    Args:
        func (function): Target function.
        attributes (list of str): Attributes of the generated CUDA function.
        mode ('numpy' or 'cuda'): The rule for typecast.
        in_types (list of _types.TypeBase): Types of the arguments.
        ret_type (_types.TypeBase or None): Type of the return value.
    """

    if not callable(func):
        raise ValueError('`func` must be a callable object.')

    if func.__name__ == '<lambda>':
        raise NotImplementedError('Lambda function is not supported.')

    attributes = ' '.join(attributes)
    source = jit._getsource_func(func)
    lines = source.split('\n')
    num_indent = len(lines[0]) - len(lines[0].lstrip())
    source = '\n'.join(
        [line.replace(' ' * num_indent, '', 1) for line in lines])

    cvars = inspect.getclosurevars(func)
    consts = dict(**cvars.globals, **cvars.nonlocals, **cvars.builtins)
    tree = ast.parse(source)
    assert isinstance(tree, ast.Module)
    assert len(tree.body) == 1
    cuda_code, env = _transpile_function(tree.body[0],
                                         attributes,
                                         mode,
                                         consts,
                                         in_types,
                                         ret_type,
                                         source=source)
    cuda_code = ''.join([code + '\n' for code in env.preambles]) + cuda_code
    return Result(
        func_name=func.__name__,
        code=cuda_code,
        return_type=env.ret_type,
    )