Пример #1
0
class BaseOirSDFGBuilder(ABC):
    has_transients = True

    def __init__(self, name, stencil: Stencil, nodes):
        self._stencil = stencil
        self._sdfg = SDFG(name)
        self._state = self._sdfg.add_state(name + "_state")
        self._extents = nodes_extent_calculation(nodes)

        self._dtypes = {
            decl.name: decl.dtype
            for decl in stencil.declarations + stencil.params
        }
        self._axes = {
            decl.name: decl.dimensions
            for decl in stencil.declarations + stencil.params
            if isinstance(decl, FieldDecl)
        }

        self._recent_write_acc: Dict[str, dace.nodes.AccessNode] = dict()
        self._recent_read_acc: Dict[str, dace.nodes.AccessNode] = dict()

        self._access_nodes: Dict[str, dace.nodes.AccessNode] = dict()
        self._access_collection_cache: Dict[
            int, AccessCollector.CartesianAccessCollection] = dict()
        self._source_nodes: Dict[str, dace.nodes.AccessNode] = dict()
        self._delete_candidates: List[MultiConnectorEdge] = list()

    def _access_space_to_subset(self, name, access_space):
        extent = self._extents[name]
        origin = (extent[0][0], extent[1][0])
        subsets = []
        if self._axes[name][0]:
            subsets.append("{start}:__I{end:+d}".format(
                start=origin[0] + access_space[0][0],
                end=origin[0] + access_space[0][1]))
        if self._axes[name][1]:
            subsets.append("{start}:__J{end:+d}".format(
                start=origin[1] + access_space[1][0],
                end=origin[1] + access_space[1][1]))
        return subsets

    def _are_nodes_ordered(self, name, node1, node2):
        assert name in self._access_nodes
        assert node1.data == name
        assert node2.data == name
        return self._access_nodes[name].index(
            node1) < self._access_nodes[name].index(node2)

    def _get_source(self, name):
        if name not in self._source_nodes:
            self._source_nodes[name] = self._state.add_read(name)
            if name not in self._access_nodes:
                self._access_nodes[name] = []
            self._access_nodes[name].insert(0, self._source_nodes[name])
        return self._source_nodes[name]

    def _get_new_sink(self, name):
        res = self._state.add_access(name)
        if name not in self._access_nodes:
            self._access_nodes[name] = []
        self._access_nodes[name].append(res)
        return res

    def _get_current_sink(self, name):
        if name in self._access_nodes:
            return self._access_nodes[name][-1]
        return None

    def _get_access_collection(
        self, node:
        "Union[HorizontalExecutionLibraryNode, VerticalLoopLibraryNode, SDFG]"
    ) -> AccessCollector.CartesianAccessCollection:
        if isinstance(node, SDFG):
            res = AccessCollector.CartesianAccessCollection([])
            for node in node.states()[0].nodes():
                if isinstance(
                        node,
                    (HorizontalExecutionLibraryNode, VerticalLoopLibraryNode)):
                    collection = self._get_access_collection(node)
                    res._ordered_accesses.extend(collection._ordered_accesses)
            return res
        elif isinstance(node, HorizontalExecutionLibraryNode):
            if id(node.oir_node) not in self._access_collection_cache:
                self._access_collection_cache[id(
                    node.oir_node)] = AccessCollector.apply(
                        node.oir_node).cartesian_accesses()
            return self._access_collection_cache[id(node.oir_node)]
        else:
            assert isinstance(node, VerticalLoopLibraryNode)
            res = AccessCollector.CartesianAccessCollection([])
            for _, sdfg in node.sections:
                collection = self._get_access_collection(sdfg)
                res._ordered_accesses.extend(collection._ordered_accesses)
            return res

    def _get_recent_reads(self, name, interval):
        if name not in self._recent_read_acc:
            self._recent_read_acc[name] = IntervalMapping()
        if not self._axes[name][2]:
            interval = Interval.full()
        return self._recent_read_acc[name][interval]

    def _get_recent_writes(self, name, interval):
        if name not in self._recent_write_acc:
            self._recent_write_acc[name] = IntervalMapping()
        if not self._axes[name][2]:
            interval = Interval.full()
        return self._recent_write_acc[name][interval]

    def _set_read(self, name, interval, node):
        if name not in self._recent_read_acc:
            self._recent_read_acc[name] = IntervalMapping()
        if not self._axes[name][2]:
            interval = Interval.full()
        self._recent_read_acc[name][interval] = node

    def _set_write(self, name, interval, node):
        if name not in self._recent_write_acc:
            self._recent_write_acc[name] = IntervalMapping()
        if not self._axes[name][2]:
            interval = Interval.full()
        self._recent_write_acc[name][interval] = node

    def _reset_writes(self):
        self._recent_write_acc = dict()

    def _add_read_edges(
        self, node,
        collections: List[Tuple[Interval,
                                AccessCollector.CartesianAccessCollection]]):
        read_accesses: Dict[str, dace.nodes.AccessNode] = dict()
        for interval, access_collection in collections:

            for name in access_collection.read_fields():
                for offset in access_collection.read_offsets()[name]:
                    read_interval = interval.shifted(offset[2])
                    for candidate_access in self._get_recent_writes(
                            name, read_interval):
                        if name not in read_accesses or self._are_nodes_ordered(
                                name, read_accesses[name], candidate_access):
                            # candidate_access is downstream from recent_access, therefore candidate is more recent
                            read_accesses[name] = candidate_access

        for interval, access_collection in collections:
            for name in access_collection.read_fields():
                for offset in access_collection.read_offsets()[name]:
                    read_interval = interval.shifted(offset[2])
                    if name not in read_accesses:
                        read_accesses[name] = self._get_source(name)
                    self._set_read(name, read_interval, read_accesses[name])

        for name, recent_access in read_accesses.items():
            node.add_in_connector("IN_" + name)
            self._state.add_edge(recent_access, None, node, "IN_" + name,
                                 dace.Memlet())

    def _add_write_edges(
        self, node,
        collections: List[Tuple[Interval,
                                AccessCollector.CartesianAccessCollection]]):
        write_accesses = dict()
        for interval, access_collection in collections:
            for name in access_collection.write_fields():
                access_node = self._get_current_sink(name)
                if access_node is None or (
                    (name not in write_accesses) and
                    (access_node in self._get_recent_reads(name, interval)
                     or access_node in self._get_recent_writes(name, interval)
                     or nx.has_path(self._state.nx, access_node, node))):
                    write_accesses[name] = self._get_new_sink(name)
                else:
                    write_accesses[name] = access_node
                self._set_write(name, interval, write_accesses[name])

        for name, access_node in write_accesses.items():
            node.add_out_connector("OUT_" + name)
            self._state.add_edge(node, "OUT_" + name, access_node, None,
                                 dace.Memlet())

    def _add_write_after_write_edges(
        self, node,
        collections: List[Tuple[Interval,
                                AccessCollector.CartesianAccessCollection]]):
        for interval, collection in collections:
            for name in collection.write_fields():
                for src in self._get_recent_writes(name, interval):
                    edge = self._state.add_edge(src, None, node, None,
                                                dace.Memlet())
                    self._delete_candidates.append(edge)

    def _add_write_after_read_edges(
        self, node,
        collections: List[Tuple[Interval,
                                AccessCollector.CartesianAccessCollection]]):
        for interval, collection in collections:
            for name in collection.read_fields():
                for offset in collection.read_offsets()[name]:
                    read_interval = interval.shifted(offset[2])
                    for dst in self._get_recent_writes(name, read_interval):
                        edge = self._state.add_edge(node, None, dst, None,
                                                    dace.Memlet())
                        self._delete_candidates.append(edge)

        for interval, collection in collections:
            for name in collection.write_fields():
                self._set_write(name, interval, node)

    def add_node(self, node):
        self._state.add_node(node)

    def finalize(self):
        for edge in self._delete_candidates:
            assert edge.src_conn is None
            assert edge.dst_conn is None
            self._state.remove_edge(edge)
            if not nx.has_path(self._state.nx, edge.src, edge.dst):
                self._state.add_edge(edge.src, edge.src_conn, edge.dst,
                                     edge.dst_conn, edge.data)

        self.add_subsets()
        self.add_arrays()

        for acc in (n for n in self._state.nodes()
                    if isinstance(n, dace.nodes.AccessNode)):
            is_write = len(self._state.in_edges(acc)) > 0 and all(
                edge.data.data is not None
                for edge in self._state.in_edges(acc))
            is_read = len(self._state.out_edges(acc)) > 0 and all(
                edge.data.data is not None
                for edge in self._state.out_edges(acc))
            if is_read and is_write:
                acc.access = dace.AccessType.ReadWrite
            elif is_read:
                acc.access = dace.AccessType.ReadOnly
            else:
                assert is_write
                acc.access = dace.AccessType.WriteOnly

    def _get_sdfg(self):
        self.finalize()
        return self._sdfg

    def add_arrays(self):
        shapes = self.get_shapes()
        for decl in self._stencil.params + self._stencil.declarations:
            name = decl.name
            dtype = dace.dtypes.typeclass(
                np.dtype(data_type_to_typestr(self._dtypes[name])).name)
            if isinstance(decl, ScalarDecl):
                self._sdfg.add_symbol(name, stype=dtype)
            else:
                if name not in self._get_access_collection(
                        self._sdfg).offsets():
                    continue
                assert name in self._dtypes
                strides = tuple(
                    dace.symbolic.pystr_to_symbolic(f"__{name}_{var}_stride")
                    for is_axis, var in zip(self._axes[name], "IJK") if is_axis
                ) + tuple(
                    dace.symbolic.pystr_to_symbolic(f"__{name}_d{dim}_stride")
                    for dim, _ in enumerate(decl.data_dims))
                self._sdfg.add_array(
                    name,
                    dtype=dtype,
                    shape=shapes[name],
                    strides=strides,
                    transient=isinstance(decl, Temporary)
                    and self.has_transients,
                    lifetime=dace.AllocationLifetime.Persistent,
                )

    def add_subsets(self):
        decls = {
            decl.name: decl
            for decl in self._stencil.params + self._stencil.declarations
        }
        for node in self._state.nodes():
            if isinstance(node, dace.nodes.LibraryNode):
                access_spaces_input, access_spaces_output = self.get_access_spaces(
                    node)
                k_subset_strs_input, k_subset_strs_output = self.get_k_subsets(
                    node)
                for edge in self._state.in_edges(node) + self._state.out_edges(
                        node):
                    if edge.dst_conn is not None:
                        name = edge.src.data
                        access_space = access_spaces_input[name]
                        subset_str_k = k_subset_strs_input.get(name, None)
                        dynamic = isinstance(
                            node, HorizontalExecutionLibraryNode) and any(
                                isinstance(stmt, oir.MaskStmt)
                                for stmt in node.oir_node.body)
                    elif edge.src_conn is not None:
                        name = edge.dst.data
                        access_space = access_spaces_output[name]
                        subset_str_k = k_subset_strs_output.get(name, None)
                        dynamic = False
                    else:
                        continue
                    subset_strs = self._access_space_to_subset(
                        name, access_space)
                    if subset_str_k is not None:
                        subset_strs.append(subset_str_k)
                    for dim in decls[name].data_dims:
                        subset_strs.append(f"0:{dim}")
                    edge.data = dace.Memlet.simple(
                        data=name,
                        subset_str=",".join(subset_strs),
                        dynamic=dynamic)

    @abstractmethod
    def get_k_size(self, name):
        pass

    @abstractmethod
    def add_read_edges(self, node):
        pass

    @abstractmethod
    def add_write_edges(self, node):
        pass

    @abstractmethod
    def add_write_after_read_edges(self, node):
        pass

    @abstractmethod
    def add_write_after_write_edges(self, node):
        pass

    @abstractmethod
    def get_k_subsets(self, node):
        pass

    @abstractmethod
    def get_access_spaces(self, node):
        pass

    @abstractmethod
    def get_shapes(self):
        pass

    @classmethod
    def build(cls, name, stencil: Stencil,
              nodes: List[dace.nodes.LibraryNode]):
        builder = cls(name, stencil, nodes)
        for n in nodes:
            builder.add_node(n)
            builder.add_write_after_write_edges(n)
            builder.add_read_edges(n)
            builder.add_write_edges(n)
        builder._reset_writes()
        for n in reversed(nodes):
            builder.add_write_after_read_edges(n)
        res = builder._get_sdfg()
        res.validate()
        return res
Пример #2
0
def insert_sdfg_element(sdfg_str, type, parent_uuid, edge_a_uuid):
    sdfg_answer = load_sdfg_from_json(sdfg_str)
    sdfg = sdfg_answer['sdfg']
    uuid = 'error'
    ret = find_graph_element_by_uuid(sdfg, parent_uuid)
    parent = ret['element']

    libname = None
    if type is not None and isinstance(type, str):
        split_type = type.split('|')
        if len(split_type) == 2:
            type = split_type[0]
            libname = split_type[1]

    if type == 'SDFGState':
        if parent is None:
            parent = sdfg
        elif isinstance(parent, nodes.NestedSDFG):
            parent = parent.sdfg
        state = parent.add_state()
        uuid = [get_uuid(state)]
    elif type == 'AccessNode':
        arrays = list(parent.parent.arrays.keys())
        if len(arrays) == 0:
            parent.parent.add_array('tmp', [1], dtype=dtypes.float64)
            arrays = list(parent.parent.arrays.keys())
        node = parent.add_access(arrays[0])
        uuid = [get_uuid(node, parent)]
    elif type == 'Map':
        map_entry, map_exit = parent.add_map('map', dict(i='0:1'))
        uuid = [get_uuid(map_entry, parent), get_uuid(map_exit, parent)]
    elif type == 'Consume':
        consume_entry, consume_exit = parent.add_consume('consume', ('i', '1'))
        uuid = [get_uuid(consume_entry, parent), get_uuid(consume_exit, parent)]
    elif type == 'Tasklet':
        tasklet = parent.add_tasklet(
            name='placeholder',
            inputs={'in'},
            outputs={'out'},
            code='')
        uuid = [get_uuid(tasklet, parent)]
    elif type == 'NestedSDFG':
        sub_sdfg = SDFG('nested_sdfg')
        sub_sdfg.add_array('in', [1], dtypes.float32)
        sub_sdfg.add_array('out', [1], dtypes.float32)
        
        nsdfg = parent.add_nested_sdfg(sub_sdfg, sdfg, {'in'}, {'out'})
        uuid = [get_uuid(nsdfg, parent)]
    elif type == 'LibraryNode':
        if libname is None:
            return {
                'error': {
                    'message': 'Failed to add library node',
                    'details': 'Must provide a valid library node type',
                },
            }
        libnode_class = pydoc.locate(libname)
        libnode = libnode_class()
        parent.add_node(libnode)
        uuid = [get_uuid(libnode, parent)]
    elif type == 'Edge':
        edge_start_ret = find_graph_element_by_uuid(sdfg, edge_a_uuid)
        edge_start = edge_start_ret['element']
        edge_parent = edge_start_ret['parent']
        if edge_start is not None:
            if edge_parent is None:
                edge_parent = sdfg

            if isinstance(edge_parent, SDFGState):
                if not (isinstance(edge_start, nodes.Node) and
                        isinstance(parent, nodes.Node)):
                    return {
                        'error': {
                            'message': 'Failed to add edge',
                            'details': 'Must connect two nodes or two states',
                        },
                    }
                memlet = Memlet()
                edge_parent.add_edge(edge_start, None, parent, None, memlet)
            elif isinstance(edge_parent, SDFG):
                if not (isinstance(edge_start, SDFGState) and
                        isinstance(parent, SDFGState)):
                    return {
                        'error': {
                            'message': 'Failed to add edge',
                            'details': 'Must connect two nodes or two states',
                        },
                    }
                isedge = InterstateEdge()
                edge_parent.add_edge(edge_start, parent, isedge)
            uuid = ['NONE']
        else:
            raise ValueError('No edge starting point provided')

    old_meta = disable_save_metadata()
    new_sdfg_str = sdfg.to_json()
    restore_save_metadata(old_meta)

    return {
        'sdfg': new_sdfg_str,
        'uuid': uuid,
    }
Пример #3
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the subgraph, execute it and insert an AccessNode to the result

        parent: ONNXModel = sdfg._parent_onnx_model
        state = sdfg.nodes()[self.state_id]
        node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]]

        if isinstance(node, donnx.ONNXShape):
            # if we have a shape node, replace it with a constant
            assert len(state.in_edges(node)) == 1
            shape_in_edge = state.in_edges(node)[0]
            assert shape_in_edge.dst_conn == "data"
            shape_desc = sdfg.arrays[shape_in_edge.src.data]

            constant_name = sdfg.temp_data_name()
            clean_constant_name = clean_onnx_name(constant_name)
            sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ),
                           dace.int64)

            assert constant_name not in parent.clean_weights
            parent.weights[constant_name] = np.array(shape_desc.shape,
                                                     np.int64)

            assert len(state.out_edges(node)) == 1
            output_edge = state.out_edges(node)[0]
            access_shape = state.add_access(clean_constant_name)
            state.add_edge(access_shape, None, output_edge.dst,
                           output_edge.dst_conn,
                           sdfg.make_array_memlet(clean_constant_name))
        else:
            # otherwise compute the result of the op
            sub_sdfg = dace.SDFG("sub_sdfg")
            sub_state = sub_sdfg.add_state()

            node_copy = copy.deepcopy(node)
            sub_state.add_node(node_copy)

            inputs = {}
            for edge in state.in_edges(node):
                # we know from can_be_applied that all in edges are from AccessNodes
                assert (isinstance(edge.src, nd.AccessNode)
                        and hasattr(sdfg, "_parent_onnx_model") and
                        edge.src.data in sdfg._parent_onnx_model.clean_weights)

                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc)

                input_value = sdfg._parent_onnx_model.clean_weights[
                    edge.src.data]

                if len(input_value.shape) == 0:
                    inputs['array_' + edge.dst_conn] = input_value[()]
                else:
                    inputs['array_' + edge.dst_conn] = input_value.copy()

                access = sub_state.add_access('array_' + edge.dst_conn)
                sub_state.add_edge(
                    access, None, node_copy, edge.dst_conn,
                    sub_sdfg.make_array_memlet('array_' + edge.dst_conn))

            outputs = {}
            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                if isinstance(desc, dt.Scalar):
                    # we need to copy to an array of size [1] so that we can "return" the output from the sdfg
                    desc.transient = True
                    sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn,
                                          desc)
                    sub_sdfg.add_array('array_' + edge.src_conn, [1],
                                       desc.dtype,
                                       transient=False)

                    access_scalar = sub_state.add_access('scalar_array_' +
                                                         edge.src_conn)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access_scalar, None,
                        sub_sdfg.make_array_memlet('scalar_array_' +
                                                   edge.src_conn))

                    sub_state.add_edge(
                        access_scalar, None, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))
                else:
                    desc.transient = False
                    sub_sdfg.add_datadesc('array_' + edge.src_conn, desc)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))

                if len(desc.shape) == 0:
                    outputs['array_' + edge.src_conn] = np.empty(
                        (1, ), desc.dtype.as_numpy_dtype())
                else:
                    outputs['array_' + edge.src_conn] = np.empty(
                        tuple(desc.shape), desc.dtype.as_numpy_dtype())

            sub_sdfg(**outputs, **inputs)

            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                output_value = outputs['array_' + edge.src_conn]

                constant_name = sdfg.temp_data_name()
                clean_constant_name = clean_onnx_name(constant_name)
                sdfg.add_datadesc(clean_constant_name, desc)

                assert constant_name not in parent.weights
                if isinstance(desc, dt.Scalar):
                    parent.weights[constant_name] = output_value.reshape(())
                else:
                    parent.weights[constant_name] = output_value

                access_constant = state.add_access(clean_constant_name)
                state.add_edge(access_constant, None, edge.dst, edge.dst_conn,
                               sdfg.make_array_memlet(clean_constant_name))

        # remove all now useless nodes with a reverse BFS
        queue = deque([node])
        while len(queue) > 0:
            current_node = queue.popleft()

            edges = state.in_edges(current_node)
            state.remove_node(current_node)
            for e in edges:
                next_node = e.src
                if len(state.out_edges(next_node)) == 0:
                    queue.append(next_node)
Пример #4
0
    def apply(self, sdfg: dace.SDFG):
        # Extract the subgraph, execute it and insert an AccessNode to the result
        # this method of execution is slow but simple. A better option would be to call the ORT
        # C API from a python object (like the OpChecker).

        parent: ONNXModel = sdfg._parent_onnx_model
        state = sdfg.nodes()[self.state_id]
        node = state.nodes()[self.subgraph[ConstantFolding._onnx_node]]
        log.debug(f"Applying constant folding: {node} in {state}")

        if isinstance(node, donnx.ONNXShape):
            # if we have a shape node, replace it with a constant
            assert len(state.in_edges(node)) == 1
            shape_in_edge = state.in_edges(node)[0]
            assert shape_in_edge.dst_conn == "data"
            shape_desc = sdfg.arrays[shape_in_edge.src.data]

            constant_name = sdfg.temp_data_name()
            clean_constant_name = clean_onnx_name(constant_name)
            sdfg.add_array(clean_constant_name, (len(shape_desc.shape), ),
                           dace.int64)

            assert constant_name not in parent.clean_weights
            parent.weights[constant_name] = torch.from_numpy(
                np.array(shape_desc.shape, np.int64))

            assert len(state.out_edges(node)) == 1
            output_edge = state.out_edges(node)[0]
            access_shape = state.add_access(clean_constant_name)
            state.add_edge(access_shape, None, output_edge.dst,
                           output_edge.dst_conn,
                           sdfg.make_array_memlet(clean_constant_name))
        else:
            # otherwise compute the result of the op
            global UNIQUE_ID
            UNIQUE_ID += 1
            sub_sdfg = dace.SDFG("sub_sdfg_" + str(UNIQUE_ID))
            sub_state = sub_sdfg.add_state()

            node_copy = copy.deepcopy(node)
            sub_state.add_node(node_copy)

            inputs = {}
            for edge in state.in_edges(node):
                # we know from can_be_applied that all in edges are from AccessNodes
                assert (isinstance(edge.src, nd.AccessNode)
                        and hasattr(sdfg, "_parent_onnx_model") and
                        edge.src.data in sdfg._parent_onnx_model.clean_weights)

                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                sub_sdfg.add_datadesc('array_' + edge.dst_conn, desc)

                input_value = sdfg._parent_onnx_model.clean_weights[
                    edge.src.data]

                if len(input_value.shape) == 0:
                    inputs['array_' +
                           edge.dst_conn] = input_value.cpu().numpy()[()]
                else:
                    inputs['array_' + edge.dst_conn] = input_value.clone()

                access = sub_state.add_access('array_' + edge.dst_conn)
                sub_state.add_edge(
                    access, None, node_copy, edge.dst_conn,
                    sub_sdfg.make_array_memlet('array_' + edge.dst_conn))

            outputs = {}
            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                if isinstance(desc, dt.Scalar):
                    # we need to copy to an array of size [1] so that we can "return" the output from the sdfg
                    desc.transient = True
                    sub_sdfg.add_datadesc('scalar_array_' + edge.src_conn,
                                          desc)
                    sub_sdfg.add_array('array_' + edge.src_conn, [1],
                                       desc.dtype,
                                       transient=False)

                    access_scalar = sub_state.add_access('scalar_array_' +
                                                         edge.src_conn)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access_scalar, None,
                        sub_sdfg.make_array_memlet('scalar_array_' +
                                                   edge.src_conn))

                    sub_state.add_edge(
                        access_scalar, None, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))
                else:
                    desc.transient = False
                    sub_sdfg.add_datadesc('array_' + edge.src_conn, desc)
                    access = sub_state.add_access('array_' + edge.src_conn)
                    sub_state.add_edge(
                        node_copy, edge.src_conn, access, None,
                        sub_sdfg.make_array_memlet('array_' + edge.src_conn))

                if len(desc.shape) == 0:
                    empty_array = np.empty((1, ), desc.dtype.as_numpy_dtype())
                else:
                    empty_array = np.empty(tuple(desc.shape),
                                           desc.dtype.as_numpy_dtype())

                empty_array = torch.from_numpy(empty_array)

                if desc.storage is dtypes.StorageType.GPU_Global:
                    empty_array = empty_array.cuda()

                outputs['array_' + edge.src_conn] = empty_array

            sub_sdfg(**outputs, **inputs)

            for edge in state.out_edges(node):
                desc = copy.deepcopy(sdfg.arrays[edge.data.data])
                desc.transient = False
                output_value = outputs['array_' + edge.src_conn]

                constant_name = sdfg.temp_data_name()
                clean_constant_name = clean_onnx_name(constant_name)
                sdfg.add_datadesc(clean_constant_name, desc)

                assert constant_name not in parent.weights
                assert type(output_value) is torch.Tensor

                if not dtypes.can_access(dtypes.ScheduleType.CPU_Multicore,
                                         desc.storage):
                    cpu_desc = copy.deepcopy(desc)
                    cpu_desc.storage = dtypes.StorageType.CPU_Heap
                    cpu_desc.transient = False
                    desc.transient = True
                    copy_in_name = sdfg.temp_data_name()
                    clean_copy_in_name = clean_onnx_name(copy_in_name)
                    sdfg.add_datadesc(clean_copy_in_name, cpu_desc)

                    access_constant = state.add_access(clean_constant_name)
                    state.add_edge(state.add_read(clean_copy_in_name), None,
                                   access_constant, None,
                                   sdfg.make_array_memlet(clean_copy_in_name))

                    name_to_add = copy_in_name
                else:
                    access_constant = state.add_read(clean_constant_name)
                    name_to_add = constant_name

                if isinstance(desc, dt.Scalar):
                    parent.weights[name_to_add] = output_value.reshape(())
                else:
                    parent.weights[name_to_add] = output_value

                state.add_edge(access_constant, None, edge.dst, edge.dst_conn,
                               sdfg.make_array_memlet(clean_constant_name))

        # remove all now useless nodes with a reverse BFS
        remove_node_and_computation(sdfg, state, node)
Пример #5
0
    def backward(
        forward_node: Node, context: BackwardContext,
        given_gradients: typing.List[typing.Optional[str]],
        required_gradients: typing.List[typing.Optional[str]]
    ) -> typing.Tuple[Node, BackwardResult]:
        reduction_type = detect_reduction_type(forward_node.wcr)

        if len(given_gradients) != 1:
            raise AutoDiffException(
                "recieved invalid SDFG: reduce node {} should have exactly one output edge"
                .format(forward_node))

        if len(required_gradients) != 1:
            raise AutoDiffException(
                "recieved invalid SDFG: reduce node {} should have exactly one input edge"
                .format(forward_node))

        input_name = next(iter(required_gradients))
        in_desc = in_desc_with_name(forward_node, context.forward_state,
                                    context.forward_sdfg, input_name)

        output_name = next(iter(given_gradients))
        out_desc = out_desc_with_name(forward_node, context.forward_state,
                                      context.forward_sdfg, output_name)

        all_axes: typing.List[int] = list(range(len(in_desc.shape)))
        reduce_axes: typing.List[
            int] = all_axes if forward_node.axes is None else forward_node.axes
        non_reduce_axes: typing.List[int] = [
            i for i in all_axes if i not in reduce_axes
        ]

        result = BackwardResult.empty()

        if reduction_type is dtypes.ReductionType.Sum:
            # in this case, we need to simply scatter the grad across the axes that were reduced

            sdfg = SDFG("_reverse_" + str(reduction_type).replace(".", "_") +
                        "_")
            state = sdfg.add_state()

            rev_input_conn_name = "input_gradient"
            rev_output_conn_name = "output_gradient"
            result.required_grad_names[output_name] = rev_output_conn_name
            result.given_grad_names[input_name] = rev_input_conn_name

            _, rev_input_arr = sdfg.add_array(rev_input_conn_name,
                                              shape=out_desc.shape,
                                              dtype=out_desc.dtype)
            _, rev_output_arr = sdfg.add_array(rev_output_conn_name,
                                               shape=in_desc.shape,
                                               dtype=in_desc.dtype)

            state.add_mapped_tasklet(
                "_distribute_grad_" + str(reduction_type).replace(".", "_") +
                "_", {
                    "i" + str(i): "0:{}".format(shape)
                    for i, shape in enumerate(in_desc.shape)
                }, {
                    "__in":
                    Memlet.simple(
                        rev_input_conn_name,
                        "0" if forward_node.axes is None else ",".join(
                            "i" + str(i) for i in non_reduce_axes))
                },
                "__out = __in", {
                    "__out":
                    Memlet.simple(rev_output_conn_name,
                                  ",".join("i" + str(i) for i in all_axes),
                                  wcr_str="lambda x, y: x + y")
                },
                external_edges=True)

            return context.backward_state.add_nested_sdfg(
                sdfg, None, {rev_input_conn_name},
                {rev_output_conn_name}), result
        else:
            raise AutoDiffException(
                "Unsupported reduction type '{}'".format(reduction_type))