Example #1
0
def do(args: argparse.Namespace):
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    print('gpu:', args.gpu)
    if not os.path.exists(args.save_model_path):
        os.mkdir(args.save_model_path)
    # preprocess
    preprocess = transforms.Compose([
        transforms.RandomCrop(args.random_crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)
    # dataset
    coco_loader = get_dataloader(root=args.dataset_path, json_path=args.json_path, vocab=vocab, batch_size=args.batch_size, num_workers=args.num_workers,
                                 transform=preprocess, shuffle=False)
    # models
    encoder = EncoderCNN(args.embed_size).cuda()
    decoder = DecoderRNN(len(vocab), args.embed_size, args.hidden_size, args.num_layers).cuda()
    loss_cls = nn.CrossEntropyLoss().cuda()
    params = list(encoder.fc.parameters()) + list(encoder.bn1d.parameters()) + list(decoder.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)
    # resume
    if args.resume:
        model_states = torch.load(os.path.join(args.save_model_path, 'model.ckpt'))
        print('checkpoint epoch: %d\tstep: %d' % (model_states['epoch'], model_states['step']))
        encoder.load_state_dict(model_states['encoder'])
        decoder.load_state_dict(model_states['decoder'])
        print('load successfully')
    # train
    total_step = len(coco_loader)
    print('total step in each epoch : ', total_step)
    encoder.fc.train(mode=True)
    encoder.bn1d.train(mode=True)
    encoder.encoder.eval()
    decoder.train(mode=True)
    input('ready')
    for cur_epoch in range(args.num_epochs):
        for cur_step, (image, caption, length) in enumerate(coco_loader):
            image = image.cuda()
            caption = caption.cuda()
            target = pack_padded_sequence(caption, length, batch_first=True)[0]
            out = decoder(encoder(image), caption, length)
            loss = loss_cls(out, target)
            encoder.zero_grad()
            decoder.zero_grad()
            loss.backward()
            optimizer.step()
            if (cur_step + 1) % args.print_step == 0:
                print('Epoch : %d/%d\tStep : %d/%d\tLoss : %.8f\tPerplexity : %.8f' % (
                    cur_epoch + 1, args.num_epochs, cur_step + 1, total_step, loss.item(), np.exp(loss.item())))
            if (cur_step + 1) % args.save_model_step == 0:
                torch.save({'epoch': cur_epoch + 1, 'step': cur_step + 1, 'encoder': encoder.state_dict(), 'decoder': decoder.state_dict()},
                           os.path.join(args.save_model_path, 'model.ckpt'))
                print('model saved at E:%d\tS:%d' % (cur_epoch + 1, cur_step + 1))
Example #2
0
def train(n_epochs, train_loader, valid_loader, save_location_path, embed_size,
          hidden_size, vocab_size):

    encoder = EncoderCNN(embed_size)
    decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

    # Move to GPU, if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    criterion = nn.CrossEntropyLoss().to(device)
    params = list(decoder.parameters()) + list(encoder.embed.parameters())
    optimizer = torch.optim.Adam(params, lr=0.001)

    # This is to make sure that the 1st loss is  lower than sth and
    # Save the model according to this comparison
    valid_loss_min = np.Inf

    for epoch in range(1, n_epochs + 1):

        # Keep track of training and validation loss
        train_loss = 0.0
        valid_loss = 0.0

        encoder.train()
        decoder.train()
        for data in train_loader:
            images, captions = data['image'], data['caption']
            images = images.type(torch.FloatTensor)
            images.to(device)
            captions.to(device)

            decoder.zero_grad()
            encoder.zero_grad()

            features = encoder(images)
            outputs = decoder(features, captions)

            loss = criterion(outputs.contiguous().view(-1, vocab_size),
                             captions.view(-1))
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)

        encoder.eval()
        decoder.eval()
        for data in valid_loader:
            images, captions = data['image'], data['caption']
            images = images.type(torch.FloatTensor)
            images.to(device)
            captions.to(device)

            features = encoder(images)
            outputs = decoder(features, captions)

            loss = criterion(outputs.contiguous().view(-1, vocab_size),
                             captions.view(-1))

            valid_loss += loss.item() * images.size(0)

            # Average losses
            train_loss = train_loss / len(train_loader)
            valid_loss = valid_loss / len(valid_loader)

            print(
                f"Epoch: {epoch} \tTraining Loss: {train_loss} \tValidation Loss: {valid_loss}"
            )

            # save model if validation loss has decreased
            if valid_loss <= valid_loss_min:
                print(
                    f"Validation loss decreased ({valid_loss_min} --> {valid_loss}).  Saving model ..."
                )
                torch.save(encoder.state_dict(),
                           save_location_path + '/encoder{n_epochs}.pt')
                torch.save(decoder.state_dict(),
                           save_location_path + '/decoder{n_epochs}.pt')
                valid_loss_min = valid_loss
def main():

    cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    args = get_parser().parse_args()

    NUM_WORKERS = 4
    CROP_SIZE = 256
    NUM_PIXELS = 64
    ENCODER_SIZE = 2048
    learning_rate = args.lr
    start_epoch = 0

    max_BLEU = 0

    vocab = pickle.load(open('vocab.p', 'rb'))

    train_transform = transforms.Compose([
        transforms.RandomCrop(CROP_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.444, 0.421, 0.385), (0.285, 0.277, 0.286))
    ])

    val_transform = transforms.Compose([
        transforms.CenterCrop(CROP_SIZE),
        transforms.ToTensor(),
        transforms.Normalize((0.444, 0.421, 0.385), (0.285, 0.277, 0.286))
    ])

    train_loader = torch.utils.data.DataLoader(dataset=Custom_Flickr30k(
        '../flickr30k-images',
        '../flickr30k-captions/results_20130124.token',
        vocab,
        transform=train_transform,
        train=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=NUM_WORKERS,
                                               collate_fn=collate_fn)

    val_loader = torch.utils.data.DataLoader(dataset=Custom_Flickr30k(
        '../flickr30k-images',
        '../flickr30k-captions/results_20130124.token',
        vocab,
        transform=val_transform,
        train=False),
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=NUM_WORKERS,
                                             collate_fn=collate_fn)

    # Initialize models
    encoder = EncoderCNN(args.fine_tune).to(device)
    decoder = DecoderRNNwithAttention(len(vocab),
                                      args.embed_size,
                                      args.hid_size,
                                      1,
                                      args.attn_size,
                                      ENCODER_SIZE,
                                      NUM_PIXELS,
                                      dropout=args.drop).to(device)

    # Initialize optimization
    criterion = torch.nn.CrossEntropyLoss()
    if args.fine_tune:
        params = list(encoder.parameters()) + list(decoder.parameters())
    else:
        params = list(decoder.parameters()) + list(
            encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = torch.optim.Adam(params, lr=learning_rate)

    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            max_BLEU = checkpoint['max_BLEU']
            encoder.load_state_dict(checkpoint['encoder'])
            decoder.load_state_dict(checkpoint['decoder'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    XEntropy = AverageMeter()
    PPL = AverageMeter()

    # Save
    if not args.resume:
        file = open(f'{args.save}/resuts.txt', 'a')
        file.write('Loss,PPL,BLEU \n')
        file.close()

    for epoch in range(start_epoch, args.epoch):
        print('Epoch {}'.format(epoch + 1))
        print('training...')
        for i, (images, captions, lengths) in enumerate(train_loader):
            # Batch to device
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]

            encoder.train()
            decoder.train()

            features = encoder(images)
            predictions, attention_weights = decoder(features, captions,
                                                     lengths)

            scores = pack_padded_sequence(predictions[:, :-1, :],
                                          torch.tensor(lengths) - 2,
                                          batch_first=True).cpu()
            targets = pack_padded_sequence(captions[:, 1:-1],
                                           torch.tensor(lengths) - 2,
                                           batch_first=True).cpu()

            loss = criterion(scores.data, targets.data)
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()

            XEntropy.update(loss.item(), len(lengths))
            PPL.update(np.exp(loss.item()), len(lengths))
        print('Train Perplexity = {}'.format(PPL.avg))

        if epoch % 50 == 0:
            learning_rate /= 5
            for param_group in optimizer.param_groups:
                param_group['lr'] = learning_rate

        print('validating...')
        curr_BLEU = bleu_eval(encoder, decoder, val_loader, args.batch_size,
                              device)
        is_best = curr_BLEU > max_BLEU
        max_BLEU = max(curr_BLEU, max_BLEU)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'encoder': encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'max_BLEU': max_BLEU,
                'optimizer': optimizer.state_dict(),
            }, is_best, args.save)

        print('Validation BLEU = {}'.format(curr_BLEU))

        # Save
        file = open(f'{args.save}/resuts.txt', 'a')
        file.write('{},{},{} \n'.format(XEntropy.avg, PPL.avg, curr_BLEU))
        file.close()
Example #4
0
def train_captioner():
    print("Training The Capitoner ... ")
    # Create model directory
    if not os.path.exists(path_trained_model):
        os.makedirs(path_trained_model)

    # Image preprocessing, first resize the input image then do normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.Resize((input_resnet_size, input_resnet_size),
                          interpolation=Image.ANTIALIAS),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Loading  Dictionary (as binary data)
    with open(dict_path, 'rb') as file:
        dictionary = pickle.load(file)

    # Build data loader
    data_loader = get_loader(imgs_path,
                             data_caps,
                             dictionary,
                             transform,
                             BATCH_SIZE,
                             shuffle=True,
                             num_workers=2)

    # Build the models
    encoder = EncoderCNN(word_embedding_size).to(device)
    decoder = DecoderRNN(word_embedding_size, lstm_output_size,
                         len(dictionary[0]), num_layers).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(
        encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = torch.optim.Adam(params, lr=LEARN_RATE)

    # Train the models
    total_step = len(data_loader)
    for epoch in range(NUM_EPOCHS):
        for i, (images, captions, lengths) in enumerate(data_loader):
            # Set mini-batch dataset
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]

            # Forward, backward and optimize
            features = encoder(images)
            outputs = decoder(features, captions, lengths)
            loss = criterion(outputs, targets)
            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()

            # Print log info
            if i % 20 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                    epoch, NUM_EPOCHS, i, total_step, loss.item()))

        # Sace model after each epoch ...
        torch.save(
            decoder.state_dict(),
            os.path.join(path_trained_model,
                         'captioner{}.ckpt'.format(epoch + 1)))
        torch.save(
            encoder.state_dict(),
            os.path.join(path_trained_model,
                         'feature-extractor-{}.ckpt'.format(epoch + 1)))
        # Set mini-batch dataset
        images = images.to(device)
        captions = captions.to(device)
        targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

        # Forward, backward and optimize
        features = encoder(images)
        outputs = decoder(features, captions, lengths)
        loss = criterion(outputs, targets)
        decoder.zero_grad()
        encoder.zero_grad()
        loss.backward()
        optimizer.step()

        # Print log info
        if i % log_step == 0:
            print(
                'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                .format(epoch, num_epochs, i, total_step, loss.item(),
                        np.exp(loss.item())))

        # Save the model checkpoints
        if (i + 1) % save_step == 0:
            torch.save(
                decoder.state_dict(),
                os.path.join(model_path,
                             'decoder-{}-{}.ckpt'.format(epoch + 1, i + 1)))
            torch.save(
                encoder.state_dict(),
                os.path.join(model_path,
                             'encoder-{}-{}.ckpt'.format(epoch + 1, i + 1)))
Example #6
0
def main(args):
    model_path = args.model_path
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    # load vocablary
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    img_path = args.img_path
    factual_cap_path = args.factual_caption_path
    humorous_cap_path = args.humorous_caption_path

    # import data_loader
    data_loader = get_data_loader(img_path,
                                  factual_cap_path,
                                  vocab,
                                  args.caption_batch_size,
                                  shuffle=True)
    styled_data_loader = get_styled_data_loader(humorous_cap_path,
                                                vocab,
                                                args.language_batch_size,
                                                shuffle=True)

    # import models
    emb_dim = args.emb_dim
    hidden_dim = args.hidden_dim
    factored_dim = args.factored_dim
    vocab_size = len(vocab)
    encoder = EncoderCNN(emb_dim)
    decoder = FactoredLSTM(emb_dim, hidden_dim, factored_dim, vocab_size)

    if torch.cuda.is_available():
        encoder = encoder.cuda()
        decoder = decoder.cuda()

    # loss and optimizer
    criterion = masked_cross_entropy
    cap_params = list(decoder.parameters()) + list(encoder.A.parameters())
    lang_params = list(decoder.parameters())
    optimizer_cap = torch.optim.Adam(cap_params, lr=args.lr_caption)
    optimizer_lang = torch.optim.Adam(lang_params, lr=args.lr_language)

    # train
    total_cap_step = len(data_loader)
    total_lang_step = len(styled_data_loader)
    epoch_num = args.epoch_num
    for epoch in range(epoch_num):
        # caption
        for i, (images, captions, lengths) in enumerate(data_loader):
            images = to_var(images, volatile=True)
            captions = to_var(captions.long())

            # forward, backward and optimize
            decoder.zero_grad()
            encoder.zero_grad()
            features = encoder(images)
            outputs = decoder(captions, features, mode="factual")
            loss = criterion(outputs[:, 1:, :].contiguous(),
                             captions[:, 1:].contiguous(), lengths - 1)
            loss.backward()
            optimizer_cap.step()

            # print log
            if i % args.log_step_caption == 0:
                print("Epoch [%d/%d], CAP, Step [%d/%d], Loss: %.4f" %
                      (epoch + 1, epoch_num, i, total_cap_step,
                       loss.data.mean()))

        eval_outputs(outputs, vocab)

        # language
        for i, (captions, lengths) in enumerate(styled_data_loader):
            captions = to_var(captions.long())

            # forward, backward and optimize
            decoder.zero_grad()
            outputs = decoder(captions, mode='humorous')
            loss = criterion(outputs, captions[:, 1:].contiguous(),
                             lengths - 1)
            loss.backward()
            optimizer_lang.step()

            # print log
            if i % args.log_step_language == 0:
                print("Epoch [%d/%d], LANG, Step [%d/%d], Loss: %.4f" %
                      (epoch + 1, epoch_num, i, total_lang_step,
                       loss.data.mean()))

        # save models
        torch.save(decoder.state_dict(),
                   os.path.join(model_path, 'decoder-%d.pkl' % (epoch + 1, )))

        torch.save(encoder.state_dict(),
                   os.path.join(model_path, 'encoder-%d.pkl' % (epoch + 1, )))