示例#1
0
def make_sdfg(implementation,
              dtype,
              id=0,
              in_shape=[n, n],
              out_shape=[n, n],
              in_subset="0:n, 0:n",
              out_subset="0:n, 0:n",
              overwrite=False,
              getri=True):

    sdfg = dace.SDFG("linalg_inv_{}_{}_{}".format(implementation, dtype.__name__, id))
    sdfg.add_symbol("n", dace.int64)
    state = sdfg.add_state("dataflow")

    sdfg.add_array("xin", in_shape, dtype)
    if not overwrite:
        sdfg.add_array("xout", out_shape, dtype)

    xin = state.add_read("xin")
    if overwrite:
        xout = state.add_write("xin")
    else:
        xout = state.add_write("xout")

    inv_node = Inv("inv", overwrite_a=overwrite, use_getri=getri)
    inv_node.implementation = implementation

    state.add_memlet_path(xin, inv_node, dst_conn="_ain", memlet=Memlet.simple(xin, in_subset, num_accesses=n * n))
    state.add_memlet_path(inv_node, xout, src_conn="_aout", memlet=Memlet.simple(xout, out_subset, num_accesses=n * n))

    return sdfg
示例#2
0
def make_sdfg(implementation,
              dtype,
              id=0,
              in_shape=[n, n],
              out_shape=[n, n],
              in_subset="0:n, 0:n",
              out_subset="0:n, 0:n"):

    sdfg = dace.SDFG("linalg_solve_{}_{}_{}".format(implementation, dtype.__name__, id))
    sdfg.add_symbol("n", dace.int64)
    state = sdfg.add_state("dataflow")

    sdfg.add_array("ain", in_shape, dtype)
    sdfg.add_array("bin", out_shape, dtype)
    sdfg.add_array("bout", out_shape, dtype)

    ain = state.add_read("ain")
    bin = state.add_read("bin")
    bout = state.add_write("bout")

    solve_node = Solve("solve")
    solve_node.implementation = implementation

    state.add_memlet_path(ain, solve_node, dst_conn="_ain", memlet=Memlet.simple(ain, in_subset, num_accesses=n * n))
    state.add_memlet_path(bin, solve_node, dst_conn="_bin", memlet=Memlet.simple(bin, out_subset, num_accesses=n * n))
    state.add_memlet_path(solve_node,
                          bout,
                          src_conn="_bout",
                          memlet=Memlet.simple(bout, out_subset, num_accesses=n * n))

    return sdfg
示例#3
0
def _make_sdfg(node, parent_state, parent_sdfg, implementation):

    inp_desc, inp_shape, out_desc, out_shape = node.validate(parent_sdfg, parent_state)
    dtype = inp_desc.dtype

    sdfg = dace.SDFG("{l}_sdfg".format(l=node.label))

    ain_arr = sdfg.add_array('_a', inp_shape, dtype=dtype, strides=inp_desc.strides)
    bout_arr = sdfg.add_array('_b', out_shape, dtype=dtype, strides=out_desc.strides)
    info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True)
    if implementation == 'cuSolverDn':
        binout_arr = sdfg.add_array('_bt', inp_shape, dtype=dtype, transient=True)
    else:
        binout_arr = bout_arr

    state = sdfg.add_state("{l}_state".format(l=node.label))

    potrf_node = Potrf('potrf', lower=node.lower)
    potrf_node.implementation = implementation

    _, me, mx = state.add_mapped_tasklet('_uzero_',
                                         dict(__i="0:%s" % out_shape[0], __j="0:%s" % out_shape[1]),
                                         dict(_inp=Memlet.simple('_b', '__i, __j')),
                                         '_out = (__i < __j) ? 0 : _inp;',
                                         dict(_out=Memlet.simple('_b', '__i, __j')),
                                         language=dace.dtypes.Language.CPP,
                                         external_edges=True)

    ain = state.add_read('_a')
    if implementation == 'cuSolverDn':
        binout1 = state.add_access('_bt')
        binout2 = state.add_access('_bt')
        binout3 = state.in_edges(me)[0].src
        bout = state.out_edges(mx)[0].dst
        transpose_ain = Transpose('AT', dtype=dtype)
        transpose_ain.implementation = 'cuBLAS'
        state.add_edge(ain, None, transpose_ain, '_inp', Memlet.from_array(*ain_arr))
        state.add_edge(transpose_ain, '_out', binout1, None, Memlet.from_array(*binout_arr))
        transpose_out = Transpose('BT', dtype=dtype)
        transpose_out.implementation = 'cuBLAS'
        state.add_edge(binout2, None, transpose_out, '_inp', Memlet.from_array(*binout_arr))
        state.add_edge(transpose_out, '_out', binout3, None, Memlet.from_array(*bout_arr))
    else:
        binout1 = state.add_access('_b')
        binout2 = state.in_edges(me)[0].src
        binout3 = state.out_edges(mx)[0].dst
        state.add_nedge(ain, binout1, Memlet.from_array(*ain_arr))

    info = state.add_write('_info')

    state.add_memlet_path(binout1, potrf_node, dst_conn="_xin", memlet=Memlet.from_array(*binout_arr))
    state.add_memlet_path(potrf_node, info, src_conn="_res", memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(potrf_node, binout2, src_conn="_xout", memlet=Memlet.from_array(*binout_arr))

    return sdfg
示例#4
0
文件: inv.py 项目: mfkiwl/dace
def _make_sdfg_getrs(node, parent_state, parent_sdfg, implementation):

    arr_desc = node.validate(parent_sdfg, parent_state)
    if node.overwrite:
        in_shape, in_dtype, in_strides, n = arr_desc
    else:
        (in_shape, in_dtype, in_strides, out_shape, out_dtype, out_strides,
         n) = arr_desc
    dtype = in_dtype

    sdfg = dace.SDFG("{l}_sdfg".format(l=node.label))

    a_arr = sdfg.add_array('_ain',
                           in_shape,
                           dtype=in_dtype,
                           strides=in_strides)
    if not node.overwrite:
        ain_arr = a_arr
        a_arr = sdfg.add_array('_ainout', [n, n],
                               dtype=in_dtype,
                               transient=True)
        b_arr = sdfg.add_array('_aout',
                               out_shape,
                               dtype=out_dtype,
                               strides=out_strides)
    else:
        b_arr = sdfg.add_array('_b', [n, n], dtype=dtype, transient=True)
    ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True)
    info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True)

    state = sdfg.add_state("{l}_state".format(l=node.label))

    getrf_node = Getrf('getrf')
    getrf_node.implementation = implementation
    getrs_node = Getrs('getrs')
    getrs_node.implementation = implementation

    if node.overwrite:
        ain = state.add_read('_ain')
        ainout = state.add_access('_ain')
        aout = state.add_write('_ain')
        bin_name = '_b'
        bout = state.add_write('_b')
        state.add_nedge(bout, aout, Memlet.from_array(*a_arr))
    else:
        a = state.add_read('_ain')
        ain = state.add_read('_ainout')
        ainout = state.add_access('_ainout')
        # aout = state.add_write('_aout')
        state.add_nedge(a, ain, Memlet.from_array(*ain_arr))
        bin_name = '_aout'
        bout = state.add_access('_aout')

    _, _, mx = state.add_mapped_tasklet(
        '_eye_',
        dict(i="0:n", j="0:n"), {},
        '_out = (i == j) ? 1 : 0;',
        dict(_out=Memlet.simple(bin_name, 'i, j')),
        language=dace.dtypes.Language.CPP,
        external_edges=True)
    bin = state.out_edges(mx)[0].dst

    ipiv = state.add_access('_pivots')
    info1 = state.add_write('_info')
    info2 = state.add_write('_info')

    state.add_memlet_path(ain,
                          getrf_node,
                          dst_conn="_xin",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(getrf_node,
                          info1,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getrf_node,
                          ipiv,
                          src_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getrf_node,
                          ainout,
                          src_conn="_xout",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(ainout,
                          getrs_node,
                          dst_conn="_a",
                          memlet=Memlet.from_array(*a_arr))
    state.add_memlet_path(bin,
                          getrs_node,
                          dst_conn="_rhs_in",
                          memlet=Memlet.from_array(*b_arr))
    state.add_memlet_path(ipiv,
                          getrs_node,
                          dst_conn="_ipiv",
                          memlet=Memlet.from_array(*ipiv_arr))
    state.add_memlet_path(getrs_node,
                          info2,
                          src_conn="_res",
                          memlet=Memlet.from_array(*info_arr))
    state.add_memlet_path(getrs_node,
                          bout,
                          src_conn="_rhs_out",
                          memlet=Memlet.from_array(*b_arr))

    return sdfg
示例#5
0
def make_sdfg(specialize):

    if specialize:
        sdfg = SDFG("histogram_fpga_parallel_{}_{}x{}".format(
            P.get(), H.get(), W.get()))
    else:
        sdfg = SDFG("histogram_fpga_parallel_{}".format(P.get()))

    copy_to_fpga_state = make_copy_to_fpga_state(sdfg)

    state = sdfg.add_state("compute")

    # Compute module
    nested_sdfg = make_compute_nested_sdfg(state)
    tasklet = state.add_nested_sdfg(nested_sdfg, sdfg, {"A_pipe_in"},
                                    {"hist_pipe_out"})
    A_pipes_out = state.add_stream("A_pipes",
                                   dtype,
                                   shape=(P, ),
                                   transient=True,
                                   storage=StorageType.FPGA_Local)
    A_pipes_in = state.add_stream("A_pipes",
                                  dtype,
                                  shape=(P, ),
                                  transient=True,
                                  storage=StorageType.FPGA_Local)
    hist_pipes_out = state.add_stream("hist_pipes",
                                      itype,
                                      shape=(P, ),
                                      transient=True,
                                      storage=StorageType.FPGA_Local)
    unroll_entry, unroll_exit = state.add_map(
        "unroll_compute", {"p": "0:P"},
        schedule=dace.ScheduleType.FPGA_Device,
        unroll=True)
    state.add_memlet_path(unroll_entry, A_pipes_in, memlet=EmptyMemlet())
    state.add_memlet_path(hist_pipes_out, unroll_exit, memlet=EmptyMemlet())
    state.add_memlet_path(A_pipes_in,
                          tasklet,
                          dst_conn="A_pipe_in",
                          memlet=Memlet.simple(A_pipes_in,
                                               "p",
                                               num_accesses="W*H"))
    state.add_memlet_path(tasklet,
                          hist_pipes_out,
                          src_conn="hist_pipe_out",
                          memlet=Memlet.simple(hist_pipes_out,
                                               "p",
                                               num_accesses="num_bins"))

    # Read module
    a_device = state.add_array("A_device", (H, W),
                               dtype,
                               transient=True,
                               storage=dace.dtypes.StorageType.FPGA_Global)
    read_entry, read_exit = state.add_map("read_map", {
        "h": "0:H",
        "w": "0:W:P"
    },
                                          schedule=ScheduleType.FPGA_Device)
    a_val = state.add_array("A_val", (P, ),
                            dtype,
                            transient=True,
                            storage=StorageType.FPGA_Local)
    read_unroll_entry, read_unroll_exit = state.add_map(
        "read_unroll", {"p": "0:P"},
        schedule=ScheduleType.FPGA_Device,
        unroll=True)
    read_tasklet = state.add_tasklet("read", {"A_in"}, {"A_pipe"},
                                     "A_pipe = A_in[p]")
    state.add_memlet_path(a_device,
                          read_entry,
                          a_val,
                          memlet=Memlet(a_val,
                                        num_accesses=1,
                                        subset=Indices(["0"]),
                                        vector_length=P.get(),
                                        other_subset=Indices(["h", "w"])))
    state.add_memlet_path(a_val,
                          read_unroll_entry,
                          read_tasklet,
                          dst_conn="A_in",
                          memlet=Memlet.simple(a_val,
                                               "0",
                                               veclen=P.get(),
                                               num_accesses=1))
    state.add_memlet_path(read_tasklet,
                          read_unroll_exit,
                          read_exit,
                          A_pipes_out,
                          src_conn="A_pipe",
                          memlet=Memlet.simple(A_pipes_out, "p"))

    # Write module
    hist_pipes_in = state.add_stream("hist_pipes",
                                     itype,
                                     shape=(P, ),
                                     transient=True,
                                     storage=StorageType.FPGA_Local)
    hist_device_out = state.add_array(
        "hist_device", (num_bins, ),
        itype,
        transient=True,
        storage=dace.dtypes.StorageType.FPGA_Global)
    merge_entry, merge_exit = state.add_map("merge", {"nb": "0:num_bins"},
                                            schedule=ScheduleType.FPGA_Device)
    merge_reduce = state.add_reduce("lambda a, b: a + b", (0, ),
                                    "0",
                                    schedule=ScheduleType.FPGA_Device)
    state.add_memlet_path(hist_pipes_in,
                          merge_entry,
                          merge_reduce,
                          memlet=Memlet.simple(hist_pipes_in,
                                               "0:P",
                                               num_accesses=P))
    state.add_memlet_path(merge_reduce,
                          merge_exit,
                          hist_device_out,
                          memlet=dace.memlet.Memlet.simple(
                              hist_device_out, "nb"))

    copy_to_host_state = make_copy_to_host_state(sdfg)

    sdfg.add_edge(copy_to_fpga_state, state, dace.graph.edges.InterstateEdge())
    sdfg.add_edge(state, copy_to_host_state, dace.graph.edges.InterstateEdge())

    return sdfg
示例#6
0
    def backward(
        forward_node: Node, context: BackwardContext,
        given_gradients: typing.List[typing.Optional[str]],
        required_gradients: typing.List[typing.Optional[str]]
    ) -> typing.Tuple[Node, BackwardResult]:
        reduction_type = detect_reduction_type(forward_node.wcr)

        if len(given_gradients) != 1:
            raise AutoDiffException(
                "recieved invalid SDFG: reduce node {} should have exactly one output edge"
                .format(forward_node))

        if len(required_gradients) != 1:
            raise AutoDiffException(
                "recieved invalid SDFG: reduce node {} should have exactly one input edge"
                .format(forward_node))

        input_name = next(iter(required_gradients))
        in_desc = in_desc_with_name(forward_node, context.forward_state,
                                    context.forward_sdfg, input_name)

        output_name = next(iter(given_gradients))
        out_desc = out_desc_with_name(forward_node, context.forward_state,
                                      context.forward_sdfg, output_name)

        all_axes: typing.List[int] = list(range(len(in_desc.shape)))
        reduce_axes: typing.List[
            int] = all_axes if forward_node.axes is None else forward_node.axes
        non_reduce_axes: typing.List[int] = [
            i for i in all_axes if i not in reduce_axes
        ]

        result = BackwardResult.empty()

        if reduction_type is dtypes.ReductionType.Sum:
            # in this case, we need to simply scatter the grad across the axes that were reduced

            sdfg = SDFG("_reverse_" + str(reduction_type).replace(".", "_") +
                        "_")
            state = sdfg.add_state()

            rev_input_conn_name = "input_gradient"
            rev_output_conn_name = "output_gradient"
            result.required_grad_names[output_name] = rev_output_conn_name
            result.given_grad_names[input_name] = rev_input_conn_name

            _, rev_input_arr = sdfg.add_array(rev_input_conn_name,
                                              shape=out_desc.shape,
                                              dtype=out_desc.dtype)
            _, rev_output_arr = sdfg.add_array(rev_output_conn_name,
                                               shape=in_desc.shape,
                                               dtype=in_desc.dtype)

            state.add_mapped_tasklet(
                "_distribute_grad_" + str(reduction_type).replace(".", "_") +
                "_", {
                    "i" + str(i): "0:{}".format(shape)
                    for i, shape in enumerate(in_desc.shape)
                }, {
                    "__in":
                    Memlet.simple(
                        rev_input_conn_name,
                        "0" if forward_node.axes is None else ",".join(
                            "i" + str(i) for i in non_reduce_axes))
                },
                "__out = __in", {
                    "__out":
                    Memlet.simple(rev_output_conn_name,
                                  ",".join("i" + str(i) for i in all_axes),
                                  wcr_str="lambda x, y: x + y")
                },
                external_edges=True)

            return context.backward_state.add_nested_sdfg(
                sdfg, None, {rev_input_conn_name},
                {rev_output_conn_name}), result
        else:
            raise AutoDiffException(
                "Unsupported reduction type '{}'".format(reduction_type))