def backward(ctx, grad_output): kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation input, weight = ctx.saved_tensors assert grad_output.is_cuda if not grad_output.is_contiguous(): grad_output = grad_output.contiguous() batch_size, input_channels, input_height, input_width = input.size() _, weight_channels, weight_height, weight_width = weight.size() output_height, output_width = grad_output.size()[2:] grad_input, grad_weight = None, None opt = dict(Dtype=Dtype(grad_output), num=batch_size, input_channels=input_channels, weight_channels=weight_channels, bottom_height=input_height, bottom_width=input_width, top_height=output_height, top_width=output_width, kernel_h=kernel_size[0], kernel_w=kernel_size[1], stride_h=stride[0], stride_w=stride[1], dilation_h=dilation[0], dilation_w=dilation[1], pad_h=padding[0], pad_w=padding[1]) with torch.cuda.device_of(input): if ctx.needs_input_grad[0]: grad_input = input.new(input.size()) n = grad_input.numel() opt['nthreads'] = n f = load_kernel('aggregation_zeropad_input_backward_kernel', _aggregation_zeropad_input_backward_kernel, **opt) f(block=(CUDA_NUM_THREADS, 1, 1), grid=(GET_BLOCKS(n), 1, 1), args=[ grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr() ], stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) if ctx.needs_input_grad[1]: grad_weight = weight.new(weight.size()) n = grad_weight.numel() // weight.shape[2] opt['nthreads'] = n f = load_kernel('aggregation_zeropad_weight_backward_kernel', _aggregation_zeropad_weight_backward_kernel, **opt) f(block=(CUDA_NUM_THREADS, 1, 1), grid=(GET_BLOCKS(n), 1, 1), args=[ grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr() ], stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) return grad_input, grad_weight, None, None, None, None
def forward(ctx, input, weight, kernel_size, stride, padding, dilation): kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation assert input.dim() == 4 and input.is_cuda and weight.is_cuda batch_size, input_channels, input_height, input_width = input.size() _, weight_channels, weight_height, weight_width = weight.size() output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) # print(output_height,output_width) assert output_height * output_width == weight_width # 初始化张量output output = input.new(batch_size, input_channels, output_height, output_width) n = output.numel() # print("n",n) with torch.cuda.device_of(input): f = load_kernel('aggregation_refpad_forward_kernel', _aggregation_refpad_forward_kernel, Dtype=Dtype(input), nthreads=n, num=batch_size, input_channels=input_channels, weight_channels=weight_channels, bottom_height=input_height, bottom_width=input_width, top_height=output_height, top_width=output_width, kernel_h=kernel_size[0], kernel_w=kernel_size[1], stride_h=stride[0], stride_w=stride[1], dilation_h=dilation[0], dilation_w=dilation[1], pad_h=padding[0], pad_w=padding[1]) f(block=(CUDA_NUM_THREADS, 1, 1), grid=(GET_BLOCKS(n), 1, 1), args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()], stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) ctx.save_for_backward(input, weight) return output
def backward(ctx, grad_output): kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation input, = ctx.saved_tensors assert grad_output.is_cuda if not grad_output.is_contiguous(): grad_output = grad_output.contiguous() batch_size, input_channels, input_height, input_width = input.size() output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) grad_input = None opt = dict(Dtype=Dtype(grad_output), num=batch_size, input_channels=input_channels, bottom_height=input_height, bottom_width=input_width, top_height=output_height, top_width=output_width, kernel_h=kernel_size[0], kernel_w=kernel_size[1], stride_h=stride[0], stride_w=stride[1], dilation_h=dilation[0], dilation_w=dilation[1], pad_h=padding[0], pad_w=padding[1]) with torch.cuda.device_of(input): if ctx.needs_input_grad[0]: grad_input = input.new(batch_size, input_channels, input_height + 2 * padding[0], input_width + 2 * padding[1]) n = grad_input.numel() opt['nthreads'] = n f = load_kernel('subtraction_refpad_input_backward_kernel', _subtraction_refpad_input_backward_kernel, **opt) f(block=(CUDA_NUM_THREADS, 1, 1), grid=(GET_BLOCKS(n), 1, 1), args=[grad_output.data_ptr(), grad_input.data_ptr()], stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) grad_input[:, :, padding[0] + 1:2 * padding[0] + 1, :] += torch.flip(grad_input[:, :, :padding[0], :], dims=[2]) grad_input[:, :, input_height - 1:input_height + padding[0] - 1, :] += torch.flip(grad_input[:, :, input_height + padding[0]:, :], dims=[2]) grad_input[:, :, :, padding[1] + 1:2 * padding[1] + 1] += torch.flip(grad_input[:, :, :, :padding[1]], dims=[3]) grad_input[:, :, :, input_width - 1:input_width + padding[1] - 1] += torch.flip(grad_input[:, :, :, input_width + padding[1]:], dims=[3]) grad_input = grad_input[:, :, padding[0]:padding[0] + input_height, padding[1]:padding[1] + input_width] return grad_input, None, None, None, None
def forward(ctx, input1, input2, kernel_size, stride, padding, dilation): kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation) ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation assert input1.dim() == 4 and input1.is_cuda batch_size, input_channels, input_height, input_width = input1.size() output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1) output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1) output = input1.new(batch_size, input_channels, kernel_size[0] * kernel_size[1], output_height * output_width) n = output.numel() // output.shape[2] with torch.cuda.device_of(input1): f = load_kernel('subtraction2_zeropad_forward_kernel', _subtraction2_zeropad_forward_kernel, Dtype=Dtype(input1), nthreads=n, num=batch_size, input_channels=input_channels, bottom_height=input_height, bottom_width=input_width, top_height=output_height, top_width=output_width, kernel_h=kernel_size[0], kernel_w=kernel_size[1], stride_h=stride[0], stride_w=stride[1], dilation_h=dilation[0], dilation_w=dilation[1], pad_h=padding[0], pad_w=padding[1]) f(block=(CUDA_NUM_THREADS, 1, 1), grid=(GET_BLOCKS(n), 1, 1), args=[input1.data_ptr(), input2.data_ptr(), output.data_ptr()], stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) ctx.save_for_backward(input1, input2) return output