示例#1
0
def load_dataset(args):

    img_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    train_caption, validation_caption, test_caption, train_id, validation_id, val_img_id, test_img_id, img_idx, vocab_table = read_text_data(
        args)

    train_dataset = TrainDataset(args.img_dir, vocab_table, img_idx,
                                 train_caption, train_id, img_transform)
    val_img_dataset = ImgDataset(args.img_dir,
                                 val_img_id,
                                 is_val=True,
                                 transform=img_transform,
                                 img_idx=img_idx)
    test_img_dataset = ImgDataset(args.img_dir,
                                  test_img_id,
                                  transform=img_transform)
    val_caption_dataset = CaptionDataset(vocab_table,
                                         validation_caption,
                                         is_val=True,
                                         label=validation_id)
    test_caption_dataset = CaptionDataset(vocab_table, test_caption)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=4,
                              shuffle=True,
                              collate_fn=train_collate_fn)
    val_img_loader = DataLoader(val_img_dataset,
                                batch_size=args.batch_size,
                                num_workers=4)
    test_img_loader = DataLoader(test_img_dataset,
                                 batch_size=args.batch_size,
                                 num_workers=4)
    val_caption_loader = DataLoader(val_caption_dataset,
                                    batch_size=args.batch_size,
                                    num_workers=4,
                                    collate_fn=test_collate_fn)
    test_caption_loader = DataLoader(test_caption_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=4,
                                     collate_fn=test_collate_fn)

    return train_loader, val_img_loader, test_img_loader, val_caption_loader, test_caption_loader, vocab_table
    def test_dataset(self):
        flickr_dataset = CaptionDataset(
            os.path.join(os.path.abspath(os.path.join(__file__, "../..")),
                         'data', 'VAL.hdf5'))
        fig = plt.figure()
        for i in range(3):
            data = flickr_dataset[i]
            print(data['caption'])
            print(data['caption_unencode'])
            ax = plt.subplot(1, 3, i + 1)
            img = data['image'].permute(1, 2, 0).numpy()
            plt.imshow(img)
        plt.show()

        self.assertEqual(True, True)
示例#3
0
    decoder.load_state_dict(
        torch.load('../models/image_captioning_{}.model'.format(START_EPOCH)))
    print('done')

# Embedding
if EMBBEDING_DIM == 200:
    print('Loading embeddings', end='...')
    embedding, _ = load_embeddings(embedding_file, DATA_FOLDER)
    decoder.load_pretrained_embeddings(embedding, fine_tune=True)
    print('done')

# Loss function
criterion = nn.CrossEntropyLoss().to(DEVICE)

# Data loader
train_loader = torch.utils.data.DataLoader(CaptionDataset(
    DATA_FOLDER, 'TRAIN'),
                                           batch_size=BATCH_SIZE,
                                           shuffle=True,
                                           num_workers=1,
                                           pin_memory=True)

valid_loader = torch.utils.data.DataLoader(CaptionDataset(DATA_FOLDER, 'VAL'),
                                           batch_size=BATCH_SIZE,
                                           shuffle=True,
                                           num_workers=1,
                                           pin_memory=True)

# Optimizer
optimizer = torch.optim.Adam(decoder.parameters(), lr=LEARNING_RATE)

# Parameters check
示例#4
0
def main():
    parser = argparse.ArgumentParser(description='caption model')

    parser.add_argument('--save_dir',
                        type=str,
                        default='logs/tmp',
                        help='directory of model save')

    # 数据集参数
    parser.add_argument('--data_folder',
                        type=str,
                        default='./datasets/caption_data',
                        help='caption dataset folder')
    parser.add_argument('--data_name',
                        type=str,
                        default='flickr8k_5_cap_per_img_5_min_word_freq',
                        help='dataset name [coco, flickr8k, flickr30k]')
    parser.add_argument('--batch_size',
                        type=int,
                        default=32,
                        help='training batch size')
    parser.add_argument('--print_freq',
                        type=int,
                        default=100,
                        help='print training state every n times')
    parser.add_argument(
        '--num_workers',
        type=int,
        default=0,  #8,
        help='number of data loader workers ')

    parser.add_argument('--epochs',
                        type=int,
                        default=120,
                        help='total training epochs')
    parser.add_argument('--grad_clip',
                        type=float,
                        default=5.,
                        help='number of gradient clip')
    parser.add_argument('--alpha_c',
                        type=float,
                        default=1.,
                        help='ratio of attention matrix')
    parser.add_argument('--encoder_lr',
                        type=float,
                        default=1e-4,
                        help='encoder learning rate')
    parser.add_argument('--decoder_lr',
                        type=float,
                        default=4e-4,
                        help='decoder learning rate')

    # 模型参数
    parser.add_argument('--attention_dim',
                        type=float,
                        default=512,
                        help='dimension of attention')
    parser.add_argument('--embed_dim',
                        type=float,
                        default=512,
                        help='dimension of word embedding')
    parser.add_argument('--decoder_dim',
                        type=float,
                        default=512,
                        help='dimension of decoder')
    #default=2048, help='dimension of decoder')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.5,
                        help='rate of dropout')
    parser.add_argument('-frz',
                        '--freeze_encoder',
                        action='store_true',
                        help='whether freeze encoder parameters')

    args = parser.parse_args()

    mkdir_if_missing(args.save_dir)
    log_path = os.path.join(args.save_dir, 'log.txt')
    with open(log_path, 'w') as f:
        f.write('{}\n'.format(args))

    # 定义训练集的数据增强操作和验证集的数据增强操作
    # 图片的大小都已经 resize 到 256 x 256
    # 训练集和验证集都只需要将图片转换成 Tensor,然后用 ImageNet 的 mean 和 std 做标准化
    tfms = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    train_dataset = CaptionDataset(args.data_folder,
                                   args.data_name,
                                   split='TRAIN',
                                   transform=tfms)
    val_dataset = CaptionDataset(args.data_folder,
                                 args.data_name,
                                 split='VAL',
                                 transform=tfms)
    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers)  # suffle:打乱  num_workers:数据加载的子进程数量
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers)

    word_map_file = os.path.join(args.data_folder,
                                 'WORDMAP_' + args.data_name + '.json')
    with open(word_map_file, 'r') as f:
        word_map = json.load(f)

    # 初始化模型
    encoder = Encoder()
    encoder.freeze_params(args.freeze_encoder)
    decoder = DecoderWithAttention(attention_dim=args.attention_dim,
                                   embed_dim=args.embed_dim,
                                   decoder_dim=args.decoder_dim,
                                   vocab_size=len(word_map),
                                   dropout=args.dropout)

    # 定义 Encoder 和 Decoder 的优化器
    encoder_optimizer = torch.optim.Adam(params=filter(
        lambda p: p.requires_grad, encoder.parameters()),
                                         lr=args.encoder_lr)
    decoder_optimizer = torch.optim.Adam(params=filter(
        lambda p: p.requires_grad, decoder.parameters()),
                                         lr=args.decoder_lr)

    # 把模型放到 GPU 上
    encoder = encoder.to(device)
    decoder = decoder.to(device)

    criterion = nn.CrossEntropyLoss()

    train(args=args,
          train_loader=train_loader,
          val_loader=val_loader,
          encoder=encoder,
          decoder=decoder,
          criterion=criterion,
          encoder_optimizer=encoder_optimizer,
          decoder_optimizer=decoder_optimizer,
          log_path=log_path)
示例#5
0
    if fine_tune_encoder is True and encoder_optimizer is None:
        encoder.fine_tune(fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(
            lambda p: p.requires_grad, encoder.parameters()),
                                             lr=encoder_lr)

decoder = decoder.to(device)
encoder = encoder.to(device)

criterion = nn.CrossEntropyLoss().to(device)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.255])

train_loader = torch.utils.data.DataLoader(CaptionDataset(
    path, base_filename, 'TRAIN', transform=transforms.Compose([normalize])),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=workers,
                                           pin_memory=True)
val_loader = torch.utils.data.DataLoader(CaptionDataset(
    path, base_filename, 'VAL', transform=transforms.Compose([normalize])),
                                         batch_size=batch_size,
                                         shuffle=True,
                                         num_workers=workers,
                                         pin_memory=True)

for epoch in range(start_epoch, epochs):
    #early stopping if no imporvement after 10 epochs
    if epochs_since_improvement == 10:
        break
示例#6
0
if __name__ == '__main__':
    opt = parser.parse_args()
    epochs = opt.epochs
    lr = opt.lr
    batch_size = opt.batch_size
    check_point = opt.check_point
    save_dir = opt.save_dir

    # load vocabulary
    vocabulary_path = 'vocab/updown_vocab.json'
    with open(vocabulary_path, 'r') as v:
        vocab = json.load(v)

    # load training set
    training_set_path = os.path.join('data', 'TRAIN.hdf5')
    training_set = CaptionDataset(training_set_path)

    # load eval set
    eval_set_path = os.path.join('data', 'VAL.hdf5')
    eval_set = CaptionDataset(eval_set_path)

    # build data-loaders for both training set and eval set
    # make both of them iterable
    training_loader = DataLoader(dataset=training_set, batch_size=batch_size)
    eval_loader = DataLoader(dataset=eval_set, batch_size=batch_size)

    # build encoder
    encoder = MaskRCNN_Benchmark()

    # build decoder
    decoder = UpDownCaptioner(vocab=vocab)
def main():

    print('Training parameters Initialized')
    training_parameters = TrainingParameters( start_epoch = 0,
                                            epochs = 120,  # number of epochs to train for
                                            epochs_since_improvement = 0,  # Epochs since improvement in BLEU score
                                            batch_size = 32,
                                            workers = 1,  # for data-loading; right now, only 1 works with h5py
                                            fine_tune_encoder = True,  # fine-tune encoder
                                            encoder_lr = 1e-4,  # learning rate for encoder, if fine-tuning is used
                                            decoder_lr = 4e-4,  # learning rate for decoder
                                            grad_clip = 5.0,  # clip gradients at an absolute value of
                                            alpha_c = 1.0,  # regularization parameter for 'doubly stochastic attention'
                                            best_bleu4 = 0.0,  # BLEU-4 score right now
                                            print_freq = 100,  # print training/validation stats every __ batches
                                            checkpoint =  './Result/BEST_checkpoint_flickr8k_5_captions_per_image_5_minimum_word_frequency.pth.tar' # path to checkpoint, None if none
                                            # checkpoint = None
                                          )

    print('Loading Word-Map')
    word_map_file = os.path.join(data_folder,'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    print('Creating Model')

    if training_parameters.checkpoint is None:
        encoder = Encoder()
        encoder.fine_tune(training_parameters.fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p : p.requires_grad, encoder.parameters()),
                                                lr=training_parameters.encoder_lr) if training_parameters.fine_tune_encoder else None
        
        decoder = Decoder(attention_dimension = attention_dimension,
                            embedding_dimension = embedding_dimension,
                            hidden_dimension = hidden_dimension,
                            vocab_size = len(word_map),
                            device = device,
                            dropout = dropout)                            
        decoder_optimizer = torch.optim.Adam(params=filter(lambda p : p.requires_grad, decoder.parameters()),
                                                lr=training_parameters.decoder_lr)

    else:
        checkpoint = torch.load(training_parameters.checkpoint)
        training_parameters.start_epoch = checkpoint['epoch'] + 1
        training_parameters.epochs_since_improvement = checkpoint['epochs_since_improvement']
        training_parameters.best_bleu4 = checkpoint['bleu4']

        encoder = Encoder()
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        encoder_optimizer = checkpoint['encoder_optimizer']

        decoder = Decoder(attention_dimension = attention_dimension,
                            embedding_dimension = embedding_dimension,
                            hidden_dimension = hidden_dimension,
                            vocab_size = len(word_map),
                            device = device,
                            dropout = dropout)
        decoder.load_state_dict(checkpoint['decoder_state_dict'])
        decoder_optimizer = checkpoint['decoder_optimizer']

        if training_parameters.fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(training_parameters.fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p : p.requires_grad, encoder.parameters()),
                                                lr=training_parameters.encoder_lr)

    encoder.to(device)
    decoder.to(device)

    criterion = nn.CrossEntropyLoss().to(device)
        
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print('Creating Data Loaders')
    train_dataloader = torch.utils.data.DataLoader(
                                    CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
                                    batch_size=training_parameters.batch_size, shuffle=True)
    
    validation_dataloader = torch.utils.data.DataLoader(
                                    CaptionDataset(data_folder, data_name, 'VALID', transform=transforms.Compose([normalize])),
                                    batch_size=training_parameters.batch_size, shuffle=True, pin_memory=True)

    for epoch in range(training_parameters.start_epoch, training_parameters.epochs):

        if training_parameters.epochs_since_improvement == 20:
            break
        if training_parameters.epochs_since_improvement > 0  and training_parameters.epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if training_parameters.fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        train(train_loader = train_dataloader,
              encoder = encoder,
              decoder = decoder,
              criterion = criterion,
              encoder_optimizer = encoder_optimizer,
              decoder_optimizer = decoder_optimizer,
              epoch = epoch,
              device = device,
              training_parameters = training_parameters)

        recent_bleu4_score = validate(validation_loader = validation_dataloader,
                                    encoder = encoder,
                                    decoder = decoder,
                                    criterion = criterion,
                                    word_map = word_map,
                                    device = device,
                                    training_parameters = training_parameters)

        is_best_score = recent_bleu4_score > training_parameters.best_bleu4
        training_parameters.best_bleu4 = max(recent_bleu4_score, training_parameters.best_bleu4)
        if not is_best_score:
            training_parameters.epochs_since_improvement += 1
            print('\nEpochs since last improvement : %d\n' % (training_parameters.epochs_since_improvement))
        else:
            training_parameters.epochs_since_improvement = 0
        
        save_checkpoint(data_name, epoch, training_parameters.epochs_since_improvement, encoder, decoder,
                        encoder_optimizer, decoder_optimizer, recent_bleu4_score, is_best_score)