예제 #1
0
def gpu_task(img_names, db_dir, save_dir):
    sosnet32 = sosnet_model.SOSNet32x32()
    net_name = 'notredame'
    sosnet32.load_state_dict(torch.load(os.path.join('sosnet-weights',"sosnet-32x32-"+net_name+".pth")))
    sosnet32.cuda().eval()

    local_detector = cv2.xfeatures2d.SIFT_create()

    for i, line in enumerate(img_names):
        img_path = os.path.join(db_dir, line)
        print img_path
        img = cv2.imread(img_path, 1)
        height, width = img.shape[:2]
        img_resize = cv2.resize(img, (int(0.5*width), int(0.5*height)))
        kpt = local_detector.detect(img, None)
        desc = tfeat_utils.describe_opencv(sosnet32, \
                cv2.cvtColor(img, cv2.COLOR_BGR2GRAY), kpt, \
                patch_size = 32, mag_factor = 7, use_gpu = True)
        with open(os.path.join(save_dir, line.split('.jpg')[0] + '.sosnet.sift'), 'w') as f:
            if desc is None:
                f.write(str(128) + '\n')
                f.write(str(0) + '\n')
                f.close()
                print "Null: %s" % line
                continue
            if len(desc) > 0:
                f.write(str(128) + '\n')
                f.write(str(len(kpt)) + '\n')
                for j in range(len(desc)):
                    locs_str = '0 0 0 0 0 '
                    descs_str = " ".join([str(float(value)) for value in desc[j]])
                    all_strs = locs_str + descs_str
                    f.write(all_strs + '\n')
                f.close()
            print "%d(%d), %s, desc: %d" %(i+1, len(img_names), line, len(desc))
예제 #2
0
    def __init__(self, do_cuda=True):
        print('Using SosnetFeature2D')
        self.model_base_path = config.cfg.root_folder + '/thirdparty/SOSNet/'

        self.do_cuda = do_cuda & torch.cuda.is_available()
        print('cuda:', self.do_cuda)
        device = torch.device("cuda:0" if self.do_cuda else "cpu")

        torch.set_grad_enabled(False)

        # mag_factor is how many times the original keypoint scale
        # is enlarged to generate a patch from a keypoint
        self.mag_factor = 3

        print('==> Loading pre-trained network.')
        #init tfeat and load the trained weights
        self.model = sosnet_model.SOSNet32x32()
        self.net_name = 'liberty'  # liberty, hpatches_a, notredame, yosemite  (see folder /thirdparty/SOSNet/sosnet-weights)
        self.model.load_state_dict(
            torch.load(
                os.path.join(self.model_base_path, 'sosnet-weights',
                             "sosnet-32x32-" + self.net_name + ".pth")))
        if self.do_cuda:
            self.model.cuda()
            print('Extracting on GPU')
        else:
            print('Extracting on CPU')
            self.model = model.cpu()
        self.model.eval()
        print('==> Successfully loaded pre-trained network.')
예제 #3
0
def init():
    global is_initialized, img_dict, sosnet32

    if not is_initialized:
        is_initialized = True
        torch.no_grad()

        # Init the 32x32 version of SOSNet.
        sosnet32 = sosnet_model.SOSNet32x32()
        net_name = 'notredame'
        sosnet32.load_state_dict(torch.load(
            os.path.join('sosnet-weights',
                         "sosnet-32x32-" + net_name + ".pth")),
                                 strict=False)
        sosnet32.cuda().eval()

        # Load the images and detect BRISK keypoints using openCV.
        brisk = cv2.BRISK_create(100)

        # Verifying if precomputed features are present.
        if os.path.exists(sosnet_constants.SOSNET_FEATURES_PATH):
            img_dict = pickle.load(
                open(sosnet_constants.SOSNET_FEATURES_PATH, "rb"))
        else:
            print(
                "sosnet_search.py :: init :: Constructing features for the images"
            )

            for img in os.listdir(sosnet_constants.SOSNET_IMAGE_DIR):
                try:
                    # Loading the images and detecting BRISK keypoints using openCV.
                    image_vec = cv2.imread(
                        sosnet_constants.SOSNET_IMAGE_DIR + '{}'.format(img),
                        0)
                    kp = brisk.detect(image_vec, None)

                    # Using the tfeat_utils method to rectify patches around openCV keypoints.
                    desc_tfeat = tfeat_utils.describe_opencv(sosnet32,
                                                             image_vec,
                                                             kp,
                                                             patch_size=32,
                                                             mag_factor=3)

                    img_dict[img] = desc_tfeat
                except Exception as e:
                    print(e)
                    print(
                        "sosnet_search.py :: init :: Error while indexing :: ",
                        img)

            with open(sosnet_constants.SOSNET_FEATURES_PATH, 'wb') as handle:
                pickle.dump(img_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

            print("sosnet_search.py :: init :: Finished processing ",
                  len(img_dict), "images")
예제 #4
0
import sys
sys.path.append("../../")
import config
config.cfg.set_lib('sosnet')

import torch
import sosnet_model
import os

tfeat_base_path = '../../thirdparty/SOSNet/'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)

sosnet32 = sosnet_model.SOSNet32x32()
net_name = 'liberty'
sosnet32.load_state_dict(
    torch.load(
        os.path.join(tfeat_base_path, 'sosnet-weights',
                     "sosnet-32x32-" + net_name + ".pth")))
sosnet32.cuda().eval()

patches = torch.rand(100, 1, 32, 32).to(device)
descrs = sosnet32(patches)

print('done!')
예제 #5
0
def train():
    # ----------------------------------------
    # Parse configuration
    config, unparsed = get_config()
    # If we have unparsed arguments, print usage and exit
    if len(unparsed) > 0:
        print_usage()
        exit(1)
    print_config(config)
    print("Number of train samples: ", len(train_data))
    print("Number of test samples: ", len(test_data))
    #print("Detected Classes are: ", train_data.class_to_idx) # classes are detected by folder structure

    # Create log directory and save directory if it does not exist
    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    if not os.path.exists(config.save_dir):
        os.makedirs(config.save_dir)

    # Initialize training
    iter_idx = -1  # make counter start at zero
    best_loss = -1  # to check if best loss
    # Prepare checkpoint file and model file to save and load from
    checkpoint_file = os.path.join(config.save_dir, "checkpoint.pth")
    bestmodel_file = os.path.join(config.save_dir, "best_model.pth")
    savemodel_file = os.path.join(config.save_dir, "save_model.pth")

    model = sosnet_model.SOSNet32x32().cuda()
    optimizer = optim.SGD(model.parameters(), lr=0.1, weight_decay=0.0001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=10,
                                                gamma=0.8)

    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)
    # Create loss objects
    data_loss = data_criterion(config)
    model_loss = model_criterion(config)
    #print("train_data: ", (train_data))
    print("train_data_loader:", train_data_loader)
    print("test_data_loader:", test_data_loader)

    fpr_per_epoch = []
    # Training loop
    for epoch in range(config.num_epoch):
        # For each iteration
        prefix = "Training Epoch {:3d}: ".format(epoch)
        print("len(train_data_loader):", len(train_data_loader))
        for batch_idx, (data_a, data_p,
                        data_n) in tqdm(enumerate(train_data_loader)):
            print("batch_idx:", batch_idx)
            print("len(train_data_loader):", len(train_data_loader))
            data_a = data_a.unsqueeze(1).float().cuda()
            data_p = data_p.unsqueeze(1).float().cuda()
            data_n = data_n.unsqueeze(1).float().cuda()
            print("data_a.shape:", data_a.shape)
            print("data_p.shape:", data_p.shape)
            print("data_n.shape:", data_n.shape)
            out_a, out_p, out_n = model(data_a), model(data_p), model(data_n)
            print("out_a:", out_a)
            print("out_p:", out_p)
            print("out_n:", out_n)
            loss = F.triplet_margin_loss(out_a,
                                         out_p,
                                         out_n,
                                         margin=2,
                                         swap=True)
            if best_loss == -1:
                best_loss = loss
            if loss < best_loss:
                best_loss = loss
                # Save
                torch.save(
                    {
                        "iter_idx": iter_idx,
                        "best_loss": best_loss,
                        "model": model.state_dict(),
                        "optimizer": optimizer.state_dict(),
                    }, bestmodel_file)
                # Save
                torch.save(model.state_dict(), savemodel_file)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Save
        torch.save(
            {
                "iter_idx": iter_idx,
                "best_loss": best_loss,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }, checkpoint_file)

        model.eval()

        l = np.empty((0, ))
        d = np.empty((0, ))
        #evaluate the network after each epoch
        for batch_idx, (data_l, data_r, lbls) in enumerate(test_data_loader):
            data_l = data_l.unsqueeze(1).float().cuda()
            data_r = data_r.unsqueeze(1).float().cuda()
            out_l, out_r = model(data_l), model(data_r)
            dists = torch.norm(out_l - out_r, 2, 1).detach().cpu().numpy()
            l = np.hstack((l, lbls.numpy()))
            d = np.hstack((d, dists))

        # FPR95 code from Yurun Tian
        d = torch.from_numpy(d)
        l = torch.from_numpy(l)
        dist_pos = d[l == 1]
        dist_neg = d[l != 1]
        dist_pos, indice = torch.sort(dist_pos)
        loc_thr = int(np.ceil(dist_pos.numel() * 0.95))
        thr = dist_pos[loc_thr]
        fpr95 = float(dist_neg.le(thr).sum()) / dist_neg.numel()
        print(epoch, fpr95)
        fpr_per_epoch.append([epoch, fpr95])
        scheduler.step()
        np.savetxt('fpr.txt', np.array(fpr_per_epoch), delimiter=',')