def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs): node.validate(sdfg, state) ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x, shape_x, strides_x), (edge_y, outer_array_y, shape_y, strides_y)) = _get_matmul_operands(node, state, sdfg, name_lhs="_A", name_rhs="_x", name_out="_y") dtype_a = outer_array_a.dtype.type dtype = outer_array_x.dtype.base_type veclen = outer_array_x.dtype.veclen m = m or node.m n = n or node.n if m is None: m = shape_y[0] if n is None: n = shape_x[0] transA = node.transA Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True) Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True) try: sdfg.add_symbol('Px', dace.int32) sdfg.add_symbol('Py', dace.int32) except FileExistsError: pass @dace.program def _gemNv_pblas(_A: dtype[m, n], _x: dtype[n], _y: dtype[m]): lA = np.empty((m // Px, n // Py), dtype=_A.dtype) lx = np.empty((n // Px, ), dtype=_x.dtype) dace.comm.BCScatter(_A, lA, (m // Px, n // Py)) dace.comm.BCScatter(_x, lx, (n // Px, 1)) ly = distr.MatMult(_A, _x, lA, lx, (m // Px, n // Py), (n // Px, 1)) dace.comm.BCGather(ly, _y, (m // Px, 1)) @dace.program def _gemTv_pblas(_A: dtype[m, n], _x: dtype[m], _y: dtype[n]): lA = np.empty((m // Px, n // Py), dtype=_A.dtype) lx = np.empty((m // Px, ), dtype=_x.dtype) dace.comm.BCScatter(_A, lA, (m // Px, n // Py)) dace.comm.BCScatter(_x, lx, (m // Px, 1)) ly = distr.MatMult(_x, _A, lx, lA, (m // Px, 1), (m // Px, n // Py)) dace.comm.BCGather(ly, _y, (n // Px, 1)) # NOTE: The following is done to avoid scalar promotion, which results # in ValueError: Node type "BlockCyclicScatter" not supported for # promotion if transA: sdfg = _gemTv_pblas.to_sdfg(simplify=False) else: sdfg = _gemNv_pblas.to_sdfg(simplify=False) sdfg.simplify() return sdfg
def expansion(node, state, sdfg): node.validate(sdfg, state) (_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg) dtype = adesc.dtype.base_type if node.beta != 0: raise NotImplementedError M = ashape[0] K = ashape[1] N = bshape[1] Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True) Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True) try: sdfg.add_symbol('Px', dace.int32) sdfg.add_symbol('Py', dace.int32) except FileExistsError: pass @dace.program def _gemm_pblas(_a: dtype[M, K], _b: dtype[K, N], _c: dtype[M, N]): lA = np.empty((M // Px, K // Py), dtype=_a.dtype) lB = np.empty((K // Px, N // Py), dtype=_b.dtype) dace.comm.BCScatter(_a, lA, (M // Px, K // Py)) dace.comm.BCScatter(_b, lB, (K // Px, N // Py)) lC = distr.MatMult(_a, _b, lA, lB, (M // Px, K // Py), (K // Px, N // Py)) dace.comm.BCGather(lC, _c, (M // Px, N // Py)) return _gemm_pblas.to_sdfg()
def expansion(node, state, sdfg): node.validate(sdfg, state) (_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg) dtype = adesc.dtype.base_type func = to_blastype(dtype.type).lower() + 'gemm' alpha = f'{dtype.ctype}({node.alpha})' beta = f'{dtype.ctype}({node.beta})' # Deal with complex input constants if isinstance(node.alpha, complex): alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})' if isinstance(node.beta, complex): beta = f'{dtype.ctype}({node.beta.real}, {node.beta.imag})' cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc) opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, dtype.ctype, func) # Adaptations for BLAS API opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans' opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans' code = '' if dtype in (dace.complex64, dace.complex128): code = f''' {dtype.ctype} alpha = {alpha}; {dtype.ctype} beta = {beta}; ''' opt['alpha'] = '&alpha' opt['beta'] = '&beta' code += ("cblas_{func}(CblasColMajor, {ta}, {tb}, " "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " "_c, {ldc});").format_map(opt) tasklet = dace.sdfg.nodes.Tasklet( node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP, ) return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) (_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg) cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] check_access(dtypes.ScheduleType.CPU_Multicore, adesc, bdesc, cdesc) dtype = cdesc.dtype.base_type func = to_blastype(dtype.type).lower() + 'gemm' if dtype == dace.float32: alpha = "1.0f" beta = "0.0f" elif dtype == dace.float64: alpha = "1.0" beta = "0.0" elif dtype == dace.complex64: alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" beta = "dace::blas::BlasConstants::Get().Complex64Zero()" elif dtype == dace.complex128: alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" beta = "dace::blas::BlasConstants::Get().Complex128Zero()" else: raise ValueError("Unsupported type for BLAS dot product: " + str(dtype)) opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdesc.dtype.ctype, func) # Adaptations for MKL/BLAS API opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans' opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans' code = ''' for (int __ib = 0; __ib < {BATCH}; ++__ib) {{ cblas_{func}(CblasColMajor, {ta}, {tb}, {M}, {N}, {K}, {alpha}, (({dtype}*){x}) + __ib*{stride_a}, {lda}, (({dtype}*){y}) + __ib*{stride_b}, {ldb}, {beta}, (({dtype}*)_c) + __ib*{stride_c}, {ldc}); }}'''.format_map(opt) tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet
def expansion(node, state, sdfg): node.validate(sdfg, state) (_, adesc, ashape, astrides), (_, bdesc, bshape, bstrides), _ = _get_matmul_operands(node, state, sdfg) dtype = adesc.dtype.base_type func = to_blastype(dtype.type).lower() + 'gemm' # TODO: Fix w.r.t. other alpha/beta values if dtype == dace.float32: alpha = "1.0f" beta = "0.0f" elif dtype == dace.float64: alpha = "1.0" beta = "0.0" elif dtype == dace.complex64: alpha = "dace::blas::BlasConstants::Get().Complex64Pone()" beta = "dace::blas::BlasConstants::Get().Complex64Zero()" elif dtype == dace.complex128: alpha = "dace::blas::BlasConstants::Get().Complex128Pone()" beta = "dace::blas::BlasConstants::Get().Complex128Zero()" else: raise ValueError("Unsupported type for BLAS dot product: " + str(dtype)) cdesc = sdfg.arrays[state.out_edges(node)[0].data.data] opt = _get_codegen_gemm_opts(node, state, sdfg, adesc, bdesc, cdesc, alpha, beta, cdesc.dtype.ctype, func) # Adaptations for MKL/BLAS API opt['ta'] = 'CblasNoTrans' if opt['ta'] == 'N' else 'CblasTrans' opt['tb'] = 'CblasNoTrans' if opt['tb'] == 'N' else 'CblasTrans' code = ("cblas_{func}(CblasColMajor, {ta}, {tb}, " "{M}, {N}, {K}, {alpha}, {x}, {lda}, {y}, {ldb}, {beta}, " "_c, {ldc});").format_map(opt) tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet
def expansion(node, parent_state, parent_sdfg, num_pes=32, tile_size_m=None): ''' GEMM node expansion. :param node: Node to expand. :param parent_state: State that the node is in. :param parent_sdfg: SDFG that the node is in. :param num_pes: Number of Processing Elements of the systolic array. By default it is set to 32. :param tile_size_m: tiling size considering columns of the input matrix B and resulting matrix C. If B/C are vectorized, the tile size refers to the vectorized container. If set to None, no tiling is used, corresponding to setting the tile size equal to the number of columns of B/C. :return: ''' ((edge_a, outer_array_a, shape_a, strides_a), (edge_b, outer_array_b, shape_b, strides_b), (edge_c, outer_array_c, shape_c, strides_c)) = _get_matmul_operands(node, parent_state, parent_sdfg) dtype_a = outer_array_a.dtype.type dtype_b = outer_array_b.dtype.type dtype_c = dace.DTYPE_TO_TYPECLASS[np.result_type(dtype_a, dtype_b).type] shape_c = (shape_a[0], shape_b[1]) if node.transA: raise NotImplementedError( "GEMM FPGA expansion not implemented for transposed A.") if node.transB: raise NotImplementedError( "GEMM FPGA expansion not implemented for transposed B.") if outer_array_a.veclen > 1: raise NotImplementedError( "Vectorization not support for input array A.") if len(shape_a) != 2 or len(shape_b) != 2 or shape_a[1] != shape_b[0]: raise SyntaxError("Matrix sizes must match") if outer_array_b.dtype.veclen != outer_array_c.dtype.veclen: raise SyntaxError("Vectorization lengths of B and C must match") ###################################################################### # GEMM Parameters and checks # Note: the following sizes consider also vectorization vec_width = outer_array_b.dtype.veclen vec_type = dace.vector(dtype_c, vec_width) N, K, M = shape_a[0], shape_a[1], shape_b[1] P = num_pes T = tile_size_m if T is None: T = M # we will perform sanity check using T and M. But at this stage, we still # don't know to what outer symbol they will map. # We try to resolve them to constant if they are symbolic, otherwise we skip the checks T_constant = dace.symbolic.resolve_symbol_to_constant(T, parent_sdfg) K_constant = dace.symbolic.resolve_symbol_to_constant(K, parent_sdfg) # Safe delay: this will be used in the compute state, pipeline scope, to insert # a delay between accumulation on the same result if needed. # Further explanations are provided in the compute state. # Note: this is a platform and type dependent parameter. if T_constant is not None: L = max(16 - T_constant, 0) else: L = 0 # This implementation uses a flattened nested loop, that overlaps feeding, # computing and draining phases. Each PE is responsible for computing one # tile of one row of the final result C. With the current implementation, # A PE needs K*T cycles to compute the results and then P*T clock cycles # to fully drain them (draining is distributed across PEs). # Therefore, in order to guarantee correctness and deadlock free we have # to ensure that the number of cycles needed to drain the results is less # or equal to the number of cycles needed to compute them. # That is PT <= KT. if K_constant is not None and P > K_constant: raise ValueError( f"GEMM-FPGA: Number of processing elements {P} must be smaller than the K-dimension {K}." ) ###################################################################### # Build the SDFG new_sdfg = dace.SDFG(node.label + "_sdfg") new_state = new_sdfg.add_state("compute") # Add data descriptors new_sdfg.add_array("_a", shape_a, dtype_a, strides=strides_a, storage=outer_array_a.storage) new_sdfg.add_array("_b", shape_b, dtype_b, strides=strides_b, storage=outer_array_b.storage) new_sdfg.add_array("_c", shape_c, dtype_c, strides=strides_c, storage=outer_array_c.storage) if node.beta != 0: new_sdfg.add_array("_cin", shape_c, dtype_c, strides=strides_c, storage=outer_array_c.storage) def make_read_A(state): # A given row of A must be repeated according to B number of tiles # Both N and M can be not a multiple of P and T respectively entry, exit = state.add_map("read_A", { "n0": f"0:ceiling({N}/{P})", "tm": f"0:ceiling({M}/{T})", "k": f"0:{K}", "n1": f"0:{P}" }, schedule=dace.ScheduleType.FPGA_Device) # The reader of A reads one element per clock cycle. # Note that if P > T+L, then this will be the bottleneck mem = state.add_read("_a") pipe = state.add_write("A_pipe") # Read data from memory: if we are out-of-bound do not read from memory # but inject dummy data tasklet = state.add_tasklet( "read_A", {"from_memory"}, {"to_kernel"}, f"""\ data = from_memory if n0 * {P} + n1 < {N} else 0 to_kernel = data""") state.add_memlet_path(mem, entry, tasklet, dst_conn="from_memory", memlet=dace.Memlet(f"_a[n0 * {P} + n1, k]", dynamic=True, allow_oob=True)) state.add_memlet_path(tasklet, exit, pipe, src_conn="to_kernel", memlet=dace.Memlet(f"A_pipe[{P} - n1 - 1]")) def make_read_B(state): # Also while reading B, we have to consider that T and P could not divide # M and N entry, exit = state.add_map("read_B", { "n": f"0:ceiling({N}/{P})", "tm": f"0:ceiling({M}/{T})", "k": f"0:{K}", "m": f"0:{T}" }, schedule=dace.ScheduleType.FPGA_Device) # If we are out-of bound, use a dummy value new_sdfg.add_array("B_dummy", dtype=vec_type, shape=[1], transient=True, storage=dace.dtypes.StorageType.FPGA_Registers) b_dummy = state.add_access("B_dummy") init_tasklet = state.add_tasklet("init_dummy_B", {}, {"init_data"}, "init_data = 0") state.add_memlet_path(init_tasklet, b_dummy, src_conn="init_data", memlet=dace.Memlet("B_dummy[0]")) mem = state.add_read("_b") pipe = state.add_write("B_pipe") tasklet = state.add_tasklet( "read_B", {"from_memory", "dummy_data"}, {"to_kernel"}, f"""\ data = from_memory if tm*{T} + m < {M} else dummy_data to_kernel = data""") state.add_memlet_path(b_dummy, entry, tasklet, dst_conn="dummy_data", memlet=dace.Memlet("B_dummy[0]")) state.add_memlet_path(mem, entry, tasklet, dst_conn="from_memory", memlet=dace.Memlet(f"_b[k, tm*{T} + m]", dynamic=True, allow_oob=True)) state.add_memlet_path(tasklet, exit, pipe, src_conn="to_kernel", memlet=dace.Memlet("B_pipe[0]")) def make_write_C(state): # Receives the results and adds it to C pipe = state.add_read("C_pipe") if node.beta != 0: mem_read = state.add_read("_cin") mem = state.add_write("_c") entry_map, exit_map = state.add_map( "write_C", { "n0": f"0:ceiling({N}/{P})", "tm": f"0:ceiling({M}/{T})", "n1": f"0:{P}", "m": f"0:{T}" }, schedule=dace.ScheduleType.FPGA_Device) # write in memory by adding C when we copy that to memory # deal with out-of-bound accesses mul_accumulated = f"{node.alpha} * from_kernel" if node.alpha != 1.0 else "from_kernel" if node.beta != 0: if node.beta != 1.0: add_prev_c = f" + {node.beta} * prev_c" else: add_prev_c = " + prev_c" else: add_prev_c = "" tasklet_inputs = {"from_kernel", "prev_c" } if node.beta != 0 else {"from_kernel"} tasklet = state.add_tasklet( "write_C", tasklet_inputs, {"to_memory"}, f"""\ if tm * {T} + m < {M} and n0 * {P} + n1 < {N} : to_memory = {mul_accumulated}{add_prev_c} """) state.add_memlet_path(pipe, entry_map, tasklet, dst_conn="from_kernel", memlet=dace.Memlet(f"C_pipe[{P}-1]")) if node.beta != 0: state.add_memlet_path(mem_read, entry_map, tasklet, dst_conn="prev_c", memlet=dace.Memlet( f"_cin[n0 * {P} + n1, tm * {T} + m]", dynamic=True, allow_oob=True)) state.add_memlet_path(tasklet, exit_map, mem, src_conn="to_memory", memlet=dace.Memlet( f"_c[n0 * {P} + n1, tm * {T} + m]", dynamic=True, allow_oob=True)) def make_compute(sdfg, state): A_pipe_in = state.add_read("A_pipe") B_pipe_in = state.add_read("B_pipe") B_pipe_out = state.add_write("B_pipe") C_pipe_in = state.add_read("C_pipe") C_pipe_out = state.add_write("C_pipe") # The computation is expressed a single, flattened loop, which is generated by the following # pipeline scope. Each PE accumulates over T partial results. The drain phase last P*T clock cycles. # Draining and compute are overlapped. # We are generating the loop by explicitly ignoring loop carried dependencies. Therefore, we have # to guarantee that the PE will accumulate on the same partial result only when its value is consolidated. # The + L is a safe delay between accumulation between the same partial result. # It must be computed by considering T and the latency needed to consolidate a partial result # (which is the latency of the add + latency for reading and writing to BRAM). entry_pipeline, exit_pipeline = state.add_pipeline( "compute_and_drain", { "n0": f"0:ceiling({N}/{P})", "tm": f"0:ceiling({M}/{T})", "k": f"0:{K}", "m": f"0:{T} + {L}" }, drain_size=P * T, drain_overlap=False, additional_iterators={ 'm_drain': 0, 'k_drain': 0 }, schedule=dace.ScheduleType.FPGA_Device) # Instantiate buffers sdfg.add_scalar("A_reg", dtype=dtype_a, transient=True, storage=dace.dtypes.StorageType.FPGA_Registers) A_reg = state.add_write("A_reg") A_reg_init = state.add_access("A_reg") # For C result we are going to use vectorized data type # Note: for some of the Sacred Mysteries of Intel OpenCL Compiler (TM), if this buffer is smaller # than 24 floats, the II of the pipeline will be 5. Therefore we check this and in case we enlarge it buffer_size = T if T_constant is None else max(T_constant, 24) sdfg.add_array("C_buffer", [buffer_size], dtype=vec_type, transient=True, storage=dace.dtypes.StorageType.FPGA_Local) C_buffer_in = state.add_read("C_buffer") C_buffer_out = state.add_write("C_buffer") # Init data to reset partial results new_sdfg.add_array("C_init", dtype=vec_type, shape=[1], transient=True, storage=dace.dtypes.StorageType.FPGA_Registers) C_init = state.add_access("C_init") C_init_tasklet = state.add_tasklet("C_data_init", {}, {"init_data"}, "init_data = 0") state.add_memlet_path(C_init_tasklet, C_init, src_conn="init_data", memlet=dace.Memlet("C_init[0]")) state.add_memlet_path(entry_pipeline, C_init_tasklet, memlet=dace.Memlet()) # Feed A # every PE: reads input data, buffer the data assigned to it buffer_a_tasklet = state.add_tasklet( "buffer_a", {"a_in"}, { "a_reg", }, f"""\ if m == 0 and not {entry_pipeline.pipeline.drain_condition()}: a_reg = a_in""") state.add_memlet_path(A_pipe_in, entry_pipeline, buffer_a_tasklet, memlet=dace.Memlet("A_pipe[p]", dynamic=True), dst_conn="a_in") state.add_memlet_path(buffer_a_tasklet, A_reg, memlet=dace.Memlet("A_reg[0]", dynamic=True), src_conn="a_reg") # Feed B sdfg.add_array("B_reg", shape=[1], dtype=vec_type, transient=True, storage=dace.dtypes.StorageType.FPGA_Local) B_reg = state.add_access("B_reg") buffer_b_tasklet = state.add_tasklet( "buffer_b", {"b_in"}, {"b_reg_out"}, f"""\ if m>={L} and not {entry_pipeline.pipeline.drain_condition()}: b_reg_out = b_in""") state.add_memlet_path(B_pipe_in, entry_pipeline, buffer_b_tasklet, memlet=dace.Memlet("B_pipe[p]", dynamic=True), dst_conn="b_in") state.add_memlet_path(buffer_b_tasklet, B_reg, memlet=dace.Memlet("B_reg[0]", dynamic=True), src_conn="b_reg_out") # Compute, Forward B, and Drain compute_tasklet = state.add_tasklet( "compute_and_drain", {"a_in", "b_in", "c_in", "forward_in", "c_init_data"}, {"b_out", "c_out", "c_pipe_out"}, f"""\ result = c_in if m >= {L} and not {entry_pipeline.pipeline.drain_condition()}: c_prev = c_init_data if k == 0 else c_in result = c_prev + a_in * b_in c_out = result if p < {P} - 1: b_out = b_in # Drain # when we have to drain: # - if we are working on second assigned row or second tile and we have something to drain # - if k = K-1 and m>=L: each PE has just finished to compute something # - if we are in the draining phase # How: # - if k = K-1 and m>=L: then the PE drains its own result #- otherwise, if k_drain<p forward data coming from previous PEs (this could happens also in the drain phase) if((n0 > 0 or tm > 0) and k_drain <p and m_drain <{T}) or (k=={K}-1 and m>= {L}) or ({entry_pipeline.pipeline.drain_condition()} and k_drain < p): c_pipe_out = result if (p==0 or (k_drain=={K}-1 and not {entry_pipeline.pipeline.drain_condition()})) else forward_in # adjust draining iterators if not {entry_pipeline.pipeline.drain_condition()}: if m_drain >= {L} + {T} -1: m_drain = 0 if k_drain >= {K} - 1: k_drain = 0 else: k_drain = k_drain +1 else: m_drain = m_drain + 1 else: if m_drain >= {T} -1: m_drain = 0 if k_drain >= {K} - 1: k_drain = 0 else: k_drain = k_drain +1 else: m_drain = m_drain + 1 """) state.add_memlet_path(A_reg, compute_tasklet, dst_conn="a_in", memlet=dace.Memlet("A_reg[0]")) state.add_memlet_path(B_reg, compute_tasklet, memlet=dace.Memlet("B_reg[0]", dynamic=False), dst_conn="b_in") state.add_memlet_path(C_init, compute_tasklet, memlet=dace.Memlet("C_init[0]"), dst_conn="c_init_data") state.add_memlet_path(compute_tasklet, exit_pipeline, B_pipe_out, memlet=dace.Memlet("B_pipe[p + 1]", dynamic=True), src_conn="b_out") state.add_memlet_path(C_buffer_in, entry_pipeline, compute_tasklet, dst_conn="c_in", memlet=dace.Memlet(f"C_buffer[m-{L}]", allow_oob=True)) state.add_memlet_path(compute_tasklet, exit_pipeline, C_buffer_out, memlet=dace.Memlet(f"C_buffer[m-{L}]", allow_oob=True, dynamic=True), src_conn="c_out") state.add_memlet_path(C_pipe_in, entry_pipeline, compute_tasklet, memlet=dace.Memlet("C_pipe[p-1]", dynamic=True), dst_conn="forward_in") state.add_memlet_path(compute_tasklet, exit_pipeline, C_pipe_out, memlet=dace.Memlet("C_pipe[p]", dynamic=True), src_conn="c_pipe_out") # Unroll processing elements compute_entry, compute_exit = state.add_map( "unroll_compute", {"p": "0:{}".format(P)}, schedule=dace.ScheduleType.FPGA_Device, unroll=True) # Bring data nodes into scope state.add_memlet_path(compute_entry, A_pipe_in, memlet=dace.memlet.Memlet()) state.add_memlet_path(compute_entry, B_pipe_in, memlet=dace.memlet.Memlet()) state.add_memlet_path(compute_entry, C_pipe_in, memlet=dace.memlet.Memlet()) state.add_memlet_path(B_pipe_out, compute_exit, memlet=dace.memlet.Memlet()) state.add_memlet_path(C_pipe_out, compute_exit, memlet=dace.memlet.Memlet()) state.add_memlet_path(compute_entry, A_reg_init, memlet=dace.memlet.Memlet()) state.add_memlet_path(A_reg_init, entry_pipeline, memlet=dace.memlet.Memlet()) b_init = state.add_access("B_reg") state.add_memlet_path(compute_entry, b_init, memlet=dace.Memlet()) state.add_memlet_path(b_init, entry_pipeline, memlet=dace.Memlet()) state.add_memlet_path(compute_entry, C_buffer_in, memlet=dace.Memlet()) state.add_memlet_path(C_buffer_out, compute_exit, memlet=dace.Memlet()) # build the compute State new_sdfg.add_stream("A_pipe", dtype_a, transient=True, shape=(P, ), storage=dace.dtypes.StorageType.FPGA_Local, buffer_size=str(P)) new_sdfg.add_stream("B_pipe", vec_type, transient=True, shape=(P + 1, ), buffer_size=1, storage=dace.dtypes.StorageType.FPGA_Local) new_sdfg.add_stream("C_pipe", vec_type, transient=True, shape=(P + 1, ), buffer_size=T, storage=dace.dtypes.StorageType.FPGA_Local) make_read_A(new_state) make_read_B(new_state) make_compute(new_sdfg, new_state) make_write_C(new_state) return new_sdfg
def make_sdfg(node, parent_state, parent_sdfg): sdfg = dace.SDFG(node.label + "_sdfg") ((edge_a, outer_array_a, shape_a, strides_a), (edge_b, outer_array_b, shape_b, strides_b), cdata) = _get_matmul_operands(node, parent_state, parent_sdfg) dtype_a = outer_array_a.dtype.type dtype_b = outer_array_b.dtype.type dtype_c = dace.DTYPE_TO_TYPECLASS[np.result_type(dtype_a, dtype_b).type] if node.transA: trans_shape_a = list(reversed(shape_a)) else: trans_shape_a = shape_a if node.transB: trans_shape_b = list(reversed(shape_b)) else: trans_shape_b = shape_b if (len(trans_shape_a) != 2 or len(trans_shape_b) != 2 or trans_shape_a[1] != trans_shape_b[0]): raise SyntaxError("Matrix sizes must match") M, K, N = trans_shape_a[0], trans_shape_a[1], trans_shape_b[1] shape_c = (M, N) storage = outer_array_a.storage _, array_a = sdfg.add_array("_a", shape_a, dtype_a, strides=strides_a, storage=outer_array_a.storage) _, array_b = sdfg.add_array("_b", shape_b, dtype_b, strides=strides_b, storage=outer_array_b.storage) _, array_c = sdfg.add_array("_c", shape_c, dtype_c, strides=cdata[-1], storage=cdata[1].storage) if node.alpha == 1.0: mul_program = "__out = __a * __b" else: mul_program = "__out = {} * __a * __b".format( _cast_to_dtype_str(node.alpha, dtype_a)) if node.beta == 1: state = sdfg.add_state(node.label + "_state") else: init_state = sdfg.add_state(node.label + "_initstate") state = sdfg.add_state_after(init_state, node.label + "_state") if node.beta != 0: sdfg.add_array("_cin", shape_c, dtype_c, strides=cdata[-1], storage=cdata[1].storage) mul_out, mul_out_array = "_c", array_c output_nodes = None # Initialization / beta map if node.beta == 0: init_state.add_mapped_tasklet( 'gemm_init', { '_o%d' % i: '0:%s' % symstr(d) for i, d in enumerate(shape_c) }, {}, 'out = 0', { 'out': dace.Memlet.simple( mul_out, ','.join( ['_o%d' % i for i in range(len(shape_c))])) }, external_edges=True) elif node.beta == 1: # Do nothing for initialization, only update the values pass else: # Beta map add_program = "__y = ({} * __c)".format( _cast_to_dtype_str(node.beta, dtype_a)) # manually broadcasting C to [M, N] if list(shape_c) == [M, N]: memlet_idx = '__i0, __i1' elif list(shape_c) == [1, N]: memlet_idx = '0, __i1' elif list(shape_c) == [M, 1]: memlet_idx = '__i0, 0' elif list(shape_c) == [N]: memlet_idx = '__i1' else: raise ValueError( "Could not broadcast input _c to ({}, {})".format(M, N)) init_state.add_mapped_tasklet( "gemm_init", {"__i%d" % i: "0:%s" % s for i, s in enumerate([M, N])}, { "__c": dace.Memlet.simple("_cin", memlet_idx), }, add_program, {"__y": dace.Memlet.simple("_c", "__i0, __i1")}, external_edges=True) # Multiplication map state.add_mapped_tasklet( "gemm", {"__i%d" % i: "0:%s" % s for i, s in enumerate([M, N, K])}, { "__a": dace.Memlet.simple( "_a", "__i2, __i0" if node.transA else "__i0, __i2"), "__b": dace.Memlet.simple( "_b", "__i1, __i2" if node.transB else "__i2, __i1") }, mul_program, { "__out": dace.Memlet.simple( mul_out, "__i0, __i1", wcr_str="lambda x, y: x + y") }, external_edges=True, output_nodes=output_nodes) return sdfg
def make_sdfg(node, parent_state, parent_sdfg): # Get metadata from parent SDFG ((edge_a, outer_array_a, shape_a, strides_a), (edge_b, outer_array_b, shape_b, strides_b), cdata) = _get_matmul_operands(node, parent_state, parent_sdfg) outedge = parent_state.out_edges(node)[0] cdesc = parent_sdfg.arrays[outedge.data.data] bopt = _get_batchmm_opts(shape_a, strides_a, shape_b, strides_b, cdesc.shape, cdesc.strides) if shape_a[-1] != shape_b[-2]: raise SyntaxError('Matrix sizes must match') if bopt: shape_c = (bopt['b'], shape_a[-2], shape_b[-1]) else: shape_c = (shape_a[-2], shape_b[-1]) dtype_a = outer_array_a.dtype.type dtype_b = outer_array_b.dtype.type dtype_c = cdesc.dtype.type if outer_array_a.storage != outer_array_b.storage: raise ValueError("Input matrices must have same storage") storage = outer_array_a.storage # Create replacement SDFG sdfg = dace.SDFG(node.label + "_sdfg") _, array_a = sdfg.add_array("_a", shape_a, dtype_a, strides=strides_a, storage=storage) _, array_b = sdfg.add_array("_b", shape_b, dtype_b, strides=strides_b, storage=storage) _, array_c = sdfg.add_array("_c", shape_c, dtype_c, strides=cdata[-1], storage=storage) # Add an initialization state init_state = sdfg.add_state() init_state.add_mapped_tasklet( 'batched_matmul_init', {'_o%d' % i: '0:%s' % symstr(d) for i, d in enumerate(shape_c)}, {}, 'out = 0', { 'out': dace.Memlet.simple( '_c', ','.join(['_o%d' % i for i in range(len(shape_c))])) }, external_edges=True) state = sdfg.add_state_after(init_state, node.label + "_state") state.add_mapped_tasklet( '_BatchedBatchedMatMult_', { '__i%d' % i: '0:%s' % s for i, s in enumerate([ bopt['b'], array_a.shape[-2], array_b.shape[-1], array_a.shape[-1] ]) }, { '__a': dace.Memlet.simple("_a", ('__i1, __i3' if len(array_a.shape) == 2 else '__i0, __i1, __i3')), '__b': dace.Memlet.simple("_b", ('__i3, __i2' if len(array_b.shape) == 2 else '__i0, __i3, __i2')) }, '__c = __a * __b', { '__c': dace.Memlet.simple( "_c", '__i0, __i1, __i2', wcr_str='lambda x, y: x + y') }, external_edges=True) return sdfg
def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs): from dace.sdfg.scope import is_devicelevel_gpu if is_devicelevel_gpu(sdfg, state, node): return ExpandGemvPure.expansion(node, state, sdfg) node.validate(sdfg, state) ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x, shape_x, strides_x), (edge_y, outer_array_y, shape_y, strides_y)) = _get_matmul_operands(node, state, sdfg, name_lhs="_A", name_rhs="_x", name_out="_y") dtype_a = outer_array_a.dtype.type dtype = outer_array_x.dtype.base_type veclen = outer_array_x.dtype.veclen m = m or node.m n = n or node.n if m is None: m = shape_y[0] if n is None: n = shape_x[0] transA = node.transA if strides_a[0] == 1: transA = not transA lda = strides_a[1] elif strides_a[1] == 1: lda = strides_a[0] else: warnings.warn('Matrix must be contiguous in at least ' 'one dimension. Falling back to pure expansion.') return ExpandGemvPure.expansion(node, state, sdfg, m=m, n=n, **kwargs) layout = 'CblasColMajor' trans = 'CblasNoTrans' if transA else 'CblasTrans' if not node.transA: m, n = n, m if veclen != 1: warnings.warn('Vector GEMV not supported, falling back to pure.') return ExpandGemvPure.expansion(node, state, sdfg, m=m, n=n, **kwargs) func, ctype, runtimetype = blas_helpers.cublas_type_metadata(dtype) func = func.lower() + 'gemv' code = f"""cblas_{func}({layout}, {trans}, {m}, {n}, {node.alpha}, _A, {lda}, _x, {strides_x[0]}, {node.beta}, _y, {strides_y[0]});""" tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet
def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs): node.validate(sdfg, state) ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x, shape_x, strides_x), (edge_y, outer_array_y, shape_y, strides_y)) = _get_matmul_operands(node, state, sdfg, name_lhs="_A", name_rhs="_x", name_out="_y") dtype_a = outer_array_a.dtype.type dtype = outer_array_x.dtype.base_type veclen = outer_array_x.dtype.veclen m = m or node.m n = n or node.n if m is None: m = shape_y[0] if n is None: n = shape_x[0] transA = node.transA if strides_a[0] == 1: transA = not transA lda = strides_a[1] elif strides_a[1] == 1: lda = strides_a[0] else: warnings.warn('Matrix must be contiguous in at least ' 'one dimension. Falling back to pure expansion.') return ExpandGemvPure.expansion(node, state, sdfg, m=m, n=n, **kwargs) trans = 'CUBLAS_OP_N' if transA else 'CUBLAS_OP_T' if not node.transA: m, n = n, m if veclen != 1: warnings.warn('Vector GEMV not supported, falling back to pure') return ExpandGemvPure.expansion(node, state, sdfg, m=m, n=n, **kwargs) func, ctype, runtimetype = blas_helpers.cublas_type_metadata(dtype) func += 'gemv' call_prefix = environments.cublas.cuBLAS.handle_setup_code(node) call_suffix = '' # Handle alpha / beta constants = { 1.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{runtimetype}Pone()", 0.0: f"__state->cublas_handle.Constants(__dace_cuda_device).{runtimetype}Zero()", } if node.alpha not in constants or node.beta not in constants: # Deal with complex input constants if isinstance(node.alpha, complex): alpha = f'{dtype.ctype}({node.alpha.real}, {node.alpha.imag})' else: alpha = f'{dtype.ctype}({node.alpha})' if isinstance(node.beta, complex): beta = f'{dtype.ctype}({node.beta.real}, {node.beta.imag})' else: beta = f'{dtype.ctype}({node.beta})' # Set pointer mode to host call_prefix += f'''cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_HOST); {dtype.ctype} alpha = {alpha}; {dtype.ctype} beta = {beta}; ''' call_suffix += ''' cublasSetPointerMode(__dace_cublas_handle, CUBLAS_POINTER_MODE_DEVICE); ''' alpha = f'({ctype} *)&alpha' beta = f'({ctype} *)&beta' else: alpha = constants[node.alpha] beta = constants[node.beta] code = (call_prefix + f""" cublas{func}(__dace_cublas_handle, {trans}, {m}, {n}, {alpha}, _A, {lda}, _x, {strides_x[0]}, {beta}, _y, {strides_y[0]}); """ + call_suffix) tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet
def expansion(node, parent_state, parent_sdfg, **kwargs): node.validate(parent_sdfg, parent_state) sdfg = dace.SDFG(node.label + "_sdfg") ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x, shape_x, strides_x), (edge_y, outer_array_y, shape_y, strides_y)) = _get_matmul_operands(node, parent_state, parent_sdfg, name_lhs="_A", name_rhs="_x", name_out="_y") dtype_a = outer_array_a.dtype.type dtype_x = outer_array_x.dtype.type dtype_y = outer_array_y.dtype.type if outer_array_a.dtype.veclen > 1 or outer_array_x.dtype.veclen > 1: raise NotImplementedError("Vectorization for pure GEMV NYI.") if node.transA: trans_shape_a = list(reversed(shape_a)) else: trans_shape_a = shape_a if trans_shape_a[1] != shape_x[0]: raise SyntaxError( "Matrix-vector product size mismatch: {} vs. {}".format( trans_shape_a[1], shape_x[0])) N, M = trans_shape_a[0], trans_shape_a[1] if outer_array_a.storage != outer_array_x.storage: raise ValueError("Input matrices must have same storage") storage = outer_array_a.storage _, array_a = sdfg.add_array("_A", shape_a, dtype_a, strides=strides_a, storage=storage) _, array_x = sdfg.add_array("_x", shape_x, dtype_x, strides=strides_x, storage=storage) _, array_y = sdfg.add_array("_y", shape_y, dtype_y, strides=strides_y, storage=storage) mul_program = "__out = {} * __A * __x".format(node.alpha) init_state = sdfg.add_state(node.label + "_initstate") state = sdfg.add_state_after(init_state, node.label + "_state") if node.beta == 0: mul_out, mul_out_array = "_y", array_y output_nodes = None else: mul_out, mul_out_array = tmp, array_tmp = sdfg.add_temp_transient( shape_y, dtype_y, storage=storage) access_tmp = state.add_read(tmp) output_nodes = {mul_out: access_tmp} # Initialization map init_state.add_mapped_tasklet( "gemv_init", { "_o%d" % i: "0:%s" % symbolic.symstr(d) for i, d in enumerate(shape_y) }, {}, "out = 0", { "out": dace.Memlet("{}[{}]".format( mul_out, ",".join( ["_o%d" % i for i in range(len(shape_y))]))) }, external_edges=True) # Multiplication map state.add_mapped_tasklet( "_GEMV_", {"__i%d" % i: "0:%s" % s for i, s in enumerate([N, M])}, { "__A": dace.Memlet("_A[{}]".format( "__i1, __i0" if node.transA else "__i0, __i1")), "__x": dace.Memlet("_x[__i1]") }, mul_program, { "__out": dace.Memlet(f"{mul_out}[__i0]", wcr="lambda x, y: x + y") }, external_edges=True, output_nodes=output_nodes) add_program = "__y_out = ({} * __y_in) + __tmp".format(node.beta) memlet_idx = "__i" # addition map if node.beta != 0: state.add_mapped_tasklet("_Add_", {"__i": "0:{}".format(N)}, { "__y_in": dace.Memlet(f"_y[{memlet_idx}]"), "__tmp": dace.Memlet(f"{mul_out}[__i]"), }, add_program, {"__y_out": dace.Memlet("_y[__i]")}, external_edges=True, input_nodes={mul_out: access_tmp}) return sdfg
def expansion(node, parent_state, parent_sdfg, tile_size_x=None, tile_size_y=None, num_partial_sums=16): """ :param node: Node to expand. :param parent_state: State that the node is in. :param parent_sdfg: SDFG that the node is in. :param tile_size_x: Tile size along the dimension of the vector x. If set to None, no tiling is used, corresponding to setting the tile size equal to the full size of x. :param tile_size_y: Tile size along the dimension of the vector y. If set to None, no tiling is used, corresponding to setting the tile size equal to the full size of y. :param num_partial_sums: The number of distinct registers to accumulate contributions to the final sum into. Should be a power of two, and should be higher than the latency of adding two numbers of the given data type. """ node.validate(parent_sdfg, parent_state) sdfg = dace.SDFG("gemv") state = sdfg.add_state("gemv") alpha = node.alpha beta = node.beta # Get input/output data (the method considers also the presence of view nodes) ((edge_a, desc_a, shape_a, strides_a), (edge_x, desc_x, shape_x, strides_x), (edge_y, desc_y, shape_y, strides_y)) = _get_matmul_operands(node, parent_state, parent_sdfg, name_lhs="_A", name_rhs="_x", name_out="_y") # Create local versions of input/output data nodes _, desc_a = sdfg.add_array("_A", shape_a, desc_a.dtype, strides=strides_a, storage=desc_a.storage, transient=False) _, desc_x = sdfg.add_array("_x", shape_x, desc_x.dtype, strides=strides_x, storage=desc_x.storage, transient=False) _, desc_y_y = sdfg.add_array("_y", shape_y, desc_y.dtype, strides=strides_y, storage=desc_y.storage, transient=False) if node.transA and desc_a.dtype.veclen > 1: raise NotImplementedError( "Vectorization not implemented for transposed A.") # Create accesses read_a = state.add_read("_A") read_x = state.add_read("_x") if beta != 0: read_y = state.add_read("_y") write_y = state.add_write("_y") size_x = desc_x.shape[0] size_y = desc_y.shape[0] if tile_size_x is None: tile_size_x = size_x if tile_size_y is None: tile_size_y = size_y num_tiles_y = f"{size_y}/{tile_size_y}" num_tiles_x = f"{size_x}/{tile_size_x}" veclen = desc_a.dtype.veclen # Create tile map y_tile_entry, y_tile_exit = state.add_map( "y_tiles", {"ty": f"0:{num_tiles_y}"}, schedule=dace.ScheduleType.FPGA_Device) x_tile_entry, x_tile_exit = state.add_map( "x_tiles", {"tx": f"0:{num_tiles_x}"}, schedule=dace.ScheduleType.FPGA_Device) # Create y map y_entry, y_exit = state.add_map("y", {"iy": f"0:{tile_size_y}"}, schedule=dace.ScheduleType.FPGA_Device) # Create x map x_entry, x_exit = state.add_map("x", {"ix": f"0:{tile_size_x}"}, schedule=dace.ScheduleType.FPGA_Device) # Local buffer of x sdfg.add_array("x_local", (tile_size_x, ), desc_x.dtype, storage=dace.StorageType.FPGA_Local, transient=True) x_local_access = state.add_read("x_local") if beta != 0: raise NotImplementedError("Not yet implemented.") multiply_tasklet = state.add_tasklet("multiply", {"A_in", "x_in"}, {f"product": desc_a.dtype}, "product = A_in * x_in") if isinstance(desc_a, dt.Stream): subset = "0" elif node.transA: subset = f"tx * {tile_size_x} + ix, ty * {tile_size_y} + iy" else: subset = f"ty * {tile_size_y} + iy, tx * {tile_size_x} + ix" state.add_memlet_path(read_a, y_tile_entry, x_tile_entry, y_entry, x_entry, multiply_tasklet, dst_conn="A_in", memlet=dace.Memlet(f"_A[{subset}]")) read_x_entry, read_x_exit = state.add_map( "read_x", {"ix": f"0:{tile_size_x}"}, schedule=dace.ScheduleType.FPGA_Device) subset = ("0" if isinstance(desc_x, dt.Stream) else f"tx*{tile_size_x} + ix") read_x_tasklet = state.add_tasklet("read_x", {"x_memory"}, {"x_buffer"}, "x_buffer = x_memory") state.add_memlet_path(read_x, y_tile_entry, x_tile_entry, read_x_entry, read_x_tasklet, dst_conn="x_memory", memlet=dace.Memlet(f"_x[{subset}]")) state.add_memlet_path(read_x_tasklet, read_x_exit, x_local_access, src_conn="x_buffer", memlet=dace.Memlet(f"x_local[ix]")) state.add_memlet_path(x_local_access, y_entry, x_entry, multiply_tasklet, dst_conn="x_in", memlet=dace.Memlet(f"x_local[ix]")) # Write to buffer sdfg.add_array("product_vector", (1, ), desc_a.dtype, transient=True, storage=dace.StorageType.FPGA_Local) product_vector = state.add_access("product_vector") state.add_memlet_path(multiply_tasklet, product_vector, src_conn="product", memlet=dace.Memlet(f"product_vector[0]")) # Vector length conversion sdfg.add_array("product_scalar", (veclen, ), desc_a.dtype.base_type, transient=True, storage=dace.StorageType.FPGA_Local) product_scalar = state.add_access("product_scalar") state.add_memlet_path(product_vector, product_scalar, memlet=dace.Memlet(f"product_vector[0]", other_subset=f"0:{veclen}")) # Now we need to collapse this reduce_vector_entry, reduce_vector_exit = state.add_map( "reduce_vector", {"u": f"0:{veclen}"}, schedule=dace.ScheduleType.FPGA_Device, unroll=True) reduce_vector_tasklet = state.add_tasklet( "reduce_vector", {"product_in", "acc_in"}, {"acc_out"}, "acc_out = product_in + acc_in") state.add_memlet_path(product_scalar, reduce_vector_entry, reduce_vector_tasklet, dst_conn="product_in", memlet=dace.Memlet(f"{product_scalar}[u]")) # Add accumulation register sdfg.add_array("accumulate_product", (1, ), desc_a.dtype.base_type, transient=True, storage=dace.StorageType.FPGA_Local) accumulate_product_read = state.add_access("accumulate_product") accumulate_product_write = state.add_access("accumulate_product") # Initialize it to zero init_reduce_vector_tasklet = state.add_tasklet("init_reduce_vector", {}, {"acc_out"}, "acc_out = 0") state.add_memlet_path(x_entry, init_reduce_vector_tasklet, memlet=dace.Memlet()) state.add_memlet_path(init_reduce_vector_tasklet, accumulate_product_read, src_conn="acc_out", memlet=dace.Memlet(f"accumulate_product[0]")) # Connect it to the tasklet state.add_memlet_path(accumulate_product_read, reduce_vector_entry, reduce_vector_tasklet, dst_conn="acc_in", memlet=dace.Memlet(f"accumulate_product[0]")) state.add_memlet_path(reduce_vector_tasklet, reduce_vector_exit, accumulate_product_write, src_conn="acc_out", memlet=dace.Memlet(f"accumulate_product[0]")) # Partial sums sdfg.add_array("partial_sums", (num_partial_sums, ), desc_y.dtype, storage=dace.StorageType.FPGA_Registers, transient=True) partial_sum_read = state.add_read("partial_sums") partial_sum_write = state.add_access("partial_sums") # Output array sdfg.add_array("y_local", (tile_size_y, ), desc_y.dtype, storage=dace.StorageType.FPGA_Local, transient=True) # Now we need to actually accumulate into a local register of y y_local_read = state.add_read("y_local") y_local_write = state.add_read("y_local") update_y_tasklet = state.add_tasklet( "update_y", {"y_in", "acc_in"}, {"acc_out"}, f"""\ prev = acc_in if ix >= {num_partial_sums} else 0 acc_out = prev + y_in""") state.add_memlet_path(accumulate_product_write, update_y_tasklet, dst_conn="y_in", memlet=dace.Memlet(f"accumulate_product[0]")) state.add_memlet_path( partial_sum_read, x_entry, update_y_tasklet, dst_conn="acc_in", memlet=dace.Memlet(f"partial_sums[ix%{num_partial_sums}]")) state.add_memlet_path(y_tile_entry, y_local_read, memlet=dace.Memlet()) state.add_memlet_path(y_entry, partial_sum_read, memlet=dace.Memlet()) state.add_memlet_path( update_y_tasklet, x_exit, partial_sum_write, src_conn="acc_out", memlet=dace.Memlet(f"partial_sums[ix%{num_partial_sums}]")) # Reduce the partial sums reduce_sums_entry, reduce_sums_exit = state.add_map( "reduce_partial_sums", {"u": f"0:{num_partial_sums}"}, schedule=dace.ScheduleType.FPGA_Device, unroll=True) reduce_sums_tasklet = state.add_tasklet( "reduce_partial_sums", {"sum_in", "val_in"}, {"sum_out"}, """ prev = sum_in if u > 0 else 0 sum_out = prev + val_in""") sdfg.add_array("accumulate_sum", (1, ), desc_y.dtype, transient=True, storage=dace.StorageType.FPGA_Local) accumulate_sum_read = state.add_access("accumulate_sum") accumulate_sum_write = state.add_access("accumulate_sum") state.add_memlet_path(y_entry, accumulate_sum_read, memlet=dace.Memlet()) state.add_memlet_path(accumulate_sum_read, reduce_sums_entry, reduce_sums_tasklet, dst_conn="sum_in", memlet=dace.Memlet("accumulate_sum[0]")) state.add_memlet_path(reduce_sums_tasklet, reduce_sums_exit, accumulate_sum_write, src_conn="sum_out", memlet=dace.Memlet("accumulate_sum[0]")) state.add_memlet_path(partial_sum_write, reduce_sums_entry, reduce_sums_tasklet, dst_conn="val_in", memlet=dace.Memlet("partial_sums[u]")) # Combine with y buffer combine_tasklet = state.add_tasklet( "combine_y", {"val", "buffer_in"}, {"buffer_out"}, """\ prev = buffer_in if tx > 0 else 0 buffer_out = prev + val""") state.add_memlet_path(accumulate_sum_write, combine_tasklet, dst_conn="val", memlet=dace.Memlet("accumulate_sum[0]")) state.add_memlet_path(y_local_read, x_tile_entry, y_entry, combine_tasklet, dst_conn="buffer_in", memlet=dace.Memlet("y_local[iy]")) state.add_memlet_path(combine_tasklet, y_exit, x_tile_exit, y_local_write, src_conn="buffer_out", memlet=dace.Memlet(f"y_local[iy]")) subset = ("0" if isinstance(desc_y, dt.Stream) else f"ty*{tile_size_y} + iy") write_y_entry, write_y_exit = state.add_map( "write_y", {"iy": f"0:{tile_size_y}"}, schedule=dace.ScheduleType.FPGA_Device) write_y_tasklet = state.add_tasklet("write_y", {"y_buffer"}, {"y_memory"}, "y_memory = y_buffer") state.add_memlet_path(y_local_write, write_y_entry, write_y_tasklet, dst_conn="y_buffer", memlet=dace.Memlet(f"y_local[iy]")) state.add_memlet_path(write_y_tasklet, write_y_exit, y_tile_exit, write_y, src_conn="y_memory", memlet=dace.Memlet(f"_y[{subset}]")) return sdfg
def expansion(node: 'Gemv', state, sdfg, m=None, n=None, **kwargs): node.validate(sdfg, state) ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x, shape_x, strides_x), (edge_y, outer_array_y, shape_y, strides_y)) = _get_matmul_operands(node, state, sdfg, name_lhs="_A", name_rhs="_x", name_out="_y") dtype_a = outer_array_a.dtype.type dtype = outer_array_x.dtype.base_type veclen = outer_array_x.dtype.veclen m = m or node.m n = n or node.n if m is None: m = shape_y[0] if n is None: n = shape_x[0] transA = node.transA if strides_a[0] == 1: transA = not transA lda = strides_a[1] elif strides_a[1] == 1: lda = strides_a[0] else: warnings.warn('Matrix must be contiguous in at least ' 'one dimension. Falling back to pure expansion.') return ExpandGemvPure.expansion(node, state, sdfg, m=m, n=n, **kwargs) trans = 'CUBLAS_OP_N' if transA else 'CUBLAS_OP_T' if not node.transA: m, n = n, m if veclen != 1: warnings.warn('Vector GEMV not supported, falling back to pure') return ExpandGemvPure.expansion(node, state, sdfg, m=m, n=n, **kwargs) func, ctype, runtimetype = blas_helpers.cublas_type_metadata(dtype) func += 'gemv' # TODO: (alpha,beta) != (1,0) if node.alpha != 1.0 or node.beta != 0.0: raise NotImplementedError alpha = ( '__state->cublas_handle.Constants(__dace_cuda_device).%sPone()' % runtimetype) beta = ( '__state->cublas_handle.Constants(__dace_cuda_device).%sZero()' % runtimetype) code = (environments.cublas.cuBLAS.handle_setup_code(node) + f""" cublas{func}(__dace_cublas_handle, {trans}, {m}, {n}, {alpha}, _A, {lda}, _x, {strides_x[0]}, {beta}, _y, {strides_y[0]});""") tasklet = dace.sdfg.nodes.Tasklet(node.name, node.in_connectors, node.out_connectors, code, language=dace.dtypes.Language.CPP) return tasklet
def make_sdfg(node, parent_state, parent_sdfg): sdfg = dace.SDFG(node.label + "_sdfg") ((edge_a, outer_array_a, shape_a, strides_a), (edge_x, outer_array_x, shape_x, _), _) = _get_matmul_operands(node, parent_state, parent_sdfg, name_lhs="_a", name_rhs="_x", name_out="_y") dtype_a = outer_array_a.dtype.type dtype_x = outer_array_x.dtype.type dtype_y = dace.DTYPE_TO_TYPECLASS[np.result_type(dtype_a, dtype_x).type] if node.transA: trans_shape_a = list(reversed(shape_a)) else: trans_shape_a = shape_a if trans_shape_a[1] != shape_x[0]: raise SyntaxError( "Matrix-vector product size mismatch: {} vs. {}".format( trans_shape_a[1], shape_x[0])) N, M = trans_shape_a[0], trans_shape_a[1] shape_y = (N, ) if outer_array_a.storage != outer_array_x.storage: raise ValueError("Input matrices must have same storage") storage = outer_array_a.storage _, array_a = sdfg.add_array("_a", shape_a, dtype_a, strides=strides_a, storage=storage) _, array_x = sdfg.add_array("_x", shape_x, dtype_x, storage=storage) _, array_y = sdfg.add_array("_y", shape_y, dtype_y, storage=storage) if node.alpha == 1.0: mul_program = "__out = __a * __x" else: mul_program = "__out = {} * __a * __x".format( _cast_to_dtype_str(node.alpha, dtype_a)) init_state = sdfg.add_state(node.label + "_initstate") state = sdfg.add_state_after(init_state, node.label + "_state") if node.beta == 0: mul_out, mul_out_array = "_y", array_y output_nodes = None else: mul_out, mul_out_array = tmp, array_tmp = sdfg.add_temp_transient( shape_y, dtype_y, storage=storage) access_tmp = state.add_read(tmp) output_nodes = {mul_out: access_tmp} # Initialization map init_state.add_mapped_tasklet( "gemv_init", {"_o%d" % i: "0:%s" % symstr(d) for i, d in enumerate(shape_y)}, {}, "out = 0", { "out": dace.Memlet.simple( mul_out, ",".join( ["_o%d" % i for i in range(len(shape_y))])) }, external_edges=True) # Multiplication map state.add_mapped_tasklet( "_GEMV_", {"__i%d" % i: "0:%s" % s for i, s in enumerate([N, M])}, { "__a": dace.Memlet.simple( "_a", "__i1, __i0" if node.transA else "__i0, __i1"), "__x": dace.Memlet.simple("_x", "__i1") }, mul_program, { "__out": dace.Memlet.simple( mul_out, "__i0", wcr_str="lambda x, y: x + y") }, external_edges=True, output_nodes=output_nodes) if node.beta != 0: add_program = "__y_out = ({} * __y_in) + __tmp".format( _cast_to_dtype_str(node.beta, dtype_a)) memlet_idx = "__i" # addition map state.add_mapped_tasklet( "_Add_", {"__i": "0:{}".format(N)}, { "__y_in": dace.Memlet.simple("_y", memlet_idx), "__tmp": dace.Memlet.simple(mul_out, "__i"), }, add_program, {"__y_out": dace.Memlet.simple("_y", "__i")}, external_edges=True, input_nodes={mul_out: access_tmp}) return sdfg
def make_sdfg(node, parent_state, parent_sdfg): sdfg = dace.SDFG(node.label + "_sdfg") ((edge_x, outer_array_x, shape_x, _), (edge_y, outer_array_y, shape_y, _), (_, outer_array_result, shape_result, _)) = _get_matmul_operands(node, parent_state, parent_sdfg, name_lhs="_x", name_rhs="_y", name_out="_result") dtype_x = outer_array_x.dtype.type dtype_y = outer_array_y.dtype.type dtype_result = outer_array_result.dtype.type if shape_x != shape_y or shape_result != [1]: raise SyntaxError("Invalid shapes to dot product.") N = shape_x[0] if outer_array_x.storage != outer_array_y.storage: raise ValueError("Input matrices must have same storage") storage = outer_array_x.storage _, array_x = sdfg.add_array("_x", shape_x, dtype_x, storage=storage) _, array_y = sdfg.add_array("_y", shape_y, dtype_y, storage=storage) _, array_result = sdfg.add_array("_result", [1], dtype_result, storage=storage) mul_program = "__out = __x * __y" init_state = sdfg.add_state(node.label + "_initstate") state = sdfg.add_state_after(init_state, node.label + "_state") mul_out, mul_out_array = "_result", array_result output_nodes = None # Initialization map init_write = init_state.add_write("_result") init_tasklet = init_state.add_tasklet("dot_init", {}, {"_out"}, "_out = 0", location=node.location) init_state.add_memlet_path(init_tasklet, init_write, src_conn="_out", memlet=dace.Memlet.simple(init_write.data, "0", num_accesses=1)) # Multiplication map state.add_mapped_tasklet( "_DOT_", {"__i": "0:{}".format(N)}, { "__x": dace.Memlet.simple("_x", "__i"), "__y": dace.Memlet.simple("_y", "__i") }, mul_program, { "__out": dace.Memlet.simple(mul_out, "0", wcr_str="lambda x, y: x + y") }, external_edges=True, output_nodes=output_nodes) return sdfg