def inference(
    images_path: str,
    sentences_path: str,
    test_imgs_file_path: str,
    batch_size: int,
    load_model_path: str,
    joint_space: int,
):
    # Check for CUDA
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset_test = FlickrDatasetValTest(images_path, sentences_path,
                                        test_imgs_file_path)
    test_loader = DataLoader(
        dataset_test,
        batch_size=batch_size,
        num_workers=4,
        collate_fn=collate_pad_batch,
        pin_memory=True,
    )
    # Create the model
    model = nn.DataParallel(ImageTextMatchingModel(joint_space)).to(device)
    # Load model
    model.load_state_dict(torch.load(load_model_path))
    # Set model in evaluation mode
    model.train(False)
    # Create evaluator
    evaluator = Evaluator(len(dataset_test), joint_space)
    with torch.no_grad():
        evaluator.reset_all_vars()
        for images, sentences in tqdm(test_loader):
            images, sentences = images.to(device), sentences.to(device)
            embedded_images, embedded_sentences = model(images, sentences)
            evaluator.update_embeddings(
                embedded_images.cpu().numpy().copy(),
                embedded_sentences.cpu().numpy().copy(),
            )

    print("=============================")
    print(f"Image-text recall at 1, 5, 10: "
          f"{evaluator.image2text_recall_at_k()} \n"
          f"Text-image recall at 1, 5, 10: "
          f"{evaluator.text2image_recall_at_k()}")
    print("=============================")
Esempio n. 2
0
def train(
    images_path: str,
    sentences_path: str,
    train_imgs_file_path: str,
    val_imgs_file_path: str,
    epochs: int,
    batch_size: int,
    checkpoint_path: str,
    save_model_path: str,
    learning_rate: float,
    weight_decay: float,
    clip_val: float,
    joint_space: int,
    margin: float,
    accumulation_steps: int,
):
    # Check for CUDA
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset_train = FlickrDatasetTrain(images_path, sentences_path,
                                       train_imgs_file_path)
    train_loader = DataLoader(
        dataset_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        collate_fn=collate_pad_batch,
        pin_memory=True,
    )
    dataset_val = FlickrDatasetValTest(images_path, sentences_path,
                                       val_imgs_file_path)
    val_loader = DataLoader(
        dataset_val,
        batch_size=batch_size,
        num_workers=4,
        collate_fn=collate_pad_batch,
        pin_memory=True,
    )
    model = nn.DataParallel(ImageTextMatchingModel(joint_space,
                                                   finetune=True)).to(device)
    # Load model
    model.load_state_dict(torch.load(checkpoint_path))
    # Create loss
    criterion = TripletLoss(margin)
    # noinspection PyUnresolvedReferences
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           weight_decay=weight_decay)
    evaluator = Evaluator(len(dataset_val), joint_space)
    for epoch in range(epochs):
        print(f"Starting epoch {epoch + 1}...")
        evaluator.reset_all_vars()

        # Set model in train mode
        model.train(True)
        # remove past gradients
        optimizer.zero_grad()
        with tqdm(total=len(train_loader)) as pbar:
            for i, (images, sentences) in enumerate(train_loader):
                # As per: https://gist.github.com/thomwolf/ac7a7da6b1888c2eeac8ac8b9b05d3d3
                images, sentences = images.to(device), sentences.to(device)
                # forward
                embedded_images, embedded_sentences = model(images, sentences)
                # Nor averaging over batch, hence not normalizing loss
                loss = criterion(embedded_images, embedded_sentences)
                # backward
                loss.backward()
                # Wait for several backward steps
                if (i + 1) % accumulation_steps == 0:
                    # clip the gradients
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   clip_val)
                    # update weights
                    optimizer.step()
                    # Remove gradients
                    optimizer.zero_grad()

                # Update progress bar
                pbar.update(1)
                pbar.set_postfix({"Batch loss": loss.item()})

        # Set model in evaluation mode
        model.train(False)
        with torch.no_grad():
            for images, sentences in tqdm(val_loader):
                images, sentences = images.to(device), sentences.to(device)
                embedded_images, embedded_sentences = model(images, sentences)
                evaluator.update_embeddings(
                    embedded_images.cpu().numpy().copy(),
                    embedded_sentences.cpu().numpy().copy(),
                )

        if evaluator.is_best_recall_at_k():
            evaluator.update_best_recall_at_k()
            print("=============================")
            print(f"Found new best on epoch {epoch + 1}!! Saving model!\n"
                  f"Current image-text recall at 1, 5, 10: "
                  f"{evaluator.best_image2text_recall_at_k} \n"
                  f"Current text-image recall at 1, 5, 10: "
                  f"{evaluator.best_text2image_recall_at_k}")
            print("=============================")
            torch.save(model.state_dict(), save_model_path)
        else:
            print("=============================")
            print(f"Metrics on epoch {epoch + 1}\n"
                  f"Current image-text recall at 1, 5, 10: "
                  f"{evaluator.cur_image2text_recall_at_k} \n"
                  f"Current text-image recall at 1, 5, 10:"
                  f"{evaluator.cur_text2image_recall_at_k}")
            print("=============================")