def train(train_loader, model, optimizer, epoch, cuda = True):
    # switch to train mode
    model.train()
    log_interval = 1
    total_loss = 0
    total_feats = 0
    spatial_only = True
    pbar = enumerate(train_loader)
    for batch_idx, data in pbar:
        #if batch_idx > 0:
        #    continue
        print 'Batch idx', batch_idx
        #print model.detector.shift_net[0].weight.data.cpu().numpy()
        img1, img2, H1to2  = data
        #if np.abs(np.sum(H.numpy()) - 3.0) > 0.01:
        #    continue
        H1to2 = H1to2.squeeze(0)
        do_aug = True
        if torch.abs(H1to2 - torch.eye(3)).sum() > 0.05:
            do_aug = False
        if (img1.size(3) *img1.size(4)   > 1340*1000):
            print img1.shape, ' too big, skipping'
            continue
        img1 = img1.float().squeeze(0)
        img2 = img2.float().squeeze(0)
        if cuda:
            img1, img2, H1to2 = img1.cuda(), img2.cuda(), H1to2.cuda()
        img1, img2, H1to2 = Variable(img1, requires_grad = False), Variable(img2, requires_grad = False), Variable(H1to2, requires_grad = False)
        if do_aug:
            new_img2, H_Orig2New = affineAug(img2, max_add = 0.2 )
            H1to2new = torch.mm(H_Orig2New, H1to2)
        else:
            new_img2 = img2
            H1to2new = H1to2
        #print H1to2
        LAFs1, aff_norm_patches1, resp1, dets1, A1 = HA(img1, True, False, True)
        LAFs2Aug, aff_norm_patches2, resp2, dets2, A2 = HA(new_img2, True, False)
        if (len(LAFs1) == 0) or (len(LAFs2Aug) == 0):
            optimizer.zero_grad()
            continue
        geom_loss, idxs_in1, idxs_in2, LAFs2_in_1  = LAFMagic(LAFs1,
                            LAFs2Aug,
                            H1to2new,
                           3.0, scale_log = 0.3)
        if  len(idxs_in1.size()) == 0:
            optimizer.zero_grad()
            print 'skip'
            continue
        aff_patches_from_LAFs2_in_1 = extract_patches(img1,
                                                      normalizeLAFs(LAFs2_in_1[idxs_in2.long(),:,:], 
                                                      img1.size(3), img1.size(2)))
        SIFTs1 = SIFT(aff_norm_patches1[idxs_in1.long(),:,:,:]).cuda()
        SIFTs2 = SIFT(aff_patches_from_LAFs2_in_1).cuda()
        #sift_snn_loss = loss_HardNet(SIFTs1, SIFTs2, column_row_swap = True,
        #                 margin = 1.0, batch_reduce = 'min', loss_type = "triplet_margin");
        patch_dist = 2.0 * torch.sqrt((aff_norm_patches1[idxs_in1.long(),:,:,:]/100. - aff_patches_from_LAFs2_in_1/100.) **2 + 1e-8).view(idxs_in1.size(0),-1).mean(dim = 1)
        sift_dist =  torch.sqrt(((SIFTs1 - SIFTs2)**2 + 1e-12).mean(dim=1))
        loss = geom_loss.cuda() .mean()
        total_loss += loss.data.cpu().numpy()[0]
        #loss += patch_dist
        total_feats += aff_patches_from_LAFs2_in_1.size(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if batch_idx % 10 == 0:
            print 'A', A1.data.cpu().numpy()[0:1,:,:]
        print 'crafted loss',  pr_l(geom_loss), 'patch', pr_l(patch_dist), 'sift', pr_l(sift_dist)#, 'hardnet',  pr_l(sift_snn_loss)
        print epoch,batch_idx, loss.data.cpu().numpy()[0], idxs_in1.shape

    print 'Train total loss:', total_loss / float(batch_idx+1), ' features ', float(total_feats) / float(batch_idx+1)
    torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()},
               '{}/new_loss_sep_checkpoint_{}.pth'.format(LOG_DIR, epoch))
def train(train_loader, model, optimizer, epoch, cuda=True):
    # switch to train mode
    model.train()
    log_interval = 1
    total_loss = 0
    total_feats = 0
    spatial_only = True
    pbar = enumerate(train_loader)
    for batch_idx, data in pbar:
        print 'Batch idx', batch_idx
        #print model.detector.shift_net[0].weight.data.cpu().numpy()
        img1, img2, H1to2 = data
        #if np.abs(np.sum(H.numpy()) - 3.0) > 0.01:
        #    continue
        H1to2 = H1to2.squeeze(0)
        if (img1.size(3) * img1.size(4) > 1340 * 1000):
            print img1.shape, ' too big, skipping'
            continue
        img1 = img1.float().squeeze(0)
        img2 = img2.float().squeeze(0)
        if cuda:
            img1, img2, H1to2 = img1.cuda(), img2.cuda(), H1to2.cuda()
        img1, img2, H1to2 = Variable(img1, requires_grad=False), Variable(
            img2, requires_grad=False), Variable(H1to2, requires_grad=False)
        new_img2, H_Orig2New = affineAug(img2)
        H1to2new = torch.mm(H_Orig2New, H1to2)
        #print H1to2
        LAFs1, aff_norm_patches1, resp1 = HA(img1, True, True, True)
        LAFs2Aug, aff_norm_patches2, resp2 = HA(new_img2, True, True)
        if (len(LAFs1) == 0) or (len(LAFs2Aug) == 0):
            optimizer.zero_grad()
            continue
        fro_dists, idxs_in1, idxs_in2, LAFs2_in_1 = get_GT_correspondence_indexes_Fro_and_center(
            LAFs1,
            LAFs2Aug,
            H1to2new,
            dist_threshold=10.0,
            center_dist_th=7.0,
            scale_diff_coef=0.4,
            skip_center_in_Fro=True,
            do_up_is_up=True,
            return_LAF2_in_1=True)
        if len(fro_dists.size()) == 0:
            optimizer.zero_grad()
            print 'skip'
            continue
        aff_patches_from_LAFs2_in_1 = extract_patches(
            img1,
            normalizeLAFs(LAFs2_in_1[idxs_in2.data.long(), :, :], img1.size(3),
                          img1.size(2)))

        #loss = fro_dists.mean()
        patch_dist = torch.sqrt(
            (aff_norm_patches1[idxs_in1.data.long(), :, :, :] / 100. -
             aff_patches_from_LAFs2_in_1 / 100.)**2 + 1e-8).view(
                 fro_dists.size(0), -1).mean(dim=1)
        loss = (fro_dists * patch_dist).mean()
        print 'Fro dist', fro_dists.mean().data.cpu().numpy(
        )[0], loss.data.cpu().numpy()[0]
        total_loss += loss.data.cpu().numpy()[0]
        #loss += patch_dist
        total_feats += fro_dists.size(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #adjust_learning_rate(optimizer)
        print epoch, batch_idx, loss.data.cpu().numpy()[0], idxs_in1.shape

    print 'Train total loss:', total_loss / float(
        batch_idx + 1), ' features ', float(total_feats) / float(batch_idx + 1)
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}/elu_new_checkpoint_{}.pth'.format(LOG_DIR, epoch))