コード例 #1
0
ファイル: gemv.py プロジェクト: mfkiwl/dace
    def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
        node.validate(sdfg, state)
        ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x,
                                                       shape_x, strides_x),
         (edge_y, outer_array_y, shape_y,
          strides_y)) = _get_matmul_operands(node,
                                             state,
                                             sdfg,
                                             name_lhs="_A",
                                             name_rhs="_x",
                                             name_out="_y")
        dtype_a = outer_array_a.dtype.type
        dtype = outer_array_x.dtype.base_type
        veclen = outer_array_x.dtype.veclen
        m = m or node.m
        n = n or node.n
        if m is None:
            m = shape_y[0]
        if n is None:
            n = shape_x[0]

        transA = node.transA

        Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True)
        Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True)
        try:
            sdfg.add_symbol('Px', dace.int32)
            sdfg.add_symbol('Py', dace.int32)
        except FileExistsError:
            pass

        @dace.program
        def _gemNv_pblas(_A: dtype[m, n], _x: dtype[n], _y: dtype[m]):
            lA = np.empty((m // Px, n // Py), dtype=_A.dtype)
            lx = np.empty((n // Px, ), dtype=_x.dtype)
            dace.comm.BCScatter(_A, lA, (m // Px, n // Py))
            dace.comm.BCScatter(_x, lx, (n // Px, 1))
            ly = distr.MatMult(_A, _x, lA, lx, (m // Px, n // Py),
                               (n // Px, 1))
            dace.comm.BCGather(ly, _y, (m // Px, 1))

        @dace.program
        def _gemTv_pblas(_A: dtype[m, n], _x: dtype[m], _y: dtype[n]):
            lA = np.empty((m // Px, n // Py), dtype=_A.dtype)
            lx = np.empty((m // Px, ), dtype=_x.dtype)
            dace.comm.BCScatter(_A, lA, (m // Px, n // Py))
            dace.comm.BCScatter(_x, lx, (m // Px, 1))
            ly = distr.MatMult(_x, _A, lx, lA, (m // Px, 1),
                               (m // Px, n // Py))
            dace.comm.BCGather(ly, _y, (n // Px, 1))

        # NOTE: The following is done to avoid scalar promotion, which results
        # in ValueError: Node type "BlockCyclicScatter" not supported for
        # promotion
        if transA:
            sdfg = _gemTv_pblas.to_sdfg(simplify=False)
        else:
            sdfg = _gemNv_pblas.to_sdfg(simplify=False)
        sdfg.simplify()
        return sdfg
コード例 #2
0
    def expansion(node, state, sdfg):
        node.validate(sdfg, state)
        (_, adesc, ashape,
         astrides), (_, bdesc, bshape,
                     bstrides), _ = _get_matmul_operands(node, state, sdfg)
        dtype = adesc.dtype.base_type

        if node.beta != 0:
            raise NotImplementedError

        M = ashape[0]
        K = ashape[1]
        N = bshape[1]
        Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True)
        Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True)
        try:
            sdfg.add_symbol('Px', dace.int32)
            sdfg.add_symbol('Py', dace.int32)
        except FileExistsError:
            pass

        @dace.program
        def _gemm_pblas(_a: dtype[M, K], _b: dtype[K, N], _c: dtype[M, N]):
            lA = np.empty((M // Px, K // Py), dtype=_a.dtype)
            lB = np.empty((K // Px, N // Py), dtype=_b.dtype)
            dace.comm.BCScatter(_a, lA, (M // Px, K // Py))
            dace.comm.BCScatter(_b, lB, (K // Px, N // Py))
            lC = distr.MatMult(_a, _b, lA, lB, (M // Px, K // Py),
                               (K // Px, N // Py))
            dace.comm.BCGather(lC, _c, (M // Px, N // Py))

        return _gemm_pblas.to_sdfg()
コード例 #3
0
    def expansion(node, state, sdfg):
        node.validate(sdfg, state)
        (_, adesc, ashape,
         astrides), (_, bdesc, bshape,
                     bstrides), _ = _get_matmul_operands(node, state, sdfg)
        dtype = adesc.dtype.base_type
        func = to_blastype(dtype.type).lower() + 'gemm'
        alpha = f'{dtype.ctype}({node.alpha})'
        beta = f'{dtype.ctype}({node.beta})'

        # Deal with complex input constants
        if isinstance(node.alpha, complex):
            alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})'
        if isinstance(node.beta, complex):
            beta = f'{dtype.ctype}({node.beta.real}, {node.beta.imag})'

        cdesc = sdfg.arrays[state.out_edges(node)[0].data.data]

        check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc)

        opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc,
                                     alpha, beta, dtype.ctype, func)

        # Adaptations for BLAS API
        opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans'
        opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans'

        code = ''
        if dtype in (dace.complex64, dace.complex128):
            code = f'''
            {dtype.ctype} alpha = {alpha};
            {dtype.ctype} beta = {beta};
            '''
            opt['alpha'] = '&alpha'
            opt['beta'] = '&beta'

        code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, "
                 "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, "
                 "_c, {ldc});").format_map(opt)

        tasklet = dace.sdfg.nodes.Tasklet(
            node.name,
            node.in_connectors,
            node.out_connectors,
            code,
            language=dace.dtypes.Language.CPP,
        )
        return tasklet
コード例 #4
0
ファイル: batched_matmul.py プロジェクト: mfkiwl/dace
    def expansion(node, state, sdfg):
        node.validate(sdfg, state)
        (_, adesc, ashape,
         astrides), (_, bdesc, bshape,
                     bstrides), _ = _get_matmul_operands(node, state, sdfg)
        cdesc = sdfg.arrays[state.out_edges(node)[0].data.data]
        check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc)
        dtype = cdesc.dtype.base_type
        func = to_blastype(dtype.type).lower() + 'gemm'
        if dtype == dace.float32:
            alpha = "1.0f"
            beta = "0.0f"
        elif dtype == dace.float64:
            alpha = "1.0"
            beta = "0.0"
        elif dtype == dace.complex64:
            alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
            beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
        elif dtype == dace.complex128:
            alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
            beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
        else:
            raise ValueError("Unsupported type for BLAS dot product: " +
                             str(dtype))
        opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc,
                                     alpha, beta, cdesc.dtype.ctype, func)

        # Adaptations for MKL/BLAS API
        opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans'
        opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans'

        code = '''
        for (int __ib = 0; __ib < {BATCH}; ++__ib) {{
            cblas_{func}(CblasColMajor, {ta}, {tb}, {M}, {N}, {K}, {alpha},
                         (({dtype}*){x}) + __ib*{stride_a}, {lda},
                         (({dtype}*){y}) + __ib*{stride_b}, {ldb},
                         {beta},
                         (({dtype}*)_c) + __ib*{stride_c}, {ldc});
        }}'''.format_map(opt)

        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)
        return tasklet
コード例 #5
0
    def expansion(node, state, sdfg):
        node.validate(sdfg, state)
        (_, adesc, ashape,
         astrides), (_, bdesc, bshape,
                     bstrides), _ = _get_matmul_operands(node, state, sdfg)
        dtype = adesc.dtype.base_type
        func = to_blastype(dtype.type).lower() + 'gemm'
        # TODO: Fix w.r.t. other alpha/beta values
        if dtype == dace.float32:
            alpha = "1.0f"
            beta = "0.0f"
        elif dtype == dace.float64:
            alpha = "1.0"
            beta = "0.0"
        elif dtype == dace.complex64:
            alpha = "dace::blas::BlasConstants::Get().Complex64Pone()"
            beta = "dace::blas::BlasConstants::Get().Complex64Zero()"
        elif dtype == dace.complex128:
            alpha = "dace::blas::BlasConstants::Get().Complex128Pone()"
            beta = "dace::blas::BlasConstants::Get().Complex128Zero()"
        else:
            raise ValueError("Unsupported type for BLAS dot product: " +
                             str(dtype))
        cdesc = sdfg.arrays[state.out_edges(node)[0].data.data]
        opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc,
                                     alpha, beta, cdesc.dtype.ctype, func)

        # Adaptations for MKL/BLAS API
        opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans'
        opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans'

        code = ("cblas_{func}(CblasColMajor, {ta}, {tb}, "
                "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, "
                "_c, {ldc});").format_map(opt)

        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)
        return tasklet
コード例 #6
0
    def expansion(node,
                  parent_state,
                  parent_sdfg,
                  num_pes=32,
                  tile_size_m=None):
        '''
        GEMM node expansion.

        :param node: Node to expand.
        :param parent_state: State that the node is in.
        :param parent_sdfg: SDFG that the node is in.
        :param num_pes: Number of Processing Elements of the systolic array. By default it is set to 32.

        :param tile_size_m: tiling size considering columns of the input matrix B and resulting matrix C.
                            If B/C are vectorized, the tile size refers to the vectorized container.
                            If set to None, no tiling is used, corresponding to setting the tile size
                            equal to the number of columns of B/C.
        :return:
        '''

        ((edge_a, outer_array_a, shape_a, strides_a), (edge_b, outer_array_b,
                                                       shape_b, strides_b),
         (edge_c, outer_array_c, shape_c,
          strides_c)) = _get_matmul_operands(node, parent_state, parent_sdfg)

        dtype_a = outer_array_a.dtype.type
        dtype_b = outer_array_b.dtype.type
        dtype_c = dace.DTYPE_TO_TYPECLASS[np.result_type(dtype_a,
                                                         dtype_b).type]
        shape_c = (shape_a[0], shape_b[1])
        if node.transA:
            raise NotImplementedError(
                "GEMM FPGA expansion not implemented for transposed A.")
        if node.transB:
            raise NotImplementedError(
                "GEMM FPGA expansion not implemented for transposed B.")

        if outer_array_a.veclen > 1:
            raise NotImplementedError(
                "Vectorization not support for input array A.")

        if len(shape_a) != 2 or len(shape_b) != 2 or shape_a[1] != shape_b[0]:
            raise SyntaxError("Matrix sizes must match")

        if outer_array_b.dtype.veclen != outer_array_c.dtype.veclen:
            raise SyntaxError("Vectorization lengths of B and C must match")

        ######################################################################
        # GEMM Parameters and checks

        # Note: the following sizes consider also vectorization
        vec_width = outer_array_b.dtype.veclen
        vec_type = dace.vector(dtype_c, vec_width)
        N, K, M = shape_a[0], shape_a[1], shape_b[1]

        P = num_pes
        T = tile_size_m
        if T is None:
            T = M

        # we will perform sanity check using T and M. But at this stage, we still
        # don't know to what outer symbol they will map.
        # We try to resolve them to constant if they are symbolic, otherwise we skip the checks
        T_constant = dace.symbolic.resolve_symbol_to_constant(T, parent_sdfg)
        K_constant = dace.symbolic.resolve_symbol_to_constant(K, parent_sdfg)

        # Safe delay: this will be used in the compute state, pipeline scope, to insert
        # a delay between accumulation on the same result if needed.
        # Further explanations are provided in the compute state.

        # Note: this is a platform and type dependent parameter.
        if T_constant is not None:
            L = max(16 - T_constant, 0)
        else:
            L = 0

        # This implementation uses a flattened nested loop, that overlaps feeding,
        # computing and draining phases. Each PE is responsible for computing one
        # tile of one row of the final result C. With the current implementation,
        # A PE needs K*T cycles to compute the results and then P*T clock cycles
        # to fully drain them (draining is distributed across PEs).
        # Therefore, in order to guarantee correctness and deadlock free we have
        # to ensure that the number of cycles needed to drain the results is less
        # or equal to the number of cycles needed to compute them.
        # That is PT <= KT.

        if K_constant is not None and P > K_constant:
            raise ValueError(
                f"GEMM-FPGA: Number of processing elements {P} must be smaller than the K-dimension {K}."
            )

        ######################################################################
        # Build the SDFG

        new_sdfg = dace.SDFG(node.label + "_sdfg")
        new_state = new_sdfg.add_state("compute")

        # Add data descriptors

        new_sdfg.add_array("_a",
                           shape_a,
                           dtype_a,
                           strides=strides_a,
                           storage=outer_array_a.storage)
        new_sdfg.add_array("_b",
                           shape_b,
                           dtype_b,
                           strides=strides_b,
                           storage=outer_array_b.storage)
        new_sdfg.add_array("_c",
                           shape_c,
                           dtype_c,
                           strides=strides_c,
                           storage=outer_array_c.storage)

        if node.beta != 0:
            new_sdfg.add_array("_cin",
                               shape_c,
                               dtype_c,
                               strides=strides_c,
                               storage=outer_array_c.storage)

        def make_read_A(state):

            # A given row of A must be repeated according to B number of tiles
            # Both N and M can be not a multiple of P and T respectively
            entry, exit = state.add_map("read_A", {
                "n0": f"0:ceiling({N}/{P})",
                "tm": f"0:ceiling({M}/{T})",
                "k": f"0:{K}",
                "n1": f"0:{P}"
            },
                                        schedule=dace.ScheduleType.FPGA_Device)

            # The reader of A reads one element per clock cycle.
            # Note that if P > T+L, then this will be the bottleneck

            mem = state.add_read("_a")
            pipe = state.add_write("A_pipe")

            # Read data from memory: if we are out-of-bound do not read from memory
            # but inject dummy data
            tasklet = state.add_tasklet(
                "read_A", {"from_memory"}, {"to_kernel"}, f"""\
data = from_memory if n0 * {P} + n1 < {N} else 0
to_kernel = data""")

            state.add_memlet_path(mem,
                                  entry,
                                  tasklet,
                                  dst_conn="from_memory",
                                  memlet=dace.Memlet(f"_a[n0 * {P} + n1, k]",
                                                     dynamic=True,
                                                     allow_oob=True))
            state.add_memlet_path(tasklet,
                                  exit,
                                  pipe,
                                  src_conn="to_kernel",
                                  memlet=dace.Memlet(f"A_pipe[{P} - n1 - 1]"))

        def make_read_B(state):

            # Also while reading B, we have to consider that T and P could not divide
            # M and N

            entry, exit = state.add_map("read_B", {
                "n": f"0:ceiling({N}/{P})",
                "tm": f"0:ceiling({M}/{T})",
                "k": f"0:{K}",
                "m": f"0:{T}"
            },
                                        schedule=dace.ScheduleType.FPGA_Device)

            # If we are out-of bound, use a dummy value
            new_sdfg.add_array("B_dummy",
                               dtype=vec_type,
                               shape=[1],
                               transient=True,
                               storage=dace.dtypes.StorageType.FPGA_Registers)
            b_dummy = state.add_access("B_dummy")
            init_tasklet = state.add_tasklet("init_dummy_B", {}, {"init_data"},
                                             "init_data = 0")

            state.add_memlet_path(init_tasklet,
                                  b_dummy,
                                  src_conn="init_data",
                                  memlet=dace.Memlet("B_dummy[0]"))

            mem = state.add_read("_b")
            pipe = state.add_write("B_pipe")
            tasklet = state.add_tasklet(
                "read_B", {"from_memory", "dummy_data"}, {"to_kernel"}, f"""\
data = from_memory if tm*{T} + m < {M} else dummy_data
to_kernel = data""")

            state.add_memlet_path(b_dummy,
                                  entry,
                                  tasklet,
                                  dst_conn="dummy_data",
                                  memlet=dace.Memlet("B_dummy[0]"))

            state.add_memlet_path(mem,
                                  entry,
                                  tasklet,
                                  dst_conn="from_memory",
                                  memlet=dace.Memlet(f"_b[k, tm*{T} + m]",
                                                     dynamic=True,
                                                     allow_oob=True))

            state.add_memlet_path(tasklet,
                                  exit,
                                  pipe,
                                  src_conn="to_kernel",
                                  memlet=dace.Memlet("B_pipe[0]"))

        def make_write_C(state):

            # Receives the results and adds it to C

            pipe = state.add_read("C_pipe")
            if node.beta != 0:
                mem_read = state.add_read("_cin")
            mem = state.add_write("_c")

            entry_map, exit_map = state.add_map(
                "write_C", {
                    "n0": f"0:ceiling({N}/{P})",
                    "tm": f"0:ceiling({M}/{T})",
                    "n1": f"0:{P}",
                    "m": f"0:{T}"
                },
                schedule=dace.ScheduleType.FPGA_Device)

            # write in memory by adding C when we copy that to memory

            # deal with out-of-bound accesses

            mul_accumulated = f"{node.alpha} * from_kernel" if node.alpha != 1.0 else "from_kernel"
            if node.beta != 0:
                if node.beta != 1.0:
                    add_prev_c = f" + {node.beta} * prev_c"
                else:
                    add_prev_c = " + prev_c"
            else:
                add_prev_c = ""
            tasklet_inputs = {"from_kernel", "prev_c"
                              } if node.beta != 0 else {"from_kernel"}
            tasklet = state.add_tasklet(
                "write_C", tasklet_inputs, {"to_memory"}, f"""\
if tm * {T} + m  < {M}  and  n0 * {P} + n1 < {N} :                                               
    to_memory = {mul_accumulated}{add_prev_c}
""")
            state.add_memlet_path(pipe,
                                  entry_map,
                                  tasklet,
                                  dst_conn="from_kernel",
                                  memlet=dace.Memlet(f"C_pipe[{P}-1]"))
            if node.beta != 0:
                state.add_memlet_path(mem_read,
                                      entry_map,
                                      tasklet,
                                      dst_conn="prev_c",
                                      memlet=dace.Memlet(
                                          f"_cin[n0 * {P} + n1, tm * {T} + m]",
                                          dynamic=True,
                                          allow_oob=True))

            state.add_memlet_path(tasklet,
                                  exit_map,
                                  mem,
                                  src_conn="to_memory",
                                  memlet=dace.Memlet(
                                      f"_c[n0 * {P} + n1, tm * {T} + m]",
                                      dynamic=True,
                                      allow_oob=True))

        def make_compute(sdfg, state):

            A_pipe_in = state.add_read("A_pipe")
            B_pipe_in = state.add_read("B_pipe")
            B_pipe_out = state.add_write("B_pipe")
            C_pipe_in = state.add_read("C_pipe")
            C_pipe_out = state.add_write("C_pipe")

            # The computation is expressed a single, flattened loop, which is generated by the following
            # pipeline scope. Each PE accumulates over T partial results. The drain phase last P*T clock cycles.
            # Draining and compute are overlapped.
            # We are generating the loop by explicitly ignoring loop carried dependencies. Therefore, we have
            # to guarantee that the PE will accumulate on the same partial result only when its value is consolidated.
            # The + L is a safe delay between accumulation between the same partial result.
            # It must be computed by considering T and the latency needed to consolidate a partial result
            # (which is the latency of the add + latency for reading and writing to BRAM).

            entry_pipeline, exit_pipeline = state.add_pipeline(
                "compute_and_drain", {
                    "n0": f"0:ceiling({N}/{P})",
                    "tm": f"0:ceiling({M}/{T})",
                    "k": f"0:{K}",
                    "m": f"0:{T} + {L}"
                },
                drain_size=P * T,
                drain_overlap=False,
                additional_iterators={
                    'm_drain': 0,
                    'k_drain': 0
                },
                schedule=dace.ScheduleType.FPGA_Device)

            # Instantiate buffers
            sdfg.add_scalar("A_reg",
                            dtype=dtype_a,
                            transient=True,
                            storage=dace.dtypes.StorageType.FPGA_Registers)
            A_reg = state.add_write("A_reg")
            A_reg_init = state.add_access("A_reg")

            # For C result we are going to use vectorized data type

            # Note: for some of the Sacred Mysteries of Intel OpenCL Compiler (TM), if this buffer is smaller
            # than 24 floats, the II of the pipeline will be 5. Therefore we check this and in case we enlarge it
            buffer_size = T if T_constant is None else max(T_constant, 24)
            sdfg.add_array("C_buffer", [buffer_size],
                           dtype=vec_type,
                           transient=True,
                           storage=dace.dtypes.StorageType.FPGA_Local)
            C_buffer_in = state.add_read("C_buffer")
            C_buffer_out = state.add_write("C_buffer")

            # Init data to reset partial results
            new_sdfg.add_array("C_init",
                               dtype=vec_type,
                               shape=[1],
                               transient=True,
                               storage=dace.dtypes.StorageType.FPGA_Registers)
            C_init = state.add_access("C_init")
            C_init_tasklet = state.add_tasklet("C_data_init", {},
                                               {"init_data"}, "init_data = 0")

            state.add_memlet_path(C_init_tasklet,
                                  C_init,
                                  src_conn="init_data",
                                  memlet=dace.Memlet("C_init[0]"))
            state.add_memlet_path(entry_pipeline,
                                  C_init_tasklet,
                                  memlet=dace.Memlet())

            # Feed A
            # every PE: reads input data, buffer the data assigned to it
            buffer_a_tasklet = state.add_tasklet(
                "buffer_a", {"a_in"}, {
                    "a_reg",
                }, f"""\
if m == 0 and not {entry_pipeline.pipeline.drain_condition()}:
    a_reg = a_in""")

            state.add_memlet_path(A_pipe_in,
                                  entry_pipeline,
                                  buffer_a_tasklet,
                                  memlet=dace.Memlet("A_pipe[p]",
                                                     dynamic=True),
                                  dst_conn="a_in")
            state.add_memlet_path(buffer_a_tasklet,
                                  A_reg,
                                  memlet=dace.Memlet("A_reg[0]", dynamic=True),
                                  src_conn="a_reg")

            # Feed B
            sdfg.add_array("B_reg",
                           shape=[1],
                           dtype=vec_type,
                           transient=True,
                           storage=dace.dtypes.StorageType.FPGA_Local)
            B_reg = state.add_access("B_reg")
            buffer_b_tasklet = state.add_tasklet(
                "buffer_b", {"b_in"}, {"b_reg_out"}, f"""\
if  m>={L} and not {entry_pipeline.pipeline.drain_condition()}:
    b_reg_out = b_in""")

            state.add_memlet_path(B_pipe_in,
                                  entry_pipeline,
                                  buffer_b_tasklet,
                                  memlet=dace.Memlet("B_pipe[p]",
                                                     dynamic=True),
                                  dst_conn="b_in")
            state.add_memlet_path(buffer_b_tasklet,
                                  B_reg,
                                  memlet=dace.Memlet("B_reg[0]", dynamic=True),
                                  src_conn="b_reg_out")

            # Compute, Forward B, and Drain
            compute_tasklet = state.add_tasklet(
                "compute_and_drain",
                {"a_in", "b_in", "c_in", "forward_in", "c_init_data"},
                {"b_out", "c_out", "c_pipe_out"}, f"""\
result = c_in
if m >= {L} and not {entry_pipeline.pipeline.drain_condition()}:
    c_prev = c_init_data if k == 0 else c_in
    result =  c_prev + a_in * b_in
    c_out = result
    if p < {P} - 1:
        b_out = b_in
# Drain
# when we have to drain:
# - if we are working on second assigned row or second tile and we have something to drain
# - if k = K-1 and m>=L: each PE has just finished to compute something
# - if we are in the draining phase
# How: 
# - if k = K-1 and m>=L: then the PE drains its own result
#-  otherwise, if k_drain<p forward data coming from previous PEs (this could happens also in the drain phase)
if((n0 > 0 or tm > 0)  and k_drain <p and m_drain <{T}) or  (k=={K}-1 and m>= {L}) or ({entry_pipeline.pipeline.drain_condition()} and k_drain < p):
    c_pipe_out = result if (p==0 or (k_drain=={K}-1 and not {entry_pipeline.pipeline.drain_condition()})) else forward_in

# adjust draining iterators
if not {entry_pipeline.pipeline.drain_condition()}:
    if m_drain >= {L} +  {T} -1:
        m_drain = 0
        if k_drain >= {K} - 1:
            k_drain = 0
        else:
            k_drain = k_drain +1
    else:
        m_drain = m_drain + 1
else:
    if m_drain >=  {T} -1:
        m_drain = 0
        if k_drain >= {K} - 1:
            k_drain = 0
        else:
            k_drain = k_drain +1
    else:
        m_drain = m_drain + 1
    """)

            state.add_memlet_path(A_reg,
                                  compute_tasklet,
                                  dst_conn="a_in",
                                  memlet=dace.Memlet("A_reg[0]"))
            state.add_memlet_path(B_reg,
                                  compute_tasklet,
                                  memlet=dace.Memlet("B_reg[0]",
                                                     dynamic=False),
                                  dst_conn="b_in")
            state.add_memlet_path(C_init,
                                  compute_tasklet,
                                  memlet=dace.Memlet("C_init[0]"),
                                  dst_conn="c_init_data")

            state.add_memlet_path(compute_tasklet,
                                  exit_pipeline,
                                  B_pipe_out,
                                  memlet=dace.Memlet("B_pipe[p + 1]",
                                                     dynamic=True),
                                  src_conn="b_out")
            state.add_memlet_path(C_buffer_in,
                                  entry_pipeline,
                                  compute_tasklet,
                                  dst_conn="c_in",
                                  memlet=dace.Memlet(f"C_buffer[m-{L}]",
                                                     allow_oob=True))

            state.add_memlet_path(compute_tasklet,
                                  exit_pipeline,
                                  C_buffer_out,
                                  memlet=dace.Memlet(f"C_buffer[m-{L}]",
                                                     allow_oob=True,
                                                     dynamic=True),
                                  src_conn="c_out")

            state.add_memlet_path(C_pipe_in,
                                  entry_pipeline,
                                  compute_tasklet,
                                  memlet=dace.Memlet("C_pipe[p-1]",
                                                     dynamic=True),
                                  dst_conn="forward_in")
            state.add_memlet_path(compute_tasklet,
                                  exit_pipeline,
                                  C_pipe_out,
                                  memlet=dace.Memlet("C_pipe[p]",
                                                     dynamic=True),
                                  src_conn="c_pipe_out")

            # Unroll processing elements
            compute_entry, compute_exit = state.add_map(
                "unroll_compute", {"p": "0:{}".format(P)},
                schedule=dace.ScheduleType.FPGA_Device,
                unroll=True)

            # Bring data nodes into scope
            state.add_memlet_path(compute_entry,
                                  A_pipe_in,
                                  memlet=dace.memlet.Memlet())
            state.add_memlet_path(compute_entry,
                                  B_pipe_in,
                                  memlet=dace.memlet.Memlet())
            state.add_memlet_path(compute_entry,
                                  C_pipe_in,
                                  memlet=dace.memlet.Memlet())

            state.add_memlet_path(B_pipe_out,
                                  compute_exit,
                                  memlet=dace.memlet.Memlet())

            state.add_memlet_path(C_pipe_out,
                                  compute_exit,
                                  memlet=dace.memlet.Memlet())

            state.add_memlet_path(compute_entry,
                                  A_reg_init,
                                  memlet=dace.memlet.Memlet())
            state.add_memlet_path(A_reg_init,
                                  entry_pipeline,
                                  memlet=dace.memlet.Memlet())
            b_init = state.add_access("B_reg")
            state.add_memlet_path(compute_entry, b_init, memlet=dace.Memlet())
            state.add_memlet_path(b_init, entry_pipeline, memlet=dace.Memlet())
            state.add_memlet_path(compute_entry,
                                  C_buffer_in,
                                  memlet=dace.Memlet())
            state.add_memlet_path(C_buffer_out,
                                  compute_exit,
                                  memlet=dace.Memlet())

        # build the compute State

        new_sdfg.add_stream("A_pipe",
                            dtype_a,
                            transient=True,
                            shape=(P, ),
                            storage=dace.dtypes.StorageType.FPGA_Local,
                            buffer_size=str(P))
        new_sdfg.add_stream("B_pipe",
                            vec_type,
                            transient=True,
                            shape=(P + 1, ),
                            buffer_size=1,
                            storage=dace.dtypes.StorageType.FPGA_Local)
        new_sdfg.add_stream("C_pipe",
                            vec_type,
                            transient=True,
                            shape=(P + 1, ),
                            buffer_size=T,
                            storage=dace.dtypes.StorageType.FPGA_Local)

        make_read_A(new_state)
        make_read_B(new_state)
        make_compute(new_sdfg, new_state)
        make_write_C(new_state)
        return new_sdfg
コード例 #7
0
    def make_sdfg(node, parent_state, parent_sdfg):
        sdfg = dace.SDFG(node.label + "_sdfg")

        ((edge_a, outer_array_a, shape_a, strides_a), (edge_b, outer_array_b,
                                                       shape_b, strides_b),
         cdata) = _get_matmul_operands(node, parent_state, parent_sdfg)

        dtype_a = outer_array_a.dtype.type
        dtype_b = outer_array_b.dtype.type
        dtype_c = dace.DTYPE_TO_TYPECLASS[np.result_type(dtype_a,
                                                         dtype_b).type]

        if node.transA:
            trans_shape_a = list(reversed(shape_a))
        else:
            trans_shape_a = shape_a

        if node.transB:
            trans_shape_b = list(reversed(shape_b))
        else:
            trans_shape_b = shape_b

        if (len(trans_shape_a) != 2 or len(trans_shape_b) != 2
                or trans_shape_a[1] != trans_shape_b[0]):
            raise SyntaxError("Matrix sizes must match")
        M, K, N = trans_shape_a[0], trans_shape_a[1], trans_shape_b[1]
        shape_c = (M, N)

        storage = outer_array_a.storage

        _, array_a = sdfg.add_array("_a",
                                    shape_a,
                                    dtype_a,
                                    strides=strides_a,
                                    storage=outer_array_a.storage)
        _, array_b = sdfg.add_array("_b",
                                    shape_b,
                                    dtype_b,
                                    strides=strides_b,
                                    storage=outer_array_b.storage)
        _, array_c = sdfg.add_array("_c",
                                    shape_c,
                                    dtype_c,
                                    strides=cdata[-1],
                                    storage=cdata[1].storage)

        if node.alpha == 1.0:
            mul_program = "__out = __a * __b"
        else:
            mul_program = "__out = {} * __a * __b".format(
                _cast_to_dtype_str(node.alpha, dtype_a))

        if node.beta == 1:
            state = sdfg.add_state(node.label + "_state")
        else:
            init_state = sdfg.add_state(node.label + "_initstate")
            state = sdfg.add_state_after(init_state, node.label + "_state")

        if node.beta != 0:
            sdfg.add_array("_cin",
                           shape_c,
                           dtype_c,
                           strides=cdata[-1],
                           storage=cdata[1].storage)

        mul_out, mul_out_array = "_c", array_c
        output_nodes = None

        # Initialization / beta map
        if node.beta == 0:
            init_state.add_mapped_tasklet(
                'gemm_init', {
                    '_o%d' % i: '0:%s' % symstr(d)
                    for i, d in enumerate(shape_c)
                }, {},
                'out = 0', {
                    'out':
                    dace.Memlet.simple(
                        mul_out, ','.join(
                            ['_o%d' % i for i in range(len(shape_c))]))
                },
                external_edges=True)
        elif node.beta == 1:
            # Do nothing for initialization, only update the values
            pass
        else:
            # Beta map
            add_program = "__y = ({} * __c)".format(
                _cast_to_dtype_str(node.beta, dtype_a))

            # manually broadcasting C to [M, N]
            if list(shape_c) == [M, N]:
                memlet_idx = '__i0, __i1'
            elif list(shape_c) == [1, N]:
                memlet_idx = '0, __i1'
            elif list(shape_c) == [M, 1]:
                memlet_idx = '__i0, 0'
            elif list(shape_c) == [N]:
                memlet_idx = '__i1'
            else:
                raise ValueError(
                    "Could not broadcast input _c to ({}, {})".format(M, N))

            init_state.add_mapped_tasklet(
                "gemm_init",
                {"__i%d" % i: "0:%s" % s
                 for i, s in enumerate([M, N])}, {
                     "__c": dace.Memlet.simple("_cin", memlet_idx),
                 },
                add_program, {"__y": dace.Memlet.simple("_c", "__i0, __i1")},
                external_edges=True)

        # Multiplication map
        state.add_mapped_tasklet(
            "gemm", {"__i%d" % i: "0:%s" % s
                     for i, s in enumerate([M, N, K])},
            {
                "__a":
                dace.Memlet.simple(
                    "_a", "__i2, __i0" if node.transA else "__i0, __i2"),
                "__b":
                dace.Memlet.simple(
                    "_b", "__i1, __i2" if node.transB else "__i2, __i1")
            },
            mul_program, {
                "__out":
                dace.Memlet.simple(
                    mul_out, "__i0, __i1", wcr_str="lambda x, y: x + y")
            },
            external_edges=True,
            output_nodes=output_nodes)

        return sdfg
コード例 #8
0
ファイル: batched_matmul.py プロジェクト: mfkiwl/dace
    def make_sdfg(node, parent_state, parent_sdfg):
        # Get metadata from parent SDFG
        ((edge_a, outer_array_a, shape_a, strides_a), (edge_b, outer_array_b,
                                                       shape_b, strides_b),
         cdata) = _get_matmul_operands(node, parent_state, parent_sdfg)
        outedge = parent_state.out_edges(node)[0]
        cdesc = parent_sdfg.arrays[outedge.data.data]
        bopt = _get_batchmm_opts(shape_a, strides_a, shape_b, strides_b,
                                 cdesc.shape, cdesc.strides)

        if shape_a[-1] != shape_b[-2]:
            raise SyntaxError('Matrix sizes must match')
        if bopt:
            shape_c = (bopt['b'], shape_a[-2], shape_b[-1])
        else:
            shape_c = (shape_a[-2], shape_b[-1])

        dtype_a = outer_array_a.dtype.type
        dtype_b = outer_array_b.dtype.type
        dtype_c = cdesc.dtype.type

        if outer_array_a.storage != outer_array_b.storage:
            raise ValueError("Input matrices must have same storage")
        storage = outer_array_a.storage

        # Create replacement SDFG
        sdfg = dace.SDFG(node.label + "_sdfg")

        _, array_a = sdfg.add_array("_a",
                                    shape_a,
                                    dtype_a,
                                    strides=strides_a,
                                    storage=storage)
        _, array_b = sdfg.add_array("_b",
                                    shape_b,
                                    dtype_b,
                                    strides=strides_b,
                                    storage=storage)
        _, array_c = sdfg.add_array("_c",
                                    shape_c,
                                    dtype_c,
                                    strides=cdata[-1],
                                    storage=storage)

        # Add an initialization state
        init_state = sdfg.add_state()
        init_state.add_mapped_tasklet(
            'batched_matmul_init',
            {'_o%d' % i: '0:%s' % symstr(d)
             for i, d in enumerate(shape_c)}, {},
            'out = 0', {
                'out':
                dace.Memlet.simple(
                    '_c', ','.join(['_o%d' % i for i in range(len(shape_c))]))
            },
            external_edges=True)

        state = sdfg.add_state_after(init_state, node.label + "_state")

        state.add_mapped_tasklet(
            '_BatchedBatchedMatMult_', {
                '__i%d' % i: '0:%s' % s
                for i, s in enumerate([
                    bopt['b'], array_a.shape[-2], array_b.shape[-1],
                    array_a.shape[-1]
                ])
            }, {
                '__a':
                dace.Memlet.simple("_a", ('__i1, __i3' if len(array_a.shape)
                                          == 2 else '__i0, __i1, __i3')),
                '__b':
                dace.Memlet.simple("_b", ('__i3, __i2' if len(array_b.shape)
                                          == 2 else '__i0, __i3, __i2'))
            },
            '__c = __a * __b', {
                '__c':
                dace.Memlet.simple(
                    "_c", '__i0, __i1, __i2', wcr_str='lambda x, y: x + y')
            },
            external_edges=True)

        return sdfg
コード例 #9
0
ファイル: gemv.py プロジェクト: mfkiwl/dace
    def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
        from dace.sdfg.scope import is_devicelevel_gpu
        if is_devicelevel_gpu(sdfg, state, node):
            return ExpandGemvPure.expansion(node, state, sdfg)

        node.validate(sdfg, state)

        ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x,
                                                       shape_x, strides_x),
         (edge_y, outer_array_y, shape_y,
          strides_y)) = _get_matmul_operands(node,
                                             state,
                                             sdfg,
                                             name_lhs="_A",
                                             name_rhs="_x",
                                             name_out="_y")
        dtype_a = outer_array_a.dtype.type
        dtype = outer_array_x.dtype.base_type
        veclen = outer_array_x.dtype.veclen
        m = m or node.m
        n = n or node.n
        if m is None:
            m = shape_y[0]
        if n is None:
            n = shape_x[0]

        transA = node.transA
        if strides_a[0] == 1:
            transA = not transA
            lda = strides_a[1]
        elif strides_a[1] == 1:
            lda = strides_a[0]
        else:
            warnings.warn('Matrix must be contiguous in at least '
                          'one dimension. Falling back to pure expansion.')
            return ExpandGemvPure.expansion(node,
                                            state,
                                            sdfg,
                                            m=m,
                                            n=n,
                                            **kwargs)

        layout = 'CblasColMajor'
        trans = 'CblasNoTrans' if transA else 'CblasTrans'
        if not node.transA:
            m, n = n, m

        if veclen != 1:
            warnings.warn('Vector GEMV not supported, falling back to pure.')
            return ExpandGemvPure.expansion(node,
                                            state,
                                            sdfg,
                                            m=m,
                                            n=n,
                                            **kwargs)

        func, ctype, runtimetype = blas_helpers.cublas_type_metadata(dtype)
        func = func.lower() + 'gemv'

        code = f"""cblas_{func}({layout}, {trans}, {m}, {n}, {node.alpha}, _A, {lda},
                                _x, {strides_x[0]}, {node.beta}, _y, {strides_y[0]});"""

        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)

        return tasklet
コード例 #10
0
ファイル: gemv.py プロジェクト: mfkiwl/dace
    def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
        node.validate(sdfg, state)

        ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x,
                                                       shape_x, strides_x),
         (edge_y, outer_array_y, shape_y,
          strides_y)) = _get_matmul_operands(node,
                                             state,
                                             sdfg,
                                             name_lhs="_A",
                                             name_rhs="_x",
                                             name_out="_y")
        dtype_a = outer_array_a.dtype.type
        dtype = outer_array_x.dtype.base_type
        veclen = outer_array_x.dtype.veclen
        m = m or node.m
        n = n or node.n
        if m is None:
            m = shape_y[0]
        if n is None:
            n = shape_x[0]

        transA = node.transA
        if strides_a[0] == 1:
            transA = not transA
            lda = strides_a[1]
        elif strides_a[1] == 1:
            lda = strides_a[0]
        else:
            warnings.warn('Matrix must be contiguous in at least '
                          'one dimension. Falling back to pure expansion.')
            return ExpandGemvPure.expansion(node,
                                            state,
                                            sdfg,
                                            m=m,
                                            n=n,
                                            **kwargs)

        trans = 'CUBLAS_OP_N' if transA else 'CUBLAS_OP_T'
        if not node.transA:
            m, n = n, m

        if veclen != 1:
            warnings.warn('Vector GEMV not supported, falling back to pure')
            return ExpandGemvPure.expansion(node,
                                            state,
                                            sdfg,
                                            m=m,
                                            n=n,
                                            **kwargs)

        func, ctype, runtimetype = blas_helpers.cublas_type_metadata(dtype)
        func += 'gemv'
        call_prefix = environments.cublas.cuBLAS.handle_setup_code(node)
        call_suffix = ''

        # Handle alpha / beta
        constants = {
            1.0:
            f"__state->cublas_handle.Constants(__dace_cuda_device).{runtimetype}Pone()",
            0.0:
            f"__state->cublas_handle.Constants(__dace_cuda_device).{runtimetype}Zero()",
        }
        if node.alpha not in constants or node.beta not in constants:
            # Deal with complex input constants
            if isinstance(node.alpha, complex):
                alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})'
            else:
                alpha = f'{dtype.ctype}({node.alpha})'
            if isinstance(node.beta, complex):
                beta = f'{dtype.ctype}({node.beta.real}, {node.beta.imag})'
            else:
                beta = f'{dtype.ctype}({node.beta})'

            # Set pointer mode to host
            call_prefix += f'''cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_HOST);
            {dtype.ctype} alpha = {alpha};
            {dtype.ctype} beta = {beta};
            '''
            call_suffix += '''
cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_DEVICE);
            '''
            alpha = f'({ctype} *)&alpha'
            beta = f'({ctype} *)&beta'
        else:
            alpha = constants[node.alpha]
            beta = constants[node.beta]

        code = (call_prefix + f"""
cublas{func}(__dace_cublas_handle, {trans}, {m}, {n}, {alpha}, _A, {lda},
             _x, {strides_x[0]}, {beta}, _y, {strides_y[0]});
                """ + call_suffix)

        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)

        return tasklet
コード例 #11
0
ファイル: gemv.py プロジェクト: mfkiwl/dace
    def expansion(node, parent_state, parent_sdfg, **kwargs):
        node.validate(parent_sdfg, parent_state)
        sdfg = dace.SDFG(node.label + "_sdfg")
        ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x,
                                                       shape_x, strides_x),
         (edge_y, outer_array_y, shape_y,
          strides_y)) = _get_matmul_operands(node,
                                             parent_state,
                                             parent_sdfg,
                                             name_lhs="_A",
                                             name_rhs="_x",
                                             name_out="_y")
        dtype_a = outer_array_a.dtype.type
        dtype_x = outer_array_x.dtype.type
        dtype_y = outer_array_y.dtype.type

        if outer_array_a.dtype.veclen > 1 or outer_array_x.dtype.veclen > 1:
            raise NotImplementedError("Vectorization for pure GEMV NYI.")

        if node.transA:
            trans_shape_a = list(reversed(shape_a))
        else:
            trans_shape_a = shape_a

        if trans_shape_a[1] != shape_x[0]:
            raise SyntaxError(
                "Matrix-vector product size mismatch: {} vs. {}".format(
                    trans_shape_a[1], shape_x[0]))

        N, M = trans_shape_a[0], trans_shape_a[1]

        if outer_array_a.storage != outer_array_x.storage:
            raise ValueError("Input matrices must have same storage")
        storage = outer_array_a.storage

        _, array_a = sdfg.add_array("_A",
                                    shape_a,
                                    dtype_a,
                                    strides=strides_a,
                                    storage=storage)
        _, array_x = sdfg.add_array("_x",
                                    shape_x,
                                    dtype_x,
                                    strides=strides_x,
                                    storage=storage)
        _, array_y = sdfg.add_array("_y",
                                    shape_y,
                                    dtype_y,
                                    strides=strides_y,
                                    storage=storage)

        mul_program = "__out = {} * __A * __x".format(node.alpha)

        init_state = sdfg.add_state(node.label + "_initstate")
        state = sdfg.add_state_after(init_state, node.label + "_state")

        if node.beta == 0:
            mul_out, mul_out_array = "_y", array_y
            output_nodes = None
        else:
            mul_out, mul_out_array = tmp, array_tmp = sdfg.add_temp_transient(
                shape_y, dtype_y, storage=storage)

            access_tmp = state.add_read(tmp)
            output_nodes = {mul_out: access_tmp}

        # Initialization map
        init_state.add_mapped_tasklet(
            "gemv_init", {
                "_o%d" % i: "0:%s" % symbolic.symstr(d)
                for i, d in enumerate(shape_y)
            }, {},
            "out = 0", {
                "out":
                dace.Memlet("{}[{}]".format(
                    mul_out, ",".join(
                        ["_o%d" % i for i in range(len(shape_y))])))
            },
            external_edges=True)

        # Multiplication map
        state.add_mapped_tasklet(
            "_GEMV_", {"__i%d" % i: "0:%s" % s
                       for i, s in enumerate([N, M])}, {
                           "__A":
                           dace.Memlet("_A[{}]".format(
                               "__i1, __i0" if node.transA else "__i0, __i1")),
                           "__x":
                           dace.Memlet("_x[__i1]")
                       },
            mul_program, {
                "__out": dace.Memlet(f"{mul_out}[__i0]",
                                     wcr="lambda x, y: x + y")
            },
            external_edges=True,
            output_nodes=output_nodes)

        add_program = "__y_out = ({} * __y_in) + __tmp".format(node.beta)

        memlet_idx = "__i"

        # addition map
        if node.beta != 0:
            state.add_mapped_tasklet("_Add_", {"__i": "0:{}".format(N)}, {
                "__y_in": dace.Memlet(f"_y[{memlet_idx}]"),
                "__tmp": dace.Memlet(f"{mul_out}[__i]"),
            },
                                     add_program,
                                     {"__y_out": dace.Memlet("_y[__i]")},
                                     external_edges=True,
                                     input_nodes={mul_out: access_tmp})

        return sdfg
コード例 #12
0
ファイル: gemv.py プロジェクト: mfkiwl/dace
    def expansion(node,
                  parent_state,
                  parent_sdfg,
                  tile_size_x=None,
                  tile_size_y=None,
                  num_partial_sums=16):
        """
        :param node: Node to expand.
        :param parent_state: State that the node is in.
        :param parent_sdfg: SDFG that the node is in.
        :param tile_size_x: Tile size along the dimension of the vector x. If
                            set to None, no tiling is used, corresponding to
                            setting the tile size equal to the full size of x.
        :param tile_size_y: Tile size along the dimension of the vector y. If
                            set to None, no tiling is used, corresponding to
                            setting the tile size equal to the full size of y.
        :param num_partial_sums: The number of distinct registers to accumulate
                                 contributions to the final sum into. Should be
                                 a power of two, and should be higher than the
                                 latency of adding two numbers of the given
                                 data type.
        """

        node.validate(parent_sdfg, parent_state)

        sdfg = dace.SDFG("gemv")
        state = sdfg.add_state("gemv")

        alpha = node.alpha
        beta = node.beta

        # Get input/output data (the method considers also the presence of view nodes)
        ((edge_a, desc_a, shape_a, strides_a), (edge_x, desc_x, shape_x,
                                                strides_x),
         (edge_y, desc_y, shape_y,
          strides_y)) = _get_matmul_operands(node,
                                             parent_state,
                                             parent_sdfg,
                                             name_lhs="_A",
                                             name_rhs="_x",
                                             name_out="_y")

        # Create local versions of input/output data nodes
        _, desc_a = sdfg.add_array("_A",
                                   shape_a,
                                   desc_a.dtype,
                                   strides=strides_a,
                                   storage=desc_a.storage,
                                   transient=False)
        _, desc_x = sdfg.add_array("_x",
                                   shape_x,
                                   desc_x.dtype,
                                   strides=strides_x,
                                   storage=desc_x.storage,
                                   transient=False)
        _, desc_y_y = sdfg.add_array("_y",
                                     shape_y,
                                     desc_y.dtype,
                                     strides=strides_y,
                                     storage=desc_y.storage,
                                     transient=False)

        if node.transA and desc_a.dtype.veclen > 1:
            raise NotImplementedError(
                "Vectorization not implemented for transposed A.")

        # Create accesses
        read_a = state.add_read("_A")
        read_x = state.add_read("_x")
        if beta != 0:
            read_y = state.add_read("_y")
        write_y = state.add_write("_y")

        size_x = desc_x.shape[0]
        size_y = desc_y.shape[0]
        if tile_size_x is None:
            tile_size_x = size_x
        if tile_size_y is None:
            tile_size_y = size_y
        num_tiles_y = f"{size_y}/{tile_size_y}"
        num_tiles_x = f"{size_x}/{tile_size_x}"

        veclen = desc_a.dtype.veclen

        # Create tile map
        y_tile_entry, y_tile_exit = state.add_map(
            "y_tiles", {"ty": f"0:{num_tiles_y}"},
            schedule=dace.ScheduleType.FPGA_Device)
        x_tile_entry, x_tile_exit = state.add_map(
            "x_tiles", {"tx": f"0:{num_tiles_x}"},
            schedule=dace.ScheduleType.FPGA_Device)

        # Create y map
        y_entry, y_exit = state.add_map("y", {"iy": f"0:{tile_size_y}"},
                                        schedule=dace.ScheduleType.FPGA_Device)

        # Create x map
        x_entry, x_exit = state.add_map("x", {"ix": f"0:{tile_size_x}"},
                                        schedule=dace.ScheduleType.FPGA_Device)

        # Local buffer of x
        sdfg.add_array("x_local", (tile_size_x, ),
                       desc_x.dtype,
                       storage=dace.StorageType.FPGA_Local,
                       transient=True)
        x_local_access = state.add_read("x_local")

        if beta != 0:
            raise NotImplementedError("Not yet implemented.")

        multiply_tasklet = state.add_tasklet("multiply", {"A_in", "x_in"},
                                             {f"product": desc_a.dtype},
                                             "product = A_in * x_in")

        if isinstance(desc_a, dt.Stream):
            subset = "0"
        elif node.transA:
            subset = f"tx * {tile_size_x} + ix, ty * {tile_size_y} + iy"
        else:
            subset = f"ty * {tile_size_y} + iy, tx * {tile_size_x} + ix"
        state.add_memlet_path(read_a,
                              y_tile_entry,
                              x_tile_entry,
                              y_entry,
                              x_entry,
                              multiply_tasklet,
                              dst_conn="A_in",
                              memlet=dace.Memlet(f"_A[{subset}]"))
        read_x_entry, read_x_exit = state.add_map(
            "read_x", {"ix": f"0:{tile_size_x}"},
            schedule=dace.ScheduleType.FPGA_Device)
        subset = ("0" if isinstance(desc_x, dt.Stream) else
                  f"tx*{tile_size_x} + ix")
        read_x_tasklet = state.add_tasklet("read_x", {"x_memory"},
                                           {"x_buffer"}, "x_buffer = x_memory")
        state.add_memlet_path(read_x,
                              y_tile_entry,
                              x_tile_entry,
                              read_x_entry,
                              read_x_tasklet,
                              dst_conn="x_memory",
                              memlet=dace.Memlet(f"_x[{subset}]"))
        state.add_memlet_path(read_x_tasklet,
                              read_x_exit,
                              x_local_access,
                              src_conn="x_buffer",
                              memlet=dace.Memlet(f"x_local[ix]"))
        state.add_memlet_path(x_local_access,
                              y_entry,
                              x_entry,
                              multiply_tasklet,
                              dst_conn="x_in",
                              memlet=dace.Memlet(f"x_local[ix]"))

        # Write to buffer
        sdfg.add_array("product_vector", (1, ),
                       desc_a.dtype,
                       transient=True,
                       storage=dace.StorageType.FPGA_Local)
        product_vector = state.add_access("product_vector")
        state.add_memlet_path(multiply_tasklet,
                              product_vector,
                              src_conn="product",
                              memlet=dace.Memlet(f"product_vector[0]"))

        # Vector length conversion
        sdfg.add_array("product_scalar", (veclen, ),
                       desc_a.dtype.base_type,
                       transient=True,
                       storage=dace.StorageType.FPGA_Local)
        product_scalar = state.add_access("product_scalar")
        state.add_memlet_path(product_vector,
                              product_scalar,
                              memlet=dace.Memlet(f"product_vector[0]",
                                                 other_subset=f"0:{veclen}"))

        # Now we need to collapse this
        reduce_vector_entry, reduce_vector_exit = state.add_map(
            "reduce_vector", {"u": f"0:{veclen}"},
            schedule=dace.ScheduleType.FPGA_Device,
            unroll=True)

        reduce_vector_tasklet = state.add_tasklet(
            "reduce_vector", {"product_in", "acc_in"}, {"acc_out"},
            "acc_out = product_in + acc_in")
        state.add_memlet_path(product_scalar,
                              reduce_vector_entry,
                              reduce_vector_tasklet,
                              dst_conn="product_in",
                              memlet=dace.Memlet(f"{product_scalar}[u]"))

        # Add accumulation register
        sdfg.add_array("accumulate_product", (1, ),
                       desc_a.dtype.base_type,
                       transient=True,
                       storage=dace.StorageType.FPGA_Local)
        accumulate_product_read = state.add_access("accumulate_product")
        accumulate_product_write = state.add_access("accumulate_product")

        # Initialize it to zero
        init_reduce_vector_tasklet = state.add_tasklet("init_reduce_vector",
                                                       {}, {"acc_out"},
                                                       "acc_out = 0")
        state.add_memlet_path(x_entry,
                              init_reduce_vector_tasklet,
                              memlet=dace.Memlet())
        state.add_memlet_path(init_reduce_vector_tasklet,
                              accumulate_product_read,
                              src_conn="acc_out",
                              memlet=dace.Memlet(f"accumulate_product[0]"))

        # Connect it to the tasklet
        state.add_memlet_path(accumulate_product_read,
                              reduce_vector_entry,
                              reduce_vector_tasklet,
                              dst_conn="acc_in",
                              memlet=dace.Memlet(f"accumulate_product[0]"))
        state.add_memlet_path(reduce_vector_tasklet,
                              reduce_vector_exit,
                              accumulate_product_write,
                              src_conn="acc_out",
                              memlet=dace.Memlet(f"accumulate_product[0]"))

        # Partial sums
        sdfg.add_array("partial_sums", (num_partial_sums, ),
                       desc_y.dtype,
                       storage=dace.StorageType.FPGA_Registers,
                       transient=True)
        partial_sum_read = state.add_read("partial_sums")
        partial_sum_write = state.add_access("partial_sums")

        # Output array
        sdfg.add_array("y_local", (tile_size_y, ),
                       desc_y.dtype,
                       storage=dace.StorageType.FPGA_Local,
                       transient=True)

        # Now we need to actually accumulate into a local register of y
        y_local_read = state.add_read("y_local")
        y_local_write = state.add_read("y_local")
        update_y_tasklet = state.add_tasklet(
            "update_y", {"y_in", "acc_in"}, {"acc_out"}, f"""\
prev = acc_in if ix >= {num_partial_sums} else 0
acc_out = prev + y_in""")
        state.add_memlet_path(accumulate_product_write,
                              update_y_tasklet,
                              dst_conn="y_in",
                              memlet=dace.Memlet(f"accumulate_product[0]"))
        state.add_memlet_path(
            partial_sum_read,
            x_entry,
            update_y_tasklet,
            dst_conn="acc_in",
            memlet=dace.Memlet(f"partial_sums[ix%{num_partial_sums}]"))
        state.add_memlet_path(y_tile_entry, y_local_read, memlet=dace.Memlet())
        state.add_memlet_path(y_entry, partial_sum_read, memlet=dace.Memlet())
        state.add_memlet_path(
            update_y_tasklet,
            x_exit,
            partial_sum_write,
            src_conn="acc_out",
            memlet=dace.Memlet(f"partial_sums[ix%{num_partial_sums}]"))

        # Reduce the partial sums
        reduce_sums_entry, reduce_sums_exit = state.add_map(
            "reduce_partial_sums", {"u": f"0:{num_partial_sums}"},
            schedule=dace.ScheduleType.FPGA_Device,
            unroll=True)
        reduce_sums_tasklet = state.add_tasklet(
            "reduce_partial_sums", {"sum_in", "val_in"}, {"sum_out"}, """
prev = sum_in if u > 0 else 0
sum_out = prev + val_in""")
        sdfg.add_array("accumulate_sum", (1, ),
                       desc_y.dtype,
                       transient=True,
                       storage=dace.StorageType.FPGA_Local)
        accumulate_sum_read = state.add_access("accumulate_sum")
        accumulate_sum_write = state.add_access("accumulate_sum")
        state.add_memlet_path(y_entry,
                              accumulate_sum_read,
                              memlet=dace.Memlet())
        state.add_memlet_path(accumulate_sum_read,
                              reduce_sums_entry,
                              reduce_sums_tasklet,
                              dst_conn="sum_in",
                              memlet=dace.Memlet("accumulate_sum[0]"))
        state.add_memlet_path(reduce_sums_tasklet,
                              reduce_sums_exit,
                              accumulate_sum_write,
                              src_conn="sum_out",
                              memlet=dace.Memlet("accumulate_sum[0]"))
        state.add_memlet_path(partial_sum_write,
                              reduce_sums_entry,
                              reduce_sums_tasklet,
                              dst_conn="val_in",
                              memlet=dace.Memlet("partial_sums[u]"))

        # Combine with y buffer
        combine_tasklet = state.add_tasklet(
            "combine_y", {"val", "buffer_in"}, {"buffer_out"}, """\
prev = buffer_in if tx > 0 else 0
buffer_out = prev + val""")
        state.add_memlet_path(accumulate_sum_write,
                              combine_tasklet,
                              dst_conn="val",
                              memlet=dace.Memlet("accumulate_sum[0]"))
        state.add_memlet_path(y_local_read,
                              x_tile_entry,
                              y_entry,
                              combine_tasklet,
                              dst_conn="buffer_in",
                              memlet=dace.Memlet("y_local[iy]"))

        state.add_memlet_path(combine_tasklet,
                              y_exit,
                              x_tile_exit,
                              y_local_write,
                              src_conn="buffer_out",
                              memlet=dace.Memlet(f"y_local[iy]"))

        subset = ("0" if isinstance(desc_y, dt.Stream) else
                  f"ty*{tile_size_y} + iy")
        write_y_entry, write_y_exit = state.add_map(
            "write_y", {"iy": f"0:{tile_size_y}"},
            schedule=dace.ScheduleType.FPGA_Device)
        write_y_tasklet = state.add_tasklet("write_y", {"y_buffer"},
                                            {"y_memory"},
                                            "y_memory = y_buffer")
        state.add_memlet_path(y_local_write,
                              write_y_entry,
                              write_y_tasklet,
                              dst_conn="y_buffer",
                              memlet=dace.Memlet(f"y_local[iy]"))
        state.add_memlet_path(write_y_tasklet,
                              write_y_exit,
                              y_tile_exit,
                              write_y,
                              src_conn="y_memory",
                              memlet=dace.Memlet(f"_y[{subset}]"))

        return sdfg
コード例 #13
0
    def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs):
        node.validate(sdfg, state)

        ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x,
                                                       shape_x, strides_x),
         (edge_y, outer_array_y, shape_y,
          strides_y)) = _get_matmul_operands(node,
                                             state,
                                             sdfg,
                                             name_lhs="_A",
                                             name_rhs="_x",
                                             name_out="_y")
        dtype_a = outer_array_a.dtype.type
        dtype = outer_array_x.dtype.base_type
        veclen = outer_array_x.dtype.veclen
        m = m or node.m
        n = n or node.n
        if m is None:
            m = shape_y[0]
        if n is None:
            n = shape_x[0]

        transA = node.transA
        if strides_a[0] == 1:
            transA = not transA
            lda = strides_a[1]
        elif strides_a[1] == 1:
            lda = strides_a[0]
        else:
            warnings.warn('Matrix must be contiguous in at least '
                          'one dimension. Falling back to pure expansion.')
            return ExpandGemvPure.expansion(node,
                                            state,
                                            sdfg,
                                            m=m,
                                            n=n,
                                            **kwargs)

        trans = 'CUBLAS_OP_N' if transA else 'CUBLAS_OP_T'
        if not node.transA:
            m, n = n, m

        if veclen != 1:
            warnings.warn('Vector GEMV not supported, falling back to pure')
            return ExpandGemvPure.expansion(node,
                                            state,
                                            sdfg,
                                            m=m,
                                            n=n,
                                            **kwargs)

        func, ctype, runtimetype = blas_helpers.cublas_type_metadata(dtype)
        func += 'gemv'

        # TODO: (alpha,beta) != (1,0)
        if node.alpha != 1.0 or node.beta != 0.0:
            raise NotImplementedError
        alpha = (
            '__state->cublas_handle.Constants(__dace_cuda_device).%sPone()' %
            runtimetype)
        beta = (
            '__state->cublas_handle.Constants(__dace_cuda_device).%sZero()' %
            runtimetype)

        code = (environments.cublas.cuBLAS.handle_setup_code(node) + f"""
cublas{func}(__dace_cublas_handle, {trans}, {m}, {n}, {alpha}, _A, {lda},
             _x, {strides_x[0]}, {beta}, _y, {strides_y[0]});""")

        tasklet = dace.sdfg.nodes.Tasklet(node.name,
                                          node.in_connectors,
                                          node.out_connectors,
                                          code,
                                          language=dace.dtypes.Language.CPP)

        return tasklet
コード例 #14
0
    def make_sdfg(node, parent_state, parent_sdfg):
        sdfg = dace.SDFG(node.label + "_sdfg")

        ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x,
                                                       shape_x, _),
         _) = _get_matmul_operands(node,
                                   parent_state,
                                   parent_sdfg,
                                   name_lhs="_a",
                                   name_rhs="_x",
                                   name_out="_y")

        dtype_a = outer_array_a.dtype.type
        dtype_x = outer_array_x.dtype.type
        dtype_y = dace.DTYPE_TO_TYPECLASS[np.result_type(dtype_a,
                                                         dtype_x).type]

        if node.transA:
            trans_shape_a = list(reversed(shape_a))
        else:
            trans_shape_a = shape_a

        if trans_shape_a[1] != shape_x[0]:
            raise SyntaxError(
                "Matrix-vector product size mismatch: {} vs. {}".format(
                    trans_shape_a[1], shape_x[0]))

        N, M = trans_shape_a[0], trans_shape_a[1]
        shape_y = (N, )

        if outer_array_a.storage != outer_array_x.storage:
            raise ValueError("Input matrices must have same storage")
        storage = outer_array_a.storage

        _, array_a = sdfg.add_array("_a",
                                    shape_a,
                                    dtype_a,
                                    strides=strides_a,
                                    storage=storage)
        _, array_x = sdfg.add_array("_x", shape_x, dtype_x, storage=storage)
        _, array_y = sdfg.add_array("_y", shape_y, dtype_y, storage=storage)

        if node.alpha == 1.0:
            mul_program = "__out = __a * __x"
        else:
            mul_program = "__out = {} * __a * __x".format(
                _cast_to_dtype_str(node.alpha, dtype_a))

        init_state = sdfg.add_state(node.label + "_initstate")
        state = sdfg.add_state_after(init_state, node.label + "_state")

        if node.beta == 0:
            mul_out, mul_out_array = "_y", array_y
            output_nodes = None
        else:
            mul_out, mul_out_array = tmp, array_tmp = sdfg.add_temp_transient(
                shape_y, dtype_y, storage=storage)

            access_tmp = state.add_read(tmp)
            output_nodes = {mul_out: access_tmp}

        # Initialization map
        init_state.add_mapped_tasklet(
            "gemv_init",
            {"_o%d" % i: "0:%s" % symstr(d)
             for i, d in enumerate(shape_y)}, {},
            "out = 0", {
                "out":
                dace.Memlet.simple(
                    mul_out, ",".join(
                        ["_o%d" % i for i in range(len(shape_y))]))
            },
            external_edges=True)

        # Multiplication map
        state.add_mapped_tasklet(
            "_GEMV_", {"__i%d" % i: "0:%s" % s
                       for i, s in enumerate([N, M])},
            {
                "__a":
                dace.Memlet.simple(
                    "_a", "__i1, __i0" if node.transA else "__i0, __i1"),
                "__x":
                dace.Memlet.simple("_x", "__i1")
            },
            mul_program, {
                "__out":
                dace.Memlet.simple(
                    mul_out, "__i0", wcr_str="lambda x, y: x + y")
            },
            external_edges=True,
            output_nodes=output_nodes)

        if node.beta != 0:
            add_program = "__y_out = ({} * __y_in) + __tmp".format(
                _cast_to_dtype_str(node.beta, dtype_a))

            memlet_idx = "__i"

            # addition map
            state.add_mapped_tasklet(
                "_Add_", {"__i": "0:{}".format(N)}, {
                    "__y_in": dace.Memlet.simple("_y", memlet_idx),
                    "__tmp": dace.Memlet.simple(mul_out, "__i"),
                },
                add_program, {"__y_out": dace.Memlet.simple("_y", "__i")},
                external_edges=True,
                input_nodes={mul_out: access_tmp})

        return sdfg
コード例 #15
0
ファイル: dot.py プロジェクト: JanKleine/dace
    def make_sdfg(node, parent_state, parent_sdfg):
        sdfg = dace.SDFG(node.label + "_sdfg")

        ((edge_x, outer_array_x, shape_x, _), (edge_y, outer_array_y, shape_y,
                                               _),
         (_, outer_array_result, shape_result,
          _)) = _get_matmul_operands(node,
                                     parent_state,
                                     parent_sdfg,
                                     name_lhs="_x",
                                     name_rhs="_y",
                                     name_out="_result")

        dtype_x = outer_array_x.dtype.type
        dtype_y = outer_array_y.dtype.type
        dtype_result = outer_array_result.dtype.type

        if shape_x != shape_y or shape_result != [1]:
            raise SyntaxError("Invalid shapes to dot product.")

        N = shape_x[0]

        if outer_array_x.storage != outer_array_y.storage:
            raise ValueError("Input matrices must have same storage")
        storage = outer_array_x.storage

        _, array_x = sdfg.add_array("_x", shape_x, dtype_x, storage=storage)
        _, array_y = sdfg.add_array("_y", shape_y, dtype_y, storage=storage)
        _, array_result = sdfg.add_array("_result", [1],
                                         dtype_result,
                                         storage=storage)

        mul_program = "__out = __x * __y"

        init_state = sdfg.add_state(node.label + "_initstate")
        state = sdfg.add_state_after(init_state, node.label + "_state")

        mul_out, mul_out_array = "_result", array_result
        output_nodes = None

        # Initialization map
        init_write = init_state.add_write("_result")
        init_tasklet = init_state.add_tasklet("dot_init", {}, {"_out"},
                                              "_out = 0",
                                              location=node.location)
        init_state.add_memlet_path(init_tasklet,
                                   init_write,
                                   src_conn="_out",
                                   memlet=dace.Memlet.simple(init_write.data,
                                                             "0",
                                                             num_accesses=1))

        # Multiplication map
        state.add_mapped_tasklet(
            "_DOT_", {"__i": "0:{}".format(N)}, {
                "__x": dace.Memlet.simple("_x", "__i"),
                "__y": dace.Memlet.simple("_y", "__i")
            },
            mul_program, {
                "__out":
                dace.Memlet.simple(mul_out, "0", wcr_str="lambda x, y: x + y")
            },
            external_edges=True,
            output_nodes=output_nodes)

        return sdfg