Ejemplo n.º 1
0
def test_tasklet_parameter():
    """
        Test the sv parameter support.
    """

    # add sdfg
    sdfg = dace.SDFG('rtl_tasklet_parameter')

    # add state
    state = sdfg.add_state()

    # add arrays
    sdfg.add_array('A', [1], dtype=dace.int32)
    sdfg.add_array('B', [1], dtype=dace.int32)

    # add parameter(s)
    sdfg.add_constant("MAX_VAL", 42)

    # add custom cpp tasklet
    tasklet = state.add_tasklet(name='rtl_tasklet',
                                inputs={'a'},
                                outputs={'b'},
                                code='''
        /*
            Convention:
               |---------------------------------------------------------------------|
            -->| ap_aclk (clock input)                                               |
            -->| ap_areset (reset input, rst on high)                                |
               |                                                                     |
            -->| {inputs}                                              reg {outputs} |-->
               |                                                                     |
            <--| s_axis_a_tready (ready for data)       (data avail) m_axis_b_tvalid |-->
            -->| s_axis_a_tvalid (new data avail)    (data consumed) m_axis_b_tready |<--
               |---------------------------------------------------------------------|
        */

        typedef enum [1:0] {READY, BUSY, DONE} state_e;
        state_e state;

        always@(posedge ap_aclk) begin
            if (ap_areset) begin // case: reset
                m_axis_b_tdata <= 0;
                s_axis_a_tready <= 1'b1;
                state <= READY;
            end else if (s_axis_a_tvalid && state == READY) begin // case: load a
                m_axis_b_tdata <= s_axis_a_tdata;
                s_axis_a_tready <= 1'b0;
                state <= BUSY;
            end else if (m_axis_b_tdata < MAX_VAL) // case: increment counter b
                m_axis_b_tdata <= m_axis_b_tdata + 1;
            else
                m_axis_b_tdata <= m_axis_b_tdata;
                state <= DONE;
        end

        assign m_axis_b_tvalid  = (m_axis_b_tdata >= MAX_VAL) ? 1'b1:1'b0;
        ''',
                                language=dace.Language.SystemVerilog)

    # add input/output array
    A = state.add_read('A')
    B = state.add_write('B')

    # connect input/output array with the tasklet
    state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0]'))
    state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0]'))

    # validate sdfg
    sdfg.validate()

    # execute

    # init data structures
    a = np.random.randint(0, 100, 1).astype(np.int32)
    b = np.random.randint(0, 100, 1).astype(np.int32)

    # call program
    sdfg(A=a, B=b)

    # check result
    assert b == sdfg.constants["MAX_VAL"]
Ejemplo n.º 2
0
def mapfission_sdfg():
    sdfg = dace.SDFG('mapfission')
    sdfg.add_array('A', [4], dace.float64)
    sdfg.add_array('B', [2], dace.float64)
    sdfg.add_scalar('scal', dace.float64, transient=True)
    sdfg.add_scalar('s1', dace.float64, transient=True)
    sdfg.add_transient('s2', [2], dace.float64)
    sdfg.add_transient('s3out', [1], dace.float64)
    state = sdfg.add_state()

    # Nodes
    rnode = state.add_read('A')
    ome, omx = state.add_map('outer', dict(i='0:2'))
    t1 = state.add_tasklet('one', {'a'}, {'b'}, 'b = a[0] + a[1]')
    ime2, imx2 = state.add_map('inner', dict(j='0:2'))
    t2 = state.add_tasklet('two', {'a'}, {'b'}, 'b = a * 2')
    s24node = state.add_access('s2')
    s34node = state.add_access('s3out')
    ime3, imx3 = state.add_map('inner', dict(j='0:2'))
    t3 = state.add_tasklet('three', {'a'}, {'b'}, 'b = a[0] * 3')
    scalar = state.add_tasklet('scalar', {}, {'out'}, 'out = 5.0')
    t4 = state.add_tasklet('four', {'ione', 'itwo', 'ithree', 'sc'}, {'out'},
                           'out = ione + itwo[0] * itwo[1] + ithree + sc')
    wnode = state.add_write('B')

    # Edges
    state.add_nedge(ome, scalar, dace.Memlet())
    state.add_memlet_path(rnode,
                          ome,
                          t1,
                          memlet=dace.Memlet.simple('A', '2*i:2*i+2'),
                          dst_conn='a')
    state.add_memlet_path(rnode,
                          ome,
                          ime2,
                          t2,
                          memlet=dace.Memlet.simple('A', '2*i+j'),
                          dst_conn='a')
    state.add_memlet_path(t2,
                          imx2,
                          s24node,
                          memlet=dace.Memlet.simple('s2', 'j'),
                          src_conn='b')
    state.add_memlet_path(rnode,
                          ome,
                          ime3,
                          t3,
                          memlet=dace.Memlet.simple('A', '2*i:2*i+2'),
                          dst_conn='a')
    state.add_memlet_path(t3,
                          imx3,
                          s34node,
                          memlet=dace.Memlet.simple('s3out', '0'),
                          src_conn='b')

    state.add_edge(t1, 'b', t4, 'ione', dace.Memlet.simple('s1', '0'))
    state.add_edge(s24node, None, t4, 'itwo', dace.Memlet.simple('s2', '0:2'))
    state.add_edge(s34node, None, t4, 'ithree',
                   dace.Memlet.simple('s3out', '0'))
    state.add_edge(scalar, 'out', t4, 'sc', dace.Memlet.simple('scal', '0'))
    state.add_memlet_path(t4,
                          omx,
                          wnode,
                          memlet=dace.Memlet.simple('B', 'i'),
                          src_conn='out')

    sdfg.validate()
    return sdfg
Ejemplo n.º 3
0
    def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
        node.validate(sdfg, state)
        inedge: graph.MultiConnectorEdge = state.in_edges(node)[0]
        outedge: graph.MultiConnectorEdge = state.out_edges(node)[0]
        insubset = dcpy(inedge.data.subset)
        isqdim = insubset.squeeze()
        outsubset = dcpy(outedge.data.subset)
        osqdim = outsubset.squeeze()
        input_dims = len(insubset)
        output_dims = len(outsubset)
        input_data = sdfg.arrays[inedge.data.data]
        output_data = sdfg.arrays[outedge.data.data]

        if len(osqdim) == 0:  # Fix for scalars
            osqdim = [0]

        # Standardize and squeeze axes
        axes = node.axes if node.axes else [
            i for i in range(len(inedge.data.subset))
        ]
        axes = [axis for axis in axes if axis in isqdim]

        assert node.identity is not None

        # Create nested SDFG
        nsdfg = SDFG('reduce')

        nsdfg.add_array('_in',
                        insubset.size(),
                        input_data.dtype,
                        strides=[
                            s for i, s in enumerate(input_data.strides)
                            if i in isqdim
                        ],
                        storage=input_data.storage)

        nsdfg.add_array('_out',
                        outsubset.size(),
                        output_data.dtype,
                        strides=[
                            s for i, s in enumerate(output_data.strides)
                            if i in osqdim
                        ],
                        storage=output_data.storage)

        nsdfg.add_transient('acc', [1], nsdfg.arrays['_in'].dtype,
                            dtypes.StorageType.Register)

        nstate = nsdfg.add_state()

        # Interleave input and output axes to match input memlet
        ictr, octr = 0, 0
        input_subset = []
        for i in isqdim:
            if i in axes:
                input_subset.append('_i%d' % ictr)
                ictr += 1
            else:
                input_subset.append('_o%d' % octr)
                octr += 1

        ome, omx = nstate.add_map(
            'reduce_output', {
                '_o%d' % i: '0:%s' % symstr(sz)
                for i, sz in enumerate(outsubset.size())
            })
        outm = dace.Memlet.simple(
            '_out', ','.join(['_o%d' % i for i in range(output_dims)]))
        #wcr_str=node.wcr)
        inmm = dace.Memlet.simple('_in', ','.join(input_subset))

        idt = nstate.add_tasklet('reset', {}, {'o'}, f'o = {node.identity}')
        nstate.add_edge(ome, None, idt, None, dace.Memlet())

        accread = nstate.add_access('acc')
        accwrite = nstate.add_access('acc')
        nstate.add_edge(idt, 'o', accread, None, dace.Memlet('acc'))

        # Add inner map, which corresponds to the range to reduce, containing
        # an identity tasklet
        ime, imx = nstate.add_map('reduce_values', {
            '_i%d' % i: '0:%s' % symstr(insubset.size()[isqdim.index(axis)])
            for i, axis in enumerate(sorted(axes))
        },
                                  schedule=dtypes.ScheduleType.Sequential)

        # Add identity tasklet for reduction
        t = nstate.add_tasklet('identity', {'a', 'b'}, {'o'}, 'o = b')

        # Connect everything
        r = nstate.add_read('_in')
        w = nstate.add_write('_out')
        nstate.add_memlet_path(r, ome, ime, t, dst_conn='b', memlet=inmm)
        nstate.add_memlet_path(accread,
                               ime,
                               t,
                               dst_conn='a',
                               memlet=dace.Memlet('acc[0]'))
        nstate.add_memlet_path(t,
                               imx,
                               accwrite,
                               src_conn='o',
                               memlet=dace.Memlet('acc[0]', wcr=node.wcr))
        nstate.add_memlet_path(accwrite, omx, w, memlet=outm)

        # Rename outer connectors and add to node
        inedge._dst_conn = '_in'
        outedge._src_conn = '_out'
        node.add_in_connector('_in')
        node.add_out_connector('_out')

        from dace.transformation import dataflow
        nsdfg.apply_transformations_repeated(dataflow.MapCollapse)

        return nsdfg
Ejemplo n.º 4
0
def create_vadd_multibank_sdfg(bank_count_per_array=2,
                               ndim=1,
                               unroll_map_inside=False,
                               sdfg_name="vadd_hbm"):
    N = dace.symbol("N")
    M = dace.symbol("M")
    S = dace.symbol("S")

    sdfg = dace.SDFG(sdfg_name)
    state = sdfg.add_state('vadd_hbm', True)
    shape = [bank_count_per_array, N]
    access_str = "i"
    inner_map_range = dict()
    inner_map_range["i"] = "0:N"
    if (ndim >= 2):
        shape = [bank_count_per_array, N, M]
        access_str = "i, j"
        inner_map_range["j"] = "0:M"
    if (ndim >= 3):
        shape = [bank_count_per_array, N, M, S]
        access_str = "i, j, t"
        inner_map_range["t"] = "0:S"

    in1 = sdfg.add_array("in1", shape, dace.float32)
    in2 = sdfg.add_array("in2", shape, dace.float32)
    out = sdfg.add_array("out", shape, dace.float32)

    in1[1].location["memorytype"] = "hbm"
    in2[1].location["memorytype"] = "hbm"
    out[1].location["memorytype"] = "hbm"
    in1[1].location["bank"] = f"0:{bank_count_per_array}"
    in2[1].location[
        "bank"] = f"{bank_count_per_array}:{2*bank_count_per_array}"
    out[1].location[
        "bank"] = f"{2*bank_count_per_array}:{3*bank_count_per_array}"

    read_in1 = state.add_read("in1")
    read_in2 = state.add_read("in2")
    out_write = state.add_write("out")

    tmp_in1_memlet = dace.Memlet(f"in1[k, {access_str}]")
    tmp_in2_memlet = dace.Memlet(f"in2[k, {access_str}]")
    tmp_out_memlet = dace.Memlet(f"out[k, {access_str}]")

    outer_entry, outer_exit = state.add_map(
        "vadd_outer_map", dict(k=f'0:{bank_count_per_array}'))
    map_entry, map_exit = state.add_map("vadd_inner_map", inner_map_range)
    tasklet = state.add_tasklet("addandwrite", dict(__in1=None, __in2=None),
                                dict(__out=None), '__out = __in1 + __in2')
    outer_entry.map.schedule = dace.ScheduleType.Unrolled

    if (unroll_map_inside):
        state.add_memlet_path(read_in1,
                              map_entry,
                              outer_entry,
                              tasklet,
                              memlet=tmp_in1_memlet,
                              dst_conn="__in1")
        state.add_memlet_path(read_in2,
                              map_entry,
                              outer_entry,
                              tasklet,
                              memlet=tmp_in2_memlet,
                              dst_conn="__in2")
        state.add_memlet_path(tasklet,
                              outer_exit,
                              map_exit,
                              out_write,
                              memlet=tmp_out_memlet,
                              src_conn="__out")
    else:
        state.add_memlet_path(read_in1,
                              outer_entry,
                              map_entry,
                              tasklet,
                              memlet=tmp_in1_memlet,
                              dst_conn="__in1")
        state.add_memlet_path(read_in2,
                              outer_entry,
                              map_entry,
                              tasklet,
                              memlet=tmp_in2_memlet,
                              dst_conn="__in2")
        state.add_memlet_path(tasklet,
                              map_exit,
                              outer_exit,
                              out_write,
                              memlet=tmp_out_memlet,
                              src_conn="__out")

    sdfg.apply_fpga_transformations()
    return sdfg
Ejemplo n.º 5
0
    def make_sdfg(dtype=dace.float32):
        sdfg = dace.SDFG("multiple_kernels_multiple_states")
        n = dace.symbol("size")

        input_data = ["x", "y", "v", "w", "xx", "yy", "vv", "ww"]
        output_data = ["z", "zz"]
        device_transient_data = [
            "device_tmp0", "device_tmp1", "device_tmp2", "device_tmp3"
        ]

        for d in input_data + output_data:
            sdfg.add_array(d, shape=[n], dtype=dtype)
            sdfg.add_array(f"device_{d}",
                           shape=[n],
                           dtype=dtype,
                           storage=dace.dtypes.StorageType.FPGA_Global,
                           transient=True)

        for d in device_transient_data:
            sdfg.add_array(d,
                           shape=[n],
                           dtype=dtype,
                           storage=dace.dtypes.StorageType.FPGA_Global,
                           transient=True)

        ###########################################################################
        # Copy data to FPGA

        copy_in_state = sdfg.add_state("copy_to_device")

        for d in input_data:
            in_host = copy_in_state.add_read(d)
            in_device = copy_in_state.add_read(f"device_{d}")

            copy_in_state.add_memlet_path(in_host,
                                          in_device,
                                          memlet=dace.Memlet(f"{d}[0:{n}]"))

        ###########################################################################
        # Copy data from FPGA
        copy_out_state = sdfg.add_state("copy_to_host")

        for d in output_data:
            out_host = copy_out_state.add_write(d)
            out_device = copy_out_state.add_read(f"device_{d}")

            copy_out_state.add_memlet_path(out_device,
                                           out_host,
                                           memlet=dace.Memlet(f"{d}[0:{n}]"))

        ########################################################################
        # FPGA, First State

        fpga_state_0 = sdfg.add_state("fpga_state_0")

        x_in = fpga_state_0.add_read("device_x")
        y_in = fpga_state_0.add_read("device_y")
        v_in = fpga_state_0.add_read("device_v")
        w_in = fpga_state_0.add_read("device_w")
        device_tmp0 = fpga_state_0.add_access("device_tmp0")
        device_tmp1 = fpga_state_0.add_access("device_tmp1")
        z_out = fpga_state_0.add_write("device_z")

        # x + y
        vecMap_entry00, vecMap_exit00 = fpga_state_0.add_map(
            'vecAdd_map00',
            dict(i=f'0:{n}'),
            schedule=dace.dtypes.ScheduleType.FPGA_Device)

        vecAdd_tasklet00 = fpga_state_0.add_tasklet('vec_add_task00',
                                                    ['x_con', 'y_con'],
                                                    ['z_con'],
                                                    'z_con = x_con + y_con')

        fpga_state_0.add_memlet_path(x_in,
                                     vecMap_entry00,
                                     vecAdd_tasklet00,
                                     dst_conn='x_con',
                                     memlet=dace.Memlet("device_x[i]"))

        fpga_state_0.add_memlet_path(y_in,
                                     vecMap_entry00,
                                     vecAdd_tasklet00,
                                     dst_conn='y_con',
                                     memlet=dace.Memlet("device_y[i]"))

        fpga_state_0.add_memlet_path(vecAdd_tasklet00,
                                     vecMap_exit00,
                                     device_tmp0,
                                     src_conn='z_con',
                                     memlet=dace.Memlet("device_tmp0[i]"))

        # v + w

        vecMap_entry01, vecMap_exit01 = fpga_state_0.add_map(
            'vecAdd_map01',
            dict(i=f'0:{n}'),
            schedule=dace.dtypes.ScheduleType.FPGA_Device)

        vecAdd_tasklet01 = fpga_state_0.add_tasklet('vec_add_task01',
                                                    ['x_con', 'y_con'],
                                                    ['z_con'],
                                                    'z_con = x_con + y_con')

        fpga_state_0.add_memlet_path(v_in,
                                     vecMap_entry01,
                                     vecAdd_tasklet01,
                                     dst_conn='x_con',
                                     memlet=dace.Memlet(f"device_v[i]"))

        fpga_state_0.add_memlet_path(w_in,
                                     vecMap_entry01,
                                     vecAdd_tasklet01,
                                     dst_conn='y_con',
                                     memlet=dace.Memlet(f"device_w[i]"))

        fpga_state_0.add_memlet_path(vecAdd_tasklet01,
                                     vecMap_exit01,
                                     device_tmp1,
                                     src_conn='z_con',
                                     memlet=dace.Memlet(f"device_tmp1[i]"))

        # tmp0 + tmp 1

        vecMap_entry02, vecMap_exit02 = fpga_state_0.add_map(
            'vecAdd_map02',
            dict(i=f'0:{n}'),
            schedule=dace.dtypes.ScheduleType.FPGA_Device)

        vecAdd_tasklet02 = fpga_state_0.add_tasklet('vec_add_task02',
                                                    ['x_con', 'y_con'],
                                                    ['z_con'],
                                                    'z_con = x_con + y_con')

        fpga_state_0.add_memlet_path(device_tmp0,
                                     vecMap_entry02,
                                     vecAdd_tasklet02,
                                     dst_conn='x_con',
                                     memlet=dace.Memlet("device_tmp0[i]"))

        fpga_state_0.add_memlet_path(device_tmp1,
                                     vecMap_entry02,
                                     vecAdd_tasklet02,
                                     dst_conn='y_con',
                                     memlet=dace.Memlet("device_tmp1[i]"))

        fpga_state_0.add_memlet_path(vecAdd_tasklet02,
                                     vecMap_exit02,
                                     z_out,
                                     src_conn='z_con',
                                     memlet=dace.Memlet("device_z[i]"))
        ########################################################################
        # FPGA, Second State

        fpga_state_1 = sdfg.add_state("fpga_state_1")

        xx_in = fpga_state_1.add_read("device_xx")
        yy_in = fpga_state_1.add_read("device_yy")
        vv_in = fpga_state_1.add_read("device_vv")
        ww_in = fpga_state_1.add_read("device_ww")
        device_tmp2 = fpga_state_1.add_access("device_tmp2")
        device_tmp3 = fpga_state_1.add_access("device_tmp3")
        zz_out = fpga_state_1.add_write("device_zz")

        # xx + yy
        vecMap_entry10, vecMap_exit10 = fpga_state_1.add_map(
            'vecAdd_map10',
            dict(i=f'0:{n}'),
            schedule=dace.dtypes.ScheduleType.FPGA_Device)

        vecAdd_tasklet10 = fpga_state_1.add_tasklet('vec_add_task10',
                                                    ['x_con', 'y_con'],
                                                    ['z_con'],
                                                    'z_con = x_con + y_con')

        fpga_state_1.add_memlet_path(xx_in,
                                     vecMap_entry10,
                                     vecAdd_tasklet10,
                                     dst_conn='x_con',
                                     memlet=dace.Memlet("device_xx[i]"))

        fpga_state_1.add_memlet_path(yy_in,
                                     vecMap_entry10,
                                     vecAdd_tasklet10,
                                     dst_conn='y_con',
                                     memlet=dace.Memlet("device_yy[i]"))

        fpga_state_1.add_memlet_path(vecAdd_tasklet10,
                                     vecMap_exit10,
                                     device_tmp2,
                                     src_conn='z_con',
                                     memlet=dace.Memlet("device_tmp2[i]"))

        # vv + ww
        vecMap_entry11, vecMap_exit11 = fpga_state_1.add_map(
            'vecAdd_map11',
            dict(i=f'0:{n}'),
            schedule=dace.dtypes.ScheduleType.FPGA_Device)

        vecAdd_tasklet11 = fpga_state_1.add_tasklet('vec_add_task11',
                                                    ['x_con', 'y_con'],
                                                    ['z_con'],
                                                    'z_con = x_con + y_con')

        fpga_state_1.add_memlet_path(vv_in,
                                     vecMap_entry11,
                                     vecAdd_tasklet11,
                                     dst_conn='x_con',
                                     memlet=dace.Memlet(f"device_vv[i]"))

        fpga_state_1.add_memlet_path(ww_in,
                                     vecMap_entry11,
                                     vecAdd_tasklet11,
                                     dst_conn='y_con',
                                     memlet=dace.Memlet(f"device_ww[i]"))

        fpga_state_1.add_memlet_path(vecAdd_tasklet11,
                                     vecMap_exit11,
                                     device_tmp3,
                                     src_conn='z_con',
                                     memlet=dace.Memlet(f"device_tmp3[i]"))

        # tmp2 + tmp 3

        vecMap_entry12, vecMap_exit12 = fpga_state_1.add_map(
            'vecAdd_map12',
            dict(i=f'0:{n}'),
            schedule=dace.dtypes.ScheduleType.FPGA_Device)

        vecAdd_tasklet12 = fpga_state_1.add_tasklet('vec_add_task12',
                                                    ['x_con', 'y_con'],
                                                    ['z_con'],
                                                    'z_con = x_con + y_con')

        fpga_state_1.add_memlet_path(device_tmp2,
                                     vecMap_entry12,
                                     vecAdd_tasklet12,
                                     dst_conn='x_con',
                                     memlet=dace.Memlet("device_tmp2[i]"))

        fpga_state_1.add_memlet_path(device_tmp3,
                                     vecMap_entry12,
                                     vecAdd_tasklet12,
                                     dst_conn='y_con',
                                     memlet=dace.Memlet("device_tmp3[i]"))

        fpga_state_1.add_memlet_path(vecAdd_tasklet12,
                                     vecMap_exit12,
                                     zz_out,
                                     src_conn='z_con',
                                     memlet=dace.Memlet("device_zz[i]"))

        ######################################
        # Interstate edges
        sdfg.add_edge(copy_in_state, fpga_state_0,
                      dace.sdfg.sdfg.InterstateEdge())
        sdfg.add_edge(fpga_state_0, fpga_state_1,
                      dace.sdfg.sdfg.InterstateEdge())
        sdfg.add_edge(fpga_state_1, copy_out_state,
                      dace.sdfg.sdfg.InterstateEdge())

        #########
        # Validate
        sdfg.fill_scope_connectors()
        sdfg.validate()
        return sdfg
Ejemplo n.º 6
0
def make_vec_mul_sdfg(dtype=dace.float32):
    # Vector multiplication SDFG

    vecWidth = 4
    n = dace.symbol("size")
    vecMul_sdfg = dace.SDFG("vec_mul")
    vecType = dace.vector(dtype, vecWidth)
    fpga_state = vecMul_sdfg.add_state("vec_mul_state")

    vecMul_sdfg.add_array('_device_x', shape=[n / vecWidth], dtype=vecType, storage=dace.dtypes.StorageType.FPGA_Global)
    vecMul_sdfg.add_array('_device_y', shape=[n / vecWidth], dtype=vecType, storage=dace.dtypes.StorageType.FPGA_Global)
    vecMul_sdfg.add_array('_device_z', shape=[n / vecWidth], dtype=vecType, storage=dace.dtypes.StorageType.FPGA_Global)

    x = fpga_state.add_read("_device_x")
    y = fpga_state.add_read("_device_y")
    z = fpga_state.add_write("_device_z")

    # ---------- ----------
    # COMPUTE
    # ---------- ----------
    vecMap_entry, vecMap_exit = fpga_state.add_map('vecMul_map',
                                                   dict(i='0:{0}/{1}'.format(n, vecWidth)),
                                                   schedule=dace.dtypes.ScheduleType.FPGA_Device)

    vecMul_tasklet = fpga_state.add_tasklet('vecMul_task', ['x_con', 'y_con'], ['z_con'], 'z_con = x_con * y_con')

    fpga_state.add_memlet_path(x, vecMap_entry, vecMul_tasklet, dst_conn='x_con', memlet=dace.Memlet(f"{x.data}[i]"))

    fpga_state.add_memlet_path(y, vecMap_entry, vecMul_tasklet, dst_conn='y_con', memlet=dace.Memlet(f"{y.data}[i]"))

    fpga_state.add_memlet_path(vecMul_tasklet, vecMap_exit, z, src_conn='z_con', memlet=dace.Memlet(f"{z.data}[i]"))

    #########
    # Validate
    vecMul_sdfg.fill_scope_connectors()
    vecMul_sdfg.validate()
    return vecMul_sdfg
Ejemplo n.º 7
0
            m_axis_b_tdata <= m_axis_b_tdata + 1;
        else
            m_axis_b_tdata <= m_axis_b_tdata;
            state <= DONE;
    end    

    assign m_axis_b_tvalid  = (m_axis_b_tdata >= MAX_VAL) ? 1'b1:1'b0;  
    ''',
                            language=dace.Language.SystemVerilog)

# add input/output array
A = state.add_read('A')
B = state.add_write('B')

# connect input/output array with the tasklet
state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0]'))
state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0]'))

# validate sdfg
sdfg.validate()

######################################################################

if __name__ == '__main__':

    # init data structures
    a = np.random.randint(0, 100, 1).astype(np.int32)
    b = np.array([0]).astype(np.int32)

    # show initial values
    print("a={}, b={}".format(a, b))
Ejemplo n.º 8
0
def test_tasklet_array():
    """
        Test the simple array execution sample.
    """

    n = 128
    N = dace.symbol('N')
    N.set(n)

    # add sdfg
    sdfg = dace.SDFG('rtl_tasklet_array')

    # add state
    state = sdfg.add_state()

    # add arrays
    sdfg.add_array('A', [N], dtype=dace.int32)
    sdfg.add_array('B', [N], dtype=dace.int32)

    # add custom cpp tasklet
    tasklet = state.add_tasklet(name='rtl_tasklet',
                                inputs={'a'},
                                outputs={'b'},
                                code='''
        always@(posedge ap_aclk) begin
            if (ap_areset) begin
                s_axis_a_tready <= 1;
                m_axis_b_tvalid <= 0;
                m_axis_b_tdata <= 0;
            end else if (s_axis_a_tvalid && s_axis_a_tready) begin
                s_axis_a_tready <= 0;
                m_axis_b_tvalid <= 1;
                m_axis_b_tdata <= s_axis_a_tdata + 42;
            end else if (m_axis_b_tvalid && m_axis_b_tready) begin
                s_axis_a_tready <= 1;
                m_axis_b_tvalid <= 0;
                m_axis_b_tdata <= 0;
            end
        end
        ''',
                                language=dace.Language.SystemVerilog)

    # add input/output array
    A = state.add_read('A')
    B = state.add_write('B')

    # connect input/output array with the tasklet
    state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0:N]'))
    state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0:N]'))

    # validate sdfg
    sdfg.specialize({'N': N.get()})
    sdfg.validate()

    # init data structures
    a = np.random.randint(0, 100, N.get()).astype(np.int32)
    b = np.zeros((N.get(), )).astype(np.int32)

    # call program
    sdfg(A=a, B=b)

    # check result
    assert (b == a + 42).all()
Ejemplo n.º 9
0
    end
    ''',
                                language=dace.Language.SystemVerilog)

# add read and write tasklets
read_a = state.add_tasklet('read_a', {'inp'}, {'out'}, 'out = inp')
write_b = state.add_tasklet('write_b', {'inp'}, {'out'}, 'out = inp')

# add read and write maps
read_a_entry, read_a_exit = state.add_map('read_a_map', dict(i='0:N'), schedule=dace.ScheduleType.FPGA_Device)
write_b_entry, write_b_exit = state.add_map('write_b_map', dict(i='0:N'), schedule=dace.ScheduleType.FPGA_Device)

# add read_a memlets and access nodes
read_a_inp = state.add_read('fpga_A')
read_a_out = state.add_write('A_stream')
state.add_memlet_path(read_a_inp, read_a_entry, read_a, dst_conn='inp', memlet=dace.Memlet('fpga_A[i]'))
state.add_memlet_path(read_a, read_a_exit, read_a_out, src_conn='out', memlet=dace.Memlet('A_stream[0]'))

# add tasklet memlets
A = state.add_read('A_stream')
B = state.add_write('B_stream')
state.add_memlet_path(A, rtl_tasklet, dst_conn='a', memlet=dace.Memlet('A_stream[0]'))
state.add_memlet_path(rtl_tasklet, B, src_conn='b', memlet=dace.Memlet('B_stream[0]'))

# add write_b memlets and access nodes
write_b_inp = state.add_read('B_stream')
write_b_out = state.add_write('fpga_B')
state.add_memlet_path(write_b_inp, write_b_entry, write_b, dst_conn='inp', memlet=dace.Memlet('B_stream[0]'))
state.add_memlet_path(write_b, write_b_exit, write_b_out, src_conn='out', memlet=dace.Memlet('fpga_B[i]'))

# add copy to device state
Ejemplo n.º 10
0
def test_tasklet_map():
    '''
        Test the unrolled map support for M tasklets on N vectors of size W.
    '''
    # add symbols
    n = 512
    m = 8
    w = 4
    N = dace.symbol('N')
    M = dace.symbol('M')
    W = dace.symbol('W')
    N.set(n)
    M.set(m)
    W.set(w)

    # add sdfg
    sdfg = dace.SDFG('rtl_tasklet_map')

    # add state
    state = sdfg.add_state()

    # add arrays
    sdfg.add_array('A', [M, N], dtype=dace.vector(dace.int32, W.get()))
    sdfg.add_array('B', [M, N], dtype=dace.vector(dace.int32, W.get()))
    sdfg.add_array('C', [M, N], dtype=dace.vector(dace.int32, W.get()))

    mentry, mexit = state.add_map('compute_map', {'k': '0:M'})

    tasklet = state.add_tasklet(name='rtl_tasklet1',
                                inputs={'a', 'b'},
                                outputs={'c'},
                                code='''
reg [W-1:0][31:0] a_data;
reg a_valid;
reg [W-1:0][31:0] b_data;
reg b_valid;

// Read A
always@(posedge ap_aclk) begin
    if (ap_areset) begin
        s_axis_a_tready <= 0;
        a_valid <= 0;
        a_data <= 0;
    end else begin
        if (s_axis_a_tready && s_axis_a_tvalid) begin
            a_valid <= 1;
            a_data <= s_axis_a_tdata;
            s_axis_a_tready <= 0;
        end else if (m_axis_c_tvalid && m_axis_c_tready) begin
            a_valid <= 0;
            s_axis_a_tready <= 1;
        end else begin
            s_axis_a_tready <= ~a_valid;
        end
    end
end

// Read B
always@(posedge ap_aclk) begin
    if (ap_areset) begin
        s_axis_b_tready <= 0;
        b_valid <= 0;
        b_data <= 0;
    end else begin
        if (s_axis_b_tready && s_axis_b_tvalid) begin
            b_valid <= 1;
            b_data <= s_axis_b_tdata;
            s_axis_b_tready <= 0;
        end else if (m_axis_c_tvalid && m_axis_c_tready) begin
            b_valid <= 0;
            b_data <= 0;
            s_axis_b_tready <= 1;
        end else begin
            s_axis_b_tready <= ~b_valid;
        end
    end
end

// Compute and write C
always@(posedge ap_aclk) begin
    if (ap_areset) begin
        m_axis_c_tvalid <= 0;
        m_axis_c_tdata <= 0;
    end else begin
        if (m_axis_c_tvalid && m_axis_c_tready) begin
            m_axis_c_tvalid <= 0;
        end else if (a_valid && b_valid) begin
            m_axis_c_tvalid <= 1;
            m_axis_c_tdata <= a_data + b_data;
        end
    end
end''',
                                language=dace.Language.SystemVerilog)

    A = state.add_read('A')
    B = state.add_read('B')
    C = state.add_write('C')

    state.add_memlet_path(A, mentry, tasklet, memlet=dace.Memlet('A[k,0:N]'), dst_conn='a')
    state.add_memlet_path(B, mentry, tasklet, memlet=dace.Memlet('B[k,0:N]'), dst_conn='b')
    state.add_memlet_path(tasklet, mexit, C, memlet=dace.Memlet('C[k,0:N]'), src_conn='c')

    sdfg.specialize({'M': M, 'N': N, 'W': W})
    sdfg.validate()

    # init data structures
    a = np.random.randint(0, 100, m * n * w).reshape((m, n, w)).astype(np.int32)
    b = np.random.randint(0, 100, m * n * w).reshape((m, n, w)).astype(np.int32)
    c = np.zeros((m, n, w)).astype(np.int32)

    # call program
    sdfg(A=a, B=b, C=c)

    # check result
    assert (c == a + b).all()
Ejemplo n.º 11
0
def test_tasklet_double_clk_counters():
    """
        Test double clock functionality utilizing two counters, one for each clock.
        The first 16 bits of the result should contain the count from the "slow" clock.
        The last 16 bits of the result should contain the count from the "fast" clock, i.e. slow count * 2
    """
    old_freq = dace.config.Config.get('compiler', 'xilinx', 'frequency')
    dace.config.Config.set('compiler', 'xilinx', 'frequency', value='"0:300\\|1:600"')
    sdfg = dace.SDFG('rtl_tasklet_double_clk_counters')
    state = sdfg.add_state()
    sdfg.add_array('A', [1], dtype=dace.int32)
    sdfg.add_array('B', [1], dtype=dace.int32)

    tasklet = state.add_tasklet(name='rtl_tasklet',
                                inputs={'a'},
                                outputs={'b'},
                                code='''

    reg [31:0] max_cnt;
    reg [15:0] s_cnt;
    reg        s_done;
    reg [15:0] d_cnt;
    reg        d_done;

    always @(posedge ap_aclk) begin
        if (ap_areset) begin
            s_axis_a_tready <= 1;
        end else if (s_axis_a_tvalid && s_axis_a_tready) begin
            max_cnt <= s_axis_a_tdata;
            s_axis_a_tready <= 0;
        end else if (m_axis_b_tvalid && m_axis_b_tready) begin
            s_axis_a_tready <= 1;
        end
    end

    always @(posedge ap_aclk) begin
        if (ap_areset) begin
            s_cnt <= 0;
            s_done <= 0;
        end else if (s_cnt < max_cnt[15:0]) begin
            s_cnt <= s_cnt + 1;
            s_done <= 0;
        end else begin
            s_done <= max_cnt > 0;
        end
    end

    always @(posedge ap_aclk_2) begin
        if (ap_areset) begin
            d_cnt <= 0;
            d_done <= 0;
        end else if (s_cnt < max_cnt[15:0]) begin
            d_cnt <= d_cnt + 1;
            d_done <= 0;
        end else begin
            d_done <= max_cnt > 0;
        end
    end

    always @(posedge ap_aclk) begin
        if (ap_areset) begin
            m_axis_b_tvalid <= 0;
            m_axis_b_tdata <= 0;
        end else begin
            m_axis_b_tvalid <= s_done && d_done;
            m_axis_b_tdata[15:0]  <= s_cnt;
            m_axis_b_tdata[31:16] <= d_cnt;
        end
    end
                                ''',
                                language=dace.Language.SystemVerilog)
    A = state.add_read('A')
    B = state.add_write('B')

    state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0]'))
    state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0]'))

    sdfg.validate()

    a = np.random.randint(0, 100, 1).astype(np.int32)
    b = np.zeros((1, )).astype(np.int32)

    sdfg(A=a, B=b)

    dace.config.Config.set('compiler', 'xilinx', 'frequency', value=old_freq)

    assert b[0] & 0xFFFF == a[0]
    assert (b[0] >> 16) & 0xFFFF == a[0] * 2
Ejemplo n.º 12
0
def test_multi_tasklet():
    """
        Test multiple rtl tasklet support.
    """

    # add sdfg
    sdfg = dace.SDFG('rtl_multi_tasklet')

    # add state
    state = sdfg.add_state()

    # add arrays
    sdfg.add_array('A', [1], dtype=dace.int32)
    sdfg.add_array('B', [1], dtype=dace.int32)
    sdfg.add_array('C', [1], dtype=dace.int32)

    # add custom cpp tasklet
    tasklet0 = state.add_tasklet(name='rtl_tasklet0',
                                 inputs={'a'},
                                 outputs={'b'},
                                 code='''
        typedef enum [1:0] {READY, BUSY, DONE} state_e;
        state_e state;

        always@(posedge ap_aclk) begin
            if (ap_areset) begin // case: reset
                m_axis_b_tdata <= 0;
                s_axis_a_tready <= 1'b1;
                state <= READY;
            end else if (s_axis_a_tvalid && state == READY) begin // case: load a
                m_axis_b_tdata <= s_axis_a_tdata;
                s_axis_a_tready <= 1'b0;
                state <= BUSY;
            end else if (m_axis_b_tdata < 80) // case: increment counter b
                m_axis_b_tdata <= m_axis_b_tdata + 1;
            else
                m_axis_b_tdata <= m_axis_b_tdata;
                state <= DONE;
        end

        assign m_axis_b_tvalid = (m_axis_b_tdata >= 80) ? 1'b1:1'b0;
        ''',
                                 language=dace.Language.SystemVerilog)

    tasklet1 = state.add_tasklet(name='rtl_tasklet1',
                                 inputs={'b'},
                                 outputs={'c'},
                                 code='''
        typedef enum [1:0] {READY, BUSY, DONE} state_e;
        state_e state;

        always@(posedge ap_aclk) begin
            if (ap_areset) begin // case: reset
                m_axis_c_tdata <= 0;
                s_axis_b_tready <= 1'b1;
                state <= READY;
            end else if (s_axis_b_tvalid && state == READY) begin // case: load a
                m_axis_c_tdata <= s_axis_b_tdata;
                s_axis_b_tready <= 1'b0;
                state <= BUSY;
            end else if (m_axis_c_tdata < 100) // case: increment counter b
                m_axis_c_tdata <= m_axis_c_tdata + 1;
            else
                m_axis_c_tdata <= m_axis_c_tdata;
                state <= DONE;
        end

        assign m_axis_c_tvalid = (m_axis_c_tdata >= 100) ? 1'b1:1'b0;
        ''',
                                 language=dace.Language.SystemVerilog)

    # add input/output array
    A = state.add_read('A')
    B_w = state.add_write('B')
    B_r = state.add_read('B')
    C = state.add_write('C')

    # connect input/output array with the tasklet
    state.add_edge(A, None, tasklet0, 'a', dace.Memlet('A[0]'))
    state.add_edge(tasklet0, 'b', B_w, None, dace.Memlet('B[0]'))
    state.add_edge(B_r, None, tasklet1, 'b', dace.Memlet('B[0]'))
    state.add_edge(tasklet1, 'c', C, None, dace.Memlet('C[0]'))

    # validate sdfg
    sdfg.validate()

    # Execute

    # init data structures
    a = np.random.randint(0, 80, 1).astype(np.int32)
    b = np.array([0]).astype(np.int32)
    c = np.array([0]).astype(np.int32)

    # call program
    sdfg(A=a, B=b, C=c)

    # check result
    assert b == 80
    assert c == 100
Ejemplo n.º 13
0
def test_tasklet_vector_conversion():
    """
        Test rtl tasklet vector conversion support.
    """

    # add symbol
    N = dace.symbol('N')

    # add sdfg
    sdfg = dace.SDFG('rtl_tasklet_vector_conversion')

    # define compile-time constant
    sdfg.specialize(dict(N=4))

    # add state
    state = sdfg.add_state()

    # add arrays
    sdfg.add_array('A', [N], dtype=dace.int32)
    sdfg.add_array('B', [1], dtype=dace.int32)

    # add custom cpp tasklet
    tasklet = state.add_tasklet(name='rtl_tasklet',
                                inputs={'a': dace.vector(dace.int32, N)},
                                outputs={'b'},
                                code='''
        /*
            Convention:
               |---------------------------------------------------------------------|
            -->| ap_aclk (clock input)                                               |
            -->| ap_areset (reset input, rst on high)                                |
               |                                                                     |
            -->| {inputs}                                              reg {outputs} |-->
               |                                                                     |
            <--| s_axis_a_tready (ready for data)       (data avail) m_axis_b_tvalid |-->
            -->| s_axis_a_tvalid (new data avail)    (data consumed) m_axis_b_tready |<--
               |---------------------------------------------------------------------|
        */

        typedef enum [1:0] {READY, BUSY, DONE} state_e;
        state_e state;

        always@(posedge ap_aclk) begin
            if (ap_areset) begin // case: reset
                m_axis_b_tdata <= 0;
                s_axis_a_tready <= 1'b1;
                state <= READY;
            end else if (s_axis_a_tvalid && state == READY) begin // case: load a
                m_axis_b_tdata <= s_axis_a_tdata[0];
                s_axis_a_tready <= 1'b0;
                state <= BUSY;
            end else if (m_axis_b_tdata < s_axis_a_tdata[0] + s_axis_a_tdata[1] && state == BUSY) begin // case: increment counter b
                m_axis_b_tdata <= m_axis_b_tdata + 1;
            end else if (state == BUSY) begin
                m_axis_b_tdata <= m_axis_b_tdata;
                state <= DONE;
            end
        end

        assign m_axis_b_tvalid  = (m_axis_b_tdata >= s_axis_a_tdata[0] + s_axis_a_tdata[1] && (state == BUSY || state == DONE)) ? 1'b1:1'b0;
        ''',
                                language=dace.Language.SystemVerilog)

    # add input/output array
    A = state.add_read('A')
    B = state.add_write('B')

    # connect input/output array with the tasklet
    state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0:N]'))
    state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0]'))

    # validate sdfg
    sdfg.validate()

    # Execute

    # init data structures
    a = np.random.randint(0, 100, dace.symbolic.evaluate(N, sdfg.constants)).astype(np.int32)
    b = np.array([0]).astype(np.int32)

    # call program
    sdfg(A=a, B=b)

    # check result
    assert b == a[0] + a[1]
Ejemplo n.º 14
0
def test_tasklet_vector_add():
    """
        Test rtl tasklet vector support.
    """

    # add symbol
    W = dace.symbol('W')

    # add sdfg
    sdfg = dace.SDFG('rtl_tasklet_vector_add')

    # define compile-time constant
    sdfg.specialize(dict(W=4))

    # add state
    state = sdfg.add_state()

    # add arrays
    sdfg.add_array('A', [1], dtype=dace.vector(dace.int32, dace.symbolic.evaluate(W, sdfg.constants)))
    sdfg.add_array('B', [1], dtype=dace.vector(dace.int32, dace.symbolic.evaluate(W, sdfg.constants)))

    # add custom cpp tasklet
    tasklet = state.add_tasklet(name='rtl_tasklet',
                                inputs={'a'},
                                outputs={'b'},
                                code='''
    always@(posedge ap_aclk) begin
        if (ap_areset) begin
            s_axis_a_tready <= 1;
            m_axis_b_tvalid <= 0;
            m_axis_b_tdata <= 0;
        end else if (s_axis_a_tvalid && s_axis_a_tready) begin
            s_axis_a_tready <= 0;
            m_axis_b_tvalid <= 1;
            for (int i = 0; i < W; i++) begin
                m_axis_b_tdata[i] <= s_axis_a_tdata[i] + 42;
            end
        end else if (m_axis_b_tvalid && m_axis_b_tready) begin
            s_axis_a_tready <= 1;
            m_axis_b_tvalid <= 0;
            m_axis_b_tdata <= 0;
        end
    end
        ''',
                                language=dace.Language.SystemVerilog)

    # add input/output array
    A = state.add_read('A')
    B = state.add_write('B')

    # connect input/output array with the tasklet
    state.add_edge(A, None, tasklet, 'a', dace.Memlet('A[0]'))
    state.add_edge(tasklet, 'b', B, None, dace.Memlet('B[0]'))

    # validate sdfg
    sdfg.validate()

    # Execute

    # init data structures
    a = np.random.randint(0, 100, (dace.symbolic.evaluate(W, sdfg.constants), )).astype(np.int32)
    b = np.zeros((dace.symbolic.evaluate(W, sdfg.constants), )).astype(np.int32)

    # call program
    sdfg(A=a, B=b)

    # check result
    print(a)
    print(b)
    assert (b == a + 42).all()
Ejemplo n.º 15
0
#############

# Create top-level SDFG
state = sdfg.add_state('s0')
me, mx = state.add_map('mymap', dict(k='0:2'))
# NOTE: The names of the inputs/outputs of the nested SDFG must match array
#       names above (lines 30, 32)!
nsdfg = state.add_nested_sdfg(sub_sdfg, sdfg, {'sA'}, {'sC'})
Ain = state.add_read('A')
Aout = state.add_write('A')

# Connect dataflow nodes
state.add_memlet_path(Ain,
                      me,
                      nsdfg,
                      memlet=dace.Memlet(data='A', subset='k'),
                      dst_conn='sA')
state.add_memlet_path(nsdfg,
                      mx,
                      Aout,
                      memlet=dace.Memlet(data='A', subset='k'),
                      src_conn='sC')
###

# Validate correctness of SDFG
sdfg.validate()

######################################
if __name__ == '__main__':
    a = np.random.rand(2).astype(np.float32)
    b = np.zeros([2])
Ejemplo n.º 16
0
def make_sdfg():
    """ Creates three SDFG nested within each other, where two input arrays and
        two output arrays are fed throughout the hierarchy. One input and one
        output are not used for anything in the innermost SDFG, and can thus be
        removed in all nestings.
    """

    n = dace.symbol("N")

    sdfg_outer = dace.SDFG("prune_connectors_test")
    sdfg_outer.set_global_code("#include <fstream>\n#include <mutex>")
    state_outer = sdfg_outer.add_state("state_outer")
    sdfg_outer.add_symbol("N", dace.int32)

    sdfg_middle = dace.SDFG("middle")
    sdfg_middle.add_symbol("N", dace.int32)
    nsdfg_middle = state_outer.add_nested_sdfg(
        sdfg_middle,
        sdfg_outer, {"read_used_middle", "read_unused_middle"},
        {"write_used_middle", "write_unused_middle"},
        name="middle")
    state_middle = sdfg_middle.add_state("middle")

    entry_middle, exit_middle = state_middle.add_map("map_middle",
                                                     {"i": "0:N"})

    sdfg_inner = dace.SDFG("inner")
    sdfg_inner.add_symbol("N", dace.int32)
    nsdfg_inner = state_middle.add_nested_sdfg(
        sdfg_inner,
        sdfg_middle, {"read_used_inner", "read_unused_inner"},
        {"write_used_inner", "write_unused_inner"},
        name="inner")
    state_inner = sdfg_inner.add_state("inner")

    entry_inner, exit_inner = state_inner.add_map("map_inner", {"j": "0:N"})
    tasklet = state_inner.add_tasklet("tasklet", {"read_tasklet"},
                                      {"write_tasklet"},
                                      "write_tasklet = read_tasklet + 1")

    for s in ["unused", "used"]:

        # Read

        sdfg_outer.add_array(f"read_{s}", [n, n], dace.uint16)
        sdfg_outer.add_array(f"read_{s}_outer", [n, n], dace.uint16)
        sdfg_middle.add_array(f"read_{s}_middle", [n, n], dace.uint16)
        sdfg_inner.add_array(f"read_{s}_inner", [n], dace.uint16)

        read_outer = state_outer.add_read(f"read_{s}")
        read_middle = state_middle.add_read(f"read_{s}_middle")

        state_outer.add_memlet_path(read_outer,
                                    nsdfg_middle,
                                    dst_conn=f"read_{s}_middle",
                                    memlet=dace.Memlet(f"read_{s}[0:N, 0:N]"))
        state_middle.add_memlet_path(
            read_middle,
            entry_middle,
            nsdfg_inner,
            dst_conn=f"read_{s}_inner",
            memlet=dace.Memlet(f"read_{s}_middle[i, 0:N]"))

        # Write

        sdfg_outer.add_array(f"write_{s}", [n, n], dace.uint16)
        sdfg_outer.add_array(f"write_{s}_outer", [n, n], dace.uint16)
        sdfg_middle.add_array(f"write_{s}_middle", [n, n], dace.uint16)
        sdfg_inner.add_array(f"write_{s}_inner", [n], dace.uint16)

        write_outer = state_outer.add_write(f"write_{s}")
        write_middle = state_middle.add_write(f"write_{s}_middle")

        state_outer.add_memlet_path(nsdfg_middle,
                                    write_outer,
                                    src_conn=f"write_{s}_middle",
                                    memlet=dace.Memlet(f"write_{s}[0:N, 0:N]"))
        state_middle.add_memlet_path(
            nsdfg_inner,
            exit_middle,
            write_middle,
            src_conn=f"write_{s}_inner",
            memlet=dace.Memlet(f"write_{s}_middle[i, 0:N]"))

    read_inner = state_inner.add_read(f"read_used_inner")
    write_inner = state_inner.add_write(f"write_used_inner")

    state_inner.add_memlet_path(read_inner,
                                entry_inner,
                                tasklet,
                                dst_conn=f"read_tasklet",
                                memlet=dace.Memlet(f"read_{s}_inner[j]"))

    state_inner.add_memlet_path(tasklet,
                                exit_inner,
                                write_inner,
                                src_conn=f"write_tasklet",
                                memlet=dace.Memlet(f"write_{s}_inner[j]"))

    # Create mapped nested SDFG where the map entry and exit would be orphaned
    # by pruning the read and write, and must have nedges added to them

    isolated_read = state_outer.add_read("read_unused_outer")
    isolated_write = state_outer.add_write("write_unused_outer")
    isolated_sdfg = dace.SDFG("isolated_sdfg")
    isolated_nsdfg = state_outer.add_nested_sdfg(isolated_sdfg,
                                                 sdfg_outer,
                                                 {"read_unused_isolated"},
                                                 {"write_unused_isolated"},
                                                 name="isolated")
    isolated_sdfg.add_symbol("i", dace.int32)
    isolated_nsdfg.symbol_mapping["i"] = "i"
    isolated_entry, isolated_exit = state_outer.add_map(
        "isolated", {"i": "0:N"})
    state_outer.add_memlet_path(
        isolated_read,
        isolated_entry,
        isolated_nsdfg,
        dst_conn="read_unused_isolated",
        memlet=dace.Memlet("read_unused_outer[0:N, 0:N]"))
    state_outer.add_memlet_path(
        isolated_nsdfg,
        isolated_exit,
        isolated_write,
        src_conn="write_unused_isolated",
        memlet=dace.Memlet("write_unused_outer[0:N, 0:N]"))
    isolated_state = isolated_sdfg.add_state("isolated")
    isolated_state.add_tasklet("isolated", {}, {},
                               """\
static std::mutex mutex;
std::unique_lock<std::mutex> lock(mutex);
std::ofstream of("prune_connectors_test.txt", std::ofstream::app);
of << i << "\\n";""",
                               language=dace.Language.CPP)

    return sdfg_outer
Ejemplo n.º 17
0
def make_sdfg(name="fpga_stcl_test", dtype=dace.float32, veclen=8):

    vtype = dace.vector(dtype, veclen)

    n = dace.symbol("N")
    m = dace.symbol("M")

    sdfg = dace.SDFG(name)

    pre_state = sdfg.add_state(name + "_pre")
    state = sdfg.add_state(name)
    post_state = sdfg.add_state(name + "_post")
    sdfg.add_edge(pre_state, state, dace.InterstateEdge())
    sdfg.add_edge(state, post_state, dace.InterstateEdge())

    _, desc_input_host = sdfg.add_array("a", (n, m / veclen), vtype)
    _, desc_output_host = sdfg.add_array("b", (n, m / veclen), vtype)
    desc_input_device = copy.copy(desc_input_host)
    desc_input_device.storage = dace.StorageType.FPGA_Global
    desc_input_device.location["bank"] = 0
    desc_input_device.transient = True
    desc_output_device = copy.copy(desc_output_host)
    desc_output_device.storage = dace.StorageType.FPGA_Global
    desc_output_device.location["bank"] = 1
    desc_output_device.transient = True
    sdfg.add_datadesc("a_device", desc_input_device)
    sdfg.add_datadesc("b_device", desc_output_device)

    # Host to device
    pre_read = pre_state.add_read("a")
    pre_write = pre_state.add_write("a_device")
    pre_state.add_memlet_path(
        pre_read, pre_write, memlet=dace.Memlet(f"a_device[0:N, 0:M/{veclen}]"))

    # Device to host
    post_read = post_state.add_read("b_device")
    post_write = post_state.add_write("b")
    post_state.add_memlet_path(
        post_read,
        post_write,
        memlet=dace.Memlet(f"b_device[0:N, 0:M/{veclen}]"))

    # Compute state
    read_memory = state.add_read("a_device")
    write_memory = state.add_write("b_device")

    # Memory streams
    sdfg.add_stream("a_stream",
                    vtype,
                    storage=dace.StorageType.FPGA_Local,
                    transient=True)
    sdfg.add_stream("b_stream",
                    vtype,
                    storage=dace.StorageType.FPGA_Local,
                    transient=True)
    produce_input_stream = state.add_write("a_stream")
    consume_input_stream = state.add_read("a_stream")
    produce_output_stream = state.add_write("b_stream")
    consume_output_stream = state.add_write("b_stream")

    tasklet = state.add_tasklet(
        name, {"_north", "_west", "_east", "_south"}, {"result"}, """\
north = _north if i >= 1 else 1
west = _west if {W}*j + u >= 1 else 1
east = _east if {W}*j + u < M - 1 else 1
south = _south if i < N - 1 else 1

result = 0.25 * (north + west + east + south)""".format(W=veclen))

    entry, exit = state.add_pipeline(name, {
        "i": "0:N",
        "j": "0:M/{}".format(veclen),
    },
                                     schedule=dace.ScheduleType.FPGA_Device,
                                     init_size=m / veclen,
                                     init_overlap=False,
                                     drain_size=m / veclen,
                                     drain_overlap=True)

    # Unrolled map
    unroll_entry, unroll_exit = state.add_map(
        name + "_unroll", {"u": "0:{}".format(veclen)},
        schedule=dace.ScheduleType.FPGA_Device,
        unroll=True)

    # Container-to-container copies between arrays and streams
    state.add_memlet_path(read_memory,
                          produce_input_stream,
                          memlet=dace.Memlet(
                              f"{read_memory.data}[0:N, 0:M/{veclen}]",
                              other_subset="0"))
    state.add_memlet_path(consume_output_stream,
                          write_memory,
                          memlet=dace.Memlet(
                              write_memory.data,
                              f"{write_memory.data}[0:N, 0:M/{veclen}]",
                              other_subset="0"))

    # Container-to-container copy from vectorized stream to non-vectorized
    # buffer
    sdfg.add_array("input_buffer", (1, ),
                   vtype,
                   storage=dace.StorageType.FPGA_Local,
                   transient=True)
    sdfg.add_array("shift_register", (2 * m + veclen, ),
                   dtype,
                   storage=dace.StorageType.FPGA_ShiftRegister,
                   transient=True)
    sdfg.add_array("output_buffer", (veclen, ),
                   dtype,
                   storage=dace.StorageType.FPGA_Local,
                   transient=True)
    sdfg.add_array("output_buffer_packed", (1, ),
                   vtype,
                   storage=dace.StorageType.FPGA_Local,
                   transient=True)
    input_buffer = state.add_access("input_buffer")
    shift_register = state.add_access("shift_register")
    output_buffer = state.add_access("output_buffer")
    output_buffer_packed = state.add_access("output_buffer_packed")

    # Only write if not initializing
    read_tasklet = state.add_tasklet(
        name + "_conditional_read", {"_in"}, {"_out"},
        "if not {}:\n\t_out = _in".format(entry.pipeline.drain_condition()))

    # Input stream to buffer
    state.add_memlet_path(consume_input_stream,
                          entry,
                          read_tasklet,
                          dst_conn="_in",
                          memlet=dace.Memlet(f"{consume_input_stream.data}[0]",
                                             dynamic=True))
    state.add_memlet_path(read_tasklet,
                          input_buffer,
                          src_conn="_out",
                          memlet=dace.Memlet(f"{input_buffer.data}[0]"))
    state.add_memlet_path(input_buffer,
                          shift_register,
                          memlet=dace.Memlet(f"{input_buffer.data}[0]",
                                             other_subset=f"2*M:(2*M + {veclen})"))

    # Stencils accesses
    state.add_memlet_path(
        shift_register,
        unroll_entry,
        tasklet,
        dst_conn="_north",
        memlet=dace.Memlet(f"{shift_register.data}[u]"))  # North
    state.add_memlet_path(
        shift_register,
        unroll_entry,
        tasklet,
        dst_conn="_west",
        memlet=dace.Memlet(f"{shift_register.data}[u + M - 1]"))  # West
    state.add_memlet_path(
        shift_register,
        unroll_entry,
        tasklet,
        dst_conn="_east",
        memlet=dace.Memlet(f"{shift_register.data}[u + M + 1]"))  # East
    state.add_memlet_path(
        shift_register,
        unroll_entry,
        tasklet,
        dst_conn="_south",
        memlet=dace.Memlet(f"{shift_register.data}[u + 2 * M]"))  # South

    # Tasklet to buffer
    state.add_memlet_path(tasklet,
                          unroll_exit,
                          output_buffer,
                          src_conn="result",
                          memlet=dace.Memlet(f"{output_buffer.data}[u]"))

    # Pack buffer
    state.add_memlet_path(output_buffer,
                          output_buffer_packed,
                          memlet=dace.Memlet(f"{output_buffer_packed.data}[0]",
                                             other_subset=f"0:{veclen}"))

    # Only write if not initializing
    write_tasklet = state.add_tasklet(
        name + "_conditional_write", {"_in"}, {"_out"},
        "if not {}:\n\t_out = _in".format(entry.pipeline.init_condition()))

    # Buffer to output stream
    state.add_memlet_path(output_buffer_packed,
                          write_tasklet,
                          dst_conn="_in",
                          memlet=dace.Memlet(f"{output_buffer_packed.data}[0]"))

    # Buffer to output stream
    state.add_memlet_path(write_tasklet,
                          exit,
                          produce_output_stream,
                          src_conn="_out",
                          memlet=dace.Memlet(f"{produce_output_stream.data}[0]",
                                             dynamic=True))

    return sdfg
Ejemplo n.º 18
0
        m_axis_c_tdata <= m_axis_c_tdata;
        state <= DONE;
end    

assign m_axis_c_tvalid = (m_axis_c_tdata >= 100) ? 1'b1:1'b0;   
""",
                             language=dace.Language.SystemVerilog)

# add input/output array
A = state.add_read('A')
B_w = state.add_write('B')
B_r = state.add_read('B')
C = state.add_write('C')

# connect input/output array with the tasklet
state.add_edge(A, None, tasklet0, 'a', dace.Memlet('A[0]'))
state.add_edge(tasklet0, 'b', B_w, None, dace.Memlet('B[0]'))
state.add_edge(B_r, None, tasklet1, 'b', dace.Memlet('B[0]'))
state.add_edge(tasklet1, 'c', C, None, dace.Memlet('C[0]'))

# validate sdfg
sdfg.validate()

######################################################################

if __name__ == '__main__':

    # init data structures
    a = np.random.randint(0, 80, 1).astype(np.int32)
    b = np.array([0]).astype(np.int32)
    c = np.array([0]).astype(np.int32)
Ejemplo n.º 19
0
def make_fpga_sdfg():
    '''
    Build an SDFG with two nested SDFGs in a single FPGA state
    '''

    n = dace.symbol("n")
    vecWidth = 4
    vecType = dace.vector(dace.float32, vecWidth)
    sdfg = dace.SDFG("nested_sdfg_kernels")

    ###########################################################################
    # Copy data to FPGA

    copy_in_state = sdfg.add_state("copy_to_device")

    sdfg.add_array("x", shape=[n / vecWidth], dtype=vecType)
    sdfg.add_array("y", shape=[n / vecWidth], dtype=vecType)

    sdfg.add_array("v", shape=[n / vecWidth], dtype=vecType)

    in_host_x = copy_in_state.add_read("x")
    in_host_y = copy_in_state.add_read("y")

    in_host_v = copy_in_state.add_read("v")

    sdfg.add_array("device_x",
                   shape=[n / vecWidth],
                   dtype=vecType,
                   storage=dace.dtypes.StorageType.FPGA_Global,
                   transient=True)
    sdfg.add_array("device_y",
                   shape=[n / vecWidth],
                   dtype=vecType,
                   storage=dace.dtypes.StorageType.FPGA_Global,
                   transient=True)

    sdfg.add_array("device_v",
                   shape=[n / vecWidth],
                   dtype=vecType,
                   storage=dace.dtypes.StorageType.FPGA_Global,
                   transient=True)

    in_device_x = copy_in_state.add_write("device_x")
    in_device_y = copy_in_state.add_write("device_y")

    in_device_v = copy_in_state.add_write("device_v")

    copy_in_state.add_memlet_path(in_host_x, in_device_x, memlet=dace.Memlet(f"{in_host_x.data}[0:{n}/{vecWidth}]"))
    copy_in_state.add_memlet_path(in_host_y, in_device_y, memlet=dace.Memlet(f"{in_host_y.data}[0:{n}/{vecWidth}]"))

    copy_in_state.add_memlet_path(in_host_v, in_device_v, memlet=dace.Memlet(f"{in_host_v.data}[0:{n}/{vecWidth}]"))

    ###########################################################################
    # Copy data from FPGA
    sdfg.add_array("z", shape=[n / vecWidth], dtype=vecType)
    sdfg.add_array("u", shape=[n / vecWidth], dtype=vecType)

    copy_out_state = sdfg.add_state("copy_to_host")

    sdfg.add_array("device_z",
                   shape=[n / vecWidth],
                   dtype=vecType,
                   storage=dace.dtypes.StorageType.FPGA_Global,
                   transient=True)

    sdfg.add_array("device_u",
                   shape=[n / vecWidth],
                   dtype=vecType,
                   storage=dace.dtypes.StorageType.FPGA_Global,
                   transient=True)

    out_device_z = copy_out_state.add_read("device_z")
    out_host_z = copy_out_state.add_write("z")

    out_device_u = copy_out_state.add_read("device_u")
    out_host_u = copy_out_state.add_write("u")

    copy_out_state.add_memlet_path(out_device_z, out_host_z, memlet=dace.Memlet(f"{out_host_z.data}[0:{n}/{vecWidth}]"))
    copy_out_state.add_memlet_path(out_device_u, out_host_u, memlet=dace.Memlet(f"{out_host_u.data}[0:{n}/{vecWidth}]"))
    ###########################################################################
    # State that must not become an FPGA kernel

    non_fpga_state = sdfg.add_state("I_do_not_want_to_be_fpga_kernel")
    non_fpga_state.location["is_FPGA_kernel"] = False
    # Build the vec addition SDFG and nest it

    in_device_x = non_fpga_state.add_read("device_x")
    in_device_y = non_fpga_state.add_read("device_y")
    in_device_v = non_fpga_state.add_read("device_v")
    out_device_z = non_fpga_state.add_write("device_z")
    out_device_u = non_fpga_state.add_write("device_u")

    to_nest = make_vec_add_sdfg()
    # add nested sdfg with symbol mapping
    nested_sdfg = non_fpga_state.add_nested_sdfg(to_nest, sdfg, {"_device_x", "_device_y"}, {"_device_z"},
                                                 {"size": "n"})

    non_fpga_state.add_memlet_path(in_device_x,
                                   nested_sdfg,
                                   dst_conn="_device_x",
                                   memlet=dace.Memlet(f"{in_device_x.data}[0:{n}/{vecWidth}]"))
    non_fpga_state.add_memlet_path(in_device_y,
                                   nested_sdfg,
                                   dst_conn="_device_y",
                                   memlet=dace.Memlet(f"{in_device_y.data}[0:{n}/{vecWidth}]"))
    non_fpga_state.add_memlet_path(nested_sdfg,
                                   out_device_z,
                                   src_conn="_device_z",
                                   memlet=dace.Memlet(f"{out_device_z.data}[0:{n}/{vecWidth}]"))

    # Build the second vec addition SDFG and nest it

    to_nest = make_vec_add_sdfg()
    # add nested sdfg with symbol mapping
    nested_sdfg = non_fpga_state.add_nested_sdfg(to_nest, sdfg, {"_device_x", "_device_y"}, {"_device_z"},
                                                 {"size": "n"})

    non_fpga_state.add_memlet_path(out_device_z,
                                   nested_sdfg,
                                   dst_conn="_device_x",
                                   memlet=dace.Memlet(f"{out_device_z.data}[0:{n}/{vecWidth}]"))
    non_fpga_state.add_memlet_path(in_device_v,
                                   nested_sdfg,
                                   dst_conn="_device_y",
                                   memlet=dace.Memlet(f"{in_device_v.data}[0:{n}/{vecWidth}]"))
    non_fpga_state.add_memlet_path(nested_sdfg,
                                   out_device_u,
                                   src_conn="_device_z",
                                   memlet=dace.Memlet(f"{out_device_u.data}[0:{n}/{vecWidth}]"))

    ######################################
    # Interstate edges
    sdfg.add_edge(copy_in_state, non_fpga_state, dace.sdfg.sdfg.InterstateEdge())
    sdfg.add_edge(non_fpga_state, copy_out_state, dace.sdfg.sdfg.InterstateEdge())
    sdfg.fill_scope_connectors()
    sdfg.validate()

    return sdfg
Ejemplo n.º 20
0
write_b = state.add_tasklet('write_b', {'inp'}, {'out'}, 'out = inp')

# add read and write maps
read_a_entry, read_a_exit = state.add_map(
    'read_a_map', dict(i='0:N'), schedule=dace.ScheduleType.FPGA_Device)
write_b_entry, write_b_exit = state.add_map(
    'write_b_map', dict(i='0:N'), schedule=dace.ScheduleType.FPGA_Device)

# add read_a memlets and access nodes
read_a_inp = state.add_read('fpga_A')
read_a_out = state.add_write('A_stream')
state.add_memlet_path(read_a_inp,
                      read_a_entry,
                      read_a,
                      dst_conn='inp',
                      memlet=dace.Memlet('fpga_A[i]'))
state.add_memlet_path(read_a,
                      read_a_exit,
                      read_a_out,
                      src_conn='out',
                      memlet=dace.Memlet('A_stream[0]'))

# add tasklet memlets
A = state.add_read('A_stream')
B = state.add_write('B_stream')
state.add_memlet_path(A,
                      rtl_tasklet,
                      dst_conn='a',
                      memlet=dace.Memlet('A_stream[0]'))
state.add_memlet_path(rtl_tasklet,
                      B,
Ejemplo n.º 21
0
    def expansion(node: "Gearbox", parent_state: dace.SDFGState,
                  parent_sdfg: dace.SDFG):

        (in_edge, in_desc, out_edge, out_desc, is_pack,
         gear_factor) = node.validate(parent_sdfg, parent_state)

        is_elementwise = in_desc.dtype.base_type == out_desc.dtype.base_type

        sdfg = dace.SDFG(node.name)
        in_desc_inner = copy.deepcopy(in_desc)
        in_desc_inner.transient = False
        sdfg.add_datadesc(in_edge.dst_conn, in_desc_inner)
        out_desc_inner = copy.deepcopy(out_desc)
        out_desc_inner.transient = False
        sdfg.add_datadesc(out_edge.src_conn, out_desc_inner)

        state = sdfg.add_state(node.name)
        input_read = state.add_read(in_edge.dst_conn)
        output_write = state.add_write(out_edge.src_conn)
        vec_it = f"_{node.name}_w"
        entry, exit = state.add_map(node.name, {
            f"_{node.name}_i": f"0:{node.size}",
            vec_it: f"0:{gear_factor}"
        },
                                    schedule=node.schedule)
        buffer_name = f"{node.name}_buffer"

        if is_elementwise:

            dtype = in_desc.dtype.base_type

            if is_pack:
                small_veclen = in_desc.dtype.veclen
                large_veclen = out_desc.dtype.veclen
            else:
                large_veclen = in_desc.dtype.veclen
                small_veclen = out_desc.dtype.veclen

            sdfg.add_array(buffer_name, (large_veclen, ),
                           dtype,
                           storage=out_desc.storage,
                           transient=True)

            nested_sdfg = dace.SDFG(f"{node.name}_nested")
            nested_in_desc = copy.deepcopy(in_desc)
            nested_in_desc.transient = False
            nested_out_desc = copy.deepcopy(out_desc)
            nested_out_desc.transient = False
            nested_sdfg.add_datadesc(f"_{in_edge.dst_conn}", nested_in_desc)
            nested_sdfg.add_datadesc(f"_{out_edge.src_conn}", nested_out_desc)
            read_state = nested_sdfg.add_state(f"{node.name}_read")
            write_state = nested_sdfg.add_state(f"{node.name}_write")

            read_nested = read_state.add_read(f"_{in_edge.dst_conn}")
            write_nested = write_state.add_write(f"_{out_edge.src_conn}")

            buffer_name_inner = f"{buffer_name}_inner"
            nested_sdfg.add_array(buffer_name_inner, (large_veclen, ),
                                  dtype,
                                  storage=out_desc.storage,
                                  transient=False)
            buffer_read = write_state.add_read(buffer_name_inner)
            buffer_write = read_state.add_write(buffer_name_inner)

            elem_it = f"_{node.name}_e"

            if is_pack:

                # Unpack the input vector into individual elements
                unpack_name = f"{node.name}_unpack"
                nested_sdfg.add_array(unpack_name, (small_veclen, ),
                                      dtype,
                                      storage=in_desc.storage,
                                      transient=True)
                unpack_access = read_state.add_write(unpack_name)
                read_state.add_memlet_path(
                    read_nested,
                    unpack_access,
                    memlet=dace.Memlet(f"{read_nested.data}[0]"))

                # Now write the elements into the large buffer at the
                # appropriate indices
                unroll_entry, unroll_exit = read_state.add_map(
                    f"{node.name}_elementwise", {elem_it: f"0:{small_veclen}"},
                    schedule=node.schedule,
                    unroll=True)
                unroll_tasklet = read_state.add_tasklet(
                    f"{node.name}_elementwise", {"unpack_in"}, {"buffer_out"},
                    "buffer_out = unpack_in")
                read_state.add_memlet_path(
                    unpack_access,
                    unroll_entry,
                    unroll_tasklet,
                    dst_conn="unpack_in",
                    memlet=dace.Memlet(f"{unpack_name}[{elem_it}]"))
                read_state.add_memlet_path(
                    unroll_tasklet,
                    unroll_exit,
                    buffer_write,
                    src_conn="buffer_out",
                    memlet=dace.Memlet(f"{buffer_name_inner}[{vec_it} * "
                                       f"{small_veclen} + {elem_it}]"))

                # Only progress to the write state if we're on the last vector
                nested_sdfg.add_edge(
                    read_state, write_state,
                    dace.InterstateEdge(f"{vec_it} >= {gear_factor} - 1"))
                end_state = nested_sdfg.add_state(f"{node.name}_end")
                nested_sdfg.add_edge(
                    read_state, end_state,
                    dace.InterstateEdge(f"{vec_it} < {gear_factor} - 1"))
                nested_sdfg.add_edge(write_state, end_state,
                                     dace.InterstateEdge())

                # Write out
                write_state.add_memlet_path(
                    buffer_read,
                    write_nested,
                    memlet=dace.Memlet(f"{write_nested.data}[0]"))

            else:  # Is unpack

                # Only read a new wide vector on the first iteration
                start_state = nested_sdfg.add_state(f"{node.name}_start")
                nested_sdfg.add_edge(start_state, read_state,
                                     dace.InterstateEdge(f"{vec_it} == 0"))
                nested_sdfg.add_edge(start_state, write_state,
                                     dace.InterstateEdge(f"{vec_it} != 0"))
                nested_sdfg.add_edge(read_state, write_state,
                                     dace.InterstateEdge())

                # Read new wide vector
                read_state.add_memlet_path(
                    read_nested,
                    buffer_write,
                    memlet=dace.Memlet(f"{read_nested.data}[0]"))

                # Read out the appropriate elements and write them out
                pack_name = f"{node.name}_pack"
                nested_sdfg.add_array(pack_name, (small_veclen, ),
                                      dtype,
                                      storage=in_desc.storage,
                                      transient=True)
                pack_access = write_state.add_write(pack_name)
                unroll_entry, unroll_exit = write_state.add_map(
                    f"{node.name}_elementwise", {elem_it: f"0:{small_veclen}"},
                    schedule=node.schedule,
                    unroll=True)
                unroll_tasklet = write_state.add_tasklet(
                    f"{node.name}_elementwise", {"buffer_in"}, {"pack_out"},
                    "pack_out = buffer_in")
                write_state.add_memlet_path(
                    buffer_read,
                    unroll_entry,
                    unroll_tasklet,
                    dst_conn="buffer_in",
                    memlet=dace.Memlet(f"{buffer_name_inner}[{vec_it} * "
                                       f"{small_veclen} + {elem_it}]"))
                write_state.add_memlet_path(
                    unroll_tasklet,
                    unroll_exit,
                    pack_access,
                    src_conn="pack_out",
                    memlet=dace.Memlet(f"{pack_name}[{elem_it}]"))
                write_state.add_memlet_path(
                    pack_access,
                    write_nested,
                    memlet=dace.Memlet(f"{write_nested.data}[0]"))

            nested_sdfg_node = state.add_nested_sdfg(
                nested_sdfg, sdfg,
                {f"_{in_edge.dst_conn}", f"{buffer_name}_inner"},
                {f"_{out_edge.src_conn}", f"{buffer_name}_inner"})
            buffer_read = state.add_read(buffer_name)
            buffer_write = state.add_write(buffer_name)
            state.add_memlet_path(input_read,
                                  entry,
                                  nested_sdfg_node,
                                  dst_conn=f"_{in_edge.dst_conn}",
                                  memlet=dace.Memlet(f"{input_read.data}[0]",
                                                     dynamic=True))
            state.add_memlet_path(nested_sdfg_node,
                                  exit,
                                  output_write,
                                  src_conn=f"_{out_edge.src_conn}",
                                  memlet=dace.Memlet(f"{output_write.data}[0]",
                                                     dynamic=True))
            state.add_memlet_path(
                buffer_read,
                entry,
                nested_sdfg_node,
                dst_conn=buffer_name_inner,
                memlet=dace.Memlet(f"{buffer_name}[0:{large_veclen}]"))
            state.add_memlet_path(
                nested_sdfg_node,
                exit,
                buffer_write,
                src_conn=buffer_name_inner,
                memlet=dace.Memlet(f"{buffer_name}[0:{large_veclen}]"))

        else:  # Not elementwise, one side is a vector of vectors

            vtype = out_desc.dtype if is_pack else in_desc.dtype

            sdfg.add_array(buffer_name, (1, ),
                           vtype,
                           storage=in_desc.storage,
                           transient=True)
            buffer_read = state.add_read(buffer_name)
            buffer_write = state.add_write(buffer_name)

            tasklet = state.add_tasklet(
                node.name, {"val_in", "buffer_in"}, {"val_out", "buffer_out"},
                f"""\
wide = buffer_in
wide[_{node.name}_w] = val_in
if _{node.name}_w == {gear_factor} - 1:
    val_out = wide
buffer_out = wide""" if is_pack else f"""\
wide = val_in if _{node.name}_w == 0 else buffer_in
val_out = wide[_{node.name}_w]
buffer_out = wide""")
            state.add_memlet_path(input_read,
                                  entry,
                                  tasklet,
                                  dst_conn="val_in",
                                  memlet=dace.Memlet(f"{in_edge.dst_conn}[0]",
                                                     dynamic=not is_pack))
            state.add_memlet_path(buffer_read,
                                  entry,
                                  tasklet,
                                  dst_conn="buffer_in",
                                  memlet=dace.Memlet(f"{buffer_name}[0]"))
            state.add_memlet_path(tasklet,
                                  exit,
                                  output_write,
                                  src_conn="val_out",
                                  memlet=dace.Memlet(f"{out_edge.src_conn}[0]",
                                                     dynamic=is_pack))
            state.add_memlet_path(tasklet,
                                  exit,
                                  buffer_write,
                                  src_conn="buffer_out",
                                  memlet=dace.Memlet(f"{buffer_name}[0]"))

        return sdfg
def make_sdfg(squeeze, name):
    N, M = dace.symbol('N'), dace.symbol('M')
    sdfg = dace.SDFG('memlet_propagation_%s' % name)
    sdfg.add_symbol('N', dace.int64)
    sdfg.add_symbol('M', dace.int64)
    sdfg.add_array('A', [N + 1, M], dace.int64)
    state = sdfg.add_state()
    me, mx = state.add_map('map', dict(j='1:M'))
    w = state.add_write('A')

    # Create nested SDFG
    nsdfg = dace.SDFG('nested')
    if squeeze:
        nsdfg.add_array('a1', [N + 1], dace.int64, strides=[M])
        nsdfg.add_array('a2', [N - 1], dace.int64, strides=[M])
    else:
        nsdfg.add_array('a', [N + 1, M], dace.int64)

    nstate = nsdfg.add_state()
    a1 = nstate.add_write('a1' if squeeze else 'a')
    a2 = nstate.add_write('a2' if squeeze else 'a')
    t1 = nstate.add_tasklet('add99', {}, {'out'}, 'out = i + 99')
    t2 = nstate.add_tasklet('add101', {}, {'out'}, 'out = i + 101')
    nstate.add_edge(t1, 'out', a1, None,
                    dace.Memlet('a1[i]' if squeeze else 'a[i, 1]'))
    nstate.add_edge(t2, 'out', a2, None,
                    dace.Memlet('a2[i]' if squeeze else 'a[i+2, 0]'))
    nsdfg.add_loop(None, nstate, None, 'i', '0', 'i < N - 2', 'i + 1')

    # Connect nested SDFG to toplevel one
    nsdfg_node = state.add_nested_sdfg(nsdfg,
                                       None, {},
                                       {'a1', 'a2'} if squeeze else {'a'},
                                       symbol_mapping=dict(j='j', N='N',
                                                           M='M'))
    state.add_nedge(me, nsdfg_node, dace.Memlet())
    # Add outer memlet that is overapproximated
    if squeeze:
        # This is expected to propagate to A[0:N - 2, j].
        state.add_memlet_path(nsdfg_node,
                              mx,
                              w,
                              src_conn='a1',
                              memlet=dace.Memlet('A[0:N+1, j]'))
        # This is expected to propagate to A[2:N, j - 1].
        state.add_memlet_path(nsdfg_node,
                              mx,
                              w,
                              src_conn='a2',
                              memlet=dace.Memlet('A[2:N+1, j-1]'))
    else:
        # This memlet is expected to propagate to A[0:N, j - 1:j + 1].
        state.add_memlet_path(nsdfg_node,
                              mx,
                              w,
                              src_conn='a',
                              memlet=dace.Memlet('A[0:N+1, j-1:j+1]'))

    propagation.propagate_memlets_sdfg(sdfg)

    return sdfg
Ejemplo n.º 23
0
def make_sdfg():

    N = dace.symbol("N")

    sdfg = dace.SDFG("fpga_conflict_resolution")

    sdfg.add_array("host_memory", [N], dace.int32)
    sdfg.add_array("global_memory", [N],
                   dace.int32,
                   transient=True,
                   storage=dace.StorageType.FPGA_Global)
    sdfg.add_array("local_memory", [1],
                   dace.int32,
                   transient=True,
                   storage=dace.StorageType.FPGA_Local)

    state = sdfg.add_state("fpga_conflict_resolution")

    # Copy memory to FPGA
    pre_state = sdfg.add_state("pre_state")
    pre_host = pre_state.add_read("host_memory")
    pre_device = pre_state.add_write("global_memory")
    pre_state.add_memlet_path(pre_host,
                              pre_device,
                              memlet=dace.Memlet("global_memory[0:N]"))
    sdfg.add_edge(pre_state, state, dace.InterstateEdge())

    # Copy memory back
    post_state = sdfg.add_state("post_state")
    post_device = post_state.add_read("global_memory")
    post_host = post_state.add_write("host_memory")
    post_state.add_memlet_path(post_device,
                               post_host,
                               memlet=dace.Memlet("global_memory[0:N]"))
    sdfg.add_edge(state, post_state, dace.InterstateEdge())

    gmem_read = state.add_read("global_memory")
    gmem_write = state.add_write("global_memory")

    local_init = state.add_access("local_memory")
    local_write = state.add_access("local_memory")

    # Initialize local memory
    init_tasklet = state.add_tasklet("init", {}, {"out"}, "out = 0")
    state.add_memlet_path(init_tasklet,
                          local_init,
                          src_conn="out",
                          memlet=dace.Memlet("local_memory[0]"))

    # Accumulate on local memory
    acc_entry, acc_exit = state.add_map("wcr_local", {"i": "0:N"},
                                        schedule=dace.ScheduleType.FPGA_Device)
    acc_tasklet = state.add_tasklet("wcr_local", {"gmem"}, {"lmem"},
                                    "lmem = gmem")
    state.add_memlet_path(gmem_read,
                          acc_entry,
                          acc_tasklet,
                          dst_conn="gmem",
                          memlet=dace.Memlet("global_memory[i]"))
    state.add_memlet_path(local_init, acc_entry, memlet=dace.Memlet())
    state.add_memlet_path(acc_tasklet,
                          acc_exit,
                          local_write,
                          src_conn="lmem",
                          memlet=dace.Memlet("local_memory[0]",
                                             wcr="lambda a, b: a + b"))

    # Write with conflict into global memory
    wcr_entry, wcr_exit = state.add_map("wcr_global", {"i": "0:N"},
                                        schedule=dace.ScheduleType.FPGA_Device)
    wcr_tasklet = state.add_tasklet("wcr_global", {"lmem"}, {"gmem"},
                                    "gmem = lmem")
    state.add_memlet_path(local_write,
                          wcr_entry,
                          wcr_tasklet,
                          dst_conn="lmem",
                          memlet=dace.Memlet("local_memory[0]"))
    state.add_memlet_path(wcr_tasklet,
                          wcr_exit,
                          gmem_write,
                          src_conn="gmem",
                          memlet=dace.Memlet("global_memory[i]",
                                             wcr="lambda a, b: a + b"))

    return sdfg
Ejemplo n.º 24
0
def make_backward_function(model: ONNXModel,
                           apply_strict=False
                           ) -> Type[torch.autograd.Function]:
    """ Convert an ONNXModel to a PyTorch differentiable function. This method should not be used on it's own.
        Instead use the ``backward=True`` parameter of :class:`daceml.pytorch.DaceModule`.

        :param model: the model to convert.
        :param apply_strict: whether to apply strict transformations before creating the backward pass.
        :return: the PyTorch compatible :class:`torch.autograd.Function`.
    """

    if len(model.sdfg.nodes()) != 1:
        raise AutoDiffException(
            "Expected to find exactly one SDFGState, found {}".format(
                len(model.sdfg.nodes())))

    forward_sdfg = model.sdfg
    forward_state = model.sdfg.nodes()[0]

    backward_sdfg = dace.SDFG(forward_sdfg.name + "_backward")
    backward_state = backward_sdfg.add_state()

    gen = BackwardPassGenerator(
        sdfg=forward_sdfg,
        state=forward_state,
        given_gradients=[clean_onnx_name(name) for name in model.outputs],
        required_gradients=[clean_onnx_name(name) for name in model.inputs],
        backward_sdfg=backward_sdfg,
        backward_state=backward_state,
        apply_strict=apply_strict)

    backward_result, backward_grad_arrays, backward_input_arrays = gen.backward(
    )

    replaced_scalars = {}
    for name, desc in backward_input_arrays.items():
        if name not in forward_sdfg.arrays:
            raise AutoDiffException(
                "Expected to find array with name '{}' in SDFG".format(name))

        forward_desc = forward_sdfg.arrays[name]
        # we will save this output and pass it to the backward pass

        # Views should not be forwarded. Instead the backward pass generator should forward the source of the view,
        # and rebuild the sequence of required views in the backward pass.
        assert type(forward_desc) is not dt.View
        if isinstance(forward_desc, dt.Scalar):
            # we can't return scalars from SDFGs, so we add a copy to an array of size 1
            fwd_arr_name, _ = forward_sdfg.add_array(
                name + "_array", [1],
                forward_desc.dtype,
                transient=False,
                storage=forward_desc.storage,
                find_new_name=True)
            bwd_arr_name, _ = backward_sdfg.add_array(
                name + "_array", [1],
                forward_desc.dtype,
                transient=False,
                storage=forward_desc.storage,
                find_new_name=True)
            backward_sdfg.arrays[name].transient = True

            fwd_copy_state = forward_sdfg.add_state_after(forward_state,
                                                          label="copy_out_" +
                                                          fwd_arr_name)
            bwd_copy_state = backward_sdfg.add_state_before(backward_state,
                                                            label="copy_in_" +
                                                            bwd_arr_name)
            fwd_copy_state.add_edge(fwd_copy_state.add_read(name), None,
                                    fwd_copy_state.add_write(fwd_arr_name),
                                    None, dace.Memlet(name + "[0]"))

            bwd_copy_state.add_edge(bwd_copy_state.add_read(bwd_arr_name),
                                    None, bwd_copy_state.add_write(name), None,
                                    dace.Memlet(name + "[0]"))
            replaced_scalars[name] = fwd_arr_name
        else:
            forward_sdfg.arrays[name].transient = False

    backward_sdfg.validate()

    class DaceFunction(torch.autograd.Function):
        _backward_sdfg = backward_sdfg
        _forward_model = model
        _backward_result = backward_result

        @staticmethod
        def forward(ctx, *inputs):
            # setup the intermediate buffers

            if any(not inp.is_contiguous() for inp in inputs):
                log.warning("forced to copy input since it was not contiguous")

            copied_inputs = tuple(
                inp if inp.is_contiguous else inp.contiguous()
                for inp in inputs)

            # prepare the arguments
            inputs, params, symbols, outputs = model._call_args(
                args=copied_inputs, kwargs={})

            # create the empty tensors we need for the intermediate values
            for inp, val in backward_input_arrays.items():
                if isinstance(val, dt.Scalar):
                    # the value we need is actually in an array
                    inp = replaced_scalars[inp]

                if inp not in inputs and inp not in outputs and inp not in params:
                    inputs[inp] = create_output_array(symbols,
                                                      forward_sdfg.arrays[inp],
                                                      use_torch=True)

            DaceFunction._forward_model.sdfg(**inputs, **symbols, **params,
                                             **outputs)

            def _get_arr(name, desc):
                if isinstance(desc, dt.Scalar):
                    name = replaced_scalars[name]
                if name in inputs:
                    value = inputs[name]
                elif name in outputs:
                    value = outputs[name]
                elif name in params:
                    value = params[name]
                else:
                    raise AutoDiffException(
                        f"Could not get value of array {name}")

                return value

            # save the arrays we need for the backward pass
            backward_inputs = {
                name: _get_arr(name, desc)
                for name, desc in backward_input_arrays.items()
            }
            for name in replaced_scalars:
                backward_inputs[replaced_scalars[name]] = backward_inputs[name]
                del backward_inputs[name]
            ctx.dace_backward_inputs = backward_inputs
            ctx.dace_symbols = symbols

            if len(outputs) == 1:
                return next(iter(outputs.values()))

            return tuple(outputs.values())

        @staticmethod
        def backward(ctx, *grads):
            backward_inputs = ctx.dace_backward_inputs

            if len(grads) != len(model.outputs):
                raise ValueError("Expected to receive {} grads, got {}".format(
                    len(model.outputs), len(grads)))

            given_grads = dict(
                zip((DaceFunction._backward_result.given_grad_names[
                    clean_onnx_name(outp)] for outp in model.outputs), grads))
            for name, value in given_grads.items():
                if not isinstance(value, torch.Tensor):
                    raise ValueError(
                        "Unsupported input with type {};"
                        " currently only tensor inputs are supported".format(
                            type(value)))
                if not value.is_contiguous():
                    log.warning(
                        "forced to copy input since it was not contiguous")
                    given_grads[name] = value.contiguous()

            # these are the grads we will calculate
            input_grad_names = [
                DaceFunction._backward_result.required_grad_names[
                    clean_onnx_name(inp)]
                for inp in itertools.chain(model.inputs)
            ]

            # init the grads we will calculate with zeros
            grad_values = OrderedDict()
            for name in input_grad_names:
                grad_values[name] = create_output_array(
                    ctx.dace_symbols,
                    backward_grad_arrays[name],
                    use_torch=True,
                    zeros=True)

            DaceFunction._backward_sdfg(**grad_values, **backward_inputs,
                                        **given_grads)

            return tuple(grad_values.values())

    return DaceFunction
Ejemplo n.º 25
0
def make_sdfg(with_wcr, map_in_guard, reverse_loop, use_variable, assign_after,
              log_path):

    sdfg = dace.SDFG(f"loop_to_map_test_{with_wcr}_{map_in_guard}_"
                     f"{reverse_loop}_{use_variable}_{assign_after}")
    sdfg.set_global_code("#include <fstream>\n#include <mutex>")

    init = sdfg.add_state("init")
    guard = sdfg.add_state("guard")
    body = sdfg.add_state("body")
    after = sdfg.add_state("after")
    post = sdfg.add_state("post")

    N = dace.symbol("N", dace.int32)

    if not reverse_loop:
        sdfg.add_edge(init, guard, dace.InterstateEdge(assignments={"i": "0"}))
        sdfg.add_edge(guard, body, dace.InterstateEdge(condition="i < N"))
        sdfg.add_edge(guard, after, dace.InterstateEdge(condition="i >= N"))
        sdfg.add_edge(body, guard,
                      dace.InterstateEdge(assignments={"i": "i + 1"}))
    else:
        sdfg.add_edge(init, guard,
                      dace.InterstateEdge(assignments={"i": "N - 1"}))
        sdfg.add_edge(guard, body, dace.InterstateEdge(condition="i >= 0"))
        sdfg.add_edge(guard, after, dace.InterstateEdge(condition="i < 0"))
        sdfg.add_edge(body, guard,
                      dace.InterstateEdge(assignments={"i": "i - 1"}))
    sdfg.add_edge(
        after, post,
        dace.InterstateEdge(assignments={"i": "N"} if assign_after else None))

    sdfg.add_array("A", [N], dace.float64)
    sdfg.add_array("B", [N], dace.float64)
    sdfg.add_array("C", [N], dace.float64)
    sdfg.add_array("D", [N], dace.float64)
    sdfg.add_array("E", [1], dace.uint16)

    a = body.add_read("A")
    b = body.add_read("B")
    c = body.add_write("C")
    d = body.add_write("D")

    if map_in_guard:
        guard_read = guard.add_read("C")
        guard_write = guard.add_write("C")
        guard.add_mapped_tasklet("write_self", {"i": "0:N"},
                                 {"c_in": dace.Memlet("C[i]")},
                                 "c_out = c_in",
                                 {"c_out": dace.Memlet("C[i]")},
                                 external_edges=True,
                                 input_nodes={"C": guard_read},
                                 output_nodes={"C": guard_write})

    tasklet0 = body.add_tasklet("tasklet0", {"a"}, {"c"}, "c = 1/a")
    tasklet1 = body.add_tasklet("tasklet1", {"a", "b"}, {"d"},
                                "d = sqrt(a**2 + b**2)")

    tasklet2 = body.add_tasklet("tasklet2", {}, {},
                                f"""\
static std::mutex mutex;
std::unique_lock<std::mutex> lock(mutex);
std::ofstream of("{log_path}", std::ofstream::app);
of << i << "\\n";""",
                                language=dace.Language.CPP)

    body.add_memlet_path(a, tasklet0, dst_conn="a", memlet=dace.Memlet("A[i]"))
    body.add_memlet_path(tasklet0,
                         c,
                         src_conn="c",
                         memlet=dace.Memlet(
                             "C[i]",
                             wcr="lambda a, b: a + b" if with_wcr else None))

    body.add_memlet_path(a, tasklet1, dst_conn="a", memlet=dace.Memlet("A[i]"))
    body.add_memlet_path(b, tasklet1, dst_conn="b", memlet=dace.Memlet("B[i]"))
    body.add_memlet_path(tasklet1,
                         d,
                         src_conn="d",
                         memlet=dace.Memlet(
                             "D[i]",
                             wcr="lambda a, b: a + b" if with_wcr else None))

    e = post.add_write("E")
    post_tasklet = post.add_tasklet("post", {}, {"e"},
                                    "e = i" if use_variable else "e = N")
    post.add_memlet_path(post_tasklet,
                         e,
                         src_conn="e",
                         memlet=dace.Memlet("E[0]"))

    return sdfg
Ejemplo n.º 26
0
    def expansion(node, parent_state, parent_sdfg):

        sdfg = dace.SDFG(node.label + "_outer")
        state = sdfg.add_state(node.label + "_outer")

        (inputs, outputs, shape, field_to_data, field_to_desc, field_to_edge,
         vector_lengths) = parse_connectors(node, parent_state, parent_sdfg)

        #######################################################################
        # Parse the tasklet code
        #######################################################################

        # Replace relative indices with memlet names
        converter = SubscriptConverter()

        # Add copy boundary conditions
        for field in node.boundary_conditions:
            if node.boundary_conditions[field]["btype"] == "copy":
                center_index = tuple(0 for _ in range(
                    len(parent_sdfg.arrays[field_to_data[field]].shape)))
                # This will register the renaming
                converter.convert(field, center_index)

        # Replace accesses in the code
        code, field_accesses = parse_accesses(node.code.as_string, outputs)

        iterator_mapping = make_iterator_mapping(node, field_accesses, shape)
        vector_length = validate_vector_lengths(vector_lengths,
                                                iterator_mapping)
        shape_vectorized = tuple(s / vector_length if i == len(shape) -
                                 1 else s for i, s in enumerate(shape))

        # Extract which fields to read from streams and what to buffer
        buffer_sizes = collections.OrderedDict()
        buffer_accesses = collections.OrderedDict()
        scalars = {}  # {name: type}
        for field_name in inputs:
            relative = field_accesses[field_name]
            dim_mask = iterator_mapping[field_name]
            if not any(dim_mask):
                # This is a scalar, no buffer needed. Instead, the SDFG must
                # take this as a symbol
                scalars[field_name] = parent_sdfg.symbols[field_name]
                sdfg.add_symbol(field_name, parent_sdfg.symbols[field_name])
                continue
            abs_indices = ([
                dim_to_abs_val(i, tuple(s for s, m in zip(shape, dim_mask)
                                        if m), parent_sdfg) for i in relative
            ] + ([0] if field_name in node.boundary_conditions
                 and node.boundary_conditions[field_name]["btype"] == "copy"
                 else []))
            max_access = max(abs_indices)
            min_access = min(abs_indices)
            buffer_size = max_access - min_access + vector_lengths[field_name]
            buffer_sizes[field_name] = buffer_size
            # (indices relative to center, buffer indices, center index)
            buffer_accesses[field_name] = ([tuple(r) for r in relative], [
                i - min_access for i in abs_indices
            ], -min_access)

        # Create a initialization phase corresponding to the highest distance
        # to the center
        init_sizes = [
            (buffer_sizes[key] - vector_lengths[key] - val[2]) // vector_length
            for key, val in buffer_accesses.items()
        ]
        init_size_max = int(np.max(init_sizes))

        parameters = [f"_i{i}" for i in range(len(shape))]

        # Dimensions we need to iterate over
        iterator_mask = np.array([s != 0 and s != 1 for s in shape],
                                 dtype=bool)
        iterators = make_iterators(
            tuple(s for s, m in zip(shape_vectorized, iterator_mask) if m),
            parameters=tuple(s for s, m in zip(parameters, iterator_mask)
                             if m))

        # Manually add pipeline entry and exit nodes
        pipeline_range = dace.properties.SubsetProperty.from_string(', '.join(
            iterators.values()))
        pipeline = dace.sdfg.nodes.Pipeline(
            "compute_" + node.label,
            list(iterators.keys()),
            pipeline_range,
            dace.dtypes.ScheduleType.FPGA_Device,
            False,
            init_size=init_size_max,
            init_overlap=False,
            drain_size=init_size_max,
            drain_overlap=True)
        entry = dace.sdfg.nodes.PipelineEntry(pipeline)
        exit = dace.sdfg.nodes.PipelineExit(pipeline)
        state.add_nodes_from([entry, exit])

        # Add nested SDFG to do 1) shift buffers 2) read from input 3) compute
        nested_sdfg = dace.SDFG(node.label + "_inner", parent=state)
        nested_sdfg_tasklet = state.add_nested_sdfg(
            nested_sdfg,
            sdfg,
            # Input connectors
            [k + "_in" for k in inputs if any(iterator_mapping[k])] +
            [name + "_buffer_in" for name, _ in buffer_sizes.items()],
            # Output connectors
            [k + "_out" for k in outputs] +
            [name + "_buffer_out" for name, _ in buffer_sizes.items()],
            schedule=dace.ScheduleType.FPGA_Device)
        # Propagate symbols
        for sym_name, sym_type in parent_sdfg.symbols.items():
            nested_sdfg.add_symbol(sym_name, sym_type)
            nested_sdfg_tasklet.symbol_mapping[sym_name] = sym_name
        # Map iterators
        for p in parameters:
            nested_sdfg.add_symbol(p, dace.int64)
            nested_sdfg_tasklet.symbol_mapping[p] = p

        # Shift state, which shifts all buffers by one
        shift_state = nested_sdfg.add_state(node.label + "_shift")

        # Update state, which reads new values from memory
        update_state = nested_sdfg.add_state(node.label + "_update")

        #######################################################################
        # Implement boundary conditions
        #######################################################################

        boundary_code, oob_cond = generate_boundary_conditions(
            node, shape, field_accesses, field_to_desc, iterator_mapping)

        #######################################################################
        # Only write if we're in bounds
        #######################################################################

        write_code = ("\n".join([
            "{}_inner_out = {}\n".format(
                output,
                field_accesses[output][tuple(0 for _ in range(len(shape)))])
            for output in outputs
        ]))
        if init_size_max > 0 or len(oob_cond) > 0:
            write_cond = []
            if init_size_max > 0:
                init_cond = pipeline.init_condition()
                write_cond.append("not " + init_cond)
                nested_sdfg_tasklet.symbol_mapping[init_cond] = init_cond
                nested_sdfg.add_symbol(init_cond, dace.bool)
            if len(oob_cond) > 0:
                oob_cond = " or ".join(sorted(oob_cond))
                oob_cond = f"not ({oob_cond})"
                write_cond.append(oob_cond)
            write_cond = " and ".join(write_cond)
            write_cond = f"if {write_cond}:\n\t"
        else:
            write_cond = ""

        code = boundary_code + "\n" + code + "\n" + write_code

        #######################################################################
        # Create DaCe compute state
        #######################################################################

        # Compute state, which reads from input channels, performs the compute,
        # and writes to the output channel(s)
        compute_state = nested_sdfg.add_state(node.label + "_compute")
        compute_inputs = list(
            itertools.chain.from_iterable(
                [["_" + v for v in field_accesses[f].values()] for f in inputs
                 if any(iterator_mapping[f])]))
        compute_tasklet = compute_state.add_tasklet(
            node.label + "_compute",
            compute_inputs, {name + "_inner_out"
                             for name in outputs},
            code,
            language=dace.dtypes.Language.Python)
        if vector_length > 1:
            compute_unroll_entry, compute_unroll_exit = compute_state.add_map(
                compute_state.label + "_unroll",
                {"i_unroll": f"0:{vector_length}"},
                schedule=dace.ScheduleType.FPGA_Device,
                unroll=True)

        # Connect the three nested states
        nested_sdfg.add_edge(shift_state, update_state,
                             dace.sdfg.InterstateEdge())
        nested_sdfg.add_edge(update_state, compute_state,
                             dace.sdfg.InterstateEdge())

        # First, grab scalar variables
        for scalar, scalar_type in scalars.items():
            nested_sdfg.add_symbol(scalar, scalar_type)

        # Code to increment custom iterators
        iterator_code = ""

        for (field_name, size), init_size in zip(buffer_sizes.items(),
                                                 init_sizes):

            data_name = field_to_data[field_name]
            connector = field_to_edge[field_name].dst_conn
            data_name_outer = connector
            data_name_inner = field_name + "_in"
            desc_outer = parent_sdfg.arrays[data_name].clone()
            desc_outer.transient = False
            sdfg.add_datadesc(data_name_outer, desc_outer)

            mapping = iterator_mapping[field_name]
            is_array = not isinstance(desc_outer, dt.Stream)

            # If this array is part of the initialization phase, it needs its
            # own iterator, which we need to instantiate and increment in the
            # outer SDFG
            if is_array:
                if init_size == 0:
                    field_index = [s for s, p in zip(parameters, mapping) if p]
                else:
                    # Create custom iterators for this array
                    num_dims = sum(mapping, 0)
                    field_iterators = [(f"_{field_name}_i{i}", shape[i])
                                       for i in range(num_dims) if mapping[i]]
                    start_index = init_size_max - init_size
                    tab = ""
                    if start_index > 0:
                        iterator_code += (
                            f"if {pipeline.iterator_str()} >= {start_index}:\n"
                        )
                        tab += "  "
                    for i, (it, s) in enumerate(reversed(field_iterators)):
                        iterator_code += f"""\
{tab}if {it} < {s} - 1:
{tab}  {it} = {it} + 1
{tab}else:
{tab}  {it} = 0\n"""
                        tab += "  "
                    field_index = [fi[0] for fi in field_iterators]
                    for fi in field_index:
                        pipeline.additional_iterators[fi] = "0"
                        nested_sdfg.add_symbol(fi, dace.int64)
                        nested_sdfg_tasklet.symbol_mapping[fi] = fi
                field_index = ", ".join(field_index)
            else:
                field_index = "0"

            # Begin reading according to this field's own buffer size, which is
            # translated to an index by subtracting it from the maximum buffer
            # size
            begin_reading = init_size_max - init_size
            total_size = functools.reduce(operator.mul, shape_vectorized, 1)
            end_reading = total_size + init_size_max - init_size

            # Outer memory read
            read_node_outer = state.add_read(data_name_outer)
            if begin_reading != 0 or end_reading != total_size + init_size_max:
                sdfg.add_scalar(f"{field_name}_wavefront",
                                desc_outer.dtype,
                                storage=dace.StorageType.FPGA_Local,
                                transient=True)
                wavefront_access = state.add_access(f"{field_name}_wavefront")
                condition = []
                it = pipeline.iterator_str()
                if begin_reading != 0:
                    condition.append(f"{it} >= {begin_reading}")
                if end_reading != total_size + init_size_max:
                    condition.append(f"{it} < {end_reading}")
                condition = " and ".join(condition)
                update_tasklet = state.add_tasklet(
                    f"read_{field_name}", {"wavefront_in"}, {"wavefront_out"},
                    f"if {condition}:\n"
                    "\twavefront_out = wavefront_in\n",
                    language=dace.dtypes.Language.Python)
                state.add_memlet_path(read_node_outer,
                                      entry,
                                      update_tasklet,
                                      dst_conn="wavefront_in",
                                      memlet=dace.Memlet(
                                          f"{data_name_outer}[{field_index}]",
                                          dynamic=True))
                state.add_memlet_path(update_tasklet,
                                      wavefront_access,
                                      src_conn="wavefront_out",
                                      memlet=dace.Memlet(
                                          f"{field_name}_wavefront",
                                          dynamic=True))
                state.add_memlet_path(
                    wavefront_access,
                    nested_sdfg_tasklet,
                    dst_conn=f"{field_name}_in",
                    memlet=dace.Memlet(f"{field_name}_wavefront"))
            else:
                state.add_memlet_path(
                    read_node_outer,
                    entry,
                    nested_sdfg_tasklet,
                    dst_conn=f"{field_name}_in",
                    memlet=dace.Memlet(f"{data_name_outer}[{field_index}]"))

            # Create inner memory access
            nested_sdfg.add_scalar(data_name_inner,
                                   desc_outer.dtype,
                                   storage=dace.StorageType.FPGA_Local,
                                   transient=False)

            buffer_name_outer = f"{node.label}_{field_name}_buffer"
            buffer_name_inner_read = f"{field_name}_buffer_in"
            buffer_name_inner_write = f"{field_name}_buffer_out"

            # Create buffer transient in outer SDFG
            field_dtype = parent_sdfg.data(data_name).dtype
            _, desc_outer = sdfg.add_array(
                buffer_name_outer, (size, ),
                field_dtype.base_type,
                storage=dace.dtypes.StorageType.FPGA_Local,
                transient=True)

            # Create read and write nodes
            read_node_outer = state.add_read(buffer_name_outer)
            write_node_outer = state.add_write(buffer_name_outer)

            # Outer buffer read
            state.add_memlet_path(
                read_node_outer,
                entry,
                nested_sdfg_tasklet,
                dst_conn=buffer_name_inner_read,
                memlet=dace.Memlet(f"{buffer_name_outer}[0:{size}]"))

            # Outer buffer write
            state.add_memlet_path(nested_sdfg_tasklet,
                                  exit,
                                  write_node_outer,
                                  src_conn=buffer_name_inner_write,
                                  memlet=dace.Memlet(
                                      f"{write_node_outer.data}[0:{size}]",
                                      dynamic=True))

            # Inner copy
            desc_inner_read = desc_outer.clone()
            desc_inner_read.transient = False
            desc_inner_read.name = buffer_name_inner_read
            desc_inner_write = desc_inner_read.clone()
            desc_inner_write.name = buffer_name_inner_write
            nested_sdfg.add_datadesc(buffer_name_inner_read, desc_inner_read)
            nested_sdfg.add_datadesc(buffer_name_inner_write, desc_inner_write)

            # Make shift state if necessary
            if size > 1:
                shift_read = shift_state.add_read(buffer_name_inner_read)
                shift_write = shift_state.add_write(buffer_name_inner_write)
                shift_entry, shift_exit = shift_state.add_map(
                    f"shift_{field_name}",
                    {"i_shift": f"0:{size} - {vector_lengths[field_name]}"},
                    schedule=dace.dtypes.ScheduleType.FPGA_Device,
                    unroll=True)
                shift_tasklet = shift_state.add_tasklet(
                    f"shift_{field_name}", {f"{field_name}_shift_in"},
                    {f"{field_name}_shift_out"},
                    f"{field_name}_shift_out = {field_name}_shift_in")
                shift_state.add_memlet_path(
                    shift_read,
                    shift_entry,
                    shift_tasklet,
                    dst_conn=field_name + "_shift_in",
                    memlet=dace.Memlet(
                        f"{shift_read.data}"
                        f"[i_shift + {vector_lengths[field_name]}]"))
                shift_state.add_memlet_path(
                    shift_tasklet,
                    shift_exit,
                    shift_write,
                    src_conn=field_name + "_shift_out",
                    memlet=dace.Memlet(f"{shift_write.data}[i_shift]"))

            # Make update state
            update_read = update_state.add_read(data_name_inner)
            update_write = update_state.add_write(buffer_name_inner_write)
            subset = f"{size} - {vector_length}:{size}" if size > 1 else "0"
            update_state.add_memlet_path(update_read,
                                         update_write,
                                         memlet=dace.Memlet(
                                             f"{update_read.data}",
                                             other_subset=f"{subset}"))

            # Make compute state
            compute_read = compute_state.add_read(buffer_name_inner_read)
            for relative, offset in zip(buffer_accesses[field_name][0],
                                        buffer_accesses[field_name][1]):
                memlet_name = field_accesses[field_name][tuple(relative)]
                if vector_length > 1:
                    if vector_lengths[field_name] > 1:
                        offset = f"{offset} + i_unroll"
                    else:
                        offset = str(offset)
                    path = [
                        compute_read, compute_unroll_entry, compute_tasklet
                    ]
                else:
                    offset = str(offset)
                    path = [compute_read, compute_tasklet]
                compute_state.add_memlet_path(
                    *path,
                    dst_conn="_" + memlet_name,
                    memlet=dace.Memlet(f"{compute_read.data}[{offset}]"))

        # Tasklet to update iterators
        if iterator_code:
            update_iterator_tasklet = state.add_tasklet(
                f"{node.label}_update_iterators", {}, {}, iterator_code)
            state.add_memlet_path(nested_sdfg_tasklet,
                                  update_iterator_tasklet,
                                  memlet=dace.Memlet())
            state.add_memlet_path(update_iterator_tasklet,
                                  exit,
                                  memlet=dace.Memlet())

        for field_name in outputs:

            for offset in field_accesses[field_name]:
                if offset is not None and list(offset) != [0] * len(offset):
                    raise NotImplementedError("Output offsets not implemented")

            data_name = field_to_data[field_name]

            # Outer write
            data_name_outer = field_name
            data_name_inner = field_name + "_out"
            desc_outer = parent_sdfg.arrays[data_name].clone()
            desc_outer.transient = False
            array_index = ", ".join(map(str, parameters))
            try:
                sdfg.add_datadesc(data_name_outer, desc_outer)
            except NameError:  # Already an input
                pass

            # Create inner access
            nested_sdfg.add_scalar(data_name_inner,
                                   desc_outer.dtype,
                                   storage=dace.StorageType.FPGA_Local,
                                   transient=False)

            # Inner write
            write_node_inner = compute_state.add_write(data_name_inner)

            # Intermediate buffer, mostly relevant for vectorization
            output_buffer_name = field_name + "_output_buffer"
            nested_sdfg.add_array(output_buffer_name, (vector_length, ),
                                  desc_outer.dtype.base_type,
                                  storage=dace.StorageType.FPGA_Registers,
                                  transient=True)
            output_buffer = compute_state.add_access(output_buffer_name)

            # If vectorized, we need to pass through the unrolled scope
            if vector_length > 1:
                compute_state.add_memlet_path(
                    compute_tasklet,
                    compute_unroll_exit,
                    output_buffer,
                    src_conn=field_name + "_inner_out",
                    memlet=dace.Memlet(f"{output_buffer_name}[i_unroll]"))
            else:
                compute_state.add_memlet_path(
                    compute_tasklet,
                    output_buffer,
                    src_conn=field_name + "_inner_out",
                    memlet=dace.Memlet(f"{output_buffer_name}[0]")),

            # Final memlet to the output
            compute_state.add_memlet_path(
                output_buffer,
                write_node_inner,
                memlet=dace.Memlet(f"{write_node_inner.data}")),

            # Conditional write tasklet
            sdfg.add_scalar(f"{field_name}_result",
                            desc_outer.dtype,
                            storage=dace.StorageType.FPGA_Local,
                            transient=True)
            output_access = state.add_access(f"{field_name}_result")
            state.add_memlet_path(nested_sdfg_tasklet,
                                  output_access,
                                  src_conn=data_name_inner,
                                  memlet=dace.Memlet(f"{field_name}_result"))
            output_tasklet = state.add_tasklet(
                f"{field_name}_conditional_write", {f"_{field_name}_result"},
                {f"_{data_name_inner}"},
                (write_cond + f"_{data_name_inner} = _{field_name}_result"))
            state.add_memlet_path(output_access,
                                  output_tasklet,
                                  dst_conn=f"_{field_name}_result",
                                  memlet=dace.Memlet(f"{field_name}_result"))
            write_node_outer = state.add_write(data_name_outer)
            if isinstance(desc_outer, dt.Stream):
                subset = "0"
            else:
                subset = array_index
            state.add_memlet_path(output_tasklet,
                                  exit,
                                  write_node_outer,
                                  src_conn=f"_{data_name_inner}",
                                  memlet=dace.Memlet(
                                      f"{write_node_outer.data}[{subset}]",
                                      dynamic=True)),

        return sdfg
Ejemplo n.º 27
0
def _make_sdfg(l: str = 'Python'):

    language = dtypes.Language.Python if l == 'Python' else dtypes.Language.CPP
    endl = '\n' if l == 'Python' else ';\n'

    sdfg = dace.SDFG(f'map_with_{l}_tasklets')
    _, arrA = sdfg.add_array('A', (20, ), dace.float32)
    _, arrB = sdfg.add_array('B', (10, ), dace.float32)
    _, arrC = sdfg.add_array('C', (10, ), dace.float32)

    state = sdfg.add_state(is_start_state=True)
    A = state.add_read('A')
    B = state.add_read('B')
    C = state.add_write('C')
    me, mx = state.add_map('Map', {'i': '0:10'})
    inputs = {'__inp1', '__inp2'}
    outputs = {'__out'}
    ta = state.add_tasklet(
        'a', inputs, {'__out1', '__out2', '__out3'},
        f'__out1 = __inp1 + __inp2{endl}__out2 = __out1{endl}__out3 = __out1{endl}',
        language)
    tb = state.add_tasklet('b', inputs, outputs,
                           f'__out = __inp1 * __inp2{endl}', language)
    tc = state.add_tasklet('c', inputs, outputs,
                           f'__out = __inp1 + __inp2{endl}', language)
    td = state.add_tasklet('d', inputs, outputs,
                           f'__out = __inp1 / __inp2{endl}', language)
    te = state.add_tasklet('e', inputs, outputs,
                           f'__out = __inp1 * __inp2{endl}', language)
    state.add_memlet_path(A,
                          me,
                          ta,
                          memlet=dace.Memlet('A[i]'),
                          dst_conn='__inp1')
    state.add_memlet_path(B,
                          me,
                          ta,
                          memlet=dace.Memlet('B[i]'),
                          dst_conn='__inp2')
    state.add_memlet_path(A,
                          me,
                          tb,
                          memlet=dace.Memlet('A[2*i]'),
                          dst_conn='__inp2')
    state.add_memlet_path(B,
                          me,
                          tc,
                          memlet=dace.Memlet('B[i]'),
                          dst_conn='__inp2')
    state.add_edge(ta, '__out1', tb, '__inp1', dace.Memlet())
    state.add_edge(ta, '__out2', tc, '__inp1', dace.Memlet())
    state.add_edge(tb, '__out', td, '__inp2', dace.Memlet())
    state.add_edge(tc, '__out', td, '__inp1', dace.Memlet())
    state.add_edge(ta, '__out3', te, '__inp1', dace.Memlet())
    state.add_edge(td, '__out', te, '__inp2', dace.Memlet())
    state.add_memlet_path(te,
                          mx,
                          C,
                          memlet=dace.Memlet('C[i]'),
                          src_conn='__out')

    return sdfg
def _make_sdfg(name, storage=dace.dtypes.StorageType.CPU_Heap, isview=False):

    N = dace.symbol('N', dtype=dace.int32, integer=True, positive=True)
    i = dace.symbol('i', dtype=dace.int32, integer=True)

    sdfg = dace.SDFG(name)
    _, A = sdfg.add_array('A', [N, N, N], dtype=dace.float64)
    _, B = sdfg.add_array('B', [N], dtype=dace.float64)
    if isview:
        _, tmp1 = sdfg.add_view('tmp1', [N - 4, N - 4, N - i],
                                dtype=dace.float64,
                                storage=storage,
                                strides=A.strides)
    else:
        _, tmp1 = sdfg.add_transient('tmp1', [N - 4, N - 4, N - i],
                                     dtype=dace.float64,
                                     storage=storage)
    _, tmp2 = sdfg.add_transient('tmp2', [1],
                                 dtype=dace.float64,
                                 storage=storage)

    begin_state = sdfg.add_state("begin", is_start_state=True)
    guard_state = sdfg.add_state("guard")
    body1_state = sdfg.add_state("body1")
    body2_state = sdfg.add_state("body2")
    body3_state = sdfg.add_state("body3")
    end_state = sdfg.add_state("end")

    sdfg.add_edge(begin_state, guard_state,
                  dace.InterstateEdge(assignments=dict(i='0')))
    sdfg.add_edge(guard_state, body1_state,
                  dace.InterstateEdge(condition=f'i<{N}'))
    sdfg.add_edge(guard_state, end_state,
                  dace.InterstateEdge(condition=f'i>={N}'))
    sdfg.add_edge(body1_state, body2_state, dace.InterstateEdge())
    sdfg.add_edge(body2_state, body3_state, dace.InterstateEdge())
    sdfg.add_edge(body3_state, guard_state,
                  dace.InterstateEdge(assignments=dict(i='i+1')))

    if not isview:
        read_a = body1_state.add_read('A')
        write_tmp1 = body1_state.add_write('tmp1')
        body1_state.add_nedge(read_a, write_tmp1,
                              dace.Memlet(f'A[2:{N}-2, 2:{N}-2, i:{N}]'))

    if isview:
        read_a = body2_state.add_read('A')
        read_tmp1 = body2_state.add_access('tmp1')
        body2_state.add_nedge(read_a, read_tmp1,
                              dace.Memlet(f'A[2:{N}-2, 2:{N}-2, i:{N}]'))
    else:
        read_tmp1 = body2_state.add_read('tmp1')
    rednode = standard.Reduce(wcr='lambda a, b : a + b', identity=0)
    if storage == dace.dtypes.StorageType.GPU_Global:
        rednode.implementation = 'CUDA (device)'
    elif storage == dace.dtypes.StorageType.FPGA_Global:
        rednode.implementation = 'FPGAPartialReduction'
    body2_state.add_node(rednode)
    write_tmp2 = body2_state.add_write('tmp2')
    body2_state.add_nedge(read_tmp1, rednode,
                          dace.Memlet.from_array('tmp1', tmp1))
    body2_state.add_nedge(rednode, write_tmp2, dace.Memlet('tmp2[0]'))

    read_tmp2 = body3_state.add_read('tmp2')
    write_b = body3_state.add_write('B')
    body3_state.add_nedge(read_tmp2, write_b, dace.Memlet('B[i]'))

    return sdfg
Ejemplo n.º 29
0
    def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG):
        """ Create a map around the BlockReduce node
            with in and out transients in registers
            and an if tasklet that redirects the output
            of thread 0 to a shared memory transient
        """
        ### define some useful vars
        graph = state
        reduce_node = node
        in_edge = graph.in_edges(reduce_node)[0]
        out_edge = graph.out_edges(reduce_node)[0]

        axes = reduce_node.axes
        ### add a map that encloses the reduce node
        (new_entry, new_exit) = graph.add_map(
                      name = 'inner_reduce_block',
                      ndrange = {'i'+str(i): f'{rng[0]}:{rng[1]+1}:{rng[2]}'  \
                                for (i,rng) in enumerate(in_edge.data.subset) \
                                if i in axes},
                      schedule = dtypes.ScheduleType.Default)

        map = new_entry.map
        ExpandReduceCUDABlockAll.redirect_edge(graph,
                                               in_edge,
                                               new_dst=new_entry)
        ExpandReduceCUDABlockAll.redirect_edge(graph,
                                               out_edge,
                                               new_src=new_exit)

        subset_in = subsets.Range([
            in_edge.data.subset[i] if i not in axes else
            (new_entry.map.params[0], new_entry.map.params[0], 1)
            for i in range(len(in_edge.data.subset))
        ])
        memlet_in = dace.Memlet(data=in_edge.data.data,
                                volume=1,
                                subset=subset_in)
        memlet_out = dcpy(out_edge.data)
        graph.add_edge(u=new_entry,
                       u_connector=None,
                       v=reduce_node,
                       v_connector=None,
                       memlet=memlet_in)
        graph.add_edge(u=reduce_node,
                       u_connector=None,
                       v=new_exit,
                       v_connector=None,
                       memlet=memlet_out)

        ### add in and out local storage
        from dace.transformation.dataflow.local_storage import LocalStorage

        in_local_storage_subgraph = {
            LocalStorage.node_a: graph.nodes().index(new_entry),
            LocalStorage.node_b: graph.nodes().index(reduce_node)
        }
        out_local_storage_subgraph = {
            LocalStorage.node_a: graph.nodes().index(reduce_node),
            LocalStorage.node_b: graph.nodes().index(new_exit)
        }

        local_storage = LocalStorage(sdfg.sdfg_id,
                                     sdfg.nodes().index(state),
                                     in_local_storage_subgraph, 0)

        local_storage.array = in_edge.data.data
        local_storage.apply(sdfg)
        in_transient = local_storage._data_node
        sdfg.data(in_transient.data).storage = dtypes.StorageType.Register

        local_storage = LocalStorage(sdfg.sdfg_id,
                                     sdfg.nodes().index(state),
                                     out_local_storage_subgraph, 0)
        local_storage.array = out_edge.data.data
        local_storage.apply(sdfg)
        out_transient = local_storage._data_node
        sdfg.data(out_transient.data).storage = dtypes.StorageType.Register

        # hack: swap edges as local_storage does not work correctly here
        # as subsets and data get assigned wrongly (should be swapped)
        # NOTE: If local_storage ever changes, this will not work any more
        e1 = graph.in_edges(out_transient)[0]
        e2 = graph.out_edges(out_transient)[0]
        e1.data.data = dcpy(e2.data.data)
        e1.data.subset = dcpy(e2.data.subset)

        ### add an if tasket and diverge
        code = 'if '
        for (i, param) in enumerate(new_entry.map.params):
            code += (param + '== 0')
            if i < len(axes) - 1:
                code += ' and '
        code += ':\n'
        code += '\tout=inp'

        tasklet_node = graph.add_tasklet(name='block_reduce_write',
                                         inputs=['inp'],
                                         outputs=['out'],
                                         code=code)

        edge_out_outtrans = graph.out_edges(out_transient)[0]
        edge_out_innerexit = graph.out_edges(new_exit)[0]
        ExpandReduceCUDABlockAll.redirect_edge(graph,
                                               edge_out_outtrans,
                                               new_dst=tasklet_node,
                                               new_dst_conn='inp')
        e = graph.add_edge(u=tasklet_node,
                           u_connector='out',
                           v=new_exit,
                           v_connector=None,
                           memlet=dcpy(edge_out_innerexit.data))
        # set dynamic with volume 0 FORNOW
        e.data.volume = 0
        e.data.dynamic = True

        ### set reduce_node axes to all (needed)
        reduce_node.axes = None

        # fill scope connectors, done.
        sdfg.fill_scope_connectors()

        # finally, change the implementation to cuda (block)
        # itself and expand again.
        reduce_node.implementation = 'CUDA (block)'
        sub_expansion = ExpandReduceCUDABlock(0, 0, {}, 0)
        return sub_expansion.expansion(node=node, state=state, sdfg=sdfg)
Ejemplo n.º 30
0
def create_gemm_sdfg(sdfg_name,
                     alpha,
                     beta,
                     A,
                     B,
                     C,
                     dtype,
                     transA=False,
                     transB=False,
                     vec_width=1,
                     expansion_args=None):
    '''
    Build an SDFG that perform the given GEMM operation along the given axis
    Input data A, B, and C is not vectorized
    '''
    sdfg = dace.SDFG(sdfg_name)

    ###########################################################################
    # Copy data to FPGA

    copy_in_state = sdfg.add_state("copy_to_device")
    A_shape = A.shape
    B_shape = B.shape
    C_shape = C.shape
    N = A_shape[0]
    K = A_shape[1]
    M = B_shape[1]
    vec_type = dace.vector(dtype, vec_width)

    # Create data containers
    sdfg.add_array('A', A_shape, dtype)
    sdfg.add_array("A_device",
                   shape=A_shape,
                   dtype=dtype,
                   storage=dace.dtypes.StorageType.FPGA_Global,
                   transient=True)
    sdfg.add_array("B", [K, M / vec_width], dtype=vec_type)
    sdfg.add_array("B_device", [K, M / vec_width],
                   dtype=vec_type,
                   transient=True,
                   storage=dace.dtypes.StorageType.FPGA_Global)

    sdfg.add_array("C", [N, M / vec_width], dtype=vec_type)
    sdfg.add_array("C_device", [N, M / vec_width],
                   dtype=vec_type,
                   transient=True,
                   storage=dace.dtypes.StorageType.FPGA_Global)

    # Copy A
    in_host_A = copy_in_state.add_read("A")
    in_device_A = copy_in_state.add_write("A_device")
    copy_in_state.add_memlet_path(in_host_A,
                                  in_device_A,
                                  memlet=dace.Memlet(f"A[0:{N}, 0:{K}]"))

    # Copy B
    in_host_B = copy_in_state.add_read("B")
    in_device_B = copy_in_state.add_write("B_device")
    copy_in_state.add_memlet_path(
        in_host_B,
        in_device_B,
        memlet=dace.Memlet(f"B[0:{K}, 0:{M}/{vec_width}]"))

    # Copy C
    in_host_C = copy_in_state.add_read("C")
    in_device_C = copy_in_state.add_write("C_device")
    copy_in_state.add_memlet_path(
        in_host_C,
        in_device_C,
        memlet=dace.Memlet(f"C[0:{N}, 0:{M}/{vec_width}]"))

    ###########################################################################
    # Copy data from FPGA
    copy_out_state = sdfg.add_state("copy_from_device")

    out_device = copy_out_state.add_read("C_device")
    out_host = copy_out_state.add_write("C")
    copy_out_state.add_memlet_path(
        out_device,
        out_host,
        memlet=dace.Memlet(f"C[0:{N}, 0:{M}//{vec_width}]"))

    ########################################################################
    # FPGA State

    fpga_state = sdfg.add_state("fpga_state")
    in_A = fpga_state.add_read("A_device")
    in_B = fpga_state.add_read("B_device")
    in_C = fpga_state.add_read("C_device")
    out_C = fpga_state.add_read("C_device")

    gemm_node = blas.Gemm("gemm",
                          transA=transA,
                          transB=transB,
                          alpha=alpha,
                          beta=beta)
    gemm_node.implementation = "FPGA1DSystolic"

    fpga_state.add_memlet_path(in_A,
                               gemm_node,
                               dst_conn="_a",
                               memlet=dace.Memlet(f"A_device[0:{N}, 0:{K}]"))
    fpga_state.add_memlet_path(
        in_B,
        gemm_node,
        dst_conn="_b",
        memlet=dace.Memlet(f"B_device[0:{K}, 0:{M}/{vec_width}]"))
    fpga_state.add_memlet_path(
        in_C,
        gemm_node,
        dst_conn="_cin",
        memlet=dace.Memlet(f"C_device[0:{N}, 0:{M}/{vec_width}]"))
    fpga_state.add_memlet_path(
        gemm_node,
        out_C,
        src_conn="_c",
        memlet=dace.Memlet(f"C_device[0:{N}, 0:{M}/{vec_width}]"))

    ######################################
    # Interstate edges
    sdfg.add_edge(copy_in_state, fpga_state, dace.sdfg.sdfg.InterstateEdge())
    sdfg.add_edge(fpga_state, copy_out_state, dace.sdfg.sdfg.InterstateEdge())
    sdfg.validate()

    if expansion_args is not None:
        gemm_node.expand(sdfg, fpga_state, **expansion_args)

    return sdfg