def test_warp_softmax(vector_length=1): # Get SDFG sdfg = softmax_fwd.to_sdfg(strict=True) # Apply transformations sdfg.apply_transformations_repeated(ReduceExpansion) MultiExpansion.apply_to(sdfg, sdfg.node(0).nodes()) SubgraphFusion.apply_to(sdfg, sdfg.node(0).nodes()) sdfg.expand_library_nodes() sdfg.apply_strict_transformations() sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion]) sdfg.apply_transformations(GPUTransformSDFG) sdfg.apply_transformations(WarpTiling) sdfg.apply_transformations_repeated([HoistState, InlineSDFG, StateFusion], strict=True) sdfg.apply_transformations_repeated([TrivialMapElimination, MapFusion]) if vector_length != 1: sdfg.apply_transformations_repeated( Vectorization, dict(vector_len=vector_length, preamble=False, postamble=False, strided_map=False)) sdfg.specialize(dict(dn1=2, dn2=16, dn3=128, dr=128)) # Check validity sdfg.validate() assert sdfg.number_of_nodes() == 1 state = sdfg.node(0) assert len([ c for c in state.scope_children()[None] if isinstance(c, dace.nodes.MapEntry) ]) == 1 # Check correctness inp = np.random.rand(2, 16, 128, 128).astype(np.float32) out = np.random.rand(2, 16, 128, 128).astype(np.float32) reg_out = softmax(inp) sdfg(inp=inp, out=out) assert np.allclose(out, reg_out, rtol=1e-4, atol=1e-6)
def test_applyto_subgraph(): sdfg = dbladd.to_sdfg() sdfg.simplify() state = sdfg.node(0) # Apply to subgraph SubgraphFusion.apply_to(sdfg, state.nodes())
def test_applyto_subgraph(): sdfg = dbladd.to_sdfg() sdfg.apply_strict_transformations() state = sdfg.node(0) # Apply to subgraph SubgraphFusion.apply_to(sdfg, state.nodes())
def test_applyto_subgraph(): sdfg = dbladd.to_sdfg() sdfg.coarsen_dataflow() state = sdfg.node(0) # Apply to subgraph SubgraphFusion.apply_to(sdfg, state.nodes())