Example #1
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        map_exit = self.map_exit
        outer_map_exit = self.outer_map_exit

        # Choose array
        array = self.array
        if array is None or len(array) == 0:
            array = next(e.data.data
                         for e in graph.edges_between(map_exit, outer_map_exit)
                         if e.data.wcr is not None)

        # Avoid import loop
        from dace.transformation.dataflow.local_storage import OutLocalStorage

        data_node: nodes.AccessNode = OutLocalStorage.apply_to(
            sdfg,
            dict(array=array),
            verify=False,
            save=False,
            node_a=map_exit,
            node_b=outer_map_exit)

        if self.identity is None:
            warnings.warn('AccumulateTransient did not properly initialize '
                          'newly-created transient!')
            return

        sdfg_state: SDFGState = sdfg.node(self.state_id)

        map_entry = sdfg_state.entry_node(map_exit)

        nested_sdfg: NestedSDFG = nest_state_subgraph(
            sdfg=sdfg,
            state=sdfg_state,
            subgraph=SubgraphView(
                sdfg_state, {map_entry, map_exit}
                | sdfg_state.all_nodes_between(map_entry, map_exit)))

        nested_sdfg_state: SDFGState = nested_sdfg.sdfg.nodes()[0]

        init_state = nested_sdfg.sdfg.add_state_before(nested_sdfg_state)

        temp_array: Array = sdfg.arrays[data_node.data]

        init_state.add_mapped_tasklet(
            name='acctrans_init',
            map_ranges={
                '_o%d' % i: '0:%s' % symstr(d)
                for i, d in enumerate(temp_array.shape)
            },
            inputs={},
            code='out = %s' % self.identity,
            outputs={
                'out':
                dace.Memlet.simple(data=data_node.data,
                                   subset_str=','.join([
                                       '_o%d' % i
                                       for i, _ in enumerate(temp_array.shape)
                                   ]))
            },
            external_edges=True)
Example #2
0
    def apply(self, graph: SDFGState, sdfg: SDFG):
        node_a = self.node_a
        node_b = self.node_b
        prefix = self.prefix

        # Determine direction of new memlet
        scope_dict = graph.scope_dict()
        propagate_forward = sd.scope_contains_scope(scope_dict, node_a, node_b)

        array = self.array
        if array is None or len(array) == 0:
            array = next(e.data.data
                         for e in graph.edges_between(node_a, node_b)
                         if e.data.data is not None and e.data.wcr is None)

        original_edge = None
        invariant_memlet = None
        for edge in graph.edges_between(node_a, node_b):
            if array == edge.data.data:
                original_edge = edge
                invariant_memlet = edge.data
                break
        if invariant_memlet is None:
            for edge in graph.edges_between(node_a, node_b):
                original_edge = edge
                invariant_memlet = edge.data
                warnings.warn('Array %s not found! Using array %s instead.' %
                              (array, invariant_memlet.data))
                array = invariant_memlet.data
                break
        if invariant_memlet is None:
            raise NameError('Array %s not found!' % array)
        if self.create_array:
            # Add transient array
            new_data, _ = sdfg.add_transient(
                name=prefix + invariant_memlet.data,
                shape=[
                    symbolic.overapproximate(r).simplify()
                    for r in invariant_memlet.bounding_box_size()
                ],
                dtype=sdfg.arrays[invariant_memlet.data].dtype,
                find_new_name=True)

        else:
            new_data = prefix + invariant_memlet.data
        data_node = nodes.AccessNode(new_data)
        # Store as fields so that other transformations can use them
        self._local_name = new_data
        self._data_node = data_node

        to_data_mm = copy.deepcopy(invariant_memlet)
        from_data_mm = copy.deepcopy(invariant_memlet)
        offset = subsets.Indices([r[0] for r in invariant_memlet.subset])

        # Reconnect, assuming one edge to the access node
        graph.remove_edge(original_edge)
        if propagate_forward:
            graph.add_edge(node_a, original_edge.src_conn, data_node, None,
                           to_data_mm)
            new_edge = graph.add_edge(data_node, None, node_b,
                                      original_edge.dst_conn, from_data_mm)
        else:
            new_edge = graph.add_edge(node_a, original_edge.src_conn,
                                      data_node, None, to_data_mm)
            graph.add_edge(data_node, None, node_b, original_edge.dst_conn,
                           from_data_mm)

        # Offset all edges in the memlet tree (including the new edge)
        for edge in graph.memlet_tree(new_edge):
            edge.data.subset.offset(offset, True)
            edge.data.data = new_data

        return data_node