def getModel(weights_file="./static/weights/oxbuild_final.pth"): """ Function that returns the model (saved during deploy stage to redce load time) Args: weights_file: path of trained weights file Returns: model based on weights_file """ use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") resnet_model = create_embedding_net() model = TripletNet(resnet_model) model.load_state_dict(torch.load(weights_file)) model.to(device) model.eval() return model
def inference_on_single_labelled_image_pca( query_img_file, labels_dir, img_dir, img_fts_dir, weights_file, top_k=1000, plot=True, ): """ Function that returns the average precision for a given query image and also plots the top 20 results Args: query_img_file : path of query image file labels_dir : Directory for ground truth labels img_dir : Directory holding the images img_fts_dir : Directory holding the pca reduced features generated through create_db.py script weights_file: path of trained weights file top_k : top_k values used to calculate the average precison plot : if True, top 20 results are plotted Returns: Average precision for the query image file """ # Create cuda parameters use_cuda = torch.cuda.is_available() np.random.seed(2019) torch.manual_seed(2019) device = torch.device("cuda" if use_cuda else "cpu") print("Available device = ", device) # Create embedding network resnet_model = create_embedding_net() model = TripletNet(resnet_model) model.load_state_dict(torch.load(weights_file)) model.to(device) model.eval() # Get query name query_img_name = query_img_file.split("/")[-1] query_img_path = os.path.join(img_dir, query_img_name) # Create Query extractor object QUERY_EXTRACTOR = QueryExtractor(labels_dir, img_dir, subset="inference") # Create query ground truth dictionary query_gt_dict = QUERY_EXTRACTOR.get_query_map()[query_img_name] # Creat image database QUERY_IMAGES_FTS = [ os.path.join(img_fts_dir, file) for file in sorted(os.listdir(img_fts_dir)) ] QUERY_IMAGES = [ os.path.join(img_fts_dir, file) for file in sorted(os.listdir(img_dir)) ] # Query fts query_fts = get_query_embedding(model, device, query_img_file).detach().cpu().numpy() query_fts = perform_pca_on_single_vector(query_fts) # Create similarity list similarity = [] for file in tqdm(QUERY_IMAGES_FTS): file_fts = np.squeeze(np.load(file)) cos_sim = np.dot(query_fts, file_fts) / (np.linalg.norm(query_fts) * np.linalg.norm(file_fts)) similarity.append(cos_sim) # Get best matches using similarity similarity = np.asarray(similarity) indexes = (-similarity).argsort()[:top_k] best_matches = [QUERY_IMAGES[index] for index in indexes] # Get preds if plot: preds = get_preds_and_visualize(best_matches, query_gt_dict, img_dir, 20) else: preds = get_preds(best_matches, query_gt_dict) # Get average precision ap = ap_per_query(best_matches, query_gt_dict) return ap
def create_embeddings_db_pca(model_weights_path, img_dir, fts_dir): """ Given a model weights path, this function creates a triplet network, loads the parameters and generates the dimension reduced (using pca) vectors and save it in the provided feature directory. Args: model_weights_path : path of trained weights img_dir : directory that holds the images fts_dir : directory to store the embeddings Returns: None Eg run: create_embeddings_db_pca("./weights/oxbuild-exp-3.pth", img_dir="./data/oxbuild/images/", fts_dir="./fts_pca/oxbuild/") """ # Create cuda parameters use_cuda = torch.cuda.is_available() np.random.seed(2019) torch.manual_seed(2019) device = torch.device("cuda" if use_cuda else "cpu") print("Available device = ", device) # Create transforms mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] transforms_test = transforms.Compose([ transforms.Resize(460), transforms.FiveCrop(448), transforms.Lambda(lambda crops: torch.stack( [transforms.ToTensor()(crop) for crop in crops])), transforms.Lambda(lambda crops: torch.stack( [transforms.Normalize(mean=mean, std=std)(crop) for crop in crops])), ]) # Creat image database if "paris" in img_dir: print("> Blacklisted images must be removed") blacklist = [ "paris_louvre_000136.jpg", "paris_louvre_000146.jpg", "paris_moulinrouge_000422.jpg", "paris_museedorsay_001059.jpg", "paris_notredame_000188.jpg", "paris_pantheon_000284.jpg", "paris_pantheon_000960.jpg", "paris_pantheon_000974.jpg", "paris_pompidou_000195.jpg", "paris_pompidou_000196.jpg", "paris_pompidou_000201.jpg", "paris_pompidou_000467.jpg", "paris_pompidou_000640.jpg", "paris_sacrecoeur_000299.jpg", "paris_sacrecoeur_000330.jpg", "paris_sacrecoeur_000353.jpg", "paris_triomphe_000662.jpg", "paris_triomphe_000833.jpg", "paris_triomphe_000863.jpg", "paris_triomphe_000867.jpg", ] files = os.listdir(img_dir) for blacklisted_file in blacklist: files.remove(blacklisted_file) QUERY_IMAGES = [os.path.join(img_dir, file) for file in sorted(files)] else: QUERY_IMAGES = [ os.path.join(img_dir, file) for file in sorted(os.listdir(img_dir)) ] # Create dataset eval_dataset = EmbeddingDataset(img_dir, QUERY_IMAGES, transforms=transforms_test) eval_loader = DataLoader(eval_dataset, batch_size=1, num_workers=0, shuffle=False) # Create embedding network resnet_model = create_embedding_net() model = TripletNet(resnet_model) model.load_state_dict(torch.load(model_weights_path)) model.to(device) model.eval() # Create features with torch.no_grad(): for idx, image in enumerate(tqdm(eval_loader)): # Move image to device and get crops image = image.to(device) bs, ncrops, c, h, w = image.size() # Get output output = model.get_embedding(image.view(-1, c, h, w)) output = output.view(bs, ncrops, -1).mean(1).cpu().numpy() # Perform pca output = perform_pca_on_single_vector(output) # Save fts img_name = (QUERY_IMAGES[idx].split("/")[-1]).replace(".jpg", "") save_path = os.path.join(fts_dir, img_name) np.save(save_path, output.flatten()) del output, image gc.collect() # if __name__ == '__main__': # create_embeddings_db_pca("./weights/oxbuild-exp-3.pth", img_dir="./data/oxbuild/images/", fts_dir="./fts_pca/oxbuild/")
def main(data_dir, results_dir, weights_dir, which_dataset, image_resize, image_crop_size, exp_num, max_epochs, batch_size, samples_update_size, num_workers=4, lr=5e-6, weight_decay=1e-5): """ This is the main function. You need to interface only with this function to train. (It will record all the results) Once you have trained use create_db.py to create the embeddings and then use the inference_on_single_image.py to test Arguments: data_dir : parent directory for data results_dir : directory to store the results (Make sure you create this directory first) weights_dir : directory to store the weights (Make sure you create this directory first) which_dataset : "oxford" or "paris" image_resize : resize to this size image_crop_size : square crop size exp_num : experiment number to record the log and results max_epochs : maximum epochs to run batch_size : batch size (I used 5) samples_update_size : Number of samples the network should see before it performs one parameter update (I used 64) Keyword Arguments: num_workers : default 4 lr : Initial learning rate (default 5e-6) weight_decay: default 1e-5 Eg run: if __name__ == '__main__': main(data_dir="./data/", results_dir="./results", weights_dir="./weights", which_dataset="oxbuild", image_resize=460, image_crop_size=448, exp_num=3, max_epochs=10, batch_size=5, samples_update_size=64) """ # Define directories labels_dir = os.path.join(data_dir, which_dataset, "gt_files") image_dir = os.path.join(data_dir, which_dataset, "images") # Create Query extractor object q_train = QueryExtractor(labels_dir, image_dir, subset="train") q_valid = QueryExtractor(labels_dir, image_dir, subset="valid") # Create transformss mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] transforms_train = transforms.Compose([ transforms.Resize(image_resize), transforms.RandomResizedCrop(image_crop_size, scale=(0.8, 1.2)), transforms.ColorJitter(brightness=(0.80, 1.20)), transforms.RandomHorizontalFlip(p=0.50), transforms.RandomChoice([ transforms.RandomRotation(15), transforms.Grayscale(num_output_channels=3), ]), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) transforms_valid = transforms.Compose([ transforms.Resize(image_resize), transforms.CenterCrop(image_crop_size), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) # Create dataset dataset_train = VggImageRetrievalDataset(labels_dir, image_dir, q_train, transforms=transforms_train) dataset_valid = VggImageRetrievalDataset(labels_dir, image_dir, q_valid, transforms=transforms_valid) # Create dataloader train_loader = DataLoader(dataset_train, batch_size=batch_size, num_workers=num_workers, shuffle=True) valid_loader = DataLoader(dataset_valid, batch_size=batch_size, num_workers=num_workers, shuffle=False) # Create cuda parameters use_cuda = torch.cuda.is_available() np.random.seed(2020) torch.manual_seed(2020) device = torch.device("cuda" if use_cuda else "cpu") # Create embedding network embedding_model = create_embedding_net() model = TripletNet(embedding_model) model.to(device) # Create optimizer and scheduler optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10) # Create log file log_file = open(os.path.join(results_dir, "log-{}.txt".format(exp_num)), "w+") log_file.write("----------Experiment {}----------\n".format(exp_num)) log_file.write("Dataset = {}, Image sizes = {}, {}\n".format( which_dataset, image_resize, image_crop_size)) # Creat batch update value update_batch = int(math.ceil(float(samples_update_size) / batch_size)) model_name = "{}-exp-{}.pth".format(which_dataset, exp_num) loss_plot_save_path = os.path.join( results_dir, "{}-loss-exp-{}.png".format(which_dataset, exp_num)) # Print stats before starting training print("Running VGG Image Retrieval Training script") print("Dataset used\t\t:{}".format(which_dataset)) print("Max epochs\t\t: {}".format(max_epochs)) print("Gradient update\t\t: every {} batches ({} samples)".format( update_batch, samples_update_size)) print("Initial Learning rate\t: {}".format(lr)) print("Image resize, crop size\t: {}, {}".format(image_resize, image_crop_size)) print("Available device \t:", device) # Train the triplet network tr_hist, val_hist = train_model(model, device, optimizer, scheduler, train_loader, valid_loader, epochs=max_epochs, update_batch=update_batch, model_name=model_name, save_dir=weights_dir, log_file=log_file) # Close the file log_file.close() # Plot and save plot_history(tr_hist, val_hist, "Triplet Loss", loss_plot_save_path, labels=["train", "validation"]) # if __name__ == '__main__': # main(data_dir="./data/", results_dir="./results", weights_dir="./weights", # which_dataset="oxbuild", image_resize=460, image_crop_size=448, # exp_num=3, max_epochs=10, batch_size=5, samples_update_size=64)