def inference( images_path: str, sentences_path: str, test_imgs_file_path: str, batch_size: int, load_model_path: str, joint_space: int, ): # Check for CUDA device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataset_test = FlickrDatasetValTest(images_path, sentences_path, test_imgs_file_path) test_loader = DataLoader( dataset_test, batch_size=batch_size, num_workers=4, collate_fn=collate_pad_batch, pin_memory=True, ) # Create the model model = nn.DataParallel(ImageTextMatchingModel(joint_space)).to(device) # Load model model.load_state_dict(torch.load(load_model_path)) # Set model in evaluation mode model.train(False) # Create evaluator evaluator = Evaluator(len(dataset_test), joint_space) with torch.no_grad(): evaluator.reset_all_vars() for images, sentences in tqdm(test_loader): images, sentences = images.to(device), sentences.to(device) embedded_images, embedded_sentences = model(images, sentences) evaluator.update_embeddings( embedded_images.cpu().numpy().copy(), embedded_sentences.cpu().numpy().copy(), ) print("=============================") print(f"Image-text recall at 1, 5, 10: " f"{evaluator.image2text_recall_at_k()} \n" f"Text-image recall at 1, 5, 10: " f"{evaluator.text2image_recall_at_k()}") print("=============================")
def train( images_path: str, sentences_path: str, train_imgs_file_path: str, val_imgs_file_path: str, epochs: int, batch_size: int, checkpoint_path: str, save_model_path: str, learning_rate: float, weight_decay: float, clip_val: float, joint_space: int, margin: float, accumulation_steps: int, ): # Check for CUDA device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dataset_train = FlickrDatasetTrain(images_path, sentences_path, train_imgs_file_path) train_loader = DataLoader( dataset_train, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_pad_batch, pin_memory=True, ) dataset_val = FlickrDatasetValTest(images_path, sentences_path, val_imgs_file_path) val_loader = DataLoader( dataset_val, batch_size=batch_size, num_workers=4, collate_fn=collate_pad_batch, pin_memory=True, ) model = nn.DataParallel(ImageTextMatchingModel(joint_space, finetune=True)).to(device) # Load model model.load_state_dict(torch.load(checkpoint_path)) # Create loss criterion = TripletLoss(margin) # noinspection PyUnresolvedReferences optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) evaluator = Evaluator(len(dataset_val), joint_space) for epoch in range(epochs): print(f"Starting epoch {epoch + 1}...") evaluator.reset_all_vars() # Set model in train mode model.train(True) # remove past gradients optimizer.zero_grad() with tqdm(total=len(train_loader)) as pbar: for i, (images, sentences) in enumerate(train_loader): # As per: https://gist.github.com/thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3 images, sentences = images.to(device), sentences.to(device) # forward embedded_images, embedded_sentences = model(images, sentences) # Nor averaging over batch, hence not normalizing loss loss = criterion(embedded_images, embedded_sentences) # backward loss.backward() # Wait for several backward steps if (i + 1) % accumulation_steps == 0: # clip the gradients torch.nn.utils.clip_grad_norm_(model.parameters(), clip_val) # update weights optimizer.step() # Remove gradients optimizer.zero_grad() # Update progress bar pbar.update(1) pbar.set_postfix({"Batch loss": loss.item()}) # Set model in evaluation mode model.train(False) with torch.no_grad(): for images, sentences in tqdm(val_loader): images, sentences = images.to(device), sentences.to(device) embedded_images, embedded_sentences = model(images, sentences) evaluator.update_embeddings( embedded_images.cpu().numpy().copy(), embedded_sentences.cpu().numpy().copy(), ) if evaluator.is_best_recall_at_k(): evaluator.update_best_recall_at_k() print("=============================") print(f"Found new best on epoch {epoch + 1}!! Saving model!\n" f"Current image-text recall at 1, 5, 10: " f"{evaluator.best_image2text_recall_at_k} \n" f"Current text-image recall at 1, 5, 10: " f"{evaluator.best_text2image_recall_at_k}") print("=============================") torch.save(model.state_dict(), save_model_path) else: print("=============================") print(f"Metrics on epoch {epoch + 1}\n" f"Current image-text recall at 1, 5, 10: " f"{evaluator.cur_image2text_recall_at_k} \n" f"Current text-image recall at 1, 5, 10:" f"{evaluator.cur_text2image_recall_at_k}") print("=============================")