Ejemplo n.º 1
0
Archivo: dot.py Proyecto: thobauma/dace
    def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
        (desc_x, stride_x), (desc_y, stride_y), desc_res, sz = node.validate(
            parent_sdfg, parent_state)
        dtype = desc_x.dtype.base_type
        veclen = desc_x.dtype.veclen

        func, _, _ = blas_helpers.cublas_type_metadata(dtype)
        func = func + 'dot'

        n = n or node.n or sz
        if veclen != 1:
            n /= veclen

        code = (
            environments.cublas.cuBLAS.handle_setup_code(node) +
            f"""cublas{func}(__dace_cublas_handle, {n}, _x, {stride_x}, _y, 
                             {stride_y}, _result);""")

        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          {'_result': dtypes.pointer(dtype)},
                                          code,
                                          language=dace.dtypes.Language.CPP)

        return tasklet
Ejemplo n.º 2
0
    def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
        (desc_x, stride_x, rows_x,
         cols_x), desc_ipiv, desc_result = node.validate(
             parent_sdfg, parent_state)
        dtype = desc_x.dtype.base_type
        veclen = desc_x.dtype.veclen

        func, cuda_type, _ = blas_helpers.cublas_type_metadata(dtype)
        func = func + 'getrf'

        n = n or node.n
        if veclen != 1:
            n /= veclen

        code = (environments.cusolverdn.cuSolverDn.handle_setup_code(node) +
                f"""
                int __dace_workspace_size = 0;
                {cuda_type}* __dace_workspace;
                cusolverDn{func}_bufferSize(
                    __dace_cusolverDn_handle, {rows_x}, {cols_x}, ({cuda_type}*)_xin,
                    {stride_x}, &__dace_workspace_size);
                cudaMalloc<{cuda_type}>(
                    &__dace_workspace,
                    sizeof({cuda_type}) * __dace_workspace_size);
                cusolverDn{func}(
                    __dace_cusolverDn_handle, {rows_x}, {cols_x}, ({cuda_type}*)_xin,
                    {stride_x}, __dace_workspace, _ipiv, _res);
                cudaFree(__dace_workspace);
                """)

        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)
        conn = tasklet.out_connectors
        conn = {
            c: (dtypes.pointer(dace.int32) if c == '_res' else t)
            for c, t in conn.items()
        }
        tasklet.out_connectors = conn

        return tasklet
Ejemplo n.º 3
0
    def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
        (desc_x, stride_x), (desc_y, stride_y), desc_res, sz = node.validate(
            parent_sdfg, parent_state)
        dtype = desc_x.dtype.base_type
        veclen = desc_x.dtype.veclen

        func, _, _ = blas_helpers.cublas_type_metadata(dtype)
        func = func.lower() + 'dot'

        n = n or node.n or sz
        if veclen != 1:
            n /= veclen
        code = f"_result = cblas_{func}({n}, _x, {stride_x}, _y, {stride_y});"
        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)
        return tasklet
Ejemplo n.º 4
0
    def expansion(node, parent_state, parent_sdfg, n=None, **kwargs):
        (desc_a, stride_a, rows_a,
         cols_a), (desc_rhs, stride_rhs, rows_rhs,
                   cols_rhs), desc_ipiv, desc_res = node.validate(
                       parent_sdfg, parent_state)
        dtype = desc_a.dtype.base_type
        veclen = desc_a.dtype.veclen

        func, cuda_type, _ = blas_helpers.cublas_type_metadata(dtype)
        func = func + 'getrs'

        n = n or node.n
        if veclen != 1:
            n /= veclen

        # NOTE: In the case where the RHS is only a single vector (1D array),
        # cuSOLVER still expects ldb to be the "number of rows"
        if len(desc_rhs.shape) == 1:
            stride_rhs = rows_rhs

        code = (environments.cusolverdn.cuSolverDn.handle_setup_code(node) +
                f"""
                cusolverDn{func}(
                    __dace_cusolverDn_handle, CUBLAS_OP_N, {rows_a}, {cols_rhs},
                    ({cuda_type}*)_a, {stride_a}, _ipiv, ({cuda_type}*)_rhs_in, {stride_rhs}, _res); 
                """)

        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)
        conn = tasklet.out_connectors
        conn = {
            c: (dtypes.pointer(dace.int32) if c == '_res' else t)
            for c, t in conn.items()
        }
        tasklet.out_connectors = conn

        return tasklet
Ejemplo n.º 5
0
    def expansion(node, state, sdfg, **kwargs):
        node.validate(sdfg, state)
        dtype = node.dtype

        func, cdtype, factort = blas_helpers.cublas_type_metadata(dtype)
        func = func + 'geam'

        alpha = f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Pone()"
        beta = f"__state->cublas_handle.Constants(__dace_cuda_device).{factort}Zero()"
        _, _, (m, n) = _get_transpose_input(node, state, sdfg)

        code = (environments.cublas.cuBLAS.handle_setup_code(node) +
                f"""cublas{func}(
                    __dace_cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N,
                    {m}, {n}, {alpha}, ({cdtype}*)_inp, {n}, {beta}, ({cdtype}*)_inp, {m}, ({cdtype}*)_out, {m});
                """)

        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)

        return tasklet
Ejemplo n.º 6
0
Archivo: gemv.py Proyecto: mfkiwl/dace
    def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
        from dace.sdfg.scope import is_devicelevel_gpu
        if is_devicelevel_gpu(sdfg, state, node):
            return ExpandGemvPure.expansion(node, state, sdfg)

        node.validate(sdfg, state)

        ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x,
                                                       shape_x, strides_x),
         (edge_y, outer_array_y, shape_y,
          strides_y)) = _get_matmul_operands(node,
                                             state,
                                             sdfg,
                                             name_lhs="_A",
                                             name_rhs="_x",
                                             name_out="_y")
        dtype_a = outer_array_a.dtype.type
        dtype = outer_array_x.dtype.base_type
        veclen = outer_array_x.dtype.veclen
        m = m or node.m
        n = n or node.n
        if m is None:
            m = shape_y[0]
        if n is None:
            n = shape_x[0]

        transA = node.transA
        if strides_a[0] == 1:
            transA = not transA
            lda = strides_a[1]
        elif strides_a[1] == 1:
            lda = strides_a[0]
        else:
            warnings.warn('Matrix must be contiguous in at least '
                          'one dimension. Falling back to pure expansion.')
            return ExpandGemvPure.expansion(node,
                                            state,
                                            sdfg,
                                            m=m,
                                            n=n,
                                            **kwargs)

        layout = 'CblasColMajor'
        trans = 'CblasNoTrans' if transA else 'CblasTrans'
        if not node.transA:
            m, n = n, m

        if veclen != 1:
            warnings.warn('Vector GEMV not supported, falling back to pure.')
            return ExpandGemvPure.expansion(node,
                                            state,
                                            sdfg,
                                            m=m,
                                            n=n,
                                            **kwargs)

        func, ctype, runtimetype = blas_helpers.cublas_type_metadata(dtype)
        func = func.lower() + 'gemv'

        code = f"""cblas_{func}({layout}, {trans}, {m}, {n}, {node.alpha}, _A, {lda},
                                _x, {strides_x[0]}, {node.beta}, _y, {strides_y[0]});"""

        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)

        return tasklet
Ejemplo n.º 7
0
Archivo: gemv.py Proyecto: mfkiwl/dace
    def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
        node.validate(sdfg, state)

        ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x,
                                                       shape_x, strides_x),
         (edge_y, outer_array_y, shape_y,
          strides_y)) = _get_matmul_operands(node,
                                             state,
                                             sdfg,
                                             name_lhs="_A",
                                             name_rhs="_x",
                                             name_out="_y")
        dtype_a = outer_array_a.dtype.type
        dtype = outer_array_x.dtype.base_type
        veclen = outer_array_x.dtype.veclen
        m = m or node.m
        n = n or node.n
        if m is None:
            m = shape_y[0]
        if n is None:
            n = shape_x[0]

        transA = node.transA
        if strides_a[0] == 1:
            transA = not transA
            lda = strides_a[1]
        elif strides_a[1] == 1:
            lda = strides_a[0]
        else:
            warnings.warn('Matrix must be contiguous in at least '
                          'one dimension. Falling back to pure expansion.')
            return ExpandGemvPure.expansion(node,
                                            state,
                                            sdfg,
                                            m=m,
                                            n=n,
                                            **kwargs)

        trans = 'CUBLAS_OP_N' if transA else 'CUBLAS_OP_T'
        if not node.transA:
            m, n = n, m

        if veclen != 1:
            warnings.warn('Vector GEMV not supported, falling back to pure')
            return ExpandGemvPure.expansion(node,
                                            state,
                                            sdfg,
                                            m=m,
                                            n=n,
                                            **kwargs)

        func, ctype, runtimetype = blas_helpers.cublas_type_metadata(dtype)
        func += 'gemv'
        call_prefix = environments.cublas.cuBLAS.handle_setup_code(node)
        call_suffix = ''

        # Handle alpha / beta
        constants = {
            1.0:
            f"__state->cublas_handle.Constants(__dace_cuda_device).{runtimetype}Pone()",
            0.0:
            f"__state->cublas_handle.Constants(__dace_cuda_device).{runtimetype}Zero()",
        }
        if node.alpha not in constants or node.beta not in constants:
            # Deal with complex input constants
            if isinstance(node.alpha, complex):
                alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})'
            else:
                alpha = f'{dtype.ctype}({node.alpha})'
            if isinstance(node.beta, complex):
                beta = f'{dtype.ctype}({node.beta.real}, {node.beta.imag})'
            else:
                beta = f'{dtype.ctype}({node.beta})'

            # Set pointer mode to host
            call_prefix += f'''cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_HOST);
            {dtype.ctype} alpha = {alpha};
            {dtype.ctype} beta = {beta};
            '''
            call_suffix += '''
cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_DEVICE);
            '''
            alpha = f'({ctype} *)&alpha'
            beta = f'({ctype} *)&beta'
        else:
            alpha = constants[node.alpha]
            beta = constants[node.beta]

        code = (call_prefix + f"""
cublas{func}(__dace_cublas_handle, {trans}, {m}, {n}, {alpha}, _A, {lda},
             _x, {strides_x[0]}, {beta}, _y, {strides_y[0]});
                """ + call_suffix)

        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)

        return tasklet
Ejemplo n.º 8
0
    def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
        node.validate(sdfg, state)

        ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x,
                                                       shape_x, strides_x),
         (edge_y, outer_array_y, shape_y,
          strides_y)) = _get_matmul_operands(node,
                                             state,
                                             sdfg,
                                             name_lhs="_A",
                                             name_rhs="_x",
                                             name_out="_y")
        dtype_a = outer_array_a.dtype.type
        dtype = outer_array_x.dtype.base_type
        veclen = outer_array_x.dtype.veclen
        m = m or node.m
        n = n or node.n
        if m is None:
            m = shape_y[0]
        if n is None:
            n = shape_x[0]

        transA = node.transA
        if strides_a[0] == 1:
            transA = not transA
            lda = strides_a[1]
        elif strides_a[1] == 1:
            lda = strides_a[0]
        else:
            warnings.warn('Matrix must be contiguous in at least '
                          'one dimension. Falling back to pure expansion.')
            return ExpandGemvPure.expansion(node,
                                            state,
                                            sdfg,
                                            m=m,
                                            n=n,
                                            **kwargs)

        trans = 'CUBLAS_OP_N' if transA else 'CUBLAS_OP_T'
        if not node.transA:
            m, n = n, m

        if veclen != 1:
            warnings.warn('Vector GEMV not supported, falling back to pure')
            return ExpandGemvPure.expansion(node,
                                            state,
                                            sdfg,
                                            m=m,
                                            n=n,
                                            **kwargs)

        func, ctype, runtimetype = blas_helpers.cublas_type_metadata(dtype)
        func += 'gemv'

        # TODO: (alpha,beta) != (1,0)
        if node.alpha != 1.0 or node.beta != 0.0:
            raise NotImplementedError
        alpha = (
            '__state->cublas_handle.Constants(__dace_cuda_device).%sPone()' %
            runtimetype)
        beta = (
            '__state->cublas_handle.Constants(__dace_cuda_device).%sZero()' %
            runtimetype)

        code = (environments.cublas.cuBLAS.handle_setup_code(node) + f"""
cublas{func}(__dace_cublas_handle, {trans}, {m}, {n}, {alpha}, _A, {lda},
             _x, {strides_x[0]}, {beta}, _y, {strides_y[0]});""")

        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)

        return tasklet