Ejemplo n.º 1
0
def main():

    cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    args = get_parser().parse_args()

    NUM_WORKERS = 4
    CROP_SIZE = 256
    NUM_PIXELS = 64
    ENCODER_SIZE = 2048
    learning_rate = args.lr
    start_epoch = 0

    max_BLEU = 0

    vocab = pickle.load(open('vocab.p', 'rb'))

    train_transform = transforms.Compose([
        transforms.RandomCrop(CROP_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.444, 0.421, 0.385), (0.285, 0.277, 0.286))
    ])

    val_transform = transforms.Compose([
        transforms.CenterCrop(CROP_SIZE),
        transforms.ToTensor(),
        transforms.Normalize((0.444, 0.421, 0.385), (0.285, 0.277, 0.286))
    ])

    train_loader = torch.utils.data.DataLoader(dataset=Custom_Flickr30k(
        '../flickr30k-images',
        '../flickr30k-captions/results_20130124.token',
        vocab,
        transform=train_transform,
        train=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=NUM_WORKERS,
                                               collate_fn=collate_fn)

    val_loader = torch.utils.data.DataLoader(dataset=Custom_Flickr30k(
        '../flickr30k-images',
        '../flickr30k-captions/results_20130124.token',
        vocab,
        transform=val_transform,
        train=False),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=NUM_WORKERS,
                                             collate_fn=collate_fn)

    # Initialize models
    encoder = EncoderCNN(args.fine_tune).to(device)
    decoder = DecoderRNNwithAttention(len(vocab),
                                      args.embed_size,
                                      args.hid_size,
                                      1,
                                      args.attn_size,
                                      ENCODER_SIZE,
                                      NUM_PIXELS,
                                      dropout=args.drop).to(device)

    # Initialize optimization
    criterion = torch.nn.CrossEntropyLoss()
    if args.fine_tune:
        params = list(encoder.parameters()) + list(decoder.parameters())
    else:
        params = list(decoder.parameters()) + list(
            encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = torch.optim.Adam(params, lr=learning_rate)

    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            max_BLEU = checkpoint['max_BLEU']
            encoder.load_state_dict(checkpoint['encoder'])
            decoder.load_state_dict(checkpoint['decoder'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    XEntropy = AverageMeter()
    PPL = AverageMeter()

    # Save
    if not args.resume:
        file = open(f'{args.save}/resuts.txt', 'a')
        file.write('Loss,PPL,BLEU \n')
        file.close()

    for epoch in range(start_epoch, args.epoch):
        print('Epoch {}'.format(epoch + 1))
        print('training...')
        for i, (images, captions, lengths) in enumerate(train_loader):
            # Batch to device
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]

            encoder.train()
            decoder.train()

            features = encoder(images)
            predictions, attention_weights = decoder(features, captions,
                                                     lengths)

            scores = pack_padded_sequence(predictions[:, :-1, :],
                                          torch.tensor(lengths) - 2,
                                          batch_first=True).cpu()
            targets = pack_padded_sequence(captions[:, 1:-1],
                                           torch.tensor(lengths) - 2,
                                           batch_first=True).cpu()

            loss = criterion(scores.data, targets.data)
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()

            XEntropy.update(loss.item(), len(lengths))
            PPL.update(np.exp(loss.item()), len(lengths))
        print('Train Perplexity = {}'.format(PPL.avg))

        if epoch % 50 == 0:
            learning_rate /= 5
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate

        print('validating...')
        curr_BLEU = bleu_eval(encoder, decoder, val_loader, args.batch_size,
                              device)
        is_best = curr_BLEU > max_BLEU
        max_BLEU = max(curr_BLEU, max_BLEU)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'encoder': encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'max_BLEU': max_BLEU,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save)

        print('Validation BLEU = {}'.format(curr_BLEU))

        # Save
        file = open(f'{args.save}/resuts.txt', 'a')
        file.write('{},{},{} \n'.format(XEntropy.avg, PPL.avg, curr_BLEU))
        file.close()
Ejemplo n.º 2
0
def train(n_epochs, train_loader, valid_loader, save_location_path, embed_size,
          hidden_size, vocab_size):

    encoder = EncoderCNN(embed_size)
    decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

    # Move to GPU, if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    criterion = nn.CrossEntropyLoss().to(device)
    params = list(decoder.parameters()) + list(encoder.embed.parameters())
    optimizer = torch.optim.Adam(params, lr=0.001)

    # This is to make sure that the 1st loss is  lower than sth and
    # Save the model according to this comparison
    valid_loss_min = np.Inf

    for epoch in range(1, n_epochs + 1):

        # Keep track of training and validation loss
        train_loss = 0.0
        valid_loss = 0.0

        encoder.train()
        decoder.train()
        for data in train_loader:
            images, captions = data['image'], data['caption']
            images = images.type(torch.FloatTensor)
            images.to(device)
            captions.to(device)

            decoder.zero_grad()
            encoder.zero_grad()

            features = encoder(images)
            outputs = decoder(features, captions)

            loss = criterion(outputs.contiguous().view(-1, vocab_size),
                             captions.view(-1))
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)

        encoder.eval()
        decoder.eval()
        for data in valid_loader:
            images, captions = data['image'], data['caption']
            images = images.type(torch.FloatTensor)
            images.to(device)
            captions.to(device)

            features = encoder(images)
            outputs = decoder(features, captions)

            loss = criterion(outputs.contiguous().view(-1, vocab_size),
                             captions.view(-1))

            valid_loss += loss.item() * images.size(0)

            # Average losses
            train_loss = train_loss / len(train_loader)
            valid_loss = valid_loss / len(valid_loader)

            print(
                f"Epoch: {epoch} \tTraining Loss: {train_loss} \tValidation Loss: {valid_loss}"
            )

            # save model if validation loss has decreased
            if valid_loss <= valid_loss_min:
                print(
                    f"Validation loss decreased ({valid_loss_min} --> {valid_loss}).  Saving model ..."
                )
                torch.save(encoder.state_dict(),
                           save_location_path + '/encoder{n_epochs}.pt')
                torch.save(decoder.state_dict(),
                           save_location_path + '/decoder{n_epochs}.pt')
                valid_loss_min = valid_loss
def script(args):
    transform = transforms.Compose([
        transforms.Resize(args.img_size),
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ])

    train_loader, vocab = get_loader(args.root_dir, args.train_tsv_path,
                                     args.image_path, transform,
                                     args.batch_size, args.shuffle,
                                     args.num_workers)

    vocab_size = len(vocab)
    print("vocab_size: ", vocab_size)

    val_loader, _ = get_loader(args.root_dir, args.val_tsv_path,
                               args.image_path, transform, args.batch_size,
                               args.shuffle, args.num_workers, vocab)

    encoderCNN = EncoderCNN().to(args.device)

    sentLSTM = SentenceLSTM(encoderCNN.enc_dim, args.sent_hidden_dim,
                            args.att_dim, args.sent_input_dim,
                            args.word_input_dim,
                            args.int_stop_dim).to(args.device)

    wordLSTM = WordLSTM(args.word_input_dim, args.word_hidden_dim, vocab_size,
                        args.num_layers).to(args.device)

    criterion_stop = nn.CrossEntropyLoss().to(args.device)
    criterion_words = nn.CrossEntropyLoss().to(args.device)

    params_cnn = list(encoderCNN.parameters())
    params_lstm = list(sentLSTM.parameters()) + list(wordLSTM.parameters())

    optim_cnn = torch.optim.Adam(params=params_cnn, lr=args.learning_rate_cnn)
    optim_lstm = torch.optim.Adam(params=params_lstm,
                                  lr=args.learning_rate_lstm)

    total_step = len(train_loader)

    evaluate(args, val_loader, encoderCNN, sentLSTM, wordLSTM, vocab)

    for epoch in range(args.num_epochs):
        encoderCNN.train()
        sentLSTM.train()
        wordLSTM.train()

        for i, (images, captions, prob) in enumerate(train_loader):
            optim_cnn.zero_grad()
            optim_lstm.zero_grad()

            batch_size = images.shape[0]
            images = images.to(args.device)
            captions = captions.to(args.device)
            prob = prob.to(args.device)

            vis_enc_output = encoderCNN(images)

            topics, ps = sentLSTM(vis_enc_output, captions, args.device)

            loss_sent = criterion_stop(ps.view(-1, 2), prob.view(-1))

            loss_word = torch.tensor([0.0]).to(args.device)

            for j in range(captions.shape[1]):
                word_outputs = wordLSTM(topics[:, j, :], captions[:, j, :])

                loss_word += criterion_words(
                    word_outputs.contiguous().view(-1, vocab_size),
                    captions[:, j, :].contiguous().view(-1))

            loss = args.lambda_sent * loss_sent + args.lambda_word * loss_word

            loss.backward()
            optim_cnn.step()
            optim_lstm.step()

            # Print log info
            if i % args.log_step == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                    epoch, args.num_epochs, i, total_step, loss.item()))

            ## Save the model checkpoints
            # if (i+1) % args.save_step == 0:
            #     torch.save(decoder.state_dict(), os.path.join(
            #         args.model_path, 'decoder-{}-{}.ckpt'.format(epoch+1, i+1)))
            #     torch.save(encoder.state_dict(), os.path.join(
            #         args.model_path, 'encoder-{}-{}.ckpt'.format(epoch+1, i+1)))

        evaluate(args, val_loader, encoderCNN, sentLSTM, wordLSTM, vocab)