Exemplo n.º 1
0
    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        # Is this even a loop
        if not super().can_be_applied(graph, expr_index, sdfg, permissive):
            return False

        # Obtain loop information
        guard: sd.SDFGState = self.loop_guard
        body: sd.SDFGState = self.loop_begin

        # Obtain iteration variable, range, and stride
        loop_info = find_for_loop(sdfg, guard, body)
        if not loop_info:
            return False
        _, (start, end, step), _ = loop_info

        try:
            if step > 0 and start + step < end + 1:
                return False
            if step < 0 and start + step > end - 1:
                return False
        except:
            # if the relation can't be determined it's not a trivial loop
            return False

        return True
Exemplo n.º 2
0
    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        if not super().can_be_applied(graph, expr_index, sdfg, permissive):
            return False

        guard = self.loop_guard
        begin = self.loop_begin

        # If loop cannot be detected, fail
        found = find_for_loop(sdfg, guard, begin)
        if found is None:
            return False

        return True
Exemplo n.º 3
0
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        if not DetectLoop.can_be_applied(graph, candidate, expr_index, sdfg,
                                         strict):
            return False

        guard = graph.node(candidate[DetectLoop._loop_guard])
        begin = graph.node(candidate[DetectLoop._loop_begin])

        # If loop cannot be detected, fail
        found = find_for_loop(sdfg, guard, begin)
        if found is None:
            return False

        return True
Exemplo n.º 4
0
    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        # Is this even a loop
        if not super().can_be_applied(graph, expr_index, sdfg, permissive):
            return False

        guard = self.loop_guard
        begin = self.loop_begin
        found = find_for_loop(graph, guard, begin)

        # If loop cannot be detected, fail
        if not found:
            return False
        _, rng, _ = found

        # If loop stride is not specialized or constant-sized, fail
        if symbolic.issymbolic(rng[2], sdfg.constants):
            return False
        # If loop range diff is not constant-sized, fail
        if symbolic.issymbolic(rng[1] - rng[0], sdfg.constants):
            return False
        return True
Exemplo n.º 5
0
    def apply(self, _, sdfg: sd.SDFG):
        # Obtain loop information
        guard: sd.SDFGState = self.loop_guard
        body: sd.SDFGState = self.loop_begin

        # Obtain iteration variable, range and stride
        itervar, (start, end,
                  step), (_, body_end) = find_for_loop(sdfg, guard, body)

        # Find all loop-body states
        states = set()
        to_visit = [body]
        while to_visit:
            state = to_visit.pop(0)
            for _, dst, _ in sdfg.out_edges(state):
                if dst not in states and dst is not guard:
                    to_visit.append(dst)
            states.add(state)

        for state in states:
            state.replace(itervar, start)

        # remove loop
        for body_inedge in sdfg.in_edges(body):
            sdfg.remove_edge(body_inedge)
        for body_outedge in sdfg.out_edges(body_end):
            sdfg.remove_edge(body_outedge)

        for guard_inedge in sdfg.in_edges(guard):
            guard_inedge.data.assignments = {}
            sdfg.add_edge(guard_inedge.src, body, guard_inedge.data)
            sdfg.remove_edge(guard_inedge)
        for guard_outedge in sdfg.out_edges(guard):
            guard_outedge.data.condition = CodeBlock("1")
            sdfg.add_edge(body_end, guard_outedge.dst, guard_outedge.data)
            sdfg.remove_edge(guard_outedge)
        sdfg.remove_node(guard)
        if itervar in sdfg.symbols and helpers.is_symbol_unused(sdfg, itervar):
            sdfg.remove_symbol(itervar)
Exemplo n.º 6
0
    def can_be_applied(graph, candidate, expr_index, sdfg, strict=False):
        # Is this even a loop
        if not DetectLoop.can_be_applied(graph, candidate, expr_index, sdfg,
                                         strict):
            return False

        guard = graph.node(candidate[DetectLoop._loop_guard])
        begin = graph.node(candidate[DetectLoop._loop_begin])
        found = find_for_loop(graph, guard, begin)

        # If loop cannot be detected, fail
        if not found:
            return False
        _, rng, _ = found

        # If loop stride is not specialized or constant-sized, fail
        if symbolic.issymbolic(rng[2], sdfg.constants):
            return False
        # If loop range diff is not constant-sized, fail
        if symbolic.issymbolic(rng[1] - rng[0], sdfg.constants):
            return False
        return True
Exemplo n.º 7
0
    def can_be_applied(self, graph, expr_index, sdfg, permissive=False):
        # Is this even a loop
        if not super().can_be_applied(graph, expr_index, sdfg, permissive):
            return False

        # Obtain loop information
        guard: sd.SDFGState = self.loop_guard
        body: sd.SDFGState = self.loop_begin
        after: sd.SDFGState = self.exit_state

        # Obtain iteration variable, range, and stride
        loop_info = find_for_loop(sdfg, guard, body)
        if not loop_info:
            return False
        itervar, (start, end, step), (_, body_end) = loop_info

        if step not in [-1, 1]:
            return False

        # Body must contain a single state
        if body != body_end:
            return False

        # Check if body contains exactly one map
        maps = [node for node in body.nodes() if isinstance(node, nodes.MapEntry)]
        if len(maps) != 1:
            return False

        # Check that everything else is independent of the loop's itervar
        subgraph = body.scope_subgraph(maps[0])
        map_exit = body.exit_node(maps[0])
        descs: Set[str] = set()
        for e in body.edges():
            if not e.data.is_empty():
                descs.add(e.data.data)
            if e.src in subgraph.nodes() or e.dst in subgraph.nodes():
                continue
            if e.dst is maps[0] and isinstance(e.src, nodes.AccessNode):
                continue
            if e.src is map_exit and isinstance(e.dst, nodes.AccessNode):
                continue
            if str(itervar) in e.data.free_symbols:
                return False
        for n in body.nodes():
            if n in subgraph.nodes():
                continue
            if str(itervar) in n.free_symbols:
                return False

        # Check for iteration variable in map and data descriptors
        if str(itervar) in maps[0].free_symbols:
            return False
        for arr in descs:
            if str(itervar) in set(map(str, sdfg.arrays[arr].free_symbols)):
                return False

        def test_subset_dependency(subset: sbs.Subset, mparams: Set[int]) -> Tuple[bool, List[int]]:
            dims = []
            for i, r in enumerate(subset):
                if not isinstance(r, (list, tuple)):
                    r = [r]
                fsymbols = set()
                for token in r:
                    if symbolic.issymbolic(token):
                        fsymbols = fsymbols.union({str(s) for s in token.free_symbols})
                if itervar in fsymbols:
                    if fsymbols.intersection(mparams):
                        return (False, [])
                    else:
                        dims.append(i)
            return (True, dims)

        # Check that Map memlets depend on itervar in a consistent manner
        # a. A container must either not depend at all on itervar, or depend on it always in the same dimensions.
        # b. Abort when a dimension depends on both the itervar and a Map parameter.
        mparams = set(maps[0].map.params)
        data_dependency = dict()
        for e in body.edges():
            if e.src in subgraph.nodes() and e.dst in subgraph.nodes():
                if itervar in e.data.free_symbols:
                    for i, subset in enumerate((e.data.src_subset, e.data.dst_subset)):
                        if subset:
                            if i == 0:
                                access = body.memlet_path(e)[0].src
                            else:
                                access = body.memlet_path(e)[-1].dst
                            passed, dims = test_subset_dependency(subset, mparams)
                            if not passed:
                                return False
                            if dims:
                                if access.data in data_dependency:
                                    if data_dependency[access.data] != dims:
                                        return False
                                else:
                                    data_dependency[access.data] = dims

        for node in body.nodes():
            if isinstance(node, nodes.AccessNode):
                if body.in_edges(node).count(True) > 1:
                    return False
                if body.out_edges(node).count(True) > 1:
                    return False

        return True
Exemplo n.º 8
0
    def apply(self, _, sdfg: sd.SDFG):
        # Obtain loop information
        guard: sd.SDFGState = self.loop_guard
        body: sd.SDFGState = self.loop_begin

        # Obtain iteration variable, range, and stride
        itervar, (start, end, step), _ = find_for_loop(sdfg, guard, body)

        forward_loop = step > 0

        for node in body.nodes():
            if isinstance(node, nodes.MapEntry):
                map_entry = node
            if isinstance(node, nodes.MapExit):
                map_exit = node

        # nest map's content in sdfg
        map_subgraph = body.scope_subgraph(map_entry, include_entry=False, include_exit=False)
        nsdfg = helpers.nest_state_subgraph(sdfg, body, map_subgraph, full_data=True)

        # replicate loop in nested sdfg
        new_before, new_guard, new_after = nsdfg.sdfg.add_loop(
            before_state=None,
            loop_state=nsdfg.sdfg.nodes()[0],
            loop_end_state=None,
            after_state=None,
            loop_var=itervar,
            initialize_expr=f'{start}',
            condition_expr=f'{itervar} <= {end}' if forward_loop else f'{itervar} >= {end}',
            increment_expr=f'{itervar} + {step}' if forward_loop else f'{itervar} - {abs(step)}')

        # remove outer loop
        before_guard_edge = nsdfg.sdfg.edges_between(new_before, new_guard)[0]
        for e in nsdfg.sdfg.out_edges(new_guard):
            if e.dst is new_after:
                guard_after_edge = e
            else:
                guard_body_edge = e

        for body_inedge in sdfg.in_edges(body):
            if body_inedge.src is guard:
                guard_body_edge.data.assignments.update(body_inedge.data.assignments)
            sdfg.remove_edge(body_inedge)
        for body_outedge in sdfg.out_edges(body):
            sdfg.remove_edge(body_outedge)
        for guard_inedge in sdfg.in_edges(guard):
            before_guard_edge.data.assignments.update(guard_inedge.data.assignments)
            guard_inedge.data.assignments = {}
            sdfg.add_edge(guard_inedge.src, body, guard_inedge.data)
            sdfg.remove_edge(guard_inedge)
        for guard_outedge in sdfg.out_edges(guard):
            if guard_outedge.dst is body:
                guard_body_edge.data.assignments.update(guard_outedge.data.assignments)
            else:
                guard_after_edge.data.assignments.update(guard_outedge.data.assignments)
            guard_outedge.data.condition = CodeBlock("1")
            sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data)
            sdfg.remove_edge(guard_outedge)
        sdfg.remove_node(guard)
        if itervar in nsdfg.symbol_mapping:
            del nsdfg.symbol_mapping[itervar]
        if itervar in sdfg.symbols:
            del sdfg.symbols[itervar]

        # Add missing data/symbols
        for s in nsdfg.sdfg.free_symbols:
            if s in nsdfg.symbol_mapping:
                continue
            if s in sdfg.symbols:
                nsdfg.symbol_mapping[s] = s
            elif s in sdfg.arrays:
                desc = sdfg.arrays[s]
                access = body.add_access(s)
                conn = nsdfg.sdfg.add_datadesc(s, copy.deepcopy(desc))
                nsdfg.sdfg.arrays[s].transient = False
                nsdfg.add_in_connector(conn)
                body.add_memlet_path(access, map_entry, nsdfg, memlet=Memlet.from_array(s, desc), dst_conn=conn)
            else:
                raise NotImplementedError(f"Free symbol {s} is neither a symbol nor data.")
        to_delete = set()
        for s in nsdfg.symbol_mapping:
            if s not in nsdfg.sdfg.free_symbols:
                to_delete.add(s)
        for s in to_delete:
            del nsdfg.symbol_mapping[s]

        # propagate scope for correct volumes
        scope_tree = ScopeTree(map_entry, map_exit)
        scope_tree.parent = ScopeTree(None, None)
        # The first execution helps remove apperances of symbols
        # that are now defined only in the nested SDFG in memlets.
        propagation.propagate_memlets_scope(sdfg, body, scope_tree)

        for s in to_delete:
            if helpers.is_symbol_unused(sdfg, s):
                sdfg.remove_symbol(s)

        from dace.transformation.interstate import RefineNestedAccess
        transformation = RefineNestedAccess()
        transformation.setup_match(sdfg, 0, sdfg.node_id(body), {RefineNestedAccess.nsdfg: body.node_id(nsdfg)}, 0)
        transformation.apply(body, sdfg)

        # Second propagation for refined accesses.
        propagation.propagate_memlets_scope(sdfg, body, scope_tree)
Exemplo n.º 9
0
    def can_be_applied(self, graph, candidate, expr_index, sdfg, strict=False):
        # Is this even a loop
        if not DetectLoop.can_be_applied(graph, candidate, expr_index, sdfg,
                                         strict):
            return False

        guard = graph.node(candidate[DetectLoop._loop_guard])
        begin = graph.node(candidate[DetectLoop._loop_begin])

        # Guard state should not contain any dataflow
        if len(guard.nodes()) != 0:
            return False

        # If loop cannot be detected, fail
        found = find_for_loop(graph, guard, begin,
                              itervar=self.itervar)
        if not found:
            return False

        itervar, (start, end, step), (_, body_end) = found

        # We cannot handle symbols read from data containers unless they are
        # scalar
        for expr in (start, end, step):
            if symbolic.contains_sympy_functions(expr):
                return False

        # Find all loop-body states
        states = set([body_end])
        to_visit = [begin]
        while to_visit:
            state = to_visit.pop(0)
            if state is body_end:
                continue
            for _, dst, _ in graph.out_edges(state):
                if dst not in states:
                    to_visit.append(dst)
            states.add(state)

        write_set = set()
        for state in states:
            _, wset = state.read_and_write_sets()
            write_set |= wset

        # Get access nodes from other states to isolate local loop variables
        other_access_nodes = set()
        for state in sdfg.nodes():
            if state in states:
                continue
            other_access_nodes |= set(n.data for n in state.data_nodes()
                                      if sdfg.arrays[n.data].transient)
        # Add non-transient nodes from loop state
        for state in states:
            other_access_nodes |= set(n.data for n in state.data_nodes()
                                      if not sdfg.arrays[n.data].transient)

        write_memlets = defaultdict(list)

        itersym = symbolic.pystr_to_symbolic(itervar)
        a = sp.Wild('a', exclude=[itersym])
        b = sp.Wild('b', exclude=[itersym])

        for state in states:
            for dn in state.data_nodes():
                if dn.data not in other_access_nodes:
                    continue
                # Take all writes that are not conflicted into consideration
                if dn.data in write_set:
                    for e in state.in_edges(dn):
                        if e.data.dynamic and e.data.wcr is None:
                            # If pointers are involved, give up
                            return False
                        # To be sure that the value is only written at unique
                        # indices per loop iteration, we want to match symbols
                        # of the form "a*i+b" where a >= 1, and i is the iteration
                        # variable. The iteration variable must be used.
                        if e.data.wcr is None:
                            dst_subset = e.data.get_dst_subset(e, state)
                            if not _check_range(dst_subset, a, itersym, b, step):
                                return False
                        # End of check

                        write_memlets[dn.data].append(e.data)

        # After looping over relevant writes, consider reads that may overlap
        for state in states:
            for dn in state.data_nodes():
                if dn.data not in other_access_nodes:
                    continue
                data = dn.data
                if data in write_memlets:
                    # Import as necessary
                    from dace.sdfg.propagation import propagate_subset

                    for e in state.out_edges(dn):
                        # If the same container is both read and written, only match if
                        # it read and written at locations that will not create data races
                        if e.data.dynamic and e.data.src_subset.num_elements() != 1:
                            # If pointers are involved, give up
                            return False
                        src_subset = e.data.get_src_subset(e, state)
                        if not _check_range(src_subset, a, itersym, b, step):
                            return False

                        pread = propagate_subset([e.data], sdfg.arrays[data],
                                                [itervar],
                                                subsets.Range([(start, end, step)
                                                                ]))
                        for candidate in write_memlets[data]:
                            # Simple case: read and write are in the same subset
                            if e.data.subset == candidate.subset:
                                break
                            # Propagated read does not overlap with propagated write
                            pwrite = propagate_subset([candidate],
                                                    sdfg.arrays[data], [itervar],
                                                    subsets.Range([(start, end,
                                                                    step)]))
                            if subsets.intersects(pread.subset,
                                                pwrite.subset) is False:
                                break
                            return False

        # Check that the iteration variable is not used on other edges or states
        # before it is reassigned
        prior_states = True
        for state in cfg.stateorder_topological_sort(sdfg):
            # Skip all states up to guard
            if prior_states:
                if state is begin:
                    prior_states = False
                continue
            # We do not need to check the loop-body states
            if state in states:
                continue
            if itervar in state.free_symbols:
                return False
            # Don't continue in this direction, as the variable has
            # now been reassigned
            # TODO: Handle case of subset of out_edges
            if all(itervar in e.data.assignments
                   for e in sdfg.out_edges(state)):
                break

        return True
Exemplo n.º 10
0
    def apply(self, sdfg: sd.SDFG):
        # Obtain loop information
        guard: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_guard])
        body: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_begin])
        after: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._exit_state])

        # Obtain iteration variable, range, and stride
        itervar, (start, end, step), (_, body_end) = find_for_loop(
            sdfg, guard, body, itervar=self.itervar)

        # Find all loop-body states
        states = set([body_end])
        to_visit = [body]
        while to_visit:
            state = to_visit.pop(0)
            if state is body_end:
                continue
            for _, dst, _ in sdfg.out_edges(state):
                if dst not in states:
                    to_visit.append(dst)
            states.add(state)

        # Nest loop-body states
        if len(states) > 1:

            # Find read/write sets
            read_set, write_set = set(), set()
            for state in states:
                rset, wset = state.read_and_write_sets()
                read_set |= rset
                write_set |= wset
                # Add data from edges
                for src in states:
                    for dst in states:
                        for edge in sdfg.edges_between(src, dst):
                            for s in edge.data.free_symbols:
                                if s in sdfg.arrays:
                                    read_set.add(s)

            # Find NestedSDFG's unique data
            rw_set = read_set | write_set
            unique_set = set()
            for name in rw_set:
                if not sdfg.arrays[name].transient:
                    continue
                found = False
                for state in sdfg.states():
                    if state in states:
                        continue
                    for node in state.nodes():
                        if (isinstance(node, nodes.AccessNode) and
                                node.data == name):
                            found = True
                            break
                if not found:
                    unique_set.add(name)

            # Find NestedSDFG's connectors
            read_set = {n for n in read_set if n not in unique_set or not sdfg.arrays[n].transient}
            write_set = {n for n in write_set if n not in unique_set or not sdfg.arrays[n].transient}

            # Create NestedSDFG and add all loop-body states and edges
            # Also, find defined symbols in NestedSDFG
            fsymbols = set(sdfg.free_symbols)
            new_body = sdfg.add_state('single_state_body')
            nsdfg = SDFG("loop_body", constants=sdfg.constants, parent=new_body)
            nsdfg.add_node(body, is_start_state=True)
            body.parent = nsdfg
            exit_state = nsdfg.add_state('exit')
            nsymbols = dict()
            for state in states:
                if state is body:
                    continue
                nsdfg.add_node(state)
                state.parent = nsdfg
            for state in states:
                if state is body:
                    continue
                for src, dst, data in sdfg.in_edges(state):
                    nsymbols.update({s: sdfg.symbols[s] for s in data.assignments.keys() if s in sdfg.symbols})
                    nsdfg.add_edge(src, dst, data)
            nsdfg.add_edge(body_end, exit_state, InterstateEdge())

            # Move guard -> body edge to guard -> new_body
            for src, dst, data, in sdfg.edges_between(guard, body):
                sdfg.add_edge(src, new_body, data)
            # Move body_end -> guard edge to new_body -> guard
            for src, dst, data in sdfg.edges_between(body_end, guard):
                sdfg.add_edge(new_body, dst, data)
            
            # Delete loop-body states and edges from parent SDFG
            for state in states:
                for e in sdfg.all_edges(state):
                    sdfg.remove_edge(e)
                sdfg.remove_node(state)
            
            # Add NestedSDFG arrays
            for name in read_set | write_set:
                nsdfg.arrays[name] = copy.deepcopy(sdfg.arrays[name])
                nsdfg.arrays[name].transient = False
            for name in unique_set:
                nsdfg.arrays[name] = sdfg.arrays[name]
                del sdfg.arrays[name]
            
            # Add NestedSDFG node
            cnode = new_body.add_nested_sdfg(nsdfg, None, read_set, write_set)
            if sdfg.parent:
                for s, m in sdfg.parent_nsdfg_node.symbol_mapping.items():
                    if s not in cnode.symbol_mapping:
                        cnode.symbol_mapping[s] = m
                        nsdfg.add_symbol(s, sdfg.symbols[s])
            for name in read_set:
                r = new_body.add_read(name)
                new_body.add_edge(
                    r, None, cnode, name,
                    memlet.Memlet.from_array(name, sdfg.arrays[name]))
            for name in write_set:
                w = new_body.add_write(name)
                new_body.add_edge(
                    cnode, name, w, None,
                    memlet.Memlet.from_array(name, sdfg.arrays[name]))

            # Fix SDFG symbols
            for sym in sdfg.free_symbols - fsymbols:
                del sdfg.symbols[sym]
            for sym, dtype in nsymbols.items():
                nsdfg.symbols[sym] = dtype

            # Change body state reference
            body = new_body

        if (step < 0) == True:
            # If step is negative, we have to flip start and end to produce a
            # correct map with a positive increment
            start, end, step = end, start, -step

        # If necessary, make a nested SDFG with assignments
        isedge = sdfg.edges_between(guard, body)[0]
        symbols_to_remove = set()
        if len(isedge.data.assignments) > 0:
            nsdfg = helpers.nest_state_subgraph(
                sdfg, body, gr.SubgraphView(body, body.nodes()))
            for sym in isedge.data.free_symbols:
                if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors:
                    continue
                if sym in sdfg.symbols:
                    nsdfg.symbol_mapping[sym] = symbolic.pystr_to_symbolic(sym)
                    nsdfg.sdfg.add_symbol(sym, sdfg.symbols[sym])
                elif sym in sdfg.arrays:
                    if sym in nsdfg.sdfg.arrays:
                        raise NotImplementedError
                    rnode = body.add_read(sym)
                    nsdfg.add_in_connector(sym)
                    desc = copy.deepcopy(sdfg.arrays[sym])
                    desc.transient = False
                    nsdfg.sdfg.add_datadesc(sym, desc)
                    body.add_edge(rnode, None, nsdfg, sym, memlet.Memlet(sym))

            nstate = nsdfg.sdfg.node(0)
            init_state = nsdfg.sdfg.add_state_before(nstate)
            nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0]
            nisedge.data.assignments = isedge.data.assignments
            symbols_to_remove = set(nisedge.data.assignments.keys())
            for k in nisedge.data.assignments.keys():
                if k in nsdfg.symbol_mapping:
                    del nsdfg.symbol_mapping[k]
            isedge.data.assignments = {}

        source_nodes = body.source_nodes()
        sink_nodes = body.sink_nodes()

        map = nodes.Map(body.label + "_map", [itervar], [(start, end, step)])
        entry = nodes.MapEntry(map)
        exit = nodes.MapExit(map)
        body.add_node(entry)
        body.add_node(exit)

        # If the map uses symbols from data containers, instantiate reads
        containers_to_read = entry.free_symbols & sdfg.arrays.keys()
        for rd in containers_to_read:
            # We are guaranteed that this is always a scalar, because
            # can_be_applied makes sure there are no sympy functions in each of
            # the loop expresions
            access_node = body.add_read(rd)
            body.add_memlet_path(access_node,
                                 entry,
                                 dst_conn=rd,
                                 memlet=memlet.Memlet(rd))

        # Reroute all memlets through the entry and exit nodes
        for n in source_nodes:
            if isinstance(n, nodes.AccessNode):
                for e in body.out_edges(n):
                    body.remove_edge(e)
                    body.add_edge_pair(entry,
                                       e.dst,
                                       n,
                                       e.data,
                                       internal_connector=e.dst_conn)
            else:
                body.add_nedge(entry, n, memlet.Memlet())
        for n in sink_nodes:
            if isinstance(n, nodes.AccessNode):
                for e in body.in_edges(n):
                    body.remove_edge(e)
                    body.add_edge_pair(exit,
                                       e.src,
                                       n,
                                       e.data,
                                       internal_connector=e.src_conn)
            else:
                body.add_nedge(n, exit, memlet.Memlet())

        # Get rid of the loop exit condition edge
        after_edge = sdfg.edges_between(guard, after)[0]
        sdfg.remove_edge(after_edge)

        # Remove the assignment on the edge to the guard
        for e in sdfg.in_edges(guard):
            if itervar in e.data.assignments:
                del e.data.assignments[itervar]

        # Remove the condition on the entry edge
        condition_edge = sdfg.edges_between(guard, body)[0]
        condition_edge.data.condition = CodeBlock("1")

        # Get rid of backedge to guard
        sdfg.remove_edge(sdfg.edges_between(body, guard)[0])

        # Route body directly to after state, maintaining any other assignments
        # it might have had
        sdfg.add_edge(
            body, after,
            sd.InterstateEdge(assignments=after_edge.data.assignments))

        # If this had made the iteration variable a free symbol, we can remove
        # it from the SDFG symbols
        if itervar in sdfg.free_symbols:
            sdfg.remove_symbol(itervar)
        for sym in symbols_to_remove:
            if helpers.is_symbol_unused(sdfg, sym):
                sdfg.remove_symbol(sym)
Exemplo n.º 11
0
    def apply(self, sdfg: sd.SDFG):
        # Obtain loop information
        guard: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_guard])
        body: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_begin])
        after: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._exit_state])

        # Obtain iteration variable, range, and stride
        itervar, (start, end, step), _ = find_for_loop(sdfg, guard, body)

        if (step < 0) == True:
            # If step is negative, we have to flip start and end to produce a
            # correct map with a positive increment
            start, end, step = end, start, -step

        # If necessary, make a nested SDFG with assignments
        isedge = sdfg.edges_between(guard, body)[0]
        symbols_to_remove = set()
        if len(isedge.data.assignments) > 0:
            nsdfg = helpers.nest_state_subgraph(
                sdfg, body, gr.SubgraphView(body, body.nodes()))
            for sym in isedge.data.free_symbols:
                if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors:
                    continue
                if sym in sdfg.symbols:
                    nsdfg.symbol_mapping[sym] = symbolic.pystr_to_symbolic(sym)
                    nsdfg.sdfg.add_symbol(sym, sdfg.symbols[sym])
                elif sym in sdfg.arrays:
                    if sym in nsdfg.sdfg.arrays:
                        raise NotImplementedError
                    rnode = body.add_read(sym)
                    nsdfg.add_in_connector(sym)
                    desc = copy.deepcopy(sdfg.arrays[sym])
                    desc.transient = False
                    nsdfg.sdfg.add_datadesc(sym, desc)
                    body.add_edge(rnode, None, nsdfg, sym, memlet.Memlet(sym))

            nstate = nsdfg.sdfg.node(0)
            init_state = nsdfg.sdfg.add_state_before(nstate)
            nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0]
            nisedge.data.assignments = isedge.data.assignments
            symbols_to_remove = set(nisedge.data.assignments.keys())
            for k in nisedge.data.assignments.keys():
                if k in nsdfg.symbol_mapping:
                    del nsdfg.symbol_mapping[k]
            isedge.data.assignments = {}

        source_nodes = body.source_nodes()
        sink_nodes = body.sink_nodes()

        map = nodes.Map(body.label + "_map", [itervar], [(start, end, step)])
        entry = nodes.MapEntry(map)
        exit = nodes.MapExit(map)
        body.add_node(entry)
        body.add_node(exit)

        # If the map uses symbols from data containers, instantiate reads
        containers_to_read = entry.free_symbols & sdfg.arrays.keys()
        for rd in containers_to_read:
            # We are guaranteed that this is always a scalar, because
            # can_be_applied makes sure there are no sympy functions in each of
            # the loop expresions
            access_node = body.add_read(rd)
            body.add_memlet_path(access_node,
                                 entry,
                                 dst_conn=rd,
                                 memlet=memlet.Memlet(rd))

        # Reroute all memlets through the entry and exit nodes
        for n in source_nodes:
            if isinstance(n, nodes.AccessNode):
                for e in body.out_edges(n):
                    body.remove_edge(e)
                    body.add_edge_pair(entry,
                                       e.dst,
                                       n,
                                       e.data,
                                       internal_connector=e.dst_conn)
            else:
                body.add_nedge(entry, n, memlet.Memlet())
        for n in sink_nodes:
            if isinstance(n, nodes.AccessNode):
                for e in body.in_edges(n):
                    body.remove_edge(e)
                    body.add_edge_pair(exit,
                                       e.src,
                                       n,
                                       e.data,
                                       internal_connector=e.src_conn)
            else:
                body.add_nedge(n, exit, memlet.Memlet())

        # Get rid of the loop exit condition edge
        after_edge = sdfg.edges_between(guard, after)[0]
        sdfg.remove_edge(after_edge)

        # Remove the assignment on the edge to the guard
        for e in sdfg.in_edges(guard):
            if itervar in e.data.assignments:
                del e.data.assignments[itervar]

        # Remove the condition on the entry edge
        condition_edge = sdfg.edges_between(guard, body)[0]
        condition_edge.data.condition = CodeBlock("1")

        # Get rid of backedge to guard
        sdfg.remove_edge(sdfg.edges_between(body, guard)[0])

        # Route body directly to after state, maintaining any other assignments
        # it might have had
        sdfg.add_edge(
            body, after,
            sd.InterstateEdge(assignments=after_edge.data.assignments))

        # If this had made the iteration variable a free symbol, we can remove
        # it from the SDFG symbols
        if itervar in sdfg.free_symbols:
            sdfg.remove_symbol(itervar)
        for sym in symbols_to_remove:
            if helpers.is_symbol_unused(sdfg, sym):
                sdfg.remove_symbol(sym)
Exemplo n.º 12
0
    def apply(self, _, sdfg: sd.SDFG):
        ####################################################################
        # Obtain loop information
        guard: sd.SDFGState = self.loop_guard
        begin: sd.SDFGState = self.loop_begin
        after_state: sd.SDFGState = self.exit_state

        # Obtain iteration variable, range, and stride
        condition_edge = sdfg.edges_between(guard, begin)[0]
        not_condition_edge = sdfg.edges_between(guard, after_state)[0]
        itervar, rng, loop_struct = find_for_loop(sdfg, guard, begin)

        # Get loop states
        loop_states = list(
            sdutil.dfs_conditional(sdfg,
                                   sources=[begin],
                                   condition=lambda _, child: child != guard))
        first_id = loop_states.index(begin)
        last_state = loop_struct[1]
        last_id = loop_states.index(last_state)
        loop_subgraph = gr.SubgraphView(sdfg, loop_states)

        ####################################################################
        # Transform

        if self.begin:
            # If begin, change initialization assignment and prepend states before
            # guard
            init_edges = []
            before_states = loop_struct[0]
            for before_state in before_states:
                init_edge = sdfg.edges_between(before_state, guard)[0]
                init_edge.data.assignments[itervar] = str(rng[0] +
                                                          self.count * rng[2])
                init_edges.append(init_edge)
            append_states = before_states

            # Add `count` states, each with instantiated iteration variable
            for i in range(self.count):
                # Instantiate loop states with iterate value
                state_name: str = 'start_' + itervar + str(i * rng[2])
                state_name = state_name.replace('-', 'm').replace(
                    '+', 'p').replace('*', 'M').replace('/', 'D')
                new_states = self.instantiate_loop(
                    sdfg,
                    loop_states,
                    loop_subgraph,
                    itervar,
                    rng[0] + i * rng[2],
                    state_name,
                )

                # Connect states to before the loop with unconditional edges
                for append_state in append_states:
                    sdfg.add_edge(append_state, new_states[first_id],
                                  sd.InterstateEdge())
                append_states = [new_states[last_id]]

            # Reconnect edge to guard state from last peeled iteration
            for append_state in append_states:
                if append_state not in before_states:
                    for init_edge in init_edges:
                        sdfg.remove_edge(init_edge)
                    sdfg.add_edge(append_state, guard, init_edges[0].data)
        else:
            # If begin, change initialization assignment and prepend states before
            # guard
            itervar_sym = pystr_to_symbolic(itervar)
            condition_edge.data.condition = CodeBlock(
                self._modify_cond(condition_edge.data.condition, itervar,
                                  rng[2]))
            not_condition_edge.data.condition = CodeBlock(
                self._modify_cond(not_condition_edge.data.condition, itervar,
                                  rng[2]))
            prepend_state = after_state

            # Add `count` states, each with instantiated iteration variable
            for i in reversed(range(self.count)):
                # Instantiate loop states with iterate value
                state_name: str = 'end_' + itervar + str(-i * rng[2])
                state_name = state_name.replace('-', 'm').replace(
                    '+', 'p').replace('*', 'M').replace('/', 'D')
                new_states = self.instantiate_loop(
                    sdfg,
                    loop_states,
                    loop_subgraph,
                    itervar,
                    itervar_sym + i * rng[2],
                    state_name,
                )

                # Connect states to before the loop with unconditional edges
                sdfg.add_edge(new_states[last_id], prepend_state,
                              sd.InterstateEdge())
                prepend_state = new_states[first_id]

            # Reconnect edge to guard state from last peeled iteration
            if prepend_state != after_state:
                sdfg.remove_edge(not_condition_edge)
                sdfg.add_edge(guard, prepend_state, not_condition_edge.data)
Exemplo n.º 13
0
    def apply(self, sdfg):
        # Obtain loop information
        guard: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_guard])
        begin: sd.SDFGState = sdfg.node(self.subgraph[DetectLoop._loop_begin])
        after_state: sd.SDFGState = sdfg.node(
            self.subgraph[DetectLoop._exit_state])

        # Obtain iteration variable, range, and stride, together with the last
        # state(s) before the loop and the last loop state.
        itervar, rng, loop_struct = find_for_loop(sdfg, guard, begin)

        # Loop must be fully unrollable for now.
        if self.count != 0:
            raise NotImplementedError  # TODO(later)

        # Get loop states
        loop_states = list(
            sdutil.dfs_conditional(sdfg,
                                   sources=[begin],
                                   condition=lambda _, child: child != guard))
        first_id = loop_states.index(begin)
        last_state = loop_struct[1]
        last_id = loop_states.index(last_state)
        loop_subgraph = gr.SubgraphView(sdfg, loop_states)

        try:
            start, end, stride = (r for r in rng)
            stride = symbolic.evaluate(stride, sdfg.constants)
            loop_diff = int(symbolic.evaluate(end - start + 1, sdfg.constants))
            is_symbolic = any([symbolic.issymbolic(r) for r in rng[:2]])
        except TypeError:
            raise TypeError('Loop difference and strides cannot be symbolic.')
        # Create states for loop subgraph
        unrolled_states = []

        for i in range(0, loop_diff, stride):
            current_index = start + i
            # Instantiate loop states with iterate value
            new_states = self.instantiate_loop(sdfg, loop_states,
                                               loop_subgraph, itervar,
                                               current_index,
                                               str(i) if is_symbolic else None)

            # Connect iterations with unconditional edges
            if len(unrolled_states) > 0:
                sdfg.add_edge(unrolled_states[-1][1], new_states[first_id],
                              sd.InterstateEdge())

            unrolled_states.append((new_states[first_id], new_states[last_id]))

        # Get any assignments that might be on the edge to the after state
        after_assignments = (sdfg.edges_between(
            guard, after_state)[0].data.assignments)

        # Connect new states to before and after states without conditions
        if unrolled_states:
            before_states = loop_struct[0]
            for before_state in before_states:
                sdfg.add_edge(before_state, unrolled_states[0][0],
                              sd.InterstateEdge())
            sdfg.add_edge(unrolled_states[-1][1], after_state,
                          sd.InterstateEdge(assignments=after_assignments))

        # Remove old states from SDFG
        sdfg.remove_nodes_from([guard] + loop_states)