예제 #1
0
    def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
        from dace.codegen.prettycode import CodeIOStream
        from dace.codegen.targets.cpp import unparse_cr_split, cpp_array_expr

        node.validate(sdfg, state)
        input_edge: graph.MultiConnectorEdge = state.in_edges(node)[0]
        output_edge: graph.MultiConnectorEdge = state.out_edges(node)[0]
        input_dims = len(input_edge.data.subset)
        output_dims = len(output_edge.data.subset)
        input_data = sdfg.arrays[input_edge.data.data]
        output_data = sdfg.arrays[output_edge.data.data]

        # Setup all locations in which code will be written
        cuda_globalcode = CodeIOStream()
        cuda_initcode = CodeIOStream()
        cuda_exitcode = CodeIOStream()
        host_globalcode = CodeIOStream()
        host_localcode = CodeIOStream()
        output_memlet = output_edge.data

        # Try to autodetect reduction type
        redtype = detect_reduction_type(node.wcr)

        node_id = state.node_id(node)
        state_id = sdfg.node_id(state)
        idstr = '{sdfg}_{state}_{node}'.format(sdfg=sdfg.name,
                                               state=state_id,
                                               node=node_id)

        if node.out_connectors:
            dtype = next(node.out_connectors.values())
        else:
            dtype = sdfg.arrays[output_memlet.data].dtype

        output_type = dtype.ctype

        if node.identity is None:
            raise ValueError('For device reduce nodes, initial value must be '
                             'specified')

        # Create a functor or use an existing one for reduction
        if redtype == dtypes.ReductionType.Custom:
            body, [arg1, arg2] = unparse_cr_split(sdfg, node.wcr)
            cuda_globalcode.write(
                """
        struct __reduce_{id} {{
            template <typename T>
            DACE_HDFI T operator()(const T &{arg1}, const T &{arg2}) const {{
                {contents}
            }}
        }};""".format(id=idstr, arg1=arg1, arg2=arg2, contents=body), sdfg,
                state_id, node_id)
            reduce_op = ', __reduce_' + idstr + '(), ' + symstr(node.identity)
        elif redtype in ExpandReduceCUDADevice._SPECIAL_RTYPES:
            reduce_op = ''
        else:
            credtype = 'dace::ReductionType::' + str(
                redtype)[str(redtype).find('.') + 1:]
            reduce_op = ((', dace::_wcr_fixed<%s, %s>()' %
                          (credtype, output_type)) + ', ' +
                         symstr(node.identity))

        # Obtain some SDFG-related information
        input_memlet = input_edge.data
        reduce_shape = input_memlet.subset.bounding_box_size()
        num_items = ' * '.join(symstr(s) for s in reduce_shape)
        overapprox_memlet = dcpy(input_memlet)
        if any(
                str(s) not in sdfg.free_symbols.union(sdfg.constants.keys())
                for s in overapprox_memlet.subset.free_symbols):
            propagation.propagate_states(sdfg)
            for p, r in state.ranges.items():
                overapprox_memlet = propagation.propagate_subset(
                    [overapprox_memlet], input_data, [p], r)
        overapprox_shape = overapprox_memlet.subset.bounding_box_size()
        overapprox_items = ' * '.join(symstr(s) for s in overapprox_shape)

        input_dims = input_memlet.subset.dims()
        output_dims = output_memlet.subset.data_dims()

        reduce_all_axes = (node.axes is None or len(node.axes) == input_dims)
        if reduce_all_axes:
            reduce_last_axes = False
        else:
            reduce_last_axes = sorted(node.axes) == list(
                range(input_dims - len(node.axes), input_dims))

        if not reduce_all_axes and not reduce_last_axes:
            warnings.warn(
                'Multiple axis reductions not supported with this expansion. '
                'Falling back to the pure expansion.')
            return ExpandReducePureSequentialDim.expansion(node, state, sdfg)

        # Verify that data is on the GPU
        if input_data.storage not in [
                dtypes.StorageType.GPU_Global, dtypes.StorageType.CPU_Pinned
        ]:
            warnings.warn('Input of GPU reduction must either reside '
                          ' in global GPU memory or pinned CPU memory')
            return ExpandReducePure.expansion(node, state, sdfg)

        if output_data.storage not in [
                dtypes.StorageType.GPU_Global, dtypes.StorageType.CPU_Pinned
        ]:
            warnings.warn('Output of GPU reduction must either reside '
                          ' in global GPU memory or pinned CPU memory')
            return ExpandReducePure.expansion(node, state, sdfg)

        # Determine reduction type
        kname = (ExpandReduceCUDADevice._SPECIAL_RTYPES[redtype] if redtype
                 in ExpandReduceCUDADevice._SPECIAL_RTYPES else 'Reduce')

        # Create temp memory for this GPU
        cuda_globalcode.write(
            """
            void *__cub_storage_{sdfg}_{state}_{node} = NULL;
            size_t __cub_ssize_{sdfg}_{state}_{node} = 0;
        """.format(sdfg=sdfg.name, state=state_id, node=node_id), sdfg,
            state_id, node)

        if reduce_all_axes:
            reduce_type = 'DeviceReduce'
            reduce_range = overapprox_items
            reduce_range_def = 'size_t num_items'
            reduce_range_use = 'num_items'
            reduce_range_call = num_items
        elif reduce_last_axes:
            num_reduce_axes = len(node.axes)
            not_reduce_axes = reduce_shape[:-num_reduce_axes]
            reduce_axes = reduce_shape[-num_reduce_axes:]
            overapprox_not_reduce_axes = overapprox_shape[:-num_reduce_axes]
            overapprox_reduce_axes = overapprox_shape[-num_reduce_axes:]

            num_segments = ' * '.join([symstr(s) for s in not_reduce_axes])
            segment_size = ' * '.join([symstr(s) for s in reduce_axes])
            overapprox_num_segments = ' * '.join(
                [symstr(s) for s in overapprox_not_reduce_axes])
            overapprox_segment_size = ' * '.join(
                [symstr(s) for s in overapprox_reduce_axes])

            reduce_type = 'DeviceSegmentedReduce'
            iterator = 'dace::stridedIterator({size})'.format(
                size=overapprox_segment_size)
            reduce_range = '{num}, {it}, {it} + 1'.format(
                num=overapprox_num_segments, it=iterator)
            reduce_range_def = 'size_t num_segments, size_t segment_size'
            iterator_use = 'dace::stridedIterator(segment_size)'
            reduce_range_use = 'num_segments, {it}, {it} + 1'.format(
                it=iterator_use)
            reduce_range_call = '%s, %s' % (num_segments, segment_size)

        # Call CUB to get the storage size, allocate and free it
        cuda_initcode.write(
            """
            cub::{reduce_type}::{kname}(nullptr, __cub_ssize_{sdfg}_{state}_{node},
                                        ({intype}*)nullptr, ({outtype}*)nullptr, {reduce_range}{redop});
            cudaMalloc(&__cub_storage_{sdfg}_{state}_{node}, __cub_ssize_{sdfg}_{state}_{node});
""".format(sdfg=sdfg.name,
           state=state_id,
           node=node_id,
           reduce_type=reduce_type,
           reduce_range=reduce_range,
           redop=reduce_op,
           intype=input_data.dtype.ctype,
           outtype=output_data.dtype.ctype,
           kname=kname), sdfg, state_id, node)

        cuda_exitcode.write(
            'cudaFree(__cub_storage_{sdfg}_{state}_{node});'.format(
                sdfg=sdfg.name, state=state_id, node=node_id), sdfg, state_id,
            node)

        # Write reduction function definition
        cuda_globalcode.write("""
DACE_EXPORTED void __dace_reduce_{id}({intype} *input, {outtype} *output, {reduce_range_def}, cudaStream_t stream);
void __dace_reduce_{id}({intype} *input, {outtype} *output, {reduce_range_def}, cudaStream_t stream)
{{
cub::{reduce_type}::{kname}(__cub_storage_{id}, __cub_ssize_{id},
                            input, output, {reduce_range_use}{redop}, stream);
}}
        """.format(id=idstr,
                   intype=input_data.dtype.ctype,
                   outtype=output_data.dtype.ctype,
                   reduce_type=reduce_type,
                   reduce_range_def=reduce_range_def,
                   reduce_range_use=reduce_range_use,
                   kname=kname,
                   redop=reduce_op))

        # Write reduction function definition in caller file
        host_globalcode.write(
            """
DACE_EXPORTED void __dace_reduce_{id}({intype} *input, {outtype} *output, {reduce_range_def}, cudaStream_t stream);
        """.format(id=idstr,
                   reduce_range_def=reduce_range_def,
                   intype=input_data.dtype.ctype,
                   outtype=output_data.dtype.ctype), sdfg, state_id, node)

        # Call reduction function where necessary
        host_localcode.write(
            '__dace_reduce_{id}(_in, _out, {reduce_range_call}, __dace_current_stream);'
            .format(id=idstr, reduce_range_call=reduce_range_call))

        # Make tasklet
        tnode = dace.nodes.Tasklet('reduce',
                                   {'_in': dace.pointer(input_data.dtype)},
                                   {'_out': dace.pointer(output_data.dtype)},
                                   host_localcode.getvalue(),
                                   language=dace.Language.CPP)

        # Add the rest of the code
        sdfg.append_global_code(host_globalcode.getvalue())
        sdfg.append_global_code(cuda_globalcode.getvalue(), 'cuda')
        sdfg.append_init_code(cuda_initcode.getvalue(), 'cuda')
        sdfg.append_exit_code(cuda_exitcode.getvalue(), 'cuda')

        # Rename outer connectors and add to node
        input_edge._dst_conn = '_in'
        output_edge._src_conn = '_out'
        node.add_in_connector('_in')
        node.add_out_connector('_out')

        return tnode
예제 #2
0
def test_conditional_full_merge():
    @dace.program(dace.int32, dace.int32, dace.int32)
    def conditional_full_merge(a, b, c):
        if a < 10:
            if b < 10:
                c = 0
            else:
                c = 1
        c += 1

    sdfg = conditional_full_merge.to_sdfg(strict=False)
    propagate_states(sdfg)

    # Check start state.
    state = sdfg.start_state
    state_check_executions(state, 1)

    # Check the first if guard, `a < 10`.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 1)
    # Get edges to the true and fals branches.
    oedges = sdfg.out_edges(state)
    true_branch_edge = None
    false_branch_edge = None
    for edge in oedges:
        if edge.data.label == '(a < 10)':
            true_branch_edge = edge
        elif edge.data.label == '(not (a < 10))':
            false_branch_edge = edge
    if false_branch_edge is None or true_branch_edge is None:
        raise RuntimeError('Couldn\'t identify guard edges')
    # Check the true branch.
    state = true_branch_edge.dst
    state_check_executions(state, 1, expected_dynamic=True)
    # Check the next if guard, `b < 20`
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 1, expected_dynamic=True)
    # Get edges to the true and fals branches.
    oedges = sdfg.out_edges(state)
    true_branch_edge = None
    false_branch_edge = None
    for edge in oedges:
        if edge.data.label == '(b < 10)':
            true_branch_edge = edge
        elif edge.data.label == '(not (b < 10))':
            false_branch_edge = edge
    if false_branch_edge is None or true_branch_edge is None:
        raise RuntimeError('Couldn\'t identify guard edges')
    # Check the true branch.
    state = true_branch_edge.dst
    state_check_executions(state, 1, expected_dynamic=True)
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 1, expected_dynamic=True)
    # Check the false branch.
    state = false_branch_edge.dst
    state_check_executions(state, 1, expected_dynamic=True)
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 1, expected_dynamic=True)

    # Check the first branch merge state.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 1, expected_dynamic=True)

    # Check the second branch merge state.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 1)

    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 1)
예제 #3
0
def test_3_fold_nested_loop():
    @dace.program(dace.int32[20, 20])
    def nested_3(A):
        for i in range(20):
            for j in range(i, 20):
                for k in range(i, j):
                    A[k, j] += 5

    sdfg = nested_3.to_sdfg(strict=False)
    propagate_states(sdfg)

    # Check start state.
    state = sdfg.start_state
    state_check_executions(state, 1)

    # 1st level loop, check loop guard, `for i in range(20)`.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 21)
    # Get edges to inside and outside the loop.
    oedges = sdfg.out_edges(state)
    end_branch_edge = None
    for_branch_edge = None
    for edge in oedges:
        if edge.data.label == '(i < 20)':
            for_branch_edge = edge
        elif edge.data.label == '(not (i < 20))':
            end_branch_edge = edge
    if end_branch_edge is None or for_branch_edge is None:
        raise RuntimeError('Couldn\'t identify guard edges')
    # Check loop-end branch.
    state = end_branch_edge.dst
    state_check_executions(state, 1)
    # Check inside the loop.
    state = for_branch_edge.dst
    state_check_executions(state, 20)

    # 2nd level nested loop, check loog guard, `for j in range(i, 20)`.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 230)
    # Get edges to inside and outside the loop.
    oedges = sdfg.out_edges(state)
    end_branch_edge = None
    for_branch_edge = None
    for edge in oedges:
        if edge.data.label == '(j < 20)':
            for_branch_edge = edge
        elif edge.data.label == '(not (j < 20))':
            end_branch_edge = edge
    if end_branch_edge is None or for_branch_edge is None:
        raise RuntimeError('Couldn\'t identify guard edges')
    # Check loop-end branch.
    state = end_branch_edge.dst
    state_check_executions(state, 20)
    # Check inside the loop.
    state = for_branch_edge.dst
    state_check_executions(state, 210)

    # 3rd level nested loop, check loog guard, `for k in range(i, j)`.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 1540)
    # Get edges to inside and outside the loop.
    oedges = sdfg.out_edges(state)
    end_branch_edge = None
    for_branch_edge = None
    for edge in oedges:
        if edge.data.label == '(k < j)':
            for_branch_edge = edge
        elif edge.data.label == '(not (k < j))':
            end_branch_edge = edge
    if end_branch_edge is None or for_branch_edge is None:
        raise RuntimeError('Couldn\'t identify guard edges')
    # Check loop-end branch.
    state = end_branch_edge.dst
    state_check_executions(state, 210)
    # Check inside the loop.
    state = for_branch_edge.dst
    state_check_executions(state, 1330)
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 1330)
예제 #4
0
def test_3_fold_nested_loop_with_symbolic_bounds():
    N = dace.symbol('N')
    M = dace.symbol('M')
    K = dace.symbol('K')

    @dace.program(dace.int32)
    def nested_3_symbolic(a):
        for i in range(N):
            for j in range(M):
                for k in range(K):
                    a += 5

    sdfg = nested_3_symbolic.to_sdfg(strict=False)
    propagate_states(sdfg)

    # Check start state.
    state = sdfg.start_state
    state_check_executions(state, 1)

    # 1st level loop, check loop guard, `for i in range(20)`.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, N + 1)
    # Get edges to inside and outside the loop.
    oedges = sdfg.out_edges(state)
    end_branch_edge = None
    for_branch_edge = None
    for edge in oedges:
        if edge.data.label == '(i < N)':
            for_branch_edge = edge
        elif edge.data.label == '(not (i < N))':
            end_branch_edge = edge
    if end_branch_edge is None or for_branch_edge is None:
        raise RuntimeError('Couldn\'t identify guard edges')
    # Check loop-end branch.
    state = end_branch_edge.dst
    state_check_executions(state, 1)
    # Check inside the loop.
    state = for_branch_edge.dst
    state_check_executions(state, N)

    # 2nd level nested loop, check loog guard, `for j in range(i, 20)`.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, M * N + N)
    # Get edges to inside and outside the loop.
    oedges = sdfg.out_edges(state)
    end_branch_edge = None
    for_branch_edge = None
    for edge in oedges:
        if edge.data.label == '(j < M)':
            for_branch_edge = edge
        elif edge.data.label == '(not (j < M))':
            end_branch_edge = edge
    if end_branch_edge is None or for_branch_edge is None:
        raise RuntimeError('Couldn\'t identify guard edges')
    # Check loop-end branch.
    state = end_branch_edge.dst
    state_check_executions(state, N)
    # Check inside the loop.
    state = for_branch_edge.dst
    state_check_executions(state, M * N)

    # 3rd level nested loop, check loog guard, `for k in range(i, j)`.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, M * N * K + M * N)
    # Get edges to inside and outside the loop.
    oedges = sdfg.out_edges(state)
    end_branch_edge = None
    for_branch_edge = None
    for edge in oedges:
        if edge.data.label == '(k < K)':
            for_branch_edge = edge
        elif edge.data.label == '(not (k < K))':
            end_branch_edge = edge
    if end_branch_edge is None or for_branch_edge is None:
        raise RuntimeError('Couldn\'t identify guard edges')
    # Check loop-end branch.
    state = end_branch_edge.dst
    state_check_executions(state, M * N)
    # Check inside the loop.
    state = for_branch_edge.dst
    state_check_executions(state, M * N * K)
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, M * N * K)
예제 #5
0
def test_while_with_nested_full_merge_branch():
    @dace.program(dace.int32)
    def while_with_nested_full_merge_branch(a):
        while a < 20:
            if a < 10:
                a += 2
            else:
                a += 1

    sdfg = while_with_nested_full_merge_branch.to_sdfg(strict=False)
    propagate_states(sdfg)

    # Check start state.
    state = sdfg.start_state
    state_check_executions(state, 1)

    # While loop, check loop guard, `while a < N`. Must be dynamic unbounded.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 0, expected_dynamic=True)
    # Get edges to inside and outside the loop.
    oedges = sdfg.out_edges(state)
    end_branch_edge = None
    for_branch_edge = None
    for edge in oedges:
        if edge.data.label == '(a < 20)':
            for_branch_edge = edge
        elif edge.data.label == '(not (a < 20))':
            end_branch_edge = edge
    if end_branch_edge is None or for_branch_edge is None:
        raise RuntimeError('Couldn\'t identify guard edges')
    # Check loop-end branch.
    state = end_branch_edge.dst
    state_check_executions(state, 1)
    # Check inside the loop.
    state = for_branch_edge.dst
    state_check_executions(state, 0, expected_dynamic=True)

    # Check the branch guard, `if a < 10`.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 0, expected_dynamic=True)
    # Get edges to both sides of the conditional split.
    oedges = sdfg.out_edges(state)
    condition_met_edge = None
    condition_broken_edge = None
    for edge in oedges:
        if edge.data.label == '(a < 10)':
            condition_met_edge = edge
        elif edge.data.label == '(not (a < 10))':
            condition_broken_edge = edge
    if condition_met_edge is None or condition_broken_edge is None:
        raise RuntimeError('Couldn\'t identify conditional guard edges')
    # Check the 'true' branch.
    state = condition_met_edge.dst
    state_check_executions(state, 0, expected_dynamic=True)
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 0, expected_dynamic=True)
    # Check the 'false' branch.
    state = condition_broken_edge.dst
    state_check_executions(state, 0, expected_dynamic=True)
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 0, expected_dynamic=True)

    # Check where the branches meet again.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 0, expected_dynamic=True)
def test_while_inside_for():
    @dace.program(dace.int32)
    def while_inside_for(a):
        for i in range(20):
            j = 0
            while j < 20:
                a += 5

    sdfg = while_inside_for.to_sdfg(strict=False)
    propagate_states(sdfg)

    # Check start state.
    state = sdfg.start_state
    state_check_executions(state, 1)

    # Check the for loop guard, `i in range(20)`.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 21)
    # Get edges to inside and outside the loop.
    oedges = sdfg.out_edges(state)
    end_branch_edge = None
    for_branch_edge = None
    for edge in oedges:
        if edge.data.label == '(i < 20)':
            for_branch_edge = edge
        elif edge.data.label == '(not (i < 20))':
            end_branch_edge = edge
    if end_branch_edge is None or for_branch_edge is None:
        raise RuntimeError(
            'Couldn\'t identify guard edges'
        )
    # Check loop-end branch.
    state = end_branch_edge.dst
    state_check_executions(state, 1)
    # Check inside the loop.
    state = for_branch_edge.dst
    state_check_executions(state, 20)
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 20)

    # Check the while guard, `j < 20`.
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 0, expected_dynamic=True)
    # Get edges to inside and outside the loop.
    oedges = sdfg.out_edges(state)
    end_branch_edge = None
    for_branch_edge = None
    for edge in oedges:
        if edge.data.label == '(j < 20)':
            for_branch_edge = edge
        elif edge.data.label == '(not (j < 20))':
            end_branch_edge = edge
    if end_branch_edge is None or for_branch_edge is None:
        raise RuntimeError(
            'Couldn\'t identify guard edges'
        )
    # Check loop-end branch.
    state = end_branch_edge.dst
    state_check_executions(state, 20)
    # Check inside the loop.
    state = for_branch_edge.dst
    state_check_executions(state, 0, expected_dynamic=True)
    state = sdfg.out_edges(state)[0].dst
    state_check_executions(state, 0, expected_dynamic=True)