示例#1
0
def test_warp_softmax(vector_length=1):
    # Get SDFG
    sdfg = softmax_fwd.to_sdfg(strict=True)

    # Apply transformations
    sdfg.apply_transformations_repeated(ReduceExpansion)
    MultiExpansion.apply_to(sdfg, sdfg.node(0).nodes())
    SubgraphFusion.apply_to(sdfg, sdfg.node(0).nodes())
    sdfg.expand_library_nodes()
    sdfg.apply_strict_transformations()
    sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion])
    sdfg.apply_transformations(GPUTransformSDFG)
    sdfg.apply_transformations(WarpTiling)
    sdfg.apply_transformations_repeated([HoistState, InlineSDFG, StateFusion],
                                        strict=True)
    sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion])
    if vector_length != 1:
        sdfg.apply_transformations_repeated(
            Vectorization,
            dict(vector_len=vector_length,
                 preamble=False,
                 postamble=False,
                 strided_map=False))
    sdfg.specialize(dict(dn1=2, dn2=16, dn3=128, dr=128))

    # Check validity
    sdfg.validate()
    assert sdfg.number_of_nodes() == 1
    state = sdfg.node(0)
    assert len([
        c for c in state.scope_children()[None]
        if isinstance(c, dace.nodes.MapEntry)
    ]) == 1

    # Check correctness
    inp = np.random.rand(2, 16, 128, 128).astype(np.float32)
    out = np.random.rand(2, 16, 128, 128).astype(np.float32)
    reg_out = softmax(inp)

    sdfg(inp=inp, out=out)

    assert np.allclose(out, reg_out, rtol=1e-4, atol=1e-6)
示例#2
0
def test_applyto_subgraph():
    sdfg = dbladd.to_sdfg()
    sdfg.simplify()
    state = sdfg.node(0)
    # Apply to subgraph
    SubgraphFusion.apply_to(sdfg, state.nodes())
示例#3
0
def test_applyto_subgraph():
    sdfg = dbladd.to_sdfg()
    sdfg.apply_strict_transformations()
    state = sdfg.node(0)
    # Apply to subgraph
    SubgraphFusion.apply_to(sdfg, state.nodes())
示例#4
0
def test_applyto_subgraph():
    sdfg = dbladd.to_sdfg()
    sdfg.coarsen_dataflow()
    state = sdfg.node(0)
    # Apply to subgraph
    SubgraphFusion.apply_to(sdfg, state.nodes())