def check_gradient(): input = torch.randn(N, inC, inH, inW).cuda() input.requires_grad = True offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() # offset.data.zero_() # offset.data -= 0.5 offset.requires_grad = True mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda() # mask.data.zero_() mask.requires_grad = True mask = torch.sigmoid(mask) weight = torch.randn(outC, inC, kH, kW).cuda() weight.requires_grad = True bias = torch.rand(outC).cuda() bias.requires_grad = True func = DCNv2Function(stride=1, padding=1, dilation=1, deformable_groups=deformable_groups) print( gradcheck(func, (input, offset, mask, weight, bias), eps=1e-3, atol=1e-3, rtol=1e-2))
def forward(self, input): out = self.conv_offset_mask(input) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups) return func(input, offset, mask, self.weight, self.bias)
def forward(self, input, fea): '''input: input features for deformable conv fea: other features used for generating offsets and mask''' out = self.conv_offset_mask(fea) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) offset_mean = torch.mean(torch.abs(offset)) # if offset_mean > 100: # logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean)) mask = torch.sigmoid(mask) func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups) return func(input, offset, mask, self.weight, self.bias)
def forward(self, input, offset, mask): func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups) return func(input, offset, mask, self.weight, self.bias)