Example #1
0
def check_gradient():

    input = torch.randn(N, inC, inH, inW).cuda()
    input.requires_grad = True

    offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda()
    # offset.data.zero_()
    # offset.data -= 0.5
    offset.requires_grad = True

    mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda()
    # mask.data.zero_()
    mask.requires_grad = True
    mask = torch.sigmoid(mask)

    weight = torch.randn(outC, inC, kH, kW).cuda()
    weight.requires_grad = True

    bias = torch.rand(outC).cuda()
    bias.requires_grad = True

    func = DCNv2Function(stride=1,
                         padding=1,
                         dilation=1,
                         deformable_groups=deformable_groups)

    print(
        gradcheck(func, (input, offset, mask, weight, bias),
                  eps=1e-3,
                  atol=1e-3,
                  rtol=1e-2))
Example #2
0
 def forward(self, input):
     out = self.conv_offset_mask(input)
     o1, o2, mask = torch.chunk(out, 3, dim=1)
     offset = torch.cat((o1, o2), dim=1)
     mask = torch.sigmoid(mask)
     func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups)
     return func(input, offset, mask, self.weight, self.bias)
Example #3
0
    def forward(self, input, fea):
        '''input: input features for deformable conv
        fea: other features used for generating offsets and mask'''
        out = self.conv_offset_mask(fea)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)

        offset_mean = torch.mean(torch.abs(offset))
        # if offset_mean > 100:
        #     logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))

        mask = torch.sigmoid(mask)

        func = DCNv2Function(self.stride, self.padding, self.dilation,
                             self.deformable_groups)
        return func(input, offset, mask, self.weight, self.bias)
Example #4
0
 def forward(self, input, offset, mask):
     func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups)
     return func(input, offset, mask, self.weight, self.bias)