예제 #1
0
파일: distr.py 프로젝트: am-ivanov/dace
def _wait(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, request: str):

    from dace.libraries.mpi.nodes.wait import Wait

    libnode = Wait('_Wait_')

    req_range = None
    if isinstance(request, tuple):
        req_name, req_range = request
    else:
        req_name = request

    desc = sdfg.arrays[req_name]
    req_node = state.add_access(req_name)

    src = sdfg.add_temp_transient([1], dtypes.int32)
    src_node = state.add_write(src[0])
    tag = sdfg.add_temp_transient([1], dtypes.int32)
    tag_node = state.add_write(tag[0])

    if req_range:
        req_mem = Memlet.simple(req_name, req_range)
    else:
        req_mem = Memlet.from_array(req_name, desc)

    state.add_edge(req_node, None, libnode, '_request', req_mem)
    state.add_edge(libnode, '_stat_source', src_node, None,
                   Memlet.from_array(*src))
    state.add_edge(libnode, '_stat_tag', tag_node, None,
                   Memlet.from_array(*tag))

    return None
예제 #2
0
파일: distr.py 프로젝트: am-ivanov/dace
def _gather(pv: 'ProgramVisitor',
            sdfg: SDFG,
            state: SDFGState,
            in_buffer: str,
            out_buffer: str,
            root: Union[str, sp.Expr, Number] = 0):

    from dace.libraries.mpi.nodes.gather import Gather

    libnode = Gather('_Gather_')
    in_desc = sdfg.arrays[in_buffer]
    out_desc = sdfg.arrays[out_buffer]
    in_node = state.add_read(in_buffer)
    out_node = state.add_write(out_buffer)
    if isinstance(root, str) and root in sdfg.arrays.keys():
        root_node = state.add_read(root)
    else:
        storage = in_desc.storage
        root_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage)
        root_node = state.add_access(root_name)
        root_tasklet = state.add_tasklet('_set_root_', {}, {'__out'},
                                         '__out = {}'.format(root))
        state.add_edge(root_tasklet, '__out', root_node, None,
                       Memlet.simple(root_name, '0'))
    state.add_edge(in_node, None, libnode, '_inbuffer',
                   Memlet.from_array(in_buffer, in_desc))
    state.add_edge(root_node, None, libnode, '_root',
                   Memlet.simple(root_node.data, '0'))
    state.add_edge(libnode, '_outbuffer', out_node, None,
                   Memlet.from_array(out_buffer, out_desc))

    return None
예제 #3
0
파일: distr.py 프로젝트: am-ivanov/dace
def _Reduce(pv: 'ProgramVisitor',
            sdfg: SDFG,
            state: SDFGState,
            buffer: str,
            op: str,
            root: Union[str, sp.Expr, Number] = 0,
            grid: str = None):

    from dace.libraries.mpi.nodes.reduce import Reduce

    libnode = Reduce('_Reduce_', op, grid)
    desc = sdfg.arrays[buffer]
    in_buffer = state.add_read(buffer)
    out_buffer = state.add_write(buffer)
    if isinstance(root, str) and root in sdfg.arrays.keys():
        root_node = state.add_read(root)
    else:
        storage = desc.storage
        root_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage)
        root_node = state.add_access(root_name)
        root_tasklet = state.add_tasklet('_set_root_', {}, {'__out'},
                                         '__out = {}'.format(root))
        state.add_edge(root_tasklet, '__out', root_node, None,
                       Memlet.simple(root_name, '0'))
    state.add_edge(in_buffer, None, libnode, '_inbuffer',
                   Memlet.from_array(buffer, desc))
    state.add_edge(root_node, None, libnode, '_root',
                   Memlet.simple(root_node.data, '0'))
    state.add_edge(libnode, '_outbuffer', out_buffer, None,
                   Memlet.from_array(buffer, desc))

    return None
예제 #4
0
파일: distr.py 프로젝트: am-ivanov/dace
def _block_gather(pv: 'ProgramVisitor',
                  sdfg: SDFG,
                  state: SDFGState,
                  in_buffer: str,
                  out_buffer: str,
                  gather_grid: str,
                  reduce_grid: str = None,
                  correspondence: Sequence[Integral] = None):
    """ Block-gathers an Array using process-grids, sub-arrays, and the BlockGather library node.
        This method currently does not support Array slices and imperfect tiling.
        :param in_buffer: Name of the (local) Array descriptor.
        :param out_buffer: Name of the (global) Array descriptor.
        :param gather_grid: Name of the sub-grid used for gathering the Array (reduction group leaders).
        :param reduce_grid: Name of the sub-grid used for broadcasting the Array (reduction groups). 
        :param correspondence: Matching of the array/sub-array's dimensions to the process-grid's dimensions.
        :return: Name of the new sub-array descriptor.
    """
    in_desc = sdfg.arrays[in_buffer]
    out_desc = sdfg.arrays[out_buffer]

    if in_desc.dtype != out_desc.dtype:
        raise ValueError("Input/output buffer datatypes must match!")

    subarray_name = _subarray(pv,
                              sdfg,
                              state,
                              out_buffer,
                              in_buffer,
                              process_grid=gather_grid,
                              correspondence=correspondence)

    from dace.libraries.mpi import BlockGather
    libnode = BlockGather('_BlockGather_', subarray_name, gather_grid,
                          reduce_grid)

    inbuf_name = in_buffer
    in_desc = sdfg.arrays[inbuf_name]
    inbuf_node = state.add_read(inbuf_name)
    inbuf_mem = Memlet.from_array(inbuf_name, in_desc)

    outbuf_name = out_buffer
    out_desc = sdfg.arrays[outbuf_name]
    outbuf_node = state.add_write(outbuf_name)
    outbuf_mem = Memlet.from_array(outbuf_name, out_desc)

    state.add_edge(inbuf_node, None, libnode, '_inp_buffer', inbuf_mem)
    state.add_edge(libnode, '_out_buffer', outbuf_node, None, outbuf_mem)

    return subarray_name
예제 #5
0
def _reduce(sdfg: SDFG,
            state: SDFGState,
            redfunction: Callable[[Any, Any], Any],
            in_array: str,
            out_array=None,
            axis=None,
            identity=None):
    if out_array is None:
        inarr = in_array
        # Convert axes to tuple
        if axis is not None and not isinstance(axis, (tuple, list)):
            axis = (axis, )
        if axis is not None:
            axis = tuple(pystr_to_symbolic(a) for a in axis)
        input_subset = parse_memlet_subset(sdfg.arrays[inarr],
                                           ast.parse(in_array).body[0].value,
                                           {})
        input_memlet = Memlet.simple(inarr, input_subset)
        output_shape = None
        if axis is None:
            output_shape = [1]
        else:
            output_subset = copy.deepcopy(input_subset)
            output_subset.pop(axis)
            output_shape = output_subset.size()
        outarr, arr = sdfg.add_temp_transient(output_shape,
                                              sdfg.arrays[inarr].dtype,
                                              sdfg.arrays[inarr].storage)
        output_memlet = Memlet.from_array(outarr, arr)
    else:
        inarr = in_array
        outarr = out_array

        # Convert axes to tuple
        if axis is not None and not isinstance(axis, (tuple, list)):
            axis = (axis, )
        if axis is not None:
            axis = tuple(pystr_to_symbolic(a) for a in axis)

        # Compute memlets
        input_subset = parse_memlet_subset(sdfg.arrays[inarr],
                                           ast.parse(in_array).body[0].value,
                                           {})
        input_memlet = Memlet.simple(inarr, input_subset)
        output_subset = parse_memlet_subset(sdfg.arrays[outarr],
                                            ast.parse(out_array).body[0].value,
                                            {})
        output_memlet = Memlet.simple(outarr, output_subset)

    # Create reduce subgraph
    inpnode = state.add_read(inarr)
    rednode = state.add_reduce(redfunction, axis, identity)
    outnode = state.add_write(outarr)
    state.add_nedge(inpnode, rednode, input_memlet)
    state.add_nedge(rednode, outnode, output_memlet)

    if out_array is None:
        return outarr
    else:
        return []
예제 #6
0
파일: distr.py 프로젝트: am-ivanov/dace
def _cart_create(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState,
                 dims: ShapeType):
    """ Creates a process-grid and adds it to the DaCe program. The process-grid is implemented with [MPI_Cart_create](https://www.mpich.org/static/docs/latest/www3/MPI_Cart_create.html).
        :param dims: Shape of the process-grid (see `dims` parameter of `MPI_Cart_create`), e.g., [2, 3, 3].
        :return: Name of the new process-grid descriptor.
    """
    pgrid_name = sdfg.add_pgrid(dims)

    # Dummy tasklet adds MPI variables to the program's state.
    from dace.libraries.mpi import Dummy
    tasklet = Dummy(pgrid_name, [
        f'MPI_Comm {pgrid_name}_comm;',
        f'MPI_Group {pgrid_name}_group;',
        f'int {pgrid_name}_coords[{len(dims)}];',
        f'int {pgrid_name}_dims[{len(dims)}];',
        f'int {pgrid_name}_rank;',
        f'int {pgrid_name}_size;',
        f'bool {pgrid_name}_valid;',
    ])

    state.add_node(tasklet)

    # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations.
    _, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True)
    wnode = state.add_write(pgrid_name)
    state.add_edge(tasklet, '__out', wnode, None,
                   Memlet.from_array(pgrid_name, scal))

    return pgrid_name
예제 #7
0
def test_gpu_dma():
    sdfg: dace.SDFG = gpu_dma.to_sdfg(strict=True)
    sdfg.name = 'gpu_dma'
    sdfg.apply_transformations(GPUTransformSDFG,
                               options={'strict_transform': False})

    map_ = next(n for n, _ in sdfg.all_nodes_recursive()
                if isinstance(n, nodes.MapEntry))

    add_gpu_location(sdfg, map_, 0)

    sdfg.arrays['gpu_X'].location = {'gpu': 1}

    # clone GPU scalar
    inodename = 'alpha'
    inode = sdfg.arrays['alpha']
    newdesc = inode.clone()
    newdesc.location = {'gpu': 0}
    newdesc.storage = StorageType.GPU_Global
    newdesc.transient = True
    name = sdfg.add_datadesc('gpu_' + inodename, newdesc, find_new_name=True)
    # Replace original scalar
    for state in sdfg.nodes():
        for node in state.nodes():
            if (isinstance(node, nodes.AccessNode) and node.data == inodename):
                node.data = name
    # Replace memlets
    for state in sdfg.nodes():
        for edge in state.edges():
            if edge.data.data == inodename:
                edge.data.data = name

    # add GPU scalar to the copyin state
    copyin_state = sdfg.start_state

    src_array = nodes.AccessNode(inodename, debuginfo=inode.debuginfo)
    dst_array = nodes.AccessNode(name, debuginfo=inode.debuginfo)
    copyin_state.add_node(src_array)
    copyin_state.add_node(dst_array)
    copyin_state.add_nedge(
        src_array, dst_array,
        Memlet.from_array(src_array.data, src_array.desc(sdfg)))

    sdfg.apply_strict_transformations()

    np.random.seed(0)
    n = 16
    X = np.ndarray(shape=n, dtype=np_dtype)
    alpha = np.ndarray(shape=1, dtype=np_dtype)
    alpha.fill(np.random.rand())

    a_times_X = sdfg(X=X, alpha=alpha[0], N=n)
    res = X * alpha
    idx = zip(*np.where(~np.isclose(res, a_times_X, atol=0, rtol=1e-7)))
    for i in idx:
        print(i, res[i], a_times_X, X[i] * alpha, X[i], alpha)
    assert np.allclose(res, a_times_X)
    print('PASS')
예제 #8
0
def _simple_call(sdfg: SDFG,
                 state: SDFGState,
                 inpname: str,
                 func: str,
                 restype: dace.typeclass = None):
    """ Implements a simple call of the form `out = func(inp)`. """
    inparr = sdfg.arrays[inpname]
    if restype is None:
        restype = sdfg.arrays[inpname].dtype
    outname, outarr = sdfg.add_temp_transient(inparr.shape, restype,
                                              inparr.storage)
    num_elements = reduce(lambda x, y: x * y, inparr.shape)
    if num_elements == 1:
        inp = state.add_read(inpname)
        out = state.add_write(outname)
        tasklet = state.add_tasklet(func, {'__inp'}, {'__out'},
                                    '__out = {f}(__inp)'.format(f=func))
        state.add_edge(inp, None, tasklet, '__inp',
                       Memlet.from_array(inpname, inparr))
        state.add_edge(tasklet, '__out', out, None,
                       Memlet.from_array(outname, outarr))
    else:
        state.add_mapped_tasklet(
            name=func,
            map_ranges={
                '__i%d' % i: '0:%s' % n
                for i, n in enumerate(inparr.shape)
            },
            inputs={
                '__inp':
                Memlet.simple(
                    inpname,
                    ','.join(['__i%d' % i for i in range(len(inparr.shape))]))
            },
            code='__out = {f}(__inp)'.format(f=func),
            outputs={
                '__out':
                Memlet.simple(
                    outname,
                    ','.join(['__i%d' % i for i in range(len(inparr.shape))]))
            },
            external_edges=True)

    return outname
예제 #9
0
파일: distr.py 프로젝트: am-ivanov/dace
def _Allreduce(pv: 'ProgramVisitor',
               sdfg: SDFG,
               state: SDFGState,
               buffer: str,
               op: str,
               grid: str = None):

    from dace.libraries.mpi.nodes.allreduce import Allreduce

    libnode = Allreduce('_Allreduce_', op, grid)
    desc = sdfg.arrays[buffer]
    in_buffer = state.add_read(buffer)
    out_buffer = state.add_write(buffer)
    state.add_edge(in_buffer, None, libnode, '_inbuffer',
                   Memlet.from_array(buffer, desc))
    state.add_edge(libnode, '_outbuffer', out_buffer, None,
                   Memlet.from_array(buffer, desc))

    return None
예제 #10
0
def create_batch_gemm_sdfg(dtype, strides):
    #########################
    sdfg = SDFG('einsum')
    state = sdfg.add_state()
    M, K, N = (symbolic.symbol(s) for s in ['M', 'K', 'N'])
    BATCH, sAM, sAK, sAB, sBK, sBN, sBB, sCM, sCN, sCB = (
        symbolic.symbol(s) if symbolic.issymbolic(strides[s]) else strides[s]
        for s in [
            'BATCH', 'sAM', 'sAK', 'sAB', 'sBK', 'sBN', 'sBB', 'sCM', 'sCN',
            'sCB'
        ])

    batched = strides['BATCH'] != 1

    _, xarr = sdfg.add_array(
        'X',
        dtype=dtype,
        shape=[BATCH, M, K] if batched else [M, K],
        strides=[sAB, sAM, sAK] if batched else [sAM, sAK])
    _, yarr = sdfg.add_array(
        'Y',
        dtype=dtype,
        shape=[BATCH, K, N] if batched else [K, N],
        strides=[sBB, sBK, sBN] if batched else [sBK, sBN])
    _, zarr = sdfg.add_array(
        'Z',
        dtype=dtype,
        shape=[BATCH, M, N] if batched else [M, N],
        strides=[sCB, sCM, sCN] if batched else [sCM, sCN])

    gX = state.add_read('X')
    gY = state.add_read('Y')
    gZ = state.add_write('Z')

    import dace.libraries.blas as blas  # Avoid import loop

    libnode = blas.MatMul('einsum_gemm')
    state.add_node(libnode)
    state.add_edge(gX, None, libnode, '_a', Memlet.from_array(gX.data, xarr))
    state.add_edge(gY, None, libnode, '_b', Memlet.from_array(gY.data, yarr))
    state.add_edge(libnode, '_c', gZ, None, Memlet.from_array(gZ.data, zarr))

    return sdfg
예제 #11
0
def nccl_send(pv: 'ProgramVisitor',
              sdfg: SDFG,
              state: SDFGState,
              in_buffer: str,
              peer: symbolic.SymbolicType = 0,
              group_handle: str = None):

    inputs = {"_inbuffer"}
    outputs = set()

    if isinstance(group_handle, str):
        gh_start = False
        if group_handle in sdfg.arrays.keys():
            gh_name = group_handle
            gh_out = state.add_access(gh_name)
            gh_in = state.add_access(gh_name)
            inputs.add("_group_handle")
        else:
            gh_start = True
            gh_name = _define_local_scalar(pv, sdfg, state, dace.int32,
                                           dtypes.StorageType.GPU_Global)
            gh_out = state.add_access(gh_name)
        outputs.add("_group_handle")

    libnode = Send(inputs=inputs, outputs=outputs, peer=peer)

    if isinstance(group_handle, str):
        gh_memlet = Memlet.simple(gh_name, '0')
        if not gh_start:
            state.add_edge(gh_in, None, libnode, "_group_handle", gh_memlet)
        state.add_edge(libnode, "_group_handle", gh_out, None, gh_memlet)

    in_range = None
    if isinstance(in_buffer, tuple):
        in_name, in_range = in_buffer
    else:
        in_name = in_buffer

    desc = sdfg.arrays[in_name]
    conn = libnode.in_connectors
    conn = {
        c: (dtypes.pointer(desc.dtype) if c == '_buffer' else t)
        for c, t in conn.items()
    }
    libnode.in_connectors = conn
    in_node = state.add_read(in_name)

    if in_range:
        buf_mem = Memlet.simple(in_name, in_range)
    else:
        buf_mem = Memlet.from_array(in_name, desc)

    state.add_edge(in_node, None, libnode, '_inbuffer', buf_mem)

    return []
예제 #12
0
파일: distr.py 프로젝트: am-ivanov/dace
def _subarray(pv: 'ProgramVisitor',
              sdfg: SDFG,
              state: SDFGState,
              array: Union[str, ShapeType],
              subarray: Union[str, ShapeType],
              dtype: dtypes.typeclass = None,
              process_grid: str = None,
              correspondence: Sequence[Integral] = None):
    """ Adds a sub-array descriptor to the DaCe Program.
        Sub-arrays are implemented (when `process_grid` is set) with [MPI_Type_create_subarray](https://www.mpich.org/static/docs/v3.2/www3/MPI_Type_create_subarray.html).
        :param array: Either the name of an Array descriptor or the shape of the array (similar to the `array_of_sizes` parameter of `MPI_Type_create_subarray`).
        :param subarray: Either the name of an Array descriptor or the sub-shape of the (sub-)array (similar to the `array_of_subsizes` parameter of `MPI_Type_create_subarray`).
        :param dtype: Datatype of the array/sub-array (similar to the `oldtype` parameter of `MPI_Type_create_subarray`).
        :process_grid: Name of the process-grid for collective scatter/gather operations.
        :param correspondence: Matching of the array/sub-array's dimensions to the process-grid's dimensions.
        :return: Name of the new sub-array descriptor.
    """
    # Get dtype, shape, and subshape
    if isinstance(array, str):
        shape = sdfg.arrays[array].shape
        arr_dtype = sdfg.arrays[array].dtype
    else:
        shape = array
        arr_dtype = None
    if isinstance(subarray, str):
        subshape = sdfg.arrays[subarray].shape
        sub_dtype = sdfg.arrays[subarray].dtype
    else:
        subshape = subarray
        sub_dtype = None
    dtype = dtype or arr_dtype or sub_dtype

    subarray_name = sdfg.add_subarray(dtype, shape, subshape, process_grid,
                                      correspondence)

    # Generate subgraph only if process-grid is set, i.e., the sub-array will be used for collective scatter/gather ops.
    if process_grid:
        # Dummy tasklet adds MPI variables to the program's state.
        from dace.libraries.mpi import Dummy
        tasklet = Dummy(subarray_name, [
            f'MPI_Datatype {subarray_name};', f'int* {subarray_name}_counts;',
            f'int* {subarray_name}_displs;'
        ])

        state.add_node(tasklet)

        # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations.
        _, scal = sdfg.add_scalar(subarray_name, dace.int32, transient=True)
        wnode = state.add_write(subarray_name)
        state.add_edge(tasklet, '__out', wnode, None,
                       Memlet.from_array(subarray_name, scal))

    return subarray_name
예제 #13
0
파일: distr.py 프로젝트: am-ivanov/dace
def _cart_sub(pv: 'ProgramVisitor',
              sdfg: SDFG,
              state: SDFGState,
              parent_grid: str,
              color: Sequence[Union[Integral, bool]],
              exact_grid: RankType = None):
    """ Partitions the `parent_grid` to lower-dimensional sub-grids and adds them to the DaCe program.
        The sub-grids are implemented with [MPI_Cart_sub](https://www.mpich.org/static/docs/latest/www3/MPI_Cart_sub.html).
        :param parent_grid: Parent process-grid (similar to the `comm` parameter of `MPI_Cart_sub`).
        :param color: The i-th entry specifies whether the i-th dimension is kept in the sub-grid or is dropped (see `remain_dims` input of `MPI_Cart_sub`).
        :param exact_grid: [DEVELOPER] If set then, out of all the sub-grids created, only the one that contains the rank with id `exact_grid` will be utilized for collective communication.
        :return: Name of the new sub-grid descriptor.
    """
    pgrid_name = sdfg.add_pgrid(parent_grid=parent_grid,
                                color=color,
                                exact_grid=exact_grid)

    # Count sub-grid dimensions.
    pgrid_ndims = sum([bool(c) for c in color])

    # Dummy tasklet adds MPI variables to the program's state.
    from dace.libraries.mpi import Dummy
    tasklet = Dummy(pgrid_name, [
        f'MPI_Comm {pgrid_name}_comm;',
        f'MPI_Group {pgrid_name}_group;',
        f'int {pgrid_name}_coords[{pgrid_ndims}];',
        f'int {pgrid_name}_dims[{pgrid_ndims}];',
        f'int {pgrid_name}_rank;',
        f'int {pgrid_name}_size;',
        f'bool {pgrid_name}_valid;',
    ])

    state.add_node(tasklet)

    # Pseudo-writing to a dummy variable to avoid removal of Dummy node by transformations.
    _, scal = sdfg.add_scalar(pgrid_name, dace.int32, transient=True)
    wnode = state.add_write(pgrid_name)
    state.add_edge(tasklet, '__out', wnode, None,
                   Memlet.from_array(pgrid_name, scal))

    return pgrid_name
예제 #14
0
파일: distr.py 프로젝트: am-ivanov/dace
def _wait(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, request: str):

    from dace.libraries.mpi.nodes.wait import Waitall

    libnode = Waitall('_Waitall_')

    req_range = None
    if isinstance(request, tuple):
        req_name, req_range = request
    else:
        req_name = request

    desc = sdfg.arrays[req_name]
    req_node = state.add_access(req_name)

    if req_range:
        req_mem = Memlet.simple(req_name, req_range)
    else:
        req_mem = Memlet.from_array(req_name, desc)

    state.add_edge(req_node, None, libnode, '_request', req_mem)

    return None
예제 #15
0
def make_sdfg(implementation, dtype, storage=dace.StorageType.Default):

    n = dace.symbol("n")

    suffix = "_device" if storage != dace.StorageType.Default else ""
    transient = storage != dace.StorageType.Default

    sdfg = dace.SDFG("matrix_solve_getrf_getrs_{}_{}".format(
        implementation, dtype))
    state = sdfg.add_state("dataflow")

    Ahost_arr = sdfg.add_array("A", [n, n],
                               dtype,
                               storage=dace.StorageType.Default)
    Bhost_arr = sdfg.add_array("B", [n],
                               dtype,
                               storage=dace.StorageType.Default)

    if transient:
        A_arr = sdfg.add_array("A" + suffix, [n, n],
                               dtype,
                               storage=storage,
                               transient=transient)
        AT_arr = sdfg.add_array('AT' + suffix, [n, n],
                                dtype,
                                storage=storage,
                                transient=transient)
        B_arr = sdfg.add_array("B" + suffix, [n],
                               dtype,
                               storage=storage,
                               transient=transient)
    sdfg.add_array("pivots" + suffix, [n],
                   dace.dtypes.int32,
                   storage=storage,
                   transient=transient)
    sdfg.add_array("result_getrf" + suffix, [1],
                   dace.dtypes.int32,
                   storage=storage,
                   transient=transient)
    sdfg.add_array("result_getrs" + suffix, [1],
                   dace.dtypes.int32,
                   storage=storage,
                   transient=transient)

    if transient:
        Ahi = state.add_read("A")
        Ai = state.add_access("A" + suffix)
        Ain = state.add_access("AT" + suffix)
        Aout = state.add_access("AT" + suffix)
        Bhi = state.add_read("B")
        Bho = state.add_read("B")
        Bin = state.add_access("B" + suffix)
        Bout = state.add_access("B" + suffix)
        transpose_in = blas.nodes.transpose.Transpose("transpose_in",
                                                      dtype=dtype)
        transpose_in.implementation = "cuBLAS"
        state.add_nedge(Ahi, Ai, Memlet.from_array(*Ahost_arr))
        state.add_nedge(Bhi, Bin, Memlet.from_array(*Bhost_arr))
        state.add_nedge(Bout, Bho, Memlet.from_array(*Bhost_arr))
        state.add_memlet_path(Ai,
                              transpose_in,
                              dst_conn='_inp',
                              memlet=Memlet.from_array(*A_arr))
        state.add_memlet_path(transpose_in,
                              Ain,
                              src_conn='_out',
                              memlet=Memlet.from_array(*AT_arr))
    else:
        Ain = state.add_access("A" + suffix)
        Aout = state.add_access("A" + suffix)
        Bin = state.add_access("B" + suffix)
        Bout = state.add_access("B" + suffix)
    pivots = state.add_access("pivots" + suffix)
    res_getrf = state.add_access("result_getrf" + suffix)
    res_getrs = state.add_access("result_getrs" + suffix)

    getrf_node = lapack.nodes.getrf.Getrf("getrf")
    getrf_node.implementation = implementation
    getrs_node = lapack.nodes.getrs.Getrs("getrs")
    getrs_node.implementation = implementation

    state.add_memlet_path(Ain,
                          getrf_node,
                          dst_conn="_xin",
                          memlet=Memlet.simple(Ain,
                                               "0:n, 0:n",
                                               num_accesses=n * n))
    state.add_memlet_path(getrf_node,
                          res_getrf,
                          src_conn="_res",
                          memlet=Memlet.simple(res_getrf, "0", num_accesses=1))
    state.add_memlet_path(getrs_node,
                          res_getrs,
                          src_conn="_res",
                          memlet=Memlet.simple(res_getrs, "0", num_accesses=1))
    state.add_memlet_path(getrf_node,
                          pivots,
                          src_conn="_ipiv",
                          memlet=Memlet.simple(pivots, "0:n", num_accesses=n))
    state.add_memlet_path(pivots,
                          getrs_node,
                          dst_conn="_ipiv",
                          memlet=Memlet.simple(pivots, "0:n", num_accesses=n))
    state.add_memlet_path(getrf_node,
                          Aout,
                          src_conn="_xout",
                          memlet=Memlet.simple(Aout,
                                               "0:n, 0:n",
                                               num_accesses=n * n))
    state.add_memlet_path(Aout,
                          getrs_node,
                          dst_conn="_a",
                          memlet=Memlet.simple(Aout,
                                               "0:n, 0:n",
                                               num_accesses=n * n))
    state.add_memlet_path(Bin,
                          getrs_node,
                          dst_conn="_rhs_in",
                          memlet=Memlet.simple(Bin, "0:n", num_accesses=n))
    state.add_memlet_path(getrs_node,
                          Bout,
                          src_conn="_rhs_out",
                          memlet=Memlet.simple(Bout, "0:n", num_accesses=n))

    return sdfg
예제 #16
0
파일: winograd.py 프로젝트: mratsim/dace
def mm_small(
    state,
    A_node,
    B_node,
    C_node,
    A_subset=None,
    B_subset=None,
    C_subset=None,
    A_memlet=None,
    B_memlet=None,
    C_memlet=None,
    map_entry=None,
    map_exit=None,
    A_direct=True,
    B_direct=True,
):
    # C = A@B
    sdfg = state.parent
    Cshape = C_node.desc(sdfg).shape if C_subset is None else C_subset
    if isinstance(B_node, str):
        someshape = list(
            A_node.desc(sdfg).shape) if A_subset is None else A_subset
        someshape = [someshape[0]] + [someshape[1]] + [Cshape[1]]
    elif isinstance(A_node, str):
        someshape = list(
            B_node.desc(sdfg).shape) if B_subset is None else B_subset
        someshape = [Cshape[0]] + [someshape[0]] + [someshape[1]]
    else:
        raise AssertionError(
            "one of these have to be an str, else don't call this function")

    mapRange = ["0:" + str(_s) for _s in someshape]
    mapParams = ["i2", "i3", "i4"]
    mmapEntry, mmapExit = state.add_map(
        "matmul_sequential",
        dict(zip(mapParams, mapRange)),
        schedule=dace.ScheduleType.Sequential,
    )
    if isinstance(A_node, str):
        tasklet = state.add_tasklet("matmul_sequential", {"j1"}, {"out"},
                                    "out=" + A_node + "[i2, i3]" + "*j1")
        if B_memlet:
            state.add_edge(map_entry, None, mmapEntry, None, B_memlet)
            b_memlet_trailing = [str(_t[0]) for _t in B_memlet.subset[-2:]]
            state.add_edge(
                mmapEntry,
                None,
                tasklet,
                "j1",
                Memlet.simple(B_node,
                              ",".join(["i3", "i4"] + b_memlet_trailing)),
            )
        else:
            state.add_edge(
                B_node if B_direct else map_entry,
                None,
                mmapEntry,
                None,
                Memlet.from_array(B_node, B_node.desc(sdfg)),
            )
            state.add_edge(
                mmapEntry,
                None,
                tasklet,
                "j1",
                Memlet.simple(B_node, ",".join(["i3", "i4"])),
            )
    else:
        tasklet = state.add_tasklet("matmul_sequential", {"j0"}, {"out"},
                                    "out=j0*" + B_node + "[i3, i4]")
        if A_memlet:
            state.add_edge(map_entry, None, mmapEntry, None, A_memlet)
            a_memlet_trailing = [str(_t[0]) for _t in A_memlet.subset[-2:]]
            state.add_edge(
                mmapEntry,
                None,
                tasklet,
                "j0",
                Memlet.simple(A_node,
                              ",".join(["i2", "i3"] + a_memlet_trailing)),
            )
        else:
            state.add_edge(
                A_node if A_direct else map_entry,
                None,
                mmapEntry,
                None,
                Memlet.from_array(A_node, A_node.desc(sdfg)),
            )
            state.add_edge(
                mmapEntry,
                None,
                tasklet,
                "j0",
                Memlet.simple(A_node, ",".join(["i2", "i3"])),
            )

    if C_memlet:
        c_memlet_trailing = [str(_t[0]) for _t in C_memlet.subset[-2:]]
        state.add_edge(
            tasklet,
            "out",
            mmapExit,
            None,
            Memlet.simple(
                C_node,
                ",".join(["i2", "i4"] + c_memlet_trailing),
                wcr_str="lambda a,b: a+b",
                wcr_conflict=False,
            ),
        )
        state.add_edge(mmapExit, None, map_exit, None, C_memlet)
    else:
        state.add_edge(
            tasklet,
            "out",
            mmapExit,
            None,
            Memlet.simple(
                C_node,
                ",".join(["i2", "i4"]),
                wcr_str="lambda a,b:a+b",
                wcr_conflict=False,
            ),
        )
        state.add_edge(
            mmapExit,
            None,
            C_node if map_exit is None else map_exit,
            None,
            Memlet.simple(
                C_node,
                ",".join(["0:" + str(_s) for _s in C_node.desc(sdfg).shape]),
                wcr_str="lambda a,b:a+b",
                wcr_conflict=False,
            ),
        )
예제 #17
0
파일: winograd.py 프로젝트: mratsim/dace
def winograd_convolution(dace_session, tf_node):
    debugNodes = []
    state = dace_session.state
    add_cublas_cusolver(dace_session.graph)
    #############Add constants for transformation matrices###############
    dace_session.graph.add_constant('Btrans', bt)
    dace_session.graph.add_constant('B', b)
    bNode = 'B'
    bTransposeNode = 'Btrans'

    dace_session.graph.add_constant('G', g)
    dace_session.graph.add_constant('Gtrans', gt)
    gNode = 'G'
    gTransposeNode = 'Gtrans'

    dace_session.graph.add_constant('Atrans', at)
    dace_session.graph.add_constant('A', a)
    aNode = 'A'
    aTransposeNode = 'Atrans'

    inputNodes = []
    inputParams = []
    inputDims = []
    for _inp in tf_node.inputs:
        _node, _params, _dims = dace_session.create_and_add_input_node(_inp)
        inputNodes.append(_node)
        inputParams.append(_params)
        inputDims.append(_dims)
    # Manually add copy for kernel from CPU to GPU
    kernel_desc = inputNodes[1].desc(dace_session.graph)
    kernelGPU = state.add_transient(
        inputNodes[1].data + "GPU",
        shape=kernel_desc.shape,
        dtype=kernel_desc.dtype,
        lifetime=dtypes.AllocationLifetime.SDFG,
        storage=dace.StorageType.GPU_Global,
    )
    state.add_edge(
        inputNodes[1],
        None,
        kernelGPU,
        None,
        Memlet.from_array(inputNodes[1],
                          inputNodes[1].desc(dace_session.graph)),
    )
    inputNodes[1] = kernelGPU
    outputList = dace_session.create_and_add_output_node(tf_node)
    outputDims = dace_session.get_default_dims(tf_node.outputs[0])
    if str(tf_node.get_attr("padding"))[2:-1] == "SAME":
        paddedInput, paddedDims = dace_session.inputPadding(
            tf_node,
            inputNodes[0],
            inputNodes[0].desc(dace_session.graph),
            outputList[0].desc(dace_session.graph).shape[1],
            inputNodes[1].desc(dace_session.graph).shape[0],
            tf_node.get_attr("strides")[1],
            inputDims[0],
        )
        inputDims[0] = paddedDims
        inputNodes[0] = paddedInput
    outputShape = [int(_s) for _s in tf_node.outputs[0].shape]
    inputViewShape = [
        IMAGE_TILE_SIZE,
        IMAGE_TILE_SIZE,
        tf_node.inputs[0].shape[-1],
        outputShape[0] * ceil(outputShape[1] / OUTPUT_TILE_SIZE) *
        ceil(outputShape[2] / OUTPUT_TILE_SIZE),
    ]
    inputViewDims = ["0:" + str(_x) for _x in inputViewShape]
    ########Tiling the image#################################
    inputViewParams = [
        "i3%" + str(outputShape[0]),
        "(i3/" + str(outputShape[0]) + ")%"
        # + str(output_shape[0] * ceil(output_shape[1] / OUTPUT_TILE_SIZE))
        + str(ceil(outputShape[2] / OUTPUT_TILE_SIZE)) + "*" +
        str(OUTPUT_TILE_SIZE) + "+i0",
        # + str(
        #    ceil(output_shape[1] / OUTPUT_TILE_SIZE)
        #    * ceil(output_shape[2] / OUTPUT_TILE_SIZE)
        # ),
        "int_floor(i3,"
        # + str(ceil(output_shape[1] / OUTPUT_TILE_SIZE))
        + str(outputShape[0] * ceil(outputShape[2] / OUTPUT_TILE_SIZE)) +
        ")*" + str(OUTPUT_TILE_SIZE) + "+i1",
        "i2",
    ]
    inputView = state.add_transient(
        "V" + "_".join([str(_s) for _s in inputViewShape]),
        inputViewShape,
        dace.float32,
        dace.StorageType.GPU_Global,
    )
    mapEntry, mapExit = state.add_map(
        string_builder(tf_node.name) + "_input_tile",
        dict(zip(inputParams[0], inputViewDims)),
    )
    tasklet = state.add_tasklet(
        string_builder(tf_node.name) + "_input_tile", {"j0"}, {"out"},
        "out = j0")
    dace_session.add_in_memlets([inputNodes[0]], mapEntry, tasklet,
                                [inputDims[0]], [inputViewParams])
    dace_session.add_out_memlets([inputView], mapExit, tasklet,
                                 [inputViewDims], [inputParams[0]])
    ##################Transforming all input tiles#########################
    #[TODO] try to re-use memory
    vNode = state.add_transient(
        "V_output" + "_".join([str(_s) for _s in inputViewShape]),
        inputViewShape,
        dace.float32,
        dace.StorageType.GPU_Global,
    )
    vNode.setzero = True
    mapEntry, mapExit = state.add_map(
        string_builder(tf_node.name) + "_input_txform",
        dict(zip(inputParams[0][0:2], inputViewDims[2:4])),
        dace.ScheduleType.GPU_Device,
    )
    intermediateResultNode = state.add_transient("BtI", bt.shape, dace.float32,
                                                 dace.StorageType.Register)
    intermediateResultNode.setzero = True
    state.add_edge(
        inputView,
        None,
        mapEntry,
        None,
        Memlet.simple(inputView, ",".join(inputViewDims)),
    )
    mm_small(
        state,
        bTransposeNode,
        inputView,
        intermediateResultNode,
        B_subset=[IMAGE_TILE_SIZE, IMAGE_TILE_SIZE],
        B_memlet=Memlet.simple(
            inputView, ",".join(inputViewDims[0:2] + inputParams[0][0:2])),
        map_entry=mapEntry,
        B_direct=False,
    )
    mm_small(
        state,
        intermediateResultNode,
        bNode,
        vNode,
        map_exit=mapExit,
        C_subset=[IMAGE_TILE_SIZE, IMAGE_TILE_SIZE],
        C_memlet=Memlet.simple(
            vNode,
            ",".join(inputViewDims[0:2] + inputParams[0][0:2]),
            wcr_str="lambda a,b: a+b",
            wcr_conflict=False,
        ),
        map_entry=mapEntry,
        A_direct=True,
    )
    state.add_edge(
        mapExit,
        None,
        vNode,
        None,
        Memlet.simple(
            vNode,
            ",".join(inputViewDims),
            wcr_str="lambda a,b: a+b",
            wcr_conflict=False,
        ),
    )
    #############Transforming the kernel###############################
    mapEntry, mapExit = state.add_map(
        string_builder(tf_node.name) + "_kernel_txform",
        dict(zip(inputParams[1][0:2], inputDims[1][2:4])),
        dace.ScheduleType.GPU_Device,
    )
    intermediateResultNode = state.add_transient("GF", g.shape, dace.float32,
                                                 dace.StorageType.Register)
    intermediateResultNode.setzero = True
    processedKernelNode = state.add_transient(
        "U" + "_".join([
            str(_s) for _s in inputViewShape[0:2] +
            list(tf_node.inputs[1].shape[-1:-3:-1])
        ]),
        inputViewShape[0:2] + list(tf_node.inputs[1].shape[-1:-3:-1]),
        dace.float32,
        dace.StorageType.GPU_Global,
    )
    processedKernelNode.setzero = True
    state.add_edge(
        inputNodes[1],
        None,
        mapEntry,
        None,
        dace.Memlet.from_array(inputNodes[1].data,
                               inputNodes[1].desc(dace_session.graph)),
    )

    mm_small(
        state,
        gNode,
        inputNodes[1],
        intermediateResultNode,
        map_entry=mapEntry,
        B_subset=tf_node.inputs[1].shape[0:2],
        B_memlet=Memlet.simple(
            inputNodes[1], ",".join(inputDims[1][0:2] + inputParams[1][0:2])),
        B_direct=False,
    )
    mm_small(
        state,
        intermediateResultNode,
        gTransposeNode,
        processedKernelNode,
        C_subset=[IMAGE_TILE_SIZE, IMAGE_TILE_SIZE],
        C_memlet=Memlet.simple(
            processedKernelNode,
            ",".join(inputViewDims[0:2] + [inputParams[0][1]] +
                     [inputParams[0][0]]),
            wcr_str="lambda a,b: a+b",
            wcr_conflict=False,
        ),
        map_entry=mapEntry,
        map_exit=mapExit,
        A_direct=True,
    )
    state.add_edge(
        mapExit,
        None,
        processedKernelNode,
        None,
        Memlet.simple(
            processedKernelNode.data,
            ",".join([
                "0:" + str(_s)
                for _s in processedKernelNode.desc(dace_session.graph).shape
            ]),
            wcr_str="lambda a,b: a+b",
            wcr_conflict=False,
        ),
    )
    mNode = state.add_transient(
        "m" + "_".join([
            str(_s) for _s in inputViewShape[0:2] +
            [tf_node.inputs[1].shape[-1], inputViewShape[-1]]
        ]),
        inputViewShape[0:2] +
        [tf_node.inputs[1].shape[-1], inputViewShape[-1]],
        dace.float32,
        dace.StorageType.GPU_Global,
    )
    mNodeDims = ["0:" + str(_d) for _d in mNode.desc(dace_session.graph).shape]
    mapEntry, mapExit = state.add_map(
        string_builder(tf_node.name) + "_eltwise_product",
        dict(zip(inputParams[0][0:2], inputViewDims[0:2])),
        dace.ScheduleType.Sequential,
    )
    state.add_edge(
        vNode,
        None,
        mapEntry,
        None,
        Memlet.from_array(vNode.data, vNode.desc(dace_session.graph)),
    )
    state.add_edge(
        processedKernelNode,
        None,
        mapEntry,
        None,
        Memlet.from_array(processedKernelNode.data,
                          processedKernelNode.desc(dace_session.graph)),
    )
    mm(
        state,
        vNode,
        processedKernelNode,
        mNode,
        A_subset=inputViewShape[2:4],
        A_memlet=Memlet.simple(
            vNode, ",".join(inputParams[0][0:2] + inputViewDims[-2:])),
        B_subset=tf_node.inputs[1].shape[-1:-3:-1],
        B_memlet=Memlet.simple(
            processedKernelNode,
            ",".join(
                inputParams[0][0:2] +
                ["0:" + str(_s) for _s in tf_node.inputs[1].shape[-1:-3:-1]]),
        ),
        C_subset=[tf_node.inputs[1].shape[-1], inputViewShape[-1]],
        C_memlet=Memlet.simple(mNode,
                               ",".join(inputParams[0][0:2] + mNodeDims[-2:])),
        map_entry=mapEntry,
        map_exit=mapExit,
        shadow_a=True,
        shadow_b=True,
    )
    state.add_edge(mapExit, None, mNode, None,
                   Memlet.simple(mNode, ",".join(mNodeDims)))
    #################OUTPUT TRANSFORMATION################################
    mapRange = [inputDims[1][-1]] + [inputViewDims[-1]]
    mapEntry, mapExit = state.add_map(
        string_builder(tf_node.name) + "_output_txform",
        dict(zip(inputParams[0][0:2], mapRange)),
        dace.ScheduleType.GPU_Device,
    )
    intermediateResultNode = state.add_transient("AtM", at.shape, dace.float32,
                                                 dace.StorageType.Register)
    intermediateResultNode.setzero = True
    transformedOutputNode = state.add_transient(
        "inv_txformed_output" + "_".join([str(tf_node.inputs[1].shape[-1])] +
                                         [str(inputViewShape[-1])]),
        [OUTPUT_TILE_SIZE, OUTPUT_TILE_SIZE] + [tf_node.inputs[1].shape[-1]] +
        [inputViewShape[-1]],
        dace.float32,
        dace.StorageType.GPU_Global,
    )
    transformedOutputNode.setzero = True
    state.add_edge(mNode, None, mapEntry, None,
                   Memlet.simple(mNode, ",".join(mNodeDims)))
    mm_small(
        state,
        aTransposeNode,
        mNode,
        intermediateResultNode,
        B_subset=inputViewShape[0:2],
        B_memlet=Memlet.simple(
            mNode, ",".join(inputViewDims[0:2] + inputParams[0][0:2])),
        map_entry=mapEntry,
        B_direct=False,
    )
    mm_small(
        state,
        intermediateResultNode,
        aNode,
        transformedOutputNode,
        C_subset=[OUTPUT_TILE_SIZE, OUTPUT_TILE_SIZE],
        C_memlet=Memlet.simple(
            transformedOutputNode,
            ",".join(
                ["0:" + str(OUTPUT_TILE_SIZE), "0:" + str(OUTPUT_TILE_SIZE)] +
                inputParams[0][0:2]),
            wcr_str="lambda a,b:a+b",
            wcr_conflict=False,
        ),
        map_entry=mapEntry,
        map_exit=mapExit,
        A_direct=True,
    )
    state.add_edge(
        mapExit,
        None,
        transformedOutputNode,
        None,
        Memlet.simple(
            transformedOutputNode.data,
            ",".join([
                "0:" + str(_s)
                for _s in transformedOutputNode.desc(dace_session.graph).shape
            ]),
            wcr_str="lambda a,b: a+b",
            wcr_conflict=False,
        ),
    )
    ###################Un-Tile the output to NHWC format###################
    outputParams = [
        "i3%" + str(outputShape[0]),
        "(i3/" + str(outputShape[0]) + ")%" +
        str(ceil(outputShape[2] / OUTPUT_TILE_SIZE)) + "*" +
        str(OUTPUT_TILE_SIZE) + "+i0",
        "int_floor(i3," +
        str(outputShape[0] * ceil(outputShape[2] / OUTPUT_TILE_SIZE)) + ")*" +
        str(OUTPUT_TILE_SIZE) + "+i1",
        "i2",
    ]
    mapRange = [
        "0:" + str(_s)
        for _s in transformedOutputNode.desc(dace_session.graph).shape
    ]
    mapEntry, mapExit = state.add_map(
        string_builder(tf_node.name) + "_output_untile",
        dict(zip(inputParams[0], mapRange)),
    )
    tasklet = state.add_tasklet(
        string_builder(tf_node.name) + "_output_untile", {"j0"}, {"out"},
        "out = j0")
    dace_session.add_in_memlets([transformedOutputNode], mapEntry, tasklet,
                                [mapRange], [inputParams[0]])
    dace_session.add_out_memlets(outputList, mapExit, tasklet, [outputDims],
                                 [outputParams])
    ################# Debugging with callbacks #############
    taskletInputs = ["i" + str(index) for index in range(len(debugNodes))]
    callback_tasklet = state.add_tasklet(
        string_builder(tf_node.name) + "_printer",
        {*taskletInputs},
        {},
        string_builder(tf_node.name) + "_printer" + "(" +
        ",".join(taskletInputs) + ");",
        language=dace.dtypes.Language.CPP,
    )
    for _n, _conn in zip(debugNodes, taskletInputs):
        _n_cpu = state.add_transient(_n.data + "_cpucopy",
                                     _n.desc(dace_session.graph).shape,
                                     _n.desc(dace_session.graph).dtype,
                                     storage=dace.StorageType.CPU_Heap,
                                     lifetime=dtypes.AllocationLifetime.SDFG)
        state.add_edge(_n, None, _n_cpu, None,
                       Memlet.from_array(_n, _n.desc(dace_session.graph)))
        state.add_edge(
            _n_cpu,
            None,
            callback_tasklet,
            _conn,
            Memlet.from_array(_n_cpu, _n_cpu.desc(dace_session.graph)),
        )
    callback_input_types = []
    for somenode in debugNodes:
        callback_input_types.append(somenode.desc(dace_session.graph))
    dace_session.callbackFunctionDict[string_builder(tf_node.name) +
                                      "_printer"] = printer
    dace_session.callbackTypeDict[string_builder(tf_node.name) +
                                  "_printer"] = dace.data.Scalar(
                                      dace.callback(None,
                                                    *callback_input_types))
예제 #18
0
파일: distr.py 프로젝트: am-ivanov/dace
def _bcgather(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState,
              in_buffer: str, out_buffer: str,
              block_sizes: Union[str, Sequence[Union[sp.Expr, Number]]]):

    from dace.libraries.pblas.nodes.pgeadd import BlockCyclicGather

    libnode = BlockCyclicGather('_BCGather_')

    inbuf_range = None
    if isinstance(in_buffer, tuple):
        inbuf_name, inbuf_range = in_buffer
    else:
        inbuf_name = in_buffer
    in_desc = sdfg.arrays[inbuf_name]
    inbuf_node = state.add_read(inbuf_name)

    bsizes_range = None
    if isinstance(block_sizes, (list, tuple)):
        if isinstance(block_sizes[0], str):
            bsizes_name, bsizes_range = block_sizes
            bsizes_desc = sdfg.arrays[bsizes_name]
            bsizes_node = state.add_read(bsizes_name)
        else:
            bsizes_name, bsizes_desc = sdfg.add_temp_transient(
                (len(block_sizes), ), dtype=dace.int32)
            bsizes_node = state.add_access(bsizes_name)
            bsizes_tasklet = state.add_tasklet(
                '_set_bsizes_', {}, {'__out'}, ";".join([
                    "__out[{}] = {}".format(i, sz)
                    for i, sz in enumerate(block_sizes)
                ]))
            state.add_edge(bsizes_tasklet, '__out', bsizes_node, None,
                           Memlet.from_array(bsizes_name, bsizes_desc))
    else:
        bsizes_name = block_sizes
        bsizes_desc = sdfg.arrays[bsizes_name]
        bsizes_node = state.add_read(bsizes_name)

    outbuf_range = None
    if isinstance(out_buffer, tuple):
        outbuf_name, outbuf_range = out_buffer
    else:
        outbuf_name = out_buffer
    out_desc = sdfg.arrays[outbuf_name]
    outbuf_node = state.add_write(outbuf_name)

    if inbuf_range:
        inbuf_mem = Memlet.simple(inbuf_name, inbuf_range)
    else:
        inbuf_mem = Memlet.from_array(inbuf_name, in_desc)
    if bsizes_range:
        bsizes_mem = Memlet.simple(bsizes_name, bsizes_range)
    else:
        bsizes_mem = Memlet.from_array(bsizes_name, bsizes_desc)
    if outbuf_range:
        outbuf_mem = Memlet.simple(outbuf_name, outbuf_range)
    else:
        outbuf_mem = Memlet.from_array(outbuf_name, out_desc)

    state.add_edge(inbuf_node, None, libnode, '_inbuffer', inbuf_mem)
    state.add_edge(bsizes_node, None, libnode, '_block_sizes', bsizes_mem)
    state.add_edge(libnode, '_outbuffer', outbuf_node, None, outbuf_mem)

    return None
예제 #19
0
def _create_einsum_internal(sdfg: SDFG,
                            state: SDFGState,
                            einsum_string: str,
                            *arrays: str,
                            dtype: Optional[dtypes.typeclass] = None,
                            optimize: bool = False,
                            output: Optional[str] = None,
                            nodes: Optional[Dict[str, AccessNode]] = None,
                            init_output: bool = None):
    # Infer shapes and strides of input/output arrays
    einsum = EinsumParser(einsum_string)

    if len(einsum.inputs) != len(arrays):
        raise ValueError('Invalid number of arrays for einsum expression')

    # Get shapes from arrays and verify dimensionality
    chardict = {}
    for inp, inpname in zip(einsum.inputs, arrays):
        inparr = sdfg.arrays[inpname]
        if len(inp) != len(inparr.shape):
            raise ValueError('Dimensionality mismatch in input "%s"' % inpname)
        for char, shp in zip(inp, inparr.shape):
            if char in chardict and shp != chardict[char]:
                raise ValueError('Dimension mismatch in einsum expression')
            chardict[char] = shp

    if optimize:
        # Try to import opt_einsum
        try:
            import opt_einsum as oe
        except (ModuleNotFoundError, NameError, ImportError):
            raise ImportError('To optimize einsum expressions, please install '
                              'the "opt_einsum" package.')

        for char, shp in chardict.items():
            if symbolic.issymbolic(shp):
                raise ValueError('Einsum optimization cannot be performed '
                                 'on symbolically-sized array dimension "%s" '
                                 'for subscript character "%s"' % (shp, char))

        # Create optimal contraction path
        # noinspection PyTypeChecker
        _, path_info = oe.contract_path(
            einsum_string, *oe.helpers.build_views(einsum_string, chardict))

        input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays}
        result_node = None

        # Follow path and create a chain of operation SDFG states
        for pair, nonfree, expr, after, blas in path_info.contraction_list:
            result, result_node = _create_einsum_internal(sdfg,
                                                          state,
                                                          expr,
                                                          arrays[pair[0]],
                                                          arrays[pair[1]],
                                                          dtype=dtype,
                                                          optimize=False,
                                                          output=None,
                                                          nodes=input_nodes)
            arrays = ([a for i, a in enumerate(arrays) if i not in pair] +
                      [result])
            input_nodes[result] = result_node

        return arrays[0], result_node
        # END of einsum optimization

    input_nodes = nodes or {arr: state.add_read(arr) for arr in arrays}

    # Get output shape from chardict, or [1] for a scalar output
    output_shape = list(map(lambda k: chardict[k], einsum.output)) or [1]
    output_index = ','.join(o for o in einsum.output) or '0'

    if output is None:
        dtype = dtype or sdfg.arrays[arrays[0]].dtype
        output, odesc = sdfg.add_temp_transient(output_shape, dtype)
        to_init = True
    else:
        odesc = sdfg.arrays[output]
        dtype = dtype or odesc.dtype
        to_init = init_output or True

    is_conflicted = not all(
        all(indim in einsum.output for indim in inp) for inp in einsum.inputs)
    if not is_conflicted and init_output is None:
        to_init = False

    if not einsum.is_bmm():
        # Fall back to "pure" SDFG einsum with conflict resolution
        c = state.add_write(output)

        # Add state before this one to initialize the output value
        if to_init:
            init_state = sdfg.add_state_before(state)
            if len(einsum.output) > 0:
                init_state.add_mapped_tasklet(
                    'einsum_reset',
                    {k: '0:%s' % chardict[k]
                     for k in einsum.output}, {},
                    'out_%s = 0' % output,
                    {'out_%s' % output: Memlet.simple(output, output_index)},
                    external_edges=True)
            else:  # Scalar output
                t = init_state.add_tasklet('einsum_reset', set(),
                                           {'out_%s' % output},
                                           'out_%s = 0' % output)
                onode = init_state.add_write(output)
                init_state.add_edge(t, 'out_%s' % output, onode, None,
                                    Memlet.simple(output, '0'))

        wcr = 'lambda a,b: a+b' if is_conflicted else None
        # Pure einsum map
        state.add_mapped_tasklet(
            'einsum', {k: '0:%s' % v
                       for k, v in chardict.items()}, {
                           'inp_%s' % arr: Memlet.simple(arr, ','.join(inp))
                           for inp, arr in zip(einsum.inputs, arrays)
                       },
            'out_%s = %s' % (output, ' * '.join('inp_%s' % arr
                                                for arr in arrays)),
            {
                'out_%s' % output: Memlet.simple(
                    output, output_index, wcr_str=wcr)
            },
            input_nodes=input_nodes,
            output_nodes={output: c},
            external_edges=True)
    else:
        # Represent einsum as a GEMM or batched GEMM (using library nodes)
        a_shape = sdfg.arrays[arrays[0]].shape
        b_shape = sdfg.arrays[arrays[1]].shape
        c_shape = output_shape

        a = input_nodes[arrays[0]]
        b = input_nodes[arrays[1]]
        c = state.add_write(output)

        # Compute GEMM dimensions and strides
        strides = dict(
            BATCH=prod([c_shape[dim] for dim in einsum.c_batch]),
            M=prod([a_shape[dim] for dim in einsum.a_only]),
            K=prod([a_shape[dim] for dim in einsum.a_sum]),
            N=prod([b_shape[dim] for dim in einsum.b_only]),
            sAM=prod(a_shape[einsum.a_only[-1] + 1:]) if einsum.a_only else 1,
            sAK=prod(a_shape[einsum.a_sum[-1] + 1:]) if einsum.a_sum else 1,
            sAB=prod(a_shape[einsum.a_batch[-1] +
                             1:]) if einsum.a_batch else 1,
            sBK=prod(b_shape[einsum.b_sum[-1] + 1:]) if einsum.b_sum else 1,
            sBN=prod(b_shape[einsum.b_only[-1] + 1:]) if einsum.b_only else 1,
            sBB=prod(b_shape[einsum.b_batch[-1] +
                             1:]) if einsum.b_batch else 1,
            sCM=prod(c_shape[einsum.c_a_only[-1] +
                             1:]) if einsum.c_a_only else 1,
            sCN=prod(c_shape[einsum.c_b_only[-1] +
                             1:]) if einsum.c_b_only else 1,
            sCB=prod(c_shape[einsum.c_batch[-1] +
                             1:]) if einsum.c_batch else 1)

        # Complement strides to make matrices as necessary
        if len(a_shape) == 1 and len(einsum.a_sum) == 1:
            strides['sAK'] = 1
            strides['sAB'] = strides['sAM'] = strides['K']
        if len(b_shape) == 1 and len(einsum.b_sum) == 1:
            strides['sBN'] = 1
            strides['sBK'] = 1
            strides['sBB'] = strides['K']
        if len(c_shape) == 1 and len(einsum.a_sum) == len(einsum.b_sum):
            strides['sCN'] = 1
            strides['sCB'] = strides['sCM'] = strides['N']

        # Create nested SDFG for GEMM
        nsdfg = create_batch_gemm_sdfg(dtype, strides)

        nsdfg_node = state.add_nested_sdfg(nsdfg, None, {'X', 'Y'}, {'Z'},
                                           strides)
        state.add_edge(a, None, nsdfg_node, 'X',
                       Memlet.from_array(a.data, a.desc(sdfg)))
        state.add_edge(b, None, nsdfg_node, 'Y',
                       Memlet.from_array(b.data, b.desc(sdfg)))
        state.add_edge(nsdfg_node, 'Z', c, None,
                       Memlet.from_array(c.data, c.desc(sdfg)))

    return output, c
예제 #20
0
파일: distr.py 프로젝트: am-ivanov/dace
def _redistribute(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState,
                  in_buffer: str, in_subarray: str, out_buffer: str,
                  out_subarray: str):
    """ Redistributes an Array using process-grids, sub-arrays, and the Redistribute library node.
        :param in_buffer: Name of the (local) input Array descriptor.
        :param in_subarray: Input sub-array descriptor.
        :param out_buffer: Name of the (local) output Array descriptor.
        :param out_subarray: Output sub-array descriptor.
        :return: Name of the new redistribution descriptor.
    """
    in_desc = sdfg.arrays[in_buffer]
    out_desc = sdfg.arrays[out_buffer]

    rdistrarray_name = sdfg.add_rdistrarray(in_subarray, out_subarray)

    from dace.libraries.mpi import Dummy, Redistribute
    tasklet = Dummy(rdistrarray_name, [
        f'MPI_Datatype {rdistrarray_name};', f'int {rdistrarray_name}_sends;',
        f'MPI_Datatype* {rdistrarray_name}_send_types;',
        f'int* {rdistrarray_name}_dst_ranks;',
        f'int {rdistrarray_name}_recvs;',
        f'MPI_Datatype* {rdistrarray_name}_recv_types;',
        f'int* {rdistrarray_name}_src_ranks;',
        f'int {rdistrarray_name}_self_copies;',
        f'int* {rdistrarray_name}_self_src;',
        f'int* {rdistrarray_name}_self_dst;',
        f'int* {rdistrarray_name}_self_size;'
    ])
    state.add_node(tasklet)
    _, scal = sdfg.add_scalar(rdistrarray_name, dace.int32, transient=True)
    wnode = state.add_write(rdistrarray_name)
    state.add_edge(tasklet, '__out', wnode, None,
                   Memlet.from_array(rdistrarray_name, scal))

    libnode = Redistribute('_Redistribute_', rdistrarray_name)

    inbuf_range = None
    if isinstance(in_buffer, tuple):
        inbuf_name, inbuf_range = in_buffer
    else:
        inbuf_name = in_buffer
    in_desc = sdfg.arrays[inbuf_name]
    inbuf_node = state.add_read(inbuf_name)

    outbuf_range = None
    if isinstance(out_buffer, tuple):
        outbuf_name, outbuf_range = out_buffer
    else:
        outbuf_name = out_buffer
    out_desc = sdfg.arrays[outbuf_name]
    outbuf_node = state.add_write(outbuf_name)

    if inbuf_range:
        inbuf_mem = Memlet.simple(inbuf_name, inbuf_range)
    else:
        inbuf_mem = Memlet.from_array(inbuf_name, in_desc)
    if outbuf_range:
        outbuf_mem = Memlet.simple(outbuf_name, outbuf_range)
    else:
        outbuf_mem = Memlet.from_array(outbuf_name, out_desc)

    state.add_edge(inbuf_node, None, libnode, '_inp_buffer', inbuf_mem)
    state.add_edge(libnode, '_out_buffer', outbuf_node, None, outbuf_mem)

    return rdistrarray_name
예제 #21
0
파일: distr.py 프로젝트: am-ivanov/dace
def _irecv(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str,
           src: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr,
                                                        Number], request: str):

    from dace.libraries.mpi.nodes.irecv import Irecv

    libnode = Irecv('_Irecv_')

    buf_range = None
    if isinstance(buffer, tuple):
        buf_name, buf_range = buffer
    else:
        buf_name = buffer
    desc = sdfg.arrays[buf_name]
    buf_node = state.add_read(buf_name)

    req_range = None
    if isinstance(request, tuple):
        req_name, req_range = request
    else:
        req_name = request
    req_desc = sdfg.arrays[req_name]
    req_node = state.add_write(req_name)

    conn = libnode.out_connectors
    conn = {
        c: (dtypes.pointer(desc.dtype) if c == '_buffer' else t)
        for c, t in conn.items()
    }
    conn = {
        c: (dtypes.pointer(req_desc.dtype) if c == '_request' else t)
        for c, t in conn.items()
    }
    libnode.out_connectors = conn

    src_range = None
    if isinstance(src, tuple):
        src_name, src_range = src
        src_node = state.add_read(src_name)
    elif isinstance(src, str) and src in sdfg.arrays.keys():
        src_name = src
        src_node = state.add_read(src_name)
    else:
        storage = desc.storage
        src_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage)
        src_node = state.add_access(src_name)
        src_tasklet = state.add_tasklet('_set_src_', {}, {'__out'},
                                        '__out = {}'.format(src))
        state.add_edge(src_tasklet, '__out', src_node, None,
                       Memlet.simple(src_name, '0'))

    tag_range = None
    if isinstance(tag, tuple):
        tag_name, tag_range = tag
        tag_node = state.add_read(tag_name)
    if isinstance(tag, str) and tag in sdfg.arrays.keys():
        tag_name = tag
        tag_node = state.add_read(tag)
    else:
        storage = desc.storage
        tag_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage)
        tag_node = state.add_access(tag_name)
        tag_tasklet = state.add_tasklet('_set_tag_', {}, {'__out'},
                                        '__out = {}'.format(tag))
        state.add_edge(tag_tasklet, '__out', tag_node, None,
                       Memlet.simple(tag_name, '0'))

    if buf_range:
        buf_mem = Memlet.simple(buf_name, buf_range)
    else:
        buf_mem = Memlet.from_array(buf_name, desc)
    if req_range:
        req_mem = Memlet.simple(req_name, req_range)
    else:
        req_mem = Memlet.from_array(req_name, req_desc)
    if src_range:
        src_mem = Memlet.simple(src_name, src_range)
    else:
        src_mem = Memlet.simple(src_name, '0')
    if tag_range:
        tag_mem = Memlet.simple(tag_name, tag_range)
    else:
        tag_mem = Memlet.simple(tag_name, '0')

    state.add_edge(libnode, '_buffer', buf_node, None, buf_mem)
    state.add_edge(src_node, None, libnode, '_src', src_mem)
    state.add_edge(tag_node, None, libnode, '_tag', tag_mem)
    state.add_edge(libnode, '_request', req_node, None, req_mem)

    return None
예제 #22
0
def _elementwise(sdfg: SDFG,
                 state: SDFGState,
                 func: str,
                 in_array: str,
                 out_array=None):
    """Apply a lambda function to each element in the input"""

    inparr = sdfg.arrays[in_array]
    restype = sdfg.arrays[in_array].dtype

    if out_array is None:
        out_array, outarr = sdfg.add_temp_transient(inparr.shape, restype,
                                                    inparr.storage)
    else:
        outarr = sdfg.arrays[out_array]

    func_ast = ast.parse(func)
    try:
        lambda_ast = func_ast.body[0].value
        if len(lambda_ast.args.args) != 1:
            raise SyntaxError(
                "Expected lambda with one arg, but {} has {}".format(
                    func, len(lambda_ast.args.arrgs)))
        arg = lambda_ast.args.args[0].arg
        body = astutils.unparse(lambda_ast.body)
    except AttributeError:
        raise SyntaxError("Could not parse func {}".format(func))

    code = "__out = {}".format(body)

    num_elements = reduce(lambda x, y: x * y, inparr.shape)
    if num_elements == 1:
        inp = state.add_read(in_array)
        out = state.add_write(out_array)
        tasklet = state.add_tasklet("_elementwise_", {arg}, {'__out'}, code)
        state.add_edge(inp, None, tasklet, arg,
                       Memlet.from_array(in_array, inparr))
        state.add_edge(tasklet, '__out', out, None,
                       Memlet.from_array(out_array, outarr))
    else:
        state.add_mapped_tasklet(
            name="_elementwise_",
            map_ranges={
                '__i%d' % i: '0:%s' % n
                for i, n in enumerate(inparr.shape)
            },
            inputs={
                arg:
                Memlet.simple(
                    in_array,
                    ','.join(['__i%d' % i for i in range(len(inparr.shape))]))
            },
            code=code,
            outputs={
                '__out':
                Memlet.simple(
                    out_array,
                    ','.join(['__i%d' % i for i in range(len(inparr.shape))]))
            },
            external_edges=True)

    return out_array
예제 #23
0
def nest_state_subgraph(sdfg: SDFG,
                        state: SDFGState,
                        subgraph: SubgraphView,
                        name: Optional[str] = None,
                        full_data: bool = False) -> nodes.NestedSDFG:
    """ Turns a state subgraph into a nested SDFG. Operates in-place.
        :param sdfg: The SDFG containing the state subgraph.
        :param state: The state containing the subgraph.
        :param subgraph: Subgraph to nest.
        :param name: An optional name for the nested SDFG.
        :param full_data: If True, nests entire input/output data.
        :return: The nested SDFG node.
        :raise KeyError: Some or all nodes in the subgraph are not located in
                         this state, or the state does not belong to the given
                         SDFG.
        :raise ValueError: The subgraph is contained in more than one scope.
    """
    if state.parent != sdfg:
        raise KeyError('State does not belong to given SDFG')
    if subgraph.graph != state:
        raise KeyError('Subgraph does not belong to given state')

    # Find the top-level scope
    scope_tree = state.scope_tree()
    scope_dict = state.scope_dict()
    scope_dict_children = state.scope_dict(True)
    top_scopenode = -1  # Initialized to -1 since "None" already means top-level

    for node in subgraph.nodes():
        if node not in scope_dict:
            raise KeyError('Node not found in state')

        # If scope entry/exit, ensure entire scope is in subgraph
        if isinstance(node, nodes.EntryNode):
            scope_nodes = scope_dict_children[node]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (entry)')
        elif isinstance(node, nodes.ExitNode):
            entry = state.entry_node(node)
            scope_nodes = scope_dict_children[entry] + [entry]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (exit)')

        scope_node = scope_dict[node]
        if scope_node not in subgraph.nodes():
            if top_scopenode != -1 and top_scopenode != scope_node:
                raise ValueError(
                    'Subgraph is contained in more than one scope')
            top_scopenode = scope_node

    scope = scope_tree[top_scopenode]
    ###

    # Collect inputs and outputs of the nested SDFG
    inputs: List[MultiConnectorEdge] = []
    outputs: List[MultiConnectorEdge] = []
    for node in subgraph.source_nodes():
        inputs.extend(state.in_edges(node))
    for node in subgraph.sink_nodes():
        outputs.extend(state.out_edges(node))

    # Collect transients not used outside of subgraph (will be removed of
    # top-level graph)
    data_in_subgraph = set(n.data for n in subgraph.nodes()
                           if isinstance(n, nodes.AccessNode))
    # Find other occurrences in SDFG
    other_nodes = set(
        n.data for s in sdfg.nodes() for n in s.nodes()
        if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes())
    subgraph_transients = set()
    for data in data_in_subgraph:
        datadesc = sdfg.arrays[data]
        if datadesc.transient and data not in other_nodes:
            subgraph_transients.add(data)

    # All transients of edges between code nodes are also added to nested graph
    for edge in subgraph.edges():
        if (isinstance(edge.src, nodes.CodeNode)
                and isinstance(edge.dst, nodes.CodeNode)):
            subgraph_transients.add(edge.data.data)

    # Collect data used in access nodes within subgraph (will be referenced in
    # full upon nesting)
    input_arrays = set()
    output_arrays = set()
    for node in subgraph.nodes():
        if (isinstance(node, nodes.AccessNode)
                and node.data not in subgraph_transients):
            if state.out_degree(node) > 0:
                input_arrays.add(node.data)
            if state.in_degree(node) > 0:
                output_arrays.add(node.data)

    # Create the nested SDFG
    nsdfg = SDFG(name or 'nested_' + state.label)

    # Transients are added to the nested graph as-is
    for name in subgraph_transients:
        nsdfg.add_datadesc(name, sdfg.arrays[name])

    # Input/output data that are not source/sink nodes are added to the graph
    # as non-transients
    for name in (input_arrays | output_arrays):
        datadesc = copy.deepcopy(sdfg.arrays[name])
        datadesc.transient = False
        nsdfg.add_datadesc(name, datadesc)

    # Connected source/sink nodes outside subgraph become global data
    # descriptors in nested SDFG
    input_names = []
    output_names = []
    for edge in inputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
            continue
        name = '__in_' + edge.data.data
        datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
        datadesc.transient = False
        if not full_data:
            datadesc.shape = edge.data.subset.size()
        input_names.append(
            nsdfg.add_datadesc(name, datadesc, find_new_name=True))
    for edge in outputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
            continue
        name = '__out_' + edge.data.data
        datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
        datadesc.transient = False
        if not full_data:
            datadesc.shape = edge.data.subset.size()
        output_names.append(
            nsdfg.add_datadesc(name, datadesc, find_new_name=True))
    ###################

    # Add scope symbols to the nested SDFG
    for v in scope.defined_vars:
        if v in sdfg.symbols:
            sym = sdfg.symbols[v]
            nsdfg.add_symbol(v, sym.dtype)

    # Create nested state
    nstate = nsdfg.add_state()

    # Add subgraph nodes and edges to nested state
    nstate.add_nodes_from(subgraph.nodes())
    for e in subgraph.edges():
        nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, e.data)

    # Modify nested SDFG parents in subgraph
    for node in subgraph.nodes():
        if isinstance(node, nodes.NestedSDFG):
            node.sdfg.parent = nstate
            node.sdfg.parent_sdfg = nsdfg

    # Add access nodes and edges as necessary
    edges_to_offset = []
    for name, edge in zip(input_names, inputs):
        node = nstate.add_read(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge,
                                nstate.add_edge(node, None, edge.dst,
                                                edge.dst_conn, new_edge)))
    for name, edge in zip(output_names, outputs):
        node = nstate.add_write(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge,
                                nstate.add_edge(edge.src, edge.src_conn, node,
                                                None, new_edge)))

    # Offset memlet paths inside nested SDFG according to subsets
    for original_edge, new_edge in edges_to_offset:
        for edge in nstate.memlet_tree(new_edge):
            edge.data.data = new_edge.data.data
            if not full_data:
                edge.data.subset.offset(original_edge.data.subset, True)

    # Add nested SDFG node to the input state
    nested_sdfg = state.add_nested_sdfg(nsdfg, None,
                                        set(input_names) | input_arrays,
                                        set(output_names) | output_arrays)

    # Reconnect memlets to nested SDFG
    for name, edge in zip(input_names, inputs):
        if full_data:
            data = Memlet.from_array(edge.data.data,
                                     sdfg.arrays[edge.data.data])
        else:
            data = edge.data
        state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data)
    for name, edge in zip(output_names, outputs):
        if full_data:
            data = Memlet.from_array(edge.data.data,
                                     sdfg.arrays[edge.data.data])
        else:
            data = edge.data
        state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data)

    # Connect access nodes to internal input/output data as necessary
    entry = scope.entry
    exit = scope.exit
    for name in input_arrays:
        node = state.add_read(name)
        if entry is not None:
            state.add_nedge(entry, node, EmptyMemlet())
        state.add_edge(node, None, nested_sdfg, name,
                       Memlet.from_array(name, sdfg.arrays[name]))
    for name in output_arrays:
        node = state.add_write(name)
        if exit is not None:
            state.add_nedge(node, exit, EmptyMemlet())
        state.add_edge(nested_sdfg, name, node, None,
                       Memlet.from_array(name, sdfg.arrays[name]))

    # Remove subgraph nodes from graph
    state.remove_nodes_from(subgraph.nodes())

    # Remove subgraph transients from top-level graph
    for transient in subgraph_transients:
        del sdfg.arrays[transient]

    return nested_sdfg
예제 #24
0
파일: distr.py 프로젝트: am-ivanov/dace
def _distr_matmult(pv: 'ProgramVisitor',
                   sdfg: SDFG,
                   state: SDFGState,
                   opa: str,
                   opb: str,
                   shape: Sequence[Union[sp.Expr, Number]],
                   a_block_sizes: Union[str, Sequence[Union[sp.Expr,
                                                            Number]]] = None,
                   b_block_sizes: Union[str, Sequence[Union[sp.Expr,
                                                            Number]]] = None,
                   c_block_sizes: Union[str, Sequence[Union[sp.Expr,
                                                            Number]]] = None):

    arra = sdfg.arrays[opa]
    arrb = sdfg.arrays[opb]

    if len(shape) == 3:
        gm, gn, gk = shape
    else:
        gm, gn = shape

    a_block_sizes = a_block_sizes or arra.shape
    if len(a_block_sizes) < 2:
        a_block_sizes = (a_block_sizes[0], 1)
    b_block_sizes = b_block_sizes or arrb.shape
    if len(b_block_sizes) < 2:
        b_block_sizes = (b_block_sizes[0], 1)

    if len(arra.shape) == 1 and len(arrb.shape) == 2:
        a_block_sizes, b_block_sizes = b_block_sizes, a_block_sizes

    a_bsizes_range = None
    if isinstance(a_block_sizes, (list, tuple)):
        if isinstance(a_block_sizes[0], str):
            a_bsizes_name, a_bsizes_range = a_block_sizes
            a_bsizes_desc = sdfg.arrays[a_bsizes_name]
            a_bsizes_node = state.add_read(a_bsizes_name)
        else:
            a_bsizes_name, a_bsizes_desc = sdfg.add_temp_transient(
                (len(a_block_sizes), ), dtype=dace.int32)
            a_bsizes_node = state.add_access(a_bsizes_name)
            a_bsizes_tasklet = state.add_tasklet(
                '_set_a_bsizes_', {}, {'__out'}, ";".join([
                    "__out[{}] = {}".format(i, sz)
                    for i, sz in enumerate(a_block_sizes)
                ]))
            state.add_edge(a_bsizes_tasklet, '__out', a_bsizes_node, None,
                           Memlet.from_array(a_bsizes_name, a_bsizes_desc))
    else:
        a_bsizes_name = a_block_sizes
        a_bsizes_desc = sdfg.arrays[a_bsizes_name]
        a_bsizes_node = state.add_read(a_bsizes_name)

    b_bsizes_range = None
    if isinstance(a_block_sizes, (list, tuple)):
        if isinstance(a_block_sizes[0], str):
            b_bsizes_name, b_sizes_range = b_block_sizes
            b_bsizes_desc = sdfg.arrays[b_bsizes_name]
            b_bsizes_node = state.add_read(b_bsizes_name)
        else:
            b_bsizes_name, b_bsizes_desc = sdfg.add_temp_transient(
                (len(b_block_sizes), ), dtype=dace.int32)
            b_bsizes_node = state.add_access(b_bsizes_name)
            b_bsizes_tasklet = state.add_tasklet(
                '_set_b_sizes_', {}, {'__out'}, ";".join([
                    "__out[{}] = {}".format(i, sz)
                    for i, sz in enumerate(b_block_sizes)
                ]))
            state.add_edge(b_bsizes_tasklet, '__out', b_bsizes_node, None,
                           Memlet.from_array(b_bsizes_name, b_bsizes_desc))
    else:
        b_bsizes_name = b_block_sizes
        b_bsizes_desc = sdfg.arrays[b_bsizes_name]
        b_bsizes_node = state.add_read(b_bsizes_name)

    if len(arra.shape) == 2 and len(arrb.shape) == 2:
        # Gemm
        from dace.libraries.pblas.nodes.pgemm import Pgemm
        tasklet = Pgemm("__DistrMatMult__", gm, gn, gk)
        m = arra.shape[0]
        n = arrb.shape[-1]
        out = sdfg.add_temp_transient((m, n), dtype=arra.dtype)
    elif len(arra.shape) == 2 and len(arrb.shape) == 1:
        # Gemv
        from dace.libraries.pblas.nodes.pgemv import Pgemv
        tasklet = Pgemv("__DistrMatVecMult__", m=gm, n=gn)
        if c_block_sizes:
            m = c_block_sizes[0]
        else:
            m = arra.shape[0]
        out = sdfg.add_temp_transient((m, ), dtype=arra.dtype)
    elif len(arra.shape) == 1 and len(arrb.shape) == 2:
        # Gemv transposed
        # Swap a and b
        opa, opb = opb, opa
        arra, arrb = arrb, arra
        from dace.libraries.pblas.nodes.pgemv import Pgemv
        tasklet = Pgemv("__DistrMatVecMult__", transa='T', m=gm, n=gn)
        if c_block_sizes:
            n = c_block_sizes[0]
        else:
            n = arra.shape[1]
        out = sdfg.add_temp_transient((n, ), dtype=arra.dtype)

    anode = state.add_read(opa)
    bnode = state.add_read(opb)
    cnode = state.add_write(out[0])

    if a_bsizes_range:
        a_bsizes_mem = Memlet.simple(a_bsizes_name, a_bsizes_range)
    else:
        a_bsizes_mem = Memlet.from_array(a_bsizes_name, a_bsizes_desc)
    if b_bsizes_range:
        b_bsizes_mem = Memlet.simple(b_bsizes_name, b_bsizes_range)
    else:
        b_bsizes_mem = Memlet.from_array(b_bsizes_name, b_bsizes_desc)

    state.add_edge(anode, None, tasklet, '_a', Memlet.from_array(opa, arra))
    state.add_edge(bnode, None, tasklet, '_b', Memlet.from_array(opb, arrb))
    state.add_edge(a_bsizes_node, None, tasklet, '_a_block_sizes',
                   a_bsizes_mem)
    state.add_edge(b_bsizes_node, None, tasklet, '_b_block_sizes',
                   b_bsizes_mem)
    state.add_edge(tasklet, '_c', cnode, None, Memlet.from_array(*out))

    return out[0]
예제 #25
0
파일: helpers.py 프로젝트: mfkiwl/dace
def nest_state_subgraph(sdfg: SDFG,
                        state: SDFGState,
                        subgraph: SubgraphView,
                        name: Optional[str] = None,
                        full_data: bool = False) -> nodes.NestedSDFG:
    """ Turns a state subgraph into a nested SDFG. Operates in-place.
        :param sdfg: The SDFG containing the state subgraph.
        :param state: The state containing the subgraph.
        :param subgraph: Subgraph to nest.
        :param name: An optional name for the nested SDFG.
        :param full_data: If True, nests entire input/output data.
        :return: The nested SDFG node.
        :raise KeyError: Some or all nodes in the subgraph are not located in
                         this state, or the state does not belong to the given
                         SDFG.
        :raise ValueError: The subgraph is contained in more than one scope.
    """
    if state.parent != sdfg:
        raise KeyError('State does not belong to given SDFG')
    if subgraph is not state and subgraph.graph is not state:
        raise KeyError('Subgraph does not belong to given state')

    # Find the top-level scope
    scope_tree = state.scope_tree()
    scope_dict = state.scope_dict()
    scope_dict_children = state.scope_children()
    top_scopenode = -1  # Initialized to -1 since "None" already means top-level

    for node in subgraph.nodes():
        if node not in scope_dict:
            raise KeyError('Node not found in state')

        # If scope entry/exit, ensure entire scope is in subgraph
        if isinstance(node, nodes.EntryNode):
            scope_nodes = scope_dict_children[node]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (entry)')
        elif isinstance(node, nodes.ExitNode):
            entry = state.entry_node(node)
            scope_nodes = scope_dict_children[entry] + [entry]
            if any(n not in subgraph.nodes() for n in scope_nodes):
                raise ValueError('Subgraph contains partial scopes (exit)')

        scope_node = scope_dict[node]
        if scope_node not in subgraph.nodes():
            if top_scopenode != -1 and top_scopenode != scope_node:
                raise ValueError('Subgraph is contained in more than one scope')
            top_scopenode = scope_node

    scope = scope_tree[top_scopenode]
    ###

    # Consolidate edges in top scope
    utils.consolidate_edges(sdfg, scope)
    snodes = subgraph.nodes()

    # Collect inputs and outputs of the nested SDFG
    inputs: List[MultiConnectorEdge] = []
    outputs: List[MultiConnectorEdge] = []
    for node in snodes:
        for edge in state.in_edges(node):
            if edge.src not in snodes:
                inputs.append(edge)
        for edge in state.out_edges(node):
            if edge.dst not in snodes:
                outputs.append(edge)

    # Collect transients not used outside of subgraph (will be removed of
    # top-level graph)
    data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode))
    # Find other occurrences in SDFG
    other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes()
                      if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes())
    subgraph_transients = set()
    for data in data_in_subgraph:
        datadesc = sdfg.arrays[data]
        if datadesc.transient and data not in other_nodes:
            subgraph_transients.add(data)

    # All transients of edges between code nodes are also added to nested graph
    for edge in subgraph.edges():
        if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)):
            subgraph_transients.add(edge.data.data)

    # Collect data used in access nodes within subgraph (will be referenced in
    # full upon nesting)
    input_arrays = set()
    output_arrays = {}
    for node in subgraph.nodes():
        if (isinstance(node, nodes.AccessNode) and node.data not in subgraph_transients):
            if node.has_reads(state):
                input_arrays.add(node.data)
            if node.has_writes(state):
                output_arrays[node.data] = state.in_edges(node)[0].data.wcr

    # Create the nested SDFG
    nsdfg = SDFG(name or 'nested_' + state.label)

    # Transients are added to the nested graph as-is
    for name in subgraph_transients:
        nsdfg.add_datadesc(name, sdfg.arrays[name])

    # Input/output data that are not source/sink nodes are added to the graph
    # as non-transients
    for name in (input_arrays | output_arrays.keys()):
        datadesc = copy.deepcopy(sdfg.arrays[name])
        datadesc.transient = False
        nsdfg.add_datadesc(name, datadesc)

    # Connected source/sink nodes outside subgraph become global data
    # descriptors in nested SDFG
    input_names = {}
    output_names = {}
    global_subsets: Dict[str, Tuple[str, Subset]] = {}
    for edge in inputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
            continue
        name = edge.data.data
        if name not in global_subsets:
            datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
            datadesc.transient = False
            if not full_data:
                datadesc.shape = edge.data.subset.size()
            new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True)
            global_subsets[name] = (new_name, edge.data.subset)
        else:
            new_name, subset = global_subsets[name]
            if not full_data:
                new_subset = union(subset, edge.data.subset)
                if new_subset is None:
                    new_subset = Range.from_array(sdfg.arrays[name])
                global_subsets[name] = (new_name, new_subset)
                nsdfg.arrays[new_name].shape = new_subset.size()
        input_names[edge] = new_name
    for edge in outputs:
        if edge.data.data is None:  # Skip edges with an empty memlet
            continue
        name = edge.data.data
        if name not in global_subsets:
            datadesc = copy.deepcopy(sdfg.arrays[edge.data.data])
            datadesc.transient = False
            if not full_data:
                datadesc.shape = edge.data.subset.size()
            new_name = nsdfg.add_datadesc(name, datadesc, find_new_name=True)
            global_subsets[name] = (new_name, edge.data.subset)
        else:
            new_name, subset = global_subsets[name]
            if not full_data:
                new_subset = union(subset, edge.data.subset)
                if new_subset is None:
                    new_subset = Range.from_array(sdfg.arrays[name])
                global_subsets[name] = (new_name, new_subset)
                nsdfg.arrays[new_name].shape = new_subset.size()
        output_names[edge] = new_name
    ###################

    # Add scope symbols to the nested SDFG
    defined_vars = set(
        symbolic.pystr_to_symbolic(s) for s in (state.symbols_defined_at(top_scopenode).keys()
                                                | sdfg.symbols))
    for v in defined_vars:
        if v in sdfg.symbols:
            sym = sdfg.symbols[v]
            nsdfg.add_symbol(v, sym.dtype)

    # Add constants to nested SDFG
    for cstname, cstval in sdfg.constants.items():
        nsdfg.add_constant(cstname, cstval)

    # Create nested state
    nstate = nsdfg.add_state()

    # Add subgraph nodes and edges to nested state
    nstate.add_nodes_from(subgraph.nodes())
    for e in subgraph.edges():
        nstate.add_edge(e.src, e.src_conn, e.dst, e.dst_conn, copy.deepcopy(e.data))

    # Modify nested SDFG parents in subgraph
    for node in subgraph.nodes():
        if isinstance(node, nodes.NestedSDFG):
            node.sdfg.parent = nstate
            node.sdfg.parent_sdfg = nsdfg
            node.sdfg.parent_nsdfg_node = node

    # Add access nodes and edges as necessary
    edges_to_offset = []
    for edge, name in input_names.items():
        node = nstate.add_read(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge, nstate.add_edge(node, None, edge.dst, edge.dst_conn, new_edge)))
    for edge, name in output_names.items():
        node = nstate.add_write(name)
        new_edge = copy.deepcopy(edge.data)
        new_edge.data = name
        edges_to_offset.append((edge, nstate.add_edge(edge.src, edge.src_conn, node, None, new_edge)))

    # Offset memlet paths inside nested SDFG according to subsets
    for original_edge, new_edge in edges_to_offset:
        for edge in nstate.memlet_tree(new_edge):
            edge.data.data = new_edge.data.data
            if not full_data:
                edge.data.subset.offset(global_subsets[original_edge.data.data][1], True)

    # Add nested SDFG node to the input state
    nested_sdfg = state.add_nested_sdfg(nsdfg, None,
                                        set(input_names.values()) | input_arrays,
                                        set(output_names.values()) | output_arrays.keys())

    # Reconnect memlets to nested SDFG
    reconnected_in = set()
    reconnected_out = set()
    empty_input = None
    empty_output = None
    for edge in inputs:
        if edge.data.data is None:
            empty_input = edge
            continue

        name = input_names[edge]
        if name in reconnected_in:
            continue
        if full_data:
            data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data])
        else:
            data = copy.deepcopy(edge.data)
            data.subset = global_subsets[edge.data.data][1]
        state.add_edge(edge.src, edge.src_conn, nested_sdfg, name, data)
        reconnected_in.add(name)

    for edge in outputs:
        if edge.data.data is None:
            empty_output = edge
            continue

        name = output_names[edge]
        if name in reconnected_out:
            continue
        if full_data:
            data = Memlet.from_array(edge.data.data, sdfg.arrays[edge.data.data])
        else:
            data = copy.deepcopy(edge.data)
            data.subset = global_subsets[edge.data.data][1]
        data.wcr = edge.data.wcr
        state.add_edge(nested_sdfg, name, edge.dst, edge.dst_conn, data)
        reconnected_out.add(name)

    # Connect access nodes to internal input/output data as necessary
    entry = scope.entry
    exit = scope.exit
    for name in input_arrays:
        node = state.add_read(name)
        if entry is not None:
            state.add_nedge(entry, node, Memlet())
        state.add_edge(node, None, nested_sdfg, name, Memlet.from_array(name, sdfg.arrays[name]))
    for name, wcr in output_arrays.items():
        node = state.add_write(name)
        if exit is not None:
            state.add_nedge(node, exit, Memlet())
        state.add_edge(nested_sdfg, name, node, None, Memlet(data=name, wcr=wcr))

    # Graph was not reconnected, but needs to be
    if state.in_degree(nested_sdfg) == 0 and empty_input is not None:
        state.add_edge(empty_input.src, empty_input.src_conn, nested_sdfg, None, empty_input.data)
    if state.out_degree(nested_sdfg) == 0 and empty_output is not None:
        state.add_edge(nested_sdfg, None, empty_output.dst, empty_output.dst_conn, empty_output.data)

    # Remove subgraph nodes from graph
    state.remove_nodes_from(subgraph.nodes())

    # Remove subgraph transients from top-level graph
    for transient in subgraph_transients:
        del sdfg.arrays[transient]

    # Remove newly isolated nodes due to memlet consolidation
    for edge in inputs:
        if state.in_degree(edge.src) + state.out_degree(edge.src) == 0:
            state.remove_node(edge.src)
    for edge in outputs:
        if state.in_degree(edge.dst) + state.out_degree(edge.dst) == 0:
            state.remove_node(edge.dst)

    return nested_sdfg
예제 #26
0
N = dp.symbol('N')
sdfg = SDFG('tlstream')
state = sdfg.add_state('doit')

localarr = state.add_transient('la', [10], dp.float32)
localstream = state.add_stream('ls', dp.float32, 1, transient=True)
globalstream = state.add_stream('gs', dp.float32, 1, transient=True)
globalarr = state.add_array('ga', [N], dp.float32)

me, mx = state.add_map('par', dict(i='0:N'))
tasklet = state.add_tasklet('arange', set(), {'a'}, 'a = i')

state.add_nedge(me, tasklet, EmptyMemlet())
state.add_edge(tasklet, 'a', localstream, None,
               Memlet.from_array(localstream.data, localstream.desc(sdfg)))
state.add_nedge(localstream, localarr,
                Memlet.from_array(localarr.data, localarr.desc(sdfg)))
state.add_nedge(localarr, mx,
                Memlet.from_array(globalstream.data, globalstream.desc(sdfg)))
state.add_nedge(mx, globalstream,
                Memlet.from_array(globalstream.data, globalstream.desc(sdfg)))
state.add_nedge(globalstream, globalarr,
                Memlet.from_array(globalarr.data, globalarr.desc(sdfg)))

sdfg.fill_scope_connectors()

if __name__ == '__main__':
    print('Thread-local stream test')

    N.set(20)
예제 #27
0
def make_sdfg(implementation, dtype, storage=dace.StorageType.Default):

    n = dace.symbol("n")

    suffix = "_device" if storage != dace.StorageType.Default else ""
    transient = storage != dace.StorageType.Default

    sdfg = dace.SDFG("matrix_lufact_getrf_{}_{}".format(
        implementation, str(dtype)))
    state = sdfg.add_state("dataflow")

    xhost_arr = sdfg.add_array("x", [n, n],
                               dtype,
                               storage=dace.StorageType.Default)

    if transient:
        x_arr = sdfg.add_array("x" + suffix, [n, n],
                               dtype,
                               storage=storage,
                               transient=transient)
        xt_arr = sdfg.add_array('xt' + suffix, [n, n],
                                dtype,
                                storage=storage,
                                transient=transient)
    sdfg.add_array("pivots" + suffix, [n],
                   dace.dtypes.int32,
                   storage=storage,
                   transient=transient)
    sdfg.add_array("result" + suffix, [1],
                   dace.dtypes.int32,
                   storage=storage,
                   transient=transient)

    if transient:
        xhi = state.add_read("x")
        xho = state.add_write("x")
        xi = state.add_access("x" + suffix)
        xo = state.add_access("x" + suffix)
        xin = state.add_access("xt" + suffix)
        xout = state.add_access("xt" + suffix)
        transpose_in = blas.nodes.transpose.Transpose("transpose_in",
                                                      dtype=dtype)
        transpose_in.implementation = "cuBLAS"
        transpose_out = blas.nodes.transpose.Transpose("transpose_out",
                                                       dtype=dtype)
        transpose_out.implementation = "cuBLAS"
        state.add_nedge(xhi, xi, Memlet.from_array(*xhost_arr))
        state.add_nedge(xo, xho, Memlet.from_array(*xhost_arr))
        state.add_memlet_path(xi,
                              transpose_in,
                              dst_conn='_inp',
                              memlet=Memlet.from_array(*x_arr))
        state.add_memlet_path(transpose_in,
                              xin,
                              src_conn='_out',
                              memlet=Memlet.from_array(*xt_arr))
        state.add_memlet_path(xout,
                              transpose_out,
                              dst_conn='_inp',
                              memlet=Memlet.from_array(*xt_arr))
        state.add_memlet_path(transpose_out,
                              xo,
                              src_conn='_out',
                              memlet=Memlet.from_array(*x_arr))
    else:
        xin = state.add_access("x" + suffix)
        xout = state.add_access("x" + suffix)
    pivots = state.add_access("pivots" + suffix)
    result = state.add_access("result" + suffix)

    getrf_node = lapack.nodes.getrf.Getrf("getrf")
    getrf_node.implementation = implementation

    state.add_memlet_path(xin,
                          getrf_node,
                          dst_conn="_xin",
                          memlet=Memlet.simple(xin,
                                               "0:n, 0:n",
                                               num_accesses=n * n))
    state.add_memlet_path(getrf_node,
                          result,
                          src_conn="_res",
                          memlet=Memlet.simple(result, "0", num_accesses=1))
    state.add_memlet_path(getrf_node,
                          pivots,
                          src_conn="_ipiv",
                          memlet=Memlet.simple(pivots, "0:n", num_accesses=n))
    state.add_memlet_path(getrf_node,
                          xout,
                          src_conn="_xout",
                          memlet=Memlet.simple(xout,
                                               "0:n, 0:n",
                                               num_accesses=n * n))

    return sdfg
예제 #28
0
파일: distr.py 프로젝트: am-ivanov/dace
def _isend(pv: 'ProgramVisitor', sdfg: SDFG, state: SDFGState, buffer: str,
           dst: Union[str, sp.Expr, Number], tag: Union[str, sp.Expr,
                                                        Number], request: str):

    from dace.libraries.mpi.nodes.isend import Isend

    libnode = Isend('_Isend_')

    buf_range = None
    if isinstance(buffer, tuple):
        buf_name, buf_range = buffer
    else:
        buf_name = buffer
    desc = sdfg.arrays[buf_name]
    buf_node = state.add_read(buf_name)

    req_range = None
    if isinstance(request, tuple):
        req_name, req_range = request
    else:
        req_name = request
    req_desc = sdfg.arrays[req_name]
    req_node = state.add_write(req_name)

    iconn = libnode.in_connectors
    iconn = {
        c: (dtypes.pointer(desc.dtype) if c == '_buffer' else t)
        for c, t in iconn.items()
    }
    libnode.in_connectors = iconn
    oconn = libnode.out_connectors
    oconn = {
        c: (dtypes.pointer(req_desc.dtype) if c == '_request' else t)
        for c, t in oconn.items()
    }
    libnode.out_connectors = oconn

    dst_range = None
    if isinstance(dst, tuple):
        dst_name, dst_range = dst
        dst_node = state.add_read(dst_name)
    elif isinstance(dst, str) and dst in sdfg.arrays.keys():
        dst_name = dst
        dst_node = state.add_read(dst_name)
    else:
        storage = desc.storage
        dst_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage)
        dst_node = state.add_access(dst_name)
        dst_tasklet = state.add_tasklet('_set_dst_', {}, {'__out'},
                                        '__out = {}'.format(dst))
        state.add_edge(dst_tasklet, '__out', dst_node, None,
                       Memlet.simple(dst_name, '0'))

    tag_range = None
    if isinstance(tag, tuple):
        tag_name, tag_range = tag
        tag_node = state.add_read(tag_name)
    if isinstance(tag, str) and tag in sdfg.arrays.keys():
        tag_name = tag
        tag_node = state.add_read(tag)
    else:
        storage = desc.storage
        tag_name = _define_local_scalar(pv, sdfg, state, dace.int32, storage)
        tag_node = state.add_access(tag_name)
        tag_tasklet = state.add_tasklet('_set_tag_', {}, {'__out'},
                                        '__out = {}'.format(tag))
        state.add_edge(tag_tasklet, '__out', tag_node, None,
                       Memlet.simple(tag_name, '0'))

    if buf_range:
        buf_mem = Memlet.simple(buf_name, buf_range)
    else:
        buf_mem = Memlet.from_array(buf_name, desc)
    if req_range:
        req_mem = Memlet.simple(req_name, req_range)
    else:
        req_mem = Memlet.from_array(req_name, req_desc)
    if dst_range:
        dst_mem = Memlet.simple(dst_name, dst_range)
    else:
        dst_mem = Memlet.simple(dst_name, '0')
    if tag_range:
        tag_mem = Memlet.simple(tag_name, tag_range)
    else:
        tag_mem = Memlet.simple(tag_name, '0')

    state.add_edge(buf_node, None, libnode, '_buffer', buf_mem)
    state.add_edge(dst_node, None, libnode, '_dest', dst_mem)
    state.add_edge(tag_node, None, libnode, '_tag', tag_mem)
    state.add_edge(libnode, '_request', req_node, None, req_mem)

    return None