def measure_performance(labels_dir, 
            img_dir, img_fts_dir,
            weights_file,
            subset="inference"):
    """
    Given a weights file, calculate the mean average precision over all the queries for the corresponding dataset

    Args:
        labels_dir  : Directory for ground truth labels
        img_dir     : Directory holding the images
        img_fts_dir : Directory holding the pca reduced features generated through create_db.py script
        weights_file: path of trained weights file
        subset      : train/ valid/ inference
    
    Returns:
        Mean Average Precision over all queries corresponding to the dataset
    """
    # Create Query extractor object
    QUERY_EXTRACTOR = QueryExtractor(labels_dir, img_dir, subset=subset)

    # Creat image database
    query_images = QUERY_EXTRACTOR.get_query_names()

    # Create paths
    query_image_paths = [os.path.join(img_dir, file) for file in query_images]

    aps = []
    # Now evaluate
    for i in query_image_paths:
        ap = inference_on_single_labelled_image_pca(query_img_file=i, labels_dir=labels_dir, img_dir=img_dir, img_fts_dir=img_fts_dir, weights_file=weights_file, plot=False)
        aps.append(ap)

    
    return np.array(aps).mean()
Пример #2
0
def train_n_eval(feat_type='SIFT'):
    q_valid = QueryExtractor(dataset=hp.mode,
                             image_dir=hp.image_dir,
                             label_dir=hp.label_dir,
                             subset='valid')
    data_valid = [q_name for q_name, _ in q_valid.get_queries().items()]
    data_train = [
        fname for fname in os.listdir(hp.image_dir) if fname not in data_valid
    ]
    if feat_type == 'SIFT':
        # train and creat database
        if len(os.listdir('./database/BoW/SIFT/')) == 0:
            create_db_SIFT(data_train)
        # evaluate on queries
        eval(feat_type, q_valid)

    elif feat_type == 'SURF':
        # train and creat database
        if len(os.listdir('./database/BoW/SURF/')) == 0:
            create_db_SURF(data_train)
        # evaluate on queries
        eval(feat_type, q_valid)

    else:
        pass
Пример #3
0
def getQueryNames(labels_dir="./static/data/oxbuild/gt_files/",
                  img_dir="./static/data/oxbuild/images/"):
    """
    Function that returns a list of images that are part of validation set

    Args:
        labels_dir  : Directory for ground truth labels
        img_dir     : Directory holding the images

    Returns:
        List of file paths for images that are part of validation set
    """
    QUERY_EXTRACTOR = QueryExtractor(labels_dir, img_dir, subset="inference")
    query_names = QUERY_EXTRACTOR.get_query_names()
    for i in range(len(query_names)):
        query_names[i] = img_dir[1:] + query_names[i]
    return QUERY_EXTRACTOR.get_query_names()
Пример #4
0
def inference_on_single_labelled_image_pca(
    query_img_file,
    labels_dir,
    img_dir,
    img_fts_dir,
    weights_file,
    top_k=1000,
    plot=True,
):
    """
    Function that returns the average precision for a given query image and also plots the top 20 results

    Args:
        query_img_file  : path of query image file
        labels_dir  : Directory for ground truth labels
        img_dir     : Directory holding the images
        img_fts_dir : Directory holding the pca reduced features generated through create_db.py script
        weights_file: path of trained weights file
        top_k       : top_k values used to calculate the average precison
        plot        : if True, top 20 results are plotted

    Returns:
        Average precision for the query image file
    """
    # Create cuda parameters
    use_cuda = torch.cuda.is_available()
    np.random.seed(2019)
    torch.manual_seed(2019)
    device = torch.device("cuda" if use_cuda else "cpu")
    print("Available device = ", device)

    # Create embedding network
    resnet_model = create_embedding_net()
    model = TripletNet(resnet_model)
    model.load_state_dict(torch.load(weights_file))
    model.to(device)
    model.eval()

    # Get query name
    query_img_name = query_img_file.split("/")[-1]
    query_img_path = os.path.join(img_dir, query_img_name)

    # Create Query extractor object
    QUERY_EXTRACTOR = QueryExtractor(labels_dir, img_dir, subset="inference")

    # Create query ground truth dictionary
    query_gt_dict = QUERY_EXTRACTOR.get_query_map()[query_img_name]

    # Creat image database
    QUERY_IMAGES_FTS = [
        os.path.join(img_fts_dir, file)
        for file in sorted(os.listdir(img_fts_dir))
    ]
    QUERY_IMAGES = [
        os.path.join(img_fts_dir, file) for file in sorted(os.listdir(img_dir))
    ]

    # Query fts
    query_fts = get_query_embedding(model, device,
                                    query_img_file).detach().cpu().numpy()
    query_fts = perform_pca_on_single_vector(query_fts)

    # Create similarity list
    similarity = []
    for file in tqdm(QUERY_IMAGES_FTS):
        file_fts = np.squeeze(np.load(file))
        cos_sim = np.dot(query_fts, file_fts) / (np.linalg.norm(query_fts) *
                                                 np.linalg.norm(file_fts))
        similarity.append(cos_sim)

    # Get best matches using similarity
    similarity = np.asarray(similarity)
    indexes = (-similarity).argsort()[:top_k]
    best_matches = [QUERY_IMAGES[index] for index in indexes]

    # Get preds
    if plot:
        preds = get_preds_and_visualize(best_matches, query_gt_dict, img_dir,
                                        20)
    else:
        preds = get_preds(best_matches, query_gt_dict)

    # Get average precision
    ap = ap_per_query(best_matches, query_gt_dict)

    return ap
Пример #5
0
def inference_on_single_labelled_image_pca_web(
    model,
    query_img_file,
    labels_dir="./static/data/oxbuild/gt_files/",
    img_dir="./static/data/oxbuild/images/",
    img_fts_dir="./static/fts_pca/oxbuild/",
    top_k=60,
    plot=False,
):
    """
    Function similar to inference_on_single_labelled_image_pca, but modified return values for usage during deployment

    Args:
        model       : model used (either paris or oxford)
        query_img_file  : path of query image file
        labels_dir  : Directory for ground truth labels
        img_dir     : Directory holding the images
        img_fts_dir : Directory holding the pca reduced features generated through create_db.py script
        top_k       : top_k values used to calculate the average precison, default is 60 for web deployment
        plot        : if True, top 20 results are plotted

    Returns:
        List of top k similar images; list of ground truth labels for top k images
    """
    # Create cuda parameters
    use_cuda = torch.cuda.is_available()
    np.random.seed(2019)
    torch.manual_seed(2019)
    device = torch.device("cuda" if use_cuda else "cpu")
    print("Available device = ", device)

    # Get query name
    query_img_name = query_img_file.split("/")[-1]
    query_img_path = os.path.join(img_dir, query_img_name)

    # Create Query extractor object
    QUERY_EXTRACTOR = QueryExtractor(labels_dir, img_dir, subset="inference")

    # Creat image database
    QUERY_IMAGES_FTS = [
        os.path.join(img_fts_dir, file)
        for file in sorted(os.listdir(img_fts_dir))
    ]
    QUERY_IMAGES = [
        os.path.join(img_dir, file) for file in sorted(os.listdir(img_dir))
    ]

    # Query fts
    query_fts = get_query_embedding(model, device, "." +
                                    query_img_file).detach().cpu().numpy()
    query_fts = perform_pca_on_single_vector(query_fts)

    # Create similarity list
    similarity = []
    for file in tqdm(QUERY_IMAGES_FTS):
        file_fts = np.squeeze(np.load(file))
        cos_sim = np.dot(query_fts, file_fts) / (np.linalg.norm(query_fts) *
                                                 np.linalg.norm(file_fts))
        similarity.append(cos_sim)

    # Get best matches using similarity
    similarity = np.asarray(similarity)
    indexes = (-similarity).argsort()[:top_k]
    best_matches = [QUERY_IMAGES[index] for index in indexes]
    print(best_matches)

    # Create query ground truth dictionary
    gt_map = [0] * 60
    try:
        query_gt_dict = QUERY_EXTRACTOR.get_query_map()[query_img_name]
        gt_map = get_gt_web(best_matches, query_gt_dict)
    except:
        pass
    print(gt_map)
    for i in range(len(best_matches)):
        best_matches[i] = best_matches[i][1:]
    return best_matches, gt_map
Пример #6
0
def main(data_dir,
         results_dir,
         weights_dir,
         which_dataset,
         image_resize,
         image_crop_size,
         exp_num,
         max_epochs,
         batch_size,
         samples_update_size,
         num_workers=4,
         lr=5e-6,
         weight_decay=1e-5):
    """
    This is the main function. You need to interface only with this function to train. (It will record all the results)
    Once you have trained use create_db.py to create the embeddings and then use the inference_on_single_image.py to test
    
    Arguments:
        data_dir    : parent directory for data
        results_dir : directory to store the results (Make sure you create this directory first)
        weights_dir : directory to store the weights (Make sure you create this directory first)
        which_dataset : "oxford" or "paris" 
        image_resize : resize to this size
        image_crop_size : square crop size
        exp_num     : experiment number to record the log and results
        max_epochs  : maximum epochs to run
        batch_size  : batch size (I used 5)
        samples_update_size : Number of samples the network should see before it performs one parameter update (I used 64)
    
    Keyword Arguments:
        num_workers : default 4
        lr      : Initial learning rate (default 5e-6)
        weight_decay: default 1e-5

    Eg run:
        if __name__ == '__main__':
            main(data_dir="./data/", results_dir="./results", weights_dir="./weights",
            which_dataset="oxbuild", image_resize=460, image_crop_size=448,
            exp_num=3, max_epochs=10, batch_size=5, samples_update_size=64)
    """
    # Define directories
    labels_dir = os.path.join(data_dir, which_dataset, "gt_files")
    image_dir = os.path.join(data_dir, which_dataset, "images")

    # Create Query extractor object
    q_train = QueryExtractor(labels_dir, image_dir, subset="train")
    q_valid = QueryExtractor(labels_dir, image_dir, subset="valid")

    # Create transformss
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    transforms_train = transforms.Compose([
        transforms.Resize(image_resize),
        transforms.RandomResizedCrop(image_crop_size, scale=(0.8, 1.2)),
        transforms.ColorJitter(brightness=(0.80, 1.20)),
        transforms.RandomHorizontalFlip(p=0.50),
        transforms.RandomChoice([
            transforms.RandomRotation(15),
            transforms.Grayscale(num_output_channels=3),
        ]),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    transforms_valid = transforms.Compose([
        transforms.Resize(image_resize),
        transforms.CenterCrop(image_crop_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    # Create dataset
    dataset_train = VggImageRetrievalDataset(labels_dir,
                                             image_dir,
                                             q_train,
                                             transforms=transforms_train)
    dataset_valid = VggImageRetrievalDataset(labels_dir,
                                             image_dir,
                                             q_valid,
                                             transforms=transforms_valid)

    # Create dataloader
    train_loader = DataLoader(dataset_train,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              shuffle=True)
    valid_loader = DataLoader(dataset_valid,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              shuffle=False)

    # Create cuda parameters
    use_cuda = torch.cuda.is_available()
    np.random.seed(2020)
    torch.manual_seed(2020)
    device = torch.device("cuda" if use_cuda else "cpu")

    # Create embedding network
    embedding_model = create_embedding_net()
    model = TripletNet(embedding_model)
    model.to(device)

    # Create optimizer and scheduler
    optimizer = optim.Adam(model.parameters(),
                           lr=lr,
                           weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

    # Create log file
    log_file = open(os.path.join(results_dir, "log-{}.txt".format(exp_num)),
                    "w+")
    log_file.write("----------Experiment {}----------\n".format(exp_num))
    log_file.write("Dataset = {}, Image sizes = {}, {}\n".format(
        which_dataset, image_resize, image_crop_size))

    # Creat batch update value
    update_batch = int(math.ceil(float(samples_update_size) / batch_size))
    model_name = "{}-exp-{}.pth".format(which_dataset, exp_num)
    loss_plot_save_path = os.path.join(
        results_dir, "{}-loss-exp-{}.png".format(which_dataset, exp_num))

    # Print stats before starting training
    print("Running VGG Image Retrieval Training script")
    print("Dataset used\t\t:{}".format(which_dataset))
    print("Max epochs\t\t: {}".format(max_epochs))
    print("Gradient update\t\t: every {} batches ({} samples)".format(
        update_batch, samples_update_size))
    print("Initial Learning rate\t: {}".format(lr))
    print("Image resize, crop size\t: {}, {}".format(image_resize,
                                                     image_crop_size))
    print("Available device \t:", device)

    # Train the triplet network
    tr_hist, val_hist = train_model(model,
                                    device,
                                    optimizer,
                                    scheduler,
                                    train_loader,
                                    valid_loader,
                                    epochs=max_epochs,
                                    update_batch=update_batch,
                                    model_name=model_name,
                                    save_dir=weights_dir,
                                    log_file=log_file)

    # Close the file
    log_file.close()

    # Plot and save
    plot_history(tr_hist,
                 val_hist,
                 "Triplet Loss",
                 loss_plot_save_path,
                 labels=["train", "validation"])


# if __name__ == '__main__':
#     main(data_dir="./data/", results_dir="./results", weights_dir="./weights",
#         which_dataset="oxbuild", image_resize=460, image_crop_size=448,
#         exp_num=3, max_epochs=10, batch_size=5, samples_update_size=64)
Пример #7
0
    if os.path.exists(hp.logdir + 'model.pkl'):
        if use_gpu:
            map_location = lambda storage, loc: storage.cuda()
        else:
            map_location = 'cpu'
        ckpt = torch.load(hp.logdir + 'model.pkl', map_location=map_location)
        model.load_state_dict(ckpt['state_dict'])
        print('Restore model')

    model.eval()
    cs_func = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    db_embeddings, maps = create_DB(model)
    mAP, time_running = [], []

    q_valid = QueryExtractor(dataset=hp.mode,
                             image_dir=hp.image_dir,
                             label_dir=hp.label_dir,
                             subset='valid')

    for q_name, attribute in q_valid.get_queries().items():
        start = time.time()
        bbox, class_idx = attribute[0], attribute[1]
        # create image tensor
        query_img = image_preprocessing(hp.image_dir + q_name)
        query_tensor = torch.FloatTensor(
            np.transpose(np.expand_dims(query_img, axis=0), axes=[0, 3, 1, 2]))
        # get embedding vector
        if use_gpu:
            query_tensor = query_tensor.cuda()
            cs_func = cs_func.cuda()
        query_embedding = inference(model, query_tensor)
        similarity = cs_func(query_embedding, db_embeddings).topk(