def create_batch_gemm_sdfg(dtype, strides): ######################### sdfg = SDFG('einsum') state = sdfg.add_state() M, K, N = (symbolic.symbol(s) for s in ['M', 'K', 'N']) BATCH, sAM, sAK, sAB, sBK, sBN, sBB, sCM, sCN, sCB = ( symbolic.symbol(s) if symbolic.issymbolic(strides[s]) else strides[s] for s in [ 'BATCH', 'sAM', 'sAK', 'sAB', 'sBK', 'sBN', 'sBB', 'sCM', 'sCN', 'sCB' ]) batched = strides['BATCH'] != 1 _, xarr = sdfg.add_array( 'X', dtype=dtype, shape=[BATCH, M, K] if batched else [M, K], strides=[sAB, sAM, sAK] if batched else [sAM, sAK]) _, yarr = sdfg.add_array( 'Y', dtype=dtype, shape=[BATCH, K, N] if batched else [K, N], strides=[sBB, sBK, sBN] if batched else [sBK, sBN]) _, zarr = sdfg.add_array( 'Z', dtype=dtype, shape=[BATCH, M, N] if batched else [M, N], strides=[sCB, sCM, sCN] if batched else [sCM, sCN]) gX = state.add_read('X') gY = state.add_read('Y') gZ = state.add_write('Z') import dace.libraries.blas as blas # Avoid import loop libnode = blas.MatMul('einsum_gemm') state.add_node(libnode) state.add_edge(gX, None, libnode, '_a', Memlet.from_array(gX.data, xarr)) state.add_edge(gY, None, libnode, '_b', Memlet.from_array(gY.data, yarr)) state.add_edge(libnode, '_c', gZ, None, Memlet.from_array(gZ.data, zarr)) return sdfg
class MatrixProductTranspose(transformation.Transformation): """ Implements the matrix-matrix product transpose transformation. T(A) @ T(B) = T(B @ A) """ _transpose_a = blas.Transpose("") _at = nodes.AccessNode("") _transpose_b = blas.Transpose("") _bt = nodes.AccessNode("") _a_times_b = blas.MatMul("") @staticmethod def expressions(): graph = dace.sdfg.graph.OrderedDiGraph() graph.add_node(MatrixProductTranspose._transpose_a) graph.add_node(MatrixProductTranspose._at) graph.add_node(MatrixProductTranspose._transpose_b) graph.add_node(MatrixProductTranspose._bt) graph.add_node(MatrixProductTranspose._a_times_b) graph.add_edge(MatrixProductTranspose._transpose_a, MatrixProductTranspose._at, None) graph.add_edge(MatrixProductTranspose._at, MatrixProductTranspose._a_times_b, None) graph.add_edge(MatrixProductTranspose._transpose_b, MatrixProductTranspose._bt, None) graph.add_edge(MatrixProductTranspose._bt, MatrixProductTranspose._a_times_b, None) return [graph] @staticmethod def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False): _at = graph.nodes()[candidate[MatrixProductTranspose._at]] _a_times_b = graph.nodes()[candidate[ MatrixProductTranspose._a_times_b]] edges = graph.edges_between(_at, _a_times_b) # Enforce unique match if len(edges) != 1: return False _, _, _, dst_conn, _ = edges[0] if dst_conn != '_a': return False return True @staticmethod def match_to_str(graph, candidate): transpose_a = graph.nodes()[candidate[ MatrixProductTranspose._transpose_a]] transpose_b = graph.nodes()[candidate[ MatrixProductTranspose._transpose_b]] a_times_b = graph.nodes()[candidate[MatrixProductTranspose._a_times_b]] return f"{transpose_a.name} -> {a_times_b.name} <- {transpose_b.name}" def apply(self, sdfg): graph = sdfg.nodes()[self.state_id] transpose_a = graph.nodes()[self.subgraph[ MatrixProductTranspose._transpose_a]] _at = graph.nodes()[self.subgraph[MatrixProductTranspose._at]] transpose_b = graph.nodes()[self.subgraph[ MatrixProductTranspose._transpose_b]] _bt = graph.nodes()[self.subgraph[MatrixProductTranspose._bt]] a_times_b = graph.nodes()[self.subgraph[ MatrixProductTranspose._a_times_b]] for src, src_conn, _, _, memlet in graph.in_edges(transpose_a): graph.add_edge(src, src_conn, a_times_b, '_b', memlet) graph.remove_node(transpose_a) for src, src_conn, _, _, memlet in graph.in_edges(transpose_b): graph.add_edge(src, src_conn, a_times_b, '_a', memlet) graph.remove_node(transpose_b) graph.remove_node(_at) graph.remove_node(_bt) for _, _, dst, dst_conn, memlet in graph.out_edges(a_times_b): subset = dcpy(memlet.subset) subset.squeeze() size = subset.size() shape = [size[1], size[0]] break tmp_name, tmp_arr = sdfg.add_temp_transient(shape, a_times_b.dtype) tmp_acc = graph.add_access(tmp_name) transpose_c = blas.Transpose('_Transpose_', a_times_b.dtype) for edge in graph.out_edges(a_times_b): _, _, dst, dst_conn, memlet = edge graph.remove_edge(edge) graph.add_edge(transpose_c, '_out', dst, dst_conn, memlet) graph.add_edge(a_times_b, '_c', tmp_acc, None, dace.Memlet.from_array(tmp_name, tmp_arr)) graph.add_edge(tmp_acc, None, transpose_c, '_inp', dace.Memlet.from_array(tmp_name, tmp_arr))