示例#1
0
    def forward(self, photos):
        H, W = photos.shape[2:4]

        feature_map = self.conv(photos)
        feature_map = self.features(feature_map)
        feature_map_2 = torch.nn.Upsample(scale_factor=0.5,
                                          mode='bilinear')(feature_map)
        feature_map_3 = torch.nn.Upsample(scale_factor=0.25,
                                          mode='bilinear')(feature_map)

        ori_map = self.insnorm_ori_1(self.conv_ori_1(feature_map))
        ori_map = self.insnorm_ori_2(self.conv_ori_2(ori_map))
        ori_map = L2Norm(self.conv_ori_3(ori_map), dim=1).permute(0, 2, 3, 1)

        # ori_maps_1 = L2Norm(self.conv_ori(feature_map), dim=1).permute(0, 2, 3, 1).unsqueeze(-2)
        feature_map_1 = F.leaky_relu(
            self.insnorm_1_1(self.conv_1_1(feature_map)))
        # ori_maps_2 = torch.nn.Upsample(size=(H, W), mode='bilinear')(L2Norm(self.conv_ori(feature_map_2), dim=1)).permute(0, 2, 3, 1).unsqueeze(-2)
        feature_map_2 = F.leaky_relu(
            self.insnorm_1_1(self.conv_1_1(feature_map_2)))
        # ori_maps_3 = torch.nn.Upsample(size=(H, W), mode='bilinear')(L2Norm(self.conv_ori(feature_map_3), dim=1)).permute(0, 2, 3, 1).unsqueeze(-2)
        feature_map_3 = F.leaky_relu(
            self.insnorm_1_1(self.conv_1_1(feature_map_3)))

        score_maps = torch.cat(
            (
                feature_map_1,
                torch.nn.Upsample(size=(H, W), mode='bilinear')(feature_map_2),
                torch.nn.Upsample(size=(H, W), mode='bilinear')(feature_map_3),
            ),
            1,
        )  # (B, C, H, W)

        # ori_maps = torch.cat(
        #     (
        #         ori_maps_1,
        #         ori_maps_2,
        #         ori_maps_3,
        #     ),
        #     -2,
        # )

        score_maps = score_maps.permute(0, 2, 3, 1)
        scale_probs = soft_nms_3d(score_maps, ksize=15, com_strength=7.0)
        score_map, scale_map = soft_max_and_argmax_1d(
            input=scale_probs,
            orint_maps=None,
            dim=-1,
            scale_list=self.scale_list,
            keepdim=True,
            com_strength1=self.score_com_strength,
            com_strength2=self.scale_com_strength,
        )

        # import pdb
        # pdb.set_trace()
        return score_map, scale_map, ori_map
示例#2
0
    def forward(self, photos):
        H, W = photos.shape[2:4]
        feature_map = self.features(photos)
        feature_map_2 = torch.nn.Upsample(scale_factor=0.75,
                                          mode='bilinear',
                                          align_corners=True)(feature_map)
        feature_map_3 = torch.nn.Upsample(scale_factor=0.5,
                                          mode='bilinear',
                                          align_corners=True)(feature_map)

        feature_map_1 = F.leaky_relu(
            self.insnorm_1_1(self.conv_1_1(feature_map)))
        feature_map_2 = F.leaky_relu(
            self.insnorm_1_1(self.conv_1_1(feature_map_2)))
        feature_map_3 = F.leaky_relu(
            self.insnorm_1_1(self.conv_1_1(feature_map_3)))

        ori_map = self.insnorm_ori_1(
            self.conv_ori_1(
                torch.nn.Upsample(scale_factor=0.5,
                                  mode='bilinear',
                                  align_corners=True)(feature_map)))
        ori_map = L2Norm(torch.nn.Upsample(
            size=(H, W), mode='bilinear',
            align_corners=True)(self.insnorm_ori_2(self.conv_ori_2(ori_map))),
                         dim=1).permute(0, 2, 3, 1)

        score_maps = torch.cat(
            (
                feature_map_1,
                torch.nn.Upsample(size=(H, W),
                                  mode='bilinear',
                                  align_corners=True)(feature_map_2),
                torch.nn.Upsample(size=(H, W),
                                  mode='bilinear',
                                  align_corners=True)(feature_map_3),
            ),
            1,
        )  # (B, C, H, W)
        score_maps = score_maps.permute(0, 2, 3, 1)
        scale_probs = soft_nms_3d(score_maps, ksize=15, com_strength=7.0)
        score_map, scale_map = soft_max_and_argmax_1d(
            input=scale_probs,
            orint_maps=None,
            dim=-1,
            scale_list=self.scale_list,
            keepdim=True,
            com_strength1=self.score_com_strength,
            com_strength2=self.scale_com_strength,
        )
        # from matplotlib import pyplot as plt;plt.imshow(score_map[0, :, :, 0].cpu().detach().numpy());plt.show()
        # import pdb
        # pdb.set_trace()
        return score_map, scale_map, ori_map
示例#3
0
def soft_max_and_argmax_1d(input,
                           orint_maps,
                           scale_list,
                           com_strength1,
                           com_strength2,
                           dim=-1,
                           keepdim=True):
    """
    input should be pixel probability in each scale
    this function calculate the final pixel probability summary from all scale and each pixel correspond scale
    :param input: scale_probs(B, H, W, 10)
    :param orint_maps: (B, H, W, 10, 2)
    :param dim: final channel
    :param scale_list: scale space list
    :param keepdim: kepp dimension
    :param com_strength1: magnify argument of score
    :param com_strength2: magnify argument of scale
    :return: score_map(B, H, W, 1), scale_map(B, H, W, 1), (orint_map(B, H, W, 1, 2))
    """
    inputs_exp1 = torch.exp(
        com_strength1 * (input - torch.max(input, dim=dim, keepdim=True)[0]))
    input_softmax1 = inputs_exp1 / (
        inputs_exp1.sum(dim=dim, keepdim=True) + 1e-8)  # (B, H, W, 10)

    inputs_exp2 = torch.exp(
        com_strength2 * (input - torch.max(input, dim=dim, keepdim=True)[0]))
    input_softmax2 = inputs_exp2 / (
        inputs_exp2.sum(dim=dim, keepdim=True) + 1e-8)  # (B, H, W, 10)

    score_map = torch.sum(input * input_softmax1, dim=dim, keepdim=keepdim)

    scale_list_shape = [1] * len(input.size())
    scale_list_shape[dim] = -1
    scale_list = scale_list.view(scale_list_shape).to(input_softmax2.device)
    scale_map = torch.sum(scale_list * input_softmax2,
                          dim=dim,
                          keepdim=keepdim)

    if orint_maps is not None:
        orint_map = torch.sum(orint_maps * input_softmax1.unsqueeze(-1),
                              dim=dim - 1,
                              keepdim=keepdim)  # (B, H, W, 1, 2)
        orint_map = L2Norm(orint_map, dim=-1)
        return score_map, scale_map, orint_map
    else:
        return score_map, scale_map
示例#4
0
    def forward(self, photos):
        H, W = photos.shape[2:4]
        feature_map = F.leaky_relu(self.insnorm1(self.conv1(photos)))
        feature_map = F.leaky_relu(self.insnorm2(self.conv2(feature_map)))
        feature_map = F.leaky_relu(self.insnorm3(self.conv3(feature_map)))

        feature_map_2 = torch.nn.Upsample(scale_factor=0.75, mode='bilinear')(feature_map)
        feature_map_3 = torch.nn.Upsample(scale_factor=0.5, mode='bilinear')(feature_map)

        feature_map_1 = F.leaky_relu(self.insnorm_1_1(self.conv_1_1(feature_map)))
        feature_map_1 = F.leaky_relu(self.insnorm_1_2(self.conv_1_2(feature_map_1)))
        feature_map_1 = F.leaky_relu(self.insnorm_1_3(self.conv_1_3(feature_map_1)))

        feature_map_2 = F.leaky_relu(self.insnorm_2_1(self.conv_2_1(feature_map_2)))
        feature_map_2 = F.leaky_relu(self.insnorm_2_2(self.conv_2_2(feature_map_2)))
        feature_map_2 = F.leaky_relu(self.insnorm_2_3(self.conv_2_3(feature_map_2)))

        feature_map_3 = F.leaky_relu(self.insnorm_3_1(self.conv_3_1(feature_map_3)))
        feature_map_3 = F.leaky_relu(self.insnorm_3_2(self.conv_3_2(feature_map_3)))
        feature_map_3 = F.leaky_relu(self.insnorm_3_3(self.conv_3_3(feature_map_3)))

        score_maps = torch.cat(
            (
                feature_map_1,
                torch.nn.Upsample(size=(H, W), mode='bilinear')(feature_map_2),
                torch.nn.Upsample(size=(H, W), mode='bilinear')(feature_map_3),
            ),
            1,
        )  # (B, C, H, W)
        score_maps = score_maps.permute(0, 2, 3, 1)

        scale_probs = soft_nms_3d(score_maps, ksize=15, com_strength=3.0)

        score_map, scale_map = soft_max_and_argmax_1d(
            input=scale_probs,
            orint_maps=None,
            dim=-1,
            scale_list=self.scale_list,
            keepdim=True,
            com_strength1=self.score_com_strength,
            com_strength2=self.scale_com_strength,
        )
        ori_map = L2Norm(self.conv_ori(feature_map),dim=1).permute(0, 2, 3, 1)
        return score_map, scale_map, ori_map
示例#5
0
    def forward(self, photos):
        H, W = photos.shape[2:4]
        feature_map = self.features(photos)
        feature_map_2 = torch.nn.Upsample(scale_factor=0.75, mode='bilinear')(feature_map)
        feature_map_3 = torch.nn.Upsample(scale_factor=0.5, mode='bilinear')(feature_map)
        ori_map = L2Norm(torch.nn.Upsample(size=(H, W), mode='bilinear')(self.insnorm_ori(self.conv_ori(feature_map_3))), dim=1).permute(0, 2, 3, 1)

        feature_map_1 = F.leaky_relu(self.insnorm_1_1(self.conv_1_1(feature_map)))
        feature_map_2 = F.leaky_relu(self.insnorm_1_1(self.conv_1_1(feature_map_2)))
        feature_map_3 = F.leaky_relu(self.insnorm_1_1(self.conv_1_1(feature_map_3)))

        score_maps = torch.cat(
            (
                feature_map_1,
                torch.nn.Upsample(size=(H, W), mode='bilinear')(feature_map_2),
                torch.nn.Upsample(size=(H, W), mode='bilinear')(feature_map_3),
            ),
            1,
        )  # (B, C, H, W)
        return score_maps,ori_map
示例#6
0
    def forward(self, photos):
        # H, W = photos.shape[2:4]
        feature_map = self.features(photos)
        # feature_map_2 = torch.nn.Upsample(scale_factor=0.75, mode='bilinear', align_corners=True)(feature_map)
        # feature_map_3 = torch.nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)(feature_map)

        feature_map_1 = F.leaky_relu(self.insnorm(self.conv(feature_map)))
        # feature_map_2 = F.leaky_relu(self.insnorm(self.conv(feature_map_2)))
        # feature_map_3 = F.leaky_relu(self.insnorm(self.conv(feature_map_3)))

        # ori_map = L2Norm(torch.nn.Upsample(size=(H, W), mode='bilinear', align_corners=True)(
        #     self.insnorm_ori(self.conv_ori(feature_map_3))), dim=1).permute(0, 2, 3, 1)

        ori_map = L2Norm((self.insnorm_ori(self.conv_ori(feature_map))),
                         dim=1).permute(0, 2, 3, 1)

        score_maps = torch.cat(
            (
                feature_map_1,
                # torch.nn.Upsample(size=(H, W), mode='bilinear', align_corners=True)(feature_map_2),
                # torch.nn.Upsample(size=(H, W), mode='bilinear', align_corners=True)(feature_map_3),
            ),
            1,
        )  # (B, C, H, W)
        score_maps = score_maps.permute(0, 2, 3, 1)

        scale_probs = soft_nms_3d(score_maps, ksize=15, com_strength=7.0)
        score_map, scale_map = soft_max_and_argmax_1d(
            input=scale_probs,
            orint_maps=None,
            dim=-1,
            scale_list=self.scale_list,
            keepdim=True,
            com_strength1=self.score_com_strength,
            com_strength2=self.scale_com_strength,
        )

        return score_map, scale_map, ori_map
示例#7
0
    def gt_scale_orin(im2_scale, im2_orin, homo12, homo21):
        B, H, W, C = im2_scale.size()
        im2_cos, im2_sin = im2_orin.squeeze().chunk(chunks=2,
                                                    dim=-1)  # (B, H, W, 1)
        # im2_tan = im2_sin / im2_cos

        # each centX, centY, centZ is (B, H, W, 1)
        centX, centY, centZ = imgBatchXYZ(B, H,
                                          W).to(im2_scale.device).chunk(3,
                                                                        dim=3)
        """get im1w scale maps"""
        half_scale = im2_scale // 2
        centXYZ = torch.cat((centX, centY, centZ), dim=3)  # (B, H, W, 3)
        upXYZ = torch.cat((centX, centY - half_scale, centZ), dim=3)
        bottomXYZ = torch.cat((centX, centY + half_scale, centZ), dim=3)
        rightXYZ = torch.cat((centX + half_scale, centY, centZ), dim=3)
        leftXYZ = torch.cat((centX - half_scale, centY, centZ), dim=3)

        centXYw = transXYZ_2_to_1(centXYZ, homo21)  # (B, H, W, 2) (x, y)
        centXw, centYw = centXYw.chunk(chunks=2, dim=-1)  # (B, H, W, 1)
        centXYw = centXYw.long()
        upXYw = transXYZ_2_to_1(upXYZ, homo21).long()
        rightXYw = transXYZ_2_to_1(rightXYZ, homo21).long()
        bottomXYw = transXYZ_2_to_1(bottomXYZ, homo21).long()
        leftXYw = transXYZ_2_to_1(leftXYZ, homo21).long()

        upScale = MSD(upXYw, centXYw)
        rightScale = MSD(rightXYw, centXYw)
        bottomScale = MSD(bottomXYw, centXYw)
        leftScale = MSD(leftXYw, centXYw)
        centScale = (upScale + rightScale + bottomScale +
                     leftScale) / 4  # (B, H, W, 1)
        """get im1w orintation maps"""
        offset_x, offset_y = im2_scale * im2_cos, im2_scale * im2_sin  # (B, H, W, 1)
        offsetXYZ = torch.cat((centX + offset_x, centY + offset_y, centZ),
                              dim=3)
        offsetXYw = transXYZ_2_to_1(offsetXYZ, homo21)  # (B, H, W, 2) (x, y)
        offsetXw, offsetYw = offsetXYw.chunk(chunks=2, dim=-1)  # (B, H, W, 1)
        offset_ww, offset_hw = offsetXw - centXw, offsetYw - centYw  # (B, H, W, 1)
        offset_rw = (offset_ww**2 + offset_hw**2 + 1e-8).sqrt()
        # tan = offset_hw / (offset_ww + 1e-8)  # (B, H, W, 1)
        cos_w = offset_ww / (offset_rw + 1e-8)  # (B, H, W, 1)
        sin_w = offset_hw / (offset_rw + 1e-8)  # (B, H, W, 1)
        # atan_w = np.arctan(tan.cpu().detach())  # (B, H, W, 1)

        # get left scale by transXYZ_2_to_1
        map_xy_2_to_1 = transXYZ_2_to_1(centXYZ,
                                        homo12).round().long()  # (B, H, W, 2)
        x, y = map_xy_2_to_1.chunk(2, dim=3)  # each x and y is (B, H, W, 1)
        x = x.clamp(min=0, max=W - 1)
        y = y.clamp(min=0, max=H - 1)

        # (B, H, W, 1)
        im1w_scale = centScale[torch.arange(B)[:, None].repeat(1, H * W),
                               y.view(B, -1),
                               x.view(B, -1)].view(im2_scale.size())

        # (B, H, W, 1, 2)
        im1w_cos = cos_w[torch.arange(B)[:, None].repeat(1, H * W),
                         y.view(B, -1),
                         x.view(B, -1)].view(im2_cos.size())
        im1w_sin = sin_w[torch.arange(B)[:, None].repeat(1, H * W),
                         y.view(B, -1),
                         x.view(B, -1)].view(im2_sin.size())
        im1w_orin = torch.cat((im1w_cos[:, None], im1w_sin[:, None]), dim=-1)
        im1w_orin = L2Norm(im1w_orin, dim=-1).to(im2_orin.device)

        return im1w_scale, im1w_orin
    def forward(self, photos):

        # Extract score map in scale space from 3 to 21
        score_featmaps_s3 = F.leaky_relu(self.insnorm1(self.conv1(photos)))
        score_map_s3 = self.insnorm_s3(
            self.conv_s3(score_featmaps_s3)).permute(0, 2, 3, 1)
        orint_map_s3 = (L2Norm(self.conv_o3(score_featmaps_s3),
                               dim=1).permute(0, 2, 3, 1).unsqueeze(-2))

        score_featmaps_s5 = F.leaky_relu(
            self.insnorm2(self.conv2(score_featmaps_s3)))
        score_map_s5 = self.insnorm_s5(
            self.conv_s5(score_featmaps_s5)).permute(0, 2, 3, 1)
        orint_map_s5 = (L2Norm(self.conv_o5(score_featmaps_s5),
                               dim=1).permute(0, 2, 3, 1).unsqueeze(-2))
        score_featmaps_s5 = score_featmaps_s5 + score_featmaps_s3

        score_featmaps_s7 = F.leaky_relu(
            self.insnorm3(self.conv3(score_featmaps_s5)))
        score_map_s7 = self.insnorm_s7(
            self.conv_s7(score_featmaps_s7)).permute(0, 2, 3, 1)
        orint_map_s7 = (L2Norm(self.conv_o7(score_featmaps_s7),
                               dim=1).permute(0, 2, 3, 1).unsqueeze(-2))
        score_featmaps_s7 = score_featmaps_s7 + score_featmaps_s5

        score_featmaps_s9 = F.leaky_relu(
            self.insnorm4(self.conv4(score_featmaps_s7)))
        score_map_s9 = self.insnorm_s9(
            self.conv_s9(score_featmaps_s9)).permute(0, 2, 3, 1)
        orint_map_s9 = (L2Norm(self.conv_o9(score_featmaps_s9),
                               dim=1).permute(0, 2, 3, 1).unsqueeze(-2))
        score_featmaps_s9 = score_featmaps_s9 + score_featmaps_s7

        score_featmaps_s11 = F.leaky_relu(
            self.insnorm5(self.conv5(score_featmaps_s9)))
        score_map_s11 = self.insnorm_s11(
            self.conv_s11(score_featmaps_s11)).permute(0, 2, 3, 1)
        orint_map_s11 = (L2Norm(self.conv_o11(score_featmaps_s11),
                                dim=1).permute(0, 2, 3, 1).unsqueeze(-2))
        score_featmaps_s11 = score_featmaps_s11 + score_featmaps_s9

        score_featmaps_s13 = F.leaky_relu(
            self.insnorm6(self.conv6(score_featmaps_s11)))
        score_map_s13 = self.insnorm_s13(
            self.conv_s13(score_featmaps_s13)).permute(0, 2, 3, 1)
        orint_map_s13 = (L2Norm(self.conv_o13(score_featmaps_s13),
                                dim=1).permute(0, 2, 3, 1).unsqueeze(-2))
        score_featmaps_s13 = score_featmaps_s13 + score_featmaps_s11

        score_featmaps_s15 = F.leaky_relu(
            self.insnorm7(self.conv7(score_featmaps_s13)))
        score_map_s15 = self.insnorm_s15(
            self.conv_s15(score_featmaps_s15)).permute(0, 2, 3, 1)
        orint_map_s15 = (L2Norm(self.conv_o15(score_featmaps_s15),
                                dim=1).permute(0, 2, 3, 1).unsqueeze(-2))
        score_featmaps_s15 = score_featmaps_s15 + score_featmaps_s13

        score_featmaps_s17 = F.leaky_relu(
            self.insnorm8(self.conv8(score_featmaps_s15)))
        score_map_s17 = self.insnorm_s17(
            self.conv_s17(score_featmaps_s17)).permute(0, 2, 3, 1)
        orint_map_s17 = (L2Norm(self.conv_o17(score_featmaps_s17),
                                dim=1).permute(0, 2, 3, 1).unsqueeze(-2))
        score_featmaps_s17 = score_featmaps_s17 + score_featmaps_s15

        score_featmaps_s19 = F.leaky_relu(
            self.insnorm9(self.conv9(score_featmaps_s17)))
        score_map_s19 = self.insnorm_s19(
            self.conv_s19(score_featmaps_s19)).permute(0, 2, 3, 1)
        orint_map_s19 = (L2Norm(self.conv_o19(score_featmaps_s19),
                                dim=1).permute(0, 2, 3, 1).unsqueeze(-2))
        score_featmaps_s19 = score_featmaps_s19 + score_featmaps_s17

        score_featmaps_s21 = F.leaky_relu(
            self.insnorm10(self.conv10(score_featmaps_s19)))
        score_map_s21 = self.insnorm_s21(
            self.conv_s21(score_featmaps_s21)).permute(0, 2, 3, 1)
        orint_map_s21 = (L2Norm(self.conv_o21(score_featmaps_s21),
                                dim=1).permute(0, 2, 3, 1).unsqueeze(-2))

        score_maps = torch.cat(
            (
                score_map_s3,
                score_map_s5,
                score_map_s7,
                score_map_s9,
                score_map_s11,
                score_map_s13,
                score_map_s15,
                score_map_s17,
                score_map_s19,
                score_map_s21,
            ),
            -1,
        )  # (B, H, W, C)

        orint_maps = torch.cat(
            (
                orint_map_s3,
                orint_map_s5,
                orint_map_s7,
                orint_map_s9,
                orint_map_s11,
                orint_map_s13,
                orint_map_s15,
                orint_map_s17,
                orint_map_s19,
                orint_map_s21,
            ),
            -2,
        )  # (B, H, W, 10, 2)

        # get each pixel probability in all scale
        scale_probs = soft_nms_3d(score_maps, ksize=15, com_strength=3.0)

        # get each pixel probability summary from all scale space and correspond scale value
        score_map, scale_map, orint_map = soft_max_and_argmax_1d(
            input=scale_probs,
            orint_maps=orint_maps,
            dim=-1,
            scale_list=self.scale_list,
            keepdim=True,
            com_strength1=self.score_com_strength,
            com_strength2=self.scale_com_strength,
        )

        return score_map, scale_map, orint_map