def train(train_loader, model, optimizer, epoch):
    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        data_a, data_p = data
        if args.cuda:
            data_a, data_p  = data_a.float().cuda(), data_p.float().cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
        rot_LAFs, inv_rotmat = get_random_rotation_LAFs(data_a, math.pi)
        scale = Variable( 0.9 + 0.3* torch.rand(data_a.size(0), 1, 1));
        if args.cuda:
            scale = scale.cuda()
        rot_LAFs[:,0:2,0:2] = rot_LAFs[:,0:2,0:2] * scale.expand(data_a.size(0),2,2)
        shift_w, shift_h = get_random_shifts_LAFs(data_a, 2, 2)
        rot_LAFs[:,0,2] = rot_LAFs[:,0,2] + shift_w / float(data_a.size(3))
        rot_LAFs[:,1,2] = rot_LAFs[:,1,2] + shift_h / float(data_a.size(2))
        data_a_rot = extract_patches(data_a,  rot_LAFs, PS = data_a.size(2))
        st = int((data_p.size(2) - model.PS)/2)
        fin = st + model.PS

        data_p_crop = data_p[:,:, st:fin, st:fin].contiguous()
        data_a_rot_crop = data_a_rot[:,:, st:fin, st:fin].contiguous()
        out_a_rot, out_p, out_a = model(data_a_rot_crop,True), model(data_p_crop,True), model(data_a[:,:, st:fin, st:fin].contiguous(), True)
        out_p_rotatad = torch.bmm(inv_rotmat, out_p)

        ######Apply rot and get sifts
        out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(data_a_rot, out_a_rot, crop_size = model.PS)
        out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(data_p, out_p, crop_size = model.PS)

        desc_a = descriptor(out_patches_a_crop)
        desc_p = descriptor(out_patches_p_crop)
        descr_dist =  torch.sqrt(((desc_a - desc_p)**2).view(data_a.size(0),-1).sum(dim=1) + 1e-6).mean()
        geom_dist = torch.sqrt(((out_a_rot - out_p_rotatad)**2 ).view(-1,4).sum(dim=1)[0] + 1e-8).mean()
        if args.loss == 'HardNet':
            loss = loss_HardNet(desc_a,desc_p); 
        elif args.loss == 'HardNetDetach':
            loss = loss_HardNetDetach(desc_a,desc_p); 
        elif args.loss == 'Geom':
            loss = geom_dist; 
        elif args.loss == 'PosDist':
            loss = descr_dist; 
        else:
            print('Unknown loss function')
            sys.exit(1)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if batch_idx % args.log_interval == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}, {:.4f},{:.4f}'.format(
                    epoch, batch_idx * len(data_a), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader),
                    float(loss.detach().cpu().numpy()), float(geom_dist.detach().cpu().numpy()), float(descr_dist.detach().cpu().numpy())))
    torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()},
               '{}/checkpoint_{}.pth'.format(LOG_DIR,epoch))
def extract_random_LAF(data, max_rot = math.pi, max_tilt = 1.0, crop_size = 32):
    st = int((data.size(2) - crop_size)/2)
    fin = st + crop_size
    if type(max_rot) is float:
        rot_LAFs, inv_rotmat = get_random_rotation_LAFs(data, max_rot)
    else:
        rot_LAFs = max_rot
        inv_rotmat = None
    aff_LAFs, inv_TA = get_random_norm_affine_LAFs(data, max_tilt);
    aff_LAFs[:,0:2,0:2] = torch.bmm(rot_LAFs[:,0:2,0:2],aff_LAFs[:,0:2,0:2])
    data_aff = extract_patches(data,  aff_LAFs, PS = data.size(2))
    data_affcrop = data_aff[:,:, st:fin, st:fin].contiguous()
    return data_affcrop, data_aff, rot_LAFs,inv_rotmat,inv_TA 
Esempio n. 3
0
def test(test_loader, model, epoch):
    # switch to evaluate mode
    model.eval()

    geom_distances, desc_distances = [], []

    pbar = tqdm(enumerate(test_loader))
    for batch_idx, (data_a, data_p) in pbar:

        if args.cuda:
            data_a, data_p = data_a.float().cuda(), data_p.float().cuda()
        data_a, data_p = Variable(data_a, volatile=True), Variable(data_p, volatile=True)
        rot_LAFs, inv_rotmat = get_random_rotation_LAFs(data_a, math.pi)
        data_a_rot = extract_patches(data_a,  rot_LAFs, PS = data_a.size(2))
        st = int((data_p.size(2) - model.PS)/2)
        fin = st + model.PS
        data_p = data_p[:,:, st:fin, st:fin].contiguous()
        data_a_rot = data_a_rot[:,:, st:fin, st:fin].contiguous()
        out_a_rot, out_p = model(data_a_rot, True), model(data_p, True)
        out_p_rotatad = torch.bmm(inv_rotmat, out_p)
        geom_dist = torch.sqrt((out_a_rot - out_p_rotatad)**2 + 1e-12).mean()
        out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(data_a_rot, out_a_rot, crop_size = model.PS)
        out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(data_p, out_p, crop_size = model.PS)
        desc_a = descriptor(out_patches_a_crop)
        desc_p = descriptor(out_patches_p_crop)
        descr_dist =  torch.sqrt(((desc_a - desc_p)**2).view(data_a.size(0),-1).sum(dim=1) + 1e-6)#/ float(desc_a.size(1))
        descr_dist = descr_dist.mean()
        geom_distances.append(geom_dist.data.cpu().numpy().reshape(-1,1))
        desc_distances.append(descr_dist.data.cpu().numpy().reshape(-1,1))
        if batch_idx % args.log_interval == 0:
            pbar.set_description(' Test Epoch: {} [{}/{} ({:.0f}%)]'.format(
                epoch, batch_idx * len(data_a), len(test_loader.dataset),
                       100. * batch_idx / len(test_loader)))

    geom_distances = np.vstack(geom_distances).reshape(-1,1)
    desc_distances = np.vstack(desc_distances).reshape(-1,1)

    print('\33[91mTest set: Geom MSE: {:.8f}\n\33[0m'.format(geom_distances.mean()))
    print('\33[91mTest set: Desc dist: {:.8f}\n\33[0m'.format(desc_distances.mean()))
    return
Esempio n. 4
0
def train(train_loader, model, optimizer, epoch):
    # switch to train mode
    model.train()
    pbar = tqdm(enumerate(train_loader))
    for batch_idx, data in pbar:
        data_a, data_p = data
        if args.cuda:
            data_a, data_p = data_a.float().cuda(), data_p.float().cuda()
            data_a, data_p = Variable(data_a), Variable(data_p)
        st = int((data_p.size(2) - model.PS) / 2)
        fin = st + model.PS
        #
        #
        max_tilt = 3.0
        if epoch > 1:
            max_tilt = 4.0
        if epoch > 3:
            max_tilt = 4.5
        if epoch > 5:
            max_tilt = 4.8
        rot_LAFs_a, inv_rotmat_a = get_random_rotation_LAFs(data_a, math.pi)
        aff_LAFs_a, inv_TA_a = get_random_norm_affine_LAFs(data_a, max_tilt)
        aff_LAFs_a[:, 0:2, 0:2] = torch.bmm(rot_LAFs_a[:, 0:2, 0:2],
                                            aff_LAFs_a[:, 0:2, 0:2])
        data_a_aff = extract_patches(data_a, aff_LAFs_a, PS=data_a.size(2))
        data_a_aff_crop = data_a_aff[:, :, st:fin, st:fin].contiguous()
        aff_LAFs_p, inv_TA_p = get_random_norm_affine_LAFs(data_p, max_tilt)
        aff_LAFs_p[:, 0:2, 0:2] = torch.bmm(rot_LAFs_a[:, 0:2, 0:2],
                                            aff_LAFs_p[:, 0:2, 0:2])
        data_p_aff = extract_patches(data_p, aff_LAFs_p, PS=data_p.size(2))
        data_p_aff_crop = data_p_aff[:, :, st:fin, st:fin].contiguous()
        out_a_aff, out_p_aff = model(data_a_aff_crop,
                                     True), model(data_p_aff_crop, True)
        out_p_aff_back = torch.bmm(inv_TA_p, out_p_aff)
        out_a_aff_back = torch.bmm(inv_TA_a, out_a_aff)
        ######Apply rot and get sifts
        out_patches_a_crop = extract_and_crop_patches_by_predicted_transform(
            data_a_aff, out_a_aff, crop_size=model.PS)
        out_patches_p_crop = extract_and_crop_patches_by_predicted_transform(
            data_p_aff, out_p_aff, crop_size=model.PS)
        desc_a = descriptor(out_patches_a_crop)
        desc_p = descriptor(out_patches_p_crop)
        descr_dist = torch.sqrt((
            (desc_a - desc_p)**2).view(data_a.size(0), -1).sum(dim=1) + 1e-6)
        descr_loss = loss_HardNet(desc_a, desc_p, anchor_swap=True)
        geom_dist = torch.sqrt((
            (out_a_aff_back - out_p_aff_back)**2).view(-1, 4).mean(dim=1) +
                               1e-8)
        if args.merge == 'sum':
            loss = descr_loss
        elif args.merge == 'mul':
            loss = descr_loss
        else:
            print('Unknown merge option')
            sys.exit(0)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        adjust_learning_rate(optimizer)
        if batch_idx % 2 == 0:
            pbar.set_description(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}, {},{:.4f}'.
                format(epoch, batch_idx * len(data_a),
                       len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.data[0],
                       geom_dist.mean().data[0],
                       descr_dist.mean().data[0]))
    torch.save({
        'epoch': epoch + 1,
        'state_dict': model.state_dict()
    }, '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch))