def main():
    args = parse_args()

    print('BATCH_SIZE: {}'.format(args.batch_size))
    print('EMBEDDING_DIM: {}'.format(args.embedding_dim))
    print('DEC_HIDDEN_DIM: {}'.format(args.dec_hidden_dim))
    print('LR: {}'.format(args.lr))
    print('ENCODER DROPOUT: {}'.format(args.enc_dropout))
    print('DECODER DROPOUT: {}'.format(args.dec_dropout))
    print('EPOCHS: {}'.format(args.epochs))
    print('LOG_INTERVAL: {}'.format(args.log_interval))
    print('USE PRETRAINED: {}'.format(args.use_pretrained))
    print('USE CURRICULUM LEARNING: {}'.format(args.use_curriculum_learning))

    # Prepare data & split
    dataset = ImageCaptionDataset(args.image_folder, args.caption_path)
    train_set, test_set = dataset.random_split(train_portion=0.8)
    train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
    test_dataloader = DataLoader(test_set, batch_size=args.batch_size)
    print('Training set size: {}'.format(len(train_set)))
    print('Test set size: {}'.format(len(test_set)))
    print('Vocab size: {}'.format(len(dataset.vocab)))
    print('----------------------------')

    # Create model & optimizer
    encoder = ImageEncoder(device, pretrained=args.use_pretrained).to(device)
    decoder = CaptionDecoder(device, len(dataset.vocab), embedding_dim=args.embedding_dim,
                             enc_hidden_dim=encoder.hidden_dim, dec_hidden_dim=args.dec_hidden_dim, dropout=args.dec_dropout,
                             use_pretrained_emb=args.use_pretrained, word_to_int=dataset.word_to_int).to(device)
    enc_optimizer = torch.optim.Adam(encoder.parameters(), lr=args.lr)
    dec_optimizer = torch.optim.Adam(decoder.parameters(), lr=args.lr)

    # Train
    train(encoder, decoder, enc_optimizer, dec_optimizer, train_dataloader, dataset, args)

    # Save model
    torch.save(encoder.cpu().state_dict(), args.output_encoder)
    torch.save(decoder.cpu().state_dict(), args.output_decoder)
    encoder.to(device)
    decoder.to(device)

    # Test
    test(encoder, decoder, test_dataloader, dataset, args)
Пример #2
0
    def build_models(self):
        # ###################encoders######################################## #
      
        image_encoder = ImageEncoder(output_channels=cfg.hidden_dim)
        if cfg.text_encoder_path != '':
            img_encoder_path = cfg.text_encoder_path.replace('text_encoder', 'image_encoder')
            print('Load image encoder from:', img_encoder_path)
            state_dict = torch.load(img_encoder_path, map_location='cpu')
            if 'model' in state_dict.keys():
                image_encoder.load_state_dict(state_dict['model'])
            else:
                image_encoder.load_state_dict(state_dict)
        for p in image_encoder.parameters(): # make image encoder grad on
            p.requires_grad = True
   
        
#         image_encoder.eval()
        epoch = 0
        
        ###################################################################
        text_encoder = TextEncoder(bert_config = self.bert_config)
        if cfg.text_encoder_path != '':
            epoch = cfg.text_encoder_path[istart:iend]
            epoch = int(epoch) + 1
            text_encoder_path = cfg.text_encoder_path
            print('Load text encoder from:', text_encoder_path)
            state_dict = torch.load(text_encoder_path, map_location='cpu')
            if 'model' in state_dict.keys():
                text_encoder.load_state_dict(state_dict['model'])
            else:
                text_encoder.load_state_dict(state_dict)
        for p in text_encoder.parameters(): # make text encoder grad on
            p.requires_grad = True
           
        # ########################################################### #
        if cfg.CUDA:
            text_encoder = text_encoder.cuda()
            image_encoder = image_encoder.cuda()
            
        return [text_encoder, image_encoder, epoch]
Пример #3
0
def main(args):
    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing
    train_transform = transforms.Compose([
        transforms.RandomCrop(args.image_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406),
                             (0.229, 0.224, 0.225))])

    # val_transform = transforms.Compose([
    #     transforms.Resize(args.image_size, interpolation=Image.LANCZOS),
    #     transforms.ToTensor(),
    #     transforms.Normalize((0.485, 0.456, 0.406),
    #                          (0.229, 0.224, 0.225))])

    # Load vocabulary wrapper.
    with open(args.vocab_path, 'rb') as f:
        vocab = pickle.load(f)
    # Build data loader
    train_data_loader = get_loader(args.train_image_dir, args.train_vqa_path, args.ix_to_ans_file, args.train_description_file, vocab, train_transform, args.batch_size, shuffle=True, num_workers=args.num_workers)
    #val_data_loader = get_loader(args.val_image_dir, args.val_vqa_path, args.ix_to_ans_file, vocab, val_transform, args.batch_size, shuffle=False, num_workers=args.num_workers)

    image_encoder = ImageEncoder(args.img_feature_size)
    question_emb_size = 1024
    # description_emb_size = 512
    no_ans = 1000
    question_encoder = BertEncoder(question_emb_size)
    # ques_description_encoder = BertEncoder(description_emb_size)
    # vqa_decoder = VQA_Model(args.img_feature_size, question_emb_size, description_emb_size, no_ans)
    vqa_decoder = VQA_Model(args.img_feature_size, question_emb_size, no_ans)
    
    

    pretrained_epoch = 0
    if args.pretrained_epoch > 0:
        pretrained_epoch = args.pretrained_epoch
        image_encoder.load_state_dict(torch.load('./models/image_encoder-' + str(pretrained_epoch) + '.pkl'))
        question_encoder.load_state_dict(torch.load('./models/question_encoder-' + str(pretrained_epoch) + '.pkl'))
        # ques_description_encoder.load_state_dict(torch.load('./models/ques_description_encoder-' + str(pretrained_epoch) + '.pkl'))
        vqa_decoder.load_state_dict(torch.load('./models/vqa_decoder-' + str(pretrained_epoch) + '.pkl'))

    if torch.cuda.is_available():
        image_encoder.cuda()
        question_encoder.cuda()
        # ques_description_encoder.cuda()
        vqa_decoder.cuda()
        print("Cuda is enabled...")

    criterion = nn.CrossEntropyLoss()
    # params = image_encoder.get_params() + question_encoder.get_params() + ques_description_encoder.get_params() + vqa_decoder.get_params()
    params = list(image_encoder.parameters()) + list(question_encoder.parameters())  + list(vqa_decoder.parameters())
    #print("params: ", params)
    optimizer = torch.optim.Adam(params, lr=args.learning_rate, weight_decay=args.weight_decay)
    total_train_step = len(train_data_loader)

    min_avg_loss = float("inf")
    overfit_warn = 0

    for epoch in range(args.num_epochs):
        if epoch < pretrained_epoch:
            continue

        image_encoder.train()
        question_encoder.train()
        #ques_description_encoder.train()
        vqa_decoder.train()
        avg_loss = 0.0
        avg_acc = 0.0
        for bi, (question_arr, image_vqa, target_answer, answer_str) in enumerate(train_data_loader):
            loss = 0
            image_encoder.zero_grad()
            question_encoder.zero_grad()
            #ques_description_encoder.zero_grad()
            vqa_decoder.zero_grad()
            
            images = to_var(torch.stack(image_vqa))    
            question_arr = to_var(torch.stack(question_arr))
            #ques_desc_arr = to_var(torch.stack(ques_desc_arr))
            target_answer = to_var(torch.tensor(target_answer))

            image_emb = image_encoder(images)
            question_emb = question_encoder(question_arr)
            #ques_desc_emb = ques_description_encoder(ques_desc_arr)
            #output = vqa_decoder(image_emb, question_emb, ques_desc_emb)
            output = vqa_decoder(image_emb, question_emb)
            
            loss = criterion(output, target_answer)

            _, prediction = torch.max(output,1)
            no_correct_prediction = prediction.eq(target_answer).sum().item()
            accuracy = no_correct_prediction * 100/ args.batch_size

            ####
            target_answer_no = target_answer.tolist()
            prediction_no = prediction.tolist()
            ####
            loss_num = loss.item()
            avg_loss += loss.item()
            avg_acc += no_correct_prediction
            #loss /= (args.batch_size)
            loss.backward()
            optimizer.step()

            # Print log info
            if bi % args.log_step == 0:
                print('Epoch [%d/%d], Train Step [%d/%d], Loss: %.4f, Acc: %.4f'
                      %(epoch + 1, args.num_epochs, bi, total_train_step, loss.item(), accuracy))
            
        avg_loss /= (args.batch_size * total_train_step)
        avg_acc /= (args.batch_size * total_train_step)
        print('Epoch [%d/%d], Average Train Loss: %.4f, Average Train acc: %.4f' %(epoch + 1, args.num_epochs, avg_loss, avg_acc))

        # Save the models
        
        torch.save(image_encoder.state_dict(), os.path.join(args.model_path, 'image_encoder-%d.pkl' %(epoch+1)))
        torch.save(question_encoder.state_dict(), os.path.join(args.model_path, 'question_encoder-%d.pkl' %(epoch+1)))
        #torch.save(ques_description_encoder.state_dict(), os.path.join(args.model_path, 'ques_description_encoder-%d.pkl' %(epoch+1)))
        torch.save(vqa_decoder.state_dict(), os.path.join(args.model_path, 'vqa_decoder-%d.pkl' %(epoch+1)))

        overfit_warn = overfit_warn + 1 if (min_avg_loss < avg_loss) else 0
        min_avg_loss = min(min_avg_loss, avg_loss)
        lossFileName = "result/result_"+str(epoch)+".txt"
        test_fd = open(lossFileName, 'w')
        test_fd.write('Epoch: '+ str(epoch) + ' avg_loss: ' + str(avg_loss)+ " avg_acc: "+ str(avg_acc)+"\n")
        test_fd.close()

        if overfit_warn >= 5:
            print("terminated as overfitted")
            break
Пример #4
0
def main():
    args = parse_args()

    transform = transforms.Compose([
        transforms.Resize((args.imsize, args.imsize)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    if args.dataset == 'coco':
        train_dset = CocoDataset(root=args.root_path,
                                 transform=transform,
                                 mode='one')
        val_dset = CocoDataset(root=args.root_path,
                               imgdir='val2017',
                               jsonfile='annotations/captions_val2017.json',
                               transform=transform,
                               mode='all')
    train_loader = DataLoader(train_dset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.n_cpu,
                              collate_fn=collater_train)
    val_loader = DataLoader(val_dset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.n_cpu,
                            collate_fn=collater_eval)

    vocab = Vocabulary(max_len=args.max_len)
    vocab.load_vocab(args.vocab_path)

    imenc = ImageEncoder(args.out_size, args.cnn_type)
    capenc = CaptionEncoder(len(vocab), args.emb_size, args.out_size,
                            args.rnn_type)

    device = torch.device(
        "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")

    imenc = imenc.to(device)
    capenc = capenc.to(device)

    optimizer = optim.SGD([{
        'params': imenc.parameters(),
        'lr': args.lr_cnn,
        'momentum': args.mom_cnn
    }, {
        'params': capenc.parameters(),
        'lr': args.lr_rnn,
        'momentum': args.mom_rnn
    }])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='max',
                                                     factor=0.1,
                                                     patience=args.patience,
                                                     verbose=True)
    lossfunc = PairwiseRankingLoss(margin=args.margin,
                                   method=args.method,
                                   improved=args.improved,
                                   intra=args.intra)

    if args.checkpoint is not None:
        print("loading model and optimizer checkpoint from {} ...".format(
            args.checkpoint),
              flush=True)
        ckpt = torch.load(args.checkpoint)
        imenc.load_state_dict(ckpt["encoder_state"])
        capenc.load_state_dict(ckpt["decoder_state"])
        optimizer.load_state_dict(ckpt["optimizer_state"])
        scheduler.load_state_dict(ckpt["scheduler_state"])
        offset = ckpt["epoch"]
    else:
        offset = 0
    imenc = nn.DataParallel(imenc)
    capenc = nn.DataParallel(capenc)

    metrics = {}

    assert offset < args.max_epochs
    for ep in range(offset, args.max_epochs):
        imenc, capenc, optimizer = train(ep + 1, train_loader, imenc, capenc,
                                         optimizer, lossfunc, vocab, args)
        data = validate(ep + 1, val_loader, imenc, capenc, vocab, args)
        totalscore = 0
        for rank in [1, 5, 10, 20]:
            totalscore += data["i2c_recall@{}".format(rank)] + data[
                "c2i_recall@{}".format(rank)]
        scheduler.step(totalscore)

        # save checkpoint
        ckpt = {
            "stats": data,
            "epoch": ep + 1,
            "encoder_state": imenc.module.state_dict(),
            "decoder_state": capenc.module.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict()
        }
        if not os.path.exists(args.model_save_path):
            os.makedirs(args.model_save_path)
        savepath = os.path.join(
            args.model_save_path,
            "epoch_{:04d}_score_{:05d}.ckpt".format(ep + 1,
                                                    int(100 * totalscore)))
        print(
            "saving model and optimizer checkpoint to {} ...".format(savepath),
            flush=True)
        torch.save(ckpt, savepath)
        print("done for epoch {}".format(ep + 1), flush=True)

        for k, v in data.items():
            if k not in metrics.keys():
                metrics[k] = [v]
            else:
                metrics[k].append(v)

    visualize(metrics, args)