def extract_patches_from_pyr(self, dLAFs, PS = 41):
     pyr_idxs, level_idxs = get_pyramid_and_level_index_for_LAFs(dLAFs, self.sigmas, self.pix_dists, PS)
     pyr_inv_idxs = get_inverted_pyr_index(self.scale_pyr, pyr_idxs, level_idxs)
     patches = extract_patches_from_pyramid_with_inv_index(self.scale_pyr,
                                                   pyr_inv_idxs,
                                                   normalizeLAFs(dLAFs, self.scale_pyr[0][0].size(3), self.scale_pyr[0][0].size(2)), 
                                                   PS = PS)
     return patches
 def getOrientation(self, LAFs, final_pyr_idxs, final_level_idxs):
     pyr_inv_idxs = get_inverted_pyr_index(self.scale_pyr, final_pyr_idxs, final_level_idxs)
     patches_small =  extract_patches_from_pyramid_with_inv_index(self.scale_pyr, pyr_inv_idxs, LAFs, PS = self.OriNet.PS)
     max_iters = 1
     ### Detect orientation
     for i in range(max_iters):
         angles = self.OriNet(patches_small)
         LAFs = torch.cat([torch.bmm( LAFs[:,:,:2], angles2A(angles)), LAFs[:,:,2:]], dim = 2)
         if i != max_iters:
             patches_small = extract_patches_from_pyramid_with_inv_index(self.scale_pyr, pyr_inv_idxs, LAFs, PS = self.OriNet.PS)        
     return LAFs
 def getAffineShape(self, final_resp, LAFs, final_pyr_idxs, final_level_idxs, num_features = 0):
     pe_time = 0
     affnet_time = 0
     pyr_inv_idxs = get_inverted_pyr_index(self.scale_pyr, final_pyr_idxs, final_level_idxs)
     t = time.time()
     patches_small = extract_patches_from_pyramid_with_inv_index(self.scale_pyr, pyr_inv_idxs, LAFs, PS = self.AffNet.PS)
     pe_time+=time.time() - t
     t = time.time()
     base_A = torch.eye(2).unsqueeze(0).expand(final_pyr_idxs.size(0),2,2)
     if final_resp.is_cuda:
         base_A = base_A.cuda()
     base_A = Variable(base_A)
     is_good = None
     n_patches = patches_small.size(0)
     for i in range(self.num_Baum_iters):
         t = time.time()
         A = batched_forward(self.AffNet, patches_small, 512)
         is_good_current = 1
         affnet_time += time.time() - t
         if is_good is None:
             is_good = is_good_current
         else:
             is_good = is_good * is_good_current
         base_A = torch.bmm(A, base_A); 
         new_LAFs = torch.cat([torch.bmm(base_A,LAFs[:,:,0:2]), LAFs[:,:,2:] ], dim =2)
         #print torch.sqrt(new_LAFs[0,0,0]*new_LAFs[0,1,1] - new_LAFs[0,1,0] *new_LAFs[0,0,1]) * scale_pyr[0][0].size(2)
         if i != self.num_Baum_iters - 1:
             pe_time+=time.time() - t
             t = time.time()
             patches_small =  extract_patches_from_pyramid_with_inv_index(self.scale_pyr, pyr_inv_idxs, new_LAFs, PS = self.AffNet.PS)
             pe_time+= time.time() - t
             l1,l2 = batch_eig2x2(A)      
             ratio1 =  torch.abs(l1 / (l2 + 1e-8))
             converged_mask = (ratio1 <= 1.2) * (ratio1 >= (0.8)) 
     l1,l2 = batch_eig2x2(base_A)
     ratio = torch.abs(l1 / (l2 + 1e-8))
     idxs_mask = ((ratio < 6.0) * (ratio > (1./6.)))# * converged_mask.float()) > 0
     num_survived = idxs_mask.float().sum()
     if (num_features > 0) and (num_survived.data[0] > num_features):
         final_resp =  final_resp * idxs_mask.float() #zero bad points
         final_resp, idxs = torch.topk(final_resp, k = num_features);
     else:
         idxs = torch.nonzero(idxs_mask.data).view(-1).long()
         final_resp = final_resp[idxs]
     final_pyr_idxs = final_pyr_idxs[idxs]
     final_level_idxs = final_level_idxs[idxs]
     base_A = torch.index_select(base_A, 0, idxs)
     LAFs = torch.index_select(LAFs, 0, idxs)
     new_LAFs = torch.cat([torch.bmm(rectifyAffineTransformationUpIsUp(base_A), LAFs[:,:,0:2]),
                            LAFs[:,:,2:]], dim =2)
     print 'affnet_time',affnet_time
     print 'pe_time', pe_time
     return final_resp, new_LAFs, final_pyr_idxs, final_level_idxs  
    def getAffineShape(self,scale_pyr, final_resp, LAFs, final_pyr_idxs, final_level_idxs, num_features = 0, n_iters = 1):
        pyr_inv_idxs = get_inverted_pyr_index(scale_pyr, final_pyr_idxs, final_level_idxs)
        patches_small = extract_patches_from_pyramid_with_inv_index(scale_pyr, pyr_inv_idxs, LAFs, PS = self.AffNet.PS)
        base_A = torch.eye(2).unsqueeze(0).expand(final_pyr_idxs.size(0),2,2)
        if final_resp.is_cuda:
            base_A = base_A.cuda()
        base_A = Variable(base_A)
        is_good = None
        for i in range(n_iters):
            A, is_good_current = self.AffNet(patches_small)
            if is_good is None:
                is_good = is_good_current
            else:
                is_good = is_good * is_good_current
            base_A = torch.bmm(A, base_A); 
            new_LAFs = torch.cat([torch.bmm(base_A,LAFs[:,:,0:2]), LAFs[:,:,2:] ], dim =2)
            if i != self.num_Baum_iters - 1:
                patches_small =  extract_patches_from_pyramid_with_inv_index(scale_pyr, pyr_inv_idxs, new_LAFs, PS = self.AffNet.PS)
                l1,l2 = batch_eig2x2(A)      
                ratio1 =  torch.abs(l1 / (l2 + 1e-8))
                converged_mask = (ratio1 <= 1.2) * (ratio1 >= (0.8)) 
        l1,l2 = batch_eig2x2(base_A)
        #print l1,l2
        ratio = torch.abs(l1 / (l2 + 1e-8))
        #print new_LAFs[0:2,:,:]
        #print '***'
        #print ((ratio < 6.0) * (ratio > (1./6.))).float().sum()
        #print converged_mask.float().sum()
        #print is_good.float().sum()

        ratio = 1.0 + 0 * torch.abs(l1 / (l2 + 1e-8)) #CHANGE after training
        #idxs_mask = (ratio < 6.0) * (ratio > (1./6.)) * (is_good > 0.5)#  * converged_mask
        idxs_mask = ((ratio < 6.0) * (ratio > (1./6.)))# * converged_mask.float()) > 0
        num_survived = idxs_mask.float().sum()
        #print num_survived
        if (num_features > 0) and (num_survived.data[0] > num_features):
            final_resp =  final_resp * idxs_mask.float() #zero bad points
            final_resp, idxs = torch.topk(final_resp, k = num_features);
        else:
            idxs = torch.nonzero(idxs_mask.data).view(-1).long()
            if (len(idxs.size()) == 0) or (idxs.size(0) == idxs_mask.size(0)):
                idxs = None
        if idxs is not None:
            final_resp = torch.index_select(final_resp, 0, idxs)
            final_pyr_idxs = final_pyr_idxs[idxs]
            final_level_idxs = final_level_idxs[idxs]
            base_A = torch.index_select(base_A, 0, idxs)
            LAFs = torch.index_select(LAFs, 0, idxs)
        #new_LAFs = torch.cat([torch.bmm(rectifyAffineTransformationUpIsUp(base_A), LAFs[:,:,0:2]),
        #                       LAFs[:,:,2:]], dim =2)
        new_LAFs = torch.cat([torch.bmm(base_A, LAFs[:,:,0:2]),
                               LAFs[:,:,2:]], dim =2)
        return final_resp, new_LAFs, final_pyr_idxs, final_level_idxs  
 def getOrientation(self,scale_pyr, LAFs, final_pyr_idxs, final_level_idxs):
     pyr_inv_idxs = get_inverted_pyr_index(scale_pyr, final_pyr_idxs, final_level_idxs)
     patches_small =  extract_patches_from_pyramid_with_inv_index(scale_pyr, pyr_inv_idxs, LAFs, PS = self.OriNet.PS)
     max_iters = 1
     ### Detect orientation
     for i in range(max_iters):
         angles = self.OriNet(patches_small)
         #print np.degrees(ori.data.cpu().numpy().ravel()[1])
         LAFs = torch.cat([torch.bmm(angles2A(angles), LAFs[:,:,:2]), LAFs[:,:,2:]], dim = 2)
         if i != max_iters:
             patches_small = extract_patches_from_pyramid_with_inv_index(scale_pyr, pyr_inv_idxs, LAFs, PS = self.OriNet.PS)        
     return LAFs
 def ImproveLAFsEstimation(self,
                           scale_pyr,
                           final_resp,
                           LAFs,
                           final_pyr_idxs,
                           final_level_idxs,
                           num_features=0,
                           n_iters=1):
     pyr_inv_idxs = get_inverted_pyr_index(scale_pyr, final_pyr_idxs,
                                           final_level_idxs)
     patches_small = extract_patches_from_pyramid_with_inv_index(
         scale_pyr, pyr_inv_idxs, LAFs, PS=self.AffNet.PS)
     base_A = torch.eye(2).unsqueeze(0).expand(final_pyr_idxs.size(0), 2, 2)
     if final_resp.is_cuda:
         base_A = base_A.cuda()
     base_A = Variable(base_A)
     is_good = None
     for i in range(n_iters):
         aff_out = self.AffNet(patches_small)
         A = aff_out[0]
         is_good_current = aff_out[1]
         if is_good is None:
             is_good = is_good_current
         else:
             is_good = is_good * is_good_current
         base_A = torch.bmm(A[:, :, 0:2], base_A)
         features_scale = torch.sqrt(
             torch.abs(LAFs[:, 0:1, 0:1] * LAFs[:, 1:2, 1:2] -
                       LAFs[:, 1:2, 0:1] * LAFs[:, 0:1, 1:2]))
         new_LAFs = torch.cat([
             torch.bmm(base_A, LAFs[:, :, 0:2]), LAFs[:, :, 2:] +
             features_scale.expand(A.size(0), 2, 1) * A[:, :, 2:]
         ],
                              dim=2)
         if i != self.num_Baum_iters - 1:
             patches_small = extract_patches_from_pyramid_with_inv_index(
                 scale_pyr, pyr_inv_idxs, new_LAFs, PS=self.AffNet.PS)
             #l1,l2 = batch_eig2x2(A[:,0:2,0:2])
             #ratio1 =  torch.abs(l1 / (l2 + 1e-8))
             #converged_mask = (ratio1 <= 1.2) * (ratio1 >= (0.8))
     features_scale = torch.sqrt(
         torch.abs(LAFs[:, 0:1, 0:1] * LAFs[:, 1:2, 1:2] -
                   LAFs[:, 1:2, 0:1] * LAFs[:, 0:1, 1:2]))
     new_LAFs = torch.cat([
         torch.bmm(base_A, LAFs[:, :, 0:2]), LAFs[:, :, 2:] +
         A[:, :, 2:] * features_scale.expand(A.size(0), 2, 1)
     ],
                          dim=2)
     return final_resp, new_LAFs, final_pyr_idxs, final_level_idxs, aff_out[
         2], A
 def forward(self,x, do_ori = False):
     ### Detection
     t = time.time()
     num_features_prefilter = self.num
     if self.num_Baum_iters > 0:
         num_features_prefilter = int(1.5 * self.num);
     responses, LAFs, final_pyr_idxs, final_level_idxs = self.multiScaleDetector(x,num_features_prefilter)
     print time.time() - t, 'detection multiscale'
     t = time.time()
     LAFs[:,0:2,0:2] =   self.mrSize * LAFs[:,:,0:2]
     if self.num_Baum_iters > 0:
         responses, LAFs, final_pyr_idxs, final_level_idxs  = self.getAffineShape(responses, LAFs, final_pyr_idxs, final_level_idxs, self.num)
     print time.time() - t, 'affine shape iters'
     t = time.time()
     if do_ori:
         LAFs = self.getOrientation(LAFs, final_pyr_idxs, final_level_idxs)
         pyr_inv_idxs = get_inverted_pyr_index(self.scale_pyr, final_pyr_idxs, final_level_idxs)
     #patches = extract_patches_from_pyramid_with_inv_index(scale_pyr, pyr_inv_idxs, LAFs, PS = self.PS)
     #patches = extract_patches(x, LAFs, PS = self.PS)
     #print time.time() - t, len(LAFs), ' patches extraction'
     return denormalizeLAFs(LAFs, x.size(3), x.size(2)), responses