def getModel(weights_file="./static/weights/oxbuild_final.pth"): """ Function that returns the model (saved during deploy stage to redce load time) Args: weights_file: path of trained weights file Returns: model based on weights_file """ use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") resnet_model = create_embedding_net() model = TripletNet(resnet_model) model.load_state_dict(torch.load(weights_file)) model.to(device) model.eval() return model
def inference_on_single_labelled_image_pca( query_img_file, labels_dir, img_dir, img_fts_dir, weights_file, top_k=1000, plot=True, ): """ Function that returns the average precision for a given query image and also plots the top 20 results Args: query_img_file : path of query image file labels_dir : Directory for ground truth labels img_dir : Directory holding the images img_fts_dir : Directory holding the pca reduced features generated through create_db.py script weights_file: path of trained weights file top_k : top_k values used to calculate the average precison plot : if True, top 20 results are plotted Returns: Average precision for the query image file """ # Create cuda parameters use_cuda = torch.cuda.is_available() np.random.seed(2019) torch.manual_seed(2019) device = torch.device("cuda" if use_cuda else "cpu") print("Available device = ", device) # Create embedding network resnet_model = create_embedding_net() model = TripletNet(resnet_model) model.load_state_dict(torch.load(weights_file)) model.to(device) model.eval() # Get query name query_img_name = query_img_file.split("/")[-1] query_img_path = os.path.join(img_dir, query_img_name) # Create Query extractor object QUERY_EXTRACTOR = QueryExtractor(labels_dir, img_dir, subset="inference") # Create query ground truth dictionary query_gt_dict = QUERY_EXTRACTOR.get_query_map()[query_img_name] # Creat image database QUERY_IMAGES_FTS = [ os.path.join(img_fts_dir, file) for file in sorted(os.listdir(img_fts_dir)) ] QUERY_IMAGES = [ os.path.join(img_fts_dir, file) for file in sorted(os.listdir(img_dir)) ] # Query fts query_fts = get_query_embedding(model, device, query_img_file).detach().cpu().numpy() query_fts = perform_pca_on_single_vector(query_fts) # Create similarity list similarity = [] for file in tqdm(QUERY_IMAGES_FTS): file_fts = np.squeeze(np.load(file)) cos_sim = np.dot(query_fts, file_fts) / (np.linalg.norm(query_fts) * np.linalg.norm(file_fts)) similarity.append(cos_sim) # Get best matches using similarity similarity = np.asarray(similarity) indexes = (-similarity).argsort()[:top_k] best_matches = [QUERY_IMAGES[index] for index in indexes] # Get preds if plot: preds = get_preds_and_visualize(best_matches, query_gt_dict, img_dir, 20) else: preds = get_preds(best_matches, query_gt_dict) # Get average precision ap = ap_per_query(best_matches, query_gt_dict) return ap
def create_embeddings_db_pca(model_weights_path, img_dir, fts_dir): """ Given a model weights path, this function creates a triplet network, loads the parameters and generates the dimension reduced (using pca) vectors and save it in the provided feature directory. Args: model_weights_path : path of trained weights img_dir : directory that holds the images fts_dir : directory to store the embeddings Returns: None Eg run: create_embeddings_db_pca("./weights/oxbuild-exp-3.pth", img_dir="./data/oxbuild/images/", fts_dir="./fts_pca/oxbuild/") """ # Create cuda parameters use_cuda = torch.cuda.is_available() np.random.seed(2019) torch.manual_seed(2019) device = torch.device("cuda" if use_cuda else "cpu") print("Available device = ", device) # Create transforms mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] transforms_test = transforms.Compose([ transforms.Resize(460), transforms.FiveCrop(448), transforms.Lambda(lambda crops: torch.stack( [transforms.ToTensor()(crop) for crop in crops])), transforms.Lambda(lambda crops: torch.stack( [transforms.Normalize(mean=mean, std=std)(crop) for crop in crops])), ]) # Creat image database if "paris" in img_dir: print("> Blacklisted images must be removed") blacklist = [ "paris_louvre_000136.jpg", "paris_louvre_000146.jpg", "paris_moulinrouge_000422.jpg", "paris_museedorsay_001059.jpg", "paris_notredame_000188.jpg", "paris_pantheon_000284.jpg", "paris_pantheon_000960.jpg", "paris_pantheon_000974.jpg", "paris_pompidou_000195.jpg", "paris_pompidou_000196.jpg", "paris_pompidou_000201.jpg", "paris_pompidou_000467.jpg", "paris_pompidou_000640.jpg", "paris_sacrecoeur_000299.jpg", "paris_sacrecoeur_000330.jpg", "paris_sacrecoeur_000353.jpg", "paris_triomphe_000662.jpg", "paris_triomphe_000833.jpg", "paris_triomphe_000863.jpg", "paris_triomphe_000867.jpg", ] files = os.listdir(img_dir) for blacklisted_file in blacklist: files.remove(blacklisted_file) QUERY_IMAGES = [os.path.join(img_dir, file) for file in sorted(files)] else: QUERY_IMAGES = [ os.path.join(img_dir, file) for file in sorted(os.listdir(img_dir)) ] # Create dataset eval_dataset = EmbeddingDataset(img_dir, QUERY_IMAGES, transforms=transforms_test) eval_loader = DataLoader(eval_dataset, batch_size=1, num_workers=0, shuffle=False) # Create embedding network resnet_model = create_embedding_net() model = TripletNet(resnet_model) model.load_state_dict(torch.load(model_weights_path)) model.to(device) model.eval() # Create features with torch.no_grad(): for idx, image in enumerate(tqdm(eval_loader)): # Move image to device and get crops image = image.to(device) bs, ncrops, c, h, w = image.size() # Get output output = model.get_embedding(image.view(-1, c, h, w)) output = output.view(bs, ncrops, -1).mean(1).cpu().numpy() # Perform pca output = perform_pca_on_single_vector(output) # Save fts img_name = (QUERY_IMAGES[idx].split("/")[-1]).replace(".jpg", "") save_path = os.path.join(fts_dir, img_name) np.save(save_path, output.flatten()) del output, image gc.collect() # if __name__ == '__main__': # create_embeddings_db_pca("./weights/oxbuild-exp-3.pth", img_dir="./data/oxbuild/images/", fts_dir="./fts_pca/oxbuild/")