Exemple #1
0
def train():
    file_path_cap = os.path.join(Constants.data_folder_ann,
                                 Constants.captions_train_file)
    file_path_inst = os.path.join(Constants.data_folder_ann,
                                  Constants.instances_train_file)
    coco_dataloader_train, coco_data_train = get_dataloader(
        file_path_cap, file_path_inst, "train")
    file_path_cap = os.path.join(Constants.data_folder_ann,
                                 Constants.captions_val_file)
    file_path_inst = os.path.join(Constants.data_folder_ann,
                                  Constants.instances_val_file)
    coco_dataloader_val, coco_data_val = get_dataloader(
        file_path_cap, file_path_inst, "val")
    step = 0
    best_bleu4 = 0
    # initilze model, loss, etc
    model = CNNtoRNN(coco_data_train.vocab)
    model = model.to(Constants.device)
    criterion = nn.CrossEntropyLoss(
        ignore_index=coco_data_train.vocab.stoi[Constants.PAD])
    optimizer = optim.Adam(model.parameters(), lr=Hyper.learning_rate)
    #####################################################################
    if Constants.load_model:
        step = load_checkpoint(model, optimizer)

    for i in range(Hyper.total_epochs):
        model.train()  # Set model to training mode
        model.decoderRNN.train()
        model.encoderCNN.train()
        epoch = i + 1
        epochs_since_improvement = 0
        print(f"Epoch: {epoch}")
        if Constants.save_model:
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "step": step,
            }
            save_checkpoint(checkpoint)

        for _, (imgs, captions) in tqdm(enumerate(coco_dataloader_train),
                                        total=len(coco_dataloader_train),
                                        leave=False):
            imgs = imgs.to(Constants.device)
            captions = captions.to(Constants.device)
            outputs = model(imgs, captions[:-1])  # forward pass
            vocab_size = outputs.shape[2]
            outputs1 = outputs.reshape(-1, vocab_size)
            captions1 = captions.reshape(-1)
            loss = criterion(outputs1, captions1)
            optimizer.zero_grad()
            loss.backward(loss)
            optimizer.step()

        save_checkpoint_epoch(checkpoint, epoch)
        # One epoch's validation
        recent_bleu4 = validate(val_loader=coco_dataloader_val,
                                model=model,
                                criterion=criterion)

        # Check if there was an improvement
        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0
    print(
        "Example 6 OUTPUT: "
        + " ".join(model.caption_image(test_img6.to(Constants.device), vocab))
    ) """
    model.train()


def get_image(path, t_):
    temp1 = Image.open(path).convert("RGB")
    temp2 = np.array(temp1)
    temp3 = t_(temp2)
    temp4 = temp3.unsqueeze(0)
    if Hyper.is_grayscale:
        img = T.cat((temp4, temp4, temp4), 1)
        return img

    img = temp4
    return img


if __name__ == "__main__":
    with open(Constants.vocab_file, 'rb') as f:
        vocab = pickle.load(f)
        print('Vocabulary successfully loaded from the vocab.pkl file')
    epoch = 2
    model = CNNtoRNN(vocab)
    model = model.to(Constants.device)
    optimizer = optim.Adam(model.parameters(), lr=Hyper.learning_rate)
    #####################################################################
    _ = load_checkpoint_epoch(model, optimizer, epoch)
    print_examples(model, vocab)