Exemple #1
0
def make_sdfg(implementation,
              dtype,
              id=0,
              in_shape=[n, n],
              out_shape=[n, n],
              in_subset="0:n, 0:n",
              out_subset="0:n, 0:n"):

    sdfg = dace.SDFG("linalg_solve_{}_{}_{}".format(implementation, dtype.__name__, id))
    sdfg.add_symbol("n", dace.int64)
    state = sdfg.add_state("dataflow")

    sdfg.add_array("ain", in_shape, dtype)
    sdfg.add_array("bin", out_shape, dtype)
    sdfg.add_array("bout", out_shape, dtype)

    ain = state.add_read("ain")
    bin = state.add_read("bin")
    bout = state.add_write("bout")

    solve_node = Solve("solve")
    solve_node.implementation = implementation

    state.add_memlet_path(ain, solve_node, dst_conn="_ain", memlet=Memlet.simple(ain, in_subset, num_accesses=n * n))
    state.add_memlet_path(bin, solve_node, dst_conn="_bin", memlet=Memlet.simple(bin, out_subset, num_accesses=n * n))
    state.add_memlet_path(solve_node,
                          bout,
                          src_conn="_bout",
                          memlet=Memlet.simple(bout, out_subset, num_accesses=n * n))

    return sdfg
Exemple #2
0
def make_sdfg(implementation,
              dtype,
              id=0,
              in_shape=[n, n],
              out_shape=[n, n],
              in_subset="0:n, 0:n",
              out_subset="0:n, 0:n",
              overwrite=False,
              getri=True):

    sdfg = dace.SDFG("linalg_inv_{}_{}_{}".format(implementation, dtype.__name__, id))
    sdfg.add_symbol("n", dace.int64)
    state = sdfg.add_state("dataflow")

    sdfg.add_array("xin", in_shape, dtype)
    if not overwrite:
        sdfg.add_array("xout", out_shape, dtype)

    xin = state.add_read("xin")
    if overwrite:
        xout = state.add_write("xin")
    else:
        xout = state.add_write("xout")

    inv_node = Inv("inv", overwrite_a=overwrite, use_getri=getri)
    inv_node.implementation = implementation

    state.add_memlet_path(xin, inv_node, dst_conn="_ain", memlet=Memlet.simple(xin, in_subset, num_accesses=n * n))
    state.add_memlet_path(inv_node, xout, src_conn="_aout", memlet=Memlet.simple(xout, out_subset, num_accesses=n * n))

    return sdfg
Exemple #3
0
def make_sdfg(implementation, dtype, storage=dace.StorageType.Default):

    n = dace.symbol("n", dace.int64)

    sdfg = dace.SDFG("linalg_cholesky_{}_{}".format(implementation, dtype))
    state = sdfg.add_state("dataflow")

    inp = sdfg.add_array("xin", [n, n], dtype)
    out = sdfg.add_array("xout", [n, n], dtype)

    xin = state.add_read("xin")
    xout = state.add_write("xout")

    chlsky_node = Cholesky("cholesky", lower=True)
    chlsky_node.implementation = implementation

    state.add_memlet_path(xin,
                          chlsky_node,
                          dst_conn="_a",
                          memlet=Memlet.from_array(*inp))
    state.add_memlet_path(chlsky_node,
                          xout,
                          src_conn="_b",
                          memlet=Memlet.from_array(*out))

    return sdfg
Exemple #4
0
def _make_sdfg(node, parent_state, parent_sdfg, implementation):

    inp_desc, inp_shape, out_desc, out_shape = node.validate(parent_sdfg, parent_state)
    dtype = inp_desc.dtype

    sdfg = dace.SDFG("{l}_sdfg".format(l=node.label))

    ain_arr = sdfg.add_array('_a', inp_shape, dtype=dtype, strides=inp_desc.strides)
    bout_arr = sdfg.add_array('_b', out_shape, dtype=dtype, strides=out_desc.strides)
    info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True)
    if implementation == 'cuSolverDn':
        binout_arr = sdfg.add_array('_bt', inp_shape, dtype=dtype, transient=True)
    else:
        binout_arr = bout_arr

    state = sdfg.add_state("{l}_state".format(l=node.label))

    potrf_node = Potrf('potrf', lower=node.lower)
    potrf_node.implementation = implementation

    _, me, mx = state.add_mapped_tasklet('_uzero_',
                                         dict(__i="0:%s" % out_shape[0], __j="0:%s" % out_shape[1]),
                                         dict(_inp=Memlet.simple('_b', '__i, __j')),
                                         '_out = (__i < __j) ? 0 : _inp;',
                                         dict(_out=Memlet.simple('_b', '__i, __j')),
                                         language=dace.dtypes.Language.CPP,
                                         external_edges=True)

    ain = state.add_read('_a')
    if implementation == 'cuSolverDn':
        binout1 = state.add_access('_bt')
        binout2 = state.add_access('_bt')
        binout3 = state.in_edges(me)[0].src
        bout = state.out_edges(mx)[0].dst
        transpose_ain = Transpose('AT', dtype=dtype)
        transpose_ain.implementation = 'cuBLAS'
        state.add_edge(ain, None, transpose_ain, '_inp', Memlet.from_array(*ain_arr))
        state.add_edge(transpose_ain, '_out', binout1, None, Memlet.from_array(*binout_arr))
        transpose_out = Transpose('BT', dtype=dtype)
        transpose_out.implementation = 'cuBLAS'
        state.add_edge(binout2, None, transpose_out, '_inp', Memlet.from_array(*binout_arr))
        state.add_edge(transpose_out, '_out', binout3, None, Memlet.from_array(*bout_arr))
    else:
        binout1 = state.add_access('_b')
        binout2 = state.in_edges(me)[0].src
        binout3 = state.out_edges(mx)[0].dst
        state.add_nedge(ain, binout1, Memlet.from_array(*ain_arr))

    info = state.add_write('_info')

    state.add_memlet_path(binout1, potrf_node, dst_conn="_xin", memlet=Memlet.from_array(*binout_arr))
    state.add_memlet_path(potrf_node, info, src_conn="_res", memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(potrf_node, binout2, src_conn="_xout", memlet=Memlet.from_array(*binout_arr))

    return sdfg
Exemple #5
0
    def apply(self, sdfg):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg: SDFG = nsdfg_node.sdfg
        nstate: SDFGState = nsdfg.nodes()[0]

        nsdfg_scope_entry = state.entry_node(nsdfg_node)
        nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry)
                            if nsdfg_scope_entry is not None else None)

        #######################################################
        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        for loc, code in nsdfg.global_code.items():
            sdfg.append_global_code(code.code, loc)
        for loc, code in nsdfg.init_code.items():
            sdfg.append_init_code(code.code, loc)
        for loc, code in nsdfg.exit_code.items():
            sdfg.append_exit_code(code.code, loc)

        # Constants
        for cstname, cstval in nsdfg.constants.items():
            if cstname in sdfg.constants:
                if cstval != sdfg.constants[cstname]:
                    warnings.warn('Constant value mismatch for "%s" while '
                                  'inlining SDFG. Inner = %s != %s = outer' %
                                  (cstname, cstval, sdfg.constants[cstname]))
            else:
                sdfg.add_constant(cstname, cstval)

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # All transients become transients of the parent (if data already
        # exists, find new name)
        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode):
                datadesc = nsdfg.arrays[node.data]
                if node.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, node.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[node.data] = name

        # All transients of edges between code nodes are also added to parent
        for edge in nstate.edges():
            if (isinstance(edge.src, nodes.CodeNode)
                    and isinstance(edge.dst, nodes.CodeNode)):
                datadesc = nsdfg.arrays[edge.data.data]
                if edge.data.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, edge.data.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[edge.data.data] = name

        # Collect nodes to add to top-level graph
        new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {}
        new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {}

        source_accesses = set()
        sink_accesses = set()
        for node in nstate.source_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients):
                new_incoming_edges[node] = inputs[node.data]
                source_accesses.add(node)
        for node in nstate.sink_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients):
                new_outgoing_edges[node] = outputs[node.data]
                sink_accesses.add(node)

        #######################################################
        # Add nested SDFG into top-level SDFG

        # Add nested nodes into original state
        subgraph = SubgraphView(nstate, [
            n for n in nstate.nodes()
            if n not in (source_accesses | sink_accesses)
        ])
        state.add_nodes_from(subgraph.nodes())
        for edge in subgraph.edges():
            state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Replace data on inlined SDFG nodes/edges

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace(symname, '__dacesym_' + symname)
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace('__dacesym_' + symname, symvalue)

        # Replace data names with their top-level counterparts
        repldict = {}
        repldict.update(transients)
        repldict.update({
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())
        })
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in repldict:
                node.data = repldict[node.data]
        for edge in subgraph.edges():
            if edge.data.data in repldict:
                edge.data.data = repldict[edge.data.data]

        #######################################################
        # Reconnect inlined SDFG

        # If a source/sink node is one of the inputs/outputs, reconnect it,
        # replacing memlets in outgoing/incoming paths
        modified_edges = set()
        modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
                                                   state, True)
        modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
                                                   state, False)

        # Modify all other internal edges pertaining to input/output nodes
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode):
                if node.data in input_set or node.data in output_set:
                    if node.data in input_set:
                        outer_edge = inputs[input_set[node.data]]
                    else:
                        outer_edge = outputs[output_set[node.data]]

                    for edge in state.all_edges(node):
                        if (edge not in modified_edges
                                and edge.data.data == node.data):
                            for e in state.memlet_tree(edge):
                                if e.data.data == node.data:
                                    e._data = helpers.unsqueeze_memlet(
                                        e.data, outer_edge.data)

        # If source/sink node is not connected to a source/destination access
        # node, and the nested SDFG is in a scope, connect to scope with empty
        # memlets
        if nsdfg_scope_entry is not None:
            for node in subgraph.nodes():
                if state.in_degree(node) == 0:
                    state.add_edge(nsdfg_scope_entry, None, node, None,
                                   Memlet())
                if state.out_degree(node) == 0:
                    state.add_edge(node, None, nsdfg_scope_exit, None,
                                   Memlet())

        # Replace nested SDFG parents with new SDFG
        for node in nstate.nodes():
            if isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = state
                node.sdfg.parent_sdfg = sdfg
                node.sdfg.parent_nsdfg_node = node

        # Remove all unused external inputs/output memlet paths, as well as
        # resulting isolated nodes
        removed_in_edges = self._remove_edge_path(state,
                                                  inputs,
                                                  set(inputs.keys()) -
                                                  source_accesses,
                                                  reverse=True)
        removed_out_edges = self._remove_edge_path(state,
                                                   outputs,
                                                   set(outputs.keys()) -
                                                   sink_accesses,
                                                   reverse=False)

        # Re-add in/out edges to first/last nodes in subgraph
        order = [
            x for x in nx.topological_sort(nstate._nx)
            if isinstance(x, nodes.AccessNode)
        ]
        for edge in removed_in_edges:
            # Find first access node that refers to this edge
            node = next(n for n in order if n.data == edge.data.data)
            state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn,
                           edge.data)
        for edge in removed_out_edges:
            # Find last access node that refers to this edge
            node = next(n for n in reversed(order) if n.data == edge.data.data)
            state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Remove nested SDFG node
        state.remove_node(nsdfg_node)
Exemple #6
0
def test_duplicate_codegen():

    # Unfortunately I have to generate this graph manually, as doing it with the python
    # frontend wouldn't result in the node ordering that we want

    sdfg = dace.SDFG("dup")
    state = sdfg.add_state()

    c_task = state.add_tasklet("c_task",
                               inputs={"c"},
                               outputs={"d"},
                               code='d = c')
    e_task = state.add_tasklet("e_task",
                               inputs={"a", "d"},
                               outputs={"e"},
                               code="e = a + d")
    f_task = state.add_tasklet("f_task",
                               inputs={"b", "d"},
                               outputs={"f"},
                               code="f = b + d")

    _, A_arr = sdfg.add_array("A", [
        1,
    ], dace.float32)
    _, B_arr = sdfg.add_array("B", [
        1,
    ], dace.float32)
    _, C_arr = sdfg.add_array("C", [
        1,
    ], dace.float32)
    _, D_arr = sdfg.add_array("D", [
        1,
    ], dace.float32)
    _, E_arr = sdfg.add_array("E", [
        1,
    ], dace.float32)
    _, F_arr = sdfg.add_array("F", [
        1,
    ], dace.float32)
    A = state.add_read("A")
    B = state.add_read("B")
    C = state.add_read("C")
    D = state.add_access("D")
    E = state.add_write("E")
    F = state.add_write("F")

    state.add_edge(C, None, c_task, "c", Memlet.from_array("C", C_arr))
    state.add_edge(c_task, "d", D, None, Memlet.from_array("D", D_arr))

    state.add_edge(A, None, e_task, "a", Memlet.from_array("A", A_arr))
    state.add_edge(B, None, f_task, "b", Memlet.from_array("B", B_arr))
    state.add_edge(D, None, f_task, "d", Memlet.from_array("D", D_arr))
    state.add_edge(D, None, e_task, "d", Memlet.from_array("D", D_arr))

    state.add_edge(e_task, "e", E, None,
                   Memlet.from_array("E", E_arr, wcr="lambda x, y: x + y"))
    state.add_edge(f_task, "f", F, None,
                   Memlet.from_array("F", F_arr, wcr="lambda x, y: x + y"))

    A = np.array([1], dtype=np.float32)
    B = np.array([1], dtype=np.float32)
    C = np.array([1], dtype=np.float32)
    D = np.array([1], dtype=np.float32)
    E = np.zeros_like(A)
    F = np.zeros_like(A)

    sdfg(A=A, B=B, C=C, D=D, E=E, F=F)

    assert E[0] == 2
    assert F[0] == 2
Exemple #7
0
    def backward(
        forward_node: Node, context: BackwardContext,
        given_gradients: typing.List[typing.Optional[str]],
        required_gradients: typing.List[typing.Optional[str]]
    ) -> typing.Tuple[Node, BackwardResult]:
        reduction_type = detect_reduction_type(forward_node.wcr)

        if len(given_gradients) != 1:
            raise AutoDiffException(
                "recieved invalid SDFG: reduce node {} should have exactly one output edge"
                .format(forward_node))

        if len(required_gradients) != 1:
            raise AutoDiffException(
                "recieved invalid SDFG: reduce node {} should have exactly one input edge"
                .format(forward_node))

        input_name = next(iter(required_gradients))
        in_desc = in_desc_with_name(forward_node, context.forward_state,
                                    context.forward_sdfg, input_name)

        output_name = next(iter(given_gradients))
        out_desc = out_desc_with_name(forward_node, context.forward_state,
                                      context.forward_sdfg, output_name)

        all_axes: typing.List[int] = list(range(len(in_desc.shape)))
        reduce_axes: typing.List[
            int] = all_axes if forward_node.axes is None else forward_node.axes
        non_reduce_axes: typing.List[int] = [
            i for i in all_axes if i not in reduce_axes
        ]

        result = BackwardResult.empty()

        if reduction_type is dtypes.ReductionType.Sum:
            # in this case, we need to simply scatter the grad across the axes that were reduced

            sdfg = SDFG("_reverse_" + str(reduction_type).replace(".", "_") +
                        "_")
            state = sdfg.add_state()

            rev_input_conn_name = "input_gradient"
            rev_output_conn_name = "output_gradient"
            result.required_grad_names[output_name] = rev_output_conn_name
            result.given_grad_names[input_name] = rev_input_conn_name

            _, rev_input_arr = sdfg.add_array(rev_input_conn_name,
                                              shape=out_desc.shape,
                                              dtype=out_desc.dtype)
            _, rev_output_arr = sdfg.add_array(rev_output_conn_name,
                                               shape=in_desc.shape,
                                               dtype=in_desc.dtype)

            state.add_mapped_tasklet(
                "_distribute_grad_" + str(reduction_type).replace(".", "_") +
                "_", {
                    "i" + str(i): "0:{}".format(shape)
                    for i, shape in enumerate(in_desc.shape)
                }, {
                    "__in":
                    Memlet.simple(
                        rev_input_conn_name,
                        "0" if forward_node.axes is None else ",".join(
                            "i" + str(i) for i in non_reduce_axes))
                },
                "__out = __in", {
                    "__out":
                    Memlet.simple(rev_output_conn_name,
                                  ",".join("i" + str(i) for i in all_axes),
                                  wcr_str="lambda x, y: x + y")
                },
                external_edges=True)

            return context.backward_state.add_nested_sdfg(
                sdfg, None, {rev_input_conn_name},
                {rev_output_conn_name}), result
        else:
            raise AutoDiffException(
                "Unsupported reduction type '{}'".format(reduction_type))
Exemple #8
0
    def apply(self, _, sdfg: sd.SDFG):
        # Obtain loop information
        guard: sd.SDFGState = self.loop_guard
        body: sd.SDFGState = self.loop_begin

        # Obtain iteration variable, range, and stride
        itervar, (start, end, step), _ = find_for_loop(sdfg, guard, body)

        forward_loop = step > 0

        for node in body.nodes():
            if isinstance(node, nodes.MapEntry):
                map_entry = node
            if isinstance(node, nodes.MapExit):
                map_exit = node

        # nest map's content in sdfg
        map_subgraph = body.scope_subgraph(map_entry, include_entry=False, include_exit=False)
        nsdfg = helpers.nest_state_subgraph(sdfg, body, map_subgraph, full_data=True)

        # replicate loop in nested sdfg
        new_before, new_guard, new_after = nsdfg.sdfg.add_loop(
            before_state=None,
            loop_state=nsdfg.sdfg.nodes()[0],
            loop_end_state=None,
            after_state=None,
            loop_var=itervar,
            initialize_expr=f'{start}',
            condition_expr=f'{itervar} <= {end}' if forward_loop else f'{itervar} >= {end}',
            increment_expr=f'{itervar} + {step}' if forward_loop else f'{itervar} - {abs(step)}')

        # remove outer loop
        before_guard_edge = nsdfg.sdfg.edges_between(new_before, new_guard)[0]
        for e in nsdfg.sdfg.out_edges(new_guard):
            if e.dst is new_after:
                guard_after_edge = e
            else:
                guard_body_edge = e

        for body_inedge in sdfg.in_edges(body):
            if body_inedge.src is guard:
                guard_body_edge.data.assignments.update(body_inedge.data.assignments)
            sdfg.remove_edge(body_inedge)
        for body_outedge in sdfg.out_edges(body):
            sdfg.remove_edge(body_outedge)
        for guard_inedge in sdfg.in_edges(guard):
            before_guard_edge.data.assignments.update(guard_inedge.data.assignments)
            guard_inedge.data.assignments = {}
            sdfg.add_edge(guard_inedge.src, body, guard_inedge.data)
            sdfg.remove_edge(guard_inedge)
        for guard_outedge in sdfg.out_edges(guard):
            if guard_outedge.dst is body:
                guard_body_edge.data.assignments.update(guard_outedge.data.assignments)
            else:
                guard_after_edge.data.assignments.update(guard_outedge.data.assignments)
            guard_outedge.data.condition = CodeBlock("1")
            sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data)
            sdfg.remove_edge(guard_outedge)
        sdfg.remove_node(guard)
        if itervar in nsdfg.symbol_mapping:
            del nsdfg.symbol_mapping[itervar]
        if itervar in sdfg.symbols:
            del sdfg.symbols[itervar]

        # Add missing data/symbols
        for s in nsdfg.sdfg.free_symbols:
            if s in nsdfg.symbol_mapping:
                continue
            if s in sdfg.symbols:
                nsdfg.symbol_mapping[s] = s
            elif s in sdfg.arrays:
                desc = sdfg.arrays[s]
                access = body.add_access(s)
                conn = nsdfg.sdfg.add_datadesc(s, copy.deepcopy(desc))
                nsdfg.sdfg.arrays[s].transient = False
                nsdfg.add_in_connector(conn)
                body.add_memlet_path(access, map_entry, nsdfg, memlet=Memlet.from_array(s, desc), dst_conn=conn)
            else:
                raise NotImplementedError(f"Free symbol {s} is neither a symbol nor data.")
        to_delete = set()
        for s in nsdfg.symbol_mapping:
            if s not in nsdfg.sdfg.free_symbols:
                to_delete.add(s)
        for s in to_delete:
            del nsdfg.symbol_mapping[s]

        # propagate scope for correct volumes
        scope_tree = ScopeTree(map_entry, map_exit)
        scope_tree.parent = ScopeTree(None, None)
        # The first execution helps remove apperances of symbols
        # that are now defined only in the nested SDFG in memlets.
        propagation.propagate_memlets_scope(sdfg, body, scope_tree)

        for s in to_delete:
            if helpers.is_symbol_unused(sdfg, s):
                sdfg.remove_symbol(s)

        from dace.transformation.interstate import RefineNestedAccess
        transformation = RefineNestedAccess()
        transformation.setup_match(sdfg, 0, sdfg.node_id(body), {RefineNestedAccess.nsdfg: body.node_id(nsdfg)}, 0)
        transformation.apply(body, sdfg)

        # Second propagation for refined accesses.
        propagation.propagate_memlets_scope(sdfg, body, scope_tree)
Exemple #9
0
def _make_sdfg(node, parent_state, parent_sdfg, implementation):

    arr_desc = node.validate(parent_sdfg, parent_state)
    if node.overwrite:
        in_shape, in_dtype, in_strides, n = arr_desc
    else:
        (in_shape, in_dtype, in_strides, out_shape, out_dtype, out_strides,
         n) = arr_desc
    dtype = in_dtype

    sdfg = dace.SDFG("{l}_sdfg".format(l=node.label))

    a_arr = sdfg.add_array('_ain',
                           in_shape,
                           dtype=in_dtype,
                           strides=in_strides)
    if not node.overwrite:
        ain_arr = a_arr
        a_arr = sdfg.add_array('_aout',
                               out_shape,
                               dtype=out_dtype,
                               strides=out_strides)
    ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True)
    info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True)

    state = sdfg.add_state("{l}_state".format(l=node.label))

    getrf_node = Getrf('getrf')
    getrf_node.implementation = implementation
    getri_node = Getri('getri')
    getri_node.implementation = implementation

    if node.overwrite:
        ain = state.add_read('_ain')
        ainout = state.add_access('_ain')
        aout = state.add_write('_ain')
    else:
        a = state.add_read('_ain')
        ain = state.add_read('_aout')
        ainout = state.add_access('_aout')
        aout = state.add_write('_aout')
        state.add_nedge(a, ain, Memlet.from_array(*ain_arr))

    ipiv = state.add_access('_pivots')
    info1 = state.add_write('_info')
    info2 = state.add_write('_info')

    state.add_memlet_path(ain,
                          getrf_node,
                          dst_conn="_xin",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(getrf_node,
                          info1,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getrf_node,
                          ipiv,
                          src_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getrf_node,
                          ainout,
                          src_conn="_xout",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(ainout,
                          getri_node,
                          dst_conn="_xin",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(ipiv,
                          getri_node,
                          dst_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getri_node,
                          info2,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getri_node,
                          aout,
                          src_conn="_xout",
                          memlet=Memlet.from_array(*a_arr))

    return sdfg
Exemple #10
0
def _make_sdfg_getrs(node, parent_state, parent_sdfg, implementation):

    arr_desc = node.validate(parent_sdfg, parent_state)
    if node.overwrite:
        in_shape, in_dtype, in_strides, n = arr_desc
    else:
        (in_shape, in_dtype, in_strides, out_shape, out_dtype, out_strides,
         n) = arr_desc
    dtype = in_dtype

    sdfg = dace.SDFG("{l}_sdfg".format(l=node.label))

    a_arr = sdfg.add_array('_ain',
                           in_shape,
                           dtype=in_dtype,
                           strides=in_strides)
    if not node.overwrite:
        ain_arr = a_arr
        a_arr = sdfg.add_array('_ainout', [n, n],
                               dtype=in_dtype,
                               transient=True)
        b_arr = sdfg.add_array('_aout',
                               out_shape,
                               dtype=out_dtype,
                               strides=out_strides)
    else:
        b_arr = sdfg.add_array('_b', [n, n], dtype=dtype, transient=True)
    ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True)
    info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True)

    state = sdfg.add_state("{l}_state".format(l=node.label))

    getrf_node = Getrf('getrf')
    getrf_node.implementation = implementation
    getrs_node = Getrs('getrs')
    getrs_node.implementation = implementation

    if node.overwrite:
        ain = state.add_read('_ain')
        ainout = state.add_access('_ain')
        aout = state.add_write('_ain')
        bin_name = '_b'
        bout = state.add_write('_b')
        state.add_nedge(bout, aout, Memlet.from_array(*a_arr))
    else:
        a = state.add_read('_ain')
        ain = state.add_read('_ainout')
        ainout = state.add_access('_ainout')
        # aout = state.add_write('_aout')
        state.add_nedge(a, ain, Memlet.from_array(*ain_arr))
        bin_name = '_aout'
        bout = state.add_access('_aout')

    _, _, mx = state.add_mapped_tasklet(
        '_eye_',
        dict(i="0:n", j="0:n"), {},
        '_out = (i == j) ? 1 : 0;',
        dict(_out=Memlet.simple(bin_name, 'i, j')),
        language=dace.dtypes.Language.CPP,
        external_edges=True)
    bin = state.out_edges(mx)[0].dst

    ipiv = state.add_access('_pivots')
    info1 = state.add_write('_info')
    info2 = state.add_write('_info')

    state.add_memlet_path(ain,
                          getrf_node,
                          dst_conn="_xin",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(getrf_node,
                          info1,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getrf_node,
                          ipiv,
                          src_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getrf_node,
                          ainout,
                          src_conn="_xout",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(ainout,
                          getrs_node,
                          dst_conn="_a",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(bin,
                          getrs_node,
                          dst_conn="_rhs_in",
                          memlet=Memlet.from_array(*b_arr))
    state.add_memlet_path(ipiv,
                          getrs_node,
                          dst_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getrs_node,
                          info2,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getrs_node,
                          bout,
                          src_conn="_rhs_out",
                          memlet=Memlet.from_array(*b_arr))

    return sdfg
Exemple #11
0
    def apply(self, state: SDFGState, sdfg: SDFG) -> nodes.AccessNode:
        dnode: nodes.AccessNode = self.access
        if self.expr_index == 0:
            edges = state.out_edges(dnode)
        else:
            edges = state.in_edges(dnode)

        # To understand how many components we need to create, all map ranges
        # throughout memlet paths must match exactly. We thus create a
        # dictionary of unique ranges
        mapping: Dict[Tuple[subsets.Range],
                      List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict(
                          list)
        ranges = {}
        for edge in edges:
            mpath = state.memlet_path(edge)
            ranges[edge] = _collect_map_ranges(state, mpath)
            mapping[tuple(r[1] for r in ranges[edge])].append(edge)

        # Collect all edges with the same memory access pattern
        components_to_create: Dict[
            Tuple[symbolic.SymbolicType],
            List[gr.MultiConnectorEdge[mm.Memlet]]] = defaultdict(list)
        for edges_with_same_range in mapping.values():
            for edge in edges_with_same_range:
                # Get memlet path and innermost edge
                mpath = state.memlet_path(edge)
                innermost_edge = copy.deepcopy(mpath[-1] if self.expr_index ==
                                               0 else mpath[0])

                # Store memlets of the same access in the same component
                expr = _canonicalize_memlet(innermost_edge.data, ranges[edge])
                components_to_create[expr].append((innermost_edge, edge))
        components = list(components_to_create.values())

        # Split out components that have dependencies between them to avoid
        # deadlocks
        if self.expr_index == 0:
            ccs_to_add = []
            for i, component in enumerate(components):
                edges_to_remove = set()
                for cedge in component:
                    if any(
                            nx.has_path(state.nx, o[1].dst, cedge[1].dst)
                            for o in component if o is not cedge):
                        ccs_to_add.append([cedge])
                        edges_to_remove.add(cedge)
                if edges_to_remove:
                    components[i] = [
                        c for c in component if c not in edges_to_remove
                    ]
            components.extend(ccs_to_add)
        # End of split

        desc = sdfg.arrays[dnode.data]

        # Create new streams of shape 1
        streams = {}
        mpaths = {}
        for edge in edges:

            if self.use_memory_buffering:

                arrname = str(self.access)

                # Add gearbox
                total_size = edge.data.volume
                vector_size = int(self.memory_buffering_target_bytes /
                                  desc.dtype.bytes)

                if not is_int(sdfg.arrays[dnode.data].shape[-1]):
                    warnings.warn(
                        "Using the MemoryBuffering transformation is potential unsafe since {sym} is not an integer. There should be no issue if {sym} % {vec} == 0"
                        .format(sym=sdfg.arrays[dnode.data].shape[-1],
                                vec=vector_size))

                for i in sdfg.arrays[dnode.data].strides:
                    if not is_int(i):
                        warnings.warn(
                            "Using the MemoryBuffering transformation is potential unsafe since {sym} is not an integer. There should be no issue if {sym} % {vec} == 0"
                            .format(sym=i, vec=vector_size))

                if self.expr_index == 0:  # Read
                    edges = state.out_edges(dnode)
                    gearbox_input_type = dtypes.vector(desc.dtype, vector_size)
                    gearbox_output_type = desc.dtype
                    gearbox_read_volume = total_size / vector_size
                    gearbox_write_volume = total_size
                else:  # Write
                    edges = state.in_edges(dnode)
                    gearbox_input_type = desc.dtype
                    gearbox_output_type = dtypes.vector(
                        desc.dtype, vector_size)
                    gearbox_read_volume = total_size
                    gearbox_write_volume = total_size / vector_size

                input_gearbox_name, input_gearbox_newdesc = sdfg.add_stream(
                    "gearbox_input",
                    gearbox_input_type,
                    buffer_size=self.buffer_size,
                    storage=self.storage,
                    transient=True,
                    find_new_name=True)

                output_gearbox_name, output_gearbox_newdesc = sdfg.add_stream(
                    "gearbox_output",
                    gearbox_output_type,
                    buffer_size=self.buffer_size,
                    storage=self.storage,
                    transient=True,
                    find_new_name=True)

                read_to_gearbox = state.add_read(input_gearbox_name)
                write_from_gearbox = state.add_write(output_gearbox_name)

                gearbox = Gearbox(total_size / vector_size)

                state.add_node(gearbox)

                state.add_memlet_path(read_to_gearbox,
                                      gearbox,
                                      dst_conn="from_memory",
                                      memlet=Memlet(
                                          input_gearbox_name + "[0]",
                                          volume=gearbox_read_volume))
                state.add_memlet_path(gearbox,
                                      write_from_gearbox,
                                      src_conn="to_kernel",
                                      memlet=Memlet(
                                          output_gearbox_name + "[0]",
                                          volume=gearbox_write_volume))

                if self.expr_index == 0:
                    streams[edge] = input_gearbox_name
                    name = output_gearbox_name
                    newdesc = output_gearbox_newdesc
                else:
                    streams[edge] = output_gearbox_name
                    name = input_gearbox_name
                    newdesc = input_gearbox_newdesc

            else:
                # Qualify name to avoid name clashes if memory interfaces are not decoupled for Xilinx
                stream_name = "stream_" + dnode.data
                name, newdesc = sdfg.add_stream(stream_name,
                                                desc.dtype,
                                                buffer_size=self.buffer_size,
                                                storage=self.storage,
                                                transient=True,
                                                find_new_name=True)
                streams[edge] = name

                # Add these such that we can easily use output_gearbox_name and input_gearbox_name without using if statements
                output_gearbox_name = name
                input_gearbox_name = name

            mpath = state.memlet_path(edge)
            mpaths[edge] = mpath

            # Replace memlets in path with stream access
            for e in mpath:
                e.data = mm.Memlet(data=name,
                                   subset='0',
                                   other_subset=e.data.other_subset)
                if isinstance(e.src, nodes.NestedSDFG):
                    e.data.dynamic = True
                    _streamify_recursive(e.src, e.src_conn, newdesc)
                if isinstance(e.dst, nodes.NestedSDFG):
                    e.data.dynamic = True
                    _streamify_recursive(e.dst, e.dst_conn, newdesc)

            # Replace access node and memlet tree with one access
            if self.expr_index == 0:
                replacement = state.add_read(output_gearbox_name)
                state.remove_edge(edge)
                state.add_edge(replacement, edge.src_conn, edge.dst,
                               edge.dst_conn, edge.data)
            else:
                replacement = state.add_write(input_gearbox_name)
                state.remove_edge(edge)
                state.add_edge(edge.src, edge.src_conn, replacement,
                               edge.dst_conn, edge.data)

        if self.use_memory_buffering:

            arrname = str(self.access)
            vector_size = int(self.memory_buffering_target_bytes /
                              desc.dtype.bytes)

            # Vectorize access to global array.
            dtype = sdfg.arrays[arrname].dtype
            sdfg.arrays[arrname].dtype = dtypes.vector(dtype, vector_size)
            new_shape = list(sdfg.arrays[arrname].shape)
            contigidx = sdfg.arrays[arrname].strides.index(1)
            new_shape[contigidx] /= vector_size
            try:
                new_shape[contigidx] = int(new_shape[contigidx])
            except TypeError:
                pass
            sdfg.arrays[arrname].shape = new_shape

            # Change strides
            new_strides: List = list(sdfg.arrays[arrname].strides)

            for i in range(len(new_strides)):
                if i == len(new_strides
                            ) - 1:  # Skip last dimension since it is always 1
                    continue
                new_strides[i] = new_strides[i] / vector_size
            sdfg.arrays[arrname].strides = new_strides

            post_state = get_post_state(sdfg, state)

            if post_state != None:
                # Change subset in the post state such that the correct amount of memory is copied back from the device
                for e in post_state.edges():
                    if e.data.data == self.access.data:
                        new_subset = list(e.data.subset)
                        i, j, k = new_subset[-1]
                        new_subset[-1] = (i, (j + 1) / vector_size - 1, k)
                        e.data = mm.Memlet(data=str(e.src),
                                           subset=subsets.Range(new_subset))

        # Make read/write components
        ionodes = []
        for component in components:

            # Pick the first edge as the edge to make the component from
            innermost_edge, outermost_edge = component[0]
            mpath = mpaths[outermost_edge]
            mapname = streams[outermost_edge]
            innermost_edge.data.other_subset = None

            # Get edge data and streams
            if self.expr_index == 0:
                opname = 'read'
                path = [e.dst for e in mpath[:-1]]
                rmemlets = [(dnode, '__inp', innermost_edge.data)]
                wmemlets = []
                for i, (_, edge) in enumerate(component):
                    name = streams[edge]
                    ionode = state.add_write(name)
                    ionodes.append(ionode)
                    wmemlets.append(
                        (ionode, '__out%d' % i, mm.Memlet(data=name,
                                                          subset='0')))
                code = '\n'.join('__out%d = __inp' % i
                                 for i in range(len(component)))
            else:
                # More than one input stream might mean a data race, so we only
                # address the first one in the tasklet code
                if len(component) > 1:
                    warnings.warn(
                        f'More than one input found for the same index for {dnode.data}'
                    )
                opname = 'write'
                path = [state.entry_node(e.src) for e in reversed(mpath[1:])]
                wmemlets = [(dnode, '__out', innermost_edge.data)]
                rmemlets = []
                for i, (_, edge) in enumerate(component):
                    name = streams[edge]
                    ionode = state.add_read(name)
                    ionodes.append(ionode)
                    rmemlets.append(
                        (ionode, '__inp%d' % i, mm.Memlet(data=name,
                                                          subset='0')))
                code = '__out = __inp0'

            # Create map structure for read/write component
            maps = []
            for entry in path:
                map: nodes.Map = entry.map

                ranges = [(p, (r[0], r[1], r[2]))
                          for p, r in zip(map.params, map.range)]

                # Change ranges of map
                if self.use_memory_buffering:
                    # Find edges from/to map

                    edge_subset = [
                        a_tuple[0]
                        for a_tuple in list(innermost_edge.data.subset)
                    ]

                    # Change range of map
                    if isinstance(edge_subset[-1], symbol) and str(
                            edge_subset[-1]) == map.params[-1]:

                        if not is_int(ranges[-1][1][1]):

                            warnings.warn(
                                "Using the MemoryBuffering transformation is potential unsafe since {sym} is not an integer. There should be no issue if {sym} % {vec} == 0"
                                .format(sym=ranges[-1][1][1].args[1],
                                        vec=vector_size))

                        ranges[-1] = (ranges[-1][0],
                                      (ranges[-1][1][0],
                                       (ranges[-1][1][1] + 1) / vector_size -
                                       1, ranges[-1][1][2]))

                    elif isinstance(edge_subset[-1], sympy.core.add.Add):

                        for arg in edge_subset[-1].args:
                            if isinstance(
                                    arg,
                                    symbol) and str(arg) == map.params[-1]:

                                if not is_int(ranges[-1][1][1]):
                                    warnings.warn(
                                        "Using the MemoryBuffering transformation is potential unsafe since {sym} is not an integer. There should be no issue if {sym} % {vec} == 0"
                                        .format(sym=ranges[-1][1][1].args[1],
                                                vec=vector_size))

                                ranges[-1] = (ranges[-1][0], (
                                    ranges[-1][1][0],
                                    (ranges[-1][1][1] + 1) / vector_size - 1,
                                    ranges[-1][1][2]))

                maps.append(
                    state.add_map(f'__s{opname}_{mapname}', ranges,
                                  map.schedule))
            tasklet = state.add_tasklet(
                f'{opname}_{mapname}',
                {m[1]
                 for m in rmemlets},
                {m[1]
                 for m in wmemlets},
                code,
            )
            for node, cname, memlet in rmemlets:
                state.add_memlet_path(node,
                                      *(me for me, _ in maps),
                                      tasklet,
                                      dst_conn=cname,
                                      memlet=memlet)
            for node, cname, memlet in wmemlets:
                state.add_memlet_path(tasklet,
                                      *(mx for _, mx in reversed(maps)),
                                      node,
                                      src_conn=cname,
                                      memlet=memlet)

        return ionodes
Exemple #12
0
    def apply(self, sdfg: SDFG):
        subgraph = self.subgraph_view(sdfg)

        entry_states_in, entry_states_out = self.get_entry_states(
            sdfg, subgraph)
        _, exit_states_out = self.get_exit_states(sdfg, subgraph)

        entry_state_in = entry_states_in.pop()
        entry_state_out = entry_states_out.pop() \
            if len(entry_states_out) > 0 else None
        exit_state_out = exit_states_out.pop() \
            if len(exit_states_out) > 0 else None

        launch_state = None
        entry_guard_state = None
        exit_guard_state = None

        # generate entry guard state if needed
        if self.include_in_assignment and entry_state_out is not None:
            entry_edge = sdfg.edges_between(entry_state_out, entry_state_in)[0]
            if len(entry_edge.data.assignments) > 0:
                entry_guard_state = sdfg.add_state(
                    label='{}kernel_entry_guard'.format(
                        self.kernel_prefix +
                        '_' if self.kernel_prefix != '' else ''))
                sdfg.add_edge(entry_state_out, entry_guard_state,
                              InterstateEdge(entry_edge.data.condition))
                sdfg.add_edge(
                    entry_guard_state, entry_state_in,
                    InterstateEdge(None, entry_edge.data.assignments))
                sdfg.remove_edge(entry_edge)

                # Update SubgraphView
                new_node_list = subgraph.nodes()
                new_node_list.append(entry_guard_state)
                subgraph = SubgraphView(sdfg, new_node_list)

                launch_state = sdfg.add_state_before(
                    entry_guard_state,
                    label='{}kernel_launch'.format(
                        self.kernel_prefix +
                        '_' if self.kernel_prefix != '' else ''))

        # generate exit guard state
        if exit_state_out is not None:
            exit_guard_state = sdfg.add_state_before(
                exit_state_out,
                label='{}kernel_exit_guard'.format(
                    self.kernel_prefix +
                    '_' if self.kernel_prefix != '' else ''))

            # Update SubgraphView
            new_node_list = subgraph.nodes()
            new_node_list.append(exit_guard_state)
            subgraph = SubgraphView(sdfg, new_node_list)

            if launch_state is None:
                launch_state = sdfg.add_state_before(
                    exit_state_out,
                    label='{}kernel_launch'.format(
                        self.kernel_prefix +
                        '_' if self.kernel_prefix != '' else ''))

        # If the launch state doesn't exist at this point then there is no other
        # states outside of the kernel, so create a stand alone launch state
        if launch_state is None:
            assert (entry_state_in is None and exit_state_out is None)
            launch_state = sdfg.add_state(label='{}kernel_launch'.format(
                self.kernel_prefix + '_' if self.kernel_prefix != '' else ''))

        # create sdfg for kernel and fill it with states and edges from
        # ssubgraph dfg will be nested at the end
        kernel_sdfg = SDFG(
            '{}kernel'.format(self.kernel_prefix +
                              '_' if self.kernel_prefix != '' else ''))

        edges = subgraph.edges()
        for edge in edges:
            kernel_sdfg.add_edge(edge.src, edge.dst, edge.data)

        # Setting entry node in nested SDFG if no entry guard was created
        if entry_guard_state is None:
            kernel_sdfg.start_state = kernel_sdfg.node_id(entry_state_in)

        for state in subgraph:
            state.parent = kernel_sdfg

        # remove the now nested nodes from the outer sdfg and make sure the
        # launch state is properly connected to remaining states
        sdfg.remove_nodes_from(subgraph.nodes())

        if entry_state_out is not None \
                and len(sdfg.edges_between(entry_state_out, launch_state)) == 0:
            sdfg.add_edge(entry_state_out, launch_state, InterstateEdge())

        if exit_state_out is not None \
                and len(sdfg.edges_between(launch_state, exit_state_out)) == 0:
            sdfg.add_edge(launch_state, exit_state_out, InterstateEdge())

        # Handle data for kernel
        kernel_data = set(node.data for state in kernel_sdfg
                          for node in state.nodes()
                          if isinstance(node, nodes.AccessNode))

        # move Streams and Register data into the nested SDFG
        # normal data will be added as kernel argument
        kernel_args = []
        for data in kernel_data:
            if (isinstance(sdfg.arrays[data], dace.data.Stream) or
                (isinstance(sdfg.arrays[data], dace.data.Array)
                 and sdfg.arrays[data].storage == StorageType.Register)):
                kernel_sdfg.add_datadesc(data, sdfg.arrays[data])
                del sdfg.arrays[data]
            else:
                copy_desc = copy.deepcopy(sdfg.arrays[data])
                copy_desc.transient = False
                copy_desc.storage = StorageType.Default
                kernel_sdfg.add_datadesc(data, copy_desc)
                kernel_args.append(data)

        # read only data will be passed as input, writeable data will be passed
        # as 'output' otherwise kernel cannot write to data
        kernel_args_read = set()
        kernel_args_write = set()
        for data in kernel_args:
            data_accesses_read_only = [
                node.access == dtypes.AccessType.ReadOnly
                for state in kernel_sdfg for node in state
                if isinstance(node, nodes.AccessNode) and node.data == data
            ]
            if all(data_accesses_read_only):
                kernel_args_read.add(data)
            else:
                kernel_args_write.add(data)

        # Kernel SDFG is complete at this point
        if self.validate:
            kernel_sdfg.validate()

        # Filling launch state with nested SDFG, map and access nodes
        map_entry, map_exit = launch_state.add_map(
            '{}kernel_launch_map'.format(
                self.kernel_prefix + '_' if self.kernel_prefix != '' else ''),
            dict(ignore='0'),
            schedule=ScheduleType.GPU_Persistent,
        )

        nested_sdfg = launch_state.add_nested_sdfg(
            kernel_sdfg,
            sdfg,
            kernel_args_read,
            kernel_args_write,
        )

        # Create and connect read only data access nodes
        for arg in kernel_args_read:
            read_node = launch_state.add_read(arg)
            launch_state.add_memlet_path(read_node,
                                         map_entry,
                                         nested_sdfg,
                                         dst_conn=arg,
                                         memlet=Memlet.from_array(
                                             arg, sdfg.arrays[arg]))

        # Create and connect writable data access nodes
        for arg in kernel_args_write:
            write_node = launch_state.add_write(arg)
            launch_state.add_memlet_path(nested_sdfg,
                                         map_exit,
                                         write_node,
                                         src_conn=arg,
                                         memlet=Memlet.from_array(
                                             arg, sdfg.arrays[arg]))

        # Transformation is done
        if self.validate:
            sdfg.validate()
Exemple #13
0
    def apply(self, sdfg: SDFG):
        state: SDFGState = sdfg.nodes()[self.state_id]
        nsdfg_node = state.nodes()[self.subgraph[InlineSDFG._nested_sdfg]]
        nsdfg: SDFG = nsdfg_node.sdfg
        nstate: SDFGState = nsdfg.nodes()[0]

        if nsdfg_node.schedule is not dtypes.ScheduleType.Default:
            infer_types.set_default_schedule_and_storage_types(
                nsdfg, nsdfg_node.schedule)

        nsdfg_scope_entry = state.entry_node(nsdfg_node)
        nsdfg_scope_exit = (state.exit_node(nsdfg_scope_entry)
                            if nsdfg_scope_entry is not None else None)

        #######################################################
        # Collect and update top-level SDFG metadata

        # Global/init/exit code
        for loc, code in nsdfg.global_code.items():
            sdfg.append_global_code(code.code, loc)
        for loc, code in nsdfg.init_code.items():
            sdfg.append_init_code(code.code, loc)
        for loc, code in nsdfg.exit_code.items():
            sdfg.append_exit_code(code.code, loc)

        # Constants
        for cstname, cstval in nsdfg.constants.items():
            if cstname in sdfg.constants:
                if cstval != sdfg.constants[cstname]:
                    warnings.warn('Constant value mismatch for "%s" while '
                                  'inlining SDFG. Inner = %s != %s = outer' %
                                  (cstname, cstval, sdfg.constants[cstname]))
            else:
                sdfg.add_constant(cstname, cstval)

        # Find original source/destination edges (there is only one edge per
        # connector, according to match)
        inputs: Dict[str, MultiConnectorEdge] = {}
        outputs: Dict[str, MultiConnectorEdge] = {}
        input_set: Dict[str, str] = {}
        output_set: Dict[str, str] = {}
        for e in state.in_edges(nsdfg_node):
            inputs[e.dst_conn] = e
            input_set[e.data.data] = e.dst_conn
        for e in state.out_edges(nsdfg_node):
            outputs[e.src_conn] = e
            output_set[e.data.data] = e.src_conn

        # Access nodes that need to be reshaped
        reshapes: Set(str) = set()
        for aname, array in nsdfg.arrays.items():
            if array.transient:
                continue
            edge = None
            if aname in inputs:
                edge = inputs[aname]
                if len(array.shape) > len(edge.data.subset):
                    reshapes.add(aname)
                    continue
            if aname in outputs:
                edge = outputs[aname]
                if len(array.shape) > len(edge.data.subset):
                    reshapes.add(aname)
                    continue
            if edge is not None and not InlineSDFG._check_strides(
                    array.strides, sdfg.arrays[edge.data.data].strides,
                    edge.data, nsdfg_node):
                reshapes.add(aname)

        # Replace symbols using invocation symbol mapping
        # Two-step replacement (N -> __dacesym_N --> map[N]) to avoid clashes
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace(symname, '__dacesym_' + symname)
        for symname, symvalue in nsdfg_node.symbol_mapping.items():
            if str(symname) != str(symvalue):
                nsdfg.replace('__dacesym_' + symname, symvalue)

        # All transients become transients of the parent (if data already
        # exists, find new name)
        # Mapping from nested transient name to top-level name
        transients: Dict[str, str] = {}
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode):
                datadesc = nsdfg.arrays[node.data]
                if node.data not in transients and datadesc.transient:
                    name = sdfg.add_datadesc('%s_%s' %
                                             (nsdfg.label, node.data),
                                             datadesc,
                                             find_new_name=True)
                    transients[node.data] = name

        # All transients of edges between code nodes are also added to parent
        for edge in nstate.edges():
            if (isinstance(edge.src, nodes.CodeNode)
                    and isinstance(edge.dst, nodes.CodeNode)):
                if edge.data.data is not None:
                    datadesc = nsdfg.arrays[edge.data.data]
                    if edge.data.data not in transients and datadesc.transient:
                        name = sdfg.add_datadesc('%s_%s' %
                                                 (nsdfg.label, edge.data.data),
                                                 datadesc,
                                                 find_new_name=True)
                        transients[edge.data.data] = name

        # Collect nodes to add to top-level graph
        new_incoming_edges: Dict[nodes.Node, MultiConnectorEdge] = {}
        new_outgoing_edges: Dict[nodes.Node, MultiConnectorEdge] = {}

        source_accesses = set()
        sink_accesses = set()
        for node in nstate.source_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients
                    and node.data not in reshapes):
                new_incoming_edges[node] = inputs[node.data]
                source_accesses.add(node)
        for node in nstate.sink_nodes():
            if (isinstance(node, nodes.AccessNode)
                    and node.data not in transients
                    and node.data not in reshapes):
                new_outgoing_edges[node] = outputs[node.data]
                sink_accesses.add(node)

        #######################################################
        # Replace data on inlined SDFG nodes/edges

        # Replace data names with their top-level counterparts
        repldict = {}
        repldict.update(transients)
        repldict.update({
            k: v.data.data
            for k, v in itertools.chain(inputs.items(), outputs.items())
        })

        # Add views whenever reshapes are necessary
        for dname in reshapes:
            desc = nsdfg.arrays[dname]
            # To avoid potential confusion, rename protected __return keyword
            if dname.startswith('__return'):
                newname = f'{nsdfg.name}_ret{dname[8:]}'
            else:
                newname = dname
            newname, _ = sdfg.add_view(newname,
                                       desc.shape,
                                       desc.dtype,
                                       storage=desc.storage,
                                       strides=desc.strides,
                                       offset=desc.offset,
                                       debuginfo=desc.debuginfo,
                                       allow_conflicts=desc.allow_conflicts,
                                       total_size=desc.total_size,
                                       alignment=desc.alignment,
                                       may_alias=desc.may_alias,
                                       find_new_name=True)
            repldict[dname] = newname

        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in repldict:
                node.data = repldict[node.data]
        for edge in nstate.edges():
            if edge.data.data in repldict:
                edge.data.data = repldict[edge.data.data]

        # Add extra access nodes for out/in view nodes
        for node in nstate.nodes():
            if isinstance(node, nodes.AccessNode) and node.data in reshapes:
                if nstate.in_degree(node) > 0 and nstate.out_degree(node) > 0:
                    # Such a node has to be in the output set
                    edge = outputs[node.data]

                    # Redirect outgoing edges through access node
                    out_edges = list(nstate.out_edges(node))
                    anode = nstate.add_access(edge.data.data)
                    vnode = nstate.add_access(node.data)
                    nstate.add_nedge(node, anode, edge.data)
                    nstate.add_nedge(anode, vnode, edge.data)
                    for e in out_edges:
                        nstate.remove_edge(e)
                        nstate.add_edge(vnode, e.src_conn, e.dst, e.dst_conn,
                                        e.data)

        #######################################################
        # Add nested SDFG into top-level SDFG

        # Add nested nodes into original state
        subgraph = SubgraphView(nstate, [
            n for n in nstate.nodes()
            if n not in (source_accesses | sink_accesses)
        ])
        state.add_nodes_from(subgraph.nodes())
        for edge in subgraph.edges():
            state.add_edge(edge.src, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Reconnect inlined SDFG

        # If a source/sink node is one of the inputs/outputs, reconnect it,
        # replacing memlets in outgoing/incoming paths
        modified_edges = set()
        modified_edges |= self._modify_memlet_path(new_incoming_edges, nstate,
                                                   state, True)
        modified_edges |= self._modify_memlet_path(new_outgoing_edges, nstate,
                                                   state, False)

        # Reshape: add connections to viewed data
        self._modify_reshape_data(reshapes, repldict, inputs, nstate, state,
                                  True)
        self._modify_reshape_data(reshapes, repldict, outputs, nstate, state,
                                  False)

        # Modify all other internal edges pertaining to input/output nodes
        for node in subgraph.nodes():
            if isinstance(node, nodes.AccessNode):
                if node.data in input_set or node.data in output_set:
                    if node.data in input_set:
                        outer_edge = inputs[input_set[node.data]]
                    else:
                        outer_edge = outputs[output_set[node.data]]

                    for edge in state.all_edges(node):
                        if (edge not in modified_edges
                                and edge.data.data == node.data):
                            for e in state.memlet_tree(edge):
                                if e.data.data == node.data:
                                    e._data = helpers.unsqueeze_memlet(
                                        e.data, outer_edge.data)

        # If source/sink node is not connected to a source/destination access
        # node, and the nested SDFG is in a scope, connect to scope with empty
        # memlets
        if nsdfg_scope_entry is not None:
            for node in subgraph.nodes():
                if state.in_degree(node) == 0:
                    state.add_edge(nsdfg_scope_entry, None, node, None,
                                   Memlet())
                if state.out_degree(node) == 0:
                    state.add_edge(node, None, nsdfg_scope_exit, None,
                                   Memlet())

        # Replace nested SDFG parents with new SDFG
        for node in nstate.nodes():
            if isinstance(node, nodes.NestedSDFG):
                node.sdfg.parent = state
                node.sdfg.parent_sdfg = sdfg
                node.sdfg.parent_nsdfg_node = node

        # Remove all unused external inputs/output memlet paths, as well as
        # resulting isolated nodes
        removed_in_edges = self._remove_edge_path(state,
                                                  inputs,
                                                  set(inputs.keys()) -
                                                  source_accesses,
                                                  reverse=True)
        removed_out_edges = self._remove_edge_path(state,
                                                   outputs,
                                                   set(outputs.keys()) -
                                                   sink_accesses,
                                                   reverse=False)

        # Re-add in/out edges to first/last nodes in subgraph
        order = [
            x for x in nx.topological_sort(nstate._nx)
            if isinstance(x, nodes.AccessNode)
        ]
        for edge in removed_in_edges:
            # Find first access node that refers to this edge
            node = next(n for n in order if n.data == edge.data.data)
            state.add_edge(edge.src, edge.src_conn, node, edge.dst_conn,
                           edge.data)
        for edge in removed_out_edges:
            # Find last access node that refers to this edge
            node = next(n for n in reversed(order) if n.data == edge.data.data)
            state.add_edge(node, edge.src_conn, edge.dst, edge.dst_conn,
                           edge.data)

        #######################################################
        # Remove nested SDFG node
        state.remove_node(nsdfg_node)
    def apply_pass(
        self, sdfg: SDFG,
        pipeline_results: Dict[str,
                               Any]) -> Optional[Dict[SDFGState, Set[str]]]:
        """
        Removes unreachable dataflow throughout SDFG states.
        :param sdfg: The SDFG to modify.
        :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass
                                 results as ``{Pass subclass name: returned object from pass}``. If not run in a
                                 pipeline, an empty dictionary is expected.
        :return: A dictionary mapping states to removed data descriptor names, or None if nothing changed.
        """
        # Depends on the following analysis passes:
        #  * State reachability
        #  * Read/write access sets per state
        reachable: Dict[SDFGState,
                        Set[SDFGState]] = pipeline_results['StateReachability']
        access_sets: Dict[SDFGState,
                          Tuple[Set[str],
                                Set[str]]] = pipeline_results['AccessSets']
        result: Dict[SDFGState, Set[str]] = defaultdict(set)

        # Traverse SDFG backwards
        for state in reversed(list(cfg.stateorder_topological_sort(sdfg))):
            #############################################
            # Analysis
            #############################################

            # Compute states where memory will no longer be read
            writes = access_sets[state][1]
            descendants = reachable[state]
            descendant_reads = set().union(*(access_sets[succ][0]
                                             for succ in descendants))
            no_longer_used: Set[str] = set(data for data in writes
                                           if data not in descendant_reads)

            # Compute dead nodes
            dead_nodes: List[nodes.Node] = []

            # Propagate deadness backwards within a state
            for node in sdutil.dfs_topological_sort(state, reverse=True):
                if self._is_node_dead(node, sdfg, state, dead_nodes,
                                      no_longer_used):
                    dead_nodes.append(node)

            # Scope exit nodes are only dead if their corresponding entry nodes are
            live_nodes = set()
            for node in dead_nodes:
                if isinstance(node, nodes.ExitNode) and state.entry_node(
                        node) not in dead_nodes:
                    live_nodes.add(node)
            dead_nodes = dtypes.deduplicate(
                [n for n in dead_nodes if n not in live_nodes])

            if not dead_nodes:
                continue

            # Remove nodes while preserving scopes
            scopes_to_reconnect: Set[nodes.Node] = set()
            for node in state.nodes():
                # Look for scope exits that will be disconnected
                if isinstance(node, nodes.ExitNode) and node not in dead_nodes:
                    if any(n in dead_nodes for n in state.predecessors(node)):
                        scopes_to_reconnect.add(node)

            # Two types of scope disconnections may occur:
            # 1. Two scope exits will no longer be connected
            # 2. A predecessor of dead nodes is in a scope and not connected to its exit
            # Case (1) is taken care of by ``remove_memlet_path``
            # Case (2) is handled below
            # Reconnect scopes
            if scopes_to_reconnect:
                schildren = state.scope_children()
                for exit_node in scopes_to_reconnect:
                    entry_node = state.entry_node(exit_node)
                    for node in schildren[entry_node]:
                        if node is exit_node:
                            continue
                        if isinstance(node, nodes.EntryNode):
                            node = state.exit_node(node)
                        # If node will be disconnected from exit node, add an empty memlet
                        if all(succ in dead_nodes
                               for succ in state.successors(node)):
                            state.add_nedge(node, exit_node, Memlet())

            #############################################
            # Removal
            #############################################
            predecessor_nsdfgs: Dict[nodes.NestedSDFG,
                                     Set[str]] = defaultdict(set)
            for node in dead_nodes:
                # Remove memlet paths and connectors pertaining to dead nodes
                for e in state.in_edges(node):
                    mtree = state.memlet_tree(e)
                    for leaf in mtree.leaves():
                        # Keep track of predecessors of removed nodes for connector pruning
                        if isinstance(leaf.src, nodes.NestedSDFG):
                            predecessor_nsdfgs[leaf.src].add(leaf.src_conn)
                        state.remove_memlet_path(leaf)

                # Remove the node itself as necessary
                state.remove_node(node)

            result[state].update(dead_nodes)

            # Remove isolated access nodes after elimination
            access_nodes = set(state.data_nodes())
            for node in access_nodes:
                if state.degree(node) == 0:
                    state.remove_node(node)
                    result[state].add(node)

            # Prune now-dead connectors
            for node, dead_conns in predecessor_nsdfgs.items():
                for conn in dead_conns:
                    # If removed connector belonged to a nested SDFG, and no other input connector shares name,
                    # make nested data transient (dead dataflow elimination would remove internally as necessary)
                    if conn not in node.in_connectors:
                        node.sdfg.arrays[conn].transient = True

            # Update read sets for the predecessor states to reuse
            access_nodes -= result[state]
            access_node_names = set(n.data for n in access_nodes
                                    if state.out_degree(n) > 0)
            access_sets[state] = (access_node_names, access_sets[state][1])

        return result or None
Exemple #15
0
def insert_sdfg_element(sdfg_str, type, parent_uuid, edge_a_uuid):
    sdfg_answer = load_sdfg_from_json(sdfg_str)
    sdfg = sdfg_answer['sdfg']
    uuid = 'error'
    ret = find_graph_element_by_uuid(sdfg, parent_uuid)
    parent = ret['element']

    libname = None
    if type is not None and isinstance(type, str):
        split_type = type.split('|')
        if len(split_type) == 2:
            type = split_type[0]
            libname = split_type[1]

    if type == 'SDFGState':
        if parent is None:
            parent = sdfg
        elif isinstance(parent, nodes.NestedSDFG):
            parent = parent.sdfg
        state = parent.add_state()
        uuid = [get_uuid(state)]
    elif type == 'AccessNode':
        arrays = list(parent.parent.arrays.keys())
        if len(arrays) == 0:
            parent.parent.add_array('tmp', [1], dtype=dtypes.float64)
            arrays = list(parent.parent.arrays.keys())
        node = parent.add_access(arrays[0])
        uuid = [get_uuid(node, parent)]
    elif type == 'Map':
        map_entry, map_exit = parent.add_map('map', dict(i='0:1'))
        uuid = [get_uuid(map_entry, parent), get_uuid(map_exit, parent)]
    elif type == 'Consume':
        consume_entry, consume_exit = parent.add_consume('consume', ('i', '1'))
        uuid = [get_uuid(consume_entry, parent), get_uuid(consume_exit, parent)]
    elif type == 'Tasklet':
        tasklet = parent.add_tasklet(
            name='placeholder',
            inputs={'in'},
            outputs={'out'},
            code='')
        uuid = [get_uuid(tasklet, parent)]
    elif type == 'NestedSDFG':
        sub_sdfg = SDFG('nested_sdfg')
        sub_sdfg.add_array('in', [1], dtypes.float32)
        sub_sdfg.add_array('out', [1], dtypes.float32)
        
        nsdfg = parent.add_nested_sdfg(sub_sdfg, sdfg, {'in'}, {'out'})
        uuid = [get_uuid(nsdfg, parent)]
    elif type == 'LibraryNode':
        if libname is None:
            return {
                'error': {
                    'message': 'Failed to add library node',
                    'details': 'Must provide a valid library node type',
                },
            }
        libnode_class = pydoc.locate(libname)
        libnode = libnode_class()
        parent.add_node(libnode)
        uuid = [get_uuid(libnode, parent)]
    elif type == 'Edge':
        edge_start_ret = find_graph_element_by_uuid(sdfg, edge_a_uuid)
        edge_start = edge_start_ret['element']
        edge_parent = edge_start_ret['parent']
        if edge_start is not None:
            if edge_parent is None:
                edge_parent = sdfg

            if isinstance(edge_parent, SDFGState):
                if not (isinstance(edge_start, nodes.Node) and
                        isinstance(parent, nodes.Node)):
                    return {
                        'error': {
                            'message': 'Failed to add edge',
                            'details': 'Must connect two nodes or two states',
                        },
                    }
                memlet = Memlet()
                edge_parent.add_edge(edge_start, None, parent, None, memlet)
            elif isinstance(edge_parent, SDFG):
                if not (isinstance(edge_start, SDFGState) and
                        isinstance(parent, SDFGState)):
                    return {
                        'error': {
                            'message': 'Failed to add edge',
                            'details': 'Must connect two nodes or two states',
                        },
                    }
                isedge = InterstateEdge()
                edge_parent.add_edge(edge_start, parent, isedge)
            uuid = ['NONE']
        else:
            raise ValueError('No edge starting point provided')

    old_meta = disable_save_metadata()
    new_sdfg_str = sdfg.to_json()
    restore_save_metadata(old_meta)

    return {
        'sdfg': new_sdfg_str,
        'uuid': uuid,
    }
def make_sdfg(specialize):

    if specialize:
        sdfg = SDFG("histogram_fpga_parallel_{}_{}x{}".format(
            P.get(), H.get(), W.get()))
    else:
        sdfg = SDFG("histogram_fpga_parallel_{}".format(P.get()))

    copy_to_fpga_state = make_copy_to_fpga_state(sdfg)

    state = sdfg.add_state("compute")

    # Compute module
    nested_sdfg = make_compute_nested_sdfg(state)
    tasklet = state.add_nested_sdfg(nested_sdfg, sdfg, {"A_pipe_in"},
                                    {"hist_pipe_out"})
    A_pipes_out = state.add_stream("A_pipes",
                                   dtype,
                                   shape=(P, ),
                                   transient=True,
                                   storage=StorageType.FPGA_Local)
    A_pipes_in = state.add_stream("A_pipes",
                                  dtype,
                                  shape=(P, ),
                                  transient=True,
                                  storage=StorageType.FPGA_Local)
    hist_pipes_out = state.add_stream("hist_pipes",
                                      itype,
                                      shape=(P, ),
                                      transient=True,
                                      storage=StorageType.FPGA_Local)
    unroll_entry, unroll_exit = state.add_map(
        "unroll_compute", {"p": "0:P"},
        schedule=dace.ScheduleType.FPGA_Device,
        unroll=True)
    state.add_memlet_path(unroll_entry, A_pipes_in, memlet=EmptyMemlet())
    state.add_memlet_path(hist_pipes_out, unroll_exit, memlet=EmptyMemlet())
    state.add_memlet_path(A_pipes_in,
                          tasklet,
                          dst_conn="A_pipe_in",
                          memlet=Memlet.simple(A_pipes_in,
                                               "p",
                                               num_accesses="W*H"))
    state.add_memlet_path(tasklet,
                          hist_pipes_out,
                          src_conn="hist_pipe_out",
                          memlet=Memlet.simple(hist_pipes_out,
                                               "p",
                                               num_accesses="num_bins"))

    # Read module
    a_device = state.add_array("A_device", (H, W),
                               dtype,
                               transient=True,
                               storage=dace.dtypes.StorageType.FPGA_Global)
    read_entry, read_exit = state.add_map("read_map", {
        "h": "0:H",
        "w": "0:W:P"
    },
                                          schedule=ScheduleType.FPGA_Device)
    a_val = state.add_array("A_val", (P, ),
                            dtype,
                            transient=True,
                            storage=StorageType.FPGA_Local)
    read_unroll_entry, read_unroll_exit = state.add_map(
        "read_unroll", {"p": "0:P"},
        schedule=ScheduleType.FPGA_Device,
        unroll=True)
    read_tasklet = state.add_tasklet("read", {"A_in"}, {"A_pipe"},
                                     "A_pipe = A_in[p]")
    state.add_memlet_path(a_device,
                          read_entry,
                          a_val,
                          memlet=Memlet(a_val,
                                        num_accesses=1,
                                        subset=Indices(["0"]),
                                        vector_length=P.get(),
                                        other_subset=Indices(["h", "w"])))
    state.add_memlet_path(a_val,
                          read_unroll_entry,
                          read_tasklet,
                          dst_conn="A_in",
                          memlet=Memlet.simple(a_val,
                                               "0",
                                               veclen=P.get(),
                                               num_accesses=1))
    state.add_memlet_path(read_tasklet,
                          read_unroll_exit,
                          read_exit,
                          A_pipes_out,
                          src_conn="A_pipe",
                          memlet=Memlet.simple(A_pipes_out, "p"))

    # Write module
    hist_pipes_in = state.add_stream("hist_pipes",
                                     itype,
                                     shape=(P, ),
                                     transient=True,
                                     storage=StorageType.FPGA_Local)
    hist_device_out = state.add_array(
        "hist_device", (num_bins, ),
        itype,
        transient=True,
        storage=dace.dtypes.StorageType.FPGA_Global)
    merge_entry, merge_exit = state.add_map("merge", {"nb": "0:num_bins"},
                                            schedule=ScheduleType.FPGA_Device)
    merge_reduce = state.add_reduce("lambda a, b: a + b", (0, ),
                                    "0",
                                    schedule=ScheduleType.FPGA_Device)
    state.add_memlet_path(hist_pipes_in,
                          merge_entry,
                          merge_reduce,
                          memlet=Memlet.simple(hist_pipes_in,
                                               "0:P",
                                               num_accesses=P))
    state.add_memlet_path(merge_reduce,
                          merge_exit,
                          hist_device_out,
                          memlet=dace.memlet.Memlet.simple(
                              hist_device_out, "nb"))

    copy_to_host_state = make_copy_to_host_state(sdfg)

    sdfg.add_edge(copy_to_fpga_state, state, dace.graph.edges.InterstateEdge())
    sdfg.add_edge(state, copy_to_host_state, dace.graph.edges.InterstateEdge())

    return sdfg
Exemple #17
0
    def apply(self, graph: SDFGState, sdfg: SDFG) -> nodes.MapEntry:
        me = self.mapentry

        # Add new map within map
        mx = graph.exit_node(me)
        new_me, new_mx = graph.add_map('warp_tile',
                                       dict(__tid=f'0:{self.warp_size}'),
                                       dtypes.ScheduleType.GPU_ThreadBlock)
        __tid = symbolic.pystr_to_symbolic('__tid')
        for e in graph.out_edges(me):
            xfh.reconnect_edge_through_map(graph, e, new_me, True)
        for e in graph.in_edges(mx):
            xfh.reconnect_edge_through_map(graph, e, new_mx, False)

        # Stride and offset all internal maps
        maps_to_stride = xfh.get_internal_scopes(graph, new_me, immediate=True)
        for nstate, nmap in maps_to_stride:
            nsdfg = nstate.parent
            nsdfg_node = nsdfg.parent_nsdfg_node

            # Map cannot be partitioned across a warp
            if (nmap.range.size()[-1] < self.warp_size) == True:
                continue

            if nsdfg is not sdfg and nsdfg_node is not None:
                nsdfg_node.symbol_mapping['__tid'] = __tid
                if '__tid' not in nsdfg.symbols:
                    nsdfg.add_symbol('__tid', dtypes.int32)
            nmap.range[-1] = (nmap.range[-1][0], nmap.range[-1][1] - __tid,
                              nmap.range[-1][2] * self.warp_size)
            subgraph = nstate.scope_subgraph(nmap)
            subgraph.replace(nmap.params[-1], f'{nmap.params[-1]} + __tid')
            inner_map_exit = nstate.exit_node(nmap)
            # If requested, replicate maps with multiple dependent maps
            if self.replicate_maps:
                destinations = [
                    nstate.memlet_path(edge)[-1].dst
                    for edge in nstate.out_edges(inner_map_exit)
                ]

                for dst in destinations:
                    # Transformation will not replicate map with more than one
                    # output
                    if len(destinations) != 1:
                        break
                    if not isinstance(dst, nodes.AccessNode):
                        continue  # Not leading to access node
                    if not xfh.contained_in(nstate, dst, new_me):
                        continue  # Memlet path goes out of map
                    if not nsdfg.arrays[dst.data].transient:
                        continue  # Cannot modify non-transients
                    for edge in nstate.out_edges(dst)[1:]:
                        rep_subgraph = xfh.replicate_scope(
                            nsdfg, nstate, subgraph)
                        rep_edge = nstate.out_edges(
                            rep_subgraph.sink_nodes()[0])[0]
                        # Add copy of data
                        newdesc = copy.deepcopy(sdfg.arrays[dst.data])
                        newname = nsdfg.add_datadesc(dst.data,
                                                     newdesc,
                                                     find_new_name=True)
                        newaccess = nstate.add_access(newname)
                        # Redirect edges
                        xfh.redirect_edge(nstate,
                                          rep_edge,
                                          new_dst=newaccess,
                                          new_data=newname)
                        xfh.redirect_edge(nstate,
                                          edge,
                                          new_src=newaccess,
                                          new_data=newname)

            # If has WCR, add warp-collaborative reduction on outputs
            for out_edge in nstate.out_edges(inner_map_exit):
                dst = nstate.memlet_path(out_edge)[-1].dst
                if not xfh.contained_in(nstate, dst, new_me):
                    # Skip edges going out of map
                    continue
                if dst.desc(nsdfg).storage == dtypes.StorageType.GPU_Global:
                    # Skip shared memory
                    continue
                if out_edge.data.wcr is not None:
                    ctype = nsdfg.arrays[out_edge.data.data].dtype.ctype
                    redtype = detect_reduction_type(out_edge.data.wcr)
                    if redtype == dtypes.ReductionType.Custom:
                        raise NotImplementedError
                    credtype = ('dace::ReductionType::' +
                                str(redtype)[str(redtype).find('.') + 1:])

                    # One element: tasklet
                    if out_edge.data.subset.num_elements() == 1:
                        # Add local access between thread-local and warp reduction
                        name = nsdfg._find_new_name(out_edge.data.data)
                        nsdfg.add_scalar(
                            name,
                            nsdfg.arrays[out_edge.data.data].dtype,
                            transient=True)

                        # Initialize thread-local to global value
                        read = nstate.add_read(out_edge.data.data)
                        write = nstate.add_write(name)
                        edge = nstate.add_nedge(read, write,
                                                copy.deepcopy(out_edge.data))
                        edge.data.wcr = None
                        xfh.state_fission(nsdfg,
                                          SubgraphView(nstate, [read, write]))

                        newnode = nstate.add_access(name)
                        nstate.remove_edge(out_edge)
                        edge = nstate.add_edge(out_edge.src, out_edge.src_conn,
                                               newnode, None,
                                               copy.deepcopy(out_edge.data))
                        for e in nstate.memlet_path(edge):
                            e.data.data = name
                            e.data.subset = subsets.Range([(0, 0, 1)])

                        wrt = nstate.add_tasklet(
                            'warpreduce', {'__a'}, {'__out'},
                            f'__out = dace::warpReduce<{credtype}, {ctype}>::reduce(__a);',
                            dtypes.Language.CPP)
                        nstate.add_edge(newnode, None, wrt, '__a',
                                        Memlet(name))
                        out_edge.data.wcr = None
                        nstate.add_edge(wrt, '__out', out_edge.dst, None,
                                        out_edge.data)
                    else:  # More than one element: mapped tasklet
                        # Could be a parallel summation
                        # TODO(later): Check if reduction
                        continue
            # End of WCR to warp reduction

        # Make nested SDFG out of new scope
        xfh.nest_state_subgraph(sdfg, graph,
                                graph.scope_subgraph(new_me, False, False))

        return new_me
Exemple #18
0
def _make_sdfg_getrs(node, parent_state, parent_sdfg, implementation):

    arr_desc = node.validate(parent_sdfg, parent_state)
    (ain_shape, ain_dtype, ain_strides, bin_shape, bin_dtype, bin_strides,
     out_shape, out_dtype, out_strides, n, rhs) = arr_desc
    dtype = ain_dtype

    sdfg = dace.SDFG("{l}_sdfg".format(l=node.label))

    ain_arr = sdfg.add_array('_ain',
                             ain_shape,
                             dtype=ain_dtype,
                             strides=ain_strides)
    ainout_arr = sdfg.add_array('_ainout', [n, n],
                                dtype=ain_dtype,
                                transient=True)
    bin_arr = sdfg.add_array('_bin',
                             bin_shape,
                             dtype=bin_dtype,
                             strides=bin_strides)
    binout_shape = [n, rhs]
    if implementation == 'cuSolverDn':
        binout_shape = [rhs, n]
    binout_arr = sdfg.add_array('_binout',
                                binout_shape,
                                dtype=out_dtype,
                                transient=True)
    bout_arr = sdfg.add_array('_bout',
                              out_shape,
                              dtype=out_dtype,
                              strides=out_strides)
    ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True)
    info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True)

    state = sdfg.add_state("{l}_state".format(l=node.label))

    getrf_node = Getrf('getrf')
    getrf_node.implementation = implementation
    getrs_node = Getrs('getrs')
    getrs_node.implementation = implementation

    ain = state.add_read('_ain')
    ainout1 = state.add_read('_ainout')
    ainout2 = state.add_access('_ainout')
    bin = state.add_read('_bin')
    binout1 = state.add_read('_binout')
    binout2 = state.add_read('_binout')
    bout = state.add_access('_bout')
    if implementation == 'cuSolverDn':
        transpose_ain = Transpose('AT', dtype=ain_dtype)
        transpose_ain.implementation = 'cuBLAS'
        state.add_edge(ain, None, transpose_ain, '_inp',
                       Memlet.from_array(*ain_arr))
        state.add_edge(transpose_ain, '_out', ainout1, None,
                       Memlet.from_array(*ainout_arr))
        transpose_bin = Transpose('bT', dtype=bin_dtype)
        transpose_bin.implementation = 'cuBLAS'
        state.add_edge(bin, None, transpose_bin, '_inp',
                       Memlet.from_array(*bin_arr))
        state.add_edge(transpose_bin, '_out', binout1, None,
                       Memlet.from_array(*binout_arr))
        transpose_out = Transpose('XT', dtype=bin_dtype)
        transpose_out.implementation = 'cuBLAS'
        state.add_edge(binout2, None, transpose_out, '_inp',
                       Memlet.from_array(*binout_arr))
        state.add_edge(transpose_out, '_out', bout, None,
                       Memlet.from_array(*bout_arr))
    else:
        state.add_nedge(ain, ainout1, Memlet.from_array(*ain_arr))
        state.add_nedge(bin, binout1, Memlet.from_array(*bin_arr))
        state.add_nedge(binout2, bout, Memlet.from_array(*bout_arr))

    ipiv = state.add_access('_pivots')
    info1 = state.add_write('_info')
    info2 = state.add_write('_info')

    state.add_memlet_path(ainout1,
                          getrf_node,
                          dst_conn="_xin",
                          memlet=Memlet.from_array(*ainout_arr))
    state.add_memlet_path(getrf_node,
                          info1,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getrf_node,
                          ipiv,
                          src_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getrf_node,
                          ainout2,
                          src_conn="_xout",
                          memlet=Memlet.from_array(*ainout_arr))
    state.add_memlet_path(ainout2,
                          getrs_node,
                          dst_conn="_a",
                          memlet=Memlet.from_array(*ainout_arr))
    state.add_memlet_path(binout1,
                          getrs_node,
                          dst_conn="_rhs_in",
                          memlet=Memlet.from_array(*binout_arr))
    state.add_memlet_path(ipiv,
                          getrs_node,
                          dst_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getrs_node,
                          info2,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getrs_node,
                          binout2,
                          src_conn="_rhs_out",
                          memlet=Memlet.from_array(*binout_arr))

    return sdfg
Exemple #19
0
    def apply(self, state: SDFGState, sdfg: SDFG):
        adesc = self.a.desc(sdfg)
        bdesc = self.b.desc(sdfg)
        edge = state.edges_between(self.a, self.b)[0]

        if len(adesc.shape) >= len(bdesc.shape):
            copy_shape = edge.data.get_src_subset(edge, state).size()
            copy_a = True
        else:
            copy_shape = edge.data.get_dst_subset(edge, state).size()
            copy_a = False

        maprange = {f'__i{i}': (0, s - 1, 1) for i, s in enumerate(copy_shape)}

        av = self.a.data
        bv = self.b.data
        avnode = self.a
        bvnode = self.b

        # Linearize and delinearize to get index expression for other side
        if copy_a:
            a_index = [
                symbolic.pystr_to_symbolic(f'__i{i}')
                for i in range(len(copy_shape))
            ]
            b_index = self.delinearize_linearize(
                bdesc, copy_shape, edge.data.get_dst_subset(edge, state))
        else:
            a_index = self.delinearize_linearize(
                adesc, copy_shape, edge.data.get_src_subset(edge, state))
            b_index = [
                symbolic.pystr_to_symbolic(f'__i{i}')
                for i in range(len(copy_shape))
            ]

        a_subset = subsets.Range([(ind, ind, 1) for ind in a_index])
        b_subset = subsets.Range([(ind, ind, 1) for ind in b_index])

        # Set schedule based on GPU arrays
        schedule = dtypes.ScheduleType.Default
        if adesc.storage == dtypes.StorageType.GPU_Global or bdesc.storage == dtypes.StorageType.GPU_Global:
            # If already inside GPU kernel
            if is_devicelevel_gpu(sdfg, state, self.a):
                schedule = dtypes.ScheduleType.Sequential
            else:
                schedule = dtypes.ScheduleType.GPU_Device

        # Add copy map
        t, _, _ = state.add_mapped_tasklet(
            'copy',
            maprange,
            dict(__inp=Memlet(data=av, subset=a_subset)),
            '__out = __inp',
            dict(__out=Memlet(data=bv, subset=b_subset)),
            schedule,
            external_edges=True,
            input_nodes={av: avnode},
            output_nodes={bv: bvnode})

        # Set connector types (due to this transformation appearing in codegen, after connector
        # types have been resolved)
        t.in_connectors['__inp'] = adesc.dtype
        t.out_connectors['__out'] = bdesc.dtype

        # Remove old edge
        state.remove_edge(edge)