Пример #1
0
def mkc(sdfg: dace.SDFG,
        state_before,
        src_name,
        dst_name,
        src_storage=None,
        dst_storage=None,
        src_shape=None,
        dst_shape=None,
        copy_expr=None,
        src_loc=None,
        dst_loc=None):
    """
    Helper MaKe_Copy that creates and appends states performing exactly one copy. If a provided
    arrayname already exists it will use the old array, and ignore all newly passed values
    """

    if copy_expr is None:
        copy_expr = src_name
    if (state_before == None):
        state = sdfg.add_state(is_start_state=True)
    else:
        state = sdfg.add_state_after(state_before)

    def mkarray(name, shape, storage, loc):
        if (name in sdfg.arrays):
            return sdfg.arrays[name]
        is_transient = False
        if (storage in _FPGA_STORAGE_TYPES):
            is_transient = True
        arr = sdfg.add_array(name,
                             shape,
                             dace.int32,
                             storage,
                             transient=is_transient)
        if loc is not None:
            arr[1].location["memorytype"] = loc[0]
            arr[1].location["bank"] = loc[1]
        return arr

    a = mkarray(src_name, src_shape, src_storage, src_loc)
    b = mkarray(dst_name, dst_shape, dst_storage, dst_loc)

    aAcc = state.add_access(src_name)
    bAcc = state.add_access(dst_name)

    edge = state.add_edge(aAcc, None, bAcc, None, mem.Memlet(copy_expr))

    a_np_arr, b_np_arr = None, None
    if src_shape is not None:
        try:
            a_np_arr = np.zeros(src_shape, dtype=np.int32)
        except:
            pass
    if dst_shape is not None:
        try:
            b_np_arr = np.zeros(dst_shape, dtype=np.int32)
        except:
            pass
    return (state, a_np_arr, b_np_arr)
Пример #2
0
def add_backward_pass(
    sdfg: SDFG,
    state: SDFGState,
    outputs: typing.List[typing.Union[nd.AccessNode, str]],
    inputs: typing.List[typing.Union[nd.AccessNode, str]],
):
    """ Experimental: Add a backward pass to `state` using reverse-mode automatic differentiation.

        ``inputs``, ``outputs`` and ``grads`` can be provided either as ``AccessNode`` nodes, or as ``str``, in which
        case the graph will be searched for exactly one matching ``AccessNode`` with data matching the ``str``.

        The SDFG should not contain any inplace operations. It may contain the following nodes:

        * Maps
        * AccessNodes
        * Reductions (Sum, Min, Max)
        * ONNXOps
        * NestedSDFGs containing a single SDFGState (subject to the same constraints). NestedSDFGs may contain multiple
          states as long as all other states are only used for zero initialization.

        When differentiating an :class:`~daceml.onnx.nodes.onnx_op.ONNXOp`, the ONNXBackward registry will be checked
        for any matching backward pass implementations. If none are found, the ONNXForward registry will be checked for
        matching pure implementations. If one is found, symbolic differentiation of the pure implementation will be
        attempted. If this fails, or no pure forward implementation is found, the method will fail.


        :param sdfg: the parent SDFG of ``state``.
        :param state: the state to add the backward pass to. This is also the state of the forward pass.
        :param outputs: the forward pass outputs of the function to differentiate.
        :param inputs: the inputs w.r.t. which the gradient will be returned.
    """
    sdfg.validate()

    backward_state = sdfg.add_state_after(state)
    gen = BackwardPassGenerator(sdfg=sdfg,
                                state=state,
                                given_gradients=outputs,
                                required_gradients=inputs,
                                backward_sdfg=sdfg,
                                backward_state=backward_state)
    gen.backward()
Пример #3
0
    def apply(self, sdfg: SDFG):
        input: nodes.AccessNode = self.input(sdfg)
        tasklet: nodes.Tasklet = self.tasklet(sdfg)
        output: nodes.AccessNode = self.output(sdfg)
        state: SDFGState = sdfg.node(self.state_id)

        # If state fission is necessary to keep semantics, do it first
        if (self.expr_index == 0 and state.in_degree(input) > 0
                and state.out_degree(output) == 0):
            newstate = sdfg.add_state_after(state)
            newstate.add_node(tasklet)
            new_input, new_output = None, None

            # Keep old edges for after we remove tasklet from the original state
            in_edges = list(state.in_edges(tasklet))
            out_edges = list(state.out_edges(tasklet))

            for e in in_edges:
                r = newstate.add_read(e.src.data)
                newstate.add_edge(r, e.src_conn, e.dst, e.dst_conn, e.data)
                if e.src is input:
                    new_input = r
            for e in out_edges:
                w = newstate.add_write(e.dst.data)
                newstate.add_edge(e.src, e.src_conn, w, e.dst_conn, e.data)
                if e.dst is output:
                    new_output = w

            # Remove tasklet and resulting isolated nodes
            state.remove_node(tasklet)
            for e in in_edges:
                if state.degree(e.src) == 0:
                    state.remove_node(e.src)
            for e in out_edges:
                if state.degree(e.dst) == 0:
                    state.remove_node(e.dst)

            # Reset state and nodes for rest of transformation
            input = new_input
            output = new_output
            state = newstate
        # End of state fission

        if self.expr_index == 0:
            inedges = state.edges_between(input, tasklet)
            outedge = state.edges_between(tasklet, output)[0]
        else:
            me = self.map_entry(sdfg)
            mx = self.map_exit(sdfg)

            inedges = state.edges_between(me, tasklet)
            outedge = state.edges_between(tasklet, mx)[0]

        # Get relevant output connector
        outconn = outedge.src_conn

        ops = '[%s]' % ''.join(
            re.escape(o) for o in AugAssignToWCR._EXPRESSIONS)

        # Change tasklet code
        if tasklet.language is dtypes.Language.Python:
            raise NotImplementedError
        elif tasklet.language is dtypes.Language.CPP:
            cstr = tasklet.code.as_string.strip()
            for edge in inedges:
                inconn = edge.dst_conn
                match = re.match(
                    r'^\s*%s\s*=\s*%s\s*(%s)(.*);$' %
                    (re.escape(outconn), re.escape(inconn), ops), cstr)
                if match is None:
                    # match = re.match(
                    #     r'^\s*%s\s*=\s*(.*)\s*(%s)\s*%s;$' %
                    #     (re.escape(outconn), ops, re.escape(inconn)), cstr)
                    # if match is None:
                    continue
                    # op = match.group(2)
                    # expr = match.group(1)
                else:
                    op = match.group(1)
                    expr = match.group(2)

                if edge.data.subset != outedge.data.subset:
                    continue


                # Map asymmetric WCRs to symmetric ones if possible
                if op in AugAssignToWCR._EXPR_MAP:
                    op, newexpr = AugAssignToWCR._EXPR_MAP[op]
                    expr = newexpr.format(expr=expr)

                tasklet.code.code = '%s = %s;' % (outconn, expr)
                inedge = edge
                break
        else:
            raise NotImplementedError

        # Change output edge
        outedge.data.wcr = f'lambda a,b: a {op} b'

        if self.expr_index == 0:
            # Remove input node and connector
            state.remove_edge_and_connectors(inedge)
            if state.degree(input) == 0:
                state.remove_node(input)
        else:
            # Remove input edge and dst connector, but not necessarily src
            state.remove_memlet_path(inedge)

        # If outedge leads to non-transient, and this is a nested SDFG,
        # propagate outwards
        sd = sdfg
        while (not sd.arrays[outedge.data.data].transient
               and sd.parent_nsdfg_node is not None):
            nsdfg = sd.parent_nsdfg_node
            nstate = sd.parent
            sd = sd.parent_sdfg
            outedge = next(
                iter(nstate.out_edges_by_connector(nsdfg, outedge.data.data)))
            for outedge in nstate.memlet_path(outedge):
                outedge.data.wcr = f'lambda a,b: a {op} b'