def main(args):
    # Image preprocessing
    transform = transforms.Compose([
        transforms.Resize((224, 224), Image.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    pad_idx = vocab.word2idx['<pad>']
    sos_idx = vocab.word2idx['<start>']
    eos_idx = vocab.word2idx['<end>']
    unk_idx = vocab.word2idx['<unk>']

    # Build the models
    model = CVAE(
        vocab_size=len(vocab),
        embedding_size=args.embedding_size,
        rnn_type=args.rnn_type,
        hidden_size=args.hidden_size,
        word_dropout=args.word_dropout,
        embedding_dropout=args.embedding_dropout,
        latent_size=args.latent_size,
        max_sequence_length=args.max_sequence_length,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional,
        pad_idx=pad_idx,
        sos_idx=sos_idx,
        eos_idx=eos_idx,
        unk_idx=unk_idx
    )

    if not os.path.exists(args.load_checkpoint):
        raise FileNotFoundError(args.load_checkpoint)

    model.load_state_dict(torch.load(args.load_checkpoint))
    print("Model loaded from {}".format(args.load_checkpoint))

    model.to(device)
    model.eval()

    # Build data loader
    train_data_loader, valid_data_loader = get_loader(args.train_image_dir, args.val_image_dir,
                                                      args.train_caption_path, args.val_caption_path, vocab,
                                                      args.batch_size,
                                                      shuffle=True, num_workers=args.num_workers)

    f1 = open('{}/results/generated_captions.txt'.format(dataset_root_dir), 'w')
    f2 = open('{}/results/ground_truth_captions.txt'.format(dataset_root_dir), 'w')
    for i, (images, captions, lengths) in enumerate(valid_data_loader):
        images = images.to(device)

        sampled_ids, z = model.inference(n=args.batch_size, c=images)

        sampled_ids_batches = sampled_ids.cpu().numpy()  # (batch_size, max_seq_length)
        captions = captions.cpu().numpy()

        # Convert word_ids to words
        for j, sampled_ids in enumerate(sampled_ids_batches):
            sampled_caption = []
            for word_id in sampled_ids:
                word = vocab.idx2word[word_id]
                sampled_caption.append(word)
                if word == '<end>':
                    break
            generated_sentence = ' '.join(sampled_caption)
            generated_sentence = generated_sentence.rstrip()
            generated_sentence = generated_sentence.replace("\n", "")
            generated_sentence = "{0}\n".format(generated_sentence)
            if j == 0:
                print("RE: {}".format(generated_sentence))
            f1.write(generated_sentence)

        for g, ground_truth_ids in enumerate(captions):
            ground_truth_caption = []
            for word_id in ground_truth_ids:
                word = vocab.idx2word[word_id]
                ground_truth_caption.append(word)
                if word == '<end>':
                    break
            ground_truth_sentence = ' '.join(ground_truth_caption)
            ground_truth_sentence = ground_truth_sentence.rstrip()
            ground_truth_sentence = ground_truth_sentence.replace("\n", "")
            ground_truth_sentence = "{0}\n".format(ground_truth_sentence)
            if g == 0:
                print("GT: {}".format(ground_truth_sentence))
            f2.write(ground_truth_sentence)
        if i % 10 == 0:
            print("This is the {0}th batch".format(i))
    f1.close()
    f2.close()
Beispiel #2
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    pad_idx = vocab.word2idx['<pad>']
    sos_idx = vocab.word2idx['<start>']
    eos_idx = vocab.word2idx['<end>']
    unk_idx = vocab.word2idx['<unk>']

    # Build data loader
    train_data_loader, valid_data_loader = get_loader(
        args.train_image_dir,
        args.val_image_dir,
        args.train_caption_path,
        args.val_caption_path,
        vocab,
        args.batch_size,
        shuffle=True,
        num_workers=args.num_workers)

    def kl_anneal_function(anneal_function, step, k, x0):
        if anneal_function == 'logistic':
            # return float(1 / (1 + np.exp(-k * (step - x0))))
            return float(expit(k * (step - x0)))
        elif anneal_function == 'linear':
            return min(1, step / x0)

    nll = torch.nn.NLLLoss(ignore_index=pad_idx)

    def loss_fn(logp, target, length, mean, logv, anneal_function, step, k,
                x0):
        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        nll_loss = nll(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step, k, x0)

        return nll_loss, KL_loss, KL_weight

    # Build the models
    model = CVAE(vocab_size=len(vocab),
                 embedding_size=args.embedding_size,
                 rnn_type=args.rnn_type,
                 hidden_size=args.hidden_size,
                 word_dropout=args.word_dropout,
                 embedding_dropout=args.embedding_dropout,
                 latent_size=args.latent_size,
                 max_sequence_length=args.max_sequence_length,
                 num_layers=args.num_layers,
                 bidirectional=args.bidirectional,
                 pad_idx=pad_idx,
                 sos_idx=sos_idx,
                 eos_idx=eos_idx,
                 unk_idx=unk_idx)
    model.to(device)
    # Loss and optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    # Train the models
    total_step = len(train_data_loader)
    step_for_kl_annealing = 0
    best_valid_loss = float("inf")
    patience = 0

    for epoch in range(args.num_epochs):
        for i, (images, captions, lengths) in enumerate(train_data_loader):

            # Set mini-batch dataset
            images = images.to(device)
            captions_src = captions[:, :captions.size()[1] - 1]
            captions_tgt = captions[:, 1:]
            captions_src = captions_src.to(device)
            captions_tgt = captions_tgt.to(device)
            lengths = lengths - 1
            lengths = lengths.to(device)

            # Forward, backward and optimize
            logp, mean, logv, z = model(images, captions_src, lengths)

            #loss calculation
            NLL_loss, KL_loss, KL_weight = loss_fn(logp, captions_tgt, lengths,
                                                   mean, logv,
                                                   args.anneal_function,
                                                   step_for_kl_annealing,
                                                   args.k, args.x0)

            loss = (NLL_loss + KL_weight * KL_loss) / args.batch_size

            # backward + optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            step_for_kl_annealing += 1

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch, args.num_epochs, i, total_step, loss.item(),
                            np.exp(loss.item())))
                outputs = model._sample(logp)
                outputs = outputs.cpu().numpy()

                # Convert word_ids to words
                sampled_caption = []
                ground_truth_caption = []
                for word_id in outputs[-1]:
                    word = vocab.idx2word[word_id]
                    sampled_caption.append(word)
                    if word == '<end>':
                        break

                captions_tgt = captions_tgt.cpu().numpy()
                for word_id in captions_tgt[-1]:
                    word = vocab.idx2word[word_id]
                    ground_truth_caption.append(word)
                    if word == '<end>':
                        break
                reconstructed = ' '.join(sampled_caption)
                ground_truth = ' '.join(ground_truth_caption)
                print("ground_truth: {0} \n reconstructed: {1}\n".format(
                    ground_truth, reconstructed))

            # Save the model checkpoints
            if (i + 1) % args.save_step == 0:
                torch.save(
                    model.state_dict(),
                    os.path.join(args.model_path,
                                 'model-{}-{}.ckpt'.format(epoch + 1, i + 1)))

        torch.save(
            model.state_dict(),
            os.path.join(args.model_path,
                         'model-{}-epoch.ckpt'.format(epoch + 1)))

        valid_loss = 0

        #check against validation set and early stop if the validation score is not improving within patience period
        for j, (images, captions, lengths) in enumerate(valid_data_loader):
            # Set mini-batch dataset
            images = images.to(device)
            captions_src = captions[:, :captions.size()[1] - 1]
            captions_tgt = captions[:, 1:]
            captions_src = captions_src.to(device)
            captions_tgt = captions_tgt.to(device)
            lengths = lengths - 1
            lengths = lengths.to(device)

            # Forward, backward and optimize
            logp, mean, logv, z = model(images, captions_src, lengths)

            # loss calculation
            NLL_loss, KL_loss, KL_weight = loss_fn(logp, captions_tgt, lengths,
                                                   mean, logv,
                                                   args.anneal_function,
                                                   step_for_kl_annealing,
                                                   args.k, args.x0)

            valid_loss += (NLL_loss + KL_weight * KL_loss) / args.batch_size

            if j == 2:
                break
        print("validation loss for epoch {}: {}".format(epoch + 1, valid_loss))
        print("patience is at {}".format(patience))
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            patience = 0
        else:
            patience += 1

        if patience == 5:
            print("early stopping at epoch {}".format(epoch + 1))
            break