def forward(self,x, random_Baum = False, random_resp = False, return_patches = False):
     ### Detection
     num_features_prefilter = self.num
     #if self.num_Baum_iters > 0:
     #    num_features_prefilter = 2 * self.num;
     if random_resp:
         num_features_prefilter *= 4
     responses, LAFs, final_pyr_idxs, final_level_idxs, scale_pyr = self.multiScaleDetector(x,num_features_prefilter)
     if random_resp:
         if self.num < responses.size(0):
             ridxs = torch.randperm(responses.size(0))[:self.num]
             if x.is_cuda:
                 ridxs = ridxs.cuda() 
             responses = responses[ridxs]
             LAFs = LAFs[ridxs ,:,:]
             final_pyr_idxs = final_pyr_idxs[ridxs]
             final_level_idxs = final_level_idxs[ridxs]
     LAFs[:,0:2,0:2] =   self.mrSize * LAFs[:,:,0:2]
     n_iters = self.num_Baum_iters;
     if random_Baum and (n_iters > 1):
         n_iters = int(np.random.randint(1,n_iters + 1)) 
     if n_iters > 0:
         responses, LAFs, final_pyr_idxs, final_level_idxs  = self.getAffineShape(scale_pyr, responses, LAFs,
                                                                                  final_pyr_idxs, final_level_idxs, self.num, n_iters = n_iters)
     #LAFs = self.getOrientation(scale_pyr, LAFs, final_pyr_idxs, final_level_idxs)
     #if return_patches:
     #    pyr_inv_idxs = get_inverted_pyr_index(scale_pyr, final_pyr_idxs, final_level_idxs)
     #    patches = extract_patches_from_pyramid_with_inv_index(scale_pyr, pyr_inv_idxs, LAFs, PS = self.PS)
     if return_patches:
         patches = extract_patches(x, LAFs, PS = self.PS)
     else:
         patches = None
     return denormalizeLAFs(LAFs, x.size(3), x.size(2)), patches, responses#, scale_pyr
def train(train_loader, model, optimizer, epoch):
    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        data_a, data_p = data
        if args.cuda:
            data_a, data_p  = data_a.float().cuda(), data_p.float().cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
        rot_LAFs, inv_rotmat = get_random_rotation_LAFs(data_a, math.pi)
        scale = Variable( 0.9 + 0.3* torch.rand(data_a.size(0), 1, 1));
        if args.cuda:
            scale = scale.cuda()
        rot_LAFs[:,0:2,0:2] = rot_LAFs[:,0:2,0:2] * scale.expand(data_a.size(0),2,2)
        shift_w, shift_h = get_random_shifts_LAFs(data_a, 2, 2)
        rot_LAFs[:,0,2] = rot_LAFs[:,0,2] + shift_w / float(data_a.size(3))
        rot_LAFs[:,1,2] = rot_LAFs[:,1,2] + shift_h / float(data_a.size(2))
        data_a_rot = extract_patches(data_a,  rot_LAFs, PS = data_a.size(2))
        st = int((data_p.size(2) - model.PS)/2)
        fin = st + model.PS

        data_p_crop = data_p[:,:, st:fin, st:fin].contiguous()
        data_a_rot_crop = data_a_rot[:,:, st:fin, st:fin].contiguous()
        out_a_rot, out_p, out_a = model(data_a_rot_crop,True), model(data_p_crop,True), model(data_a[:,:, st:fin, st:fin].contiguous(), True)
        out_p_rotatad = torch.bmm(inv_rotmat, out_p)

        ######Apply rot and get sifts
        out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(data_a_rot, out_a_rot, crop_size = model.PS)
        out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(data_p, out_p, crop_size = model.PS)

        desc_a = descriptor(out_patches_a_crop)
        desc_p = descriptor(out_patches_p_crop)
        descr_dist =  torch.sqrt(((desc_a - desc_p)**2).view(data_a.size(0),-1).sum(dim=1) + 1e-6).mean()
        geom_dist = torch.sqrt(((out_a_rot - out_p_rotatad)**2 ).view(-1,4).sum(dim=1)[0] + 1e-8).mean()
        if args.loss == 'HardNet':
            loss = loss_HardNet(desc_a,desc_p); 
        elif args.loss == 'HardNetDetach':
            loss = loss_HardNetDetach(desc_a,desc_p); 
        elif args.loss == 'Geom':
            loss = geom_dist; 
        elif args.loss == 'PosDist':
            loss = descr_dist; 
        else:
            print('Unknown loss function')
            sys.exit(1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if batch_idx % args.log_interval == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}, {:.4f},{:.4f}'.format(
                    epoch, batch_idx * len(data_a), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader),
                    float(loss.detach().cpu().numpy()), float(geom_dist.detach().cpu().numpy()), float(descr_dist.detach().cpu().numpy())))
    torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()},
               '{}/checkpoint_{}.pth'.format(LOG_DIR,epoch))
예제 #3
0
def extract_and_crop_patches_by_predicted_transform(patches, trans, crop_size = 32):
    assert patches.size(0) == trans.size(0)
    st = int((patches.size(2) - crop_size) / 2)
    fin = st + crop_size
    rot_LAFs = Variable(torch.FloatTensor([[0.5, 0, 0.5],[0, 0.5, 0.5]]).unsqueeze(0).repeat(patches.size(0),1,1));
    if patches.is_cuda:
        rot_LAFs = rot_LAFs.cuda()
        trans = trans.cuda()
    rot_LAFs1  = torch.cat([torch.bmm(trans, rot_LAFs[:,0:2,0:2]), rot_LAFs[:,0:2,2:]], dim = 2);
    return extract_patches(patches,  rot_LAFs1, PS = patches.size(2))[:,:, st:fin, st:fin].contiguous()
def extract_random_LAF(data, max_rot = math.pi, max_tilt = 1.0, crop_size = 32):
    st = int((data.size(2) - crop_size)/2)
    fin = st + crop_size
    if type(max_rot) is float:
        rot_LAFs, inv_rotmat = get_random_rotation_LAFs(data, max_rot)
    else:
        rot_LAFs = max_rot
        inv_rotmat = None
    aff_LAFs, inv_TA = get_random_norm_affine_LAFs(data, max_tilt);
    aff_LAFs[:,0:2,0:2] = torch.bmm(rot_LAFs[:,0:2,0:2],aff_LAFs[:,0:2,0:2])
    data_aff = extract_patches(data,  aff_LAFs, PS = data.size(2))
    data_affcrop = data_aff[:,:, st:fin, st:fin].contiguous()
    return data_affcrop, data_aff, rot_LAFs,inv_rotmat,inv_TA 
 def forward(self,
             x,
             random_Baum=False,
             random_resp=False,
             return_patches=False):
     ### Detection
     num_features_prefilter = self.num
     #if self.num_Baum_iters > 0:
     #    num_features_prefilter = 2 * self.num;
     if random_resp:
         num_features_prefilter *= 4
     responses, LAFs, final_pyr_idxs, final_level_idxs, scale_pyr = self.multiScaleDetector(
         x, num_features_prefilter)
     if random_resp:
         if self.num < responses.size(0):
             ridxs = torch.randperm(responses.size(0))[:self.num]
             if x.is_cuda:
                 ridxs = ridxs.cuda()
             responses = responses[ridxs]
             LAFs = LAFs[ridxs, :, :]
             final_pyr_idxs = final_pyr_idxs[ridxs]
             final_level_idxs = final_level_idxs[ridxs]
     LAFs[:, 0:2, 0:2] = self.mrSize * LAFs[:, :, 0:2]
     n_iters = self.num_Baum_iters
     if random_Baum and (n_iters > 1):
         n_iters = int(np.random.randint(1, n_iters + 1))
     if n_iters > 0:
         responses, LAFs, final_pyr_idxs, final_level_idxs, dets, A = self.ImproveLAFsEstimation(
             scale_pyr,
             responses,
             LAFs,
             final_pyr_idxs,
             final_level_idxs,
             self.num,
             n_iters=n_iters)
     if return_patches:
         patches = extract_patches(x, LAFs, PS=self.PS)
     else:
         patches = None
     return denormalizeLAFs(
         LAFs, x.size(3),
         x.size(2)), patches, responses, dets, A  #, scale_pyr
예제 #6
0
def test(test_loader, model, epoch):
    # switch to evaluate mode
    model.eval()

    geom_distances, desc_distances = [], []

    pbar = tqdm(enumerate(test_loader))
    for batch_idx, (data_a, data_p) in pbar:

        if args.cuda:
            data_a, data_p = data_a.float().cuda(), data_p.float().cuda()
        data_a, data_p = Variable(data_a, volatile=True), Variable(data_p, volatile=True)
        rot_LAFs, inv_rotmat = get_random_rotation_LAFs(data_a, math.pi)
        data_a_rot = extract_patches(data_a,  rot_LAFs, PS = data_a.size(2))
        st = int((data_p.size(2) - model.PS)/2)
        fin = st + model.PS
        data_p = data_p[:,:, st:fin, st:fin].contiguous()
        data_a_rot = data_a_rot[:,:, st:fin, st:fin].contiguous()
        out_a_rot, out_p = model(data_a_rot, True), model(data_p, True)
        out_p_rotatad = torch.bmm(inv_rotmat, out_p)
        geom_dist = torch.sqrt((out_a_rot - out_p_rotatad)**2 + 1e-12).mean()
        out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(data_a_rot, out_a_rot, crop_size = model.PS)
        out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(data_p, out_p, crop_size = model.PS)
        desc_a = descriptor(out_patches_a_crop)
        desc_p = descriptor(out_patches_p_crop)
        descr_dist =  torch.sqrt(((desc_a - desc_p)**2).view(data_a.size(0),-1).sum(dim=1) + 1e-6)#/ float(desc_a.size(1))
        descr_dist = descr_dist.mean()
        geom_distances.append(geom_dist.data.cpu().numpy().reshape(-1,1))
        desc_distances.append(descr_dist.data.cpu().numpy().reshape(-1,1))
        if batch_idx % args.log_interval == 0:
            pbar.set_description(' Test Epoch: {} [{}/{} ({:.0f}%)]'.format(
                epoch, batch_idx * len(data_a), len(test_loader.dataset),
                       100. * batch_idx / len(test_loader)))

    geom_distances = np.vstack(geom_distances).reshape(-1,1)
    desc_distances = np.vstack(desc_distances).reshape(-1,1)

    print('\33[91mTest set: Geom MSE: {:.8f}\n\33[0m'.format(geom_distances.mean()))
    print('\33[91mTest set: Desc dist: {:.8f}\n\33[0m'.format(desc_distances.mean()))
    return
예제 #7
0
 def forward(self, cropped_feats, nLAFs):
     return self.HNHead(extract_patches(cropped_feats, nLAFs, PS=16))
예제 #8
0
def train(train_loader, model, optimizer, epoch, cuda=True):
    # switch to train mode
    model.train()
    log_interval = 1
    total_loss = 0
    total_feats = 0
    spatial_only = True
    pbar = enumerate(train_loader)
    for batch_idx, data in pbar:
        print 'Batch idx', batch_idx
        #print model.detector.shift_net[0].weight.data.cpu().numpy()
        img1, img2, H1to2 = data
        #if np.abs(np.sum(H.numpy()) - 3.0) > 0.01:
        #    continue
        H1to2 = H1to2.squeeze(0)
        if (img1.size(3) * img1.size(4) > 1340 * 1000):
            print img1.shape, ' too big, skipping'
            continue
        img1 = img1.float().squeeze(0)
        img2 = img2.float().squeeze(0)
        if cuda:
            img1, img2, H1to2 = img1.cuda(), img2.cuda(), H1to2.cuda()
        img1, img2, H1to2 = Variable(img1, requires_grad=False), Variable(
            img2, requires_grad=False), Variable(H1to2, requires_grad=False)
        LAFs1, aff_norm_patches1, resp1 = HA(img1, True, True, True)
        LAFs2, aff_norm_patches2, resp2 = HA(img2, True, True)
        if (len(LAFs1) == 0) or (len(LAFs2) == 0):
            optimizer.zero_grad()
            continue
        fro_dists, idxs_in1, idxs_in2, LAFs2_in_1 = get_GT_correspondence_indexes_Fro_and_center(
            LAFs1,
            LAFs2,
            H1to2,
            dist_threshold=4.,
            center_dist_th=7.0,
            skip_center_in_Fro=True,
            do_up_is_up=True,
            return_LAF2_in_1=True)
        if len(fro_dists.size()) == 0:
            optimizer.zero_grad()
            print 'skip'
            continue
        aff_patches_from_LAFs2_in_1 = extract_patches(
            img1,
            normalizeLAFs(LAFs2_in_1[idxs_in2.data.long(), :, :], img1.size(3),
                          img1.size(2)))

        #loss = fro_dists.mean()
        patch_dist = torch.sqrt(
            (aff_norm_patches1[idxs_in1.data.long(), :, :, :] / 100. -
             aff_patches_from_LAFs2_in_1 / 100.)**2 + 1e-8).view(
                 fro_dists.size(0), -1).mean(dim=1)
        loss = (fro_dists * patch_dist).mean()
        print 'Fro dist', fro_dists.mean().data.cpu().numpy(
        )[0], loss.data.cpu().numpy()[0]
        total_loss += loss.data.cpu().numpy()[0]
        #loss += patch_dist
        total_feats += fro_dists.size(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #adjust_learning_rate(optimizer)
        print epoch, batch_idx, loss.data.cpu().numpy()[0], idxs_in1.shape

    print 'Train total loss:', total_loss / float(
        batch_idx + 1), ' features ', float(total_feats) / float(batch_idx + 1)
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}/elu_new_checkpoint_{}.pth'.format(LOG_DIR, epoch))
예제 #9
0
def test(test_loader, model, epoch):
    # switch to evaluate mode
    model.eval()
    geom_distances, desc_distances = [], []
    pbar = tqdm(enumerate(test_loader))
    for batch_idx, data in pbar:
        data_a, data_p = data
        if args.cuda:
            data_a, data_p = data_a.float().cuda(), data_p.float().cuda()
        data_a, data_p = Variable(data_a,
                                  volatile=True), Variable(data_p,
                                                           volatile=True)
        st = int((data_p.size(2) - model.PS) / 2)
        fin = st + model.PS
        aff_LAFs_a, inv_TA_a = get_random_norm_affine_LAFs(data_a, 3.0)
        shift_w_a, shift_h_a = get_random_shifts_LAFs(data_a, 3, 3)
        aff_LAFs_a[:, 0,
                   2] = aff_LAFs_a[:, 0, 2] + shift_w_a / float(data_a.size(3))
        aff_LAFs_a[:, 1,
                   2] = aff_LAFs_a[:, 1, 2] + shift_h_a / float(data_a.size(2))
        data_a_aff = extract_patches(data_a, aff_LAFs_a, PS=data_a.size(2))
        data_a_aff_crop = data_a_aff[:, :, st:fin, st:fin].contiguous()
        aff_LAFs_p, inv_TA_p = get_random_norm_affine_LAFs(data_p, 3.0)
        shift_w_p, shift_h_p = get_random_shifts_LAFs(data_p, 3, 3)
        aff_LAFs_p[:, 0,
                   2] = aff_LAFs_p[:, 0, 2] + shift_w_p / float(data_a.size(3))
        aff_LAFs_p[:, 1,
                   2] = aff_LAFs_p[:, 1, 2] + shift_h_p / float(data_a.size(2))
        data_p_aff = extract_patches(data_p, aff_LAFs_p, PS=data_p.size(2))
        data_p_aff_crop = data_p_aff[:, :, st:fin, st:fin].contiguous()
        out_a_aff, out_p_aff = model(data_a_aff_crop,
                                     True), model(data_p_aff_crop, True)
        out_p_aff_back = torch.bmm(inv_TA_p, out_p_aff)
        out_a_aff_back = torch.bmm(inv_TA_a, out_a_aff)
        ######Apply rot and get sifts
        out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(
            data_a_aff, out_a_aff, crop_size=model.PS)
        out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(
            data_p_aff, out_p_aff, crop_size=model.PS)
        desc_a = descriptor(out_patches_a_crop)
        desc_p = descriptor(out_patches_p_crop)
        descr_dist = torch.sqrt((
            (desc_a - desc_p)**2).view(data_a.size(0), -1).sum(dim=1) +
                                1e-6) / float(desc_a.size(1))
        geom_dist = torch.sqrt((
            (out_a_aff_back - out_p_aff_back)**2).view(-1, 4).mean(dim=1) +
                               1e-8)
        geom_distances.append(geom_dist.mean().data.cpu().numpy().reshape(
            -1, 1))
        desc_distances.append(descr_dist.mean().data.cpu().numpy().reshape(
            -1, 1))
        if batch_idx % args.log_interval == 0:
            pbar.set_description(' Test Epoch: {} [{}/{} ({:.0f}%)]'.format(
                epoch, batch_idx * len(data_a), len(test_loader.dataset),
                100. * batch_idx / len(test_loader)))

    geom_distances = np.vstack(geom_distances).reshape(-1, 1)
    desc_distances = np.vstack(desc_distances).reshape(-1, 1)
    print('\33[91mTest set: Geom MSE: {:.8f}\n\33[0m'.format(
        geom_distances.mean()))
    print('\33[91mTest set: Desc dist: {:.8f}\n\33[0m'.format(
        desc_distances.mean()))
    return
예제 #10
0
def train(train_loader, model, optimizer, epoch):
    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        data_a, data_p = data
        if args.cuda:
            data_a, data_p = data_a.float().cuda(), data_p.float().cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
        st = int((data_p.size(2) - model.PS) / 2)
        fin = st + model.PS
        #
        #
        max_tilt = 3.0
        if epoch > 1:
            max_tilt = 4.0
        if epoch > 3:
            max_tilt = 4.5
        if epoch > 5:
            max_tilt = 4.8
        rot_LAFs_a, inv_rotmat_a = get_random_rotation_LAFs(data_a, math.pi)
        aff_LAFs_a, inv_TA_a = get_random_norm_affine_LAFs(data_a, max_tilt)
        aff_LAFs_a[:, 0:2, 0:2] = torch.bmm(rot_LAFs_a[:, 0:2, 0:2],
                                            aff_LAFs_a[:, 0:2, 0:2])
        data_a_aff = extract_patches(data_a, aff_LAFs_a, PS=data_a.size(2))
        data_a_aff_crop = data_a_aff[:, :, st:fin, st:fin].contiguous()
        aff_LAFs_p, inv_TA_p = get_random_norm_affine_LAFs(data_p, max_tilt)
        aff_LAFs_p[:, 0:2, 0:2] = torch.bmm(rot_LAFs_a[:, 0:2, 0:2],
                                            aff_LAFs_p[:, 0:2, 0:2])
        data_p_aff = extract_patches(data_p, aff_LAFs_p, PS=data_p.size(2))
        data_p_aff_crop = data_p_aff[:, :, st:fin, st:fin].contiguous()
        out_a_aff, out_p_aff = model(data_a_aff_crop,
                                     True), model(data_p_aff_crop, True)
        out_p_aff_back = torch.bmm(inv_TA_p, out_p_aff)
        out_a_aff_back = torch.bmm(inv_TA_a, out_a_aff)
        ######Apply rot and get sifts
        out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(
            data_a_aff, out_a_aff, crop_size=model.PS)
        out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(
            data_p_aff, out_p_aff, crop_size=model.PS)
        desc_a = descriptor(out_patches_a_crop)
        desc_p = descriptor(out_patches_p_crop)
        descr_dist = torch.sqrt((
            (desc_a - desc_p)**2).view(data_a.size(0), -1).sum(dim=1) + 1e-6)
        descr_loss = loss_HardNet(desc_a, desc_p, anchor_swap=True)
        geom_dist = torch.sqrt((
            (out_a_aff_back - out_p_aff_back)**2).view(-1, 4).mean(dim=1) +
                               1e-8)
        if args.merge == 'sum':
            loss = descr_loss
        elif args.merge == 'mul':
            loss = descr_loss
        else:
            print('Unknown merge option')
            sys.exit(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if batch_idx % 2 == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}, {},{:.4f}'.
                format(epoch, batch_idx * len(data_a),
                       len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.data[0],
                       geom_dist.mean().data[0],
                       descr_dist.mean().data[0]))
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch))
def train(train_loader, model, optimizer, epoch, cuda = True):
    # switch to train mode
    model.train()
    log_interval = 1
    total_loss = 0
    total_feats = 0
    spatial_only = True
    pbar = enumerate(train_loader)
    for batch_idx, data in pbar:
        #if batch_idx > 0:
        #    continue
        print 'Batch idx', batch_idx
        #print model.detector.shift_net[0].weight.data.cpu().numpy()
        img1, img2, H1to2  = data
        #if np.abs(np.sum(H.numpy()) - 3.0) > 0.01:
        #    continue
        H1to2 = H1to2.squeeze(0)
        do_aug = True
        if torch.abs(H1to2 - torch.eye(3)).sum() > 0.05:
            do_aug = False
        if (img1.size(3) *img1.size(4)   > 1340*1000):
            print img1.shape, ' too big, skipping'
            continue
        img1 = img1.float().squeeze(0)
        img2 = img2.float().squeeze(0)
        if cuda:
            img1, img2, H1to2 = img1.cuda(), img2.cuda(), H1to2.cuda()
        img1, img2, H1to2 = Variable(img1, requires_grad = False), Variable(img2, requires_grad = False), Variable(H1to2, requires_grad = False)
        if do_aug:
            new_img2, H_Orig2New = affineAug(img2, max_add = 0.2 )
            H1to2new = torch.mm(H_Orig2New, H1to2)
        else:
            new_img2 = img2
            H1to2new = H1to2
        #print H1to2
        LAFs1, aff_norm_patches1, resp1, dets1, A1 = HA(img1, True, False, True)
        LAFs2Aug, aff_norm_patches2, resp2, dets2, A2 = HA(new_img2, True, False)
        if (len(LAFs1) == 0) or (len(LAFs2Aug) == 0):
            optimizer.zero_grad()
            continue
        geom_loss, idxs_in1, idxs_in2, LAFs2_in_1  = LAFMagic(LAFs1,
                            LAFs2Aug,
                            H1to2new,
                           3.0, scale_log = 0.3)
        if  len(idxs_in1.size()) == 0:
            optimizer.zero_grad()
            print 'skip'
            continue
        aff_patches_from_LAFs2_in_1 = extract_patches(img1,
                                                      normalizeLAFs(LAFs2_in_1[idxs_in2.long(),:,:], 
                                                      img1.size(3), img1.size(2)))
        SIFTs1 = SIFT(aff_norm_patches1[idxs_in1.long(),:,:,:]).cuda()
        SIFTs2 = SIFT(aff_patches_from_LAFs2_in_1).cuda()
        #sift_snn_loss = loss_HardNet(SIFTs1, SIFTs2, column_row_swap = True,
        #                 margin = 1.0, batch_reduce = 'min', loss_type = "triplet_margin");
        patch_dist = 2.0 * torch.sqrt((aff_norm_patches1[idxs_in1.long(),:,:,:]/100. - aff_patches_from_LAFs2_in_1/100.) **2 + 1e-8).view(idxs_in1.size(0),-1).mean(dim = 1)
        sift_dist =  torch.sqrt(((SIFTs1 - SIFTs2)**2 + 1e-12).mean(dim=1))
        loss = geom_loss.cuda() .mean()
        total_loss += loss.data.cpu().numpy()[0]
        #loss += patch_dist
        total_feats += aff_patches_from_LAFs2_in_1.size(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if batch_idx % 10 == 0:
            print 'A', A1.data.cpu().numpy()[0:1,:,:]
        print 'crafted loss',  pr_l(geom_loss), 'patch', pr_l(patch_dist), 'sift', pr_l(sift_dist)#, 'hardnet',  pr_l(sift_snn_loss)
        print epoch,batch_idx, loss.data.cpu().numpy()[0], idxs_in1.shape

    print 'Train total loss:', total_loss / float(batch_idx+1), ' features ', float(total_feats) / float(batch_idx+1)
    torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()},
               '{}/new_loss_sep_checkpoint_{}.pth'.format(LOG_DIR, epoch))
예제 #12
0
def train(train_loader, model, optimizer, epoch, cuda=True):
    # switch to train mode
    model.train()
    log_interval = 1
    total_loss = 0
    spatial_only = True
    pbar = enumerate(train_loader)
    for batch_idx, data in pbar:
        print 'Batch idx', batch_idx
        #print model.detector.shift_net[0].weight.data.cpu().numpy()
        img1, img2, H1to2 = data
        #if np.abs(np.sum(H.numpy()) - 3.0) > 0.01:
        #    continue
        H1to2 = H1to2.squeeze(0)
        if (img1.size(3) * img1.size(4) > 1340 * 1000):
            print img1.shape, ' too big, skipping'
            continue
        img1 = img1.float().squeeze(0)
        #img1 = img1 - img1.mean()
        #img1 = img1 / 50.#(img1.std() + 1e-8)
        img2 = img2.float().squeeze(0)
        #img2 = img2 - img2.mean()
        #img2 = img2 / 50.#(img2.std() + 1e-8)
        if cuda:
            img1, img2, H1to2 = img1.cuda(), img2.cuda(), H1to2.cuda()
        img1, img2, H1to2 = Variable(img1, requires_grad=False), Variable(
            img2, requires_grad=False), Variable(H1to2, requires_grad=False)
        LAFs1, aff_norm_patches1, resp1, pyr1 = HA(img1)
        LAFs2, aff_norm_patches2, resp2, pyr2 = HA(img2)
        if (len(LAFs1) == 0) or (len(LAFs2) == 0):
            optimizer.zero_grad()
            continue
        fro_dists, idxs_in1, idxs_in2, LAFs2_in_1 = get_GT_correspondence_indexes_Fro_and_center(
            LAFs1,
            LAFs2,
            H1to2,
            dist_threshold=2.,
            center_dist_th=5.0,
            skip_center_in_Fro=True,
            do_up_is_up=True,
            return_LAF2_in_1=True)
        aff_patches_from_LAFs2_in_1 = extract_patches(img1, LAFs2_in_1)
        if len(fro_dists.size()) == 0:
            optimizer.zero_grad()
            print 'skip'
            continue
        loss = fro_dists.mean()
        total_loss += loss.data.cpu().numpy()[0]
        patch_dist = torch.mean(
            (aff_norm_patches1[idxs_in1.data.long(), :, :, :] -
             aff_patches_from_LAFs2_in_1[idxs_in2.data.long(), :, :, :])**2)
        print loss.data.cpu().numpy()[0], patch_dist.data.cpu().numpy()[0]
        loss += patch_dist / 100.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #adjust_learning_rate(optimizer)
        print epoch, batch_idx, loss.data.cpu().numpy()[0], idxs_in1.shape

    print 'Train total loss:', total_loss / float(batch_idx + 1)
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch))