def main(): args = parse_args() print('BATCH_SIZE: {}'.format(args.batch_size)) print('EMBEDDING_DIM: {}'.format(args.embedding_dim)) print('DEC_HIDDEN_DIM: {}'.format(args.dec_hidden_dim)) print('LR: {}'.format(args.lr)) print('ENCODER DROPOUT: {}'.format(args.enc_dropout)) print('DECODER DROPOUT: {}'.format(args.dec_dropout)) print('EPOCHS: {}'.format(args.epochs)) print('LOG_INTERVAL: {}'.format(args.log_interval)) print('USE PRETRAINED: {}'.format(args.use_pretrained)) print('USE CURRICULUM LEARNING: {}'.format(args.use_curriculum_learning)) # Prepare data & split dataset = ImageCaptionDataset(args.image_folder, args.caption_path) train_set, test_set = dataset.random_split(train_portion=0.8) train_dataloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True) test_dataloader = DataLoader(test_set, batch_size=args.batch_size) print('Training set size: {}'.format(len(train_set))) print('Test set size: {}'.format(len(test_set))) print('Vocab size: {}'.format(len(dataset.vocab))) print('----------------------------') # Create model & optimizer encoder = ImageEncoder(device, pretrained=args.use_pretrained).to(device) decoder = CaptionDecoder(device, len(dataset.vocab), embedding_dim=args.embedding_dim, enc_hidden_dim=encoder.hidden_dim, dec_hidden_dim=args.dec_hidden_dim, dropout=args.dec_dropout, use_pretrained_emb=args.use_pretrained, word_to_int=dataset.word_to_int).to(device) enc_optimizer = torch.optim.Adam(encoder.parameters(), lr=args.lr) dec_optimizer = torch.optim.Adam(decoder.parameters(), lr=args.lr) # Train train(encoder, decoder, enc_optimizer, dec_optimizer, train_dataloader, dataset, args) # Save model torch.save(encoder.cpu().state_dict(), args.output_encoder) torch.save(decoder.cpu().state_dict(), args.output_decoder) encoder.to(device) decoder.to(device) # Test test(encoder, decoder, test_dataloader, dataset, args)
def build_models(self): # ###################encoders######################################## # image_encoder = ImageEncoder(output_channels=cfg.hidden_dim) if cfg.text_encoder_path != '': img_encoder_path = cfg.text_encoder_path.replace('text_encoder', 'image_encoder') print('Load image encoder from:', img_encoder_path) state_dict = torch.load(img_encoder_path, map_location='cpu') if 'model' in state_dict.keys(): image_encoder.load_state_dict(state_dict['model']) else: image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): # make image encoder grad on p.requires_grad = True # image_encoder.eval() epoch = 0 ################################################################### text_encoder = TextEncoder(bert_config = self.bert_config) if cfg.text_encoder_path != '': epoch = cfg.text_encoder_path[istart:iend] epoch = int(epoch) + 1 text_encoder_path = cfg.text_encoder_path print('Load text encoder from:', text_encoder_path) state_dict = torch.load(text_encoder_path, map_location='cpu') if 'model' in state_dict.keys(): text_encoder.load_state_dict(state_dict['model']) else: text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): # make text encoder grad on p.requires_grad = True # ########################################################### # if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() return [text_encoder, image_encoder, epoch]
def main(args): # Create model directory if not os.path.exists(args.model_path): os.makedirs(args.model_path) # Image preprocessing train_transform = transforms.Compose([ transforms.RandomCrop(args.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # val_transform = transforms.Compose([ # transforms.Resize(args.image_size, interpolation=Image.LANCZOS), # transforms.ToTensor(), # transforms.Normalize((0.485, 0.456, 0.406), # (0.229, 0.224, 0.225))]) # Load vocabulary wrapper. with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Build data loader train_data_loader = get_loader(args.train_image_dir, args.train_vqa_path, args.ix_to_ans_file, args.train_description_file, vocab, train_transform, args.batch_size, shuffle=True, num_workers=args.num_workers) #val_data_loader = get_loader(args.val_image_dir, args.val_vqa_path, args.ix_to_ans_file, vocab, val_transform, args.batch_size, shuffle=False, num_workers=args.num_workers) image_encoder = ImageEncoder(args.img_feature_size) question_emb_size = 1024 # description_emb_size = 512 no_ans = 1000 question_encoder = BertEncoder(question_emb_size) # ques_description_encoder = BertEncoder(description_emb_size) # vqa_decoder = VQA_Model(args.img_feature_size, question_emb_size, description_emb_size, no_ans) vqa_decoder = VQA_Model(args.img_feature_size, question_emb_size, no_ans) pretrained_epoch = 0 if args.pretrained_epoch > 0: pretrained_epoch = args.pretrained_epoch image_encoder.load_state_dict(torch.load('./models/image_encoder-' + str(pretrained_epoch) + '.pkl')) question_encoder.load_state_dict(torch.load('./models/question_encoder-' + str(pretrained_epoch) + '.pkl')) # ques_description_encoder.load_state_dict(torch.load('./models/ques_description_encoder-' + str(pretrained_epoch) + '.pkl')) vqa_decoder.load_state_dict(torch.load('./models/vqa_decoder-' + str(pretrained_epoch) + '.pkl')) if torch.cuda.is_available(): image_encoder.cuda() question_encoder.cuda() # ques_description_encoder.cuda() vqa_decoder.cuda() print("Cuda is enabled...") criterion = nn.CrossEntropyLoss() # params = image_encoder.get_params() + question_encoder.get_params() + ques_description_encoder.get_params() + vqa_decoder.get_params() params = list(image_encoder.parameters()) + list(question_encoder.parameters()) + list(vqa_decoder.parameters()) #print("params: ", params) optimizer = torch.optim.Adam(params, lr=args.learning_rate, weight_decay=args.weight_decay) total_train_step = len(train_data_loader) min_avg_loss = float("inf") overfit_warn = 0 for epoch in range(args.num_epochs): if epoch < pretrained_epoch: continue image_encoder.train() question_encoder.train() #ques_description_encoder.train() vqa_decoder.train() avg_loss = 0.0 avg_acc = 0.0 for bi, (question_arr, image_vqa, target_answer, answer_str) in enumerate(train_data_loader): loss = 0 image_encoder.zero_grad() question_encoder.zero_grad() #ques_description_encoder.zero_grad() vqa_decoder.zero_grad() images = to_var(torch.stack(image_vqa)) question_arr = to_var(torch.stack(question_arr)) #ques_desc_arr = to_var(torch.stack(ques_desc_arr)) target_answer = to_var(torch.tensor(target_answer)) image_emb = image_encoder(images) question_emb = question_encoder(question_arr) #ques_desc_emb = ques_description_encoder(ques_desc_arr) #output = vqa_decoder(image_emb, question_emb, ques_desc_emb) output = vqa_decoder(image_emb, question_emb) loss = criterion(output, target_answer) _, prediction = torch.max(output,1) no_correct_prediction = prediction.eq(target_answer).sum().item() accuracy = no_correct_prediction * 100/ args.batch_size #### target_answer_no = target_answer.tolist() prediction_no = prediction.tolist() #### loss_num = loss.item() avg_loss += loss.item() avg_acc += no_correct_prediction #loss /= (args.batch_size) loss.backward() optimizer.step() # Print log info if bi % args.log_step == 0: print('Epoch [%d/%d], Train Step [%d/%d], Loss: %.4f, Acc: %.4f' %(epoch + 1, args.num_epochs, bi, total_train_step, loss.item(), accuracy)) avg_loss /= (args.batch_size * total_train_step) avg_acc /= (args.batch_size * total_train_step) print('Epoch [%d/%d], Average Train Loss: %.4f, Average Train acc: %.4f' %(epoch + 1, args.num_epochs, avg_loss, avg_acc)) # Save the models torch.save(image_encoder.state_dict(), os.path.join(args.model_path, 'image_encoder-%d.pkl' %(epoch+1))) torch.save(question_encoder.state_dict(), os.path.join(args.model_path, 'question_encoder-%d.pkl' %(epoch+1))) #torch.save(ques_description_encoder.state_dict(), os.path.join(args.model_path, 'ques_description_encoder-%d.pkl' %(epoch+1))) torch.save(vqa_decoder.state_dict(), os.path.join(args.model_path, 'vqa_decoder-%d.pkl' %(epoch+1))) overfit_warn = overfit_warn + 1 if (min_avg_loss < avg_loss) else 0 min_avg_loss = min(min_avg_loss, avg_loss) lossFileName = "result/result_"+str(epoch)+".txt" test_fd = open(lossFileName, 'w') test_fd.write('Epoch: '+ str(epoch) + ' avg_loss: ' + str(avg_loss)+ " avg_acc: "+ str(avg_acc)+"\n") test_fd.close() if overfit_warn >= 5: print("terminated as overfitted") break
def main(): args = parse_args() transform = transforms.Compose([ transforms.Resize((args.imsize, args.imsize)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if args.dataset == 'coco': train_dset = CocoDataset(root=args.root_path, transform=transform, mode='one') val_dset = CocoDataset(root=args.root_path, imgdir='val2017', jsonfile='annotations/captions_val2017.json', transform=transform, mode='all') train_loader = DataLoader(train_dset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_cpu, collate_fn=collater_train) val_loader = DataLoader(val_dset, batch_size=args.batch_size, shuffle=False, num_workers=args.n_cpu, collate_fn=collater_eval) vocab = Vocabulary(max_len=args.max_len) vocab.load_vocab(args.vocab_path) imenc = ImageEncoder(args.out_size, args.cnn_type) capenc = CaptionEncoder(len(vocab), args.emb_size, args.out_size, args.rnn_type) device = torch.device( "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") imenc = imenc.to(device) capenc = capenc.to(device) optimizer = optim.SGD([{ 'params': imenc.parameters(), 'lr': args.lr_cnn, 'momentum': args.mom_cnn }, { 'params': capenc.parameters(), 'lr': args.lr_rnn, 'momentum': args.mom_rnn }]) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=args.patience, verbose=True) lossfunc = PairwiseRankingLoss(margin=args.margin, method=args.method, improved=args.improved, intra=args.intra) if args.checkpoint is not None: print("loading model and optimizer checkpoint from {} ...".format( args.checkpoint), flush=True) ckpt = torch.load(args.checkpoint) imenc.load_state_dict(ckpt["encoder_state"]) capenc.load_state_dict(ckpt["decoder_state"]) optimizer.load_state_dict(ckpt["optimizer_state"]) scheduler.load_state_dict(ckpt["scheduler_state"]) offset = ckpt["epoch"] else: offset = 0 imenc = nn.DataParallel(imenc) capenc = nn.DataParallel(capenc) metrics = {} assert offset < args.max_epochs for ep in range(offset, args.max_epochs): imenc, capenc, optimizer = train(ep + 1, train_loader, imenc, capenc, optimizer, lossfunc, vocab, args) data = validate(ep + 1, val_loader, imenc, capenc, vocab, args) totalscore = 0 for rank in [1, 5, 10, 20]: totalscore += data["i2c_recall@{}".format(rank)] + data[ "c2i_recall@{}".format(rank)] scheduler.step(totalscore) # save checkpoint ckpt = { "stats": data, "epoch": ep + 1, "encoder_state": imenc.module.state_dict(), "decoder_state": capenc.module.state_dict(), "optimizer_state": optimizer.state_dict(), "scheduler_state": scheduler.state_dict() } if not os.path.exists(args.model_save_path): os.makedirs(args.model_save_path) savepath = os.path.join( args.model_save_path, "epoch_{:04d}_score_{:05d}.ckpt".format(ep + 1, int(100 * totalscore))) print( "saving model and optimizer checkpoint to {} ...".format(savepath), flush=True) torch.save(ckpt, savepath) print("done for epoch {}".format(ep + 1), flush=True) for k, v in data.items(): if k not in metrics.keys(): metrics[k] = [v] else: metrics[k].append(v) visualize(metrics, args)