Exemplo n.º 1
0
def train(train_loader, model, optimizer, epoch, cuda=True):
    # switch to train mode
    model.train()
    log_interval = 1
    total_loss = 0
    spatial_only = True
    pbar = enumerate(train_loader)
    for batch_idx, data in pbar:
        print 'Batch idx', batch_idx
        #print model.detector.shift_net[0].weight.data.cpu().numpy()
        img1, img2, H1to2 = data
        #if np.abs(np.sum(H.numpy()) - 3.0) > 0.01:
        #    continue
        H1to2 = H1to2.squeeze(0)
        if (img1.size(3) * img1.size(4) > 1340 * 1000):
            print img1.shape, ' too big, skipping'
            continue
        img1 = img1.float().squeeze(0)
        #img1 = img1 - img1.mean()
        #img1 = img1 / 50.#(img1.std() + 1e-8)
        img2 = img2.float().squeeze(0)
        #img2 = img2 - img2.mean()
        #img2 = img2 / 50.#(img2.std() + 1e-8)
        if cuda:
            img1, img2, H1to2 = img1.cuda(), img2.cuda(), H1to2.cuda()
        img1, img2, H1to2 = Variable(img1, requires_grad=False), Variable(
            img2, requires_grad=False), Variable(H1to2, requires_grad=False)
        LAFs1, aff_norm_patches1, resp1, pyr1 = HA(img1)
        LAFs2, aff_norm_patches2, resp2, pyr2 = HA(img2)
        if (len(LAFs1) == 0) or (len(LAFs2) == 0):
            optimizer.zero_grad()
            continue
        fro_dists, idxs_in1, idxs_in2 = get_GT_correspondence_indexes_Fro_and_center(
            LAFs1,
            LAFs2,
            H1to2,
            dist_threshold=2.,
            center_dist_th=5.0,
            skip_center_in_Fro=True,
            do_up_is_up=True)
        if len(fro_dists.size()) == 0:
            optimizer.zero_grad()
            print 'skip'
            continue
        loss = fro_dists.mean()
        total_loss += loss.data.cpu().numpy()[0]
        #patch_dist = torch.mean((aff_norm_patches1[idxs_in1.data.long(),:,:,:] - aff_norm_patches2[idxs_in2.data.long(), :,:,:]) **2)
        print loss.data.cpu().numpy()[0]  #, patch_dist.data.cpu().numpy()[0]
        #loss += patch_dist
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #adjust_learning_rate(optimizer)
        print epoch, batch_idx, loss.data.cpu().numpy()[0], idxs_in1.shape

    print 'Train total loss:', total_loss / float(batch_idx + 1)
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch))
def test(test_loader, model, cuda=True):
    # switch to train mode
    model_num_feats = model.num
    model.num = 1500
    model.eval()
    log_interval = 1
    pbar = enumerate(test_loader)
    total_loss = 0
    total_feats = 0
    for batch_idx, data in pbar:
        print 'Batch idx', batch_idx
        img1, img2, H1to2 = data
        if (img1.size(3) * img1.size(4) > 1500 * 1200):
            print img1.shape, ' too big, skipping'
            continue
        H1to2 = H1to2.squeeze(0)
        img1 = img1.float().squeeze(0)
        img2 = img2.float().squeeze(0)
        if cuda:
            img1, img2, H1to2 = img1.cuda(), img2.cuda(), H1to2.cuda()
        img1, img2, H1to2 = Variable(img1, volatile=True), Variable(
            img2, volatile=True), Variable(H1to2, volatile=True)
        LAFs1, aff_norm_patches1, resp1 = HA(img1)
        LAFs2, aff_norm_patches2, resp2 = HA(img2)
        if (len(LAFs1) == 0) or (len(LAFs2) == 0):
            continue
        fro_dists, idxs_in1, idxs_in2 = get_GT_correspondence_indexes_Fro_and_center(
            LAFs1,
            LAFs2,
            H1to2,
            dist_threshold=3.,
            center_dist_th=7.0,
            scale_diff_coef=0.4,
            skip_center_in_Fro=True,
            do_up_is_up=True)
        if len(fro_dists.size()) == 0:
            print 'skip'
            continue
        loss = fro_dists.mean()
        total_feats += fro_dists.size(0)
        total_loss += loss.data.cpu().numpy()[0]
        print 'test img', batch_idx, loss.data.cpu().numpy(
        )[0], fro_dists.size(0)
    print 'Total loss:', total_loss / float(batch_idx + 1), 'features', float(
        total_feats) / float(batch_idx + 1)
    model.num = model_num_feats
Exemplo n.º 3
0
def train(train_loader, model, optimizer, epoch, cuda=True):
    # switch to train mode
    model.train()
    log_interval = 1
    total_loss = 0
    total_feats = 0
    spatial_only = True
    pbar = enumerate(train_loader)
    for batch_idx, data in pbar:
        print 'Batch idx', batch_idx
        #print model.detector.shift_net[0].weight.data.cpu().numpy()
        img1, img2, H1to2 = data
        #if np.abs(np.sum(H.numpy()) - 3.0) > 0.01:
        #    continue
        H1to2 = H1to2.squeeze(0)
        if (img1.size(3) * img1.size(4) > 1340 * 1000):
            print img1.shape, ' too big, skipping'
            continue
        img1 = img1.float().squeeze(0)
        img2 = img2.float().squeeze(0)
        if cuda:
            img1, img2, H1to2 = img1.cuda(), img2.cuda(), H1to2.cuda()
        img1, img2, H1to2 = Variable(img1, requires_grad=False), Variable(
            img2, requires_grad=False), Variable(H1to2, requires_grad=False)
        LAFs1, aff_norm_patches1, resp1 = HA(img1, True, True, True)
        LAFs2, aff_norm_patches2, resp2 = HA(img2, True, True)
        if (len(LAFs1) == 0) or (len(LAFs2) == 0):
            optimizer.zero_grad()
            continue
        fro_dists, idxs_in1, idxs_in2, LAFs2_in_1 = get_GT_correspondence_indexes_Fro_and_center(
            LAFs1,
            LAFs2,
            H1to2,
            dist_threshold=4.,
            center_dist_th=7.0,
            skip_center_in_Fro=True,
            do_up_is_up=True,
            return_LAF2_in_1=True)
        if len(fro_dists.size()) == 0:
            optimizer.zero_grad()
            print 'skip'
            continue
        aff_patches_from_LAFs2_in_1 = extract_patches(
            img1,
            normalizeLAFs(LAFs2_in_1[idxs_in2.data.long(), :, :], img1.size(3),
                          img1.size(2)))

        #loss = fro_dists.mean()
        patch_dist = torch.sqrt(
            (aff_norm_patches1[idxs_in1.data.long(), :, :, :] / 100. -
             aff_patches_from_LAFs2_in_1 / 100.)**2 + 1e-8).view(
                 fro_dists.size(0), -1).mean(dim=1)
        loss = (fro_dists * patch_dist).mean()
        print 'Fro dist', fro_dists.mean().data.cpu().numpy(
        )[0], loss.data.cpu().numpy()[0]
        total_loss += loss.data.cpu().numpy()[0]
        #loss += patch_dist
        total_feats += fro_dists.size(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        #adjust_learning_rate(optimizer)
        print epoch, batch_idx, loss.data.cpu().numpy()[0], idxs_in1.shape

    print 'Train total loss:', total_loss / float(
        batch_idx + 1), ' features ', float(total_feats) / float(batch_idx + 1)
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}/elu_new_checkpoint_{}.pth'.format(LOG_DIR, epoch))