def forward(self, photos): H, W = photos.shape[2:4] feature_map = self.conv(photos) feature_map = self.features(feature_map) feature_map_2 = torch.nn.Upsample(scale_factor=0.5, mode='bilinear')(feature_map) feature_map_3 = torch.nn.Upsample(scale_factor=0.25, mode='bilinear')(feature_map) ori_map = self.insnorm_ori_1(self.conv_ori_1(feature_map)) ori_map = self.insnorm_ori_2(self.conv_ori_2(ori_map)) ori_map = L2Norm(self.conv_ori_3(ori_map), dim=1).permute(0, 2, 3, 1) # ori_maps_1 = L2Norm(self.conv_ori(feature_map), dim=1).permute(0, 2, 3, 1).unsqueeze(-2) feature_map_1 = F.leaky_relu( self.insnorm_1_1(self.conv_1_1(feature_map))) # ori_maps_2 = torch.nn.Upsample(size=(H, W), mode='bilinear')(L2Norm(self.conv_ori(feature_map_2), dim=1)).permute(0, 2, 3, 1).unsqueeze(-2) feature_map_2 = F.leaky_relu( self.insnorm_1_1(self.conv_1_1(feature_map_2))) # ori_maps_3 = torch.nn.Upsample(size=(H, W), mode='bilinear')(L2Norm(self.conv_ori(feature_map_3), dim=1)).permute(0, 2, 3, 1).unsqueeze(-2) feature_map_3 = F.leaky_relu( self.insnorm_1_1(self.conv_1_1(feature_map_3))) score_maps = torch.cat( ( feature_map_1, torch.nn.Upsample(size=(H, W), mode='bilinear')(feature_map_2), torch.nn.Upsample(size=(H, W), mode='bilinear')(feature_map_3), ), 1, ) # (B, C, H, W) # ori_maps = torch.cat( # ( # ori_maps_1, # ori_maps_2, # ori_maps_3, # ), # -2, # ) score_maps = score_maps.permute(0, 2, 3, 1) scale_probs = soft_nms_3d(score_maps, ksize=15, com_strength=7.0) score_map, scale_map = soft_max_and_argmax_1d( input=scale_probs, orint_maps=None, dim=-1, scale_list=self.scale_list, keepdim=True, com_strength1=self.score_com_strength, com_strength2=self.scale_com_strength, ) # import pdb # pdb.set_trace() return score_map, scale_map, ori_map
def forward(self, photos): H, W = photos.shape[2:4] feature_map = self.features(photos) feature_map_2 = torch.nn.Upsample(scale_factor=0.75, mode='bilinear', align_corners=True)(feature_map) feature_map_3 = torch.nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)(feature_map) feature_map_1 = F.leaky_relu( self.insnorm_1_1(self.conv_1_1(feature_map))) feature_map_2 = F.leaky_relu( self.insnorm_1_1(self.conv_1_1(feature_map_2))) feature_map_3 = F.leaky_relu( self.insnorm_1_1(self.conv_1_1(feature_map_3))) ori_map = self.insnorm_ori_1( self.conv_ori_1( torch.nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)(feature_map))) ori_map = L2Norm(torch.nn.Upsample( size=(H, W), mode='bilinear', align_corners=True)(self.insnorm_ori_2(self.conv_ori_2(ori_map))), dim=1).permute(0, 2, 3, 1) score_maps = torch.cat( ( feature_map_1, torch.nn.Upsample(size=(H, W), mode='bilinear', align_corners=True)(feature_map_2), torch.nn.Upsample(size=(H, W), mode='bilinear', align_corners=True)(feature_map_3), ), 1, ) # (B, C, H, W) score_maps = score_maps.permute(0, 2, 3, 1) scale_probs = soft_nms_3d(score_maps, ksize=15, com_strength=7.0) score_map, scale_map = soft_max_and_argmax_1d( input=scale_probs, orint_maps=None, dim=-1, scale_list=self.scale_list, keepdim=True, com_strength1=self.score_com_strength, com_strength2=self.scale_com_strength, ) # from matplotlib import pyplot as plt;plt.imshow(score_map[0, :, :, 0].cpu().detach().numpy());plt.show() # import pdb # pdb.set_trace() return score_map, scale_map, ori_map
def soft_max_and_argmax_1d(input, orint_maps, scale_list, com_strength1, com_strength2, dim=-1, keepdim=True): """ input should be pixel probability in each scale this function calculate the final pixel probability summary from all scale and each pixel correspond scale :param input: scale_probs(B, H, W, 10) :param orint_maps: (B, H, W, 10, 2) :param dim: final channel :param scale_list: scale space list :param keepdim: kepp dimension :param com_strength1: magnify argument of score :param com_strength2: magnify argument of scale :return: score_map(B, H, W, 1), scale_map(B, H, W, 1), (orint_map(B, H, W, 1, 2)) """ inputs_exp1 = torch.exp( com_strength1 * (input - torch.max(input, dim=dim, keepdim=True)[0])) input_softmax1 = inputs_exp1 / ( inputs_exp1.sum(dim=dim, keepdim=True) + 1e-8) # (B, H, W, 10) inputs_exp2 = torch.exp( com_strength2 * (input - torch.max(input, dim=dim, keepdim=True)[0])) input_softmax2 = inputs_exp2 / ( inputs_exp2.sum(dim=dim, keepdim=True) + 1e-8) # (B, H, W, 10) score_map = torch.sum(input * input_softmax1, dim=dim, keepdim=keepdim) scale_list_shape = [1] * len(input.size()) scale_list_shape[dim] = -1 scale_list = scale_list.view(scale_list_shape).to(input_softmax2.device) scale_map = torch.sum(scale_list * input_softmax2, dim=dim, keepdim=keepdim) if orint_maps is not None: orint_map = torch.sum(orint_maps * input_softmax1.unsqueeze(-1), dim=dim - 1, keepdim=keepdim) # (B, H, W, 1, 2) orint_map = L2Norm(orint_map, dim=-1) return score_map, scale_map, orint_map else: return score_map, scale_map
def forward(self, photos): H, W = photos.shape[2:4] feature_map = F.leaky_relu(self.insnorm1(self.conv1(photos))) feature_map = F.leaky_relu(self.insnorm2(self.conv2(feature_map))) feature_map = F.leaky_relu(self.insnorm3(self.conv3(feature_map))) feature_map_2 = torch.nn.Upsample(scale_factor=0.75, mode='bilinear')(feature_map) feature_map_3 = torch.nn.Upsample(scale_factor=0.5, mode='bilinear')(feature_map) feature_map_1 = F.leaky_relu(self.insnorm_1_1(self.conv_1_1(feature_map))) feature_map_1 = F.leaky_relu(self.insnorm_1_2(self.conv_1_2(feature_map_1))) feature_map_1 = F.leaky_relu(self.insnorm_1_3(self.conv_1_3(feature_map_1))) feature_map_2 = F.leaky_relu(self.insnorm_2_1(self.conv_2_1(feature_map_2))) feature_map_2 = F.leaky_relu(self.insnorm_2_2(self.conv_2_2(feature_map_2))) feature_map_2 = F.leaky_relu(self.insnorm_2_3(self.conv_2_3(feature_map_2))) feature_map_3 = F.leaky_relu(self.insnorm_3_1(self.conv_3_1(feature_map_3))) feature_map_3 = F.leaky_relu(self.insnorm_3_2(self.conv_3_2(feature_map_3))) feature_map_3 = F.leaky_relu(self.insnorm_3_3(self.conv_3_3(feature_map_3))) score_maps = torch.cat( ( feature_map_1, torch.nn.Upsample(size=(H, W), mode='bilinear')(feature_map_2), torch.nn.Upsample(size=(H, W), mode='bilinear')(feature_map_3), ), 1, ) # (B, C, H, W) score_maps = score_maps.permute(0, 2, 3, 1) scale_probs = soft_nms_3d(score_maps, ksize=15, com_strength=3.0) score_map, scale_map = soft_max_and_argmax_1d( input=scale_probs, orint_maps=None, dim=-1, scale_list=self.scale_list, keepdim=True, com_strength1=self.score_com_strength, com_strength2=self.scale_com_strength, ) ori_map = L2Norm(self.conv_ori(feature_map),dim=1).permute(0, 2, 3, 1) return score_map, scale_map, ori_map
def forward(self, photos): H, W = photos.shape[2:4] feature_map = self.features(photos) feature_map_2 = torch.nn.Upsample(scale_factor=0.75, mode='bilinear')(feature_map) feature_map_3 = torch.nn.Upsample(scale_factor=0.5, mode='bilinear')(feature_map) ori_map = L2Norm(torch.nn.Upsample(size=(H, W), mode='bilinear')(self.insnorm_ori(self.conv_ori(feature_map_3))), dim=1).permute(0, 2, 3, 1) feature_map_1 = F.leaky_relu(self.insnorm_1_1(self.conv_1_1(feature_map))) feature_map_2 = F.leaky_relu(self.insnorm_1_1(self.conv_1_1(feature_map_2))) feature_map_3 = F.leaky_relu(self.insnorm_1_1(self.conv_1_1(feature_map_3))) score_maps = torch.cat( ( feature_map_1, torch.nn.Upsample(size=(H, W), mode='bilinear')(feature_map_2), torch.nn.Upsample(size=(H, W), mode='bilinear')(feature_map_3), ), 1, ) # (B, C, H, W) return score_maps,ori_map
def forward(self, photos): # H, W = photos.shape[2:4] feature_map = self.features(photos) # feature_map_2 = torch.nn.Upsample(scale_factor=0.75, mode='bilinear', align_corners=True)(feature_map) # feature_map_3 = torch.nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)(feature_map) feature_map_1 = F.leaky_relu(self.insnorm(self.conv(feature_map))) # feature_map_2 = F.leaky_relu(self.insnorm(self.conv(feature_map_2))) # feature_map_3 = F.leaky_relu(self.insnorm(self.conv(feature_map_3))) # ori_map = L2Norm(torch.nn.Upsample(size=(H, W), mode='bilinear', align_corners=True)( # self.insnorm_ori(self.conv_ori(feature_map_3))), dim=1).permute(0, 2, 3, 1) ori_map = L2Norm((self.insnorm_ori(self.conv_ori(feature_map))), dim=1).permute(0, 2, 3, 1) score_maps = torch.cat( ( feature_map_1, # torch.nn.Upsample(size=(H, W), mode='bilinear', align_corners=True)(feature_map_2), # torch.nn.Upsample(size=(H, W), mode='bilinear', align_corners=True)(feature_map_3), ), 1, ) # (B, C, H, W) score_maps = score_maps.permute(0, 2, 3, 1) scale_probs = soft_nms_3d(score_maps, ksize=15, com_strength=7.0) score_map, scale_map = soft_max_and_argmax_1d( input=scale_probs, orint_maps=None, dim=-1, scale_list=self.scale_list, keepdim=True, com_strength1=self.score_com_strength, com_strength2=self.scale_com_strength, ) return score_map, scale_map, ori_map
def gt_scale_orin(im2_scale, im2_orin, homo12, homo21): B, H, W, C = im2_scale.size() im2_cos, im2_sin = im2_orin.squeeze().chunk(chunks=2, dim=-1) # (B, H, W, 1) # im2_tan = im2_sin / im2_cos # each centX, centY, centZ is (B, H, W, 1) centX, centY, centZ = imgBatchXYZ(B, H, W).to(im2_scale.device).chunk(3, dim=3) """get im1w scale maps""" half_scale = im2_scale // 2 centXYZ = torch.cat((centX, centY, centZ), dim=3) # (B, H, W, 3) upXYZ = torch.cat((centX, centY - half_scale, centZ), dim=3) bottomXYZ = torch.cat((centX, centY + half_scale, centZ), dim=3) rightXYZ = torch.cat((centX + half_scale, centY, centZ), dim=3) leftXYZ = torch.cat((centX - half_scale, centY, centZ), dim=3) centXYw = transXYZ_2_to_1(centXYZ, homo21) # (B, H, W, 2) (x, y) centXw, centYw = centXYw.chunk(chunks=2, dim=-1) # (B, H, W, 1) centXYw = centXYw.long() upXYw = transXYZ_2_to_1(upXYZ, homo21).long() rightXYw = transXYZ_2_to_1(rightXYZ, homo21).long() bottomXYw = transXYZ_2_to_1(bottomXYZ, homo21).long() leftXYw = transXYZ_2_to_1(leftXYZ, homo21).long() upScale = MSD(upXYw, centXYw) rightScale = MSD(rightXYw, centXYw) bottomScale = MSD(bottomXYw, centXYw) leftScale = MSD(leftXYw, centXYw) centScale = (upScale + rightScale + bottomScale + leftScale) / 4 # (B, H, W, 1) """get im1w orintation maps""" offset_x, offset_y = im2_scale * im2_cos, im2_scale * im2_sin # (B, H, W, 1) offsetXYZ = torch.cat((centX + offset_x, centY + offset_y, centZ), dim=3) offsetXYw = transXYZ_2_to_1(offsetXYZ, homo21) # (B, H, W, 2) (x, y) offsetXw, offsetYw = offsetXYw.chunk(chunks=2, dim=-1) # (B, H, W, 1) offset_ww, offset_hw = offsetXw - centXw, offsetYw - centYw # (B, H, W, 1) offset_rw = (offset_ww**2 + offset_hw**2 + 1e-8).sqrt() # tan = offset_hw / (offset_ww + 1e-8) # (B, H, W, 1) cos_w = offset_ww / (offset_rw + 1e-8) # (B, H, W, 1) sin_w = offset_hw / (offset_rw + 1e-8) # (B, H, W, 1) # atan_w = np.arctan(tan.cpu().detach()) # (B, H, W, 1) # get left scale by transXYZ_2_to_1 map_xy_2_to_1 = transXYZ_2_to_1(centXYZ, homo12).round().long() # (B, H, W, 2) x, y = map_xy_2_to_1.chunk(2, dim=3) # each x and y is (B, H, W, 1) x = x.clamp(min=0, max=W - 1) y = y.clamp(min=0, max=H - 1) # (B, H, W, 1) im1w_scale = centScale[torch.arange(B)[:, None].repeat(1, H * W), y.view(B, -1), x.view(B, -1)].view(im2_scale.size()) # (B, H, W, 1, 2) im1w_cos = cos_w[torch.arange(B)[:, None].repeat(1, H * W), y.view(B, -1), x.view(B, -1)].view(im2_cos.size()) im1w_sin = sin_w[torch.arange(B)[:, None].repeat(1, H * W), y.view(B, -1), x.view(B, -1)].view(im2_sin.size()) im1w_orin = torch.cat((im1w_cos[:, None], im1w_sin[:, None]), dim=-1) im1w_orin = L2Norm(im1w_orin, dim=-1).to(im2_orin.device) return im1w_scale, im1w_orin
def forward(self, photos): # Extract score map in scale space from 3 to 21 score_featmaps_s3 = F.leaky_relu(self.insnorm1(self.conv1(photos))) score_map_s3 = self.insnorm_s3( self.conv_s3(score_featmaps_s3)).permute(0, 2, 3, 1) orint_map_s3 = (L2Norm(self.conv_o3(score_featmaps_s3), dim=1).permute(0, 2, 3, 1).unsqueeze(-2)) score_featmaps_s5 = F.leaky_relu( self.insnorm2(self.conv2(score_featmaps_s3))) score_map_s5 = self.insnorm_s5( self.conv_s5(score_featmaps_s5)).permute(0, 2, 3, 1) orint_map_s5 = (L2Norm(self.conv_o5(score_featmaps_s5), dim=1).permute(0, 2, 3, 1).unsqueeze(-2)) score_featmaps_s5 = score_featmaps_s5 + score_featmaps_s3 score_featmaps_s7 = F.leaky_relu( self.insnorm3(self.conv3(score_featmaps_s5))) score_map_s7 = self.insnorm_s7( self.conv_s7(score_featmaps_s7)).permute(0, 2, 3, 1) orint_map_s7 = (L2Norm(self.conv_o7(score_featmaps_s7), dim=1).permute(0, 2, 3, 1).unsqueeze(-2)) score_featmaps_s7 = score_featmaps_s7 + score_featmaps_s5 score_featmaps_s9 = F.leaky_relu( self.insnorm4(self.conv4(score_featmaps_s7))) score_map_s9 = self.insnorm_s9( self.conv_s9(score_featmaps_s9)).permute(0, 2, 3, 1) orint_map_s9 = (L2Norm(self.conv_o9(score_featmaps_s9), dim=1).permute(0, 2, 3, 1).unsqueeze(-2)) score_featmaps_s9 = score_featmaps_s9 + score_featmaps_s7 score_featmaps_s11 = F.leaky_relu( self.insnorm5(self.conv5(score_featmaps_s9))) score_map_s11 = self.insnorm_s11( self.conv_s11(score_featmaps_s11)).permute(0, 2, 3, 1) orint_map_s11 = (L2Norm(self.conv_o11(score_featmaps_s11), dim=1).permute(0, 2, 3, 1).unsqueeze(-2)) score_featmaps_s11 = score_featmaps_s11 + score_featmaps_s9 score_featmaps_s13 = F.leaky_relu( self.insnorm6(self.conv6(score_featmaps_s11))) score_map_s13 = self.insnorm_s13( self.conv_s13(score_featmaps_s13)).permute(0, 2, 3, 1) orint_map_s13 = (L2Norm(self.conv_o13(score_featmaps_s13), dim=1).permute(0, 2, 3, 1).unsqueeze(-2)) score_featmaps_s13 = score_featmaps_s13 + score_featmaps_s11 score_featmaps_s15 = F.leaky_relu( self.insnorm7(self.conv7(score_featmaps_s13))) score_map_s15 = self.insnorm_s15( self.conv_s15(score_featmaps_s15)).permute(0, 2, 3, 1) orint_map_s15 = (L2Norm(self.conv_o15(score_featmaps_s15), dim=1).permute(0, 2, 3, 1).unsqueeze(-2)) score_featmaps_s15 = score_featmaps_s15 + score_featmaps_s13 score_featmaps_s17 = F.leaky_relu( self.insnorm8(self.conv8(score_featmaps_s15))) score_map_s17 = self.insnorm_s17( self.conv_s17(score_featmaps_s17)).permute(0, 2, 3, 1) orint_map_s17 = (L2Norm(self.conv_o17(score_featmaps_s17), dim=1).permute(0, 2, 3, 1).unsqueeze(-2)) score_featmaps_s17 = score_featmaps_s17 + score_featmaps_s15 score_featmaps_s19 = F.leaky_relu( self.insnorm9(self.conv9(score_featmaps_s17))) score_map_s19 = self.insnorm_s19( self.conv_s19(score_featmaps_s19)).permute(0, 2, 3, 1) orint_map_s19 = (L2Norm(self.conv_o19(score_featmaps_s19), dim=1).permute(0, 2, 3, 1).unsqueeze(-2)) score_featmaps_s19 = score_featmaps_s19 + score_featmaps_s17 score_featmaps_s21 = F.leaky_relu( self.insnorm10(self.conv10(score_featmaps_s19))) score_map_s21 = self.insnorm_s21( self.conv_s21(score_featmaps_s21)).permute(0, 2, 3, 1) orint_map_s21 = (L2Norm(self.conv_o21(score_featmaps_s21), dim=1).permute(0, 2, 3, 1).unsqueeze(-2)) score_maps = torch.cat( ( score_map_s3, score_map_s5, score_map_s7, score_map_s9, score_map_s11, score_map_s13, score_map_s15, score_map_s17, score_map_s19, score_map_s21, ), -1, ) # (B, H, W, C) orint_maps = torch.cat( ( orint_map_s3, orint_map_s5, orint_map_s7, orint_map_s9, orint_map_s11, orint_map_s13, orint_map_s15, orint_map_s17, orint_map_s19, orint_map_s21, ), -2, ) # (B, H, W, 10, 2) # get each pixel probability in all scale scale_probs = soft_nms_3d(score_maps, ksize=15, com_strength=3.0) # get each pixel probability summary from all scale space and correspond scale value score_map, scale_map, orint_map = soft_max_and_argmax_1d( input=scale_probs, orint_maps=orint_maps, dim=-1, scale_list=self.scale_list, keepdim=True, com_strength1=self.score_com_strength, com_strength2=self.scale_com_strength, ) return score_map, scale_map, orint_map