示例#1
0
def nccl_send(pv: 'ProgramVisitor',
              sdfg: SDFG,
              state: SDFGState,
              in_buffer: str,
              peer: symbolic.SymbolicType = 0,
              group_handle: str = None):

    inputs = {"_inbuffer"}
    outputs = set()

    if isinstance(group_handle, str):
        gh_start = False
        if group_handle in sdfg.arrays.keys():
            gh_name = group_handle
            gh_out = state.add_access(gh_name)
            gh_in = state.add_access(gh_name)
            inputs.add("_group_handle")
        else:
            gh_start = True
            gh_name = _define_local_scalar(pv, sdfg, state, dace.int32,
                                           dtypes.StorageType.GPU_Global)
            gh_out = state.add_access(gh_name)
        outputs.add("_group_handle")

    libnode = Send(inputs=inputs, outputs=outputs, peer=peer)

    if isinstance(group_handle, str):
        gh_memlet = Memlet.simple(gh_name, '0')
        if not gh_start:
            state.add_edge(gh_in, None, libnode, "_group_handle", gh_memlet)
        state.add_edge(libnode, "_group_handle", gh_out, None, gh_memlet)

    in_range = None
    if isinstance(in_buffer, tuple):
        in_name, in_range = in_buffer
    else:
        in_name = in_buffer

    desc = sdfg.arrays[in_name]
    conn = libnode.in_connectors
    conn = {
        c: (dtypes.pointer(desc.dtype) if c == '_buffer' else t)
        for c, t in conn.items()
    }
    libnode.in_connectors = conn
    in_node = state.add_read(in_name)

    if in_range:
        buf_mem = Memlet.simple(in_name, in_range)
    else:
        buf_mem = Memlet.from_array(in_name, desc)

    state.add_edge(in_node, None, libnode, '_inbuffer', buf_mem)

    return []
示例#2
0
文件: recv.py 项目: thobauma/dace
def nccl_recv(pv: 'ProgramVisitor',
              sdfg: SDFG,
              state: SDFGState,
              out_buffer: str,
              peer: symbolic.SymbolicType = 0,
              group_handle: str = None):

    inputs = set()
    outputs = {"_outbuffer"}

    if isinstance(group_handle, str):
        gh_start = False
        if group_handle in sdfg.arrays.keys():
            gh_name = group_handle
            gh_out = state.add_access(gh_name)
            gh_in = state.add_access(gh_name)
            inputs.add("_group_handle")
        else:
            gh_start = True
            gh_name = _define_local_scalar(pv, sdfg, state, dace.int32,
                                           dtypes.StorageType.GPU_Global)
            gh_out = state.add_access(gh_name)
        outputs.add("_group_handle")

    libnode = Recv(inputs=inputs, outputs=outputs, peer=peer)

    if isinstance(group_handle, str):
        gh_memlet = Memlet.simple(gh_name, '0')
        if not gh_start:
            state.add_edge(gh_in, None, libnode, "_group_handle", gh_memlet)
        state.add_edge(libnode, "_group_handle", gh_out, None, gh_memlet)

    out_range = None
    if isinstance(out_buffer, tuple):
        out_name, out_range = out_buffer
        out_node = state.add_write(out_name)
    elif isinstance(out_buffer, str) and out_buffer in sdfg.arrays.keys():
        out_name = out_buffer
        out_node = state.add_write(out_name)
    else:
        raise ValueError(
            "NCCL_Recv out_buffer must be an array, or a an array range tuple.")

    if out_range:
        out_mem = Memlet.simple(out_name, out_range)
    else:
        out_mem = Memlet.simple(out_name, '0')

    state.add_edge(libnode, '_outbuffer', out_node, None, out_mem)

    return []
示例#3
0
文件: reduce.py 项目: thobauma/dace
def nccl_reduce(pv: 'ProgramVisitor',
                sdfg: SDFG,
                state: SDFGState,
                redfunction: Callable[[Any, Any], Any],
                in_buffer: str,
                out_buffer: Union[str, None] = None,
                root: str = None,
                group_handle: str = None):

    inputs = {"_inbuffer"}
    outputs = {"_outbuffer"}

    if isinstance(group_handle, str):
        gh_start = False
        if group_handle in sdfg.arrays.keys():
            gh_name = group_handle
            gh_out = state.add_access(gh_name)
            gh_in = state.add_access(gh_name)
            inputs.add("_group_handle")
        else:
            gh_start = True
            gh_name = _define_local_scalar(pv, sdfg, state, dace.int32,
                                           dtypes.StorageType.GPU_Global)
            gh_out = state.add_access(gh_name)
        outputs.add("_group_handle")

    libnode = Reduce(inputs=inputs,
                     outputs=outputs,
                     wcr=redfunction,
                     root=root)

    if isinstance(group_handle, str):
        gh_memlet = Memlet.simple(gh_name, '0')
        if not gh_start:
            state.add_edge(gh_in, None, libnode, "_group_handle", gh_memlet)
        state.add_edge(libnode, "_group_handle", gh_out, None, gh_memlet)

    # If out_buffer is not specified, the operation will be in-place.
    if out_buffer is None:
        out_buffer = in_buffer

    # Add nodes
    in_node = state.add_read(in_buffer)
    out_node = state.add_write(out_buffer)

    # Connect nodes
    state.add_edge(in_node, None, libnode, '_inbuffer', Memlet(in_buffer))
    state.add_edge(libnode, '_outbuffer', out_node, None, Memlet(out_buffer))

    return []
示例#4
0
文件: distr.py 项目: am-ivanov/dace
def _Reduce(pv: 'ProgramVisitor',
            sdfg: SDFG,
            state: SDFGState,
            buffer: str,
            op: str,
            root: Union[str, sp.Expr, Number] = 0,
            grid: str = None):

    from dace.libraries.mpi.nodes.reduce import Reduce

    libnode = Reduce('_Reduce_', op, grid)
    desc = sdfg.arrays[buffer]
    in_buffer = state.add_read(buffer)
    out_buffer = state.add_write(buffer)
    if isinstance(root, str) and root in sdfg.arrays.keys():
        root_node = state.add_read(root)
    else:
        storage = desc.storage
        root_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage)
        root_node = state.add_access(root_name)
        root_tasklet = state.add_tasklet('_set_root_', {}, {'__out'},
                                         '__out = {}'.format(root))
        state.add_edge(root_tasklet, '__out', root_node, None,
                       Memlet.simple(root_name, '0'))
    state.add_edge(in_buffer, None, libnode, '_inbuffer',
                   Memlet.from_array(buffer, desc))
    state.add_edge(root_node, None, libnode, '_root',
                   Memlet.simple(root_node.data, '0'))
    state.add_edge(libnode, '_outbuffer', out_buffer, None,
                   Memlet.from_array(buffer, desc))

    return None
示例#5
0
文件: distr.py 项目: am-ivanov/dace
def _wait(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, request: str):

    from dace.libraries.mpi.nodes.wait import Wait

    libnode = Wait('_Wait_')

    req_range = None
    if isinstance(request, tuple):
        req_name, req_range = request
    else:
        req_name = request

    desc = sdfg.arrays[req_name]
    req_node = state.add_access(req_name)

    src = sdfg.add_temp_transient([1], dtypes.int32)
    src_node = state.add_write(src[0])
    tag = sdfg.add_temp_transient([1], dtypes.int32)
    tag_node = state.add_write(tag[0])

    if req_range:
        req_mem = Memlet.simple(req_name, req_range)
    else:
        req_mem = Memlet.from_array(req_name, desc)

    state.add_edge(req_node, None, libnode, '_request', req_mem)
    state.add_edge(libnode, '_stat_source', src_node, None,
                   Memlet.from_array(*src))
    state.add_edge(libnode, '_stat_tag', tag_node, None,
                   Memlet.from_array(*tag))

    return None
示例#6
0
文件: distr.py 项目: am-ivanov/dace
def _gather(pv: 'ProgramVisitor',
            sdfg: SDFG,
            state: SDFGState,
            in_buffer: str,
            out_buffer: str,
            root: Union[str, sp.Expr, Number] = 0):

    from dace.libraries.mpi.nodes.gather import Gather

    libnode = Gather('_Gather_')
    in_desc = sdfg.arrays[in_buffer]
    out_desc = sdfg.arrays[out_buffer]
    in_node = state.add_read(in_buffer)
    out_node = state.add_write(out_buffer)
    if isinstance(root, str) and root in sdfg.arrays.keys():
        root_node = state.add_read(root)
    else:
        storage = in_desc.storage
        root_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage)
        root_node = state.add_access(root_name)
        root_tasklet = state.add_tasklet('_set_root_', {}, {'__out'},
                                         '__out = {}'.format(root))
        state.add_edge(root_tasklet, '__out', root_node, None,
                       Memlet.simple(root_name, '0'))
    state.add_edge(in_node, None, libnode, '_inbuffer',
                   Memlet.from_array(in_buffer, in_desc))
    state.add_edge(root_node, None, libnode, '_root',
                   Memlet.simple(root_node.data, '0'))
    state.add_edge(libnode, '_outbuffer', out_node, None,
                   Memlet.from_array(out_buffer, out_desc))

    return None
示例#7
0
文件: distr.py 项目: am-ivanov/dace
def _wait(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, request: str):

    from dace.libraries.mpi.nodes.wait import Waitall

    libnode = Waitall('_Waitall_')

    req_range = None
    if isinstance(request, tuple):
        req_name, req_range = request
    else:
        req_name = request

    desc = sdfg.arrays[req_name]
    req_node = state.add_access(req_name)

    if req_range:
        req_mem = Memlet.simple(req_name, req_range)
    else:
        req_mem = Memlet.from_array(req_name, desc)

    state.add_edge(req_node, None, libnode, '_request', req_mem)

    return None
示例#8
0
文件: distr.py 项目: am-ivanov/dace
def _bcgather(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState,
              in_buffer: str, out_buffer: str,
              block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]]):

    from dace.libraries.pblas.nodes.pgeadd import BlockCyclicGather

    libnode = BlockCyclicGather('_BCGather_')

    inbuf_range = None
    if isinstance(in_buffer, tuple):
        inbuf_name, inbuf_range = in_buffer
    else:
        inbuf_name = in_buffer
    in_desc = sdfg.arrays[inbuf_name]
    inbuf_node = state.add_read(inbuf_name)

    bsizes_range = None
    if isinstance(block_sizes, (list, tuple)):
        if isinstance(block_sizes[0], str):
            bsizes_name, bsizes_range = block_sizes
            bsizes_desc = sdfg.arrays[bsizes_name]
            bsizes_node = state.add_read(bsizes_name)
        else:
            bsizes_name, bsizes_desc = sdfg.add_temp_transient(
                (len(block_sizes), ), dtype=dace.int32)
            bsizes_node = state.add_access(bsizes_name)
            bsizes_tasklet = state.add_tasklet(
                '_set_bsizes_', {}, {'__out'}, ";".join([
                    "__out[{}] = {}".format(i, sz)
                    for i, sz in enumerate(block_sizes)
                ]))
            state.add_edge(bsizes_tasklet, '__out', bsizes_node, None,
                           Memlet.from_array(bsizes_name, bsizes_desc))
    else:
        bsizes_name = block_sizes
        bsizes_desc = sdfg.arrays[bsizes_name]
        bsizes_node = state.add_read(bsizes_name)

    outbuf_range = None
    if isinstance(out_buffer, tuple):
        outbuf_name, outbuf_range = out_buffer
    else:
        outbuf_name = out_buffer
    out_desc = sdfg.arrays[outbuf_name]
    outbuf_node = state.add_write(outbuf_name)

    if inbuf_range:
        inbuf_mem = Memlet.simple(inbuf_name, inbuf_range)
    else:
        inbuf_mem = Memlet.from_array(inbuf_name, in_desc)
    if bsizes_range:
        bsizes_mem = Memlet.simple(bsizes_name, bsizes_range)
    else:
        bsizes_mem = Memlet.from_array(bsizes_name, bsizes_desc)
    if outbuf_range:
        outbuf_mem = Memlet.simple(outbuf_name, outbuf_range)
    else:
        outbuf_mem = Memlet.from_array(outbuf_name, out_desc)

    state.add_edge(inbuf_node, None, libnode, '_inbuffer', inbuf_mem)
    state.add_edge(bsizes_node, None, libnode, '_block_sizes', bsizes_mem)
    state.add_edge(libnode, '_outbuffer', outbuf_node, None, outbuf_mem)

    return None
示例#9
0
文件: distr.py 项目: am-ivanov/dace
def _irecv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str,
           src: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr,
                                                        Number], request: str):

    from dace.libraries.mpi.nodes.irecv import Irecv

    libnode = Irecv('_Irecv_')

    buf_range = None
    if isinstance(buffer, tuple):
        buf_name, buf_range = buffer
    else:
        buf_name = buffer
    desc = sdfg.arrays[buf_name]
    buf_node = state.add_read(buf_name)

    req_range = None
    if isinstance(request, tuple):
        req_name, req_range = request
    else:
        req_name = request
    req_desc = sdfg.arrays[req_name]
    req_node = state.add_write(req_name)

    conn = libnode.out_connectors
    conn = {
        c: (dtypes.pointer(desc.dtype) if c == '_buffer' else t)
        for c, t in conn.items()
    }
    conn = {
        c: (dtypes.pointer(req_desc.dtype) if c == '_request' else t)
        for c, t in conn.items()
    }
    libnode.out_connectors = conn

    src_range = None
    if isinstance(src, tuple):
        src_name, src_range = src
        src_node = state.add_read(src_name)
    elif isinstance(src, str) and src in sdfg.arrays.keys():
        src_name = src
        src_node = state.add_read(src_name)
    else:
        storage = desc.storage
        src_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage)
        src_node = state.add_access(src_name)
        src_tasklet = state.add_tasklet('_set_src_', {}, {'__out'},
                                        '__out = {}'.format(src))
        state.add_edge(src_tasklet, '__out', src_node, None,
                       Memlet.simple(src_name, '0'))

    tag_range = None
    if isinstance(tag, tuple):
        tag_name, tag_range = tag
        tag_node = state.add_read(tag_name)
    if isinstance(tag, str) and tag in sdfg.arrays.keys():
        tag_name = tag
        tag_node = state.add_read(tag)
    else:
        storage = desc.storage
        tag_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage)
        tag_node = state.add_access(tag_name)
        tag_tasklet = state.add_tasklet('_set_tag_', {}, {'__out'},
                                        '__out = {}'.format(tag))
        state.add_edge(tag_tasklet, '__out', tag_node, None,
                       Memlet.simple(tag_name, '0'))

    if buf_range:
        buf_mem = Memlet.simple(buf_name, buf_range)
    else:
        buf_mem = Memlet.from_array(buf_name, desc)
    if req_range:
        req_mem = Memlet.simple(req_name, req_range)
    else:
        req_mem = Memlet.from_array(req_name, req_desc)
    if src_range:
        src_mem = Memlet.simple(src_name, src_range)
    else:
        src_mem = Memlet.simple(src_name, '0')
    if tag_range:
        tag_mem = Memlet.simple(tag_name, tag_range)
    else:
        tag_mem = Memlet.simple(tag_name, '0')

    state.add_edge(libnode, '_buffer', buf_node, None, buf_mem)
    state.add_edge(src_node, None, libnode, '_src', src_mem)
    state.add_edge(tag_node, None, libnode, '_tag', tag_mem)
    state.add_edge(libnode, '_request', req_node, None, req_mem)

    return None
示例#10
0
文件: distr.py 项目: am-ivanov/dace
def _isend(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str,
           dst: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr,
                                                        Number], request: str):

    from dace.libraries.mpi.nodes.isend import Isend

    libnode = Isend('_Isend_')

    buf_range = None
    if isinstance(buffer, tuple):
        buf_name, buf_range = buffer
    else:
        buf_name = buffer
    desc = sdfg.arrays[buf_name]
    buf_node = state.add_read(buf_name)

    req_range = None
    if isinstance(request, tuple):
        req_name, req_range = request
    else:
        req_name = request
    req_desc = sdfg.arrays[req_name]
    req_node = state.add_write(req_name)

    iconn = libnode.in_connectors
    iconn = {
        c: (dtypes.pointer(desc.dtype) if c == '_buffer' else t)
        for c, t in iconn.items()
    }
    libnode.in_connectors = iconn
    oconn = libnode.out_connectors
    oconn = {
        c: (dtypes.pointer(req_desc.dtype) if c == '_request' else t)
        for c, t in oconn.items()
    }
    libnode.out_connectors = oconn

    dst_range = None
    if isinstance(dst, tuple):
        dst_name, dst_range = dst
        dst_node = state.add_read(dst_name)
    elif isinstance(dst, str) and dst in sdfg.arrays.keys():
        dst_name = dst
        dst_node = state.add_read(dst_name)
    else:
        storage = desc.storage
        dst_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage)
        dst_node = state.add_access(dst_name)
        dst_tasklet = state.add_tasklet('_set_dst_', {}, {'__out'},
                                        '__out = {}'.format(dst))
        state.add_edge(dst_tasklet, '__out', dst_node, None,
                       Memlet.simple(dst_name, '0'))

    tag_range = None
    if isinstance(tag, tuple):
        tag_name, tag_range = tag
        tag_node = state.add_read(tag_name)
    if isinstance(tag, str) and tag in sdfg.arrays.keys():
        tag_name = tag
        tag_node = state.add_read(tag)
    else:
        storage = desc.storage
        tag_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage)
        tag_node = state.add_access(tag_name)
        tag_tasklet = state.add_tasklet('_set_tag_', {}, {'__out'},
                                        '__out = {}'.format(tag))
        state.add_edge(tag_tasklet, '__out', tag_node, None,
                       Memlet.simple(tag_name, '0'))

    if buf_range:
        buf_mem = Memlet.simple(buf_name, buf_range)
    else:
        buf_mem = Memlet.from_array(buf_name, desc)
    if req_range:
        req_mem = Memlet.simple(req_name, req_range)
    else:
        req_mem = Memlet.from_array(req_name, req_desc)
    if dst_range:
        dst_mem = Memlet.simple(dst_name, dst_range)
    else:
        dst_mem = Memlet.simple(dst_name, '0')
    if tag_range:
        tag_mem = Memlet.simple(tag_name, tag_range)
    else:
        tag_mem = Memlet.simple(tag_name, '0')

    state.add_edge(buf_node, None, libnode, '_buffer', buf_mem)
    state.add_edge(dst_node, None, libnode, '_dest', dst_mem)
    state.add_edge(tag_node, None, libnode, '_tag', tag_mem)
    state.add_edge(libnode, '_request', req_node, None, req_mem)

    return None
示例#11
0
文件: distr.py 项目: am-ivanov/dace
def _distr_matmult(pv: 'ProgramVisitor',
                   sdfg: SDFG,
                   state: SDFGState,
                   opa: str,
                   opb: str,
                   shape: Sequence[Union[sp.Expr, Number]],
                   a_block_sizes: Union[str, Sequence[Union[sp.Expr,
                                                            Number]]] = None,
                   b_block_sizes: Union[str, Sequence[Union[sp.Expr,
                                                            Number]]] = None,
                   c_block_sizes: Union[str, Sequence[Union[sp.Expr,
                                                            Number]]] = None):

    arra = sdfg.arrays[opa]
    arrb = sdfg.arrays[opb]

    if len(shape) == 3:
        gm, gn, gk = shape
    else:
        gm, gn = shape

    a_block_sizes = a_block_sizes or arra.shape
    if len(a_block_sizes) < 2:
        a_block_sizes = (a_block_sizes[0], 1)
    b_block_sizes = b_block_sizes or arrb.shape
    if len(b_block_sizes) < 2:
        b_block_sizes = (b_block_sizes[0], 1)

    if len(arra.shape) == 1 and len(arrb.shape) == 2:
        a_block_sizes, b_block_sizes = b_block_sizes, a_block_sizes

    a_bsizes_range = None
    if isinstance(a_block_sizes, (list, tuple)):
        if isinstance(a_block_sizes[0], str):
            a_bsizes_name, a_bsizes_range = a_block_sizes
            a_bsizes_desc = sdfg.arrays[a_bsizes_name]
            a_bsizes_node = state.add_read(a_bsizes_name)
        else:
            a_bsizes_name, a_bsizes_desc = sdfg.add_temp_transient(
                (len(a_block_sizes), ), dtype=dace.int32)
            a_bsizes_node = state.add_access(a_bsizes_name)
            a_bsizes_tasklet = state.add_tasklet(
                '_set_a_bsizes_', {}, {'__out'}, ";".join([
                    "__out[{}] = {}".format(i, sz)
                    for i, sz in enumerate(a_block_sizes)
                ]))
            state.add_edge(a_bsizes_tasklet, '__out', a_bsizes_node, None,
                           Memlet.from_array(a_bsizes_name, a_bsizes_desc))
    else:
        a_bsizes_name = a_block_sizes
        a_bsizes_desc = sdfg.arrays[a_bsizes_name]
        a_bsizes_node = state.add_read(a_bsizes_name)

    b_bsizes_range = None
    if isinstance(a_block_sizes, (list, tuple)):
        if isinstance(a_block_sizes[0], str):
            b_bsizes_name, b_sizes_range = b_block_sizes
            b_bsizes_desc = sdfg.arrays[b_bsizes_name]
            b_bsizes_node = state.add_read(b_bsizes_name)
        else:
            b_bsizes_name, b_bsizes_desc = sdfg.add_temp_transient(
                (len(b_block_sizes), ), dtype=dace.int32)
            b_bsizes_node = state.add_access(b_bsizes_name)
            b_bsizes_tasklet = state.add_tasklet(
                '_set_b_sizes_', {}, {'__out'}, ";".join([
                    "__out[{}] = {}".format(i, sz)
                    for i, sz in enumerate(b_block_sizes)
                ]))
            state.add_edge(b_bsizes_tasklet, '__out', b_bsizes_node, None,
                           Memlet.from_array(b_bsizes_name, b_bsizes_desc))
    else:
        b_bsizes_name = b_block_sizes
        b_bsizes_desc = sdfg.arrays[b_bsizes_name]
        b_bsizes_node = state.add_read(b_bsizes_name)

    if len(arra.shape) == 2 and len(arrb.shape) == 2:
        # Gemm
        from dace.libraries.pblas.nodes.pgemm import Pgemm
        tasklet = Pgemm("__DistrMatMult__", gm, gn, gk)
        m = arra.shape[0]
        n = arrb.shape[-1]
        out = sdfg.add_temp_transient((m, n), dtype=arra.dtype)
    elif len(arra.shape) == 2 and len(arrb.shape) == 1:
        # Gemv
        from dace.libraries.pblas.nodes.pgemv import Pgemv
        tasklet = Pgemv("__DistrMatVecMult__", m=gm, n=gn)
        if c_block_sizes:
            m = c_block_sizes[0]
        else:
            m = arra.shape[0]
        out = sdfg.add_temp_transient((m, ), dtype=arra.dtype)
    elif len(arra.shape) == 1 and len(arrb.shape) == 2:
        # Gemv transposed
        # Swap a and b
        opa, opb = opb, opa
        arra, arrb = arrb, arra
        from dace.libraries.pblas.nodes.pgemv import Pgemv
        tasklet = Pgemv("__DistrMatVecMult__", transa='T', m=gm, n=gn)
        if c_block_sizes:
            n = c_block_sizes[0]
        else:
            n = arra.shape[1]
        out = sdfg.add_temp_transient((n, ), dtype=arra.dtype)

    anode = state.add_read(opa)
    bnode = state.add_read(opb)
    cnode = state.add_write(out[0])

    if a_bsizes_range:
        a_bsizes_mem = Memlet.simple(a_bsizes_name, a_bsizes_range)
    else:
        a_bsizes_mem = Memlet.from_array(a_bsizes_name, a_bsizes_desc)
    if b_bsizes_range:
        b_bsizes_mem = Memlet.simple(b_bsizes_name, b_bsizes_range)
    else:
        b_bsizes_mem = Memlet.from_array(b_bsizes_name, b_bsizes_desc)

    state.add_edge(anode, None, tasklet, '_a', Memlet.from_array(opa, arra))
    state.add_edge(bnode, None, tasklet, '_b', Memlet.from_array(opb, arrb))
    state.add_edge(a_bsizes_node, None, tasklet, '_a_block_sizes',
                   a_bsizes_mem)
    state.add_edge(b_bsizes_node, None, tasklet, '_b_block_sizes',
                   b_bsizes_mem)
    state.add_edge(tasklet, '_c', cnode, None, Memlet.from_array(*out))

    return out[0]