Exemplo n.º 1
0
def create_batch_gemm_sdfg(dtype, strides):
    #########################
    sdfg = SDFG('einsum')
    state = sdfg.add_state()
    M, K, N = (symbolic.symbol(s) for s in ['M', 'K', 'N'])
    BATCH, sAM, sAK, sAB, sBK, sBN, sBB, sCM, sCN, sCB = (
        symbolic.symbol(s) if symbolic.issymbolic(strides[s]) else strides[s]
        for s in [
            'BATCH', 'sAM', 'sAK', 'sAB', 'sBK', 'sBN', 'sBB', 'sCM', 'sCN',
            'sCB'
        ])

    batched = strides['BATCH'] != 1

    _, xarr = sdfg.add_array(
        'X',
        dtype=dtype,
        shape=[BATCH, M, K] if batched else [M, K],
        strides=[sAB, sAM, sAK] if batched else [sAM, sAK])
    _, yarr = sdfg.add_array(
        'Y',
        dtype=dtype,
        shape=[BATCH, K, N] if batched else [K, N],
        strides=[sBB, sBK, sBN] if batched else [sBK, sBN])
    _, zarr = sdfg.add_array(
        'Z',
        dtype=dtype,
        shape=[BATCH, M, N] if batched else [M, N],
        strides=[sCB, sCM, sCN] if batched else [sCM, sCN])

    gX = state.add_read('X')
    gY = state.add_read('Y')
    gZ = state.add_write('Z')

    import dace.libraries.blas as blas  # Avoid import loop

    libnode = blas.MatMul('einsum_gemm')
    state.add_node(libnode)
    state.add_edge(gX, None, libnode, '_a', Memlet.from_array(gX.data, xarr))
    state.add_edge(gY, None, libnode, '_b', Memlet.from_array(gY.data, yarr))
    state.add_edge(libnode, '_c', gZ, None, Memlet.from_array(gZ.data, zarr))

    return sdfg
Exemplo n.º 2
0
class MatrixProductTranspose(transformation.Transformation):
    """ Implements the matrix-matrix product transpose transformation.

        T(A) @ T(B) = T(B @ A)
    """

    _transpose_a = blas.Transpose("")
    _at = nodes.AccessNode("")
    _transpose_b = blas.Transpose("")
    _bt = nodes.AccessNode("")
    _a_times_b = blas.MatMul("")

    @staticmethod
    def expressions():
        graph = dace.sdfg.graph.OrderedDiGraph()
        graph.add_node(MatrixProductTranspose._transpose_a)
        graph.add_node(MatrixProductTranspose._at)
        graph.add_node(MatrixProductTranspose._transpose_b)
        graph.add_node(MatrixProductTranspose._bt)
        graph.add_node(MatrixProductTranspose._a_times_b)
        graph.add_edge(MatrixProductTranspose._transpose_a,
                       MatrixProductTranspose._at, None)
        graph.add_edge(MatrixProductTranspose._at,
                       MatrixProductTranspose._a_times_b, None)
        graph.add_edge(MatrixProductTranspose._transpose_b,
                       MatrixProductTranspose._bt, None)
        graph.add_edge(MatrixProductTranspose._bt,
                       MatrixProductTranspose._a_times_b, None)
        return [graph]

    @staticmethod
    def can_be_applied(graph, candidate, expr_index, sdfg, permissive=False):
        _at = graph.nodes()[candidate[MatrixProductTranspose._at]]
        _a_times_b = graph.nodes()[candidate[
            MatrixProductTranspose._a_times_b]]
        edges = graph.edges_between(_at, _a_times_b)
        # Enforce unique match
        if len(edges) != 1:
            return False
        _, _, _, dst_conn, _ = edges[0]
        if dst_conn != '_a':
            return False
        return True

    @staticmethod
    def match_to_str(graph, candidate):
        transpose_a = graph.nodes()[candidate[
            MatrixProductTranspose._transpose_a]]
        transpose_b = graph.nodes()[candidate[
            MatrixProductTranspose._transpose_b]]
        a_times_b = graph.nodes()[candidate[MatrixProductTranspose._a_times_b]]
        return f"{transpose_a.name} -> {a_times_b.name} <- {transpose_b.name}"

    def apply(self, sdfg):
        graph = sdfg.nodes()[self.state_id]
        transpose_a = graph.nodes()[self.subgraph[
            MatrixProductTranspose._transpose_a]]
        _at = graph.nodes()[self.subgraph[MatrixProductTranspose._at]]
        transpose_b = graph.nodes()[self.subgraph[
            MatrixProductTranspose._transpose_b]]
        _bt = graph.nodes()[self.subgraph[MatrixProductTranspose._bt]]
        a_times_b = graph.nodes()[self.subgraph[
            MatrixProductTranspose._a_times_b]]

        for src, src_conn, _, _, memlet in graph.in_edges(transpose_a):
            graph.add_edge(src, src_conn, a_times_b, '_b', memlet)
        graph.remove_node(transpose_a)
        for src, src_conn, _, _, memlet in graph.in_edges(transpose_b):
            graph.add_edge(src, src_conn, a_times_b, '_a', memlet)
        graph.remove_node(transpose_b)
        graph.remove_node(_at)
        graph.remove_node(_bt)

        for _, _, dst, dst_conn, memlet in graph.out_edges(a_times_b):
            subset = dcpy(memlet.subset)
            subset.squeeze()
            size = subset.size()
            shape = [size[1], size[0]]
            break
        tmp_name, tmp_arr = sdfg.add_temp_transient(shape, a_times_b.dtype)
        tmp_acc = graph.add_access(tmp_name)
        transpose_c = blas.Transpose('_Transpose_', a_times_b.dtype)
        for edge in graph.out_edges(a_times_b):
            _, _, dst, dst_conn, memlet = edge
            graph.remove_edge(edge)
            graph.add_edge(transpose_c, '_out', dst, dst_conn, memlet)
        graph.add_edge(a_times_b, '_c', tmp_acc, None,
                       dace.Memlet.from_array(tmp_name, tmp_arr))
        graph.add_edge(tmp_acc, None, transpose_c, '_inp',
                       dace.Memlet.from_array(tmp_name, tmp_arr))