def remove_unused_sinks(top_sdfg: dace.SDFG): """ Remove unused transient sink nodes and their generating computation. """ for sdfg in top_sdfg.all_sdfgs_recursive(): for state in sdfg.nodes(): toremove = set() map_sink_nodes = [ n for n in state.nodes() if state.out_degree(n) == 1 and state.out_edges(n)[0].data.data is None ] for node in state.sink_nodes() + map_sink_nodes: if (isinstance(node, dace.nodes.AccessNode) and sdfg.arrays[node.data].transient): if len([ n for s in sdfg.nodes() for n in s.nodes() if isinstance(n, dace.nodes.AccessNode) and n.data == node.data ]) == 1: if state.in_degree(node) == 1: predecessor = state.in_edges(node)[0].src # Only remove the node (and its predecessor) if it # has one unique predecessor that is not connected # to anything else if (state.out_degree(predecessor) == 1 and isinstance(predecessor, dace.nodes.CodeNode)): # Also remove potentially isolated input nodes for e in state.in_edges(predecessor): if len(state.all_edges(e.src)) == 1: toremove.add(e.src) toremove.add(predecessor) toremove.add(node) state.remove_nodes_from(toremove)
def remove_scalar_transients(top_sdfg: dace.SDFG): """ Clean up tasklet->scalar-transient, replacing them with symbols. """ dprint = print # lambda *args: pass removed_transients = 0 for sdfg in top_sdfg.all_sdfgs_recursive(): transients_to_remove = {} for dname, desc in sdfg.arrays.items(): skip = False if isinstance(desc, dace.data.Scalar) and desc.transient: # Find node where transient is instantiated init_tasklet: Optional[dace.nodes.Tasklet] = None itstate = None for state in sdfg.nodes(): if skip: break for node in state.nodes(): if (isinstance(node, dace.nodes.AccessNode) and node.data == dname): if state.in_degree(node) > 1: dprint('Cannot remove scalar', dname, '(more than one input)') skip = True break elif state.in_degree(node) == 1: if init_tasklet is not None: dprint('Cannot remove scalar', dname, '(initialized multiple times)') skip = True break init_tasklet = state.in_edges(node)[0].src itstate = state if init_tasklet is None: dprint('Cannot remove scalar', dname, '(uninitialized)') skip = True if skip: continue # We can remove transient, find value from tasklet if len(init_tasklet.code.code) != 1: dprint('Cannot remove scalar', dname, '(complex tasklet)') continue if not isinstance(init_tasklet.code.code[0], ast.Assign): dprint('Cannot remove scalar', dname, '(complex tasklet2)') continue val = float(unparse(init_tasklet.code.code[0].value)) dprint('Converting', dname, 'to constant with value', val) transients_to_remove[dname] = val # Remove initialization tasklet itstate.remove_node(init_tasklet) _remove_transients(sdfg, transients_to_remove) removed_transients += len(transients_to_remove) print('Cleaned up %d extra scalar transients' % removed_transients)
def remove_constant_stencils(top_sdfg: dace.SDFG): dprint = print # lambda *args: pass removed_transients = 0 for sdfg in top_sdfg.all_sdfgs_recursive(): transients_to_remove = {} for state in sdfg.nodes(): for node in state.nodes(): if (isinstance(node, stencil.Stencil) and state.in_degree(node) == 0 and state.out_degree(node) == 1): # We can remove transient, find value from tasklet if len(node.code.code) != 1: dprint('Cannot remove scalar stencil', node.name, '(complex code)') continue if not isinstance(node.code.code[0], ast.Assign): dprint('Cannot remove scalar stencil', node.name, '(complex code2)') continue # Ensure no one else is writing to it onode = state.memlet_path(state.out_edges(node)[0])[-1].dst dname = state.out_edges(node)[0].data.data if any( s.in_degree(n) > 0 for s in sdfg.nodes() for n in s.nodes() if n != onode and isinstance( n, dace.nodes.AccessNode) and n.data == dname): continue val = float(eval(unparse(node.code.code[0].value))) dprint('Converting scalar stencil result', dname, 'to constant with value', val) transients_to_remove[dname] = val # Remove initialization tasklet state.remove_node(node) _remove_transients(sdfg, transients_to_remove, ReplaceSubscript) removed_transients += len(transients_to_remove) print('Cleaned up %d extra scalar stencils' % removed_transients)
def transfer(sdfg: dace.SDFG, tuner, k: int = 5): assert isinstance(tuner, OnTheFlyMapFusionTuner) dreport = sdfg.get_instrumented_data() assert dreport is not None tuning_report = tuner.optimize(apply=False) best_configs = cutout_tuner.CutoutTuner.top_k_configs(tuning_report, k=k) subgraph_patterns = tuner._extract_patterns(best_configs) i = 0 for nsdfg in sdfg.all_sdfgs_recursive(): for state in nsdfg.states(): i = i + 1 top_maps = [] for node in state.nodes(): if isinstance(node, dace.nodes.MapEntry) and xfh.get_parent_map( state, node) is None: top_maps.append(node) if len(top_maps) < 2: continue try: cutout = cutter.cutout_state(state, *(state.nodes()), make_copy=False) except AttributeError: continue while True: base_runtime = None best_pattern = None best_pattern_runtime = math.inf for j, pattern in enumerate(subgraph_patterns): maps = [] for node in state.nodes(): if isinstance( node, dace.nodes.MapEntry ) and xfh.get_parent_map(state, node) is None: maps.append(node) if len(maps) < 2: continue maps_desc = {} state_desc = Counter() for map_entry in maps: map_desc = OnTheFlyMapFusionTuner.map_descriptor( state, map_entry) state_desc.update({map_desc: 1}) if not map_desc in maps_desc: maps_desc[map_desc] = [] maps_desc[map_desc].append(map_entry) included = True for key in pattern: if not key in state_desc or pattern[ key] > state_desc[key]: included = False break if not included: continue if base_runtime is None: baseline = cutter.cutout_state(state, *(state.nodes()), make_copy=False) baseline.start_state.instrument = dace.InstrumentationType.GPU_Events dreport_ = {} for cstate in baseline.nodes(): for dnode in cstate.data_nodes(): array = baseline.arrays[dnode.data] if array.transient: continue try: data = dreport.get_first_version( dnode.data) dreport_[dnode.data] = data except: continue base_runtime = optim_utils.subprocess_measure( baseline, dreport_, i=192, j=192) best_pattern_runtime = base_runtime if base_runtime == math.inf: break # Construct subgraph greedily subgraph_maps = [] for desc in pattern: num = pattern[desc] subgraph_maps.extend(maps_desc[desc][:num]) # Apply experiment_sdfg_ = cutter.cutout_state( state, *(state.nodes()), make_copy=False) experiment_state_ = experiment_sdfg_.start_state experiment_maps_ids = list( map(lambda me: experiment_state_.node_id(me), subgraph_maps)) experiment_sdfg = copy.deepcopy(experiment_sdfg_) experiment_state = experiment_sdfg.start_state experiment_state.instrument = dace.InstrumentationType.GPU_Events experiment_maps = list( map(lambda m_id: experiment_state.node(m_id), experiment_maps_ids)) experiment_subgraph = helpers.subgraph_from_maps( sdfg=experiment_sdfg, graph=experiment_state, map_entries=experiment_maps) map_fusion = sg.SubgraphOTFFusion() map_fusion.setup_match( experiment_subgraph, experiment_sdfg.sdfg_id, experiment_sdfg.node_id(experiment_state)) if map_fusion.can_be_applied(experiment_state, experiment_sdfg): try: experiment_fuse_counter = map_fusion.apply( experiment_state, experiment_sdfg) except: continue if experiment_fuse_counter == 0: continue dreport_ = {} for cstate in experiment_sdfg.nodes(): for dnode in cstate.data_nodes(): array = experiment_sdfg.arrays[dnode.data] if array.transient: continue try: data = dreport.get_first_version( dnode.data) dreport_[dnode.data] = data except: continue fused_runtime = optim_utils.subprocess_measure( experiment_sdfg, dreport_, i=192, j=192) if fused_runtime >= best_pattern_runtime: continue best_pattern = subgraph_maps best_pattern_runtime = fused_runtime if best_pattern is not None: subgraph = helpers.subgraph_from_maps( sdfg=nsdfg, graph=state, map_entries=best_pattern) map_fusion = sg.SubgraphOTFFusion() map_fusion.setup_match(subgraph, nsdfg.sdfg_id, nsdfg.node_id(state)) actual_fuse_counter = map_fusion.apply(state, nsdfg) best_pattern = None base_runtime = None best_pattern_runtime = math.inf else: break