def forward(ctx, input, weight, kernel_size, stride, padding, dilation):
     kernel_size, stride, padding, dilation = _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation)
     ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation = kernel_size, stride, padding, dilation
     assert input.dim() == 4 and input.is_cuda and weight.is_cuda
     batch_size, input_channels, input_height, input_width = input.size()
     _, weight_channels, weight_height, weight_width = weight.size()
     output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1)
     output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1)
     assert output_height * output_width == weight_width
     output = input.new(batch_size, input_channels, output_height, output_width)
     n = output.numel()
     with torch.cuda.device_of(input):
         f = load_kernel('aggregation_zeropad_forward_kernel', _aggregation_zeropad_forward_kernel, Dtype=Dtype(input), nthreads=n,
                         num=batch_size, input_channels=input_channels, weight_channels=weight_channels,
                         bottom_height=input_height, bottom_width=input_width,
                         top_height=output_height, top_width=output_width,
                         kernel_h=kernel_size[0], kernel_w=kernel_size[1],
                         stride_h=stride[0], stride_w=stride[1],
                         dilation_h=dilation[0], dilation_w=dilation[1],
                         pad_h=padding[0], pad_w=padding[1])
         f(block=(CUDA_NUM_THREADS, 1, 1),
           grid=(GET_BLOCKS(n), 1, 1),
           args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()],
           stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
     ctx.save_for_backward(input, weight)
     return output
示例#2
0
 def backward(ctx, grad_output):
     kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation
     input1, input2 = ctx.saved_tensors
     assert grad_output.is_cuda
     if not grad_output.is_contiguous():
         grad_output = grad_output.contiguous()
     batch_size, input_channels, input_height, input_width = input1.size()
     output_height = int((input_height + 2 * padding[0] - (dilation[0] * (kernel_size[0] - 1) + 1)) / stride[0] + 1)
     output_width = int((input_width + 2 * padding[1] - (dilation[1] * (kernel_size[1] - 1) + 1)) / stride[1] + 1)
     grad_input1, grad_input2 = None, None
     opt = dict(Dtype=Dtype(grad_output),
                num=batch_size, input_channels=input_channels,
                bottom_height=input_height, bottom_width=input_width,
                top_height=output_height, top_width=output_width,
                kernel_h=kernel_size[0], kernel_w=kernel_size[1],
                stride_h=stride[0], stride_w=stride[1],
                dilation_h=dilation[0], dilation_w=dilation[1],
                pad_h=padding[0], pad_w=padding[1])
     with torch.cuda.device_of(input1):
         if ctx.needs_input_grad[0]:
             grad_input1 = input1.new(input1.size())
             n = grad_input1.numel()
             opt['nthreads'] = n
             f = load_kernel('subtraction2_zeropad_input1_backward_kernel', _subtraction2_zeropad_input1_backward_kernel, **opt)
             f(block=(CUDA_NUM_THREADS, 1, 1),
               grid=(GET_BLOCKS(n), 1, 1),
               args=[grad_output.data_ptr(), grad_input1.data_ptr()],
               stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
     with torch.cuda.device_of(input2):
         if ctx.needs_input_grad[1]:
             grad_input2 = input2.new(input2.size())
             n = grad_input2.numel()
             opt['nthreads'] = n
             f = load_kernel('subtraction2_zeropad_input2_backward_kernel', _subtraction2_zeropad_input2_backward_kernel, **opt)
             f(block=(CUDA_NUM_THREADS, 1, 1),
               grid=(GET_BLOCKS(n), 1, 1),
               args=[grad_output.data_ptr(), grad_input2.data_ptr()],
               stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
     return grad_input1, grad_input2, None, None, None, None
 def backward(ctx, grad_output):
     kernel_size, stride, padding, dilation = ctx.kernel_size, ctx.stride, ctx.padding, ctx.dilation
     input, weight = ctx.saved_tensors
     assert grad_output.is_cuda
     if not grad_output.is_contiguous():
         grad_output = grad_output.contiguous()
     batch_size, input_channels, input_height, input_width = input.size()
     _, weight_channels, weight_height, weight_width = weight.size()
     output_height, output_width = grad_output.size()[2:]
     grad_input, grad_weight = None, None
     opt = dict(Dtype=Dtype(grad_output),
                num=batch_size, input_channels=input_channels, weight_channels=weight_channels,
                bottom_height=input_height, bottom_width=input_width,
                top_height=output_height, top_width=output_width,
                kernel_h=kernel_size[0], kernel_w=kernel_size[1],
                stride_h=stride[0], stride_w=stride[1],
                dilation_h=dilation[0], dilation_w=dilation[1],
                pad_h=padding[0], pad_w=padding[1])
     with torch.cuda.device_of(input):
         if ctx.needs_input_grad[0]:
             grad_input = input.new(input.size())
             n = grad_input.numel()
             opt['nthreads'] = n
             f = load_kernel('aggregation_zeropad_input_backward_kernel', _aggregation_zeropad_input_backward_kernel, **opt)
             f(block=(CUDA_NUM_THREADS, 1, 1),
               grid=(GET_BLOCKS(n), 1, 1),
               args=[grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr()],
               stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
         if ctx.needs_input_grad[1]:
             grad_weight = weight.new(weight.size())
             n = grad_weight.numel() // weight.shape[2]
             opt['nthreads'] = n
             f = load_kernel('aggregation_zeropad_weight_backward_kernel', _aggregation_zeropad_weight_backward_kernel, **opt)
             f(block=(CUDA_NUM_THREADS, 1, 1),
               grid=(GET_BLOCKS(n), 1, 1),
               args=[grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr()],
               stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
     return grad_input, grad_weight, None, None, None, None