Exemplo n.º 1
0
def get_vertical_loop_section_sdfg(section: "VerticalLoopSection") -> SDFG:
    from gtc.dace.nodes import HorizontalExecutionLibraryNode

    sdfg = SDFG("VerticalLoopSection_" + str(id(section)))
    old_state = sdfg.add_state("start_state", is_start_state=True)
    for he in section.horizontal_executions:
        new_state = sdfg.add_state("HorizontalExecution_" + str(id(he)) + "_state")
        sdfg.add_edge(old_state, new_state, InterstateEdge())
        new_state.add_node(HorizontalExecutionLibraryNode(oir_node=he))

        old_state = new_state
    return sdfg
Exemplo n.º 2
0
def mkc(sdfg: dace.SDFG,
        state_before,
        src_name,
        dst_name,
        src_storage=None,
        dst_storage=None,
        src_shape=None,
        dst_shape=None,
        copy_expr=None,
        src_loc=None,
        dst_loc=None):
    """
    Helper MaKe_Copy that creates and appends states performing exactly one copy. If a provided
    arrayname already exists it will use the old array, and ignore all newly passed values
    """

    if copy_expr is None:
        copy_expr = src_name
    if (state_before == None):
        state = sdfg.add_state(is_start_state=True)
    else:
        state = sdfg.add_state_after(state_before)

    def mkarray(name, shape, storage, loc):
        if (name in sdfg.arrays):
            return sdfg.arrays[name]
        is_transient = False
        if (storage in _FPGA_STORAGE_TYPES):
            is_transient = True
        arr = sdfg.add_array(name,
                             shape,
                             dace.int32,
                             storage,
                             transient=is_transient)
        if loc is not None:
            arr[1].location["memorytype"] = loc[0]
            arr[1].location["bank"] = loc[1]
        return arr

    a = mkarray(src_name, src_shape, src_storage, src_loc)
    b = mkarray(dst_name, dst_shape, dst_storage, dst_loc)

    aAcc = state.add_access(src_name)
    bAcc = state.add_access(dst_name)

    edge = state.add_edge(aAcc, None, bAcc, None, mem.Memlet(copy_expr))

    a_np_arr, b_np_arr = None, None
    if src_shape is not None:
        try:
            a_np_arr = np.zeros(src_shape, dtype=np.int32)
        except:
            pass
    if dst_shape is not None:
        try:
            b_np_arr = np.zeros(dst_shape, dtype=np.int32)
        except:
            pass
    return (state, a_np_arr, b_np_arr)
Exemplo n.º 3
0
def _expand_and_finalize_sdfg(stencil_ir: gtir.Stencil, sdfg: dace.SDFG,
                              layout_map) -> dace.SDFG:

    args_data = make_args_data_from_gtir(GtirPipeline(stencil_ir))

    # stencils without effect
    if all(info is None for info in args_data.field_info.values()):
        sdfg = dace.SDFG(stencil_ir.name)
        sdfg.add_state(stencil_ir.name)
        return sdfg

    for array in sdfg.arrays.values():
        if array.transient:
            array.lifetime = dace.AllocationLifetime.Persistent

    _pre_expand_trafos(sdfg)
    sdfg.expand_library_nodes(recursive=True)
    _specialize_transient_strides(sdfg, layout_map=layout_map)
    _post_expand_trafos(sdfg)

    return sdfg
Exemplo n.º 4
0
def split_condition_interstate_edges(sdfg: dace.SDFG):
    edges_to_split = set()
    for isedge in sdfg.edges():
        if (not isedge.data.is_unconditional()
                and len(isedge.data.assignments) > 0):
            edges_to_split.add(isedge)

    for ise in edges_to_split:
        sdfg.remove_edge(ise)
        interim = sdfg.add_state()
        sdfg.add_edge(ise.src, interim,
                      dace.InterstateEdge(ise.data.condition))
        sdfg.add_edge(interim, ise.dst,
                      dace.InterstateEdge(assignments=ise.data.assignments))
Exemplo n.º 5
0
class BaseOirSDFGBuilder(ABC):
    has_transients = True

    def __init__(self, name, stencil: Stencil, nodes):
        self._stencil = stencil
        self._sdfg = SDFG(name)
        self._state = self._sdfg.add_state(name + "_state")
        self._extents = nodes_extent_calculation(nodes)

        self._dtypes = {
            decl.name: decl.dtype
            for decl in stencil.declarations + stencil.params
        }
        self._axes = {
            decl.name: decl.dimensions
            for decl in stencil.declarations + stencil.params
            if isinstance(decl, FieldDecl)
        }

        self._recent_write_acc: Dict[str, dace.nodes.AccessNode] = dict()
        self._recent_read_acc: Dict[str, dace.nodes.AccessNode] = dict()

        self._access_nodes: Dict[str, dace.nodes.AccessNode] = dict()
        self._access_collection_cache: Dict[
            int, AccessCollector.CartesianAccessCollection] = dict()
        self._source_nodes: Dict[str, dace.nodes.AccessNode] = dict()
        self._delete_candidates: List[MultiConnectorEdge] = list()

    def _access_space_to_subset(self, name, access_space):
        extent = self._extents[name]
        origin = (extent[0][0], extent[1][0])
        subsets = []
        if self._axes[name][0]:
            subsets.append("{start}:__I{end:+d}".format(
                start=origin[0] + access_space[0][0],
                end=origin[0] + access_space[0][1]))
        if self._axes[name][1]:
            subsets.append("{start}:__J{end:+d}".format(
                start=origin[1] + access_space[1][0],
                end=origin[1] + access_space[1][1]))
        return subsets

    def _are_nodes_ordered(self, name, node1, node2):
        assert name in self._access_nodes
        assert node1.data == name
        assert node2.data == name
        return self._access_nodes[name].index(
            node1) < self._access_nodes[name].index(node2)

    def _get_source(self, name):
        if name not in self._source_nodes:
            self._source_nodes[name] = self._state.add_read(name)
            if name not in self._access_nodes:
                self._access_nodes[name] = []
            self._access_nodes[name].insert(0, self._source_nodes[name])
        return self._source_nodes[name]

    def _get_new_sink(self, name):
        res = self._state.add_access(name)
        if name not in self._access_nodes:
            self._access_nodes[name] = []
        self._access_nodes[name].append(res)
        return res

    def _get_current_sink(self, name):
        if name in self._access_nodes:
            return self._access_nodes[name][-1]
        return None

    def _get_access_collection(
        self, node:
        "Union[HorizontalExecutionLibraryNode, VerticalLoopLibraryNode, SDFG]"
    ) -> AccessCollector.CartesianAccessCollection:
        if isinstance(node, SDFG):
            res = AccessCollector.CartesianAccessCollection([])
            for node in node.states()[0].nodes():
                if isinstance(
                        node,
                    (HorizontalExecutionLibraryNode, VerticalLoopLibraryNode)):
                    collection = self._get_access_collection(node)
                    res._ordered_accesses.extend(collection._ordered_accesses)
            return res
        elif isinstance(node, HorizontalExecutionLibraryNode):
            if id(node.oir_node) not in self._access_collection_cache:
                self._access_collection_cache[id(
                    node.oir_node)] = AccessCollector.apply(
                        node.oir_node).cartesian_accesses()
            return self._access_collection_cache[id(node.oir_node)]
        else:
            assert isinstance(node, VerticalLoopLibraryNode)
            res = AccessCollector.CartesianAccessCollection([])
            for _, sdfg in node.sections:
                collection = self._get_access_collection(sdfg)
                res._ordered_accesses.extend(collection._ordered_accesses)
            return res

    def _get_recent_reads(self, name, interval):
        if name not in self._recent_read_acc:
            self._recent_read_acc[name] = IntervalMapping()
        if not self._axes[name][2]:
            interval = Interval.full()
        return self._recent_read_acc[name][interval]

    def _get_recent_writes(self, name, interval):
        if name not in self._recent_write_acc:
            self._recent_write_acc[name] = IntervalMapping()
        if not self._axes[name][2]:
            interval = Interval.full()
        return self._recent_write_acc[name][interval]

    def _set_read(self, name, interval, node):
        if name not in self._recent_read_acc:
            self._recent_read_acc[name] = IntervalMapping()
        if not self._axes[name][2]:
            interval = Interval.full()
        self._recent_read_acc[name][interval] = node

    def _set_write(self, name, interval, node):
        if name not in self._recent_write_acc:
            self._recent_write_acc[name] = IntervalMapping()
        if not self._axes[name][2]:
            interval = Interval.full()
        self._recent_write_acc[name][interval] = node

    def _reset_writes(self):
        self._recent_write_acc = dict()

    def _add_read_edges(
        self, node,
        collections: List[Tuple[Interval,
                                AccessCollector.CartesianAccessCollection]]):
        read_accesses: Dict[str, dace.nodes.AccessNode] = dict()
        for interval, access_collection in collections:

            for name in access_collection.read_fields():
                for offset in access_collection.read_offsets()[name]:
                    read_interval = interval.shifted(offset[2])
                    for candidate_access in self._get_recent_writes(
                            name, read_interval):
                        if name not in read_accesses or self._are_nodes_ordered(
                                name, read_accesses[name], candidate_access):
                            # candidate_access is downstream from recent_access, therefore candidate is more recent
                            read_accesses[name] = candidate_access

        for interval, access_collection in collections:
            for name in access_collection.read_fields():
                for offset in access_collection.read_offsets()[name]:
                    read_interval = interval.shifted(offset[2])
                    if name not in read_accesses:
                        read_accesses[name] = self._get_source(name)
                    self._set_read(name, read_interval, read_accesses[name])

        for name, recent_access in read_accesses.items():
            node.add_in_connector("IN_" + name)
            self._state.add_edge(recent_access, None, node, "IN_" + name,
                                 dace.Memlet())

    def _add_write_edges(
        self, node,
        collections: List[Tuple[Interval,
                                AccessCollector.CartesianAccessCollection]]):
        write_accesses = dict()
        for interval, access_collection in collections:
            for name in access_collection.write_fields():
                access_node = self._get_current_sink(name)
                if access_node is None or (
                    (name not in write_accesses) and
                    (access_node in self._get_recent_reads(name, interval)
                     or access_node in self._get_recent_writes(name, interval)
                     or nx.has_path(self._state.nx, access_node, node))):
                    write_accesses[name] = self._get_new_sink(name)
                else:
                    write_accesses[name] = access_node
                self._set_write(name, interval, write_accesses[name])

        for name, access_node in write_accesses.items():
            node.add_out_connector("OUT_" + name)
            self._state.add_edge(node, "OUT_" + name, access_node, None,
                                 dace.Memlet())

    def _add_write_after_write_edges(
        self, node,
        collections: List[Tuple[Interval,
                                AccessCollector.CartesianAccessCollection]]):
        for interval, collection in collections:
            for name in collection.write_fields():
                for src in self._get_recent_writes(name, interval):
                    edge = self._state.add_edge(src, None, node, None,
                                                dace.Memlet())
                    self._delete_candidates.append(edge)

    def _add_write_after_read_edges(
        self, node,
        collections: List[Tuple[Interval,
                                AccessCollector.CartesianAccessCollection]]):
        for interval, collection in collections:
            for name in collection.read_fields():
                for offset in collection.read_offsets()[name]:
                    read_interval = interval.shifted(offset[2])
                    for dst in self._get_recent_writes(name, read_interval):
                        edge = self._state.add_edge(node, None, dst, None,
                                                    dace.Memlet())
                        self._delete_candidates.append(edge)

        for interval, collection in collections:
            for name in collection.write_fields():
                self._set_write(name, interval, node)

    def add_node(self, node):
        self._state.add_node(node)

    def finalize(self):
        for edge in self._delete_candidates:
            assert edge.src_conn is None
            assert edge.dst_conn is None
            self._state.remove_edge(edge)
            if not nx.has_path(self._state.nx, edge.src, edge.dst):
                self._state.add_edge(edge.src, edge.src_conn, edge.dst,
                                     edge.dst_conn, edge.data)

        self.add_subsets()
        self.add_arrays()

        for acc in (n for n in self._state.nodes()
                    if isinstance(n, dace.nodes.AccessNode)):
            is_write = len(self._state.in_edges(acc)) > 0 and all(
                edge.data.data is not None
                for edge in self._state.in_edges(acc))
            is_read = len(self._state.out_edges(acc)) > 0 and all(
                edge.data.data is not None
                for edge in self._state.out_edges(acc))
            if is_read and is_write:
                acc.access = dace.AccessType.ReadWrite
            elif is_read:
                acc.access = dace.AccessType.ReadOnly
            else:
                assert is_write
                acc.access = dace.AccessType.WriteOnly

    def _get_sdfg(self):
        self.finalize()
        return self._sdfg

    def add_arrays(self):
        shapes = self.get_shapes()
        for decl in self._stencil.params + self._stencil.declarations:
            name = decl.name
            dtype = dace.dtypes.typeclass(
                np.dtype(data_type_to_typestr(self._dtypes[name])).name)
            if isinstance(decl, ScalarDecl):
                self._sdfg.add_symbol(name, stype=dtype)
            else:
                if name not in self._get_access_collection(
                        self._sdfg).offsets():
                    continue
                assert name in self._dtypes
                strides = tuple(
                    dace.symbolic.pystr_to_symbolic(f"__{name}_{var}_stride")
                    for is_axis, var in zip(self._axes[name], "IJK") if is_axis
                ) + tuple(
                    dace.symbolic.pystr_to_symbolic(f"__{name}_d{dim}_stride")
                    for dim, _ in enumerate(decl.data_dims))
                self._sdfg.add_array(
                    name,
                    dtype=dtype,
                    shape=shapes[name],
                    strides=strides,
                    transient=isinstance(decl, Temporary)
                    and self.has_transients,
                    lifetime=dace.AllocationLifetime.Persistent,
                )

    def add_subsets(self):
        decls = {
            decl.name: decl
            for decl in self._stencil.params + self._stencil.declarations
        }
        for node in self._state.nodes():
            if isinstance(node, dace.nodes.LibraryNode):
                access_spaces_input, access_spaces_output = self.get_access_spaces(
                    node)
                k_subset_strs_input, k_subset_strs_output = self.get_k_subsets(
                    node)
                for edge in self._state.in_edges(node) + self._state.out_edges(
                        node):
                    if edge.dst_conn is not None:
                        name = edge.src.data
                        access_space = access_spaces_input[name]
                        subset_str_k = k_subset_strs_input.get(name, None)
                        dynamic = isinstance(
                            node, HorizontalExecutionLibraryNode) and any(
                                isinstance(stmt, oir.MaskStmt)
                                for stmt in node.oir_node.body)
                    elif edge.src_conn is not None:
                        name = edge.dst.data
                        access_space = access_spaces_output[name]
                        subset_str_k = k_subset_strs_output.get(name, None)
                        dynamic = False
                    else:
                        continue
                    subset_strs = self._access_space_to_subset(
                        name, access_space)
                    if subset_str_k is not None:
                        subset_strs.append(subset_str_k)
                    for dim in decls[name].data_dims:
                        subset_strs.append(f"0:{dim}")
                    edge.data = dace.Memlet.simple(
                        data=name,
                        subset_str=",".join(subset_strs),
                        dynamic=dynamic)

    @abstractmethod
    def get_k_size(self, name):
        pass

    @abstractmethod
    def add_read_edges(self, node):
        pass

    @abstractmethod
    def add_write_edges(self, node):
        pass

    @abstractmethod
    def add_write_after_read_edges(self, node):
        pass

    @abstractmethod
    def add_write_after_write_edges(self, node):
        pass

    @abstractmethod
    def get_k_subsets(self, node):
        pass

    @abstractmethod
    def get_access_spaces(self, node):
        pass

    @abstractmethod
    def get_shapes(self):
        pass

    @classmethod
    def build(cls, name, stencil: Stencil,
              nodes: List[dace.nodes.LibraryNode]):
        builder = cls(name, stencil, nodes)
        for n in nodes:
            builder.add_node(n)
            builder.add_write_after_write_edges(n)
            builder.add_read_edges(n)
            builder.add_write_edges(n)
        builder._reset_writes()
        for n in reversed(nodes):
            builder.add_write_after_read_edges(n)
        res = builder._get_sdfg()
        res.validate()
        return res
Exemplo n.º 6
0
def make_sdfg(specialize):

    if specialize:
        sdfg = SDFG("histogram_fpga_parallel_{}_{}x{}".format(
            P.get(), H.get(), W.get()))
    else:
        sdfg = SDFG("histogram_fpga_parallel_{}".format(P.get()))

    copy_to_fpga_state = make_copy_to_fpga_state(sdfg)

    state = sdfg.add_state("compute")

    # Compute module
    nested_sdfg = make_compute_nested_sdfg(state)
    tasklet = state.add_nested_sdfg(nested_sdfg, sdfg, {"A_pipe_in"},
                                    {"hist_pipe_out"})
    A_pipes_out = state.add_stream("A_pipes",
                                   dtype,
                                   shape=(P, ),
                                   transient=True,
                                   storage=StorageType.FPGA_Local)
    A_pipes_in = state.add_stream("A_pipes",
                                  dtype,
                                  shape=(P, ),
                                  transient=True,
                                  storage=StorageType.FPGA_Local)
    hist_pipes_out = state.add_stream("hist_pipes",
                                      itype,
                                      shape=(P, ),
                                      transient=True,
                                      storage=StorageType.FPGA_Local)
    unroll_entry, unroll_exit = state.add_map(
        "unroll_compute", {"p": "0:P"},
        schedule=dace.ScheduleType.FPGA_Device,
        unroll=True)
    state.add_memlet_path(unroll_entry, A_pipes_in, memlet=EmptyMemlet())
    state.add_memlet_path(hist_pipes_out, unroll_exit, memlet=EmptyMemlet())
    state.add_memlet_path(A_pipes_in,
                          tasklet,
                          dst_conn="A_pipe_in",
                          memlet=Memlet.simple(A_pipes_in,
                                               "p",
                                               num_accesses="W*H"))
    state.add_memlet_path(tasklet,
                          hist_pipes_out,
                          src_conn="hist_pipe_out",
                          memlet=Memlet.simple(hist_pipes_out,
                                               "p",
                                               num_accesses="num_bins"))

    # Read module
    a_device = state.add_array("A_device", (H, W),
                               dtype,
                               transient=True,
                               storage=dace.dtypes.StorageType.FPGA_Global)
    read_entry, read_exit = state.add_map("read_map", {
        "h": "0:H",
        "w": "0:W:P"
    },
                                          schedule=ScheduleType.FPGA_Device)
    a_val = state.add_array("A_val", (P, ),
                            dtype,
                            transient=True,
                            storage=StorageType.FPGA_Local)
    read_unroll_entry, read_unroll_exit = state.add_map(
        "read_unroll", {"p": "0:P"},
        schedule=ScheduleType.FPGA_Device,
        unroll=True)
    read_tasklet = state.add_tasklet("read", {"A_in"}, {"A_pipe"},
                                     "A_pipe = A_in[p]")
    state.add_memlet_path(a_device,
                          read_entry,
                          a_val,
                          memlet=Memlet(a_val,
                                        num_accesses=1,
                                        subset=Indices(["0"]),
                                        vector_length=P.get(),
                                        other_subset=Indices(["h", "w"])))
    state.add_memlet_path(a_val,
                          read_unroll_entry,
                          read_tasklet,
                          dst_conn="A_in",
                          memlet=Memlet.simple(a_val,
                                               "0",
                                               veclen=P.get(),
                                               num_accesses=1))
    state.add_memlet_path(read_tasklet,
                          read_unroll_exit,
                          read_exit,
                          A_pipes_out,
                          src_conn="A_pipe",
                          memlet=Memlet.simple(A_pipes_out, "p"))

    # Write module
    hist_pipes_in = state.add_stream("hist_pipes",
                                     itype,
                                     shape=(P, ),
                                     transient=True,
                                     storage=StorageType.FPGA_Local)
    hist_device_out = state.add_array(
        "hist_device", (num_bins, ),
        itype,
        transient=True,
        storage=dace.dtypes.StorageType.FPGA_Global)
    merge_entry, merge_exit = state.add_map("merge", {"nb": "0:num_bins"},
                                            schedule=ScheduleType.FPGA_Device)
    merge_reduce = state.add_reduce("lambda a, b: a + b", (0, ),
                                    "0",
                                    schedule=ScheduleType.FPGA_Device)
    state.add_memlet_path(hist_pipes_in,
                          merge_entry,
                          merge_reduce,
                          memlet=Memlet.simple(hist_pipes_in,
                                               "0:P",
                                               num_accesses=P))
    state.add_memlet_path(merge_reduce,
                          merge_exit,
                          hist_device_out,
                          memlet=dace.memlet.Memlet.simple(
                              hist_device_out, "nb"))

    copy_to_host_state = make_copy_to_host_state(sdfg)

    sdfg.add_edge(copy_to_fpga_state, state, dace.graph.edges.InterstateEdge())
    sdfg.add_edge(state, copy_to_host_state, dace.graph.edges.InterstateEdge())

    return sdfg
Exemplo n.º 7
0
    def backward(
        forward_node: Node, context: BackwardContext,
        given_gradients: typing.List[typing.Optional[str]],
        required_gradients: typing.List[typing.Optional[str]]
    ) -> typing.Tuple[Node, BackwardResult]:
        reduction_type = detect_reduction_type(forward_node.wcr)

        if len(given_gradients) != 1:
            raise AutoDiffException(
                "recieved invalid SDFG: reduce node {} should have exactly one output edge"
                .format(forward_node))

        if len(required_gradients) != 1:
            raise AutoDiffException(
                "recieved invalid SDFG: reduce node {} should have exactly one input edge"
                .format(forward_node))

        input_name = next(iter(required_gradients))
        in_desc = in_desc_with_name(forward_node, context.forward_state,
                                    context.forward_sdfg, input_name)

        output_name = next(iter(given_gradients))
        out_desc = out_desc_with_name(forward_node, context.forward_state,
                                      context.forward_sdfg, output_name)

        all_axes: typing.List[int] = list(range(len(in_desc.shape)))
        reduce_axes: typing.List[
            int] = all_axes if forward_node.axes is None else forward_node.axes
        non_reduce_axes: typing.List[int] = [
            i for i in all_axes if i not in reduce_axes
        ]

        result = BackwardResult.empty()

        if reduction_type is dtypes.ReductionType.Sum:
            # in this case, we need to simply scatter the grad across the axes that were reduced

            sdfg = SDFG("_reverse_" + str(reduction_type).replace(".", "_") +
                        "_")
            state = sdfg.add_state()

            rev_input_conn_name = "input_gradient"
            rev_output_conn_name = "output_gradient"
            result.required_grad_names[output_name] = rev_output_conn_name
            result.given_grad_names[input_name] = rev_input_conn_name

            _, rev_input_arr = sdfg.add_array(rev_input_conn_name,
                                              shape=out_desc.shape,
                                              dtype=out_desc.dtype)
            _, rev_output_arr = sdfg.add_array(rev_output_conn_name,
                                               shape=in_desc.shape,
                                               dtype=in_desc.dtype)

            state.add_mapped_tasklet(
                "_distribute_grad_" + str(reduction_type).replace(".", "_") +
                "_", {
                    "i" + str(i): "0:{}".format(shape)
                    for i, shape in enumerate(in_desc.shape)
                }, {
                    "__in":
                    Memlet.simple(
                        rev_input_conn_name,
                        "0" if forward_node.axes is None else ",".join(
                            "i" + str(i) for i in non_reduce_axes))
                },
                "__out = __in", {
                    "__out":
                    Memlet.simple(rev_output_conn_name,
                                  ",".join("i" + str(i) for i in all_axes),
                                  wcr_str="lambda x, y: x + y")
                },
                external_edges=True)

            return context.backward_state.add_nested_sdfg(
                sdfg, None, {rev_input_conn_name},
                {rev_output_conn_name}), result
        else:
            raise AutoDiffException(
                "Unsupported reduction type '{}'".format(reduction_type))