Ejemplo n.º 1
0
        def _check_cand(candidates, outer_edges):
            for cname, (cand, nstate, indices) in candidates.items():
                if all(me == 0
                       for i, me in enumerate(cand.subset.min_element())
                       if i in indices):
                    ignore.add(cname)
                    continue

                # Ensure outer memlets begin with 0
                outer_edge = next(iter(outer_edges(nsdfg, cname)))
                if any(me != 0 for i, me in enumerate(
                        outer_edge.data.subset.min_element()) if i in indices):
                    ignore.add(cname)
                    continue

                # Check w.r.t. loops
                if len(nstate.ranges) > 0:
                    # Re-annotate loop ranges, in case someone changed them
                    # TODO: Move out of here!
                    nstate.ranges = {}
                    from dace.sdfg.propagation import _annotate_loop_ranges
                    _annotate_loop_ranges(nsdfg.sdfg, [])

                    memlet = propagation.propagate_subset(
                        [cand], nsdfg.sdfg.arrays[cname],
                        sorted(nstate.ranges.keys()),
                        subsets.Range([
                            v.ndrange()[0]
                            for _, v in sorted(nstate.ranges.items())
                        ]))
                    if all(me == 0
                           for i, me in enumerate(memlet.subset.min_element())
                           if i in indices):
                        ignore.add(cname)
                        continue

                    # Modify memlet to propagated one
                    candidates[cname] = (memlet, nstate, indices)
                else:
                    memlet = cand

                # If there are any symbols here that are not defined
                # in "defined_symbols"
                missing_symbols = (memlet.free_symbols -
                                   set(nsdfg.symbol_mapping.keys()))
                if missing_symbols:
                    ignore.add(cname)
                    continue
Ejemplo n.º 2
0
    def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
        from dace.codegen.prettycode import CodeIOStream
        from dace.codegen.targets.cpp import unparse_cr_split, cpp_array_expr

        node.validate(sdfg, state)
        input_edge: graph.MultiConnectorEdge = state.in_edges(node)[0]
        output_edge: graph.MultiConnectorEdge = state.out_edges(node)[0]
        input_dims = len(input_edge.data.subset)
        output_dims = len(output_edge.data.subset)
        input_data = sdfg.arrays[input_edge.data.data]
        output_data = sdfg.arrays[output_edge.data.data]

        # Setup all locations in which code will be written
        cuda_globalcode = CodeIOStream()
        cuda_initcode = CodeIOStream()
        cuda_exitcode = CodeIOStream()
        host_globalcode = CodeIOStream()
        host_localcode = CodeIOStream()
        output_memlet = output_edge.data

        # Try to autodetect reduction type
        redtype = detect_reduction_type(node.wcr)

        node_id = state.node_id(node)
        state_id = sdfg.node_id(state)
        idstr = '{sdfg}_{state}_{node}'.format(sdfg=sdfg.name,
                                               state=state_id,
                                               node=node_id)

        if node.out_connectors:
            dtype = next(node.out_connectors.values())
        else:
            dtype = sdfg.arrays[output_memlet.data].dtype

        output_type = dtype.ctype

        if node.identity is None:
            raise ValueError('For device reduce nodes, initial value must be '
                             'specified')

        # Create a functor or use an existing one for reduction
        if redtype == dtypes.ReductionType.Custom:
            body, [arg1, arg2] = unparse_cr_split(sdfg, node.wcr)
            cuda_globalcode.write(
                """
        struct __reduce_{id} {{
            template <typename T>
            DACE_HDFI T operator()(const T &{arg1}, const T &{arg2}) const {{
                {contents}
            }}
        }};""".format(id=idstr, arg1=arg1, arg2=arg2, contents=body), sdfg,
                state_id, node_id)
            reduce_op = ', __reduce_' + idstr + '(), ' + symstr(node.identity)
        elif redtype in ExpandReduceCUDADevice._SPECIAL_RTYPES:
            reduce_op = ''
        else:
            credtype = 'dace::ReductionType::' + str(
                redtype)[str(redtype).find('.') + 1:]
            reduce_op = ((', dace::_wcr_fixed<%s, %s>()' %
                          (credtype, output_type)) + ', ' +
                         symstr(node.identity))

        # Obtain some SDFG-related information
        input_memlet = input_edge.data
        reduce_shape = input_memlet.subset.bounding_box_size()
        num_items = ' * '.join(symstr(s) for s in reduce_shape)
        overapprox_memlet = dcpy(input_memlet)
        if any(
                str(s) not in sdfg.free_symbols.union(sdfg.constants.keys())
                for s in overapprox_memlet.subset.free_symbols):
            propagation.propagate_states(sdfg)
            for p, r in state.ranges.items():
                overapprox_memlet = propagation.propagate_subset(
                    [overapprox_memlet], input_data, [p], r)
        overapprox_shape = overapprox_memlet.subset.bounding_box_size()
        overapprox_items = ' * '.join(symstr(s) for s in overapprox_shape)

        input_dims = input_memlet.subset.dims()
        output_dims = output_memlet.subset.data_dims()

        reduce_all_axes = (node.axes is None or len(node.axes) == input_dims)
        if reduce_all_axes:
            reduce_last_axes = False
        else:
            reduce_last_axes = sorted(node.axes) == list(
                range(input_dims - len(node.axes), input_dims))

        if not reduce_all_axes and not reduce_last_axes:
            warnings.warn(
                'Multiple axis reductions not supported with this expansion. '
                'Falling back to the pure expansion.')
            return ExpandReducePureSequentialDim.expansion(node, state, sdfg)

        # Verify that data is on the GPU
        if input_data.storage not in [
                dtypes.StorageType.GPU_Global, dtypes.StorageType.CPU_Pinned
        ]:
            warnings.warn('Input of GPU reduction must either reside '
                          ' in global GPU memory or pinned CPU memory')
            return ExpandReducePure.expansion(node, state, sdfg)

        if output_data.storage not in [
                dtypes.StorageType.GPU_Global, dtypes.StorageType.CPU_Pinned
        ]:
            warnings.warn('Output of GPU reduction must either reside '
                          ' in global GPU memory or pinned CPU memory')
            return ExpandReducePure.expansion(node, state, sdfg)

        # Determine reduction type
        kname = (ExpandReduceCUDADevice._SPECIAL_RTYPES[redtype] if redtype
                 in ExpandReduceCUDADevice._SPECIAL_RTYPES else 'Reduce')

        # Create temp memory for this GPU
        cuda_globalcode.write(
            """
            void *__cub_storage_{sdfg}_{state}_{node} = NULL;
            size_t __cub_ssize_{sdfg}_{state}_{node} = 0;
        """.format(sdfg=sdfg.name, state=state_id, node=node_id), sdfg,
            state_id, node)

        if reduce_all_axes:
            reduce_type = 'DeviceReduce'
            reduce_range = overapprox_items
            reduce_range_def = 'size_t num_items'
            reduce_range_use = 'num_items'
            reduce_range_call = num_items
        elif reduce_last_axes:
            num_reduce_axes = len(node.axes)
            not_reduce_axes = reduce_shape[:-num_reduce_axes]
            reduce_axes = reduce_shape[-num_reduce_axes:]
            overapprox_not_reduce_axes = overapprox_shape[:-num_reduce_axes]
            overapprox_reduce_axes = overapprox_shape[-num_reduce_axes:]

            num_segments = ' * '.join([symstr(s) for s in not_reduce_axes])
            segment_size = ' * '.join([symstr(s) for s in reduce_axes])
            overapprox_num_segments = ' * '.join(
                [symstr(s) for s in overapprox_not_reduce_axes])
            overapprox_segment_size = ' * '.join(
                [symstr(s) for s in overapprox_reduce_axes])

            reduce_type = 'DeviceSegmentedReduce'
            iterator = 'dace::stridedIterator({size})'.format(
                size=overapprox_segment_size)
            reduce_range = '{num}, {it}, {it} + 1'.format(
                num=overapprox_num_segments, it=iterator)
            reduce_range_def = 'size_t num_segments, size_t segment_size'
            iterator_use = 'dace::stridedIterator(segment_size)'
            reduce_range_use = 'num_segments, {it}, {it} + 1'.format(
                it=iterator_use)
            reduce_range_call = '%s, %s' % (num_segments, segment_size)

        # Call CUB to get the storage size, allocate and free it
        cuda_initcode.write(
            """
            cub::{reduce_type}::{kname}(nullptr, __cub_ssize_{sdfg}_{state}_{node},
                                        ({intype}*)nullptr, ({outtype}*)nullptr, {reduce_range}{redop});
            cudaMalloc(&__cub_storage_{sdfg}_{state}_{node}, __cub_ssize_{sdfg}_{state}_{node});
""".format(sdfg=sdfg.name,
           state=state_id,
           node=node_id,
           reduce_type=reduce_type,
           reduce_range=reduce_range,
           redop=reduce_op,
           intype=input_data.dtype.ctype,
           outtype=output_data.dtype.ctype,
           kname=kname), sdfg, state_id, node)

        cuda_exitcode.write(
            'cudaFree(__cub_storage_{sdfg}_{state}_{node});'.format(
                sdfg=sdfg.name, state=state_id, node=node_id), sdfg, state_id,
            node)

        # Write reduction function definition
        cuda_globalcode.write("""
DACE_EXPORTED void __dace_reduce_{id}({intype} *input, {outtype} *output, {reduce_range_def}, cudaStream_t stream);
void __dace_reduce_{id}({intype} *input, {outtype} *output, {reduce_range_def}, cudaStream_t stream)
{{
cub::{reduce_type}::{kname}(__cub_storage_{id}, __cub_ssize_{id},
                            input, output, {reduce_range_use}{redop}, stream);
}}
        """.format(id=idstr,
                   intype=input_data.dtype.ctype,
                   outtype=output_data.dtype.ctype,
                   reduce_type=reduce_type,
                   reduce_range_def=reduce_range_def,
                   reduce_range_use=reduce_range_use,
                   kname=kname,
                   redop=reduce_op))

        # Write reduction function definition in caller file
        host_globalcode.write(
            """
DACE_EXPORTED void __dace_reduce_{id}({intype} *input, {outtype} *output, {reduce_range_def}, cudaStream_t stream);
        """.format(id=idstr,
                   reduce_range_def=reduce_range_def,
                   intype=input_data.dtype.ctype,
                   outtype=output_data.dtype.ctype), sdfg, state_id, node)

        # Call reduction function where necessary
        host_localcode.write(
            '__dace_reduce_{id}(_in, _out, {reduce_range_call}, __dace_current_stream);'
            .format(id=idstr, reduce_range_call=reduce_range_call))

        # Make tasklet
        tnode = dace.nodes.Tasklet('reduce',
                                   {'_in': dace.pointer(input_data.dtype)},
                                   {'_out': dace.pointer(output_data.dtype)},
                                   host_localcode.getvalue(),
                                   language=dace.Language.CPP)

        # Add the rest of the code
        sdfg.append_global_code(host_globalcode.getvalue())
        sdfg.append_global_code(cuda_globalcode.getvalue(), 'cuda')
        sdfg.append_init_code(cuda_initcode.getvalue(), 'cuda')
        sdfg.append_exit_code(cuda_exitcode.getvalue(), 'cuda')

        # Rename outer connectors and add to node
        input_edge._dst_conn = '_in'
        output_edge._src_conn = '_out'
        node.add_in_connector('_in')
        node.add_out_connector('_out')

        return tnode
Ejemplo n.º 3
0
    def can_be_applied(self, graph, candidate, expr_index, sdfg, strict=False):
        # Is this even a loop
        if not DetectLoop.can_be_applied(graph, candidate, expr_index, sdfg,
                                         strict):
            return False

        guard = graph.node(candidate[DetectLoop._loop_guard])
        begin = graph.node(candidate[DetectLoop._loop_begin])

        # Guard state should not contain any dataflow
        if len(guard.nodes()) != 0:
            return False

        # If loop cannot be detected, fail
        found = find_for_loop(graph, guard, begin,
                              itervar=self.itervar)
        if not found:
            return False

        itervar, (start, end, step), (_, body_end) = found

        # We cannot handle symbols read from data containers unless they are
        # scalar
        for expr in (start, end, step):
            if symbolic.contains_sympy_functions(expr):
                return False

        # Find all loop-body states
        states = set([body_end])
        to_visit = [begin]
        while to_visit:
            state = to_visit.pop(0)
            if state is body_end:
                continue
            for _, dst, _ in graph.out_edges(state):
                if dst not in states:
                    to_visit.append(dst)
            states.add(state)

        write_set = set()
        for state in states:
            _, wset = state.read_and_write_sets()
            write_set |= wset

        # Get access nodes from other states to isolate local loop variables
        other_access_nodes = set()
        for state in sdfg.nodes():
            if state in states:
                continue
            other_access_nodes |= set(n.data for n in state.data_nodes()
                                      if sdfg.arrays[n.data].transient)
        # Add non-transient nodes from loop state
        for state in states:
            other_access_nodes |= set(n.data for n in state.data_nodes()
                                      if not sdfg.arrays[n.data].transient)

        write_memlets = defaultdict(list)

        itersym = symbolic.pystr_to_symbolic(itervar)
        a = sp.Wild('a', exclude=[itersym])
        b = sp.Wild('b', exclude=[itersym])

        for state in states:
            for dn in state.data_nodes():
                if dn.data not in other_access_nodes:
                    continue
                # Take all writes that are not conflicted into consideration
                if dn.data in write_set:
                    for e in state.in_edges(dn):
                        if e.data.dynamic and e.data.wcr is None:
                            # If pointers are involved, give up
                            return False
                        # To be sure that the value is only written at unique
                        # indices per loop iteration, we want to match symbols
                        # of the form "a*i+b" where a >= 1, and i is the iteration
                        # variable. The iteration variable must be used.
                        if e.data.wcr is None:
                            dst_subset = e.data.get_dst_subset(e, state)
                            if not _check_range(dst_subset, a, itersym, b, step):
                                return False
                        # End of check

                        write_memlets[dn.data].append(e.data)

        # After looping over relevant writes, consider reads that may overlap
        for state in states:
            for dn in state.data_nodes():
                if dn.data not in other_access_nodes:
                    continue
                data = dn.data
                if data in write_memlets:
                    # Import as necessary
                    from dace.sdfg.propagation import propagate_subset

                    for e in state.out_edges(dn):
                        # If the same container is both read and written, only match if
                        # it read and written at locations that will not create data races
                        if e.data.dynamic and e.data.src_subset.num_elements() != 1:
                            # If pointers are involved, give up
                            return False
                        src_subset = e.data.get_src_subset(e, state)
                        if not _check_range(src_subset, a, itersym, b, step):
                            return False

                        pread = propagate_subset([e.data], sdfg.arrays[data],
                                                [itervar],
                                                subsets.Range([(start, end, step)
                                                                ]))
                        for candidate in write_memlets[data]:
                            # Simple case: read and write are in the same subset
                            if e.data.subset == candidate.subset:
                                break
                            # Propagated read does not overlap with propagated write
                            pwrite = propagate_subset([candidate],
                                                    sdfg.arrays[data], [itervar],
                                                    subsets.Range([(start, end,
                                                                    step)]))
                            if subsets.intersects(pread.subset,
                                                pwrite.subset) is False:
                                break
                            return False

        # Check that the iteration variable is not used on other edges or states
        # before it is reassigned
        prior_states = True
        for state in cfg.stateorder_topological_sort(sdfg):
            # Skip all states up to guard
            if prior_states:
                if state is begin:
                    prior_states = False
                continue
            # We do not need to check the loop-body states
            if state in states:
                continue
            if itervar in state.free_symbols:
                return False
            # Don't continue in this direction, as the variable has
            # now been reassigned
            # TODO: Handle case of subset of out_edges
            if all(itervar in e.data.assignments
                   for e in sdfg.out_edges(state)):
                break

        return True