def test_out(): sdfg = dace.SDFG("test_redundant_copy_out") state = sdfg.add_state() sdfg.add_array("A", [3, 3], dace.float32) sdfg.add_transient("B", [3, 3], dace.float32, storage=dace.StorageType.GPU_Global) sdfg.add_transient("C", [3, 3], dace.float32) sdfg.add_array("D", [3, 3], dace.float32) A = state.add_access("A") B = state.add_access("B") C = state.add_access("C") trans = Transpose("transpose", dtype=dace.float32) D = state.add_access("D") state.add_edge(A, None, B, None, sdfg.make_array_memlet("A")) state.add_edge(B, None, C, None, sdfg.make_array_memlet("B")) state.add_edge(C, None, trans, "_inp", sdfg.make_array_memlet("C")) state.add_edge(trans, "_out", D, None, sdfg.make_array_memlet("D")) sdfg.apply_strict_transformations() assert len(state.nodes()) == 3 assert B not in state.nodes() sdfg.validate() A_arr = np.arange(9, dtype=np.float32).reshape(3, 3) D_arr = np.zeros_like(A_arr) sdfg(A=A_arr, D=D_arr) assert (A_arr == D_arr.T).all()
def test_in(): sdfg = dace.SDFG("test_redundant_copy_in") state = sdfg.add_state() sdfg.add_array("A", [3, 3], dace.float32) sdfg.add_transient("B", [3, 3], dace.float32) sdfg.add_transient("C", [3, 3], dace.float32, storage=dace.StorageType.GPU_Global) sdfg.add_array("D", [3, 3], dace.float32) A = state.add_access("A") trans = Transpose("transpose", dtype=dace.float32) state.add_node(trans) B = state.add_access("B") C = state.add_access("C") D = state.add_access("D") state.add_edge(A, None, trans, "_inp", sdfg.make_array_memlet("A")) state.add_edge(trans, "_out", B, None, sdfg.make_array_memlet("B")) state.add_edge(B, None, C, None, sdfg.make_array_memlet("B")) state.add_edge(C, None, D, None, sdfg.make_array_memlet("C")) sdfg.coarsen_dataflow() sdfg.apply_transformations_repeated(RedundantArrayCopyingIn) assert len(state.nodes()) == 3 assert C not in state.nodes() sdfg.validate() A_arr = np.copy(np.arange(9, dtype=np.float32).reshape(3, 3)) D_arr = np.zeros_like(A_arr) sdfg(A=A_arr, D=D_arr) assert np.allclose(A_arr, D_arr.T)
def _make_sdfg(node, parent_state, parent_sdfg, implementation): inp_desc, inp_shape, out_desc, out_shape = node.validate(parent_sdfg, parent_state) dtype = inp_desc.dtype sdfg = dace.SDFG("{l}_sdfg".format(l=node.label)) ain_arr = sdfg.add_array('_a', inp_shape, dtype=dtype, strides=inp_desc.strides) bout_arr = sdfg.add_array('_b', out_shape, dtype=dtype, strides=out_desc.strides) info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True) if implementation == 'cuSolverDn': binout_arr = sdfg.add_array('_bt', inp_shape, dtype=dtype, transient=True) else: binout_arr = bout_arr state = sdfg.add_state("{l}_state".format(l=node.label)) potrf_node = Potrf('potrf', lower=node.lower) potrf_node.implementation = implementation _, me, mx = state.add_mapped_tasklet('_uzero_', dict(__i="0:%s" % out_shape[0], __j="0:%s" % out_shape[1]), dict(_inp=Memlet.simple('_b', '__i, __j')), '_out = (__i < __j) ? 0 : _inp;', dict(_out=Memlet.simple('_b', '__i, __j')), language=dace.dtypes.Language.CPP, external_edges=True) ain = state.add_read('_a') if implementation == 'cuSolverDn': binout1 = state.add_access('_bt') binout2 = state.add_access('_bt') binout3 = state.in_edges(me)[0].src bout = state.out_edges(mx)[0].dst transpose_ain = Transpose('AT', dtype=dtype) transpose_ain.implementation = 'cuBLAS' state.add_edge(ain, None, transpose_ain, '_inp', Memlet.from_array(*ain_arr)) state.add_edge(transpose_ain, '_out', binout1, None, Memlet.from_array(*binout_arr)) transpose_out = Transpose('BT', dtype=dtype) transpose_out.implementation = 'cuBLAS' state.add_edge(binout2, None, transpose_out, '_inp', Memlet.from_array(*binout_arr)) state.add_edge(transpose_out, '_out', binout3, None, Memlet.from_array(*bout_arr)) else: binout1 = state.add_access('_b') binout2 = state.in_edges(me)[0].src binout3 = state.out_edges(mx)[0].dst state.add_nedge(ain, binout1, Memlet.from_array(*ain_arr)) info = state.add_write('_info') state.add_memlet_path(binout1, potrf_node, dst_conn="_xin", memlet=Memlet.from_array(*binout_arr)) state.add_memlet_path(potrf_node, info, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(potrf_node, binout2, src_conn="_xout", memlet=Memlet.from_array(*binout_arr)) return sdfg