def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[StateAssignElimination._end_state]] edge = sdfg.in_edges(state)[0] # Since inter-state assignments that use an assigned value leads to # undefined behavior (e.g., {m: n, n: m}), we can replace each # assignment separately. keys_to_remove = set() assignments_to_consider = _assignments_to_consider(sdfg, edge) for varname, assignment in assignments_to_consider.items(): state.replace(varname, assignment) keys_to_remove.add(varname) repl_dict = {} for varname in keys_to_remove: # Remove assignments from edge del edge.data.assignments[varname] for e in sdfg.edges(): if varname in e.data.free_symbols: break else: # If removed assignment does not appear in any other edge, # replace and remove symbol if assignments_to_consider[varname] in sdfg.symbols: repl_dict[varname] = assignments_to_consider[varname] if varname in sdfg.symbols: sdfg.remove_symbol(varname) def _str_repl(s, d): for k, v in d.items(): s.replace(str(k), str(v)) if repl_dict: symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m))
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[StartStateElimination.start_state]] # Move assignments to the nested SDFG node's symbol mappings node = sdfg.parent_nsdfg_node edge = sdfg.out_edges(state)[0] for k, v in edge.data.assignments.items(): node.symbol_mapping[k] = v sdfg.remove_node(state)
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[EndStateElimination._end_state]] # Handle orphan symbols (due to the deletion the incoming edge) edge = sdfg.in_edges(state)[0] sym_assign = edge.data.assignments.keys() sdfg.remove_node(state) # Remove orphan symbols for sym in sym_assign: if sym in sdfg.free_symbols: sdfg.remove_symbol(sym)
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[StateAssignElimination._end_state]] edge = sdfg.in_edges(state)[0] # Since inter-state assignments that use an assigned value leads to # undefined behavior (e.g., {m: n, n: m}), we can replace each # assignment separately. for varname, assignment in edge.data.assignments.items(): state.replace(varname, assignment) # Remove assignments from edge edge.data.assignments = {}
def apply(self, sdfg): fstate = sdfg.nodes()[self.subgraph[SymbolAliasPromotion._first_state]] sstate = sdfg.nodes()[self.subgraph[ SymbolAliasPromotion._second_state]] edge = sdfg.edges_between(fstate, sstate)[0].data in_edge = sdfg.in_edges(fstate)[0].data to_consider = _alias_assignments(sdfg, edge) to_not_consider = set() for k, v in to_consider.items(): # Remove symbols that are taking part in the edge's condition condsyms = [str(s) for s in edge.condition_sympy().free_symbols] if k in condsyms: to_not_consider.add(k) # Remove symbols that are set in the in_edge # with a different assignment if k in in_edge.assignments and in_edge.assignments[k] != v: to_not_consider.add(k) # Remove symbols whose assignment (RHS) is a symbol # and is set in the in_edge. if v in sdfg.symbols and v in in_edge.assignments: to_not_consider.add(k) # Remove symbols whose assignment (RHS) is a scalar # and is set in the first state. if v in sdfg.arrays and isinstance(sdfg.arrays[v], dt.Scalar): if any( isinstance(n, nodes.AccessNode) and n.data == v for n in fstate.nodes()): to_not_consider.add(k) for k in to_not_consider: del to_consider[k] for k, v in to_consider.items(): del edge.assignments[k] in_edge.assignments[k] = v
def copy_memory(self, sdfg: sdfg.SDFG, dfg: state.StateSubgraphView, state_id: int, src_node: nodes.Node, dst_node: nodes.Node, edge: graph.MultiConnectorEdge, function_stream: prettycode.CodeIOStream, callsite_stream: prettycode.CodeIOStream): """ Generate input/output memory copies from the array references to local variables (i.e. for the tasklet code). """ if isinstance(edge.src, nodes.AccessNode) and isinstance( edge.dst, nodes.Tasklet): # handle AccessNode->Tasklet if isinstance(dst_node.in_connectors[edge.dst_conn], dtypes.pointer): # pointer accessor line: str = "{} {} = &{}[0];".format( dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn, edge.src.data) elif isinstance(dst_node.in_connectors[edge.dst_conn], dtypes.vector): # vector accessor line: str = "{} {} = *({} *)(&{}[0]);".format( dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn, dst_node.in_connectors[edge.dst_conn].ctype, edge.src.data) else: # scalar accessor arr = sdfg.arrays[edge.data.data] if isinstance(arr, data.Array): line: str = "{}* {} = &{}[0];".format( dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn, edge.src.data) elif isinstance(arr, data.Scalar): line: str = "{} {} = {};".format( dst_node.in_connectors[edge.dst_conn].ctype, edge.dst_conn, edge.src.data) elif isinstance(edge.src, nodes.MapEntry) and isinstance( edge.dst, nodes.Tasklet): rtl_name = self.unique_name(edge.dst, sdfg.nodes()[state_id], sdfg) self.n_unrolled[rtl_name] = symbolic.evaluate( edge.src.map.range[0][1] + 1, sdfg.constants) line: str = f'{dst_node.in_connectors[edge.dst_conn]} {edge.dst_conn} = &{edge.data.data}[{edge.src.map.params[0]}*{edge.data.volume}];' else: raise RuntimeError( "Not handling copy_memory case of type {} -> {}.".format( type(edge.src), type(edge.dst))) # write accessor to file callsite_stream.write(line)
def can_be_applied(graph, candidate, expr_index, sdfg, strict=False): state = graph.nodes()[candidate[StateAssignElimination._end_state]] out_edges = graph.out_edges(state) in_edges = graph.in_edges(state) # We only match end states with one source and at least one assignment if len(in_edges) != 1: return False edge = in_edges[0] assignments_to_consider = _assignments_to_consider(sdfg, edge) # No assignments to eliminate if len(assignments_to_consider) == 0: return False # If this is an end state, there are no other edges to consider if len(out_edges) == 0: return True # Otherwise, ensure the symbols are never set/used again in edges akeys = set(assignments_to_consider.keys()) for e in sdfg.edges(): if e is edge: continue if e.data.free_symbols & akeys: return False # If used in any state that is not the current one, fail for s in sdfg.nodes(): if s is state: continue if s.free_symbols & akeys: return False return True
def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): sdutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): sdutil.change_edge_src(sdfg, second_state, first_state) sdutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in sdutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in sdutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in sdutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state # First keep a backup of the topological sorted order of the nodes order = [ x for x in reversed(list(nx.topological_sort(first_state._nx))) if isinstance(x, nodes.AccessNode) ] for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in second_input: if first_state.in_degree(node) == 0: n = next((x for x in order if x.label == node.label), None) if n: sdutil.change_edge_src(first_state, node, n) first_state.remove_node(node) n.access = dtypes.AccessType.ReadWrite # Redirect edges and remove second state sdutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if Config.get_bool("debugprint"): StateFusion._states_fused += 1
def unparse_tasklet(self, sdfg: sdfg.SDFG, dfg: state.StateSubgraphView, state_id: int, node: nodes.Node, function_stream: prettycode.CodeIOStream, callsite_stream: prettycode.CodeIOStream): # extract data state = sdfg.nodes()[state_id] tasklet = node # construct variables paths unique_name: str = "{}_{}_{}_{}".format(tasklet.name, sdfg.sdfg_id, sdfg.node_id(state), state.node_id(tasklet)) # Collect all of the input and output connectors into buses and scalars buses = {} scalars = {} for edge in state.in_edges(tasklet): arr = sdfg.arrays[edge.src.data] # catch symbolic (compile time variables) check_issymbolic([ tasklet.in_connectors[edge.dst_conn].veclen, tasklet.in_connectors[edge.dst_conn].bytes ], sdfg) # extract parameters vec_len = int( symbolic.evaluate(tasklet.in_connectors[edge.dst_conn].veclen, sdfg.constants)) total_size = int( symbolic.evaluate(tasklet.in_connectors[edge.dst_conn].bytes, sdfg.constants)) if isinstance(arr, data.Array): if self.hardware_target: raise NotImplementedError( 'Array input for hardware* not implemented') else: buses[edge.dst_conn] = (False, total_size, vec_len) elif isinstance(arr, data.Stream): buses[edge.dst_conn] = (False, total_size, vec_len) elif isinstance(arr, data.Scalar): scalars[edge.dst_conn] = (False, total_size * 8) for edge in state.out_edges(tasklet): arr = sdfg.arrays[edge.dst.data] # catch symbolic (compile time variables) check_issymbolic([ tasklet.out_connectors[edge.src_conn].veclen, tasklet.out_connectors[edge.src_conn].bytes ], sdfg) # extract parameters vec_len = int( symbolic.evaluate(tasklet.out_connectors[edge.src_conn].veclen, sdfg.constants)) total_size = int( symbolic.evaluate(tasklet.out_connectors[edge.src_conn].bytes, sdfg.constants)) if isinstance(arr, data.Array): if self.hardware_target: raise NotImplementedError( 'Array input for hardware* not implemented') else: buses[edge.src_conn] = (True, total_size, vec_len) elif isinstance(arr, data.Stream): buses[edge.src_conn] = (True, total_size, vec_len) elif isinstance(arr, data.Scalar): print('Scalar output not implemented') # generate system verilog module components parameter_string: str = self.generate_rtl_parameters(sdfg.constants) inputs, outputs = self.generate_rtl_inputs_outputs(buses, scalars) # create rtl code object (that is later written to file) self.code_objects.append( codeobject.CodeObject( name="{}".format(unique_name), code=RTLCodeGen.RTL_HEADER.format(name=unique_name, parameters=parameter_string, inputs="\n".join(inputs), outputs="\n".join(outputs)) + tasklet.code.code + RTLCodeGen.RTL_FOOTER, language="sv", target=RTLCodeGen, title="rtl", target_type="{}".format(unique_name), additional_compiler_kwargs="", linkable=True, environments=None)) if self.hardware_target: if self.vendor == 'xilinx': rtllib_config = { "name": unique_name, "buses": { name: ('m_axis' if is_output else 's_axis', vec_len) for name, (is_output, _, vec_len) in buses.items() }, "params": { "scalars": { name: total_size for name, (_, total_size) in scalars.items() }, "memory": {} }, "ip_cores": tasklet.ip_cores if isinstance( tasklet, nodes.RTLTasklet) else {}, } self.code_objects.append( codeobject.CodeObject(name=f"{unique_name}_control", code=rtllib_control(rtllib_config), language="v", target=RTLCodeGen, title="rtl", target_type="{}".format(unique_name), additional_compiler_kwargs="", linkable=True, environments=None)) self.code_objects.append( codeobject.CodeObject(name=f"{unique_name}_top", code=rtllib_top(rtllib_config), language="v", target=RTLCodeGen, title="rtl", target_type="{}".format(unique_name), additional_compiler_kwargs="", linkable=True, environments=None)) self.code_objects.append( codeobject.CodeObject(name=f"{unique_name}_package", code=rtllib_package(rtllib_config), language="tcl", target=RTLCodeGen, title="rtl", target_type="scripts", additional_compiler_kwargs="", linkable=True, environments=None)) self.code_objects.append( codeobject.CodeObject(name=f"{unique_name}_synth", code=rtllib_synth(rtllib_config), language="tcl", target=RTLCodeGen, title="rtl", target_type="scripts", additional_compiler_kwargs="", linkable=True, environments=None)) else: # self.vendor != "xilinx" raise NotImplementedError( 'Only RTL codegen for Xilinx is implemented') else: # not hardware_target # generate verilator simulation cpp code components inputs, outputs = self.generate_cpp_inputs_outputs(tasklet) valid_zeros, ready_zeros = self.generate_cpp_zero_inits(tasklet) vector_init = self.generate_cpp_vector_init(tasklet) num_elements = self.generate_cpp_num_elements(tasklet) internal_state_str, internal_state_var = self.generate_cpp_internal_state( tasklet) read_input_hs = self.generate_input_hs(tasklet) feed_elements = self.generate_feeding(tasklet, inputs) in_ptrs, out_ptrs = self.generate_ptrs(tasklet) export_elements = self.generate_exporting(tasklet, outputs) write_output_hs = self.generate_write_output_hs(tasklet) hs_flags = self.generate_hs_flags(tasklet) input_hs_toggle = self.generate_input_hs_toggle(tasklet) output_hs_toggle = self.generate_output_hs_toggle(tasklet) running_condition = self.generate_running_condition(tasklet) # add header code to stream if not self.cpp_general_header_added: sdfg.append_global_code( cpp_code=RTLCodeGen.CPP_GENERAL_HEADER_TEMPLATE.format( debug_include="// generic includes\n#include <iostream>" if self.verilator_debug else "")) self.cpp_general_header_added = True sdfg.append_global_code( cpp_code=RTLCodeGen.CPP_MODEL_HEADER_TEMPLATE.format( name=unique_name)) # add main cpp code to stream callsite_stream.write(contents=RTLCodeGen.CPP_MAIN_TEMPLATE.format( name=unique_name, inputs=inputs, outputs=outputs, num_elements=str.join('\n', num_elements), vector_init=vector_init, valid_zeros=str.join('\n', valid_zeros), ready_zeros=str.join('\n', ready_zeros), read_input_hs=str.join('\n', read_input_hs), feed_elements=str.join('\n', feed_elements), in_ptrs=str.join('\n', in_ptrs), out_ptrs=str.join('\n', out_ptrs), export_elements=str.join('\n', export_elements), write_output_hs=str.join('\n', write_output_hs), hs_flags=str.join('\n', hs_flags), input_hs_toggle=str.join('\n', input_hs_toggle), output_hs_toggle=str.join('\n', output_hs_toggle), running_condition=str.join(' && ', running_condition), internal_state_str=internal_state_str, internal_state_var=internal_state_var, debug_sim_start="std::cout << \"SIM {name} START\" << std::endl;" if self.verilator_debug else "", debug_internal_state=""" // report internal state VL_PRINTF("[t=%lu] ap_aclk=%u ap_areset=%u valid_i=%u ready_i=%u valid_o=%u ready_o=%u \\n", main_time, model->ap_aclk, model->ap_areset, model->valid_i, model->ready_i, model->valid_o, model->ready_o); VL_PRINTF("{internal_state_str}\\n", {internal_state_var}); std::cout << std::flush; """.format(internal_state_str=internal_state_str, internal_state_var=internal_state_var) if self.verilator_debug else "", debug_sim_end="std::cout << \"SIM {name} END\" << std::endl;" if self.verilator_debug else ""), sdfg=sdfg, state_id=state_id, node_id=node)
def apply(self, sdfg): state = sdfg.nodes()[self.subgraph[EndStateElimination._end_state]] sdfg.remove_node(state)
def unparse_tasklet(self, sdfg: sdfg.SDFG, dfg: state.StateSubgraphView, state_id: int, node: nodes.Node, function_stream: prettycode.CodeIOStream, callsite_stream: prettycode.CodeIOStream): # extract data state = sdfg.nodes()[state_id] tasklet = node # construct variables paths unique_name: str = "top_{}_{}_{}".format(sdfg.sdfg_id, sdfg.node_id(state), state.node_id(tasklet)) # generate system verilog module components parameter_string: str = self.generate_rtl_parameters(sdfg.constants) inputs, outputs = self.generate_rtl_inputs_outputs(sdfg, tasklet) # create rtl code object (that is later written to file) self.code_objects.append( codeobject.CodeObject( name="{}".format(unique_name), code=RTLCodeGen.RTL_HEADER.format(name=unique_name, parameters=parameter_string, inputs="\n".join(inputs), outputs="\n".join(outputs)) + tasklet.code.code + RTLCodeGen.RTL_FOOTER, language="sv", target=RTLCodeGen, title="rtl", target_type="", additional_compiler_kwargs="", linkable=True, environments=None)) # generate verilator simulation cpp code components inputs, outputs = self.generate_cpp_inputs_outputs(tasklet) vector_init = self.generate_cpp_vector_init(tasklet) num_elements = self.generate_cpp_num_elements() internal_state_str, internal_state_var = self.generate_cpp_internal_state( tasklet) # add header code to stream if not self.cpp_general_header_added: sdfg.append_global_code( cpp_code=RTLCodeGen.CPP_GENERAL_HEADER_TEMPLATE.format( debug_include="// generic includes\n#include <iostream>" if self.verilator_debug else "")) self.cpp_general_header_added = True sdfg.append_global_code( cpp_code=RTLCodeGen.CPP_MODEL_HEADER_TEMPLATE.format( name=unique_name)) # add main cpp code to stream callsite_stream.write(contents=RTLCodeGen.CPP_MAIN_TEMPLATE.format( name=unique_name, inputs=inputs, outputs=outputs, num_elements=num_elements, vector_init=vector_init, internal_state_str=internal_state_str, internal_state_var=internal_state_var, debug_sim_start="std::cout << \"SIM {name} START\" << std::endl;" if self.verilator_debug else "", debug_feed_element="std::cout << \"feed new element\" << std::endl;" if self.verilator_debug else "", debug_export_element="std::cout << \"export element\" << std::endl;" if self.verilator_debug else "", debug_internal_state=""" // report internal state VL_PRINTF("[t=%lu] clk_i=%u rst_i=%u valid_i=%u ready_i=%u valid_o=%u ready_o=%u \\n", main_time, model->clk_i, model->rst_i, model->valid_i, model->ready_i, model->valid_o, model->ready_o); VL_PRINTF("{internal_state_str}\\n", {internal_state_var}); std::cout << std::flush; """.format(internal_state_str=internal_state_str, internal_state_var=internal_state_var) if self.verilator_debug else "", debug_read_input_hs= "std::cout << \"remove read_input_hs flag\" << std::endl;" if self.verilator_debug else "", debug_output_hs= "std::cout << \"remove write_output_hs flag\" << std::endl;" if self.verilator_debug else "", debug_sim_end="std::cout << \"SIM {name} END\" << std::endl;" if self.verilator_debug else ""), sdfg=sdfg, state_id=state_id, node_id=node)
def apply(self, sdfg): first_state = sdfg.nodes()[self.subgraph[StateFusion._first_state]] second_state = sdfg.nodes()[self.subgraph[StateFusion._second_state]] # Remove interstate edge(s) edges = sdfg.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: for src, dst, other_data in sdfg.in_edges(first_state): other_data.assignments.update(edge.data.assignments) sdfg.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): nxutil.change_edge_dest(sdfg, first_state, second_state) sdfg.remove_node(first_state) return # Special case 2: second state is empty if second_state.is_empty(): nxutil.change_edge_src(sdfg, second_state, first_state) nxutil.change_edge_dest(sdfg, second_state, first_state) sdfg.remove_node(second_state) return # Normal case: both states are not empty # Find source/sink (data) nodes first_input = [ node for node in nxutil.find_source_nodes(first_state) if isinstance(node, nodes.AccessNode) ] first_output = [ node for node in nxutil.find_sink_nodes(first_state) if isinstance(node, nodes.AccessNode) ] second_input = [ node for node in nxutil.find_source_nodes(second_state) if isinstance(node, nodes.AccessNode) ] # first input = first input - first output first_input = [ node for node in first_input if next((x for x in first_output if x.label == node.label), None) is None ] # Merge second state to first state for node in second_state.nodes(): first_state.add_node(node) for src, src_conn, dst, dst_conn, data in second_state.edges(): first_state.add_edge(src, src_conn, dst, dst_conn, data) # Merge common (data) nodes for node in first_input: try: old_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_src(first_state, old_node, node) first_state.remove_node(old_node) second_input.remove(old_node) for node in first_output: try: new_node = next(x for x in second_input if x.label == node.label) except StopIteration: continue nxutil.change_edge_dest(first_state, node, new_node) first_state.remove_node(node) second_input.remove(new_node) # Redirect edges and remove second state nxutil.change_edge_src(sdfg, second_state, first_state) sdfg.remove_node(second_state) if Config.get_bool("debugprint"): StateFusion._states_fused += 1
def apply(self, sdfg): """ The method creates two nested maps. The inner map ranges over the reduction axes, while the outer map ranges over the rest of the input dimensions. The inner map contains a trivial tasklet, while the outgoing edges copy the reduction WCR. """ graph = sdfg.nodes()[self.state_id] red_node = graph.nodes()[self.subgraph[ReduceExpansion._reduce]] inputs = [] in_memlets = [] for src, _, _, _, memlet in graph.in_edges(red_node): if src not in inputs: inputs.append(src) in_memlets.append(memlet) if len(inputs) > 1: raise NotImplementedError outputs = [] out_memlets = [] for _, _, dst, _, memlet in graph.out_edges(red_node): if dst not in outputs: outputs.append(dst) out_memlets.append(memlet) if len(outputs) > 1: raise NotImplementedError axes = red_node.axes if axes is None: axes = tuple(i for i in range(in_memlets[0].subset.dims())) outer_map_range = {} inner_map_range = {} for idx, r in enumerate(in_memlets[0].subset): if idx in axes: inner_map_range.update({ "__dim_{}".format(str(idx)): subsets.Range.dim_to_string(r) }) else: outer_map_range.update({ "__dim_{}".format(str(idx)): subsets.Range.dim_to_string(r) }) if len(outer_map_range) > 0: outer_map_entry, outer_map_exit = graph.add_map( 'reduce_outer', outer_map_range, schedule=red_node.schedule) inner_map_entry, inner_map_exit = graph.add_map( 'reduce_inner', inner_map_range, schedule=(dtypes.ScheduleType.Default if len(outer_map_range) > 0 else red_node.schedule)) tasklet = graph.add_tasklet(name='red_tasklet', inputs={'in_1'}, outputs={'out_1'}, code='out_1 = in_1') inner_map_entry.in_connectors = {'IN_1'} inner_map_entry.out_connectors = {'OUT_1'} outer_in_memlet = dcpy(in_memlets[0]) if len(outer_map_range) > 0: outer_map_entry.in_connectors = {'IN_1'} outer_map_entry.out_connectors = {'OUT_1'} graph.add_edge(inputs[0], None, outer_map_entry, 'IN_1', outer_in_memlet) else: graph.add_edge(inputs[0], None, inner_map_entry, 'IN_1', outer_in_memlet) med_in_memlet = dcpy(in_memlets[0]) med_in_range = [] for idx, r in enumerate(med_in_memlet.subset): if idx in axes: med_in_range.append(r) else: med_in_range.append(("__dim_{}".format(str(idx)), "__dim_{}".format(str(idx)), 1)) med_in_memlet.subset = subsets.Range(med_in_range) med_in_memlet.num_accesses = med_in_memlet.subset.num_elements() if len(outer_map_range) > 0: graph.add_edge(outer_map_entry, 'OUT_1', inner_map_entry, 'IN_1', med_in_memlet) inner_in_memlet = dcpy(med_in_memlet) inner_in_idx = [] for idx in range(len(inner_in_memlet.subset)): inner_in_idx.append("__dim_{}".format(str(idx))) inner_in_memlet.subset = subsets.Indices(inner_in_idx) inner_in_memlet.num_accesses = inner_in_memlet.subset.num_elements() graph.add_edge(inner_map_entry, 'OUT_1', tasklet, 'in_1', inner_in_memlet) inner_map_exit.in_connectors = {'IN_1'} inner_map_exit.out_connectors = {'OUT_1'} inner_out_memlet = dcpy(out_memlets[0]) inner_out_idx = [] for idx, r in enumerate(inner_in_memlet.subset): if idx not in axes: inner_out_idx.append(r) if len(inner_out_idx) == 0: inner_out_idx = [0] inner_out_memlet.subset = subsets.Indices(inner_out_idx) inner_out_memlet.wcr = red_node.wcr inner_out_memlet.num_accesses = inner_out_memlet.subset.num_elements() graph.add_edge(tasklet, 'out_1', inner_map_exit, 'IN_1', inner_out_memlet) outer_out_memlet = dcpy(out_memlets[0]) outer_out_range = [] for idx, r in enumerate(outer_out_memlet.subset): if idx not in axes: outer_out_range.append(r) if len(outer_out_range) == 0: outer_out_range = [(0, 0, 1)] outer_out_memlet.subset = subsets.Range(outer_out_range) outer_out_memlet.wcr = red_node.wcr if len(outer_map_range) > 0: outer_map_exit.in_connectors = {'IN_1'} outer_map_exit.out_connectors = {'OUT_1'} med_out_memlet = dcpy(inner_out_memlet) med_out_memlet.num_accesses = med_out_memlet.subset.num_elements() graph.add_edge(inner_map_exit, 'OUT_1', outer_map_exit, 'IN_1', med_out_memlet) graph.add_edge(outer_map_exit, 'OUT_1', outputs[0], None, outer_out_memlet) else: graph.add_edge(inner_map_exit, 'OUT_1', outputs[0], None, outer_out_memlet) graph.remove_edge(graph.in_edges(red_node)[0]) graph.remove_edge(graph.out_edges(red_node)[0]) graph.remove_node(red_node) return