def make_sdfg(implementation, dtype, id=0, in_shape=[n, n], out_shape=[n, n], in_subset="0:n, 0:n", out_subset="0:n, 0:n", overwrite=False, getri=True): sdfg = dace.SDFG("linalg_inv_{}_{}_{}".format(implementation, dtype.__name__, id)) sdfg.add_symbol("n", dace.int64) state = sdfg.add_state("dataflow") sdfg.add_array("xin", in_shape, dtype) if not overwrite: sdfg.add_array("xout", out_shape, dtype) xin = state.add_read("xin") if overwrite: xout = state.add_write("xin") else: xout = state.add_write("xout") inv_node = Inv("inv", overwrite_a=overwrite, use_getri=getri) inv_node.implementation = implementation state.add_memlet_path(xin, inv_node, dst_conn="_ain", memlet=Memlet.simple(xin, in_subset, num_accesses=n * n)) state.add_memlet_path(inv_node, xout, src_conn="_aout", memlet=Memlet.simple(xout, out_subset, num_accesses=n * n)) return sdfg
def make_sdfg(implementation, dtype, id=0, in_shape=[n, n], out_shape=[n, n], in_subset="0:n, 0:n", out_subset="0:n, 0:n"): sdfg = dace.SDFG("linalg_solve_{}_{}_{}".format(implementation, dtype.__name__, id)) sdfg.add_symbol("n", dace.int64) state = sdfg.add_state("dataflow") sdfg.add_array("ain", in_shape, dtype) sdfg.add_array("bin", out_shape, dtype) sdfg.add_array("bout", out_shape, dtype) ain = state.add_read("ain") bin = state.add_read("bin") bout = state.add_write("bout") solve_node = Solve("solve") solve_node.implementation = implementation state.add_memlet_path(ain, solve_node, dst_conn="_ain", memlet=Memlet.simple(ain, in_subset, num_accesses=n * n)) state.add_memlet_path(bin, solve_node, dst_conn="_bin", memlet=Memlet.simple(bin, out_subset, num_accesses=n * n)) state.add_memlet_path(solve_node, bout, src_conn="_bout", memlet=Memlet.simple(bout, out_subset, num_accesses=n * n)) return sdfg
def _make_sdfg(node, parent_state, parent_sdfg, implementation): inp_desc, inp_shape, out_desc, out_shape = node.validate(parent_sdfg, parent_state) dtype = inp_desc.dtype sdfg = dace.SDFG("{l}_sdfg".format(l=node.label)) ain_arr = sdfg.add_array('_a', inp_shape, dtype=dtype, strides=inp_desc.strides) bout_arr = sdfg.add_array('_b', out_shape, dtype=dtype, strides=out_desc.strides) info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True) if implementation == 'cuSolverDn': binout_arr = sdfg.add_array('_bt', inp_shape, dtype=dtype, transient=True) else: binout_arr = bout_arr state = sdfg.add_state("{l}_state".format(l=node.label)) potrf_node = Potrf('potrf', lower=node.lower) potrf_node.implementation = implementation _, me, mx = state.add_mapped_tasklet('_uzero_', dict(__i="0:%s" % out_shape[0], __j="0:%s" % out_shape[1]), dict(_inp=Memlet.simple('_b', '__i, __j')), '_out = (__i < __j) ? 0 : _inp;', dict(_out=Memlet.simple('_b', '__i, __j')), language=dace.dtypes.Language.CPP, external_edges=True) ain = state.add_read('_a') if implementation == 'cuSolverDn': binout1 = state.add_access('_bt') binout2 = state.add_access('_bt') binout3 = state.in_edges(me)[0].src bout = state.out_edges(mx)[0].dst transpose_ain = Transpose('AT', dtype=dtype) transpose_ain.implementation = 'cuBLAS' state.add_edge(ain, None, transpose_ain, '_inp', Memlet.from_array(*ain_arr)) state.add_edge(transpose_ain, '_out', binout1, None, Memlet.from_array(*binout_arr)) transpose_out = Transpose('BT', dtype=dtype) transpose_out.implementation = 'cuBLAS' state.add_edge(binout2, None, transpose_out, '_inp', Memlet.from_array(*binout_arr)) state.add_edge(transpose_out, '_out', binout3, None, Memlet.from_array(*bout_arr)) else: binout1 = state.add_access('_b') binout2 = state.in_edges(me)[0].src binout3 = state.out_edges(mx)[0].dst state.add_nedge(ain, binout1, Memlet.from_array(*ain_arr)) info = state.add_write('_info') state.add_memlet_path(binout1, potrf_node, dst_conn="_xin", memlet=Memlet.from_array(*binout_arr)) state.add_memlet_path(potrf_node, info, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(potrf_node, binout2, src_conn="_xout", memlet=Memlet.from_array(*binout_arr)) return sdfg
def _make_sdfg_getrs(node, parent_state, parent_sdfg, implementation): arr_desc = node.validate(parent_sdfg, parent_state) if node.overwrite: in_shape, in_dtype, in_strides, n = arr_desc else: (in_shape, in_dtype, in_strides, out_shape, out_dtype, out_strides, n) = arr_desc dtype = in_dtype sdfg = dace.SDFG("{l}_sdfg".format(l=node.label)) a_arr = sdfg.add_array('_ain', in_shape, dtype=in_dtype, strides=in_strides) if not node.overwrite: ain_arr = a_arr a_arr = sdfg.add_array('_ainout', [n, n], dtype=in_dtype, transient=True) b_arr = sdfg.add_array('_aout', out_shape, dtype=out_dtype, strides=out_strides) else: b_arr = sdfg.add_array('_b', [n, n], dtype=dtype, transient=True) ipiv_arr = sdfg.add_array('_pivots', [n], dtype=dace.int32, transient=True) info_arr = sdfg.add_array('_info', [1], dtype=dace.int32, transient=True) state = sdfg.add_state("{l}_state".format(l=node.label)) getrf_node = Getrf('getrf') getrf_node.implementation = implementation getrs_node = Getrs('getrs') getrs_node.implementation = implementation if node.overwrite: ain = state.add_read('_ain') ainout = state.add_access('_ain') aout = state.add_write('_ain') bin_name = '_b' bout = state.add_write('_b') state.add_nedge(bout, aout, Memlet.from_array(*a_arr)) else: a = state.add_read('_ain') ain = state.add_read('_ainout') ainout = state.add_access('_ainout') # aout = state.add_write('_aout') state.add_nedge(a, ain, Memlet.from_array(*ain_arr)) bin_name = '_aout' bout = state.add_access('_aout') _, _, mx = state.add_mapped_tasklet( '_eye_', dict(i="0:n", j="0:n"), {}, '_out = (i == j) ? 1 : 0;', dict(_out=Memlet.simple(bin_name, 'i, j')), language=dace.dtypes.Language.CPP, external_edges=True) bin = state.out_edges(mx)[0].dst ipiv = state.add_access('_pivots') info1 = state.add_write('_info') info2 = state.add_write('_info') state.add_memlet_path(ain, getrf_node, dst_conn="_xin", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(getrf_node, info1, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getrf_node, ipiv, src_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getrf_node, ainout, src_conn="_xout", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(ainout, getrs_node, dst_conn="_a", memlet=Memlet.from_array(*a_arr)) state.add_memlet_path(bin, getrs_node, dst_conn="_rhs_in", memlet=Memlet.from_array(*b_arr)) state.add_memlet_path(ipiv, getrs_node, dst_conn="_ipiv", memlet=Memlet.from_array(*ipiv_arr)) state.add_memlet_path(getrs_node, info2, src_conn="_res", memlet=Memlet.from_array(*info_arr)) state.add_memlet_path(getrs_node, bout, src_conn="_rhs_out", memlet=Memlet.from_array(*b_arr)) return sdfg
def make_sdfg(specialize): if specialize: sdfg = SDFG("histogram_fpga_parallel_{}_{}x{}".format( P.get(), H.get(), W.get())) else: sdfg = SDFG("histogram_fpga_parallel_{}".format(P.get())) copy_to_fpga_state = make_copy_to_fpga_state(sdfg) state = sdfg.add_state("compute") # Compute module nested_sdfg = make_compute_nested_sdfg(state) tasklet = state.add_nested_sdfg(nested_sdfg, sdfg, {"A_pipe_in"}, {"hist_pipe_out"}) A_pipes_out = state.add_stream("A_pipes", dtype, shape=(P, ), transient=True, storage=StorageType.FPGA_Local) A_pipes_in = state.add_stream("A_pipes", dtype, shape=(P, ), transient=True, storage=StorageType.FPGA_Local) hist_pipes_out = state.add_stream("hist_pipes", itype, shape=(P, ), transient=True, storage=StorageType.FPGA_Local) unroll_entry, unroll_exit = state.add_map( "unroll_compute", {"p": "0:P"}, schedule=dace.ScheduleType.FPGA_Device, unroll=True) state.add_memlet_path(unroll_entry, A_pipes_in, memlet=EmptyMemlet()) state.add_memlet_path(hist_pipes_out, unroll_exit, memlet=EmptyMemlet()) state.add_memlet_path(A_pipes_in, tasklet, dst_conn="A_pipe_in", memlet=Memlet.simple(A_pipes_in, "p", num_accesses="W*H")) state.add_memlet_path(tasklet, hist_pipes_out, src_conn="hist_pipe_out", memlet=Memlet.simple(hist_pipes_out, "p", num_accesses="num_bins")) # Read module a_device = state.add_array("A_device", (H, W), dtype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) read_entry, read_exit = state.add_map("read_map", { "h": "0:H", "w": "0:W:P" }, schedule=ScheduleType.FPGA_Device) a_val = state.add_array("A_val", (P, ), dtype, transient=True, storage=StorageType.FPGA_Local) read_unroll_entry, read_unroll_exit = state.add_map( "read_unroll", {"p": "0:P"}, schedule=ScheduleType.FPGA_Device, unroll=True) read_tasklet = state.add_tasklet("read", {"A_in"}, {"A_pipe"}, "A_pipe = A_in[p]") state.add_memlet_path(a_device, read_entry, a_val, memlet=Memlet(a_val, num_accesses=1, subset=Indices(["0"]), vector_length=P.get(), other_subset=Indices(["h", "w"]))) state.add_memlet_path(a_val, read_unroll_entry, read_tasklet, dst_conn="A_in", memlet=Memlet.simple(a_val, "0", veclen=P.get(), num_accesses=1)) state.add_memlet_path(read_tasklet, read_unroll_exit, read_exit, A_pipes_out, src_conn="A_pipe", memlet=Memlet.simple(A_pipes_out, "p")) # Write module hist_pipes_in = state.add_stream("hist_pipes", itype, shape=(P, ), transient=True, storage=StorageType.FPGA_Local) hist_device_out = state.add_array( "hist_device", (num_bins, ), itype, transient=True, storage=dace.dtypes.StorageType.FPGA_Global) merge_entry, merge_exit = state.add_map("merge", {"nb": "0:num_bins"}, schedule=ScheduleType.FPGA_Device) merge_reduce = state.add_reduce("lambda a, b: a + b", (0, ), "0", schedule=ScheduleType.FPGA_Device) state.add_memlet_path(hist_pipes_in, merge_entry, merge_reduce, memlet=Memlet.simple(hist_pipes_in, "0:P", num_accesses=P)) state.add_memlet_path(merge_reduce, merge_exit, hist_device_out, memlet=dace.memlet.Memlet.simple( hist_device_out, "nb")) copy_to_host_state = make_copy_to_host_state(sdfg) sdfg.add_edge(copy_to_fpga_state, state, dace.graph.edges.InterstateEdge()) sdfg.add_edge(state, copy_to_host_state, dace.graph.edges.InterstateEdge()) return sdfg
def backward( forward_node: Node, context: BackwardContext, given_gradients: typing.List[typing.Optional[str]], required_gradients: typing.List[typing.Optional[str]] ) -> typing.Tuple[Node, BackwardResult]: reduction_type = detect_reduction_type(forward_node.wcr) if len(given_gradients) != 1: raise AutoDiffException( "recieved invalid SDFG: reduce node {} should have exactly one output edge" .format(forward_node)) if len(required_gradients) != 1: raise AutoDiffException( "recieved invalid SDFG: reduce node {} should have exactly one input edge" .format(forward_node)) input_name = next(iter(required_gradients)) in_desc = in_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, input_name) output_name = next(iter(given_gradients)) out_desc = out_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, output_name) all_axes: typing.List[int] = list(range(len(in_desc.shape))) reduce_axes: typing.List[ int] = all_axes if forward_node.axes is None else forward_node.axes non_reduce_axes: typing.List[int] = [ i for i in all_axes if i not in reduce_axes ] result = BackwardResult.empty() if reduction_type is dtypes.ReductionType.Sum: # in this case, we need to simply scatter the grad across the axes that were reduced sdfg = SDFG("_reverse_" + str(reduction_type).replace(".", "_") + "_") state = sdfg.add_state() rev_input_conn_name = "input_gradient" rev_output_conn_name = "output_gradient" result.required_grad_names[output_name] = rev_output_conn_name result.given_grad_names[input_name] = rev_input_conn_name _, rev_input_arr = sdfg.add_array(rev_input_conn_name, shape=out_desc.shape, dtype=out_desc.dtype) _, rev_output_arr = sdfg.add_array(rev_output_conn_name, shape=in_desc.shape, dtype=in_desc.dtype) state.add_mapped_tasklet( "_distribute_grad_" + str(reduction_type).replace(".", "_") + "_", { "i" + str(i): "0:{}".format(shape) for i, shape in enumerate(in_desc.shape) }, { "__in": Memlet.simple( rev_input_conn_name, "0" if forward_node.axes is None else ",".join( "i" + str(i) for i in non_reduce_axes)) }, "__out = __in", { "__out": Memlet.simple(rev_output_conn_name, ",".join("i" + str(i) for i in all_axes), wcr_str="lambda x, y: x + y") }, external_edges=True) return context.backward_state.add_nested_sdfg( sdfg, None, {rev_input_conn_name}, {rev_output_conn_name}), result else: raise AutoDiffException( "Unsupported reduction type '{}'".format(reduction_type))