def apply(self, state: SDFGState, sdfg: SDFG): map_entry = self.map_entry current_map = map_entry.map # Expand the innermost map if multidimensional if len(current_map.params) > 1: ext, rem = dace.transformation.helpers.extract_map_dims( sdfg, map_entry, list(range(len(current_map.params) - 1))) map_entry = rem current_map = map_entry.map subgraph = state.scope_subgraph(map_entry) # Set the schedule current_map.schedule = dace.dtypes.ScheduleType.SVE_Map # Infer all connector types and apply them inferred = infer_types.infer_connector_types(sdfg, state, subgraph) infer_types.apply_connector_types(inferred) # Infer vector connectors and AccessNodes and apply them vector_inference.infer_vectors( sdfg, state, map_entry, self.vec_len, flags=vector_inference.VectorInferenceFlags.Allow_Stride, apply=True)
def vectorize(sdfg: SDFG) -> vector_inference.VectorInferenceGraph: return vector_inference.infer_vectors(sdfg, sdfg.start_state, find_map_entry(sdfg), -1, apply=False)
def can_be_applied(self, state: SDFGState, expr_index, sdfg: SDFG, permissive=False) -> bool: map_entry = self.map_entry current_map = map_entry.map subgraph = state.scope_subgraph(map_entry) subgraph_contents = state.scope_subgraph(map_entry, include_entry=False, include_exit=False) # Prevent infinite repeats if current_map.schedule == dace.dtypes.ScheduleType.SVE_Map: return False # Infer all connector types for later checks (without modifying the graph) inferred = infer_types.infer_connector_types(sdfg, state, subgraph) ######################## # Ensure only Tasklets and AccessNodes are within the map for node, _ in subgraph_contents.all_nodes_recursive(): if not isinstance(node, (nodes.Tasklet, nodes.AccessNode)): return False ######################## # Check for unsupported datatypes on the connectors (including on the Map itself) bit_widths = set() for node, _ in subgraph.all_nodes_recursive(): for conn in node.in_connectors: t = inferred[(node, conn, True)] bit_widths.add(util.get_base_type(t).bytes) if not t.type in sve.util.TYPE_TO_SVE: return False for conn in node.out_connectors: t = inferred[(node, conn, False)] bit_widths.add(util.get_base_type(t).bytes) if not t.type in sve.util.TYPE_TO_SVE: return False # Multiple different bit widths occuring (messes up the predicates) if len(bit_widths) > 1: return False ######################## # Check for unsupported memlets param_name = current_map.params[-1] for e, _ in subgraph.all_edges_recursive(): # Check for unsupported strides # The only unsupported strides are the ones containing the innermost # loop param because they are not constant during a vector step param_sym = symbolic.symbol(current_map.params[-1]) if param_sym in e.data.get_stride(sdfg, map_entry.map).free_symbols: return False # Check for unsupported WCR if e.data.wcr is not None: # Unsupported reduction type reduction_type = dace.frontend.operations.detect_reduction_type( e.data.wcr) if reduction_type not in sve.util.REDUCTION_TYPE_TO_SVE: return False # Param in memlet during WCR is not supported if param_name in e.data.subset.free_symbols and e.data.wcr_nonatomic: return False # vreduce is not supported dst_node = state.memlet_path(e)[-1] if isinstance(dst_node, nodes.Tasklet): if isinstance(dst_node.in_connectors[e.dst_conn], dtypes.vector): return False elif isinstance(dst_node, nodes.AccessNode): desc = dst_node.desc(sdfg) if isinstance(desc, data.Scalar) and isinstance( desc.dtype, dtypes.vector): return False ######################## # Check for invalid copies in the subgraph for node, _ in subgraph.all_nodes_recursive(): if not isinstance(node, nodes.Tasklet): continue for e in state.in_edges(node): # Check for valid copies from other tasklets and/or streams if e.data.data is not None: src_node = state.memlet_path(e)[0].src if not isinstance(src_node, (nodes.Tasklet, nodes.AccessNode)): # Make sure we only have Code->Code copies and from arrays return False if isinstance(src_node, nodes.AccessNode): src_desc = src_node.desc(sdfg) if isinstance(src_desc, dace.data.Stream): # Stream pops are not implemented return False # Run the vector inference algorithm to check if vectorization is feasible try: vector_inference.infer_vectors( sdfg, state, map_entry, self.vec_len, flags=vector_inference.VectorInferenceFlags.Allow_Stride, apply=False) except vector_inference.VectorInferenceException as ex: return False return True