def forward_out_desc_with_name(forward_node: nd.Node, context: BackwardContext, name) -> dt.Data: """ Find the descriptor of the data that connects to output connector `name`. :param forward_node: the node. :param context: the backward context. :param name: the output connector name. :return: the descriptor of the data that connects to connector `name`. """ return utils.out_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, name)
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: new_shape = out_desc_with_name(node, state, sdfg, "reshaped").shape node.remove_in_connector("shape") shape_node = in_edge_with_name(node, state, "shape").src constant_folding.remove_node_and_computation(sdfg, state, shape_node) def prog(data, reshaped): reshaped[:] = np.reshape(data, new_shape) return program_for_node(prog, sdfg, state, node)
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: input_desc = in_desc_with_name(node, state, sdfg, "input") output_desc = out_desc_with_name(node, state, sdfg, "output") if (input_desc.dtype == output_desc.dtype): def prog(input, output): output[:] = input else: def prog(input, output): output[:] = dace.elementwise(lambda x: x, input) return program_for_node(prog, sdfg, state, node)
def forward_can_be_applied(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> bool: if (in_desc_with_name(node, state, sdfg, "input").dtype == out_desc_with_name( node, state, sdfg, "output").dtype): return True target_type = node.to try: converters.onnx_tensor_type_to_typeclass(target_type) except ValueError: return False return True
def program_for_node(program, sdfg: SDFG, state: SDFGState, node: onnx_op.ONNXOp) -> SDFG: """ Expand a function to a dace program. The dtypes for the arguments will be extracted by matching the parameter names to edges. """ input_names = node.schema.non_variadic_inputs() variadic_input_names = node.schema.variadic_inputs() output_names = node.schema.non_variadic_outputs() variadic_output_names = node.schema.variadic_outputs() if set(input_names).intersection(output_names): # this is currently the case for only one onnx op raise ValueError( "program_for_node cannot be applied on nodes of this type;" " '{}' is both an input and an output".format( next(input_names.intersection(output_names)))) params = inspect.signature(program).parameters annotations = {} for name, param in params.items(): if name in input_names or ("__" in name and parse_variadic_param(name)[0] in variadic_input_names): annotations[name] = in_desc_with_name(node, state, sdfg, name) elif name in output_names or ("__" in name and parse_variadic_param(name)[0] in variadic_output_names): annotations[name] = out_desc_with_name(node, state, sdfg, name) else: raise ValueError( "'{}' was not found as an input or output for {}".format( name, node.schema.name)) program.__annotations__ = annotations result = DaceProgram(program, (), {}, False, dace.DeviceType.CPU) result.name = node.label + "_expansion" sdfg = result.to_sdfg() if node.schedule in [dtypes.ScheduleType.GPU_Default ] + dtypes.GPU_SCHEDULES: sdfg.apply_gpu_transformations() return sdfg
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) nsdfg = dace.SDFG(node.label + "_expansion") nstate = nsdfg.add_state() for e in node.iter_inputs_in_onnx_order(state): nsdfg.add_datadesc( e.dst_conn, in_desc_with_name(node, state, sdfg, e.dst_conn)) for e in node.iter_outputs_in_onnx_order(state): nsdfg.add_datadesc( e.src_conn, out_desc_with_name(node, state, sdfg, e.src_conn)) create_einsum_sdfg(None, nsdfg, nstate, node.equation.replace(" ", ""), *(e.dst_conn for e in node.iter_inputs_in_onnx_order(state)), output="Output") return nsdfg
def backward( forward_node: Node, context: BackwardContext, given_gradients: typing.List[typing.Optional[str]], required_gradients: typing.List[typing.Optional[str]] ) -> typing.Tuple[Node, BackwardResult]: reduction_type = detect_reduction_type(forward_node.wcr) if len(given_gradients) != 1: raise AutoDiffException( "recieved invalid SDFG: reduce node {} should have exactly one output edge" .format(forward_node)) if len(required_gradients) != 1: raise AutoDiffException( "recieved invalid SDFG: reduce node {} should have exactly one input edge" .format(forward_node)) input_name = next(iter(required_gradients)) in_desc = in_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, input_name) output_name = next(iter(given_gradients)) out_desc = out_desc_with_name(forward_node, context.forward_state, context.forward_sdfg, output_name) all_axes: typing.List[int] = list(range(len(in_desc.shape))) reduce_axes: typing.List[ int] = all_axes if forward_node.axes is None else forward_node.axes non_reduce_axes: typing.List[int] = [ i for i in all_axes if i not in reduce_axes ] result = BackwardResult.empty() if reduction_type is dtypes.ReductionType.Sum: # in this case, we need to simply scatter the grad across the axes that were reduced sdfg = SDFG("_reverse_" + str(reduction_type).replace(".", "_") + "_") state = sdfg.add_state() rev_input_conn_name = "input_gradient" rev_output_conn_name = "output_gradient" result.required_grad_names[output_name] = rev_output_conn_name result.given_grad_names[input_name] = rev_input_conn_name _, rev_input_arr = sdfg.add_array(rev_input_conn_name, shape=out_desc.shape, dtype=out_desc.dtype) _, rev_output_arr = sdfg.add_array(rev_output_conn_name, shape=in_desc.shape, dtype=in_desc.dtype) state.add_mapped_tasklet( "_distribute_grad_" + str(reduction_type).replace(".", "_") + "_", { "i" + str(i): "0:{}".format(shape) for i, shape in enumerate(in_desc.shape) }, { "__in": Memlet.simple( rev_input_conn_name, "0" if forward_node.axes is None else ",".join( "i" + str(i) for i in non_reduce_axes)) }, "__out = __in", { "__out": Memlet.simple(rev_output_conn_name, ",".join("i" + str(i) for i in all_axes), wcr_str="lambda x, y: x + y") }, external_edges=True) return context.backward_state.add_nested_sdfg( sdfg, None, {rev_input_conn_name}, {rev_output_conn_name}), result else: raise AutoDiffException( "Unsupported reduction type '{}'".format(reduction_type))
def forward(node: onnx_op.ONNXOp, state: SDFGState, sdfg: SDFG) -> typing.Union[Node, SDFG]: node.validate(sdfg, state) A_desc = in_desc_with_name(node, state, sdfg, "A") B_desc = in_desc_with_name(node, state, sdfg, "B") Y_desc = out_desc_with_name(node, state, sdfg, "Y") input0_dim = A_desc.shape input1_dim = B_desc.shape # list containing letters from z-a letters = [chr(ord('z') - i) for i in range(26)] # i j k are used for the last dimensions letters = [l for l in letters if l not in ['i', 'j', 'k']] if len(input0_dim) == 1: if len(input1_dim) != 2: raise ValueError("invalid dimensions") arg1 = 'k' arg2 = 'kj' result = 'j' elif len(input1_dim) == 1: if len(input0_dim) != 2: raise ValueError("invalid dimensions") arg1 = 'ik' arg2 = 'k' result = 'i' else: # build the einsum. The last two dimensions are always just the matrix multiply einsum # dace will later specialize to a batched matmul if possible arg1 = 'ik' arg2 = 'kj' result = 'ij' if input0_dim[-2] != input0_dim[-1]: if dace.symbolic.issymbolic(input0_dim[-2]): log.warning( f"overriding symbol {input0_dim[-2]} with value {input1_dim[-1]} in descriptor of input A of node {node}" ) new_shape = list(A_desc.shape) new_shape[-1] = input1_dim[-2] A_desc.shape = new_shape elif dace.symbolic.issymbolic(input1_dim[-1]): log.warning( f"overriding symbol {input0_dim[-1]} with value {input0_dim[-2]} in descriptor of input B of node {node}" ) new_shape = list(B_desc.shape) new_shape[-2] = input0_dim[-1] B_desc.shape = new_shape input0_dim = input0_dim[:-2] input1_dim = input1_dim[:-2] for dim0, dim1 in itertools.zip_longest(reversed(input0_dim), reversed(input1_dim)): if dim0 is None: # only dim0 exists letter = letters.pop() arg2 = letter + arg2 result = letter + result elif dim1 is None: # only dim1 exists letter = letters.pop() arg1 = letter + arg1 result = letter + result else: # both exist letter = letters.pop() arg1 = letter + arg1 arg2 = letter + arg2 result = letter + result einsum_str = '{},{}->{}'.format(arg1, arg2, result) # we lower to an ONNXEinsum node instead straight to the dace einsum to make the autodiff simpler nsdfg = dace.SDFG(node.label + "_expansion") nstate = nsdfg.add_state() einsum_node: nodes.LibraryNode = onnx_op.ONNXEinsum( node.label + "_einsum_expansion", equation=einsum_str) nstate.add_node(einsum_node) einsum_node.add_in_connector("Inputs__0") einsum_node.add_in_connector("Inputs__1") nsdfg.add_datadesc("A", copy.deepcopy(A_desc)) nsdfg.add_datadesc("B", copy.deepcopy(B_desc)) nsdfg.add_datadesc("Y", copy.deepcopy(Y_desc)) nsdfg.arrays["A"].transient = False nsdfg.arrays["B"].transient = False nsdfg.arrays["Y"].transient = False nstate.add_edge(nstate.add_read("A"), None, einsum_node, "Inputs__0", nsdfg.make_array_memlet("A")) nstate.add_edge(nstate.add_read("B"), None, einsum_node, "Inputs__1", nsdfg.make_array_memlet("B")) nstate.add_edge(einsum_node, "Output", nstate.add_write("Y"), None, nsdfg.make_array_memlet("Y")) return nsdfg