def forward(self,x, random_Baum = False, random_resp = False, return_patches = False): ### Detection num_features_prefilter = self.num #if self.num_Baum_iters > 0: # num_features_prefilter = 2 * self.num; if random_resp: num_features_prefilter *= 4 responses, LAFs, final_pyr_idxs, final_level_idxs, scale_pyr = self.multiScaleDetector(x,num_features_prefilter) if random_resp: if self.num < responses.size(0): ridxs = torch.randperm(responses.size(0))[:self.num] if x.is_cuda: ridxs = ridxs.cuda() responses = responses[ridxs] LAFs = LAFs[ridxs ,:,:] final_pyr_idxs = final_pyr_idxs[ridxs] final_level_idxs = final_level_idxs[ridxs] LAFs[:,0:2,0:2] = self.mrSize * LAFs[:,:,0:2] n_iters = self.num_Baum_iters; if random_Baum and (n_iters > 1): n_iters = int(np.random.randint(1,n_iters + 1)) if n_iters > 0: responses, LAFs, final_pyr_idxs, final_level_idxs = self.getAffineShape(scale_pyr, responses, LAFs, final_pyr_idxs, final_level_idxs, self.num, n_iters = n_iters) #LAFs = self.getOrientation(scale_pyr, LAFs, final_pyr_idxs, final_level_idxs) #if return_patches: # pyr_inv_idxs = get_inverted_pyr_index(scale_pyr, final_pyr_idxs, final_level_idxs) # patches = extract_patches_from_pyramid_with_inv_index(scale_pyr, pyr_inv_idxs, LAFs, PS = self.PS) if return_patches: patches = extract_patches(x, LAFs, PS = self.PS) else: patches = None return denormalizeLAFs(LAFs, x.size(3), x.size(2)), patches, responses#, scale_pyr
def train(train_loader, model, optimizer, epoch): # switch to train mode model.train() pbar = tqdm(enumerate(train_loader)) for batch_idx, data in pbar: data_a, data_p = data if args.cuda: data_a, data_p = data_a.float().cuda(), data_p.float().cuda() data_a, data_p = Variable(data_a), Variable(data_p) rot_LAFs, inv_rotmat = get_random_rotation_LAFs(data_a, math.pi) scale = Variable( 0.9 + 0.3* torch.rand(data_a.size(0), 1, 1)); if args.cuda: scale = scale.cuda() rot_LAFs[:,0:2,0:2] = rot_LAFs[:,0:2,0:2] * scale.expand(data_a.size(0),2,2) shift_w, shift_h = get_random_shifts_LAFs(data_a, 2, 2) rot_LAFs[:,0,2] = rot_LAFs[:,0,2] + shift_w / float(data_a.size(3)) rot_LAFs[:,1,2] = rot_LAFs[:,1,2] + shift_h / float(data_a.size(2)) data_a_rot = extract_patches(data_a, rot_LAFs, PS = data_a.size(2)) st = int((data_p.size(2) - model.PS)/2) fin = st + model.PS data_p_crop = data_p[:,:, st:fin, st:fin].contiguous() data_a_rot_crop = data_a_rot[:,:, st:fin, st:fin].contiguous() out_a_rot, out_p, out_a = model(data_a_rot_crop,True), model(data_p_crop,True), model(data_a[:,:, st:fin, st:fin].contiguous(), True) out_p_rotatad = torch.bmm(inv_rotmat, out_p) ######Apply rot and get sifts out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(data_a_rot, out_a_rot, crop_size = model.PS) out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(data_p, out_p, crop_size = model.PS) desc_a = descriptor(out_patches_a_crop) desc_p = descriptor(out_patches_p_crop) descr_dist = torch.sqrt(((desc_a - desc_p)**2).view(data_a.size(0),-1).sum(dim=1) + 1e-6).mean() geom_dist = torch.sqrt(((out_a_rot - out_p_rotatad)**2 ).view(-1,4).sum(dim=1)[0] + 1e-8).mean() if args.loss == 'HardNet': loss = loss_HardNet(desc_a,desc_p); elif args.loss == 'HardNetDetach': loss = loss_HardNetDetach(desc_a,desc_p); elif args.loss == 'Geom': loss = geom_dist; elif args.loss == 'PosDist': loss = descr_dist; else: print('Unknown loss function') sys.exit(1) optimizer.zero_grad() loss.backward() optimizer.step() adjust_learning_rate(optimizer) if batch_idx % args.log_interval == 0: pbar.set_description( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}, {:.4f},{:.4f}'.format( epoch, batch_idx * len(data_a), len(train_loader.dataset), 100. * batch_idx / len(train_loader), float(loss.detach().cpu().numpy()), float(geom_dist.detach().cpu().numpy()), float(descr_dist.detach().cpu().numpy()))) torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()}, '{}/checkpoint_{}.pth'.format(LOG_DIR,epoch))
def extract_and_crop_patches_by_predicted_transform(patches, trans, crop_size = 32): assert patches.size(0) == trans.size(0) st = int((patches.size(2) - crop_size) / 2) fin = st + crop_size rot_LAFs = Variable(torch.FloatTensor([[0.5, 0, 0.5],[0, 0.5, 0.5]]).unsqueeze(0).repeat(patches.size(0),1,1)); if patches.is_cuda: rot_LAFs = rot_LAFs.cuda() trans = trans.cuda() rot_LAFs1 = torch.cat([torch.bmm(trans, rot_LAFs[:,0:2,0:2]), rot_LAFs[:,0:2,2:]], dim = 2); return extract_patches(patches, rot_LAFs1, PS = patches.size(2))[:,:, st:fin, st:fin].contiguous()
def extract_random_LAF(data, max_rot = math.pi, max_tilt = 1.0, crop_size = 32): st = int((data.size(2) - crop_size)/2) fin = st + crop_size if type(max_rot) is float: rot_LAFs, inv_rotmat = get_random_rotation_LAFs(data, max_rot) else: rot_LAFs = max_rot inv_rotmat = None aff_LAFs, inv_TA = get_random_norm_affine_LAFs(data, max_tilt); aff_LAFs[:,0:2,0:2] = torch.bmm(rot_LAFs[:,0:2,0:2],aff_LAFs[:,0:2,0:2]) data_aff = extract_patches(data, aff_LAFs, PS = data.size(2)) data_affcrop = data_aff[:,:, st:fin, st:fin].contiguous() return data_affcrop, data_aff, rot_LAFs,inv_rotmat,inv_TA
def forward(self, x, random_Baum=False, random_resp=False, return_patches=False): ### Detection num_features_prefilter = self.num #if self.num_Baum_iters > 0: # num_features_prefilter = 2 * self.num; if random_resp: num_features_prefilter *= 4 responses, LAFs, final_pyr_idxs, final_level_idxs, scale_pyr = self.multiScaleDetector( x, num_features_prefilter) if random_resp: if self.num < responses.size(0): ridxs = torch.randperm(responses.size(0))[:self.num] if x.is_cuda: ridxs = ridxs.cuda() responses = responses[ridxs] LAFs = LAFs[ridxs, :, :] final_pyr_idxs = final_pyr_idxs[ridxs] final_level_idxs = final_level_idxs[ridxs] LAFs[:, 0:2, 0:2] = self.mrSize * LAFs[:, :, 0:2] n_iters = self.num_Baum_iters if random_Baum and (n_iters > 1): n_iters = int(np.random.randint(1, n_iters + 1)) if n_iters > 0: responses, LAFs, final_pyr_idxs, final_level_idxs, dets, A = self.ImproveLAFsEstimation( scale_pyr, responses, LAFs, final_pyr_idxs, final_level_idxs, self.num, n_iters=n_iters) if return_patches: patches = extract_patches(x, LAFs, PS=self.PS) else: patches = None return denormalizeLAFs( LAFs, x.size(3), x.size(2)), patches, responses, dets, A #, scale_pyr
def test(test_loader, model, epoch): # switch to evaluate mode model.eval() geom_distances, desc_distances = [], [] pbar = tqdm(enumerate(test_loader)) for batch_idx, (data_a, data_p) in pbar: if args.cuda: data_a, data_p = data_a.float().cuda(), data_p.float().cuda() data_a, data_p = Variable(data_a, volatile=True), Variable(data_p, volatile=True) rot_LAFs, inv_rotmat = get_random_rotation_LAFs(data_a, math.pi) data_a_rot = extract_patches(data_a, rot_LAFs, PS = data_a.size(2)) st = int((data_p.size(2) - model.PS)/2) fin = st + model.PS data_p = data_p[:,:, st:fin, st:fin].contiguous() data_a_rot = data_a_rot[:,:, st:fin, st:fin].contiguous() out_a_rot, out_p = model(data_a_rot, True), model(data_p, True) out_p_rotatad = torch.bmm(inv_rotmat, out_p) geom_dist = torch.sqrt((out_a_rot - out_p_rotatad)**2 + 1e-12).mean() out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(data_a_rot, out_a_rot, crop_size = model.PS) out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(data_p, out_p, crop_size = model.PS) desc_a = descriptor(out_patches_a_crop) desc_p = descriptor(out_patches_p_crop) descr_dist = torch.sqrt(((desc_a - desc_p)**2).view(data_a.size(0),-1).sum(dim=1) + 1e-6)#/ float(desc_a.size(1)) descr_dist = descr_dist.mean() geom_distances.append(geom_dist.data.cpu().numpy().reshape(-1,1)) desc_distances.append(descr_dist.data.cpu().numpy().reshape(-1,1)) if batch_idx % args.log_interval == 0: pbar.set_description(' Test Epoch: {} [{}/{} ({:.0f}%)]'.format( epoch, batch_idx * len(data_a), len(test_loader.dataset), 100. * batch_idx / len(test_loader))) geom_distances = np.vstack(geom_distances).reshape(-1,1) desc_distances = np.vstack(desc_distances).reshape(-1,1) print('\33[91mTest set: Geom MSE: {:.8f}\n\33[0m'.format(geom_distances.mean())) print('\33[91mTest set: Desc dist: {:.8f}\n\33[0m'.format(desc_distances.mean())) return
def forward(self, cropped_feats, nLAFs): return self.HNHead(extract_patches(cropped_feats, nLAFs, PS=16))
def train(train_loader, model, optimizer, epoch, cuda=True): # switch to train mode model.train() log_interval = 1 total_loss = 0 total_feats = 0 spatial_only = True pbar = enumerate(train_loader) for batch_idx, data in pbar: print 'Batch idx', batch_idx #print model.detector.shift_net[0].weight.data.cpu().numpy() img1, img2, H1to2 = data #if np.abs(np.sum(H.numpy()) - 3.0) > 0.01: # continue H1to2 = H1to2.squeeze(0) if (img1.size(3) * img1.size(4) > 1340 * 1000): print img1.shape, ' too big, skipping' continue img1 = img1.float().squeeze(0) img2 = img2.float().squeeze(0) if cuda: img1, img2, H1to2 = img1.cuda(), img2.cuda(), H1to2.cuda() img1, img2, H1to2 = Variable(img1, requires_grad=False), Variable( img2, requires_grad=False), Variable(H1to2, requires_grad=False) LAFs1, aff_norm_patches1, resp1 = HA(img1, True, True, True) LAFs2, aff_norm_patches2, resp2 = HA(img2, True, True) if (len(LAFs1) == 0) or (len(LAFs2) == 0): optimizer.zero_grad() continue fro_dists, idxs_in1, idxs_in2, LAFs2_in_1 = get_GT_correspondence_indexes_Fro_and_center( LAFs1, LAFs2, H1to2, dist_threshold=4., center_dist_th=7.0, skip_center_in_Fro=True, do_up_is_up=True, return_LAF2_in_1=True) if len(fro_dists.size()) == 0: optimizer.zero_grad() print 'skip' continue aff_patches_from_LAFs2_in_1 = extract_patches( img1, normalizeLAFs(LAFs2_in_1[idxs_in2.data.long(), :, :], img1.size(3), img1.size(2))) #loss = fro_dists.mean() patch_dist = torch.sqrt( (aff_norm_patches1[idxs_in1.data.long(), :, :, :] / 100. - aff_patches_from_LAFs2_in_1 / 100.)**2 + 1e-8).view( fro_dists.size(0), -1).mean(dim=1) loss = (fro_dists * patch_dist).mean() print 'Fro dist', fro_dists.mean().data.cpu().numpy( )[0], loss.data.cpu().numpy()[0] total_loss += loss.data.cpu().numpy()[0] #loss += patch_dist total_feats += fro_dists.size(0) optimizer.zero_grad() loss.backward() optimizer.step() #adjust_learning_rate(optimizer) print epoch, batch_idx, loss.data.cpu().numpy()[0], idxs_in1.shape print 'Train total loss:', total_loss / float( batch_idx + 1), ' features ', float(total_feats) / float(batch_idx + 1) torch.save({ 'epoch': epoch + 1, 'state_dict': model.state_dict() }, '{}/elu_new_checkpoint_{}.pth'.format(LOG_DIR, epoch))
def test(test_loader, model, epoch): # switch to evaluate mode model.eval() geom_distances, desc_distances = [], [] pbar = tqdm(enumerate(test_loader)) for batch_idx, data in pbar: data_a, data_p = data if args.cuda: data_a, data_p = data_a.float().cuda(), data_p.float().cuda() data_a, data_p = Variable(data_a, volatile=True), Variable(data_p, volatile=True) st = int((data_p.size(2) - model.PS) / 2) fin = st + model.PS aff_LAFs_a, inv_TA_a = get_random_norm_affine_LAFs(data_a, 3.0) shift_w_a, shift_h_a = get_random_shifts_LAFs(data_a, 3, 3) aff_LAFs_a[:, 0, 2] = aff_LAFs_a[:, 0, 2] + shift_w_a / float(data_a.size(3)) aff_LAFs_a[:, 1, 2] = aff_LAFs_a[:, 1, 2] + shift_h_a / float(data_a.size(2)) data_a_aff = extract_patches(data_a, aff_LAFs_a, PS=data_a.size(2)) data_a_aff_crop = data_a_aff[:, :, st:fin, st:fin].contiguous() aff_LAFs_p, inv_TA_p = get_random_norm_affine_LAFs(data_p, 3.0) shift_w_p, shift_h_p = get_random_shifts_LAFs(data_p, 3, 3) aff_LAFs_p[:, 0, 2] = aff_LAFs_p[:, 0, 2] + shift_w_p / float(data_a.size(3)) aff_LAFs_p[:, 1, 2] = aff_LAFs_p[:, 1, 2] + shift_h_p / float(data_a.size(2)) data_p_aff = extract_patches(data_p, aff_LAFs_p, PS=data_p.size(2)) data_p_aff_crop = data_p_aff[:, :, st:fin, st:fin].contiguous() out_a_aff, out_p_aff = model(data_a_aff_crop, True), model(data_p_aff_crop, True) out_p_aff_back = torch.bmm(inv_TA_p, out_p_aff) out_a_aff_back = torch.bmm(inv_TA_a, out_a_aff) ######Apply rot and get sifts out_patches_a_crop = extract_and_crop_patches_by_predicted_transform( data_a_aff, out_a_aff, crop_size=model.PS) out_patches_p_crop = extract_and_crop_patches_by_predicted_transform( data_p_aff, out_p_aff, crop_size=model.PS) desc_a = descriptor(out_patches_a_crop) desc_p = descriptor(out_patches_p_crop) descr_dist = torch.sqrt(( (desc_a - desc_p)**2).view(data_a.size(0), -1).sum(dim=1) + 1e-6) / float(desc_a.size(1)) geom_dist = torch.sqrt(( (out_a_aff_back - out_p_aff_back)**2).view(-1, 4).mean(dim=1) + 1e-8) geom_distances.append(geom_dist.mean().data.cpu().numpy().reshape( -1, 1)) desc_distances.append(descr_dist.mean().data.cpu().numpy().reshape( -1, 1)) if batch_idx % args.log_interval == 0: pbar.set_description(' Test Epoch: {} [{}/{} ({:.0f}%)]'.format( epoch, batch_idx * len(data_a), len(test_loader.dataset), 100. * batch_idx / len(test_loader))) geom_distances = np.vstack(geom_distances).reshape(-1, 1) desc_distances = np.vstack(desc_distances).reshape(-1, 1) print('\33[91mTest set: Geom MSE: {:.8f}\n\33[0m'.format( geom_distances.mean())) print('\33[91mTest set: Desc dist: {:.8f}\n\33[0m'.format( desc_distances.mean())) return
def train(train_loader, model, optimizer, epoch): # switch to train mode model.train() pbar = tqdm(enumerate(train_loader)) for batch_idx, data in pbar: data_a, data_p = data if args.cuda: data_a, data_p = data_a.float().cuda(), data_p.float().cuda() data_a, data_p = Variable(data_a), Variable(data_p) st = int((data_p.size(2) - model.PS) / 2) fin = st + model.PS # # max_tilt = 3.0 if epoch > 1: max_tilt = 4.0 if epoch > 3: max_tilt = 4.5 if epoch > 5: max_tilt = 4.8 rot_LAFs_a, inv_rotmat_a = get_random_rotation_LAFs(data_a, math.pi) aff_LAFs_a, inv_TA_a = get_random_norm_affine_LAFs(data_a, max_tilt) aff_LAFs_a[:, 0:2, 0:2] = torch.bmm(rot_LAFs_a[:, 0:2, 0:2], aff_LAFs_a[:, 0:2, 0:2]) data_a_aff = extract_patches(data_a, aff_LAFs_a, PS=data_a.size(2)) data_a_aff_crop = data_a_aff[:, :, st:fin, st:fin].contiguous() aff_LAFs_p, inv_TA_p = get_random_norm_affine_LAFs(data_p, max_tilt) aff_LAFs_p[:, 0:2, 0:2] = torch.bmm(rot_LAFs_a[:, 0:2, 0:2], aff_LAFs_p[:, 0:2, 0:2]) data_p_aff = extract_patches(data_p, aff_LAFs_p, PS=data_p.size(2)) data_p_aff_crop = data_p_aff[:, :, st:fin, st:fin].contiguous() out_a_aff, out_p_aff = model(data_a_aff_crop, True), model(data_p_aff_crop, True) out_p_aff_back = torch.bmm(inv_TA_p, out_p_aff) out_a_aff_back = torch.bmm(inv_TA_a, out_a_aff) ######Apply rot and get sifts out_patches_a_crop = extract_and_crop_patches_by_predicted_transform( data_a_aff, out_a_aff, crop_size=model.PS) out_patches_p_crop = extract_and_crop_patches_by_predicted_transform( data_p_aff, out_p_aff, crop_size=model.PS) desc_a = descriptor(out_patches_a_crop) desc_p = descriptor(out_patches_p_crop) descr_dist = torch.sqrt(( (desc_a - desc_p)**2).view(data_a.size(0), -1).sum(dim=1) + 1e-6) descr_loss = loss_HardNet(desc_a, desc_p, anchor_swap=True) geom_dist = torch.sqrt(( (out_a_aff_back - out_p_aff_back)**2).view(-1, 4).mean(dim=1) + 1e-8) if args.merge == 'sum': loss = descr_loss elif args.merge == 'mul': loss = descr_loss else: print('Unknown merge option') sys.exit(0) optimizer.zero_grad() loss.backward() optimizer.step() adjust_learning_rate(optimizer) if batch_idx % 2 == 0: pbar.set_description( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}, {},{:.4f}'. format(epoch, batch_idx * len(data_a), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.data[0], geom_dist.mean().data[0], descr_dist.mean().data[0])) torch.save({ 'epoch': epoch + 1, 'state_dict': model.state_dict() }, '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch))
def train(train_loader, model, optimizer, epoch, cuda = True): # switch to train mode model.train() log_interval = 1 total_loss = 0 total_feats = 0 spatial_only = True pbar = enumerate(train_loader) for batch_idx, data in pbar: #if batch_idx > 0: # continue print 'Batch idx', batch_idx #print model.detector.shift_net[0].weight.data.cpu().numpy() img1, img2, H1to2 = data #if np.abs(np.sum(H.numpy()) - 3.0) > 0.01: # continue H1to2 = H1to2.squeeze(0) do_aug = True if torch.abs(H1to2 - torch.eye(3)).sum() > 0.05: do_aug = False if (img1.size(3) *img1.size(4) > 1340*1000): print img1.shape, ' too big, skipping' continue img1 = img1.float().squeeze(0) img2 = img2.float().squeeze(0) if cuda: img1, img2, H1to2 = img1.cuda(), img2.cuda(), H1to2.cuda() img1, img2, H1to2 = Variable(img1, requires_grad = False), Variable(img2, requires_grad = False), Variable(H1to2, requires_grad = False) if do_aug: new_img2, H_Orig2New = affineAug(img2, max_add = 0.2 ) H1to2new = torch.mm(H_Orig2New, H1to2) else: new_img2 = img2 H1to2new = H1to2 #print H1to2 LAFs1, aff_norm_patches1, resp1, dets1, A1 = HA(img1, True, False, True) LAFs2Aug, aff_norm_patches2, resp2, dets2, A2 = HA(new_img2, True, False) if (len(LAFs1) == 0) or (len(LAFs2Aug) == 0): optimizer.zero_grad() continue geom_loss, idxs_in1, idxs_in2, LAFs2_in_1 = LAFMagic(LAFs1, LAFs2Aug, H1to2new, 3.0, scale_log = 0.3) if len(idxs_in1.size()) == 0: optimizer.zero_grad() print 'skip' continue aff_patches_from_LAFs2_in_1 = extract_patches(img1, normalizeLAFs(LAFs2_in_1[idxs_in2.long(),:,:], img1.size(3), img1.size(2))) SIFTs1 = SIFT(aff_norm_patches1[idxs_in1.long(),:,:,:]).cuda() SIFTs2 = SIFT(aff_patches_from_LAFs2_in_1).cuda() #sift_snn_loss = loss_HardNet(SIFTs1, SIFTs2, column_row_swap = True, # margin = 1.0, batch_reduce = 'min', loss_type = "triplet_margin"); patch_dist = 2.0 * torch.sqrt((aff_norm_patches1[idxs_in1.long(),:,:,:]/100. - aff_patches_from_LAFs2_in_1/100.) **2 + 1e-8).view(idxs_in1.size(0),-1).mean(dim = 1) sift_dist = torch.sqrt(((SIFTs1 - SIFTs2)**2 + 1e-12).mean(dim=1)) loss = geom_loss.cuda() .mean() total_loss += loss.data.cpu().numpy()[0] #loss += patch_dist total_feats += aff_patches_from_LAFs2_in_1.size(0) optimizer.zero_grad() loss.backward() optimizer.step() adjust_learning_rate(optimizer) if batch_idx % 10 == 0: print 'A', A1.data.cpu().numpy()[0:1,:,:] print 'crafted loss', pr_l(geom_loss), 'patch', pr_l(patch_dist), 'sift', pr_l(sift_dist)#, 'hardnet', pr_l(sift_snn_loss) print epoch,batch_idx, loss.data.cpu().numpy()[0], idxs_in1.shape print 'Train total loss:', total_loss / float(batch_idx+1), ' features ', float(total_feats) / float(batch_idx+1) torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()}, '{}/new_loss_sep_checkpoint_{}.pth'.format(LOG_DIR, epoch))
def train(train_loader, model, optimizer, epoch, cuda=True): # switch to train mode model.train() log_interval = 1 total_loss = 0 spatial_only = True pbar = enumerate(train_loader) for batch_idx, data in pbar: print 'Batch idx', batch_idx #print model.detector.shift_net[0].weight.data.cpu().numpy() img1, img2, H1to2 = data #if np.abs(np.sum(H.numpy()) - 3.0) > 0.01: # continue H1to2 = H1to2.squeeze(0) if (img1.size(3) * img1.size(4) > 1340 * 1000): print img1.shape, ' too big, skipping' continue img1 = img1.float().squeeze(0) #img1 = img1 - img1.mean() #img1 = img1 / 50.#(img1.std() + 1e-8) img2 = img2.float().squeeze(0) #img2 = img2 - img2.mean() #img2 = img2 / 50.#(img2.std() + 1e-8) if cuda: img1, img2, H1to2 = img1.cuda(), img2.cuda(), H1to2.cuda() img1, img2, H1to2 = Variable(img1, requires_grad=False), Variable( img2, requires_grad=False), Variable(H1to2, requires_grad=False) LAFs1, aff_norm_patches1, resp1, pyr1 = HA(img1) LAFs2, aff_norm_patches2, resp2, pyr2 = HA(img2) if (len(LAFs1) == 0) or (len(LAFs2) == 0): optimizer.zero_grad() continue fro_dists, idxs_in1, idxs_in2, LAFs2_in_1 = get_GT_correspondence_indexes_Fro_and_center( LAFs1, LAFs2, H1to2, dist_threshold=2., center_dist_th=5.0, skip_center_in_Fro=True, do_up_is_up=True, return_LAF2_in_1=True) aff_patches_from_LAFs2_in_1 = extract_patches(img1, LAFs2_in_1) if len(fro_dists.size()) == 0: optimizer.zero_grad() print 'skip' continue loss = fro_dists.mean() total_loss += loss.data.cpu().numpy()[0] patch_dist = torch.mean( (aff_norm_patches1[idxs_in1.data.long(), :, :, :] - aff_patches_from_LAFs2_in_1[idxs_in2.data.long(), :, :, :])**2) print loss.data.cpu().numpy()[0], patch_dist.data.cpu().numpy()[0] loss += patch_dist / 100. optimizer.zero_grad() loss.backward() optimizer.step() #adjust_learning_rate(optimizer) print epoch, batch_idx, loss.data.cpu().numpy()[0], idxs_in1.shape print 'Train total loss:', total_loss / float(batch_idx + 1) torch.save({ 'epoch': epoch + 1, 'state_dict': model.state_dict() }, '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch))