Ejemplo n.º 1
0
def main():
    # Load vocabulary wrapper.
    with open(vocab_path) as f:
        vocab = pickle.load(f)

    encoder = EncoderCNN(4096, embed_dim)
    encoder.load_state_dict(torch.load('searchimage.pkl'))
    for p in encoder.parameters():
        p.requires_grad = False

    word_encoder = EncoderRNN(embed_dim, embed_dim, len(vocab), num_layers_rnn)
    word_encoder.load_state_dict(torch.load('searchword.pkl'))
    if torch.cuda.is_available():
        encoder.cuda()
        word_encoder.cuda()
    # Loss and Optimizer
    criterion = nn.MSELoss()
    params = list(
        word_encoder.parameters())  # + list(encoder.linear.parameters())
    optimizer = torch.optim.Adam(params, lr=2e-6, weight_decay=0.001)

    #load data
    with open(image_data_file) as f:
        image_data = pickle.load(f)
    image_features = si.loadmat(image_feature_file)

    img_features = image_features['fc7'][0]
    img_features = np.concatenate(img_features)

    print 'here'
    iteration = 0

    for i in range(10):  # epoch
        use_caption = i % 5
        print 'Epoch', i
        losses = []
        for x, y in make_mini_batch(img_features,
                                    image_data,
                                    use_caption=use_caption):
            encoder.zero_grad()
            word_encoder.zero_grad()

            word_padding, lengths = make_word_padding(y, vocab)
            x = Variable(torch.from_numpy(x).cuda())
            word_index = Variable(torch.from_numpy(word_padding).cuda())

            features = encoder(x)
            outputs = word_encoder(word_index, lengths)
            loss = torch.mean((features - outputs).pow(2))
            loss.backward()
            optimizer.step()
            losses.append(loss.data[0])
            if iteration % 100 == 0:
                print 'loss', sum(losses) / float(len(losses))
                losses = []

            iteration += 1

        torch.save(word_encoder.state_dict(), 'searchword.pkl')
        torch.save(encoder.state_dict(), 'searchimage.pkl')
Ejemplo n.º 2
0
def main(args):
    global best_bleu4, epochs_since_improvement, checkpoint, start_epoch, fine_tune_encoder, data_name, word_map

    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    if args.checkpoint is None:
        decoder = AttnDecoderRNN(attention_dim=args.attention_dim,
                                 embed_dim=args.embed_dim,
                                 decoder_dim=args.decoder_dim,
                                 vocab_size=len(vocab),
                                 dropout=args.dropout)
        decoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, decoder.parameters()),
                                             lr=args.decoder_lr)
        encoder = EncoderCNN()
        encoder.fine_tune(args.fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(
            params=filter(lambda p: p.requires_grad, encoder.parameters()),
            lr=args.encoder_lr) if args.fine_tune_encoder else None
    else:
        checkpoint = torch.load(args.checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_bleu4 = checkpoint['bleu-4']
        decoder = checkpoint['decoder']
        decoder_optimizer = checkpoint['decoder_optimizer']
        encoder = checkpoint['encoder']
        encoder_optimizer = checkpoint['encoder_optimizer']
        if fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(fine_tune_encoder)
            encoder_optimizer = torch.optim.Adam(params=filter(
                lambda p: p.requires_grad, encoder.parameters()),
                                                 lr=args.encoder_lr)
    decoder = decoder.to(device)
    encoder = encoder.to(device)

    criterion = nn.CrossEntropyLoss().to(device)

    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Build data loader
    train_loader = get_loader(args.image_dir,
                              args.caption_path,
                              vocab,
                              transform,
                              args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)

    val_loader = get_loader(args.image_dir_val,
                            args.caption_path_val,
                            vocab,
                            transform,
                            args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers)

    for epoch in range(args.start_epoch, args.epochs):
        if args.epochs_since_improvement == 20:
            break
        if args.epochs_since_improvement > 0 and args.epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if args.fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        train(train_loader=train_loader,
              encoder=encoder,
              decoder=decoder,
              criterion=criterion,
              encoder_optimizer=encoder_optimizer,
              decoder_optimizer=decoder_optimizer,
              epoch=epoch)

        recent_bleu4 = validate(val_loader=val_loader,
                                encoder=encoder,
                                decoder=decoder,
                                criterion=criterion)

        is_best = recent_bleu4 > best_bleu4
        best_bleu4 = max(recent_bleu4, best_bleu4)
        if not is_best:
            args.epochs_since_improvement += 1
            print("\nEpoch since last improvement: %d\n" %
                  (args.epochs_since_improvement, ))
        else:
            args.epochs_since_improvement = 0

        save_checkpoint(args.data_name, epoch, args.epochs_since_improvement,
                        encoder, decoder, encoder_optimizer, decoder_optimizer,
                        recent_bleu4, is_best)
Ejemplo n.º 3
0
if model == "lstm":
    f_rnn = LSTMModel(emb_size,
                      emb_size,
                      emb_size,
                      device,
                      bidirectional=False)
    b_rnn = LSTMModel(emb_size,
                      emb_size,
                      emb_size,
                      device,
                      bidirectional=False)
f_rnn = f_rnn.to(device)
b_rnn = b_rnn.to(device)

criterion = nn.CrossEntropyLoss()
params_to_train = (list(encoder_cnn.parameters()) + list(f_rnn.parameters()) +
                   list(b_rnn.parameters()))
optimizer = torch.optim.SGD(params_to_train, lr=2e-1, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

################################## Logger #####################################
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

file_handler = logging.FileHandler("log_{}{}.log".format(
    __file__.split(".")[0], comment))
log_format = "%(asctime)s [%(levelname)-5.5s] %(message)s"
formatter = logging.Formatter(log_format)
file_handler.setFormatter(formatter)
file_handler.setLevel(logging.INFO)
Ejemplo n.º 4
0
def main(args):
    print(args)
    epochs_since_improvement = 0

    # Create model directory
    make_dir(args.model_path)

    # Image pre-processing, normalization for the pre-trained res-net
    transform = transforms.Compose([
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    # Load vocabulary wrapper
    vocab_path = args.vocab_path
    with open(vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build data loader
    train_root = args.image_dir + cfg['train']['TRAIN_DIR']
    train_json = args.caption_path + cfg['train']['train_annotation']

    val_root = args.image_dir + cfg['train']['VAL_DIR']
    val_json = args.caption_path + cfg['train']['valid_annotation']

    # After patience epochs without improvement, break training
    patience = cfg['train']['patience']
    early_stopping = EarlyStopping(patience=patience, verbose=True)

    if args.check_point and os.path.isfile(args.check_point):
        checkpoint = torch.load(args.check_point)

    old_vocab_size = 0
    if args.fine_tuning:
        encoder = checkpoint['encoder']
        decoder = checkpoint['decoder']
        print("Fine tuning with check point is {}".format(args.check_point))

        vocab, old_vocab_size = append_vocab(args.check_point_vocab, vocab)

        with open(vocab_path, 'wb') as v:
            print("Dump {} entries to vocab {}".format(vocab.idx, vocab_path))
            pickle.dump(vocab, v)
        vocab_size = len(vocab)

        # Get decoder's previous state
        old_embed = decoder.embed.weight.data[:4124]
        old_weight = decoder.linear.weight.data[:4124]
        old_bias = decoder.linear.bias.data[:4124]

        # Initialize new embedding and linear layers
        decoder.embed = nn.Embedding(vocab_size, args.embed_size)
        decoder.linear = nn.Linear(args.hidden_size, vocab_size)

        if args.freeze_cri or args.lwf or args.distill:
            # Assign old neurons to the newly-initialized layer, fine-tuning only should ignore this
            print(
                "Assigning old neurons of embedding and linear layer to new decoder..."
            )

            # Init by decoder's params
            decoder.embed.weight.data[:
                                      4124, :] = old_embed  # 4124 is the vocab size of S19
            decoder.linear.weight.data[:4124] = old_weight
            decoder.linear.bias.data[:4124] = old_bias

        encoder.to(device)
        decoder.to(device)

    else:
        # Normal training procedure
        encoder = EncoderCNN(args.embed_size).to(device)
        decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab),
                             args.num_layers).to(device)

    if args.freeze_enc:
        args.task_name += '_freeze_enc'
    elif args.freeze_dec:
        args.task_name += '_freeze_dec'
    elif args.freeze_cri:
        args.task_name += '_freeze_cri'
    elif args.lwf:
        args.task_name += '_lwf'
    elif args.distill and args.kd1:
        args.task_name += '_kd1'
    elif args.distill and args.kd2:
        args.task_name += '_kd2'

    if args.task_type == 'seq':
        args.model_path = cfg['model']['model_path_format'].format(
            args.task_type, args.task_name + '_seq', 'models')
        args.cpkt_path = cfg['model']['model_path_format'].format(
            args.task_type, args.task_name + '_seq', 'best')
    else:
        args.model_path = cfg['model']['model_path_format'].format(
            args.task_type, args.task_name, 'models')
        args.cpkt_path = cfg['model']['model_path_format'].format(
            args.task_type, args.task_name, 'best')

    # Create model directory
    make_dir(args.model_path)

    # Pseudo-labeling option
    if args.lwf:
        print("Running pseudo-labeling option...")
        # Infer pseudo-labels using previous model
        pseudo_labels = infer_caption(img_path=train_root,
                                      json_path=train_json,
                                      model=args.check_point,
                                      vocab_path=vocab_path,
                                      prediction_path=None,
                                      id2class_path=None)

        # Freeze LSTM and decoder for later joint optimization
        for param in decoder.lstm.parameters():
            param.requires_grad_(False)
        for param in encoder.parameters():
            param.requires_grad_(False)

        data = append_json(pseudo_labels, train_json)

        # Create a new json file from the train_json
        train_json = args.caption_path + 'captions_train_lwf.json'
        with open(train_json, 'w') as file:
            json.dump(data, file)

        # Knowledge distillation option
    if args.distill:
        print("Running knowledge distillation...")
        # Teacher
        teacher_cnn = checkpoint['encoder']
        teacher_lstm = checkpoint['decoder']
        teacher_cnn.train()
        teacher_lstm.train()

        # Initialize a totally new captioning model - Student
        encoder = EncoderCNN(args.embed_size).to(device)
        decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab),
                             args.num_layers).to(device)

        # Student
        student_cnn = encoder
        student_lstm = decoder

        # Move teacher to cuda
        teacher_cnn.to(device)
        teacher_lstm.to(device)

        # Loss between GT caption and the prediction
        criterion_lstm = nn.CrossEntropyLoss()
        # Loss between predictions of teacher and student
        criterion_distill = nn.MSELoss()

        # Params of student
        params_st = list(student_lstm.parameters()) + list(
            student_cnn.parameters())

        optimizer_lstm = torch.optim.Adam(params_st, lr=1e-4)
        optimizer_distill = torch.optim.Adam(student_cnn.parameters(), lr=1e-5)

    if args.freeze_enc:
        print("Freeze encoder technique!")
        for param in encoder.parameters():
            param.requires_grad_(False)

    if args.freeze_dec:
        print("Freeze decoder technique!")
        for param in decoder.lstm.parameters():
            param.requires_grad_(False)

    if args.freeze_cri:
        print("Critical Freezing technique!")
        layer_idx = -1
        for child in encoder.resnet.children():
            layer_idx += 1
            if layer_idx == 0 or layer_idx == 4:  # blk 1 & 2
                for param in child.parameters():
                    param.requires_grad = False

    train_loader = get_loader(root=train_root,
                              json=train_json,
                              vocab=vocab,
                              transform=transform,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)

    val_loader = get_loader(root=val_root,
                            json=val_json,
                            vocab=vocab,
                            transform=transform,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(encoder.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    # Theses vars are for plotting
    avg_train_losses = []
    avg_val_losses = []

    for epoch in range(args.num_epochs):

        if args.distill:
            print("Training with distillation option!")
            train_step, train_loss_step = train_distill(
                epoch,
                train_loader=train_loader,
                student_cnn=student_cnn,
                student_lstm=student_lstm,
                teacher_cnn=teacher_cnn,
                teacher_lstm=teacher_lstm,
                criterion_lstm=criterion_lstm,
                criterion_distill=criterion_distill,
                optimizer_lstm=optimizer_lstm,
                optimizer_distill=optimizer_distill)
            # Validate after an epoch
            recent_val_loss, val_step, val_loss_step = validate(
                epoch,
                val_loader=val_loader,
                encoder=student_cnn,
                decoder=student_lstm,
                criterion=criterion)
        else:

            train_step, train_loss_step = train(epoch,
                                                train_loader=train_loader,
                                                encoder=encoder,
                                                decoder=decoder,
                                                criterion=criterion,
                                                optimizer=optimizer,
                                                first_training=True,
                                                old_vocab_size=old_vocab_size)
            # Validate after an epoch
            recent_val_loss, val_step, val_loss_step = validate(
                epoch,
                val_loader=val_loader,
                encoder=encoder,
                decoder=decoder,
                criterion=criterion)
        train_loss = np.average(train_loss_step)
        val_loss = np.average(val_loss_step)

        avg_train_losses.append(train_loss)
        avg_val_losses.append(val_loss)

        # Save checkpoint
        make_dir(args.cpkt_path)
        early_stopping(args.cpkt_path, cfg['train']['data_name'], epoch,
                       epochs_since_improvement, encoder, decoder, optimizer,
                       optimizer, val_loss)

        if early_stopping.early_stop:
            print("Early Stopping!")
            break

    if args.lwf:
        # Make all trainable
        for param in decoder.linear.parameters():
            param.requires_grad_(True)
        for param in decoder.embed.parameters():
            param.requires_grad_(True)
        for param in decoder.lstm.parameters():
            param.requires_grad_(True)
        for param in encoder.parameters():
            param.requires_grad_(True)

        print("Unfreezing parameters ...")

        print("Critical Freezing technique!")
        layer_idx = -1
        for child in encoder.resnet.children():
            layer_idx += 1
            if layer_idx == 0 or layer_idx == 4:  # blk 1 & 2
                for param in child.parameters():
                    param.requires_grad = False

        # Joint optimization starts

        early_stopping = EarlyStopping(patience=patience, verbose=True)
        for epoch in range(args.num_epochs):
            train_step, train_loss_step = train(epoch,
                                                train_loader=train_loader,
                                                encoder=encoder,
                                                decoder=decoder,
                                                criterion=criterion,
                                                optimizer=optimizer,
                                                first_training=False,
                                                old_vocab_size=old_vocab_size)
            # Validate after an epoch
            recent_val_loss, val_step, val_loss_step = validate(
                epoch,
                val_loader=val_loader,
                encoder=encoder,
                decoder=decoder,
                criterion=criterion)

            train_loss = np.average(train_loss_step)
            val_loss = np.average(val_loss_step)

            avg_train_losses.append(train_loss)
            avg_val_losses.append(val_loss)

            # Save checkpoint
            make_dir(args.cpkt_path)
            early_stopping(args.cpkt_path, cfg['train']['data_name'], epoch,
                           epochs_since_improvement, encoder, decoder,
                           optimizer, optimizer, val_loss)

            if early_stopping.early_stop:
                print("Early Stopping!")
                break
Ejemplo n.º 5
0
def main(args):
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Build the models, can use a feedforward/convolutional encoder and an RNN decoder
    encoder = EncoderCNN(args.embed_size).to(
        device)  #can be sequential or convolutional
    decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab),
                         args.num_layers).to(device)
    # Loss and optimizer
    criterion1 = nn.CrossEntropyLoss()
    criterion2 = nn.NLLLoss()
    softmax = nn.LogSoftmax(dim=1)
    params = list(decoder.parameters()) + list(encoder.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)
    total_training_steps = args.num_iters
    losses = []
    perplexity = []
    for epoch in range(args.num_epochs):
        for i in range(total_training_steps):
            prog_data = generate_training_data(args.batch_size)

            images = [im[0] for im in prog_data]
            transforms = [transform[1] for transform in prog_data]

            [ele.insert(0, '<start>')
             for ele in transforms]  #start token for each sequence
            [ele.append('<end>')
             for ele in transforms]  #end token for each sequence

            lengths = [len(trans) for trans in transforms]

            maximum_len = max(lengths)
            for trans in transforms:
                if len(trans) != maximum_len:
                    trans.extend(['pad'] * (maximum_len - len(trans)))

            padded_lengths = [len(trans) for trans in transforms]
            transforms = [[word_to_int(word) for word in transform]
                          for transform in transforms]
            transforms = torch.tensor(transforms, device=device)
            images = torch.tensor(images, device=device)
            images = images.unsqueeze(
                1)  #Uncomment this line when training using EncoderCNN
            lengths = torch.tensor(lengths, device=device)
            padded_lengths = torch.tensor(padded_lengths, device=device)
            targets = pack_padded_sequence(transforms,
                                           padded_lengths,
                                           batch_first=True)[0]

            features = encoder(images)
            outputs = decoder(features, transforms, padded_lengths)
            #print(outputs)

            loss = criterion1(outputs, targets)
            losses.append(loss.item())
            perplexity.append(np.exp(loss.item()))

            decoder.zero_grad()
            encoder.zero_grad()
            loss.backward()
            optimizer.step()

            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f},Perplexity: {:5.4f}'
                    .format(epoch, args.num_epochs, i, total_training_steps,
                            loss.item(), np.exp(loss.item())))

            # Save the model checkpoints
            if (i + 1) % args.save_step == 0:
                torch.save(
                    decoder.state_dict(),
                    os.path.join(args.model_path,
                                 'decoder-{}-{}.ckpt'.format(epoch + 1,
                                                             i + 1)))
                torch.save(
                    encoder.state_dict(),
                    os.path.join(args.model_path,
                                 'encoder-{}-{}.ckpt'.format(epoch + 1,
                                                             i + 1)))

    y = losses
    z = perplexity
    x = np.arange(len(losses))
    plt.plot(x, y, label='Cross Entropy Loss')
    plt.plot(x, z, label='Perplexity')
    plt.xlabel('Iterations')
    plt.ylabel('Cross Entropy Loss and Perplexity')
    plt.title("Cross Entropy Loss and Model Perplexity During Training")
    plt.legend()
    plt.savefig('plots/plots_cnn/cnn4_gpu', dpi=100)
Ejemplo n.º 6
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    '''
    transform = transforms.Compose([ 
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(), 
        transforms.ToTensor(), 
        transforms.Normalize((0.485, 0.456, 0.406), 
                             (0.229, 0.224, 0.225))])
    
    '''
    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    set_vocab(vocab)
    # Build data loader
    data_loader, test_loader = get_loader(args.dataset_dir,
                                          args.batch_size,
                                          shuffle=True,
                                          num_workers=args.num_workers)

    # Build the models
    encoder = EncoderCNN(args.embed_size).to(device)
    decoder = DecoderRNN(args.embed_size, args.hidden_size, len(vocab),
                         args.num_layers).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(encoder.parameters())
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    # Train the models
    total_step1 = len(data_loader)
    total_step2 = len(test_loader)
    for epoch in range(args.num_epochs):
        print("\nTrain\n")
        for i, (images, captions, lengths) in enumerate(data_loader):
            if (type(images) == int):
                continue
            # Set mini-batch dataset
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True)[0]

            # decoder.zero_grad()
            # encoder.zero_grad()
            optimizer.zero_grad()

            features = encoder(images)
            outputs = decoder(features, captions, lengths)

            loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()

            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                epoch, args.num_epochs, i, total_step1, loss.item()),
                  end="\r")
            # Print log info
            if i % args.log_step == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                    .format(epoch, args.num_epochs, i, total_step1,
                            loss.item(), np.exp(loss.item())))

            # Save the model checkpoints
            if (i + 1) % args.save_step == 0:
                torch.save(
                    decoder.state_dict(),
                    os.path.join(args.model_path,
                                 'decoder-{}-{}.ckpt'.format(epoch + 1,
                                                             i + 1)))
                torch.save(
                    encoder.state_dict(),
                    os.path.join(args.model_path,
                                 'encoder-{}-{}.ckpt'.format(epoch + 1,
                                                             i + 1)))

        print("\nTest\n")

        with torch.no_grad():
            for i, (images, captions, lengths) in enumerate(test_loader):
                if (type(images) == int):
                    continue

                images = images.to(device)
                captions = captions.to(device)
                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]

                features = encoder(images)
                outputs = decoder(features, captions, lengths)
                loss = criterion(outputs, targets)

                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(
                    epoch, args.num_epochs, i, total_step2, loss.item()),
                      end="\r")
                # Print log info
                if i % args.log_step == 0:
                    print(
                        'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                        .format(epoch, args.num_epochs, i, total_step2,
                                loss.item(), np.exp(loss.item())))
Ejemplo n.º 7
0
            loss = F.cross_entropy(output, y)
            losses.append(loss)    

            y_pred = output.max(1, keepdim=True)[1]  # (y_pred != output) get the index of the max log-probability
            correct += y_pred.eq(y.view_as(y_pred)).sum().item()

           
    # show information
    acc = 100. * (correct / N_count)
    average_loss = sum(losses)/len(validation_generator)
    print('Validation set ({:d} samples): Average loss: {:.4f}\tAcc: {:.4f}%'.format(N_count, average_loss, acc))
    return average_loss, acc


# optimizer
crnn_params = list(encoder_cnn.parameters()) + list(decoder_rnn.parameters())
optimizer = torch.optim.Adam(crnn_params, lr=learning_rate)

# start training
for epoch in range(epochs):
    # train, test model
    train_loss, train_acc = train([encoder_cnn, decoder_rnn], device, training_generator, optimizer, epoch, log_interval)
    val_loss, val_acc = validation([encoder_cnn, decoder_rnn], device, optimizer, validation_generator)

   
    
# def validation(model, device, optimizer, test_loader):
#     # set model as testing mode
#     cnn_encoder, rnn_decoder = model
#     cnn_encoder.eval()
#     rnn_decoder.eval()
Ejemplo n.º 8
0
encoder_cnn = encoder_cnn.to(device)

if model == "lstm":
    f_rnn = LSTMModel(emb_size, emb_size, emb_size, device, bidirectional=False)
    b_rnn = LSTMModel(emb_size, emb_size, emb_size, device, bidirectional=False)
f_rnn = f_rnn.to(device)
b_rnn = b_rnn.to(device)

embedding = nn.Embedding(len(train_dataset.vocabulary), emb_size)
embedding = embedding.to(device)
image_embedding = nn.Linear(2048, emb_size)
image_embedding = image_embedding.to(device)

criterion = nn.CrossEntropyLoss()
params_to_train = (
    list(encoder_cnn.parameters())
    + list(f_rnn.parameters())
    + list(b_rnn.parameters())
    + list(embedding.parameters())
    + list(image_embedding.parameters())
)
optimizer = torch.optim.SGD(params_to_train, lr=2e-1, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

################################## Logger #####################################
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

file_handler = logging.FileHandler(
    "log_{}{}.log".format(__file__.split(".")[0], comment)
)
def main(args):
    # Setup Logger
    log_file_path = os.path.join(os.getcwd(), 'logs')
    if not os.path.exists(log_file_path):
        os.makedirs(log_file_path)

    start_time = time.ctime()
    logging.basicConfig(filename=os.path.join(
        log_file_path, "training_log_" +
        str(start_time).replace(':', '').replace('  ', ' ').replace(' ', '_') +
        ".log"),
                        format='%(asctime)s - %(message)s',
                        level=logging.INFO)

    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    data_augmentations = get_augmentations(args.crop_size)

    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((args.crop_size, args.crop_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    train_coco = COCO(args.train_caption_path)
    val_coco = COCO(args.val_caption_path)
    # Load vocabulary wrapper
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)

    # Build data loader
    train_dataloader = get_loader(args.train_dir,
                                  train_coco,
                                  vocab,
                                  transform,
                                  data_augmentations,
                                  args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)

    validation_dataloader = get_loader(args.val_dir,
                                       val_coco,
                                       vocab,
                                       transform,
                                       None,
                                       args.batch_size,
                                       shuffle=False,
                                       num_workers=args.num_workers)

    # Load word embeddings
    fasttext_wv = load_fasttext(args.train_caption_path)
    print("Loaded FastText word embeddings")
    embed_dim = fasttext_wv.vectors_vocab.shape[1]
    embedding_weights = np.zeros((len(vocab), embed_dim))
    for idx, word in enumerate(vocab.word2idx):
        embedding_weights[idx] = fasttext_wv[word]

    # Build the models
    encoder = EncoderCNN().to(device)
    decoder = DecoderRNN(args.lstm1_size,
                         args.lstm2_size,
                         args.att_size,
                         vocab,
                         args.embed_size,
                         embedding_weights,
                         feature_size=args.feature_size).to(device)

    # Metrics
    train_metrics = [Accuracy(), RougeBleuScore(train_coco, vocab)]

    val_metrics = [Accuracy(), RougeBleuScore(val_coco, vocab)]

    # Loss and optimizer
    loss = nn.CrossEntropyLoss(ignore_index=0)
    for p in encoder.parameters():
        p.requires_grad = False
    optimizer = torch.optim.Adam([{
        'params': encoder.parameters(),
        'lr': 0.5 * args.learning_rate
    }, {
        'params': decoder.parameters()
    }],
                                 lr=args.learning_rate)
    # optimizer = torch.optim.Adam(encoder.parameters(), lr=args.learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                                     10,
                                                                     T_mult=1,
                                                                     eta_min=0)
    encoder_unfreeze_epoch = 2
    train_model(
        train_dataloader=train_dataloader,
        validation_dataloader=validation_dataloader,
        model=[encoder, decoder],
        loss=loss,
        train_metrics=train_metrics,
        val_metrics=val_metrics,
        optimizer=optimizer,
        scheduler=scheduler,
        batch_size=args.batch_size,
        num_epochs=args.num_epochs,
        encoder_unfreeze_epoch=encoder_unfreeze_epoch,
        device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
        logger=logging,
        verbose=True,
        model_save_path=os.path.join(os.getcwd(), 'model'),
        plots_save_path=os.path.join(os.getcwd(), 'plots'))
                         vocab_from_file=vocab_from_file)

print('Total number of tokens in vocabulary:', len(data_loader.dataset.vocab))

vocab_size = len(data_loader.dataset.vocab)

encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder.to(device)
decoder.to(device)

criterion = nn.CrossEntropyLoss().cuda() if torch.cuda.is_available(
) else nn.CrossEntropyLoss()
params = list(decoder.parameters()) + list(encoder.parameters())
optimizer = torch.optim.Adam(params, lr=.001, betas=(.9, .999), eps=1e-08)

total_step = math.ceil(
    len(data_loader.dataset.caption_lengths) /
    data_loader.batch_sampler.batch_size)

import torch.utils.data as data
import numpy as np

f = open(log_file, 'w')

for epoch in range(1, num_epochs + 1):
    for i_step in range(1, total_step + 1):
        # Randomly sample a caption length, and sample indices with that length.
        indices = data_loader.dataset.get_train_indices()
Ejemplo n.º 11
0
def train_main(args):
    if not os.path.exists(args.base_dir + "model/"):
        os.mkdir(args.base_dir + "model/")

    transform = transforms.Compose([
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    with open(base_dir + "vocab.pkl", "rb") as f:
        vocab = pickle.load(f)
    vocab_size = len(vocab)

    # 新建加载数据集
    loader = get_loader(args.base_dir,
                        args.part,
                        vocab,
                        transform,
                        args.batch_size,
                        shuffle=True,
                        num_workers=args.num_workers)
    # 随机显示一张图片和对应标签
    # plotting(loader, args)

    # 实例化编码器和解码器
    encoder = EncoderCNN(args.embed_size)
    decoder = DecoderRNN(args.embed_size,
                         vocab_size,
                         args.hidden_size,
                         args.num_layers,
                         max_seq=20)

    num_captions = 5
    num_examples = len(loader)
    loss_func = nn.CrossEntropyLoss()
    params = list(decoder.parameters()) + list(encoder.parameters()) + list(
        encoder.bn.parameters())
    optimizer = Adam(params, 0.001)

    for epoch in range(args.num_epoch):
        for i, (images, captions, lengths) in enumerate(loader):
            for j in range(num_captions):
                caption = captions[:, j, :]
                length = torch.Tensor(lengths)[:, j]
                length, _ = torch.sort(length, dim=0, descending=True)
                targets = pack_padded_sequence(caption,
                                               length,
                                               batch_first=True)[0]

                # 正反向传播及优化
                features = encoder(images)
                outputs = decoder(features, caption, length)
                loss = loss_func(outputs, targets)

                decoder.zero_grad()
                encoder.zero_grad()
                loss.backward()

                optimizer.step()
            if i % 10 == 0:
                print(
                    "Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}"
                    .format(epoch + 1, args.num_epoch, i, num_examples,
                            loss.item(), np.exp(loss.item())))
        torch.save(
            decoder.state_dict(),
            os.path.join(args.model_path,
                         'decoder-epoch-{}.ckpt'.format(epoch + 1)))
        torch.save(
            encoder.state_dict(),
            os.path.join(args.model_path,
                         'encoder-epoch-{}.ckpt'.format(epoch + 1)))