def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG): from dace.codegen.prettycode import CodeIOStream from dace.codegen.targets.cpp import unparse_cr_split, cpp_array_expr node.validate(sdfg, state) input_edge: graph.MultiConnectorEdge = state.in_edges(node)[0] output_edge: graph.MultiConnectorEdge = state.out_edges(node)[0] input_dims = len(input_edge.data.subset) output_dims = len(output_edge.data.subset) input_data = sdfg.arrays[input_edge.data.data] output_data = sdfg.arrays[output_edge.data.data] # Setup all locations in which code will be written cuda_globalcode = CodeIOStream() cuda_initcode = CodeIOStream() cuda_exitcode = CodeIOStream() host_globalcode = CodeIOStream() host_localcode = CodeIOStream() output_memlet = output_edge.data # Try to autodetect reduction type redtype = detect_reduction_type(node.wcr) node_id = state.node_id(node) state_id = sdfg.node_id(state) idstr = '{sdfg}_{state}_{node}'.format(sdfg=sdfg.name, state=state_id, node=node_id) if node.out_connectors: dtype = next(node.out_connectors.values()) else: dtype = sdfg.arrays[output_memlet.data].dtype output_type = dtype.ctype if node.identity is None: raise ValueError('For device reduce nodes, initial value must be ' 'specified') # Create a functor or use an existing one for reduction if redtype == dtypes.ReductionType.Custom: body, [arg1, arg2] = unparse_cr_split(sdfg, node.wcr) cuda_globalcode.write( """ struct __reduce_{id} {{ template <typename T> DACE_HDFI T operator()(const T &{arg1}, const T &{arg2}) const {{ {contents} }} }};""".format(id=idstr, arg1=arg1, arg2=arg2, contents=body), sdfg, state_id, node_id) reduce_op = ', __reduce_' + idstr + '(), ' + symstr(node.identity) elif redtype in ExpandReduceCUDADevice._SPECIAL_RTYPES: reduce_op = '' else: credtype = 'dace::ReductionType::' + str( redtype)[str(redtype).find('.') + 1:] reduce_op = ((', dace::_wcr_fixed<%s, %s>()' % (credtype, output_type)) + ', ' + symstr(node.identity)) # Obtain some SDFG-related information input_memlet = input_edge.data reduce_shape = input_memlet.subset.bounding_box_size() num_items = ' * '.join(symstr(s) for s in reduce_shape) overapprox_memlet = dcpy(input_memlet) if any( str(s) not in sdfg.free_symbols.union(sdfg.constants.keys()) for s in overapprox_memlet.subset.free_symbols): propagation.propagate_states(sdfg) for p, r in state.ranges.items(): overapprox_memlet = propagation.propagate_subset( [overapprox_memlet], input_data, [p], r) overapprox_shape = overapprox_memlet.subset.bounding_box_size() overapprox_items = ' * '.join(symstr(s) for s in overapprox_shape) input_dims = input_memlet.subset.dims() output_dims = output_memlet.subset.data_dims() reduce_all_axes = (node.axes is None or len(node.axes) == input_dims) if reduce_all_axes: reduce_last_axes = False else: reduce_last_axes = sorted(node.axes) == list( range(input_dims - len(node.axes), input_dims)) if not reduce_all_axes and not reduce_last_axes: warnings.warn( 'Multiple axis reductions not supported with this expansion. ' 'Falling back to the pure expansion.') return ExpandReducePureSequentialDim.expansion(node, state, sdfg) # Verify that data is on the GPU if input_data.storage not in [ dtypes.StorageType.GPU_Global, dtypes.StorageType.CPU_Pinned ]: warnings.warn('Input of GPU reduction must either reside ' ' in global GPU memory or pinned CPU memory') return ExpandReducePure.expansion(node, state, sdfg) if output_data.storage not in [ dtypes.StorageType.GPU_Global, dtypes.StorageType.CPU_Pinned ]: warnings.warn('Output of GPU reduction must either reside ' ' in global GPU memory or pinned CPU memory') return ExpandReducePure.expansion(node, state, sdfg) # Determine reduction type kname = (ExpandReduceCUDADevice._SPECIAL_RTYPES[redtype] if redtype in ExpandReduceCUDADevice._SPECIAL_RTYPES else 'Reduce') # Create temp memory for this GPU cuda_globalcode.write( """ void *__cub_storage_{sdfg}_{state}_{node} = NULL; size_t __cub_ssize_{sdfg}_{state}_{node} = 0; """.format(sdfg=sdfg.name, state=state_id, node=node_id), sdfg, state_id, node) if reduce_all_axes: reduce_type = 'DeviceReduce' reduce_range = overapprox_items reduce_range_def = 'size_t num_items' reduce_range_use = 'num_items' reduce_range_call = num_items elif reduce_last_axes: num_reduce_axes = len(node.axes) not_reduce_axes = reduce_shape[:-num_reduce_axes] reduce_axes = reduce_shape[-num_reduce_axes:] overapprox_not_reduce_axes = overapprox_shape[:-num_reduce_axes] overapprox_reduce_axes = overapprox_shape[-num_reduce_axes:] num_segments = ' * '.join([symstr(s) for s in not_reduce_axes]) segment_size = ' * '.join([symstr(s) for s in reduce_axes]) overapprox_num_segments = ' * '.join( [symstr(s) for s in overapprox_not_reduce_axes]) overapprox_segment_size = ' * '.join( [symstr(s) for s in overapprox_reduce_axes]) reduce_type = 'DeviceSegmentedReduce' iterator = 'dace::stridedIterator({size})'.format( size=overapprox_segment_size) reduce_range = '{num}, {it}, {it} + 1'.format( num=overapprox_num_segments, it=iterator) reduce_range_def = 'size_t num_segments, size_t segment_size' iterator_use = 'dace::stridedIterator(segment_size)' reduce_range_use = 'num_segments, {it}, {it} + 1'.format( it=iterator_use) reduce_range_call = '%s, %s' % (num_segments, segment_size) # Call CUB to get the storage size, allocate and free it cuda_initcode.write( """ cub::{reduce_type}::{kname}(nullptr, __cub_ssize_{sdfg}_{state}_{node}, ({intype}*)nullptr, ({outtype}*)nullptr, {reduce_range}{redop}); cudaMalloc(&__cub_storage_{sdfg}_{state}_{node}, __cub_ssize_{sdfg}_{state}_{node}); """.format(sdfg=sdfg.name, state=state_id, node=node_id, reduce_type=reduce_type, reduce_range=reduce_range, redop=reduce_op, intype=input_data.dtype.ctype, outtype=output_data.dtype.ctype, kname=kname), sdfg, state_id, node) cuda_exitcode.write( 'cudaFree(__cub_storage_{sdfg}_{state}_{node});'.format( sdfg=sdfg.name, state=state_id, node=node_id), sdfg, state_id, node) # Write reduction function definition cuda_globalcode.write(""" DACE_EXPORTED void __dace_reduce_{id}({intype} *input, {outtype} *output, {reduce_range_def}, cudaStream_t stream); void __dace_reduce_{id}({intype} *input, {outtype} *output, {reduce_range_def}, cudaStream_t stream) {{ cub::{reduce_type}::{kname}(__cub_storage_{id}, __cub_ssize_{id}, input, output, {reduce_range_use}{redop}, stream); }} """.format(id=idstr, intype=input_data.dtype.ctype, outtype=output_data.dtype.ctype, reduce_type=reduce_type, reduce_range_def=reduce_range_def, reduce_range_use=reduce_range_use, kname=kname, redop=reduce_op)) # Write reduction function definition in caller file host_globalcode.write( """ DACE_EXPORTED void __dace_reduce_{id}({intype} *input, {outtype} *output, {reduce_range_def}, cudaStream_t stream); """.format(id=idstr, reduce_range_def=reduce_range_def, intype=input_data.dtype.ctype, outtype=output_data.dtype.ctype), sdfg, state_id, node) # Call reduction function where necessary host_localcode.write( '__dace_reduce_{id}(_in, _out, {reduce_range_call}, __dace_current_stream);' .format(id=idstr, reduce_range_call=reduce_range_call)) # Make tasklet tnode = dace.nodes.Tasklet('reduce', {'_in': dace.pointer(input_data.dtype)}, {'_out': dace.pointer(output_data.dtype)}, host_localcode.getvalue(), language=dace.Language.CPP) # Add the rest of the code sdfg.append_global_code(host_globalcode.getvalue()) sdfg.append_global_code(cuda_globalcode.getvalue(), 'cuda') sdfg.append_init_code(cuda_initcode.getvalue(), 'cuda') sdfg.append_exit_code(cuda_exitcode.getvalue(), 'cuda') # Rename outer connectors and add to node input_edge._dst_conn = '_in' output_edge._src_conn = '_out' node.add_in_connector('_in') node.add_out_connector('_out') return tnode
def test_conditional_full_merge(): @dace.program(dace.int32, dace.int32, dace.int32) def conditional_full_merge(a, b, c): if a < 10: if b < 10: c = 0 else: c = 1 c += 1 sdfg = conditional_full_merge.to_sdfg(strict=False) propagate_states(sdfg) # Check start state. state = sdfg.start_state state_check_executions(state, 1) # Check the first if guard, `a < 10`. state = sdfg.out_edges(state)[0].dst state_check_executions(state, 1) # Get edges to the true and fals branches. oedges = sdfg.out_edges(state) true_branch_edge = None false_branch_edge = None for edge in oedges: if edge.data.label == '(a < 10)': true_branch_edge = edge elif edge.data.label == '(not (a < 10))': false_branch_edge = edge if false_branch_edge is None or true_branch_edge is None: raise RuntimeError('Couldn\'t identify guard edges') # Check the true branch. state = true_branch_edge.dst state_check_executions(state, 1, expected_dynamic=True) # Check the next if guard, `b < 20` state = sdfg.out_edges(state)[0].dst state_check_executions(state, 1, expected_dynamic=True) # Get edges to the true and fals branches. oedges = sdfg.out_edges(state) true_branch_edge = None false_branch_edge = None for edge in oedges: if edge.data.label == '(b < 10)': true_branch_edge = edge elif edge.data.label == '(not (b < 10))': false_branch_edge = edge if false_branch_edge is None or true_branch_edge is None: raise RuntimeError('Couldn\'t identify guard edges') # Check the true branch. state = true_branch_edge.dst state_check_executions(state, 1, expected_dynamic=True) state = sdfg.out_edges(state)[0].dst state_check_executions(state, 1, expected_dynamic=True) # Check the false branch. state = false_branch_edge.dst state_check_executions(state, 1, expected_dynamic=True) state = sdfg.out_edges(state)[0].dst state_check_executions(state, 1, expected_dynamic=True) # Check the first branch merge state. state = sdfg.out_edges(state)[0].dst state_check_executions(state, 1, expected_dynamic=True) # Check the second branch merge state. state = sdfg.out_edges(state)[0].dst state_check_executions(state, 1) state = sdfg.out_edges(state)[0].dst state_check_executions(state, 1)
def test_3_fold_nested_loop(): @dace.program(dace.int32[20, 20]) def nested_3(A): for i in range(20): for j in range(i, 20): for k in range(i, j): A[k, j] += 5 sdfg = nested_3.to_sdfg(strict=False) propagate_states(sdfg) # Check start state. state = sdfg.start_state state_check_executions(state, 1) # 1st level loop, check loop guard, `for i in range(20)`. state = sdfg.out_edges(state)[0].dst state_check_executions(state, 21) # Get edges to inside and outside the loop. oedges = sdfg.out_edges(state) end_branch_edge = None for_branch_edge = None for edge in oedges: if edge.data.label == '(i < 20)': for_branch_edge = edge elif edge.data.label == '(not (i < 20))': end_branch_edge = edge if end_branch_edge is None or for_branch_edge is None: raise RuntimeError('Couldn\'t identify guard edges') # Check loop-end branch. state = end_branch_edge.dst state_check_executions(state, 1) # Check inside the loop. state = for_branch_edge.dst state_check_executions(state, 20) # 2nd level nested loop, check loog guard, `for j in range(i, 20)`. state = sdfg.out_edges(state)[0].dst state_check_executions(state, 230) # Get edges to inside and outside the loop. oedges = sdfg.out_edges(state) end_branch_edge = None for_branch_edge = None for edge in oedges: if edge.data.label == '(j < 20)': for_branch_edge = edge elif edge.data.label == '(not (j < 20))': end_branch_edge = edge if end_branch_edge is None or for_branch_edge is None: raise RuntimeError('Couldn\'t identify guard edges') # Check loop-end branch. state = end_branch_edge.dst state_check_executions(state, 20) # Check inside the loop. state = for_branch_edge.dst state_check_executions(state, 210) # 3rd level nested loop, check loog guard, `for k in range(i, j)`. state = sdfg.out_edges(state)[0].dst state_check_executions(state, 1540) # Get edges to inside and outside the loop. oedges = sdfg.out_edges(state) end_branch_edge = None for_branch_edge = None for edge in oedges: if edge.data.label == '(k < j)': for_branch_edge = edge elif edge.data.label == '(not (k < j))': end_branch_edge = edge if end_branch_edge is None or for_branch_edge is None: raise RuntimeError('Couldn\'t identify guard edges') # Check loop-end branch. state = end_branch_edge.dst state_check_executions(state, 210) # Check inside the loop. state = for_branch_edge.dst state_check_executions(state, 1330) state = sdfg.out_edges(state)[0].dst state_check_executions(state, 1330)
def test_3_fold_nested_loop_with_symbolic_bounds(): N = dace.symbol('N') M = dace.symbol('M') K = dace.symbol('K') @dace.program(dace.int32) def nested_3_symbolic(a): for i in range(N): for j in range(M): for k in range(K): a += 5 sdfg = nested_3_symbolic.to_sdfg(strict=False) propagate_states(sdfg) # Check start state. state = sdfg.start_state state_check_executions(state, 1) # 1st level loop, check loop guard, `for i in range(20)`. state = sdfg.out_edges(state)[0].dst state_check_executions(state, N + 1) # Get edges to inside and outside the loop. oedges = sdfg.out_edges(state) end_branch_edge = None for_branch_edge = None for edge in oedges: if edge.data.label == '(i < N)': for_branch_edge = edge elif edge.data.label == '(not (i < N))': end_branch_edge = edge if end_branch_edge is None or for_branch_edge is None: raise RuntimeError('Couldn\'t identify guard edges') # Check loop-end branch. state = end_branch_edge.dst state_check_executions(state, 1) # Check inside the loop. state = for_branch_edge.dst state_check_executions(state, N) # 2nd level nested loop, check loog guard, `for j in range(i, 20)`. state = sdfg.out_edges(state)[0].dst state_check_executions(state, M * N + N) # Get edges to inside and outside the loop. oedges = sdfg.out_edges(state) end_branch_edge = None for_branch_edge = None for edge in oedges: if edge.data.label == '(j < M)': for_branch_edge = edge elif edge.data.label == '(not (j < M))': end_branch_edge = edge if end_branch_edge is None or for_branch_edge is None: raise RuntimeError('Couldn\'t identify guard edges') # Check loop-end branch. state = end_branch_edge.dst state_check_executions(state, N) # Check inside the loop. state = for_branch_edge.dst state_check_executions(state, M * N) # 3rd level nested loop, check loog guard, `for k in range(i, j)`. state = sdfg.out_edges(state)[0].dst state_check_executions(state, M * N * K + M * N) # Get edges to inside and outside the loop. oedges = sdfg.out_edges(state) end_branch_edge = None for_branch_edge = None for edge in oedges: if edge.data.label == '(k < K)': for_branch_edge = edge elif edge.data.label == '(not (k < K))': end_branch_edge = edge if end_branch_edge is None or for_branch_edge is None: raise RuntimeError('Couldn\'t identify guard edges') # Check loop-end branch. state = end_branch_edge.dst state_check_executions(state, M * N) # Check inside the loop. state = for_branch_edge.dst state_check_executions(state, M * N * K) state = sdfg.out_edges(state)[0].dst state_check_executions(state, M * N * K)
def test_while_with_nested_full_merge_branch(): @dace.program(dace.int32) def while_with_nested_full_merge_branch(a): while a < 20: if a < 10: a += 2 else: a += 1 sdfg = while_with_nested_full_merge_branch.to_sdfg(strict=False) propagate_states(sdfg) # Check start state. state = sdfg.start_state state_check_executions(state, 1) # While loop, check loop guard, `while a < N`. Must be dynamic unbounded. state = sdfg.out_edges(state)[0].dst state_check_executions(state, 0, expected_dynamic=True) # Get edges to inside and outside the loop. oedges = sdfg.out_edges(state) end_branch_edge = None for_branch_edge = None for edge in oedges: if edge.data.label == '(a < 20)': for_branch_edge = edge elif edge.data.label == '(not (a < 20))': end_branch_edge = edge if end_branch_edge is None or for_branch_edge is None: raise RuntimeError('Couldn\'t identify guard edges') # Check loop-end branch. state = end_branch_edge.dst state_check_executions(state, 1) # Check inside the loop. state = for_branch_edge.dst state_check_executions(state, 0, expected_dynamic=True) # Check the branch guard, `if a < 10`. state = sdfg.out_edges(state)[0].dst state_check_executions(state, 0, expected_dynamic=True) # Get edges to both sides of the conditional split. oedges = sdfg.out_edges(state) condition_met_edge = None condition_broken_edge = None for edge in oedges: if edge.data.label == '(a < 10)': condition_met_edge = edge elif edge.data.label == '(not (a < 10))': condition_broken_edge = edge if condition_met_edge is None or condition_broken_edge is None: raise RuntimeError('Couldn\'t identify conditional guard edges') # Check the 'true' branch. state = condition_met_edge.dst state_check_executions(state, 0, expected_dynamic=True) state = sdfg.out_edges(state)[0].dst state_check_executions(state, 0, expected_dynamic=True) # Check the 'false' branch. state = condition_broken_edge.dst state_check_executions(state, 0, expected_dynamic=True) state = sdfg.out_edges(state)[0].dst state_check_executions(state, 0, expected_dynamic=True) # Check where the branches meet again. state = sdfg.out_edges(state)[0].dst state_check_executions(state, 0, expected_dynamic=True)
def test_while_inside_for(): @dace.program(dace.int32) def while_inside_for(a): for i in range(20): j = 0 while j < 20: a += 5 sdfg = while_inside_for.to_sdfg(strict=False) propagate_states(sdfg) # Check start state. state = sdfg.start_state state_check_executions(state, 1) # Check the for loop guard, `i in range(20)`. state = sdfg.out_edges(state)[0].dst state_check_executions(state, 21) # Get edges to inside and outside the loop. oedges = sdfg.out_edges(state) end_branch_edge = None for_branch_edge = None for edge in oedges: if edge.data.label == '(i < 20)': for_branch_edge = edge elif edge.data.label == '(not (i < 20))': end_branch_edge = edge if end_branch_edge is None or for_branch_edge is None: raise RuntimeError( 'Couldn\'t identify guard edges' ) # Check loop-end branch. state = end_branch_edge.dst state_check_executions(state, 1) # Check inside the loop. state = for_branch_edge.dst state_check_executions(state, 20) state = sdfg.out_edges(state)[0].dst state_check_executions(state, 20) # Check the while guard, `j < 20`. state = sdfg.out_edges(state)[0].dst state_check_executions(state, 0, expected_dynamic=True) # Get edges to inside and outside the loop. oedges = sdfg.out_edges(state) end_branch_edge = None for_branch_edge = None for edge in oedges: if edge.data.label == '(j < 20)': for_branch_edge = edge elif edge.data.label == '(not (j < 20))': end_branch_edge = edge if end_branch_edge is None or for_branch_edge is None: raise RuntimeError( 'Couldn\'t identify guard edges' ) # Check loop-end branch. state = end_branch_edge.dst state_check_executions(state, 20) # Check inside the loop. state = for_branch_edge.dst state_check_executions(state, 0, expected_dynamic=True) state = sdfg.out_edges(state)[0].dst state_check_executions(state, 0, expected_dynamic=True)