def make_sdfg(implementation, dtype, storage=dace.StorageType.Default): n = dace.symbol("n", dace.int64) sdfg = dace.SDFG("linalg_cholesky_{}_{}".format(implementation, dtype)) state = sdfg.add_state("dataflow") inp = sdfg.add_array("xin", [n, n], dtype) out = sdfg.add_array("xout", [n, n], dtype) xin = state.add_read("xin") xout = state.add_write("xout") chlsky_node = Cholesky("cholesky", lower=True) chlsky_node.implementation = implementation state.add_memlet_path(xin, chlsky_node, dst_conn="_a", memlet=Memlet.from_array(*inp)) state.add_memlet_path(chlsky_node, xout, src_conn="_b", memlet=Memlet.from_array(*out)) return sdfg
def _make_sdfg(node, parent_state, parent_sdfg, implementation): inp_desc, inp_shape, out_desc, out_shape = node.validate(parent_sdfg, parent_state) dtype = inp_desc.dtype sdfg = dace.SDFG("{l}_sdfg".format(l=node.label)) ain_arr = sdfg.add_array('_a', inp_shape, dtype=dtype, strides=inp_desc.strides) bout_arr = sdfg.add_array('_b', out_shape, dtype=dtype, strides=out_desc.strides) info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True) if implementation == 'cuSolverDn': binout_arr = sdfg.add_array('_bt', inp_shape, dtype=dtype, transient=True) else: binout_arr = bout_arr state = sdfg.add_state("{l}_state".format(l=node.label)) potrf_node = Potrf('potrf', lower=node.lower) potrf_node.implementation = implementation _, me, mx = state.add_mapped_tasklet('_uzero_', dict(__i="0:%s" % out_shape[0], __j="0:%s" % out_shape[1]), dict(_inp=Memlet.simple('_b', '__i, __j')), '_out = (__i < __j) ? 0 : _inp;', dict(_out=Memlet.simple('_b', '__i, __j')), language=dace.dtypes.Language.CPP, external_edges=True) ain = state.add_read('_a') if implementation == 'cuSolverDn': binout1 = state.add_access('_bt') binout2 = state.add_access('_bt') binout3 = state.in_edges(me)[0].src bout = state.out_edges(mx)[0].dst transpose_ain = Transpose('AT', dtype=dtype) transpose_ain.implementation = 'cuBLAS' state.add_edge(ain, None, transpose_ain, '_inp', Memlet.from_array(*ain_arr)) state.add_edge(transpose_ain, '_out', binout1, None, Memlet.from_array(*binout_arr)) transpose_out = Transpose('BT', dtype=dtype) transpose_out.implementation = 'cuBLAS' state.add_edge(binout2, None, transpose_out, '_inp', Memlet.from_array(*binout_arr)) state.add_edge(transpose_out, '_out', binout3, None, Memlet.from_array(*bout_arr)) else: binout1 = state.add_access('_b') binout2 = state.in_edges(me)[0].src binout3 = state.out_edges(mx)[0].dst state.add_nedge(ain, binout1, Memlet.from_array(*ain_arr)) info = state.add_write('_info') state.add_memlet_path(binout1, potrf_node, dst_conn="_xin", memlet=Memlet.from_array(*binout_arr)) state.add_memlet_path(potrf_node, info, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(potrf_node, binout2, src_conn="_xout", memlet=Memlet.from_array(*binout_arr)) return sdfg
def apply(self, _, sdfg: sd.SDFG): # Obtain loop information guard: sd.SDFGState = self.loop_guard body: sd.SDFGState = self.loop_begin # Obtain iteration variable, range, and stride itervar, (start, end, step), _ = find_for_loop(sdfg, guard, body) forward_loop = step > 0 for node in body.nodes(): if isinstance(node, nodes.MapEntry): map_entry = node if isinstance(node, nodes.MapExit): map_exit = node # nest map's content in sdfg map_subgraph = body.scope_subgraph(map_entry, include_entry=False, include_exit=False) nsdfg = helpers.nest_state_subgraph(sdfg, body, map_subgraph, full_data=True) # replicate loop in nested sdfg new_before, new_guard, new_after = nsdfg.sdfg.add_loop( before_state=None, loop_state=nsdfg.sdfg.nodes()[0], loop_end_state=None, after_state=None, loop_var=itervar, initialize_expr=f'{start}', condition_expr=f'{itervar} <= {end}' if forward_loop else f'{itervar} >= {end}', increment_expr=f'{itervar} + {step}' if forward_loop else f'{itervar} - {abs(step)}') # remove outer loop before_guard_edge = nsdfg.sdfg.edges_between(new_before, new_guard)[0] for e in nsdfg.sdfg.out_edges(new_guard): if e.dst is new_after: guard_after_edge = e else: guard_body_edge = e for body_inedge in sdfg.in_edges(body): if body_inedge.src is guard: guard_body_edge.data.assignments.update(body_inedge.data.assignments) sdfg.remove_edge(body_inedge) for body_outedge in sdfg.out_edges(body): sdfg.remove_edge(body_outedge) for guard_inedge in sdfg.in_edges(guard): before_guard_edge.data.assignments.update(guard_inedge.data.assignments) guard_inedge.data.assignments = {} sdfg.add_edge(guard_inedge.src, body, guard_inedge.data) sdfg.remove_edge(guard_inedge) for guard_outedge in sdfg.out_edges(guard): if guard_outedge.dst is body: guard_body_edge.data.assignments.update(guard_outedge.data.assignments) else: guard_after_edge.data.assignments.update(guard_outedge.data.assignments) guard_outedge.data.condition = CodeBlock("1") sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data) sdfg.remove_edge(guard_outedge) sdfg.remove_node(guard) if itervar in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[itervar] if itervar in sdfg.symbols: del sdfg.symbols[itervar] # Add missing data/symbols for s in nsdfg.sdfg.free_symbols: if s in nsdfg.symbol_mapping: continue if s in sdfg.symbols: nsdfg.symbol_mapping[s] = s elif s in sdfg.arrays: desc = sdfg.arrays[s] access = body.add_access(s) conn = nsdfg.sdfg.add_datadesc(s, copy.deepcopy(desc)) nsdfg.sdfg.arrays[s].transient = False nsdfg.add_in_connector(conn) body.add_memlet_path(access, map_entry, nsdfg, memlet=Memlet.from_array(s, desc), dst_conn=conn) else: raise NotImplementedError(f"Free symbol {s} is neither a symbol nor data.") to_delete = set() for s in nsdfg.symbol_mapping: if s not in nsdfg.sdfg.free_symbols: to_delete.add(s) for s in to_delete: del nsdfg.symbol_mapping[s] # propagate scope for correct volumes scope_tree = ScopeTree(map_entry, map_exit) scope_tree.parent = ScopeTree(None, None) # The first execution helps remove apperances of symbols # that are now defined only in the nested SDFG in memlets. propagation.propagate_memlets_scope(sdfg, body, scope_tree) for s in to_delete: if helpers.is_symbol_unused(sdfg, s): sdfg.remove_symbol(s) from dace.transformation.interstate import RefineNestedAccess transformation = RefineNestedAccess() transformation.setup_match(sdfg, 0, sdfg.node_id(body), {RefineNestedAccess.nsdfg: body.node_id(nsdfg)}, 0) transformation.apply(body, sdfg) # Second propagation for refined accesses. propagation.propagate_memlets_scope(sdfg, body, scope_tree)
def _make_sdfg_getrs(node, parent_state, parent_sdfg, implementation): arr_desc = node.validate(parent_sdfg, parent_state) if node.overwrite: in_shape, in_dtype, in_strides, n = arr_desc else: (in_shape, in_dtype, in_strides, out_shape, out_dtype, out_strides, n) = arr_desc dtype = in_dtype sdfg = dace.SDFG("{l}_sdfg".format(l=node.label)) a_arr = sdfg.add_array('_ain', in_shape, dtype=in_dtype, strides=in_strides) if not node.overwrite: ain_arr = a_arr a_arr = sdfg.add_array('_ainout', [n, n], dtype=in_dtype, transient=True) b_arr = sdfg.add_array('_aout', out_shape, dtype=out_dtype, strides=out_strides) else: b_arr = sdfg.add_array('_b', [n, n], dtype=dtype, transient=True) ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True) info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True) state = sdfg.add_state("{l}_state".format(l=node.label)) getrf_node = Getrf('getrf') getrf_node.implementation = implementation getrs_node = Getrs('getrs') getrs_node.implementation = implementation if node.overwrite: ain = state.add_read('_ain') ainout = state.add_access('_ain') aout = state.add_write('_ain') bin_name = '_b' bout = state.add_write('_b') state.add_nedge(bout, aout, Memlet.from_array(*a_arr)) else: a = state.add_read('_ain') ain = state.add_read('_ainout') ainout = state.add_access('_ainout') # aout = state.add_write('_aout') state.add_nedge(a, ain, Memlet.from_array(*ain_arr)) bin_name = '_aout' bout = state.add_access('_aout') _, _, mx = state.add_mapped_tasklet( '_eye_', dict(i="0:n", j="0:n"), {}, '_out = (i == j) ? 1 : 0;', dict(_out=Memlet.simple(bin_name, 'i, j')), language=dace.dtypes.Language.CPP, external_edges=True) bin = state.out_edges(mx)[0].dst ipiv = state.add_access('_pivots') info1 = state.add_write('_info') info2 = state.add_write('_info') state.add_memlet_path(ain, getrf_node, dst_conn="_xin", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(getrf_node, info1, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getrf_node, ipiv, src_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getrf_node, ainout, src_conn="_xout", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(ainout, getrs_node, dst_conn="_a", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(bin, getrs_node, dst_conn="_rhs_in", memlet=Memlet.from_array(*b_arr)) state.add_memlet_path(ipiv, getrs_node, dst_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getrs_node, info2, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getrs_node, bout, src_conn="_rhs_out", memlet=Memlet.from_array(*b_arr)) return sdfg
def _make_sdfg(node, parent_state, parent_sdfg, implementation): arr_desc = node.validate(parent_sdfg, parent_state) if node.overwrite: in_shape, in_dtype, in_strides, n = arr_desc else: (in_shape, in_dtype, in_strides, out_shape, out_dtype, out_strides, n) = arr_desc dtype = in_dtype sdfg = dace.SDFG("{l}_sdfg".format(l=node.label)) a_arr = sdfg.add_array('_ain', in_shape, dtype=in_dtype, strides=in_strides) if not node.overwrite: ain_arr = a_arr a_arr = sdfg.add_array('_aout', out_shape, dtype=out_dtype, strides=out_strides) ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True) info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True) state = sdfg.add_state("{l}_state".format(l=node.label)) getrf_node = Getrf('getrf') getrf_node.implementation = implementation getri_node = Getri('getri') getri_node.implementation = implementation if node.overwrite: ain = state.add_read('_ain') ainout = state.add_access('_ain') aout = state.add_write('_ain') else: a = state.add_read('_ain') ain = state.add_read('_aout') ainout = state.add_access('_aout') aout = state.add_write('_aout') state.add_nedge(a, ain, Memlet.from_array(*ain_arr)) ipiv = state.add_access('_pivots') info1 = state.add_write('_info') info2 = state.add_write('_info') state.add_memlet_path(ain, getrf_node, dst_conn="_xin", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(getrf_node, info1, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getrf_node, ipiv, src_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getrf_node, ainout, src_conn="_xout", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(ainout, getri_node, dst_conn="_xin", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(ipiv, getri_node, dst_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getri_node, info2, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getri_node, aout, src_conn="_xout", memlet=Memlet.from_array(*a_arr)) return sdfg
def apply(self, sdfg: SDFG): subgraph = self.subgraph_view(sdfg) entry_states_in, entry_states_out = self.get_entry_states( sdfg, subgraph) _, exit_states_out = self.get_exit_states(sdfg, subgraph) entry_state_in = entry_states_in.pop() entry_state_out = entry_states_out.pop() \ if len(entry_states_out) > 0 else None exit_state_out = exit_states_out.pop() \ if len(exit_states_out) > 0 else None launch_state = None entry_guard_state = None exit_guard_state = None # generate entry guard state if needed if self.include_in_assignment and entry_state_out is not None: entry_edge = sdfg.edges_between(entry_state_out, entry_state_in)[0] if len(entry_edge.data.assignments) > 0: entry_guard_state = sdfg.add_state( label='{}kernel_entry_guard'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) sdfg.add_edge(entry_state_out, entry_guard_state, InterstateEdge(entry_edge.data.condition)) sdfg.add_edge( entry_guard_state, entry_state_in, InterstateEdge(None, entry_edge.data.assignments)) sdfg.remove_edge(entry_edge) # Update SubgraphView new_node_list = subgraph.nodes() new_node_list.append(entry_guard_state) subgraph = SubgraphView(sdfg, new_node_list) launch_state = sdfg.add_state_before( entry_guard_state, label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # generate exit guard state if exit_state_out is not None: exit_guard_state = sdfg.add_state_before( exit_state_out, label='{}kernel_exit_guard'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # Update SubgraphView new_node_list = subgraph.nodes() new_node_list.append(exit_guard_state) subgraph = SubgraphView(sdfg, new_node_list) if launch_state is None: launch_state = sdfg.add_state_before( exit_state_out, label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # If the launch state doesn't exist at this point then there is no other # states outside of the kernel, so create a stand alone launch state if launch_state is None: assert (entry_state_in is None and exit_state_out is None) launch_state = sdfg.add_state(label='{}kernel_launch'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) # create sdfg for kernel and fill it with states and edges from # ssubgraph dfg will be nested at the end kernel_sdfg = SDFG( '{}kernel'.format(self.kernel_prefix + '_' if self.kernel_prefix != '' else '')) edges = subgraph.edges() for edge in edges: kernel_sdfg.add_edge(edge.src, edge.dst, edge.data) # Setting entry node in nested SDFG if no entry guard was created if entry_guard_state is None: kernel_sdfg.start_state = kernel_sdfg.node_id(entry_state_in) for state in subgraph: state.parent = kernel_sdfg # remove the now nested nodes from the outer sdfg and make sure the # launch state is properly connected to remaining states sdfg.remove_nodes_from(subgraph.nodes()) if entry_state_out is not None \ and len(sdfg.edges_between(entry_state_out, launch_state)) == 0: sdfg.add_edge(entry_state_out, launch_state, InterstateEdge()) if exit_state_out is not None \ and len(sdfg.edges_between(launch_state, exit_state_out)) == 0: sdfg.add_edge(launch_state, exit_state_out, InterstateEdge()) # Handle data for kernel kernel_data = set(node.data for state in kernel_sdfg for node in state.nodes() if isinstance(node, nodes.AccessNode)) # move Streams and Register data into the nested SDFG # normal data will be added as kernel argument kernel_args = [] for data in kernel_data: if (isinstance(sdfg.arrays[data], dace.data.Stream) or (isinstance(sdfg.arrays[data], dace.data.Array) and sdfg.arrays[data].storage == StorageType.Register)): kernel_sdfg.add_datadesc(data, sdfg.arrays[data]) del sdfg.arrays[data] else: copy_desc = copy.deepcopy(sdfg.arrays[data]) copy_desc.transient = False copy_desc.storage = StorageType.Default kernel_sdfg.add_datadesc(data, copy_desc) kernel_args.append(data) # read only data will be passed as input, writeable data will be passed # as 'output' otherwise kernel cannot write to data kernel_args_read = set() kernel_args_write = set() for data in kernel_args: data_accesses_read_only = [ node.access == dtypes.AccessType.ReadOnly for state in kernel_sdfg for node in state if isinstance(node, nodes.AccessNode) and node.data == data ] if all(data_accesses_read_only): kernel_args_read.add(data) else: kernel_args_write.add(data) # Kernel SDFG is complete at this point if self.validate: kernel_sdfg.validate() # Filling launch state with nested SDFG, map and access nodes map_entry, map_exit = launch_state.add_map( '{}kernel_launch_map'.format( self.kernel_prefix + '_' if self.kernel_prefix != '' else ''), dict(ignore='0'), schedule=ScheduleType.GPU_Persistent, ) nested_sdfg = launch_state.add_nested_sdfg( kernel_sdfg, sdfg, kernel_args_read, kernel_args_write, ) # Create and connect read only data access nodes for arg in kernel_args_read: read_node = launch_state.add_read(arg) launch_state.add_memlet_path(read_node, map_entry, nested_sdfg, dst_conn=arg, memlet=Memlet.from_array( arg, sdfg.arrays[arg])) # Create and connect writable data access nodes for arg in kernel_args_write: write_node = launch_state.add_write(arg) launch_state.add_memlet_path(nested_sdfg, map_exit, write_node, src_conn=arg, memlet=Memlet.from_array( arg, sdfg.arrays[arg])) # Transformation is done if self.validate: sdfg.validate()
def _make_sdfg_getrs(node, parent_state, parent_sdfg, implementation): arr_desc = node.validate(parent_sdfg, parent_state) (ain_shape, ain_dtype, ain_strides, bin_shape, bin_dtype, bin_strides, out_shape, out_dtype, out_strides, n, rhs) = arr_desc dtype = ain_dtype sdfg = dace.SDFG("{l}_sdfg".format(l=node.label)) ain_arr = sdfg.add_array('_ain', ain_shape, dtype=ain_dtype, strides=ain_strides) ainout_arr = sdfg.add_array('_ainout', [n, n], dtype=ain_dtype, transient=True) bin_arr = sdfg.add_array('_bin', bin_shape, dtype=bin_dtype, strides=bin_strides) binout_shape = [n, rhs] if implementation == 'cuSolverDn': binout_shape = [rhs, n] binout_arr = sdfg.add_array('_binout', binout_shape, dtype=out_dtype, transient=True) bout_arr = sdfg.add_array('_bout', out_shape, dtype=out_dtype, strides=out_strides) ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True) info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True) state = sdfg.add_state("{l}_state".format(l=node.label)) getrf_node = Getrf('getrf') getrf_node.implementation = implementation getrs_node = Getrs('getrs') getrs_node.implementation = implementation ain = state.add_read('_ain') ainout1 = state.add_read('_ainout') ainout2 = state.add_access('_ainout') bin = state.add_read('_bin') binout1 = state.add_read('_binout') binout2 = state.add_read('_binout') bout = state.add_access('_bout') if implementation == 'cuSolverDn': transpose_ain = Transpose('AT', dtype=ain_dtype) transpose_ain.implementation = 'cuBLAS' state.add_edge(ain, None, transpose_ain, '_inp', Memlet.from_array(*ain_arr)) state.add_edge(transpose_ain, '_out', ainout1, None, Memlet.from_array(*ainout_arr)) transpose_bin = Transpose('bT', dtype=bin_dtype) transpose_bin.implementation = 'cuBLAS' state.add_edge(bin, None, transpose_bin, '_inp', Memlet.from_array(*bin_arr)) state.add_edge(transpose_bin, '_out', binout1, None, Memlet.from_array(*binout_arr)) transpose_out = Transpose('XT', dtype=bin_dtype) transpose_out.implementation = 'cuBLAS' state.add_edge(binout2, None, transpose_out, '_inp', Memlet.from_array(*binout_arr)) state.add_edge(transpose_out, '_out', bout, None, Memlet.from_array(*bout_arr)) else: state.add_nedge(ain, ainout1, Memlet.from_array(*ain_arr)) state.add_nedge(bin, binout1, Memlet.from_array(*bin_arr)) state.add_nedge(binout2, bout, Memlet.from_array(*bout_arr)) ipiv = state.add_access('_pivots') info1 = state.add_write('_info') info2 = state.add_write('_info') state.add_memlet_path(ainout1, getrf_node, dst_conn="_xin", memlet=Memlet.from_array(*ainout_arr)) state.add_memlet_path(getrf_node, info1, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getrf_node, ipiv, src_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getrf_node, ainout2, src_conn="_xout", memlet=Memlet.from_array(*ainout_arr)) state.add_memlet_path(ainout2, getrs_node, dst_conn="_a", memlet=Memlet.from_array(*ainout_arr)) state.add_memlet_path(binout1, getrs_node, dst_conn="_rhs_in", memlet=Memlet.from_array(*binout_arr)) state.add_memlet_path(ipiv, getrs_node, dst_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getrs_node, info2, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getrs_node, binout2, src_conn="_rhs_out", memlet=Memlet.from_array(*binout_arr)) return sdfg
def test_duplicate_codegen(): # Unfortunately I have to generate this graph manually, as doing it with the python # frontend wouldn't result in the node ordering that we want sdfg = dace.SDFG("dup") state = sdfg.add_state() c_task = state.add_tasklet("c_task", inputs={"c"}, outputs={"d"}, code='d = c') e_task = state.add_tasklet("e_task", inputs={"a", "d"}, outputs={"e"}, code="e = a + d") f_task = state.add_tasklet("f_task", inputs={"b", "d"}, outputs={"f"}, code="f = b + d") _, A_arr = sdfg.add_array("A", [ 1, ], dace.float32) _, B_arr = sdfg.add_array("B", [ 1, ], dace.float32) _, C_arr = sdfg.add_array("C", [ 1, ], dace.float32) _, D_arr = sdfg.add_array("D", [ 1, ], dace.float32) _, E_arr = sdfg.add_array("E", [ 1, ], dace.float32) _, F_arr = sdfg.add_array("F", [ 1, ], dace.float32) A = state.add_read("A") B = state.add_read("B") C = state.add_read("C") D = state.add_access("D") E = state.add_write("E") F = state.add_write("F") state.add_edge(C, None, c_task, "c", Memlet.from_array("C", C_arr)) state.add_edge(c_task, "d", D, None, Memlet.from_array("D", D_arr)) state.add_edge(A, None, e_task, "a", Memlet.from_array("A", A_arr)) state.add_edge(B, None, f_task, "b", Memlet.from_array("B", B_arr)) state.add_edge(D, None, f_task, "d", Memlet.from_array("D", D_arr)) state.add_edge(D, None, e_task, "d", Memlet.from_array("D", D_arr)) state.add_edge(e_task, "e", E, None, Memlet.from_array("E", E_arr, wcr="lambda x, y: x + y")) state.add_edge(f_task, "f", F, None, Memlet.from_array("F", F_arr, wcr="lambda x, y: x + y")) A = np.array([1], dtype=np.float32) B = np.array([1], dtype=np.float32) C = np.array([1], dtype=np.float32) D = np.array([1], dtype=np.float32) E = np.zeros_like(A) F = np.zeros_like(A) sdfg(A=A, B=B, C=C, D=D, E=E, F=F) assert E[0] == 2 assert F[0] == 2