示例#1
0
def make_sdfg(implementation, dtype, storage=dace.StorageType.Default):

    n = dace.symbol("n", dace.int64)

    sdfg = dace.SDFG("linalg_cholesky_{}_{}".format(implementation, dtype))
    state = sdfg.add_state("dataflow")

    inp = sdfg.add_array("xin", [n, n], dtype)
    out = sdfg.add_array("xout", [n, n], dtype)

    xin = state.add_read("xin")
    xout = state.add_write("xout")

    chlsky_node = Cholesky("cholesky", lower=True)
    chlsky_node.implementation = implementation

    state.add_memlet_path(xin,
                          chlsky_node,
                          dst_conn="_a",
                          memlet=Memlet.from_array(*inp))
    state.add_memlet_path(chlsky_node,
                          xout,
                          src_conn="_b",
                          memlet=Memlet.from_array(*out))

    return sdfg
示例#2
0
def _make_sdfg(node, parent_state, parent_sdfg, implementation):

    inp_desc, inp_shape, out_desc, out_shape = node.validate(parent_sdfg, parent_state)
    dtype = inp_desc.dtype

    sdfg = dace.SDFG("{l}_sdfg".format(l=node.label))

    ain_arr = sdfg.add_array('_a', inp_shape, dtype=dtype, strides=inp_desc.strides)
    bout_arr = sdfg.add_array('_b', out_shape, dtype=dtype, strides=out_desc.strides)
    info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True)
    if implementation == 'cuSolverDn':
        binout_arr = sdfg.add_array('_bt', inp_shape, dtype=dtype, transient=True)
    else:
        binout_arr = bout_arr

    state = sdfg.add_state("{l}_state".format(l=node.label))

    potrf_node = Potrf('potrf', lower=node.lower)
    potrf_node.implementation = implementation

    _, me, mx = state.add_mapped_tasklet('_uzero_',
                                         dict(__i="0:%s" % out_shape[0], __j="0:%s" % out_shape[1]),
                                         dict(_inp=Memlet.simple('_b', '__i, __j')),
                                         '_out = (__i < __j) ? 0 : _inp;',
                                         dict(_out=Memlet.simple('_b', '__i, __j')),
                                         language=dace.dtypes.Language.CPP,
                                         external_edges=True)

    ain = state.add_read('_a')
    if implementation == 'cuSolverDn':
        binout1 = state.add_access('_bt')
        binout2 = state.add_access('_bt')
        binout3 = state.in_edges(me)[0].src
        bout = state.out_edges(mx)[0].dst
        transpose_ain = Transpose('AT', dtype=dtype)
        transpose_ain.implementation = 'cuBLAS'
        state.add_edge(ain, None, transpose_ain, '_inp', Memlet.from_array(*ain_arr))
        state.add_edge(transpose_ain, '_out', binout1, None, Memlet.from_array(*binout_arr))
        transpose_out = Transpose('BT', dtype=dtype)
        transpose_out.implementation = 'cuBLAS'
        state.add_edge(binout2, None, transpose_out, '_inp', Memlet.from_array(*binout_arr))
        state.add_edge(transpose_out, '_out', binout3, None, Memlet.from_array(*bout_arr))
    else:
        binout1 = state.add_access('_b')
        binout2 = state.in_edges(me)[0].src
        binout3 = state.out_edges(mx)[0].dst
        state.add_nedge(ain, binout1, Memlet.from_array(*ain_arr))

    info = state.add_write('_info')

    state.add_memlet_path(binout1, potrf_node, dst_conn="_xin", memlet=Memlet.from_array(*binout_arr))
    state.add_memlet_path(potrf_node, info, src_conn="_res", memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(potrf_node, binout2, src_conn="_xout", memlet=Memlet.from_array(*binout_arr))

    return sdfg
示例#3
0
    def apply(self, _, sdfg: sd.SDFG):
        # Obtain loop information
        guard: sd.SDFGState = self.loop_guard
        body: sd.SDFGState = self.loop_begin

        # Obtain iteration variable, range, and stride
        itervar, (start, end, step), _ = find_for_loop(sdfg, guard, body)

        forward_loop = step > 0

        for node in body.nodes():
            if isinstance(node, nodes.MapEntry):
                map_entry = node
            if isinstance(node, nodes.MapExit):
                map_exit = node

        # nest map's content in sdfg
        map_subgraph = body.scope_subgraph(map_entry, include_entry=False, include_exit=False)
        nsdfg = helpers.nest_state_subgraph(sdfg, body, map_subgraph, full_data=True)

        # replicate loop in nested sdfg
        new_before, new_guard, new_after = nsdfg.sdfg.add_loop(
            before_state=None,
            loop_state=nsdfg.sdfg.nodes()[0],
            loop_end_state=None,
            after_state=None,
            loop_var=itervar,
            initialize_expr=f'{start}',
            condition_expr=f'{itervar} <= {end}' if forward_loop else f'{itervar} >= {end}',
            increment_expr=f'{itervar} + {step}' if forward_loop else f'{itervar} - {abs(step)}')

        # remove outer loop
        before_guard_edge = nsdfg.sdfg.edges_between(new_before, new_guard)[0]
        for e in nsdfg.sdfg.out_edges(new_guard):
            if e.dst is new_after:
                guard_after_edge = e
            else:
                guard_body_edge = e

        for body_inedge in sdfg.in_edges(body):
            if body_inedge.src is guard:
                guard_body_edge.data.assignments.update(body_inedge.data.assignments)
            sdfg.remove_edge(body_inedge)
        for body_outedge in sdfg.out_edges(body):
            sdfg.remove_edge(body_outedge)
        for guard_inedge in sdfg.in_edges(guard):
            before_guard_edge.data.assignments.update(guard_inedge.data.assignments)
            guard_inedge.data.assignments = {}
            sdfg.add_edge(guard_inedge.src, body, guard_inedge.data)
            sdfg.remove_edge(guard_inedge)
        for guard_outedge in sdfg.out_edges(guard):
            if guard_outedge.dst is body:
                guard_body_edge.data.assignments.update(guard_outedge.data.assignments)
            else:
                guard_after_edge.data.assignments.update(guard_outedge.data.assignments)
            guard_outedge.data.condition = CodeBlock("1")
            sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data)
            sdfg.remove_edge(guard_outedge)
        sdfg.remove_node(guard)
        if itervar in nsdfg.symbol_mapping:
            del nsdfg.symbol_mapping[itervar]
        if itervar in sdfg.symbols:
            del sdfg.symbols[itervar]

        # Add missing data/symbols
        for s in nsdfg.sdfg.free_symbols:
            if s in nsdfg.symbol_mapping:
                continue
            if s in sdfg.symbols:
                nsdfg.symbol_mapping[s] = s
            elif s in sdfg.arrays:
                desc = sdfg.arrays[s]
                access = body.add_access(s)
                conn = nsdfg.sdfg.add_datadesc(s, copy.deepcopy(desc))
                nsdfg.sdfg.arrays[s].transient = False
                nsdfg.add_in_connector(conn)
                body.add_memlet_path(access, map_entry, nsdfg, memlet=Memlet.from_array(s, desc), dst_conn=conn)
            else:
                raise NotImplementedError(f"Free symbol {s} is neither a symbol nor data.")
        to_delete = set()
        for s in nsdfg.symbol_mapping:
            if s not in nsdfg.sdfg.free_symbols:
                to_delete.add(s)
        for s in to_delete:
            del nsdfg.symbol_mapping[s]

        # propagate scope for correct volumes
        scope_tree = ScopeTree(map_entry, map_exit)
        scope_tree.parent = ScopeTree(None, None)
        # The first execution helps remove apperances of symbols
        # that are now defined only in the nested SDFG in memlets.
        propagation.propagate_memlets_scope(sdfg, body, scope_tree)

        for s in to_delete:
            if helpers.is_symbol_unused(sdfg, s):
                sdfg.remove_symbol(s)

        from dace.transformation.interstate import RefineNestedAccess
        transformation = RefineNestedAccess()
        transformation.setup_match(sdfg, 0, sdfg.node_id(body), {RefineNestedAccess.nsdfg: body.node_id(nsdfg)}, 0)
        transformation.apply(body, sdfg)

        # Second propagation for refined accesses.
        propagation.propagate_memlets_scope(sdfg, body, scope_tree)
示例#4
0
文件: inv.py 项目: mfkiwl/dace
def _make_sdfg_getrs(node, parent_state, parent_sdfg, implementation):

    arr_desc = node.validate(parent_sdfg, parent_state)
    if node.overwrite:
        in_shape, in_dtype, in_strides, n = arr_desc
    else:
        (in_shape, in_dtype, in_strides, out_shape, out_dtype, out_strides,
         n) = arr_desc
    dtype = in_dtype

    sdfg = dace.SDFG("{l}_sdfg".format(l=node.label))

    a_arr = sdfg.add_array('_ain',
                           in_shape,
                           dtype=in_dtype,
                           strides=in_strides)
    if not node.overwrite:
        ain_arr = a_arr
        a_arr = sdfg.add_array('_ainout', [n, n],
                               dtype=in_dtype,
                               transient=True)
        b_arr = sdfg.add_array('_aout',
                               out_shape,
                               dtype=out_dtype,
                               strides=out_strides)
    else:
        b_arr = sdfg.add_array('_b', [n, n], dtype=dtype, transient=True)
    ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True)
    info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True)

    state = sdfg.add_state("{l}_state".format(l=node.label))

    getrf_node = Getrf('getrf')
    getrf_node.implementation = implementation
    getrs_node = Getrs('getrs')
    getrs_node.implementation = implementation

    if node.overwrite:
        ain = state.add_read('_ain')
        ainout = state.add_access('_ain')
        aout = state.add_write('_ain')
        bin_name = '_b'
        bout = state.add_write('_b')
        state.add_nedge(bout, aout, Memlet.from_array(*a_arr))
    else:
        a = state.add_read('_ain')
        ain = state.add_read('_ainout')
        ainout = state.add_access('_ainout')
        # aout = state.add_write('_aout')
        state.add_nedge(a, ain, Memlet.from_array(*ain_arr))
        bin_name = '_aout'
        bout = state.add_access('_aout')

    _, _, mx = state.add_mapped_tasklet(
        '_eye_',
        dict(i="0:n", j="0:n"), {},
        '_out = (i == j) ? 1 : 0;',
        dict(_out=Memlet.simple(bin_name, 'i, j')),
        language=dace.dtypes.Language.CPP,
        external_edges=True)
    bin = state.out_edges(mx)[0].dst

    ipiv = state.add_access('_pivots')
    info1 = state.add_write('_info')
    info2 = state.add_write('_info')

    state.add_memlet_path(ain,
                          getrf_node,
                          dst_conn="_xin",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(getrf_node,
                          info1,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getrf_node,
                          ipiv,
                          src_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getrf_node,
                          ainout,
                          src_conn="_xout",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(ainout,
                          getrs_node,
                          dst_conn="_a",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(bin,
                          getrs_node,
                          dst_conn="_rhs_in",
                          memlet=Memlet.from_array(*b_arr))
    state.add_memlet_path(ipiv,
                          getrs_node,
                          dst_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getrs_node,
                          info2,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getrs_node,
                          bout,
                          src_conn="_rhs_out",
                          memlet=Memlet.from_array(*b_arr))

    return sdfg
示例#5
0
文件: inv.py 项目: mfkiwl/dace
def _make_sdfg(node, parent_state, parent_sdfg, implementation):

    arr_desc = node.validate(parent_sdfg, parent_state)
    if node.overwrite:
        in_shape, in_dtype, in_strides, n = arr_desc
    else:
        (in_shape, in_dtype, in_strides, out_shape, out_dtype, out_strides,
         n) = arr_desc
    dtype = in_dtype

    sdfg = dace.SDFG("{l}_sdfg".format(l=node.label))

    a_arr = sdfg.add_array('_ain',
                           in_shape,
                           dtype=in_dtype,
                           strides=in_strides)
    if not node.overwrite:
        ain_arr = a_arr
        a_arr = sdfg.add_array('_aout',
                               out_shape,
                               dtype=out_dtype,
                               strides=out_strides)
    ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True)
    info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True)

    state = sdfg.add_state("{l}_state".format(l=node.label))

    getrf_node = Getrf('getrf')
    getrf_node.implementation = implementation
    getri_node = Getri('getri')
    getri_node.implementation = implementation

    if node.overwrite:
        ain = state.add_read('_ain')
        ainout = state.add_access('_ain')
        aout = state.add_write('_ain')
    else:
        a = state.add_read('_ain')
        ain = state.add_read('_aout')
        ainout = state.add_access('_aout')
        aout = state.add_write('_aout')
        state.add_nedge(a, ain, Memlet.from_array(*ain_arr))

    ipiv = state.add_access('_pivots')
    info1 = state.add_write('_info')
    info2 = state.add_write('_info')

    state.add_memlet_path(ain,
                          getrf_node,
                          dst_conn="_xin",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(getrf_node,
                          info1,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getrf_node,
                          ipiv,
                          src_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getrf_node,
                          ainout,
                          src_conn="_xout",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(ainout,
                          getri_node,
                          dst_conn="_xin",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(ipiv,
                          getri_node,
                          dst_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getri_node,
                          info2,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getri_node,
                          aout,
                          src_conn="_xout",
                          memlet=Memlet.from_array(*a_arr))

    return sdfg
示例#6
0
    def apply(self, sdfg: SDFG):
        subgraph = self.subgraph_view(sdfg)

        entry_states_in, entry_states_out = self.get_entry_states(
            sdfg, subgraph)
        _, exit_states_out = self.get_exit_states(sdfg, subgraph)

        entry_state_in = entry_states_in.pop()
        entry_state_out = entry_states_out.pop() \
            if len(entry_states_out) > 0 else None
        exit_state_out = exit_states_out.pop() \
            if len(exit_states_out) > 0 else None

        launch_state = None
        entry_guard_state = None
        exit_guard_state = None

        # generate entry guard state if needed
        if self.include_in_assignment and entry_state_out is not None:
            entry_edge = sdfg.edges_between(entry_state_out, entry_state_in)[0]
            if len(entry_edge.data.assignments) > 0:
                entry_guard_state = sdfg.add_state(
                    label='{}kernel_entry_guard'.format(
                        self.kernel_prefix +
                        '_' if self.kernel_prefix != '' else ''))
                sdfg.add_edge(entry_state_out, entry_guard_state,
                              InterstateEdge(entry_edge.data.condition))
                sdfg.add_edge(
                    entry_guard_state, entry_state_in,
                    InterstateEdge(None, entry_edge.data.assignments))
                sdfg.remove_edge(entry_edge)

                # Update SubgraphView
                new_node_list = subgraph.nodes()
                new_node_list.append(entry_guard_state)
                subgraph = SubgraphView(sdfg, new_node_list)

                launch_state = sdfg.add_state_before(
                    entry_guard_state,
                    label='{}kernel_launch'.format(
                        self.kernel_prefix +
                        '_' if self.kernel_prefix != '' else ''))

        # generate exit guard state
        if exit_state_out is not None:
            exit_guard_state = sdfg.add_state_before(
                exit_state_out,
                label='{}kernel_exit_guard'.format(
                    self.kernel_prefix +
                    '_' if self.kernel_prefix != '' else ''))

            # Update SubgraphView
            new_node_list = subgraph.nodes()
            new_node_list.append(exit_guard_state)
            subgraph = SubgraphView(sdfg, new_node_list)

            if launch_state is None:
                launch_state = sdfg.add_state_before(
                    exit_state_out,
                    label='{}kernel_launch'.format(
                        self.kernel_prefix +
                        '_' if self.kernel_prefix != '' else ''))

        # If the launch state doesn't exist at this point then there is no other
        # states outside of the kernel, so create a stand alone launch state
        if launch_state is None:
            assert (entry_state_in is None and exit_state_out is None)
            launch_state = sdfg.add_state(label='{}kernel_launch'.format(
                self.kernel_prefix + '_' if self.kernel_prefix != '' else ''))

        # create sdfg for kernel and fill it with states and edges from
        # ssubgraph dfg will be nested at the end
        kernel_sdfg = SDFG(
            '{}kernel'.format(self.kernel_prefix +
                              '_' if self.kernel_prefix != '' else ''))

        edges = subgraph.edges()
        for edge in edges:
            kernel_sdfg.add_edge(edge.src, edge.dst, edge.data)

        # Setting entry node in nested SDFG if no entry guard was created
        if entry_guard_state is None:
            kernel_sdfg.start_state = kernel_sdfg.node_id(entry_state_in)

        for state in subgraph:
            state.parent = kernel_sdfg

        # remove the now nested nodes from the outer sdfg and make sure the
        # launch state is properly connected to remaining states
        sdfg.remove_nodes_from(subgraph.nodes())

        if entry_state_out is not None \
                and len(sdfg.edges_between(entry_state_out, launch_state)) == 0:
            sdfg.add_edge(entry_state_out, launch_state, InterstateEdge())

        if exit_state_out is not None \
                and len(sdfg.edges_between(launch_state, exit_state_out)) == 0:
            sdfg.add_edge(launch_state, exit_state_out, InterstateEdge())

        # Handle data for kernel
        kernel_data = set(node.data for state in kernel_sdfg
                          for node in state.nodes()
                          if isinstance(node, nodes.AccessNode))

        # move Streams and Register data into the nested SDFG
        # normal data will be added as kernel argument
        kernel_args = []
        for data in kernel_data:
            if (isinstance(sdfg.arrays[data], dace.data.Stream) or
                (isinstance(sdfg.arrays[data], dace.data.Array)
                 and sdfg.arrays[data].storage == StorageType.Register)):
                kernel_sdfg.add_datadesc(data, sdfg.arrays[data])
                del sdfg.arrays[data]
            else:
                copy_desc = copy.deepcopy(sdfg.arrays[data])
                copy_desc.transient = False
                copy_desc.storage = StorageType.Default
                kernel_sdfg.add_datadesc(data, copy_desc)
                kernel_args.append(data)

        # read only data will be passed as input, writeable data will be passed
        # as 'output' otherwise kernel cannot write to data
        kernel_args_read = set()
        kernel_args_write = set()
        for data in kernel_args:
            data_accesses_read_only = [
                node.access == dtypes.AccessType.ReadOnly
                for state in kernel_sdfg for node in state
                if isinstance(node, nodes.AccessNode) and node.data == data
            ]
            if all(data_accesses_read_only):
                kernel_args_read.add(data)
            else:
                kernel_args_write.add(data)

        # Kernel SDFG is complete at this point
        if self.validate:
            kernel_sdfg.validate()

        # Filling launch state with nested SDFG, map and access nodes
        map_entry, map_exit = launch_state.add_map(
            '{}kernel_launch_map'.format(
                self.kernel_prefix + '_' if self.kernel_prefix != '' else ''),
            dict(ignore='0'),
            schedule=ScheduleType.GPU_Persistent,
        )

        nested_sdfg = launch_state.add_nested_sdfg(
            kernel_sdfg,
            sdfg,
            kernel_args_read,
            kernel_args_write,
        )

        # Create and connect read only data access nodes
        for arg in kernel_args_read:
            read_node = launch_state.add_read(arg)
            launch_state.add_memlet_path(read_node,
                                         map_entry,
                                         nested_sdfg,
                                         dst_conn=arg,
                                         memlet=Memlet.from_array(
                                             arg, sdfg.arrays[arg]))

        # Create and connect writable data access nodes
        for arg in kernel_args_write:
            write_node = launch_state.add_write(arg)
            launch_state.add_memlet_path(nested_sdfg,
                                         map_exit,
                                         write_node,
                                         src_conn=arg,
                                         memlet=Memlet.from_array(
                                             arg, sdfg.arrays[arg]))

        # Transformation is done
        if self.validate:
            sdfg.validate()
示例#7
0
文件: solve.py 项目: mfkiwl/dace
def _make_sdfg_getrs(node, parent_state, parent_sdfg, implementation):

    arr_desc = node.validate(parent_sdfg, parent_state)
    (ain_shape, ain_dtype, ain_strides, bin_shape, bin_dtype, bin_strides,
     out_shape, out_dtype, out_strides, n, rhs) = arr_desc
    dtype = ain_dtype

    sdfg = dace.SDFG("{l}_sdfg".format(l=node.label))

    ain_arr = sdfg.add_array('_ain',
                             ain_shape,
                             dtype=ain_dtype,
                             strides=ain_strides)
    ainout_arr = sdfg.add_array('_ainout', [n, n],
                                dtype=ain_dtype,
                                transient=True)
    bin_arr = sdfg.add_array('_bin',
                             bin_shape,
                             dtype=bin_dtype,
                             strides=bin_strides)
    binout_shape = [n, rhs]
    if implementation == 'cuSolverDn':
        binout_shape = [rhs, n]
    binout_arr = sdfg.add_array('_binout',
                                binout_shape,
                                dtype=out_dtype,
                                transient=True)
    bout_arr = sdfg.add_array('_bout',
                              out_shape,
                              dtype=out_dtype,
                              strides=out_strides)
    ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True)
    info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True)

    state = sdfg.add_state("{l}_state".format(l=node.label))

    getrf_node = Getrf('getrf')
    getrf_node.implementation = implementation
    getrs_node = Getrs('getrs')
    getrs_node.implementation = implementation

    ain = state.add_read('_ain')
    ainout1 = state.add_read('_ainout')
    ainout2 = state.add_access('_ainout')
    bin = state.add_read('_bin')
    binout1 = state.add_read('_binout')
    binout2 = state.add_read('_binout')
    bout = state.add_access('_bout')
    if implementation == 'cuSolverDn':
        transpose_ain = Transpose('AT', dtype=ain_dtype)
        transpose_ain.implementation = 'cuBLAS'
        state.add_edge(ain, None, transpose_ain, '_inp',
                       Memlet.from_array(*ain_arr))
        state.add_edge(transpose_ain, '_out', ainout1, None,
                       Memlet.from_array(*ainout_arr))
        transpose_bin = Transpose('bT', dtype=bin_dtype)
        transpose_bin.implementation = 'cuBLAS'
        state.add_edge(bin, None, transpose_bin, '_inp',
                       Memlet.from_array(*bin_arr))
        state.add_edge(transpose_bin, '_out', binout1, None,
                       Memlet.from_array(*binout_arr))
        transpose_out = Transpose('XT', dtype=bin_dtype)
        transpose_out.implementation = 'cuBLAS'
        state.add_edge(binout2, None, transpose_out, '_inp',
                       Memlet.from_array(*binout_arr))
        state.add_edge(transpose_out, '_out', bout, None,
                       Memlet.from_array(*bout_arr))
    else:
        state.add_nedge(ain, ainout1, Memlet.from_array(*ain_arr))
        state.add_nedge(bin, binout1, Memlet.from_array(*bin_arr))
        state.add_nedge(binout2, bout, Memlet.from_array(*bout_arr))

    ipiv = state.add_access('_pivots')
    info1 = state.add_write('_info')
    info2 = state.add_write('_info')

    state.add_memlet_path(ainout1,
                          getrf_node,
                          dst_conn="_xin",
                          memlet=Memlet.from_array(*ainout_arr))
    state.add_memlet_path(getrf_node,
                          info1,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getrf_node,
                          ipiv,
                          src_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getrf_node,
                          ainout2,
                          src_conn="_xout",
                          memlet=Memlet.from_array(*ainout_arr))
    state.add_memlet_path(ainout2,
                          getrs_node,
                          dst_conn="_a",
                          memlet=Memlet.from_array(*ainout_arr))
    state.add_memlet_path(binout1,
                          getrs_node,
                          dst_conn="_rhs_in",
                          memlet=Memlet.from_array(*binout_arr))
    state.add_memlet_path(ipiv,
                          getrs_node,
                          dst_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getrs_node,
                          info2,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getrs_node,
                          binout2,
                          src_conn="_rhs_out",
                          memlet=Memlet.from_array(*binout_arr))

    return sdfg
示例#8
0
def test_duplicate_codegen():

    # Unfortunately I have to generate this graph manually, as doing it with the python
    # frontend wouldn't result in the node ordering that we want

    sdfg = dace.SDFG("dup")
    state = sdfg.add_state()

    c_task = state.add_tasklet("c_task",
                               inputs={"c"},
                               outputs={"d"},
                               code='d = c')
    e_task = state.add_tasklet("e_task",
                               inputs={"a", "d"},
                               outputs={"e"},
                               code="e = a + d")
    f_task = state.add_tasklet("f_task",
                               inputs={"b", "d"},
                               outputs={"f"},
                               code="f = b + d")

    _, A_arr = sdfg.add_array("A", [
        1,
    ], dace.float32)
    _, B_arr = sdfg.add_array("B", [
        1,
    ], dace.float32)
    _, C_arr = sdfg.add_array("C", [
        1,
    ], dace.float32)
    _, D_arr = sdfg.add_array("D", [
        1,
    ], dace.float32)
    _, E_arr = sdfg.add_array("E", [
        1,
    ], dace.float32)
    _, F_arr = sdfg.add_array("F", [
        1,
    ], dace.float32)
    A = state.add_read("A")
    B = state.add_read("B")
    C = state.add_read("C")
    D = state.add_access("D")
    E = state.add_write("E")
    F = state.add_write("F")

    state.add_edge(C, None, c_task, "c", Memlet.from_array("C", C_arr))
    state.add_edge(c_task, "d", D, None, Memlet.from_array("D", D_arr))

    state.add_edge(A, None, e_task, "a", Memlet.from_array("A", A_arr))
    state.add_edge(B, None, f_task, "b", Memlet.from_array("B", B_arr))
    state.add_edge(D, None, f_task, "d", Memlet.from_array("D", D_arr))
    state.add_edge(D, None, e_task, "d", Memlet.from_array("D", D_arr))

    state.add_edge(e_task, "e", E, None,
                   Memlet.from_array("E", E_arr, wcr="lambda x, y: x + y"))
    state.add_edge(f_task, "f", F, None,
                   Memlet.from_array("F", F_arr, wcr="lambda x, y: x + y"))

    A = np.array([1], dtype=np.float32)
    B = np.array([1], dtype=np.float32)
    C = np.array([1], dtype=np.float32)
    D = np.array([1], dtype=np.float32)
    E = np.zeros_like(A)
    F = np.zeros_like(A)

    sdfg(A=A, B=B, C=C, D=D, E=E, F=F)

    assert E[0] == 2
    assert F[0] == 2