Beispiel #1
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing
    train_transform = transforms.Compose([
        transforms.RandomCrop(args.image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    val_transform = transforms.Compose([
        transforms.Resize(args.image_size, interpolation=Image.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper.
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build data loader
    train_data_loader = get_loader(args.train_image_dir,
                                   args.train_sis_path,
                                   vocab,
                                   train_transform,
                                   args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers)
    val_data_loader = get_loader(args.val_image_dir,
                                 args.val_sis_path,
                                 vocab,
                                 val_transform,
                                 args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers)

    encoder = EncoderStory(args.img_feature_size, args.hidden_size,
                           args.num_layers)
    decoder = DecoderStory(args.embed_size, args.hidden_size, vocab)

    pretrained_epoch = 0
    if args.pretrained_epoch > 0:
        pretrained_epoch = args.pretrained_epoch
        encoder.load_state_dict(
            torch.load('./models/encoder-' + str(pretrained_epoch) + '.pkl'))
        decoder.load_state_dict(
            torch.load('./models/decoder-' + str(pretrained_epoch) + '.pkl'))

    if torch.cuda.is_available():
        encoder.cuda()
        decoder.cuda()
        print("Cuda is enabled...")

    criterion = nn.CrossEntropyLoss()
    params = decoder.get_params() + encoder.get_params()
    optimizer = torch.optim.Adam(params,
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    total_train_step = len(train_data_loader)
    total_val_step = len(val_data_loader)

    min_avg_loss = float("inf")
    overfit_warn = 0

    for epoch in range(args.num_epochs):

        if epoch < pretrained_epoch:
            continue

        encoder.train()
        decoder.train()
        avg_loss = 0.0
        for bi, (image_stories, targets_set, lengths_set, photo_squence_set,
                 album_ids_set) in enumerate(train_data_loader):
            decoder.zero_grad()
            encoder.zero_grad()
            loss = 0
            images = to_var(torch.stack(image_stories))

            features, _ = encoder(images)

            for si, data in enumerate(zip(features, targets_set, lengths_set)):
                feature = data[0]
                captions = to_var(data[1])
                lengths = data[2]

                outputs = decoder(feature, captions, lengths)

                for sj, result in enumerate(zip(outputs, captions, lengths)):
                    loss += criterion(result[0], result[1][0:result[2]])

            avg_loss += loss.item()
            loss /= (args.batch_size * 5)
            loss.backward()
            optimizer.step()

            # Print log info
            if bi % args.log_step == 0:
                print(
                    'Epoch [%d/%d], Train Step [%d/%d], Loss: %.4f, Perplexity: %5.4f'
                    % (epoch + 1, args.num_epochs, bi, total_train_step,
                       loss.item(), np.exp(loss.item())))

        avg_loss /= (args.batch_size * total_train_step * 5)
        print(
            'Epoch [%d/%d], Average Train Loss: %.4f, Average Train Perplexity: %5.4f'
            % (epoch + 1, args.num_epochs, avg_loss, np.exp(avg_loss)))

        # Save the models
        torch.save(
            decoder.state_dict(),
            os.path.join(args.model_path, 'decoder-%d.pkl' % (epoch + 1)))
        torch.save(
            encoder.state_dict(),
            os.path.join(args.model_path, 'encoder-%d.pkl' % (epoch + 1)))

        # Validation
        encoder.eval()
        decoder.eval()
        avg_loss = 0.0
        for bi, (image_stories, targets_set, lengths_set, photo_sequence_set,
                 album_ids_set) in enumerate(val_data_loader):
            loss = 0
            images = to_var(torch.stack(image_stories))

            features, _ = encoder(images)

            for si, data in enumerate(zip(features, targets_set, lengths_set)):
                feature = data[0]
                captions = to_var(data[1])
                lengths = data[2]

                outputs = decoder(feature, captions, lengths)

                for sj, result in enumerate(zip(outputs, captions, lengths)):
                    loss += criterion(result[0], result[1][0:result[2]])

            avg_loss += loss.item()
            loss /= (args.batch_size * 5)

            # Print log info
            if bi % args.log_step == 0:
                print(
                    'Epoch [%d/%d], Val Step [%d/%d], Loss: %.4f, Perplexity: %5.4f'
                    % (epoch + 1, args.num_epochs, bi, total_val_step,
                       loss.item(), np.exp(loss.item())))

        avg_loss /= (args.batch_size * total_val_step * 5)
        print(
            'Epoch [%d/%d], Average Val Loss: %.4f, Average Val Perplexity: %5.4f'
            % (epoch + 1, args.num_epochs, avg_loss, np.exp(avg_loss)))

        #Termination Condition
        overfit_warn = overfit_warn + 1 if (min_avg_loss < avg_loss) else 0
        min_avg_loss = min(min_avg_loss, avg_loss)

        if overfit_warn >= 10:
            break
Beispiel #2
0
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

with open(args.vocab_path, 'rb') as f:
    vocab = pickle.load(f)

data_loader = get_loader(image_dir,
                         sis_path,
                         vocab,
                         transform,
                         args.batch_size,
                         shuffle=False,
                         num_workers=args.num_workers)

encoder = EncoderStory(args.img_feature_size, args.hidden_size,
                       args.num_layers)
decoder = DecoderStory(args.embed_size, args.hidden_size, vocab)

encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

encoder.eval()
decoder.eval()

if torch.cuda.is_available():
    encoder.cuda()
    decoder.cuda()
    print("Cuda is enabled...")

criterion = nn.CrossEntropyLoss()