예제 #1
0
    def test_multidim(self):
        sdfg = dace.SDFG('mapfission_multidim')
        sdfg.add_array('A', [2, 3], dace.float64)
        state = sdfg.add_state()
        me, mx = state.add_map('outer', dict(i='0:2', j='0:3'))

        nsdfg = dace.SDFG('nested')
        nsdfg.add_array('a', [1], dace.float64)
        nstate = nsdfg.add_state()
        t = nstate.add_tasklet('reset', {}, {'out'}, 'out = 0')
        a = nstate.add_write('a')
        nstate.add_edge(t, 'out', a, None, dace.Memlet.simple('a', '0'))
        nsdfg_node = state.add_nested_sdfg(nsdfg, None, {}, {'a'})

        state.add_edge(me, None, nsdfg_node, None, dace.EmptyMemlet())
        anode = state.add_write('A')
        state.add_memlet_path(nsdfg_node,
                              mx,
                              anode,
                              src_conn='a',
                              memlet=dace.Memlet.simple('A', 'i,j'))

        self.assertGreater(sdfg.apply_transformations(MapFission), 0)

        # Test
        A = np.random.rand(2, 3)
        sdfg(A=A)
        self.assertTrue(np.allclose(A, np.zeros_like(A)))
예제 #2
0
def test_state_subgraph():
    sdfg = dace.SDFG('fsymtest2')
    state = sdfg.add_state()

    # Add a nested SDFG
    nsdfg = dace.SDFG('nsdfg')
    nstate = nsdfg.add_state()
    me, mx = state.add_map('map', dict(i='0:N'))
    nsdfg = state.add_nested_sdfg(nsdfg,
                                  None, {}, {},
                                  symbol_mapping=dict(l=L / 2, i='i'))
    state.add_nedge(me, nsdfg, dace.EmptyMemlet())
    state.add_nedge(nsdfg, mx, dace.EmptyMemlet())

    # Entire graph
    assert state.free_symbols == {'L', 'N'}

    # Try a subgraph containing only the map contents
    assert state.scope_subgraph(me, include_entry=False,
                                include_exit=False).free_symbols == {'L', 'i'}
예제 #3
0
    def test_connector_mismatch(self):
        try:
            sdfg = dace.SDFG('a')
            state = sdfg.add_state()
            me, mx = state.add_map('b', dict(i="0:1"))
            A = state.add_array('A', [1], dace.float32)
            T = state.add_tasklet('T', {'a'}, {}, 'printf("%f", a)')

            me.add_in_connector('IN_a')
            me.add_out_connector('OUT_b')
            state.add_edge(A, None, me, 'IN_a',
                           dace.Memlet.from_array(A.data, A.desc(sdfg)))
            state.add_edge(me, 'OUT_b', T, 'a', dace.Memlet.simple(A, '0'))
            state.add_edge(T, None, mx, None, dace.EmptyMemlet())

            sdfg.validate()
            self.fail('Failed to detect invalid SDFG')
        except dace.sdfg.InvalidSDFGError as ex:
            print('Exception caught:', ex)
예제 #4
0
def mapfission_sdfg():
    sdfg = dace.SDFG('mapfission')
    sdfg.add_array('A', [4], dace.float64)
    sdfg.add_array('B', [2], dace.float64)
    sdfg.add_scalar('scal', dace.float64, transient=True)
    sdfg.add_scalar('s1', dace.float64, transient=True)
    sdfg.add_transient('s2', [2], dace.float64)
    sdfg.add_transient('s3out', [1], dace.float64)
    state = sdfg.add_state()

    # Nodes
    rnode = state.add_read('A')
    ome, omx = state.add_map('outer', dict(i='0:2'))
    t1 = state.add_tasklet('one', {'a'}, {'b'}, 'b = a[0] + a[1]')
    ime2, imx2 = state.add_map('inner', dict(j='0:2'))
    t2 = state.add_tasklet('two', {'a'}, {'b'}, 'b = a * 2')
    s24node = state.add_access('s2')
    s34node = state.add_access('s3out')
    ime3, imx3 = state.add_map('inner', dict(j='0:2'))
    t3 = state.add_tasklet('three', {'a'}, {'b'}, 'b = a[0] * 3')
    scalar = state.add_tasklet('scalar', {}, {'out'}, 'out = 5.0')
    t4 = state.add_tasklet('four', {'ione', 'itwo', 'ithree', 'sc'}, {'out'},
                           'out = ione + itwo[0] * itwo[1] + ithree + sc')
    wnode = state.add_write('B')

    # Edges
    state.add_nedge(ome, scalar, dace.EmptyMemlet())
    state.add_memlet_path(rnode,
                          ome,
                          t1,
                          memlet=dace.Memlet.simple('A', '2*i:2*i+2'),
                          dst_conn='a')
    state.add_memlet_path(rnode,
                          ome,
                          ime2,
                          t2,
                          memlet=dace.Memlet.simple('A', '2*i+j'),
                          dst_conn='a')
    state.add_memlet_path(t2,
                          imx2,
                          s24node,
                          memlet=dace.Memlet.simple('s2', 'j'),
                          src_conn='b')
    state.add_memlet_path(rnode,
                          ome,
                          ime3,
                          t3,
                          memlet=dace.Memlet.simple('A', '2*i:2*i+2'),
                          dst_conn='a')
    state.add_memlet_path(t3,
                          imx3,
                          s34node,
                          memlet=dace.Memlet.simple('s3out', '0'),
                          src_conn='b')

    state.add_edge(t1, 'b', t4, 'ione', dace.Memlet.simple('s1', '0'))
    state.add_edge(s24node, None, t4, 'itwo', dace.Memlet.simple('s2', '0:2'))
    state.add_edge(s34node, None, t4, 'ithree',
                   dace.Memlet.simple('s3out', '0'))
    state.add_edge(scalar, 'out', t4, 'sc', dace.Memlet.simple('scal', '0'))
    state.add_memlet_path(t4,
                          omx,
                          wnode,
                          memlet=dace.Memlet.simple('B', 'i'),
                          src_conn='out')

    sdfg.validate()
    return sdfg