Пример #1
0
def count_arithmetic_ops(sdfg: dace.SDFG,
                         symbols: Dict[str, Any] = None) -> int:
    result = 0
    symbols = symbols or {}
    for state in sdfg.nodes():
        result += count_arithmetic_ops_state(state, symbols)
    return result
Пример #2
0
 def apply_pass(self, sdfg: SDFG, _) -> Dict[SDFGState, Set[SDFGState]]:
     """
     :return: A dictionary mapping each state to its other reachable states.
     """
     reachable: Dict[SDFGState, Set[SDFGState]] = {}
     tc: nx.DiGraph = nx.transitive_closure(sdfg.nx)
     for state in sdfg.nodes():
         reachable[state] = set(tc.successors(state))
     return reachable
Пример #3
0
def optimize_for_gpu(sdfg: dace.SDFG, m: int, n: int, k: int):
    """ Optimize the matrix multiplication example for GPUs. """
    # Ensure integers are 32-bit by default
    dace.Config.set('compiler', 'default_data_types', value='C')

    # Fuse the map and reduce nodes
    sdfg.apply_transformations(MapReduceFusion)

    # Apply GPU transformation
    sdfg.apply_gpu_transformations()

    # Find multiplication map
    entry = find_map_by_param(sdfg, 'k')

    # Create a tiling strategy
    divides_evenly = (m % 64 == 0) and (n % 64 == 0) and (k % 8 == 0)
    xfutil.tile(sdfg, entry, divides_evenly, True, i=64, j=64, k=8)
    xfutil.tile(sdfg, entry, divides_evenly, True, i=8, j=4)

    # Create kernel schedule by collapsing and reordering maps
    gtile_i = find_map_by_param(sdfg, 'tile_i')
    gtile_j = find_map_by_param(sdfg, 'tile_j')
    btile_i = find_map_by_param(sdfg, 'tile1_i')
    btile_j = find_map_by_param(sdfg, 'tile1_j')
    MapCollapse.apply_to(sdfg, outer_map_entry=gtile_i, inner_map_entry=gtile_j, permissive=True)
    MapCollapse.apply_to(sdfg, outer_map_entry=btile_i, inner_map_entry=btile_j, permissive=True)
    btile = find_map_by_param(sdfg, 'tile1_i')
    btile.map.schedule = dace.ScheduleType.GPU_ThreadBlock

    # Add local storage (shared memory) for A and B on GPU
    ktile = find_map_by_param(sdfg, 'tile_k')
    smem_a = InLocalStorage.apply_to(sdfg, dict(array='A'), node_a=ktile, node_b=btile)
    smem_b = InLocalStorage.apply_to(sdfg, dict(array='B'), node_a=ktile, node_b=btile)
    sdfg.arrays[smem_a.data].storage = dace.StorageType.GPU_Shared
    sdfg.arrays[smem_b.data].storage = dace.StorageType.GPU_Shared

    # Add local storage (registers) for A and B
    ttile = find_map_by_param(sdfg, 'k')
    warptile, ttile = xfutil.extract_map_dims(sdfg, ttile, [2])
    InLocalStorage.apply_to(sdfg, dict(array='trans_gpu_A'), node_a=warptile, node_b=ttile)
    InLocalStorage.apply_to(sdfg, dict(array='trans_gpu_B'), node_a=warptile, node_b=ttile)

    # Add local storage (registers) for C
    state = next(s for s in sdfg.nodes() if warptile in s.nodes())
    warptile_exit = state.exit_node(warptile)
    btile_exit = state.exit_node(btile)
    AccumulateTransient.apply_to(sdfg, map_exit=warptile_exit, outer_map_exit=btile_exit)
    # Set C tile to zero on allocation
    c_access = next(n for n in state.data_nodes() if n.data == 'trans_gpu_C')
    c_access.setzero = True

    # Unroll microkernel maps
    ttile.map.unroll = True

    # Apply double-buffering on shared memory
    DoubleBuffering.apply_to(sdfg, map_entry=ktile, transient=smem_a)
Пример #4
0
    def _sdfg_freeze_domain_and_origin(self, inner_sdfg: dace.SDFG,
                                       domain: Tuple[int, ...],
                                       origin: Dict[str, Tuple[int, ...]]):
        wrapper_sdfg = dace.SDFG("frozen_" + inner_sdfg.name)
        state = wrapper_sdfg.add_state("frozen_" + inner_sdfg.name + "_state")

        inputs = set()
        outputs = set()
        for inner_state in inner_sdfg.nodes():
            for node in inner_state.nodes():
                if (not isinstance(node, dace.nodes.AccessNode)
                        or inner_sdfg.arrays[node.data].transient):
                    continue
                if node.has_reads(inner_state):
                    inputs.add(node.data)
                if node.has_writes(inner_state):
                    outputs.add(node.data)

        nsdfg = state.add_nested_sdfg(inner_sdfg, None, inputs, outputs)

        self._sdfg_add_arrays_and_edges(wrapper_sdfg,
                                        state,
                                        inner_sdfg,
                                        nsdfg,
                                        inputs,
                                        outputs,
                                        origins=origin)

        # in special case of empty domain, remove entire SDFG.
        if any(d == 0 for d in domain):
            states = wrapper_sdfg.states()
            assert len(states) == 1
            for node in states[0].nodes():
                state.remove_node(node)

        # make sure that symbols are passed throught o inner sdfg
        for symbol in nsdfg.sdfg.free_symbols:
            if symbol not in wrapper_sdfg.symbols:
                wrapper_sdfg.add_symbol(symbol, nsdfg.sdfg.symbols[symbol])

        # Try to inline wrapped SDFG before symbols are specialized to avoid extra views
        inline_sdfgs(wrapper_sdfg)

        self._sdfg_specialize_symbols(wrapper_sdfg, domain)

        for _, _, array in wrapper_sdfg.arrays_recursive():
            if array.transient:
                array.lifetime = dace.dtypes.AllocationLifetime.SDFG

        signature = self.__sdfg_signature__()
        wrapper_sdfg.arg_names = [
            a for a in signature[0] if a not in signature[1]
        ]

        return wrapper_sdfg
Пример #5
0
def _permute_array(array: Array, perm: Tuple[int, int, int], sdfg: dace.SDFG,
                   array_name: str):
    array.shape = [array.shape[i] for i in perm]
    array.strides = [array.strides[i] for i in perm]
    array.offset = [array.offset[i] for i in perm]
    # Modify all edges coming in/out of the array
    for state in sdfg.nodes():
        for e in state.edges():
            if e.data.data == array_name:
                e.data.subset = type(
                    e.data.subset)([e.data.subset[i] for i in perm])
Пример #6
0
def _specialize_transient_strides(sdfg: dace.SDFG, layout_map):
    repldict = replace_strides(
        [array for array in sdfg.arrays.values() if array.transient],
        layout_map,
    )
    sdfg.replace_dict(repldict)
    for state in sdfg.nodes():
        for node in state.nodes():
            if isinstance(node, dace.nodes.NestedSDFG):
                for k, v in repldict.items():
                    if k in node.symbol_mapping:
                        node.symbol_mapping[k] = v
    for k in repldict.keys():
        if k in sdfg.symbols:
            sdfg.remove_symbol(k)
Пример #7
0
def _remove_transients(sdfg: dace.SDFG,
                       transients_to_remove: Dict[str, float],
                       replacer: ast.NodeTransformer = ASTFindReplace):
    """ Replaces transients with constants, removing associated access
        nodes. """
    # Remove transients
    for dname, val in transients_to_remove.items():
        # Add constant, remove data descriptor
        del sdfg.arrays[dname]
        sdfg.add_constant(dname, val)

        for state in sdfg.nodes():
            for node in state.nodes():
                if (isinstance(node, dace.nodes.AccessNode)
                        and node.data == dname):
                    # For all access node instances, remove
                    # outgoing edge connectors from subsequent nodes,
                    # then remove access nodes
                    for edge in state.out_edges(node):
                        for e in state.memlet_tree(edge):
                            # Do not break scopes if there are no other edges
                            if len(state.edges_between(e.src, e.dst)) == 1:
                                state.add_edge(e.src, None, e.dst, None,
                                               dace.Memlet())
                            state.remove_edge_and_connectors(e)
                            # If tasklet, replace connector name with constant
                            if isinstance(e.dst, dace.nodes.Tasklet):
                                replacer({
                                    e.dst_conn: dname
                                }).visit(e.dst.code.code)
                            # If stencil, handle similarly
                            elif isinstance(e.dst, stencil.Stencil):
                                del e.dst.accesses[e.dst_conn]
                                for i, stmt in enumerate(e.dst.code.code):
                                    e.dst.code.code[i] = replacer({
                                        e.dst_conn:
                                        dname
                                    }).visit(stmt)
                            # If dst is a NestedSDFG, add the dst_connector as
                            # a constant and remove internal nodes
                            elif isinstance(e.dst, dace.nodes.NestedSDFG):
                                nsdfg: dace.SDFG = e.dst.sdfg
                                _remove_transients(nsdfg, {dname: val})

                    # Lastly, remove the node itself
                    state.remove_node(node)
Пример #8
0
    def can_be_applied(graph: dace.SDFGState,
                       candidate: Dict[Any, int],
                       expr_index: int,
                       sdfg: dace.SDFG,
                       strict=False):
        stencil_a: Stencil = graph.node(candidate[StencilFusion._stencil_a])
        stencil_b: Stencil = graph.node(candidate[StencilFusion._stencil_b])
        array: nodes.AccessNode = graph.node(
            candidate[StencilFusion._tmp_array])

        # Ensure the stencil shapes match
        if len(stencil_a.shape) != len(stencil_b.shape):
            return False
        if any(sa != sb for sa, sb in zip(stencil_a.shape, stencil_b.shape)):
            return False

        # Ensure that the transient is not used anywhere else and can be
        # removed
        if len(graph.all_edges(array)) != 2:
            return False
        if not sdfg.arrays[array.data].transient:
            return False
        if (len([
                n for state in sdfg.nodes() for n in state.nodes()
                if isinstance(n, nodes.AccessNode) and n.data == array.data
        ]) > 1):
            return False

        # Ensure that second stencil only has one input access of the
        # candidate transient to remove
        edge = graph.out_edges(array)[0]
        if len(stencil_b.accesses[edge.dst_conn][1]) > 1:
            return False

        # TODO: Remove check once stencils can be offset
        if any(a != 0 for a in stencil_b.accesses[edge.dst_conn][1][0]):
            return False

        # Code languages must match
        if stencil_a.code.language != stencil_b.code.language:
            return False

        # TODO: Boundary condition matching checks

        return True
Пример #9
0
def add_gpu_location(sdfg: dace.SDFG, mapEntry, gpu):
    graph = sdfg.nodes()[sdfg.sdfg_id]
    mapEntry.location = {'gpu': gpu}
    exit_edges = [
        e for e in graph.out_edges(mapEntry)
        if isinstance(e.dst, nodes.Tasklet)
    ]
    for e in exit_edges:
        tasklet = e.dst
        tasklet.location = {'gpu': gpu}
    entry_edges = [
        e for e in graph.in_edges(mapEntry)
        if isinstance(e.src, nodes.AccessNode)
        and not isinstance(e.src.desc(sdfg), Scalar)
    ]
    for e in entry_edges:
        data_node = e.src
        data_node.desc(sdfg).location = {'gpu': gpu}
Пример #10
0
    def unify_symbols(sdfg: dace.SDFG):
        """ Uses one set of symbols across all nested SDFGs. """
        for state in sdfg.nodes():
            for node in state.nodes():
                if isinstance(node, dace.nodes.NestedSDFG):
                    # First, get nested symbols and replace them if they match
                    # the names of the outer symbols
                    usedsyms: Set[str] = set()
                    for symvalue in node.symbol_mapping.values():
                        usedsyms |= set(
                            map(
                                str,
                                dace.symbolic.pystr_to_symbolic(
                                    symvalue).free_symbols))

                    # Replace clashing names
                    clashing = usedsyms & (node.sdfg.symbols.keys()
                                           | node.sdfg.arrays.keys())
                    for clash in clashing:
                        new_name = find_new_name(node.sdfg, clash)
                        node.sdfg.replace(clash, new_name)
                        if clash in node.symbol_mapping:
                            node.symbol_mapping[
                                new_name] = node.symbol_mapping[clash]
                            del node.symbol_mapping[clash]

                    # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
                    for symname, symvalue in node.symbol_mapping.items():
                        if str(symname) != str(symvalue):
                            node.sdfg.replace(symname, '__dacesym_' + symname)
                    for symname, symvalue in node.symbol_mapping.items():
                        if str(symname) != str(symvalue):
                            if str(symvalue) in node.sdfg.symbols:
                                del node.sdfg.symbols[str(symvalue)]
                            node.sdfg.replace('__dacesym_' + symname,
                                              str(symvalue))

                    # Replace symbol mapping
                    node.symbol_mapping = {k: k for k in usedsyms}

                    # Recursively descend
                    unify_symbols(node.sdfg)
Пример #11
0
    def get_parents(outermost_node: nodes.Node, node: nodes.Node, sdfg: dace.SDFG, state_id: int) -> List[nodes.Node]:

        parent = None
        # Because dfg is only a subgraph view, it does not contain the entry
        # node for a given entry. This O(n) solution is suboptimal
        for state in sdfg.nodes():
            s_d = state.scope_dict()
            try:
                scope = s_d[node]
            except KeyError:
                continue

            if scope is not None:
                parent = scope
                break
        if parent is None:
            return []
        if parent == outermost_node:
            return [parent]

        return PAPIUtils.get_parents(outermost_node, parent, sdfg, state_id) + [parent]
Пример #12
0
def remove_node_and_computation(sdfg: dace.SDFG, state: dace.SDFGState,
                                node: nd.Node):
    """ Remove a node and the parent nodes that compute this node, if the outputs are not used elsewhere.

        :param sdfg: the sdfg containing the node.
        :param state: the state containing the node.
        :param node: the node to remove
    """
    queue = deque([node])
    while len(queue) > 0:
        current_node = queue.popleft()

        edges = state.in_edges(current_node)
        state.remove_node(current_node)
        for e in edges:
            next_node = e.src
            data_used_in_other_states = isinstance(next_node, nd.AccessNode) and \
                                        any(n.data == next_node.data
                                            for s in sdfg.nodes()
                                            for n in s.nodes() if s is not state)

            if len(state.out_edges(
                    next_node)) == 0 and not data_used_in_other_states:
                queue.append(next_node)
Пример #13
0
    def apply_pass(self, sdfg: SDFG,
                   _) -> Dict[SDFGState, Tuple[Set[str], Set[str]]]:
        """
        :return: A dictionary mapping each state to its other reachable states.
        """
        result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {}
        for state in sdfg.nodes():
            readset, writeset = set(), set()
            for anode in state.data_nodes():
                if state.in_degree(anode) > 0:
                    writeset.add(anode.data)
                if state.out_degree(anode) > 0:
                    readset.add(anode.data)

            result[state] = (readset, writeset)

        # Edges that read from arrays add to both ends' access sets
        anames = sdfg.arrays.keys()
        for e in sdfg.edges():
            fsyms = e.data.free_symbols & anames
            if fsyms:
                result[e.src][0].update(fsyms)
                result[e.dst][0].update(fsyms)
        return result
Пример #14
0
    def apply(self, sdfg: dace.SDFG):
        graph = sdfg.nodes()[self.state_id]
        t1 = graph.nodes()[self.subgraph[self.t1]]
        t2 = graph.nodes()[self.subgraph[self.t2]]

        def rename_conn(conn: str, names: Set[str]) -> str:
            """ Renames connector so that it doesn't clash with names.
            """
            match = re.match('(.*?)([0-9]+)$', conn)
            if match:
                pre = match.group(1)
            else:
                pre = f'{conn}_'
            i = 0
            while f'{pre}{i}' in names:
                i += 1
            return f'{pre}{i}'

        def replace(tasklet, repl_dict):
            """ Renames connectors based on the input replacement dictionary.
            """
            if tasklet.language is dtypes.Language.Python:
                repl = ConnectorRenamer(repl_dict)
                for stmt in tasklet.code.code:
                    repl.visit(stmt)
            elif tasklet.language is dtypes.Language.CPP:
                for old, new in repl_dict.items():
                    tasklet.code.code = re.sub(r'\b%s\b' % re.escape(old), new,
                                               tasklet.code.as_string)

        def replace_lhs(tasklet, repl_dict):
            """ Replaces assignments' LHS based on the input replacement
                dictionary. This is used only on CPP tasklets.
            """
            if tasklet.language is dtypes.Language.Python:
                raise ValueError(
                    "This method should only be used with CPP Tasklets")
            elif tasklet.language is dtypes.Language.CPP:
                for old, new in repl_dict.items():
                    tasklet.code.code = re.sub(
                        r'(?<!auto\s)%s[\s\t]*=' % re.escape(old), new,
                        tasklet.code.as_string)

        def extract_lhs(tasklet) -> Set[str]:
            """ Returns the LHS of assignments in Tasklet code.
            """
            if tasklet.language is dtypes.Language.Python:
                extr = PythonLHSExtractor()
                for stmt in tasklet.code.code:
                    extr.visit(stmt)
                return extr.assignments
            elif tasklet.language is dtypes.Language.CPP:
                rhs = set()
                for match in re.findall('[\s\t\n\r]*([\w]*)[\s\t]*=',
                                        tasklet.code.code):
                    rhs.add(match)
                return rhs

        rdict = dict()
        rdict_inout = dict()

        # Find names of current and former connectors
        # (assignments' LHS that are not connectors).
        t1_names = t1.in_connectors.keys() | t1.out_connectors.keys()
        t1_rhs = extract_lhs(t1)
        if t1_rhs:
            t1_names |= t1_rhs
        t2_names = t2.in_connectors.keys() | t2.out_connectors.keys()
        t2_rhs = extract_lhs(t2)
        if t2_rhs:
            t2_names |= t2_rhs

        # Change t2 connector names.
        nlist = list(t2_names)
        for name in nlist:
            if name in t1_names:
                newname = rename_conn(name, t1_names | t2_names)
                rdict[name] = newname
                t2_names.remove(name)
                t2_names.add(newname)
        if rdict:
            replace(t2, rdict)

        # Handle input edges.
        inconn = {}
        for e in graph.in_edges(t1):
            inconn[e.dst_conn] = t1.in_connectors[e.dst_conn]
        for e in graph.in_edges(t2):
            graph.remove_edge(e)
            conn = e.dst_conn
            if conn in rdict.keys():
                conn = rdict[conn]
            if e.src is t1:
                rdict_inout[conn] = e.src_conn
            else:
                inconn[conn] = t2.in_connectors[e.dst_conn]
                graph.add_edge(e.src, e.src_conn, t1, conn, e.data)

        # Handle output edges.
        outconn = {}
        for e in graph.out_edges(t1):
            outconn[e.src_conn] = t1.out_connectors[e.src_conn]
        for e in graph.out_edges(t2):
            graph.remove_edge(e)
            conn = e.src_conn
            if conn in rdict:
                conn = rdict[conn]
            outconn[conn] = t2.out_connectors[e.src_conn]
            graph.add_edge(t1, conn, e.dst, e.dst_conn, e.data)

        # Rename in-out connectors.
        if rdict_inout:
            replace(t2, rdict_inout)

        # Update t1 connectors and code.
        t1.in_connectors = inconn
        t1.out_connectors = outconn
        if t1.language is dtypes.Language.Python:
            t1.code.code.extend(t2.code.code)
        elif t1.language is dtypes.Language.CPP:
            t1.code.code += f'\n{t2.code.code}'
        graph.remove_node(t2)

        # Fix CPP assignemnt LHS that are not connectors.
        if t1.language is dtypes.Language.CPP:
            rhs = extract_lhs(t1)
            repl_dict = dict()
            for name in rhs:
                if name not in inconn and name not in outconn:
                    repl_dict[name] = f'auto {name} ='
            if repl_dict:
                replace_lhs(t1, repl_dict)
Пример #15
0
def count_moved_data(sdfg: dace.SDFG, symbols: Dict[str, Any] = None) -> int:
    result = 0
    symbols = symbols or {}
    for state in sdfg.nodes():
        result += count_moved_data_state(state, symbols)
    return result
Пример #16
0
def _add_ort_init_code(sdfg: SDFG):
    """ Add onnxruntime initialization code to the SDFG if required """

    if "OrtKernelSession" not in sdfg.global_code['frame'].as_string:
        sdfg.append_global_code("""
        // Start global ORT setup
        const OrtApi* __ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION);

        // helper function to check for status
        void __ort_check_status(OrtStatus* status)
        {
            if (status != NULL) {
                const char* msg = __ort_api->GetErrorMessage(status);
                fprintf(stderr, "%s\\n", msg);
                __ort_api->ReleaseStatus(status);
                exit(1);
            }
        }
        OrtEnv* __ort_env;
        OrtKernelSession* __ort_session;
        OrtSessionOptions* __ort_session_options;

        OrtMemoryInfo* __ort_cpu_mem_info;
        """)

        sdfg.append_init_code("""
        __ort_check_status(__ort_api->CreateCpuMemoryInfo(OrtDeviceAllocator, OrtMemTypeDefault, &__ort_cpu_mem_info));
        __ort_check_status(__ort_api->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "dace_graph", &__ort_env));
        __ort_check_status(__ort_api->CreateSessionOptions(&__ort_session_options));
        __ort_check_status(OrtSessionOptionsAppendExecutionProvider_CPU(__ort_session_options, /*use_arena=*/0));
        """)

        session_cleanup_code = """
        __ort_api->ReleaseMemoryInfo(__ort_cpu_mem_info);
        __ort_api->ReleaseKernelSession(__ort_session);
        __ort_api->ReleaseSessionOptions(__ort_session_options);
        __ort_api->ReleaseEnv(__ort_env);
        """

        if any(
                hasattr(node, "schedule")
                and node.schedule == ScheduleType.GPU_Device
                for state in sdfg.nodes() for node in state.nodes()):
            # if the SDFG contains a GPU node, add the CUDA provider and the memory_info
            sdfg.append_global_code("OrtMemoryInfo* __ort_cuda_mem_info;\n")
            sdfg.append_global_code(
                "OrtMemoryInfo* __ort_cuda_pinned_mem_info;\n")
            sdfg.append_init_code("""
            __ort_check_status(__ort_api->CreateMemoryInfo("Cuda", /*allocator_type=*/OrtDeviceAllocator, /*device=*/0, /*mem_type=*/OrtMemTypeDefault, &__ort_cuda_mem_info));
            __ort_check_status(__ort_api->CreateMemoryInfo("CudaPinned", /*allocator_type=*/OrtDeviceAllocator, /*device=*/0, /*mem_type=*/OrtMemTypeCPU, &__ort_cuda_pinned_mem_info));
            __ort_check_status(OrtSessionOptionsAppendExecutionProvider_CUDA(__ort_session_options, /*device=*/0));
            """)
            session_cleanup_code = ("""
            __ort_api->ReleaseMemoryInfo(__ort_cuda_mem_info);
            __ort_api->ReleaseMemoryInfo(__ort_cuda_pinned_mem_info);
            """ + session_cleanup_code)

        sdfg.append_global_code("// End global ORT setup\n")
        sdfg.prepend_exit_code(session_cleanup_code)
        sdfg.append_init_code("""
        __ort_check_status(__ort_api->CreateKernelSession(__ort_session_options, &__ort_session, 12));
        """)
Пример #17
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the map and its entry and exit nodes.
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[MapExpansion._map_entry]]
        map_exit = graph.exit_node(map_entry)
        current_map = map_entry.map

        # Create new maps
        new_maps = [
            nodes.Map(current_map.label + '_' + str(param), [param],
                      subsets.Range([param_range]),
                      schedule=dtypes.ScheduleType.Sequential) for param,
            param_range in zip(current_map.params[1:], current_map.range[1:])
        ]
        current_map.params = [current_map.params[0]]
        current_map.range = subsets.Range([current_map.range[0]])

        # Create new map entries and exits
        entries = [nodes.MapEntry(new_map) for new_map in new_maps]
        exits = [nodes.MapExit(new_map) for new_map in new_maps]

        # Create edges, abiding by the following rules:
        # 1. If there are no edges coming from the outside, use empty memlets
        # 2. Edges with IN_* connectors replicate along the maps
        # 3. Edges for dynamic map ranges replicate until reaching range(s)
        for edge in graph.out_edges(map_entry):
            graph.remove_edge(edge)
            graph.add_memlet_path(map_entry,
                                  *entries,
                                  edge.dst,
                                  src_conn=edge.src_conn,
                                  memlet=edge.data,
                                  dst_conn=edge.dst_conn)

        # Modify dynamic map ranges
        dynamic_edges = dace.sdfg.dynamic_map_inputs(graph, map_entry)
        for edge in dynamic_edges:
            # Remove old edge and connector
            graph.remove_edge(edge)
            edge.dst.remove_in_connector(edge.dst_conn)

            # Propagate to each range it belongs to
            path = []
            for mapnode in [map_entry] + entries:
                path.append(mapnode)
                if any(edge.dst_conn in map(str, symbolic.symlist(r))
                       for r in mapnode.map.range):
                    graph.add_memlet_path(edge.src,
                                          *path,
                                          memlet=edge.data,
                                          src_conn=edge.src_conn,
                                          dst_conn=edge.dst_conn)

        # Create new map exits
        for edge in graph.in_edges(map_exit):
            graph.remove_edge(edge)
            graph.add_memlet_path(edge.src,
                                  *exits[::-1],
                                  map_exit,
                                  memlet=edge.data,
                                  src_conn=edge.src_conn,
                                  dst_conn=edge.dst_conn)
Пример #18
0
    def apply(self, sdfg: dace.SDFG):
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[self._map_entry]]
        map_exit = graph.exit_node(map_entry)

        sz = dace.symbol('commsize',
                         dtype=dace.int32,
                         integer=True,
                         positive=True)
        Px = dace.symbol('Px', dtype=dace.int32, integer=True, positive=True)
        Py = dace.symbol('Py', dtype=dace.int32, integer=True, positive=True)

        def _prod(sequence):
            return reduce(lambda a, b: a * b, sequence, 1)

        # NOTE: Maps with step in their ranges are currently not supported
        if len(map_entry.map.params) == 2:
            params = map_entry.map.params
            ranges = [None] * 2
            b, e, _ = map_entry.map.range[0]
            ranges[0] = (0, (e - b + 1) / Px - 1, 1)
            b, e, _ = map_entry.map.range[1]
            ranges[1] = (0, (e - b + 1) / Py - 1, 1)
            strides = [1]
        else:
            params = ['__iflat']
            sizes = map_entry.map.range.size_exact()
            total_size = _prod(sizes)
            ranges = [(0, (total_size) / sz - 1, 1)]
            strides = [_prod(sizes[i + 1:]) for i in range(len(sizes))]

        root_name = sdfg.temp_data_name()
        sdfg.add_scalar(root_name, dace.int32, transient=True)
        root_node = graph.add_access(root_name)
        root_tasklet = graph.add_tasklet('_set_root_', {}, {'__out'},
                                         '__out = 0')
        graph.add_edge(root_tasklet, '__out', root_node, None,
                       dace.Memlet.simple(root_name, '0'))

        from dace.libraries.mpi import Bcast
        from dace.libraries.pblas import BlockCyclicScatter, BlockCyclicGather

        inputs = set()
        for src, _, _, _, m in graph.in_edges(map_entry):
            if not isinstance(src, nodes.AccessNode):
                raise NotImplementedError
            desc = src.desc(sdfg)
            if not isinstance(desc, (data.Scalar, data.Array)):
                raise NotImplementedError
            if list(desc.shape) != m.src_subset.size_exact():
                # Second attempt
                # TODO: We need a solution for symbols not matching
                if str(list(desc.shape)) != str(m.src_subset.size_exact()):
                    raise NotImplementedError
            inputs.add(src)

        for inp in inputs:
            desc = inp.desc(sdfg)

            if isinstance(desc, data.Scalar):
                local_access = graph.add_access(inp.data)
                bcast_node = Bcast('_Bcast_')
                graph.add_edge(inp, None, bcast_node, '_inbuffer',
                               dace.Memlet.from_array(inp.data, desc))
                graph.add_edge(root_node, None, bcast_node, '_root',
                               dace.Memlet.simple(root_name, '0'))
                graph.add_edge(bcast_node, '_outbuffer', local_access, None,
                               dace.Memlet.from_array(inp.data, desc))
                for e in graph.edges_between(inp, map_entry):
                    graph.add_edge(local_access, None, map_entry, e.dst_conn,
                                   dace.Memlet.from_array(inp.data, desc))
                    graph.remove_edge(e)

            elif isinstance(desc, data.Array):

                local_name, local_arr = sdfg.add_temp_transient(
                    [(desc.shape[0]) // Px, (desc.shape[1]) // Py],
                    dtype=desc.dtype,
                    storage=desc.storage)
                local_access = graph.add_access(local_name)
                bsizes_name, bsizes_arr = sdfg.add_temp_transient(
                    (2, ), dtype=dace.int32)
                bsizes_access = graph.add_access(bsizes_name)
                bsizes_tasklet = nodes.Tasklet(
                    '_set_bsizes_', {}, {'__out'},
                    "__out[0] = {x}; __out[1] = {y}".format(
                        x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py))
                graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None,
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                gdesc_name, gdesc_arr = sdfg.add_temp_transient(
                    (9, ), dtype=dace.int32)
                gdesc_access = graph.add_access(gdesc_name)
                ldesc_name, ldesc_arr = sdfg.add_temp_transient(
                    (9, ), dtype=dace.int32)
                ldesc_access = graph.add_access(ldesc_name)
                scatter_node = BlockCyclicScatter('_Scatter_')
                graph.add_edge(inp, None, scatter_node, '_inbuffer',
                               dace.Memlet.from_array(inp.data, desc))
                graph.add_edge(bsizes_access, None, scatter_node,
                               '_block_sizes',
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                graph.add_edge(scatter_node, '_outbuffer', local_access, None,
                               dace.Memlet.from_array(local_name, local_arr))
                graph.add_edge(scatter_node, '_gdescriptor', gdesc_access,
                               None,
                               dace.Memlet.from_array(gdesc_name, gdesc_arr))
                graph.add_edge(scatter_node, '_ldescriptor', ldesc_access,
                               None,
                               dace.Memlet.from_array(ldesc_name, ldesc_arr))
                for e in graph.edges_between(inp, map_entry):
                    graph.add_edge(
                        local_access, None, map_entry, e.dst_conn,
                        dace.Memlet.from_array(local_name, local_arr))
                    graph.remove_edge(e)
                for e in graph.out_edges(map_entry):
                    if e.data.data == inp.data:
                        e.data.data = local_name

            else:
                raise NotImplementedError

        outputs = set()
        for _, _, dst, _, m in graph.out_edges(map_exit):
            if not isinstance(dst, nodes.AccessNode):
                raise NotImplementedError
            desc = dst.desc(sdfg)
            if not isinstance(desc, data.Array):
                raise NotImplementedError
            try:
                if list(desc.shape) != m.dst_subset.size_exact():
                    # Second attempt
                    # TODO: We need a solution for symbols not matching
                    if str(list(desc.shape)) != str(m.dst_subset.size_exact()):
                        raise NotImplementedError
            except AttributeError:
                if list(desc.shape) != m.subset.size_exact():
                    # Second attempt
                    # TODO: We need a solution for symbols not matching
                    if str(list(desc.shape)) != str(m.subset.size_exact()):
                        raise NotImplementedError
            outputs.add(dst)

        for out in outputs:
            desc = out.desc(sdfg)
            if isinstance(desc, data.Scalar):
                raise NotImplementedError
            elif isinstance(desc, data.Array):
                local_name, local_arr = sdfg.add_temp_transient(
                    [(desc.shape[0]) // Px, (desc.shape[1]) // Py],
                    dtype=desc.dtype,
                    storage=desc.storage)
                local_access = graph.add_access(local_name)
                bsizes_name, bsizes_arr = sdfg.add_temp_transient(
                    (2, ), dtype=dace.int32)
                bsizes_access = graph.add_access(bsizes_name)
                bsizes_tasklet = nodes.Tasklet(
                    '_set_bsizes_', {}, {'__out'},
                    "__out[0] = {x}; __out[1] = {y}".format(
                        x=(desc.shape[0]) // Px, y=(desc.shape[1]) // Py))
                graph.add_edge(bsizes_tasklet, '__out', bsizes_access, None,
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                scatter_node = BlockCyclicGather('_Gather_')
                graph.add_edge(local_access, None, scatter_node, '_inbuffer',
                               dace.Memlet.from_array(local_name, local_arr))
                graph.add_edge(bsizes_access, None, scatter_node,
                               '_block_sizes',
                               dace.Memlet.from_array(bsizes_name, bsizes_arr))
                graph.add_edge(scatter_node, '_outbuffer', out, None,
                               dace.Memlet.from_array(out.data, desc))

                for e in graph.edges_between(map_exit, out):
                    graph.add_edge(
                        map_exit, e.src_conn, local_access, None,
                        dace.Memlet.from_array(local_name, local_arr))
                    graph.remove_edge(e)
                for e in graph.in_edges(map_exit):
                    if e.data.data == out.data:
                        e.data.data = local_name
            else:
                raise NotImplementedError

        map_entry.map.params = params
        map_entry.map.range = subsets.Range(ranges)
Пример #19
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the subgraph, execute it and insert an AccessNode to the result

        parent: ONNXModel = sdfg._parent_onnx_model
        state = sdfg.nodes()[self.state_id]
        node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]]

        if isinstance(node, donnx.ONNXShape):
            # if we have a shape node, replace it with a constant
            assert len(state.in_edges(node)) == 1
            shape_in_edge = state.in_edges(node)[0]
            assert shape_in_edge.dst_conn == "data"
            shape_desc = sdfg.arrays[shape_in_edge.src.data]

            constant_name = sdfg.temp_data_name()
            clean_constant_name = clean_onnx_name(constant_name)
            sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ),
                           dace.int64)

            assert constant_name not in parent.clean_weights
            parent.weights[constant_name] = np.array(shape_desc.shape,
                                                     np.int64)

            assert len(state.out_edges(node)) == 1
            output_edge = state.out_edges(node)[0]
            access_shape = state.add_access(clean_constant_name)
            state.add_edge(access_shape, None, output_edge.dst,
                           output_edge.dst_conn,
                           sdfg.make_array_memlet(clean_constant_name))
        else:
            # otherwise compute the result of the op
            sub_sdfg = dace.SDFG("sub_sdfg")
            sub_state = sub_sdfg.add_state()

            node_copy = copy.deepcopy(node)
            sub_state.add_node(node_copy)

            inputs = {}
            for edge in state.in_edges(node):
                # we know from can_be_applied that all in edges are from AccessNodes
                assert (isinstance(edge.src, nd.AccessNode)
                        and hasattr(sdfg, "_parent_onnx_model") and
                        edge.src.data in sdfg._parent_onnx_model.clean_weights)

                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc)

                input_value = sdfg._parent_onnx_model.clean_weights[
                    edge.src.data]

                if len(input_value.shape) == 0:
                    inputs['array_' + edge.dst_conn] = input_value[()]
                else:
                    inputs['array_' + edge.dst_conn] = input_value.copy()

                access = sub_state.add_access('array_' + edge.dst_conn)
                sub_state.add_edge(
                    access, None, node_copy, edge.dst_conn,
                    sub_sdfg.make_array_memlet('array_' + edge.dst_conn))

            outputs = {}
            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                if isinstance(desc, dt.Scalar):
                    # we need to copy to an array of size [1] so that we can "return" the output from the sdfg
                    desc.transient = True
                    sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn,
                                          desc)
                    sub_sdfg.add_array('array_' + edge.src_conn, [1],
                                       desc.dtype,
                                       transient=False)

                    access_scalar = sub_state.add_access('scalar_array_' +
                                                         edge.src_conn)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access_scalar, None,
                        sub_sdfg.make_array_memlet('scalar_array_' +
                                                   edge.src_conn))

                    sub_state.add_edge(
                        access_scalar, None, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))
                else:
                    desc.transient = False
                    sub_sdfg.add_datadesc('array_' + edge.src_conn, desc)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))

                if len(desc.shape) == 0:
                    outputs['array_' + edge.src_conn] = np.empty(
                        (1, ), desc.dtype.as_numpy_dtype())
                else:
                    outputs['array_' + edge.src_conn] = np.empty(
                        tuple(desc.shape), desc.dtype.as_numpy_dtype())

            sub_sdfg(**outputs, **inputs)

            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                output_value = outputs['array_' + edge.src_conn]

                constant_name = sdfg.temp_data_name()
                clean_constant_name = clean_onnx_name(constant_name)
                sdfg.add_datadesc(clean_constant_name, desc)

                assert constant_name not in parent.weights
                if isinstance(desc, dt.Scalar):
                    parent.weights[constant_name] = output_value.reshape(())
                else:
                    parent.weights[constant_name] = output_value

                access_constant = state.add_access(clean_constant_name)
                state.add_edge(access_constant, None, edge.dst, edge.dst_conn,
                               sdfg.make_array_memlet(clean_constant_name))

        # remove all now useless nodes with a reverse BFS
        queue = deque([node])
        while len(queue) > 0:
            current_node = queue.popleft()

            edges = state.in_edges(current_node)
            state.remove_node(current_node)
            for e in edges:
                next_node = e.src
                if len(state.out_edges(next_node)) == 0:
                    queue.append(next_node)
Пример #20
0
    def apply(self, sdfg: SDFG) -> Union[Any, None]:
        # Load/parse infos from the SDFG
        graph = sdfg.nodes()[self.state_id]
        src = graph.nodes()[self.subgraph[BankSplit._src_node]]
        dst = graph.nodes()[self.subgraph[BankSplit._dst_node]]
        src_array = sdfg.arrays[src.data]
        dst_array = sdfg.arrays[dst.data]
        collect_src = len(src_array.shape) - 1 == len(
            dst_array.shape
        )  # If this is not true we have to distribute to dst (checked in can_apply)
        if collect_src:
            bank_count = int(src_array.shape[0])
            true_size = dst_array.shape
        else:
            bank_count = int(dst_array.shape[0])
            true_size = src_array.shape
        ndim = len(true_size)

        # Move Default storage
        if sdfg.arrays[src.data].storage == dtypes.StorageType.Default:
            sdfg.arrays[src.data].storage = self.default_to_storage
        if sdfg.arrays[dst.data].storage == dtypes.StorageType.Default:
            sdfg.arrays[dst.data].storage = self.default_to_storage

        # Figure out how to split
        if self.split_array_info is None:
            split_info = [1] * ndim
            split_info[0] = bank_count
        else:
            split_info = self.split_array_info
            if len(split_info) != ndim:
                raise RuntimeError(
                    "Length of split_array_info must match number of "
                    "dimensions")
        if functools.reduce(lambda a, b: a * b, split_info) != bank_count:
            raise RuntimeError(
                "Splitting is not possible with the selected splits"
                "and this number of HBM-banks (required number of banks "
                "!= actual number of banks)")

        # create the copy-subgraph
        ndrange = dict()
        usable_params = []
        for i in range(ndim):
            usable_params.append(f"i{i}")
        for i in range(ndim):
            ndrange[usable_params[i]] = f"0:{split_info[i]}"
        graph.remove_edge_and_connectors(graph.edges_between(src, dst)[0])
        copy_map_enter, copy_map_exit = graph.add_map(
            "hbm_bank_split", ndrange, dtypes.ScheduleType.Unrolled)
        graph.add_edge(copy_map_enter, None, src, None, memlet.Memlet())
        graph.add_edge(dst, None, copy_map_exit, None, memlet.Memlet())

        target_size = [
            str(x) for x in self._get_split_size(true_size, split_info)
        ]
        target_hbm_bank = []
        for i in range(ndim):
            target_hbm_bank.append(usable_params[i])
            for j in range(i):
                target_hbm_bank[j] = f"{split_info[i]}*{target_hbm_bank[j]}"
        target_offset = []
        for i in range(ndim):
            target_offset.append(f"{usable_params[i]}*{target_size[i]}")

        target_size_str = ", ".join(
            [f"{x}:{y}" for x, y in zip([0] * ndim, target_size)])
        target_hbm_bank_str = "+ ".join(target_hbm_bank)
        target_offset_str = ", ".join(
            [f"({x}):({x}+{y})" for x, y in zip(target_offset, target_size)])
        if collect_src:
            copy_memlet = memlet.Memlet(
                f"{src.data}[{target_hbm_bank_str}, {target_size_str}]->"
                f"{target_offset_str}")
        else:
            copy_memlet = memlet.Memlet(
                f"{src.data}[{target_offset_str}]->{target_hbm_bank_str}, "
                f"{target_size_str}")
        graph.add_edge(src, None, dst, None, copy_memlet)
Пример #21
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the map and its entry and exit nodes.
        graph = sdfg.nodes()[self.state_id]
        map_entry = graph.nodes()[self.subgraph[MapExpansion._map_entry]]
        map_exit = graph.exit_nodes(map_entry)[0]
        current_map = map_entry.map

        # Create new maps
        maps = [
            nodes.Map(current_map.label + '_' + str(param), [param],
                      subsets.Range([param_range]),
                      schedule=dtypes.ScheduleType.Sequential) for param,
            param_range in zip(current_map.params, current_map.range)
        ]
        maps[0]._schedule = dtypes.ScheduleType.Default

        # Create new map entries
        entries = [nodes.MapEntry(new_map) for new_map in maps]
        entries[0].in_connectors = map_entry.in_connectors
        entries[0].out_connectors = map_entry.out_connectors
        num_entry_out_edges = len(graph.out_edges(map_entry))
        for i in range(1, len(entries)):
            entries[i].in_connectors = set('IN_' + str(i + 1)
                                           for i in range(num_entry_out_edges))
            entries[i].out_connectors = set(
                'OUT_' + str(i + 1) for i in range(num_entry_out_edges))

        # Create new map exits
        exits = [nodes.MapExit(new_map) for new_map in maps]
        exits.reverse()
        exits[-1].in_connectors = map_exit.in_connectors
        exits[-1].out_connectors = map_exit.out_connectors
        num_entry_out_edges = len(graph.out_edges(map_exit))
        for i in range(0, len(exits) - 1):
            exits[i].in_connectors = set('IN_' + str(i + 1)
                                         for i in range(num_entry_out_edges))
            exits[i].out_connectors = set('OUT_' + str(i + 1)
                                          for i in range(num_entry_out_edges))

        # Add new nodes to state
        graph.add_nodes_from(entries)
        graph.add_nodes_from(exits)

        # Redirect edges to new nodes
        dace.graph.nxutil.change_edge_dest(graph, map_entry, entries[0])
        dace.graph.nxutil.change_edge_src(graph, map_exit, exits[-1])

        for i, e in enumerate(graph.out_edges(map_entry)):
            graph.remove_edge(e)
            graph.add_edge(entries[0], e.src_conn, entries[1],
                           'IN_' + str(i + 1), copy.deepcopy(e.data))
            graph.add_edge(entries[-1], 'OUT_' + str(i + 1), e.dst, e.dst_conn,
                           copy.deepcopy(e.data))
            for j in range(1, len(entries) - 1):
                graph.add_edge(entries[j], 'OUT_' + str(i + 1), entries[j + 1],
                               'IN_' + str(i + 1), copy.deepcopy(e.data))
        for i, e in enumerate(graph.in_edges(map_exit)):
            graph.remove_edge(e)
            graph.add_edge(e.src, e.src_conn, exits[0], 'IN_' + str(i + 1),
                           copy.deepcopy(e.data))
            graph.add_edge(exits[-2], 'OUT_' + str(i + 1), exits[-1],
                           e.dst_conn, copy.deepcopy(e.data))
            for j in range(0, len(exits) - 2):
                graph.add_edge(exits[j], 'OUT_' + str(i + 1), exits[j + 1],
                               'IN_' + str(i + 1), copy.deepcopy(e.data))

        # Remove old nodes
        graph.remove_node(map_entry)
        graph.remove_node(map_exit)
Пример #22
0
    def apply_pass(self,
                   sdfg: SDFG,
                   _,
                   initial_symbols: Optional[Dict[str, Any]] = None
                   ) -> Optional[Set[str]]:
        """
        Propagates constants throughout the SDFG.
        :param sdfg: The SDFG to modify.
        :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass
                                 results as ``{Pass subclass name: returned object from pass}``. If not run in a
                                 pipeline, an empty dictionary is expected.
        :param initial_symbols: If not None, sets values of initial symbols.
        :return: A set of propagated constants, or None if nothing was changed.
        """
        initial_symbols = initial_symbols or {}

        # Early exit if no constants can be propagated
        if not initial_symbols and not self.should_apply(sdfg):
            result = {}
        else:
            # Trace all constants and symbols through states
            per_state_constants: Dict[SDFGState,
                                      Dict[str, Any]] = self.collect_constants(
                                          sdfg, initial_symbols)

            # Keep track of replaced and ambiguous symbols
            symbols_replaced: Dict[str, Any] = {}
            remaining_unknowns: Set[str] = set()

            # Collect symbols from symbol-dependent data descriptors
            # If there can be multiple values over the SDFG, the symbols are not propagated
            desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(
                sdfg, per_state_constants)

            # Replace constants per state
            for state, mapping in per_state_constants.items():
                remaining_unknowns.update({
                    k
                    for k, v in mapping.items()
                    if v is _UnknownValue or k in multivalue_desc_symbols
                })
                mapping = {
                    k: v
                    for k, v in mapping.items() if v is not _UnknownValue
                    and k not in multivalue_desc_symbols
                }

                # Update replaced symbols for later replacements
                symbols_replaced.update(mapping)

                # Replace in state contents
                state.replace_dict(mapping)
                # Replace in outgoing edges as well
                for e in sdfg.out_edges(state):
                    e.data.replace_dict(mapping, replace_keys=False)

            # If symbols are never unknown any longer, remove from SDFG
            result = {
                k: v
                for k, v in symbols_replaced.items()
                if k not in remaining_unknowns
            }
            # Remove from symbol repository
            for sym in result:
                if sym in sdfg.symbols:
                    sdfg.remove_symbol(sym)

            # Remove single-valued symbols from data descriptors (e.g., symbolic array size)
            sdfg.replace_dict(
                {k: v
                 for k, v in result.items() if k in desc_symbols},
                replace_in_graph=False,
                replace_keys=False)

            # Remove constant symbol assignments in interstate edges
            for edge in sdfg.edges():
                intersection = result & edge.data.assignments.keys()
                for sym in intersection:
                    del edge.data.assignments[sym]

        result = set(result.keys())

        if self.recursive:
            # Change result to set of tuples
            sid = sdfg.sdfg_id
            result = set((sid, sym) for sym in result)

            for state in sdfg.nodes():
                for node in state.nodes():
                    if isinstance(node, nodes.NestedSDFG):
                        nested_id = node.sdfg.sdfg_id
                        const_syms = {
                            k: v
                            for k, v in node.symbol_mapping.items()
                            if not symbolic.issymbolic(v)
                        }
                        internal = self.apply_pass(node.sdfg, _, const_syms)
                        if internal:
                            for nid, removed in internal:
                                result.add((nid, removed))
                                # Remove symbol mapping if constant was completely propagated
                                if nid == nested_id and removed in node.symbol_mapping:
                                    del node.symbol_mapping[removed]

        # Return result
        if not result:
            return None
        return result
Пример #23
0
    def find_dead_states(
            self,
            sdfg: SDFG,
            set_unconditional_edges: bool = True) -> Set[SDFGState]:
        '''
        Finds "dead" (unreachable) states in an SDFG. A state is deemed unreachable if it is:
            * Unreachable from the starting state
            * Conditions leading to it will always evaluate to False
            * There is another unconditional (always True) inter-state edge that leads to another state

        :param sdfg: The SDFG to traverse.
        :param set_unconditional_edges: If True, conditions of edges evaluated as unconditional are removed.
        :return: A set of unreachable states.
        '''
        visited: Set[SDFGState] = set()

        # Run a modified BFS where definitely False edges are not traversed, or if there is an
        # unconditional edge the rest are not. The inverse of the visited states is the dead set.
        queue = collections.deque([sdfg.start_state])
        while len(queue) > 0:
            node = queue.popleft()
            if node in visited:
                continue
            visited.add(node)

            # First, check for unconditional edges
            unconditional = None
            for e in sdfg.out_edges(node):
                # If an unconditional edge is found, ignore all other outgoing edges
                if self.is_definitely_taken(e.data):
                    # If more than one unconditional outgoing edge exist, fail with Invalid SDFG
                    if unconditional is not None:
                        raise InvalidSDFGInterstateEdgeError(
                            'Multiple unconditional edges leave the same state',
                            sdfg, sdfg.edge_id(e))
                    unconditional = e
                    if set_unconditional_edges and not e.data.is_unconditional(
                    ):
                        # Annotate edge as unconditional
                        e.data.condition = CodeBlock('1')

                    # Continue traversal through edge
                    if e.dst not in visited:
                        queue.append(e.dst)
                        continue
            if unconditional is not None:  # Unconditional edge exists, skip traversal
                continue
            # End of unconditional check

            # Check outgoing edges normally
            for e in sdfg.out_edges(node):
                next_node = e.dst

                # Test for edges that definitely evaluate to False
                if self.is_definitely_not_taken(e.data):
                    continue

                # Continue traversal through edge
                if next_node not in visited:
                    queue.append(next_node)

        # Dead states are states that are not live (i.e., visited)
        return set(sdfg.nodes()) - visited
Пример #24
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the subgraph, execute it and insert an AccessNode to the result
        # this method of execution is slow but simple. A better option would be to call the ORT
        # C API from a python object (like the OpChecker).

        parent: ONNXModel = sdfg._parent_onnx_model
        state = sdfg.nodes()[self.state_id]
        node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]]
        log.debug(f"Applying constant folding: {node} in {state}")

        if isinstance(node, donnx.ONNXShape):
            # if we have a shape node, replace it with a constant
            assert len(state.in_edges(node)) == 1
            shape_in_edge = state.in_edges(node)[0]
            assert shape_in_edge.dst_conn == "data"
            shape_desc = sdfg.arrays[shape_in_edge.src.data]

            constant_name = sdfg.temp_data_name()
            clean_constant_name = clean_onnx_name(constant_name)
            sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ),
                           dace.int64)

            assert constant_name not in parent.clean_weights
            parent.weights[constant_name] = torch.from_numpy(
                np.array(shape_desc.shape, np.int64))

            assert len(state.out_edges(node)) == 1
            output_edge = state.out_edges(node)[0]
            access_shape = state.add_access(clean_constant_name)
            state.add_edge(access_shape, None, output_edge.dst,
                           output_edge.dst_conn,
                           sdfg.make_array_memlet(clean_constant_name))
        else:
            # otherwise compute the result of the op
            global UNIQUE_ID
            UNIQUE_ID += 1
            sub_sdfg = dace.SDFG("sub_sdfg_" + str(UNIQUE_ID))
            sub_state = sub_sdfg.add_state()

            node_copy = copy.deepcopy(node)
            sub_state.add_node(node_copy)

            inputs = {}
            for edge in state.in_edges(node):
                # we know from can_be_applied that all in edges are from AccessNodes
                assert (isinstance(edge.src, nd.AccessNode)
                        and hasattr(sdfg, "_parent_onnx_model") and
                        edge.src.data in sdfg._parent_onnx_model.clean_weights)

                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc)

                input_value = sdfg._parent_onnx_model.clean_weights[
                    edge.src.data]

                if len(input_value.shape) == 0:
                    inputs['array_' +
                           edge.dst_conn] = input_value.cpu().numpy()[()]
                else:
                    inputs['array_' + edge.dst_conn] = input_value.clone()

                access = sub_state.add_access('array_' + edge.dst_conn)
                sub_state.add_edge(
                    access, None, node_copy, edge.dst_conn,
                    sub_sdfg.make_array_memlet('array_' + edge.dst_conn))

            outputs = {}
            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                if isinstance(desc, dt.Scalar):
                    # we need to copy to an array of size [1] so that we can "return" the output from the sdfg
                    desc.transient = True
                    sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn,
                                          desc)
                    sub_sdfg.add_array('array_' + edge.src_conn, [1],
                                       desc.dtype,
                                       transient=False)

                    access_scalar = sub_state.add_access('scalar_array_' +
                                                         edge.src_conn)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access_scalar, None,
                        sub_sdfg.make_array_memlet('scalar_array_' +
                                                   edge.src_conn))

                    sub_state.add_edge(
                        access_scalar, None, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))
                else:
                    desc.transient = False
                    sub_sdfg.add_datadesc('array_' + edge.src_conn, desc)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))

                if len(desc.shape) == 0:
                    empty_array = np.empty((1, ), desc.dtype.as_numpy_dtype())
                else:
                    empty_array = np.empty(tuple(desc.shape),
                                           desc.dtype.as_numpy_dtype())

                empty_array = torch.from_numpy(empty_array)

                if desc.storage is dtypes.StorageType.GPU_Global:
                    empty_array = empty_array.cuda()

                outputs['array_' + edge.src_conn] = empty_array

            sub_sdfg(**outputs, **inputs)

            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                output_value = outputs['array_' + edge.src_conn]

                constant_name = sdfg.temp_data_name()
                clean_constant_name = clean_onnx_name(constant_name)
                sdfg.add_datadesc(clean_constant_name, desc)

                assert constant_name not in parent.weights
                assert type(output_value) is torch.Tensor

                if not dtypes.can_access(dtypes.ScheduleType.CPU_Multicore,
                                         desc.storage):
                    cpu_desc = copy.deepcopy(desc)
                    cpu_desc.storage = dtypes.StorageType.CPU_Heap
                    cpu_desc.transient = False
                    desc.transient = True
                    copy_in_name = sdfg.temp_data_name()
                    clean_copy_in_name = clean_onnx_name(copy_in_name)
                    sdfg.add_datadesc(clean_copy_in_name, cpu_desc)

                    access_constant = state.add_access(clean_constant_name)
                    state.add_edge(state.add_read(clean_copy_in_name), None,
                                   access_constant, None,
                                   sdfg.make_array_memlet(clean_copy_in_name))

                    name_to_add = copy_in_name
                else:
                    access_constant = state.add_read(clean_constant_name)
                    name_to_add = constant_name

                if isinstance(desc, dt.Scalar):
                    parent.weights[name_to_add] = output_value.reshape(())
                else:
                    parent.weights[name_to_add] = output_value

                state.add_edge(access_constant, None, edge.dst, edge.dst_conn,
                               sdfg.make_array_memlet(clean_constant_name))

        # remove all now useless nodes with a reverse BFS
        remove_node_and_computation(sdfg, state, node)