def backward(ctx, grad_output): grad_output = grad_output.contiguous() if not grad_output.is_cuda: raise NotImplementedError input, offset, weight, bias = ctx.saved_tensors grad_input = torch.zeros_like(input) grad_offset = torch.zeros_like(offset) grad_weight = torch.zeros_like(weight) grad_bias = torch.zeros_like(bias) MDCONV_CUDA.deform_conv2d_backward_cuda( input, weight, bias, offset, grad_input, grad_weight, grad_bias, grad_offset, grad_output, weight.shape[2], weight.shape[3], ctx.stride[0], ctx.stride[1], ctx.padding[0], ctx.padding[1], ctx.dilation[0], ctx.dilation[1], ctx.groups, ctx.deformable_groups, ctx.in_step, ctx.with_bias) ''' int deform_conv2d_backward_cuda( at::Tensor input, at::Tensor weight, at::Tensor bias,at::Tensor offset, at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, at::Tensor grad_offset, at::Tensor grad_output, const int kernel_h,const int kernel_w,const int stride_h,const int stride_w, const int pad_h,const int pad_w,const int dilation_h,const int dilation_w, const int group,const int deformable_group, const int in_step,const bool with_bias); ''' if not ctx.with_bias: grad_bias = None # print(grad_input,grad_offset,grad_weight,grad_bias) return grad_input, grad_offset, grad_weight, grad_bias, None, None, None, None, None, None
def forward(ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1 , in_step=64): ctx.stride = _pair(stride) ctx.padding = _pair(padding) ctx.dilation = _pair(dilation) ctx.groups = groups ctx.deformable_groups = deformable_groups ctx.in_step = in_step ctx.with_bias = bias is not None if not ctx.with_bias: bias = input.new_empty(0) # fake tensor if not input.is_cuda: raise NotImplementedError if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: ctx.save_for_backward(input, offset, mask, weight, bias) output = input.new_empty(ModulatedDeformConv2dFunction._infer_shape(ctx, input, weight)) MDCONV_CUDA.modulated_deform_conv2d_forward_cuda( input, weight, bias, offset, mask, output, weight.shape[2],weight.shape[3], ctx.stride[0], ctx.stride[1], ctx.padding[0], ctx.padding[1], ctx.dilation[0],ctx.dilation[1], ctx.groups, ctx.deformable_groups,ctx.in_step, ctx.with_bias) ''' int modulated_deform_conv2d_forward_cuda( at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor offset, at::Tensor mask, at::Tensor output, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, const int dilation_h, const int dilation_w, const int group, const int deformable_group, const int in_step,const bool with_bias); ''' return output