def test_forward_output_on_cpu(self):
        device = torch.device("cpu")
        N, C, H, W = shape = 1, 1, 5, 5
        kernel_size = 3
        padding = 1

        inputs = torch.arange(np.prod(shape),
                              dtype=torch.float32).reshape(*shape).to(device)

        offset_channels = kernel_size * kernel_size * 2
        offset = torch.full((N, offset_channels, H, W),
                            0.5,
                            dtype=torch.float32).to(device)

        # Test DCN v1 on cpu
        deform = DeformConv(C, C, kernel_size=kernel_size,
                            padding=padding).to(device)
        deform.weight = torch.nn.Parameter(torch.ones_like(deform.weight))
        output = deform(inputs, offset)
        output = output.detach().cpu().numpy()
        deform_results = np.array([
            [30, 41.25, 48.75, 45, 28.75],
            [62.25, 81, 90, 80.25, 50.25],
            [99.75, 126, 135, 117.75, 72.75],
            [105, 131.25, 138.75, 120, 73.75],
            [71.75, 89.25, 93.75, 80.75, 49.5],
        ])
        self.assertTrue(np.allclose(output.flatten(),
                                    deform_results.flatten()))
    def test_forward_output(self):
        device = torch.device("cuda")
        N, C, H, W = shape = 1, 1, 5, 5
        kernel_size = 3
        padding = 1

        inputs = torch.arange(np.prod(shape),
                              dtype=torch.float32).reshape(*shape).to(device)
        """
        0  1  2   3 4
        5  6  7   8 9
        10 11 12 13 14
        15 16 17 18 19
        20 21 22 23 24
        """
        offset_channels = kernel_size * kernel_size * 2
        offset = torch.full((N, offset_channels, H, W),
                            0.5,
                            dtype=torch.float32).to(device)

        # Test DCN v1
        deform = DeformConv(C, C, kernel_size=kernel_size,
                            padding=padding).to(device)
        deform.weight = torch.nn.Parameter(torch.ones_like(deform.weight))
        output = deform(inputs, offset)
        output = output.detach().cpu().numpy()
        deform_results = np.array([
            [30, 41.25, 48.75, 45, 28.75],
            [62.25, 81, 90, 80.25, 50.25],
            [99.75, 126, 135, 117.75, 72.75],
            [105, 131.25, 138.75, 120, 73.75],
            [71.75, 89.25, 93.75, 80.75, 49.5],
        ])
        self.assertTrue(np.allclose(output.flatten(),
                                    deform_results.flatten()))

        # Test DCN v2
        mask_channels = kernel_size * kernel_size
        mask = torch.full((N, mask_channels, H, W), 0.5,
                          dtype=torch.float32).to(device)
        modulate_deform = ModulatedDeformConv(C,
                                              C,
                                              kernel_size,
                                              padding=padding,
                                              bias=False).to(device)
        modulate_deform.weight = deform.weight
        output = modulate_deform(inputs, offset, mask)
        output = output.detach().cpu().numpy()
        self.assertTrue(
            np.allclose(output.flatten(),
                        deform_results.flatten() * 0.5))
Exemple #3
0
    def make_feature_adaptive_layers(self):
        assert self.feat_adaption in FEAT_ADAPTION_METHODS, \
            "{} {}".format(self.feat_adaption, type(self.feat_adaption))
        in_channels = self.feat_channels
        if self.feat_adaption == "Empty":
            cls_conv = nn.Conv2d(in_channels, self.feat_channels, 3, 1, 1)
            loc_conv_refine = nn.Conv2d(in_channels, self.loc_feat_channels, 3,
                                        1, 1)
        # assertion before, so simplify the judgements below
        else:
            cls_conv = DeformConv(in_channels, self.feat_channels, 3, 1, 1)
            loc_conv_refine = DeformConv(in_channels, self.loc_feat_channels,
                                         3, 1, 1)

        return cls_conv, loc_conv_refine
    def test_raise_exception(self):
        device = torch.device("cuda")
        N, C, H, W = shape = 1, 1, 3, 3
        kernel_size = 3
        padding = 1

        inputs = torch.rand(shape, dtype=torch.float32).to(device)
        offset_channels = kernel_size * kernel_size  # This is wrong channels for offset
        offset = torch.randn((N, offset_channels, H, W),
                             dtype=torch.float32).to(device)
        deform = DeformConv(C, C, kernel_size=kernel_size,
                            padding=padding).to(device)
        self.assertRaises(RuntimeError, deform, inputs, offset)

        offset_channels = kernel_size * kernel_size * 2
        offset = torch.randn((N, offset_channels, H, W),
                             dtype=torch.float32).to(device)
        mask_channels = kernel_size * kernel_size * 2  # This is wrong channels for mask
        mask = torch.ones((N, mask_channels, H, W),
                          dtype=torch.float32).to(device)
        modulate_deform = ModulatedDeformConv(C,
                                              C,
                                              kernel_size,
                                              padding=padding,
                                              bias=False).to(device)
        self.assertRaises(RuntimeError, modulate_deform, inputs, offset, mask)
    def test_small_input(self):
        device = torch.device("cuda")
        for kernel_size in [3, 5]:
            padding = kernel_size // 2
            N, C, H, W = shape = (1, 1, kernel_size - 1, kernel_size - 1)

            inputs = torch.rand(shape).to(
                device)  # input size is smaller than kernel size

            offset_channels = kernel_size * kernel_size * 2
            offset = torch.randn((N, offset_channels, H, W),
                                 dtype=torch.float32).to(device)
            deform = DeformConv(C, C, kernel_size=kernel_size,
                                padding=padding).to(device)
            output = deform(inputs, offset)
            self.assertTrue(output.shape == inputs.shape)

            mask_channels = kernel_size * kernel_size
            mask = torch.ones((N, mask_channels, H, W),
                              dtype=torch.float32).to(device)
            modulate_deform = ModulatedDeformConv(C,
                                                  C,
                                                  kernel_size,
                                                  padding=padding,
                                                  bias=False).to(device)
            output = modulate_deform(inputs, offset, mask)
            self.assertTrue(output.shape == inputs.shape)
 def __init__(
     self,
     in_channels,
     out_channels,
     kernel_size=3,
     stride=1,
     padding=1,
     dilation=1,
     groups=1,
     deformable_groups=1,
     norm=None,
     activation=None,
 ):
     super(DeformConvWithOffset, self).__init__()
     self.dcn = DeformConv(
         in_channels,
         out_channels,
         kernel_size,
         stride,
         padding,
         dilation,
         groups,
         deformable_groups,
         norm=norm,
         activation=activation,
     )
     self.offset = Conv2d(
         in_channels,
         deformable_groups * 2 * kernel_size * kernel_size,
         kernel_size=kernel_size,
         stride=stride,
         padding=padding,
     )
Exemple #7
0
 def __init__(self, chi, cho, norm='BN'):
     super(_DeformConv, self).__init__()
     self.actf = nn.Sequential(get_norm(norm, cho), nn.ReLU(inplace=True))
     if DCNV1:
         self.offset = Conv2d(chi,
                              18,
                              kernel_size=3,
                              stride=1,
                              padding=1,
                              dilation=1)
         self.conv = DeformConv(chi,
                                cho,
                                kernel_size=(3, 3),
                                stride=1,
                                padding=1,
                                dilation=1,
                                deformable_groups=1)
     else:
         self.offset = Conv2d(chi,
                              27,
                              kernel_size=3,
                              stride=1,
                              padding=1,
                              dilation=1)
         self.conv = ModulatedDeformConv(chi,
                                         cho,
                                         kernel_size=3,
                                         stride=1,
                                         padding=1,
                                         dilation=1,
                                         deformable_groups=1)
     nn.init.constant_(self.offset.weight, 0)
     nn.init.constant_(self.offset.bias, 0)
Exemple #8
0
 def __init__(self, in_channels, out_channels, kernel_size =3, deformable_groups =4):
     super(FeatureAdaption, self).__init__()
     offset_channels = kernel_size* kernel_size*2
     self.conv_feat = nn.Conv2d(
         2, deformable_groups * offset_channels, 1, bias=False)
     self.conv_adaption = DeformConv(
         in_channels,
         out_channels,
         kernel_size=kernel_size,
         padding=(kernel_size - 1) // 2,
         deformable_groups=deformable_groups)
     self.relu = nn.ReLU(inplace = True)
Exemple #9
0
 def __init__(self, in_chs, out_chs, kernel_size=3, deformable_groups=1, activation='relu'):
     super(DeformLayer, self).__init__()
     self.deform_offset = conv3x3(in_chs, (2 * kernel_size ** 2) * deformable_groups)
     self.act = actFunc(activation)
     self.deform = DeformConv(
         in_chs,
         out_chs,
         kernel_size,
         stride=1,
         padding=1,
         deformable_groups=deformable_groups
     )
    def test_forward_output_on_cpu_equals_output_on_gpu(self):
        N, C, H, W = shape = 2, 4, 10, 10
        kernel_size = 3
        padding = 1

        for groups in [1, 2]:
            inputs = torch.arange(np.prod(shape),
                                  dtype=torch.float32).reshape(*shape)
            offset_channels = kernel_size * kernel_size * 2
            offset = torch.full((N, offset_channels, H, W),
                                0.5,
                                dtype=torch.float32)

            deform_gpu = DeformConv(C,
                                    C,
                                    kernel_size=kernel_size,
                                    padding=padding,
                                    groups=groups).to("cuda")
            deform_gpu.weight = torch.nn.Parameter(
                torch.ones_like(deform_gpu.weight))
            output_gpu = deform_gpu(inputs.to("cuda"),
                                    offset.to("cuda")).detach().cpu().numpy()

            deform_cpu = DeformConv(C,
                                    C,
                                    kernel_size=kernel_size,
                                    padding=padding,
                                    groups=groups).to("cpu")
            deform_cpu.weight = torch.nn.Parameter(
                torch.ones_like(deform_cpu.weight))
            output_cpu = deform_cpu(inputs.to("cpu"),
                                    offset.to("cpu")).detach().numpy()

        self.assertTrue(np.allclose(output_gpu.flatten(),
                                    output_cpu.flatten()))
Exemple #11
0
    def init_layers(self):
        self.cls_conv = nn.Sequential(*self.stacked_convs())
        self.reg_conv = nn.Sequential(*self.stacked_convs())

        self.deform_cls_conv = DeformConv(self.point_feat_channels,
                                          self.point_feat_channels,
                                          self.dcn_kernel, 1, self.dcn_pad)
        self.deform_reg_conv = DeformConv(self.point_feat_channels,
                                          self.point_feat_channels,
                                          self.dcn_kernel, 1, self.dcn_pad)

        points_out_dim = 4 if self.use_grid_points else 2 * self.num_points
        self.offsets_init = nn.Sequential(
            nn.Conv2d(self.point_feat_channels, self.point_feat_channels, 3, 1,
                      1), nn.ReLU(inplace=True),
            nn.Conv2d(self.point_feat_channels, points_out_dim, 1, 1, 0))

        self.offsets_refine = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(self.point_feat_channels, points_out_dim, 1, 1, 0))
        self.logits = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(self.point_feat_channels, self.cls_out_channels, 1, 1,
                      0))

        bias_init = float(-np.log((1 - 0.01) / 0.01))
        for modules in [
                self.cls_conv, self.reg_conv, self.offsets_init,
                self.offsets_refine, self.deform_cls_conv, self.deform_reg_conv
        ]:
            for layer in modules.modules():
                if isinstance(layer, nn.Conv2d):
                    torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
                    torch.nn.init.constant_(layer.bias, 0)

        for module in self.logits.modules():
            if hasattr(module, 'bias') and module.bias is not None:
                torch.nn.init.constant_(module.bias, bias_init)
    def test_repr(self):
        module = DeformConv(3, 10, kernel_size=3, padding=1, deformable_groups=2)
        correct_string = (
            "DeformConv(in_channels=3, out_channels=10, kernel_size=(3, 3), "
            "stride=(1, 1), padding=(1, 1), dilation=(1, 1), "
            "groups=1, deformable_groups=2, bias=False)"
        )
        self.assertEqual(repr(module), correct_string)

        module = ModulatedDeformConv(3, 10, kernel_size=3, padding=1, deformable_groups=2)
        correct_string = (
            "ModulatedDeformConv(in_channels=3, out_channels=10, kernel_size=(3, 3), "
            "stride=1, padding=1, dilation=1, groups=1, deformable_groups=2, bias=True)"
        )
        self.assertEqual(repr(module), correct_string)
Exemple #13
0
    def make_feature_adaptive_layers(self):
        if self.feat_adaptive is None or self.feat_adaptive == "none":
            self.offset_conv = None
            self.cls_conv = nn.Conv2d(self.feat_channels, self.feat_channels,
                                      3, 1, 1)
            self.loc_refine_conv = nn.Conv2d(self.feat_channels,
                                             self.feat_channels, 3, 1, 1)
        elif self.feat_adaptive == "unsupervised":
            self.offset_conv = nn.Conv2d(self.feat_channels, 18, 1, 1, 0)
            self.cls_conv = DeformConv(self.feat_channels, self.feat_channels,
                                       3, 1, 1)
            self.loc_refine_conv = DeformConv(self.feat_channels,
                                              self.feat_channels, 3, 1, 1)
        elif self.feat_adaptive == "split":
            self.offset_conv_cls = nn.Conv2d(self.feat_channels, 18, 1, 1, 0)
            self.offset_conv_loc = nn.Conv2d(self.feat_channels, 18, 1, 1, 0)
            self.cls_conv = DeformConv(self.feat_channels, self.feat_channels,
                                       3, 1, 1)
            self.loc_refine_conv = DeformConv(self.feat_channels,
                                              self.feat_channels, 3, 1, 1)
        else:
            assert self.feat_adaptive == "supervised", self.feat_adaptive

            self.dcn_kernel = 3
            self.dcn_pad = int((self.dcn_kernel - 1) / 2)
            dcn_base = np.arange(-self.dcn_pad,
                                 self.dcn_pad + 1).astype(np.float64)
            dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
            dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
            dcn_base_offset = np.stack([dcn_base_y, dcn_base_x],
                                       axis=1).reshape((-1))
            self.dcn_base_offset = torch.tensor(dcn_base_offset).view(
                1, -1, 1, 1)
            self.offset_conv = nn.Conv2d(self.feat_channels, 14, 1, 1, 0)
            self.cls_conv = DeformConv(self.feat_channels, self.feat_channels,
                                       3, 1, 1)
            self.loc_refine_conv = DeformConv(self.feat_channels,
                                              self.feat_channels, 3, 1, 1)
Exemple #14
0
    def __init__(self, cfg, input_shape: List[ShapeSpec]):
        super().__init__()
        head_params = cfg.MODEL.REPPOINTS
        self.in_channels = input_shape[0].channels
        self.num_classes = head_params.NUM_CLASSES
        self.feat_channels = head_params.FEAT_CHANNELS
        self.point_feat_channels = head_params.POINT_FEAT_CHANNELS
        self.stacked_convs = head_params.STACK_CONVS
        self.norm_mode = head_params.NORM_MODE
        self.num_points = head_params.NUM_POINTS
        self.gradient_mul = head_params.GRADIENT_MUL
        self.prior_prob = head_params.PRIOR_PROB

        self.dcn_kernel = int(np.sqrt(self.num_points))
        self.dcn_pad = int((self.dcn_kernel - 1) / 2)
        dcn_base = np.arange(-self.dcn_pad, self.dcn_pad + 1).astype(np.float64)
        dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
        dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
        dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape((-1))
        self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1)

        self.cls_convs = nn.ModuleList()
        self.reg_convs = nn.ModuleList()

        for i in range(self.stacked_convs):
            chn = self.in_channels if i == 0 else self.feat_channels
            self.cls_convs.append(
                nn.Conv2d(chn,
                          self.feat_channels,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          bias=False))
            if self.norm_mode == 'GN':
                self.cls_convs.append(
                    nn.GroupNorm(32 * self.feat_channels // 256, self.feat_channels))
            else:
                raise ValueError('The normalization method in reppoints head should be GN')
            self.cls_convs.append(nn.ReLU(inplace=True))

            self.reg_convs.append(
                nn.Conv2d(chn,
                          self.feat_channels,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          bias=False))
            if self.norm_mode == 'GN':
                self.reg_convs.append(
                    nn.GroupNorm(32 * self.feat_channels // 256, self.feat_channels))
            else:
                raise ValueError('The normalization method in reppoints head should be GN')
            self.reg_convs.append(nn.ReLU(inplace=True))

        point_out_dim = 2 * self.num_points
        self.reppoints_cls_conv = DeformConv(
            self.feat_channels, self.point_feat_channels, self.dcn_kernel, 1, self.dcn_pad)
        self.reppoints_cls_out = nn.Conv2d(self.feat_channels, self.num_classes, 1, 1, 0)
        self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels, self.point_feat_channels, 3, 1, 1)
        self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels, point_out_dim, 1, 1, 0)
        self.reppoints_pts_refine_conv = DeformConv(
            self.feat_channels, self.point_feat_channels, self.dcn_kernel, 1, self.dcn_pad)
        self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels, point_out_dim, 1, 1, 0)
        self.init_weights()
Exemple #15
0
    def __init__(self, cfg, input_shape: List[ShapeSpec]):
        super().__init__()
        # the same as RetinaNetHead, we replace the cls_score net to logits net, which utilizes the deform_conv
        # fmt: off
        in_channels = input_shape[0].channels
        num_classes = cfg.MODEL.RETINANET.NUM_CLASSES
        num_convs = cfg.MODEL.RETINANET.NUM_CONVS
        prior_prob = cfg.MODEL.RETINANET.PRIOR_PROB
        num_anchors = build_anchor_generator(cfg, input_shape).num_cell_anchors
        # fmt: on
        assert (
                len(set(num_anchors)) == 1
        ), "Using different number of anchors between levels is not currently supported!"
        num_anchors = num_anchors[0]

        cls_subnet = []
        bbox_subnet = []
        for _ in range(num_convs):
            cls_subnet.append(
                nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
            )
            cls_subnet.append(nn.ReLU())
            bbox_subnet.append(
                nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
            )
            bbox_subnet.append(nn.ReLU())

        self.cls_subnet = nn.Sequential(*cls_subnet)
        self.bbox_subnet = nn.Sequential(*bbox_subnet)
        #        self.cls_score = nn.Conv2d(
        #            in_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1
        #        )
        self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=3, stride=1, padding=1)

        # Initialization
        for modules in [self.cls_subnet, self.bbox_subnet, self.bbox_pred]:
            for layer in modules.modules():
                if isinstance(layer, nn.Conv2d):
                    torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
                    torch.nn.init.constant_(layer.bias, 0)

        # Deform_conv block, added as a second stage refinement. The implementation follows reppoints.
        self.dcn_kernel = 3
        self.dcn_pad = 1
        self.point_base_scale = 4
        self.gradient_mul = 0.1
        self.in_channels = in_channels
        self.num_anchors = num_anchors
        self.box2box_transform = Box2BoxTransform(weights=cfg.MODEL.RPN.BBOX_REG_WEIGHTS)
        dcn_base = np.arange(-self.dcn_pad,
                             self.dcn_pad + 1).astype(np.float64)
        dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
        dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
        dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape((-1))
        dcn_base_offset = torch.tensor(dcn_base_offset, dtype=torch.float32).view(1, -1, 1, 1)
        self.register_buffer("dcn_base_offset", dcn_base_offset)

        self.deform_cls_conv = DeformConv(
            self.in_channels,
            self.in_channels,
            self.dcn_kernel, 1, self.dcn_pad)
        self.deform_reg_conv = DeformConv(
            self.in_channels,
            self.in_channels,
            self.dcn_kernel, 1, self.dcn_pad)
        self.offsets_refine = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(self.num_anchors * self.in_channels,
                      num_anchors * 4,
                      1, 1, 0))
        self.logits = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(self.num_anchors * self.in_channels,
                      num_anchors * num_classes,
                      1, 1, 0))

        bias_init = float(-np.log((1 - 0.01) / 0.01))
        for modules in [
            self.offsets_refine,
            self.deform_cls_conv,
            self.deform_reg_conv]:
            for layer in modules.modules():
                if isinstance(layer, nn.Conv2d):
                    torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
                    torch.nn.init.constant_(layer.bias, 0)

        for module in self.logits.modules():
            if hasattr(module, 'bias') and module.bias is not None:
                torch.nn.init.constant_(module.bias, bias_init)
Exemple #16
0
    def _init_layers(self):
        """
        Initializes six convolutional layers for FCOS head and a scaling layer for bbox predictions.
        """
        activation = nn.ReLU()
        """ your code starts here """
        norm = nn.GroupNorm(
            32, self.in_channels) if self.norm_layer == 'GN' else None

        self.shared_convs = nn.Sequential(*[
            Conv2d(self.in_channels,
                   self.in_channels,
                   kernel_size=3,
                   padding=1,
                   norm=norm,
                   activation=activation) for _ in range(self.num_shared_convs)
        ])

        # Fanchen: cls_convs: [H*W*256 --> H*W*256] * 4 + [H*W*256 --> H*W*C(cls) / H*W*1(ctns)]
        self.cls_convs = nn.Sequential(*[
            Conv2d(self.in_channels,
                   self.in_channels,
                   kernel_size=3,
                   padding=1,
                   norm=norm,
                   activation=activation)
            for _ in range(self.num_stacked_convs)
        ]) if not self.use_deformable else nn.Sequential(*[
            Conv2d(self.in_channels,
                   self.in_channels,
                   kernel_size=3,
                   padding=1,
                   norm=norm,
                   activation=activation)
            for _ in range(self.num_stacked_convs - 1)
        ] + [
            DeformConv(self.in_channels,
                       self.in_channels,
                       kernel_size=3,
                       padding=1,
                       norm=norm,
                       activation=activation)
        ])
        # Fanchen: Following the original implement, the last layer of stacked convs is DeformConv (if applied)

        # Fanchen: reg_convs: [H*W*256 --> H*W*256] * 4 + [H*W*256 --> H*W*4]
        self.reg_convs = nn.Sequential(*[
            Conv2d(self.in_channels,
                   self.in_channels,
                   kernel_size=3,
                   padding=1,
                   norm=norm,
                   activation=activation)
            for _ in range(self.num_stacked_convs)
        ]) if not self.use_deformable else nn.Sequential(*[
            Conv2d(self.in_channels,
                   self.in_channels,
                   kernel_size=3,
                   padding=1,
                   norm=norm,
                   activation=activation)
            for _ in range(self.num_stacked_convs - 1)
        ] + [
            DeformConv(self.in_channels,
                       self.in_channels,
                       kernel_size=3,
                       padding=1,
                       norm=norm,
                       activation=activation)
        ])

        self.cls_logits = Conv2d(self.in_channels,
                                 self.num_classes,
                                 kernel_size=3,
                                 padding=1)
        self.bbox_pred = Conv2d(self.in_channels, 4, kernel_size=3, padding=1)
        self.centerness = Conv2d(self.in_channels, 1, kernel_size=3, padding=1)

        self.scales = nn.ModuleList([Scale() for _ in range(5)])
        """ your code ends here """
Exemple #17
0
import torch

from detectron2.layers import DeformConv

import init_paths
from slender_det.modeling.grid_generator import zero_center_grid, uniform_grid

torch.set_default_tensor_type('torch.cuda.FloatTensor')
# 3x3 conv with stride 1, padding 1
deform_conv = DeformConv(2, 1, 3, 1, 1)
torch.nn.init.constant_(deform_conv.weight, 1)

# 1, 2, 4, 4
grid = uniform_grid(4).unsqueeze(0).permute(0, 3, 1, 2)
grid = torch.stack([grid[:, 0], torch.zeros_like(grid[:, 0]) + 0.1], 1)

# 9, 2
offsets_1 = zero_center_grid(3).reshape(1, -1, 1, 1)
offsets_1 = offsets_1.repeat(1, 1, 4, 4)

offsets_2 = torch.zeros_like(offsets_1)

y_1 = deform_conv(grid, offsets_1)
y_2 = deform_conv(grid, offsets_2)

import ipdb
ipdb.set_trace()
    def __init__(self, cfg, input_shape: List[ShapeSpec]):
        """
        Arguments:
            in_channels (int): number of channels of the input feature
        """
        super(FCOSRepPointsHead, self).__init__()
        # TODO: Implement the sigmoid version first.
        # fmt: off
        in_channels = input_shape[0].channels
        num_classes = cfg.MODEL.FCOS.NUM_CLASSES
        self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
        self.norm_reg_targets = cfg.MODEL.FCOS.NORM_REG_TARGETS
        self.centerness_on_reg = cfg.MODEL.FCOS.CENTERNESS_ON_REG
        self.use_dcn_in_tower = cfg.MODEL.FCOS.USE_DCN_IN_TOWER
        self.use_dcn_v2 = cfg.MODEL.FCOS.USE_DCN_V2
        # fmt: on

        cls_tower = []
        bbox_tower = []
        for i in range(cfg.MODEL.FCOS.NUM_CONVS):
            use_dcn = False
            use_v2 = True
            if self.use_dcn_in_tower and i == cfg.MODEL.FCOS.NUM_CONVS - 1:
                conv_func = DFConv2d
                bias = False
                use_dcn = True
                if not self.use_dcn_v2:
                    use_v2 = False
            else:
                conv_func = nn.Conv2d
                bias = True

            if use_dcn and not use_v2:
                cls_tower.append(
                    conv_func(
                        in_channels, in_channels,
                        with_modulated_dcn=False, kernel_size=3, stride=1, padding=1, bias=bias
                    )
                )
            else:
                cls_tower.append(
                    conv_func(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
                )
            cls_tower.append(nn.GroupNorm(32, in_channels))
            cls_tower.append(nn.ReLU())

            if use_dcn and not use_v2:
                bbox_tower.append(
                    conv_func(
                        in_channels, in_channels,
                        with_modulated_dcn=False, kernel_size=3, stride=1, padding=1, bias=bias
                    )
                )
            else:
                bbox_tower.append(
                    conv_func(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias)
                )
            bbox_tower.append(nn.GroupNorm(32, in_channels))
            bbox_tower.append(nn.ReLU())

        self.add_module('cls_tower', nn.Sequential(*cls_tower))
        self.add_module('bbox_tower', nn.Sequential(*bbox_tower))

        # rep part
        self.point_feat_channels = in_channels
        self.num_points = 9
        self.dcn_kernel = int(np.sqrt(self.num_points))
        self.dcn_pad = int((self.dcn_kernel - 1) / 2)
        self.cls_out_channels = num_classes
        self.gradient_mul = 0.1
        dcn_base = np.arange(-self.dcn_pad,
                             self.dcn_pad + 1).astype(np.float64)
        dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
        dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
        dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape((-1))
        dcn_base_offset = torch.tensor(dcn_base_offset, dtype=torch.float32).view(1, -1, 1, 1)
        self.register_buffer("dcn_base_offset", dcn_base_offset)

        self.deform_cls_conv = DeformConv(
            self.point_feat_channels,
            self.point_feat_channels,
            self.dcn_kernel, 1, self.dcn_pad)
        self.deform_reg_conv = DeformConv(
            self.point_feat_channels,
            self.point_feat_channels,
            self.dcn_kernel, 1, self.dcn_pad)

        points_out_dim = 2 * self.num_points
        self.offsets_init = nn.Sequential(
            nn.Conv2d(self.point_feat_channels,
                      self.point_feat_channels,
                      3, 1, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.point_feat_channels,
                      points_out_dim,
                      1, 1, 0))

        self.offsets_refine = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(self.point_feat_channels,
                      points_out_dim,
                      1, 1, 0))
        self.logits = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(self.point_feat_channels,
                      self.cls_out_channels,
                      1, 1, 0))
        #        self.cls_logits = nn.Conv2d(in_channels, num_classes, kernel_size=3, stride=1, padding=1)
        #        self.bbox_pred = nn.Conv2d(in_channels, 4, kernel_size=3, stride=1, padding=1)
        self.centerness = nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1)
        
        self.ratio = nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1)

        # initialization
        for modules in [self.cls_tower, self.bbox_tower,
                        #                        self.cls_logits, self.bbox_pred,
                        self.offsets_init,
                        self.offsets_refine,
                        self.deform_cls_conv,
                        self.deform_reg_conv,
                        self.centerness]:
            for l in modules.modules():
                if isinstance(l, nn.Conv2d):
                    torch.nn.init.normal_(l.weight, std=0.01)
                    torch.nn.init.constant_(l.bias, 0)

        # initialize the bias for focal loss
        prior_prob = cfg.MODEL.FCOS.PRIOR_PROB
        bias_value = -math.log((1 - prior_prob) / prior_prob)
        #        torch.nn.init.constant_(self.cls_logits.bias, bias_value)
        for module in self.logits.modules():
            if hasattr(module, 'bias') and module.bias is not None:
                torch.nn.init.constant_(module.bias, bias_value)

        self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])
    def __init__(self, cfg, input_shape: List[ShapeSpec]):
        super().__init__()
        # the same as RetinaNetHead, we replace the cls_score net to logits net, which utilizes the deform_conv
        # fmt: off
        in_channels = input_shape[0].channels
        num_classes = cfg.MODEL.RETINANET.NUM_CLASSES
        num_convs = cfg.MODEL.RETINANET.NUM_CONVS
        prior_prob = cfg.MODEL.RETINANET.PRIOR_PROB
        # please add it in cfg!
        self.num_points = cfg.MODEL.PROPOSAL_GENERATOR.NUM_POINTS
        self.point_feat_channels = 256
        self.cls_out_channels = num_classes - 1  # maybe not right
        #        num_anchors      = build_anchor_generator(cfg, input_shape).num_cell_anchors
        #        # fmt: on
        #        assert (
        #            len(set(num_anchors)) == 1
        #        ), "Using different number of anchors between levels is not currently supported!"
        #        num_anchors = num_anchors[0]

        # dcn_base_offset
        self.dcn_kernel = int(np.sqrt(9))
        # 1 for kernel 3.
        self.dcn_pad = int((self.dcn_kernel - 1) / 2)
        dcn_base = np.arange(-self.dcn_pad,
                             self.dcn_pad + 1).astype(np.float64)
        dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
        dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
        dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape(
            (-1))
        dcn_base_offset = torch.tensor(dcn_base_offset,
                                       dtype=torch.float32).view(1, -1, 1, 1)
        self.register_buffer("dcn_base_offset", dcn_base_offset)

        self.gradient_mul = 0.1

        self.cls_conv = nn.Sequential(*self.stacked_convs())
        self.reg_conv = nn.Sequential(*self.stacked_convs())

        self.deform_cls_conv = DeformConv(self.point_feat_channels,
                                          self.point_feat_channels,
                                          self.dcn_kernel, 1, self.dcn_pad)
        self.deform_reg_conv = DeformConv(self.point_feat_channels,
                                          self.point_feat_channels,
                                          self.dcn_kernel, 1, self.dcn_pad)

        points_out_dim = 2 * self.num_points
        self.offsets_init = nn.Sequential(
            nn.Conv2d(self.point_feat_channels, self.point_feat_channels, 3, 1,
                      1), nn.ReLU(inplace=True),
            nn.Conv2d(self.point_feat_channels, points_out_dim, 1, 1, 0))

        self.offsets_refine = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(self.point_feat_channels, points_out_dim, 1, 1, 0))
        self.logits = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(self.point_feat_channels, self.cls_out_channels, 1, 1,
                      0))

        bias_init = float(-np.log((1 - 0.01) / 0.01))
        for modules in [
                self.cls_conv, self.reg_conv, self.offsets_init,
                self.offsets_refine, self.deform_cls_conv, self.deform_reg_conv
        ]:
            for layer in modules.modules():
                if isinstance(layer, nn.Conv2d):
                    torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
                    torch.nn.init.constant_(layer.bias, 0)

        for module in self.logits.modules():
            if hasattr(module, 'bias') and module.bias is not None:
                torch.nn.init.constant_(module.bias, bias_init)