示例#1
0
    def call_const(self, env, ndim):
        if not isinstance(ndim, int):
            raise TypeError('ndim must be an integer')

        # Numba convention: for 1D we return a single variable,
        # otherwise a tuple
        if ndim == 1:
            return Data(self._code.format(n='x'), _cuda_types.uint32)
        elif ndim == 2:
            dims = ('x', 'y')
        elif ndim == 3:
            dims = ('x', 'y', 'z')
        else:
            raise ValueError('Only ndim=1,2,3 are supported')

        elts_code = ', '.join(self._code.format(n=n) for n in dims)
        ctype = _cuda_types.Tuple([_cuda_types.uint32] * ndim)
        return Data(f'thrust::make_tuple({elts_code})', ctype)
示例#2
0
def _transpile_expr_internal(expr, env):
    if isinstance(expr, ast.BoolOp):
        values = [_transpile_expr(e, env) for e in expr.values]
        value = values[0]
        for rhs in values[1:]:
            value = _eval_operand(expr.op, (value, rhs), env)
        return value
    if isinstance(expr, ast.BinOp):
        left = _transpile_expr(expr.left, env)
        right = _transpile_expr(expr.right, env)
        return _eval_operand(expr.op, (left, right), env)
    if isinstance(expr, ast.UnaryOp):
        value = _transpile_expr(expr.operand, env)
        return _eval_operand(expr.op, (value, ), env)
    if isinstance(expr, ast.Lambda):
        raise NotImplementedError('Not implemented.')
    if isinstance(expr, ast.Compare):
        values = [expr.left] + expr.comparators
        if len(values) != 2:
            raise NotImplementedError(
                'Comparison of 3 or more values is not implemented.')
        values = [_transpile_expr(e, env) for e in values]
        return _eval_operand(expr.ops[0], values, env)
    if isinstance(expr, ast.IfExp):
        cond = _transpile_expr(expr.test, env)
        x = _transpile_expr(expr.body, env)
        y = _transpile_expr(expr.orelse, env)

        if isinstance(expr, Constant):
            return x if expr.obj else y
        if cond.ctype.dtype.kind == 'c':
            raise TypeError("Complex type value cannot be boolean condition.")
        x, y = _infer_type(x, y, env), _infer_type(y, x, env)
        if x.ctype.dtype != y.ctype.dtype:
            raise TypeError('Type mismatch in conditional expression.: '
                            f'{x.ctype.dtype} != {y.ctype.dtype}')
        cond = _astype_scalar(cond, _cuda_types.bool_, 'unsafe', env)
        return Data(f'({cond.code} ? {x.code} : {y.code})', x.ctype)

    if isinstance(expr, ast.Call):
        func = _transpile_expr(expr.func, env)
        args = [_transpile_expr(x, env) for x in expr.args]
        kwargs = dict([(kw.arg, _transpile_expr(kw.value, env))
                       for kw in expr.keywords])

        builtin_funcs = _builtin_funcs.builtin_functions_dict
        if is_constants(func) and (func.obj in builtin_funcs):
            func = builtin_funcs[func.obj]

        if isinstance(func, _internal_types.BuiltinFunc):
            return func.call(env, *args, **kwargs)

        if not is_constants(func):
            raise TypeError(f"'{func}' is not callable.")

        func = func.obj

        if is_constants(*args, *kwargs.values()):
            # compile-time function call
            args = [x.obj for x in args]
            kwargs = dict([(k, v.obj) for k, v in kwargs.items()])
            return Constant(func(*args, **kwargs))

        if isinstance(func, _kernel.ufunc):
            # ufunc call
            dtype = kwargs.pop('dtype', Constant(None)).obj
            if len(kwargs) > 0:
                name = next(iter(kwargs))
                raise TypeError(
                    f"'{name}' is an invalid keyword to ufunc {func.name}")
            return _call_ufunc(func, args, dtype, env)

        if inspect.isclass(func) and issubclass(func, _typeclasses):
            # explicit typecast
            if len(args) != 1:
                raise TypeError(
                    f'function takes {func} invalid number of argument')
            ctype = _cuda_types.Scalar(func)
            return _astype_scalar(args[0], ctype, 'unsafe', env)

        if isinstance(func, _interface._JitRawKernel) and func._device:
            args = [Data.init(x, env) for x in args]
            in_types = tuple([x.ctype for x in args])
            fname, return_type = _transpile_func_obj(func._func,
                                                     ['__device__'], env.mode,
                                                     in_types, None,
                                                     env.generated)
            in_params = ', '.join([x.code for x in args])
            return Data(f'{fname}({in_params})', return_type)

        raise TypeError(f"Invalid function call '{fname}'.")

    if isinstance(expr, ast.Constant):
        return Constant(expr.value)
    if isinstance(expr, ast.Num):
        # Deprecated since py3.8
        return Constant(expr.n)
    if isinstance(expr, ast.Str):
        # Deprecated since py3.8
        return Constant(expr.s)
    if isinstance(expr, ast.NameConstant):
        # Deprecated since py3.8
        return Constant(expr.value)
    if isinstance(expr, ast.Subscript):
        array = _transpile_expr(expr.value, env)
        index = _transpile_expr(expr.slice, env)
        return _indexing(array, index, env)
    if isinstance(expr, ast.Name):
        value = env[expr.id]
        if value is None:
            raise NameError(f'Unbound name: {expr.id}')
        return env[expr.id]
    if isinstance(expr, ast.Attribute):
        value = _transpile_expr(expr.value, env)
        if is_constants(value):
            return Constant(getattr(value.obj, expr.attr))
        if isinstance(value.ctype, _cuda_types.ArrayBase):
            if 'ndim' == expr.attr:
                return Constant(value.ctype.ndim)
        if isinstance(value.ctype, _cuda_types.CArray):
            if 'size' == expr.attr:
                return Data(f'static_cast<long long>({value.code}.size())',
                            _cuda_types.Scalar('q'))
        if isinstance(value.ctype, _interface._Dim3):
            if expr.attr in ('x', 'y', 'z'):
                return Data(f'{value.code}.{expr.attr}', _cuda_types.uint32)
        # TODO(leofang): support arbitrary Python class methods
        if isinstance(value.ctype, _ThreadGroup):
            return _internal_types.BuiltinFunc.from_class_method(
                value.code, getattr(value.ctype, expr.attr))
        raise NotImplementedError('Not implemented: __getattr__')

    if isinstance(expr, ast.Tuple):
        elts = [_transpile_expr(x, env) for x in expr.elts]
        # TODO: Support compile time constants.
        elts = [Data.init(x, env) for x in elts]
        elts_code = ', '.join([x.code for x in elts])
        ctype = _cuda_types.Tuple([x.ctype for x in elts])
        return Data(f'thrust::make_tuple({elts_code})', ctype)

    if isinstance(expr, ast.Index):
        return _transpile_expr(expr.value, env)

    raise ValueError('Not supported: type {}'.format(type(expr)))
示例#3
0
def _transpile_expr_internal(expr, env):
    if isinstance(expr, ast.BoolOp):
        values = [_transpile_expr(e, env) for e in expr.values]
        value = values[0]
        for rhs in values[1:]:
            value = _eval_operand(expr.op, (value, rhs), env)
        return value
    if isinstance(expr, ast.BinOp):
        left = _transpile_expr(expr.left, env)
        right = _transpile_expr(expr.right, env)
        return _eval_operand(expr.op, (left, right), env)
    if isinstance(expr, ast.UnaryOp):
        value = _transpile_expr(expr.operand, env)
        return _eval_operand(expr.op, (value, ), env)
    if isinstance(expr, ast.Lambda):
        raise NotImplementedError('Not implemented.')
    if isinstance(expr, ast.Compare):
        values = [expr.left] + expr.comparators
        if len(values) != 2:
            raise NotImplementedError(
                'Comparison of 3 or more values is not implemented.')
        values = [_transpile_expr(e, env) for e in values]
        return _eval_operand(expr.ops[0], values, env)
    if isinstance(expr, ast.IfExp):
        cond = _transpile_expr(expr.test, env)
        x = _transpile_expr(expr.body, env)
        y = _transpile_expr(expr.orelse, env)

        if isinstance(expr, Constant):
            return x if expr.obj else y
        if cond.ctype.dtype.kind == 'c':
            raise NotImplementedError('')
        x = Data.init(x, env)
        y = Data.init(y, env)
        if x.ctype.dtype != y.ctype.dtype:
            raise TypeError('Type mismatch in conditional expression.: '
                            f'{x.ctype.dtype} != {y.ctype.dtype}')
        cond = _astype_scalar(cond, _cuda_types.bool_, 'unsafe', env)
        return Data(f'({cond.code} ? {x.code} : {y.code})', x.ctype)

    if isinstance(expr, ast.Call):
        func = _transpile_expr(expr.func, env)
        args = [_transpile_expr(x, env) for x in expr.args]
        kwargs = dict([(kw.arg, _transpile_expr(kw.value, env))
                       for kw in expr.keywords])

        builtin_funcs = _builtin_funcs.builtin_functions_dict
        if is_constants(func) and (func.obj in builtin_funcs):
            func = builtin_funcs[func.obj]

        if isinstance(func, _internal_types.BuiltinFunc):
            return func.call(env, *args, **kwargs)

        if not is_constants(func):
            raise NotImplementedError(
                'device function call is not implemented.')

        func = func.obj

        if is_constants(*args, *kwargs.values()):
            # compile-time function call
            args = [x.obj for x in args]
            kwargs = dict([(k, v.obj) for k, v in kwargs.items()])
            return Constant(func(*args, **kwargs))

        if isinstance(func, _kernel.ufunc):
            # ufunc call
            dtype = kwargs.pop('dtype', Constant(None)).obj
            if len(kwargs) > 0:
                name = next(iter(kwargs))
                raise TypeError(
                    f"'{name}' is an invalid keyword to ufunc {func.name}")
            return _call_ufunc(func, args, dtype, env)

        if inspect.isclass(func) and issubclass(func, _typeclasses):
            # explicit typecast
            if len(args) != 1:
                raise TypeError(
                    f'function takes {func} invalid number of argument')
            ctype = _cuda_types.Scalar(func)
            return _astype_scalar(args[0], ctype, 'unsafe', env)

        raise NotImplementedError(
            f'function call of `{func.__name__}` is not implemented')

    if isinstance(expr, ast.Constant):
        return Constant(expr.value)
    if isinstance(expr, ast.Num):
        # Deprecated since py3.8
        return Constant(expr.n)
    if isinstance(expr, ast.Str):
        # Deprecated since py3.8
        return Constant(expr.s)
    if isinstance(expr, ast.NameConstant):
        # Deprecated since py3.8
        return Constant(expr.value)

    if isinstance(expr, ast.Subscript):
        value = _transpile_expr(expr.value, env)
        index = _transpile_expr(expr.slice, env)

        if is_constants(value):
            if is_constants(index):
                return Constant(value.obj[index.obj])
            raise TypeError(
                f'{type(value.obj)} is not subscriptable with non-constants.')

        value = Data.init(value, env)

        if isinstance(value.ctype, _cuda_types.Tuple):
            raise NotImplementedError

        if isinstance(value.ctype, _cuda_types.ArrayBase):
            index = Data.init(index, env)
            ndim = value.ctype.ndim
            if isinstance(index.ctype, _cuda_types.Scalar):
                index_dtype = index.ctype.dtype
                if ndim != 1:
                    raise TypeError(
                        'Scalar indexing is supported only for 1-dim array.')
                if index_dtype.kind not in 'ui':
                    raise TypeError('Array indices must be integers.')
                return Data(f'{value.code}[{index.code}]',
                            value.ctype.child_type)
            if isinstance(index.ctype, _cuda_types.Tuple):
                if ndim != len(index.ctype.types):
                    raise IndexError(f'The size of index must be {ndim}')
                for t in index.ctype.types:
                    if not isinstance(t, _cuda_types.Scalar):
                        raise TypeError('Array indices must be scalar.')
                    if t.dtype.kind not in 'iu':
                        raise TypeError('Array indices must be integer.')
                if ndim == 0:
                    return Data(f'{value.code}[0]', value.ctype.child_type)
                if ndim == 1:
                    return Data(f'{value.code}[thrust::get<0>({index.code})]',
                                value.ctype.child_type)
                return Data(f'{value.code}._indexing({index.code})',
                            value.ctype.child_type)
            if isinstance(index.ctype, _cuda_types.Array):
                raise TypeError('Advanced indexing is not supported.')
            assert False  # Never reach.

        raise TypeError(f'{value.code} is not subscriptable.')

    if isinstance(expr, ast.Name):
        value = env[expr.id]
        if value is None:
            raise NameError(f'Unbound name: {expr.id}')
        return env[expr.id]
    if isinstance(expr, ast.Attribute):
        value = _transpile_expr(expr.value, env)
        if is_constants(value):
            return Constant(getattr(value.obj, expr.attr))
        raise NotImplementedError('Not implemented: __getattr__')

    if isinstance(expr, ast.Tuple):
        elts = [_transpile_expr(x, env) for x in expr.elts]
        # TODO: Support compile time constants.
        elts = [Data.init(x, env) for x in elts]
        elts_code = ', '.join([x.code for x in elts])
        ctype = _cuda_types.Tuple([x.ctype for x in elts])
        return Data(f'thrust::make_tuple({elts_code})', ctype)

    if isinstance(expr, ast.Index):
        return _transpile_expr(expr.value, env)

    raise ValueError('Not supported: type {}'.format(type(expr)))