def recommend_train(context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, train_dataset: Dataset, valid_dataset: Dataset, test_dataset: Dataset, model_file: str, vocab_size: int, embed_init=None): """Recommend train. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. train_dataset (Dataset): Train dataset. valid_dataset (Dataset): Valid dataset. test_dataset (Dataset): Test dataset. model_file (str): Saved model file. vocab_size (int): Vocabulary size. embed_init: Initial embedding (vocab_size, embed_size). """ # Data loader. train_data_loader = DataLoader( dataset=train_dataset, batch_size=RecommendTrainConfig.batch_size, shuffle=True, num_workers=RecommendTrainConfig.num_data_loader_workers) # Model. similarity_config = SimilarityConfig(vocab_size, embed_init) similarity = Similarity(similarity_config).to(GlobalConfig.device) # Model parameters. params = list( chain.from_iterable([ list(model.parameters()) for model in [ context_text_encoder, context_image_encoder, context_encoder, similarity ] ])) optimizer = Adam(params, lr=RecommendTrainConfig.learning_rate) epoch_id = 0 min_valid_loss = None # Load saved state. if isfile(model_file): state = torch.load(model_file) similarity.load_state_dict(state['similarity']) optimizer.load_state_dict(state['optimizer']) epoch_id = state['epoch_id'] min_valid_loss = state['min_valid_loss'] # Loss. sum_loss = 0 bad_loss_cnt = 0 # Switch to train mode. context_text_encoder.train() context_image_encoder.train() context_encoder.train() similarity.train() finished = False for epoch_id in range(epoch_id, RecommendTrainConfig.num_iterations): for batch_id, train_data in enumerate(train_data_loader): # Sets gradients to 0. optimizer.zero_grad() context_dialog, pos_products, neg_products = train_data texts, text_lengths, images, utter_types = context_dialog # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) batch_size = texts.size(0) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) # utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, _ = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) loss = recommend_loss(similarity, batch_size, context, pos_products, neg_products) sum_loss += loss loss.backward() optimizer.step() # Print loss every `TrainConfig.print_freq` batches. if (batch_id + 1) % RecommendTrainConfig.print_freq == 0: cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") sum_loss /= RecommendTrainConfig.print_freq print('epoch: {} \tbatch: {} \tloss: {} \ttime: {}'.format( epoch_id + 1, batch_id + 1, sum_loss, cur_time)) sum_loss = 0 # Valid every `TrainConfig.valid_freq` batches. if (batch_id + 1) % RecommendTrainConfig.valid_freq == 0: valid_loss = recommend_valid(context_text_encoder, context_image_encoder, context_encoder, similarity, valid_dataset) cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print('valid_loss: {} \ttime: {}'.format(valid_loss, cur_time)) # Save current best model. if min_valid_loss is None or valid_loss < min_valid_loss: min_valid_loss = valid_loss bad_loss_cnt = 0 save_dict = { 'task': RECOMMEND_TASK, 'epoch_id': epoch_id, 'min_valid_loss': min_valid_loss, 'optimizer': optimizer.state_dict(), 'context_text_encoder': context_text_encoder.state_dict(), 'context_image_encoder': context_image_encoder.state_dict(), 'context_encoder': context_encoder.state_dict(), 'similarity': similarity.state_dict() } torch.save(save_dict, model_file) print('Best model saved.') else: bad_loss_cnt += 1 if bad_loss_cnt > RecommendTrainConfig.patience: recommend_test(context_text_encoder, context_image_encoder, context_encoder, similarity, test_dataset) finished = True break if finished: break
def knowledge_attribute_train(context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, train_dataset: Dataset, valid_dataset: Dataset, test_dataset: Dataset, model_file: str, attribute_data: AttributeData, vocab: Dict[str, int], embed_init=None): """Knowledge styletip train. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. train_dataset (Dataset): Train dataset. valid_dataset (Dataset): Valid dataset. test_dataset (Dataset): Test dataset. model_file (str): Saved model file. attribute_data (AttributeData): Attribute data. vocab (Dict[str, int]): Vocabulary. embed_init: Initial embedding (vocab_size, embed_size). """ # Data loader. train_data_loader = DataLoader( dataset=train_dataset, batch_size=KnowledgeAttributeTrainConfig.batch_size, shuffle=True, num_workers=KnowledgeAttributeTrainConfig.num_data_loader_workers) # Model. vocab_size = len(vocab) attribute_kv_memory_config = AttributeKVMemoryConfig( len(attribute_data.key_vocab), len(attribute_data.value_vocab)) text_decoder_config = KnowledgeTextDecoderConfig(vocab_size, MemoryConfig.memory_size, MemoryConfig.output_size, embed_init) to_hidden = ToHidden(text_decoder_config) to_hidden = to_hidden.to(GlobalConfig.device) attribute_kv_memory = KVMemory(attribute_kv_memory_config) attribute_kv_memory = attribute_kv_memory.to(GlobalConfig.device) text_decoder = TextDecoder(text_decoder_config) text_decoder = text_decoder.to(GlobalConfig.device) # Model parameters. params = list( chain.from_iterable([ list(model.parameters()) for model in [ context_text_encoder, context_image_encoder, context_encoder, to_hidden, attribute_kv_memory, text_decoder ] ])) optimizer = Adam(params, lr=KnowledgeAttributeTrainConfig.learning_rate) epoch_id = 0 min_valid_loss = None # Load saved state. if isfile(model_file): state = torch.load(model_file) to_hidden.load_state_dict(state['to_hidden']) attribute_kv_memory.load_state_dict(state['attribute_kv_memory']) text_decoder.load_state_dict(state['text_decoder']) optimizer.load_state_dict(state['optimizer']) epoch_id = state['epoch_id'] min_valid_loss = state['min_valid_loss'] # Loss. sum_loss = 0 bad_loss_cnt = 0 # Switch to train mode. context_text_encoder.train() context_image_encoder.train() context_encoder.train() to_hidden.train() attribute_kv_memory.train() text_decoder.train() finished = False for epoch_id in range(epoch_id, KnowledgeAttributeTrainConfig.num_iterations): for batch_id, train_data in enumerate(train_data_loader): # Set gradients to 0. optimizer.zero_grad() train_data, products = train_data keys, values, pair_length = products keys = keys.to(GlobalConfig.device) values = values.to(GlobalConfig.device) pair_length = pair_length.to(GlobalConfig.device) texts, text_lengths, images, utter_types = train_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, hiddens = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) encode_knowledge_func = partial(attribute_kv_memory, keys, values, pair_length) loss, n_totals = text_loss(to_hidden, text_decoder, text_decoder_config.text_length, context, texts[-1], text_lengths[-1], hiddens, encode_knowledge_func) sum_loss += loss / text_decoder_config.text_length loss.backward() optimizer.step() # Print loss every `TrainConfig.print_freq` batches. if (batch_id + 1) % KnowledgeAttributeTrainConfig.print_freq == 0: cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") sum_loss /= KnowledgeAttributeTrainConfig.print_freq print('epoch: {} \tbatch: {} \tloss: {} \ttime: {}'.format( epoch_id + 1, batch_id + 1, sum_loss, cur_time)) sum_loss = 0 # Valid every `TrainConfig.valid_freq` batches. if (batch_id + 1) % KnowledgeAttributeTrainConfig.valid_freq == 0: valid_loss = knowledge_attribute_valid( context_text_encoder, context_image_encoder, context_encoder, to_hidden, attribute_kv_memory, text_decoder, valid_dataset, text_decoder_config.text_length) cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print('valid_loss: {} \ttime: {}'.format(valid_loss, cur_time)) # Save current best model. if min_valid_loss is None or valid_loss < min_valid_loss: min_valid_loss = valid_loss bad_loss_cnt = 0 save_dict = { 'task': KNOWLEDGE_ATTRIBUTE_SUBTASK, 'epoch_id': epoch_id, 'min_valid_loss': min_valid_loss, 'optimizer': optimizer.state_dict(), 'context_text_encoder': context_text_encoder.state_dict(), 'context_image_encoder': context_image_encoder.state_dict(), 'context_encoder': context_encoder.state_dict(), 'to_hidden': to_hidden.state_dict(), 'attribute_kv_memory': attribute_kv_memory.state_dict(), 'text_decoder': text_decoder.state_dict() } torch.save(save_dict, model_file) print('Best model saved.') else: bad_loss_cnt += 1 if bad_loss_cnt > KnowledgeAttributeTrainConfig.patience: knowledge_attribute_test( context_text_encoder, context_image_encoder, context_encoder, to_hidden, attribute_kv_memory, text_decoder, test_dataset, text_decoder_config.text_length, vocab) finished = True break if finished: break
loss.backward() optimizer.step() distributed.all_reduce(loss_manual_mining.data) losses_manual_mining += loss_manual_mining.data.item() / len( args.devices) if local_rank == 0: total = i + 1 tbar.set_description('epoch: %d, loss manual mining: %.3f' % (epoch + 1, losses_manual_mining / total)) # tbar.set_description('epoch: %d, loss1: %.3f, loss2: %.3f' # % (epoch + 1, losses_manual_mining / (i + 1), losses_hard_mining / (i + 1))) scheduler.step(epoch) if local_rank == 0: checkpoints = { 'score': score_model.module.state_dict(), 'item': image_encoder.state_dict(), 'optimizer': optimizer.state_dict() } torch.save( checkpoints, os.path.join(checkpoints_dir, 'model-epoch{}.pth'.format(epoch + 1))) # score_model.eval() # score_model.eval() with torch.no_grad(): valid(epoch + 1, checkpoints_dir, use_bert=use_bert) distributed.barrier()
def main(): # ignore warnings #warnings.simplefilter('ignore') args = get_arguments() SETTING = Dict(yaml.safe_load(open(os.path.join('arguments',args.arg+'.yaml'), encoding='utf8'))) print(args) args.device = list (map(str,args.device)) os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(args.device) #image transformer train_transform = transforms.Compose([ transforms.Resize(SETTING.imsize_pre), transforms.RandomCrop(SETTING.imsize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize(SETTING.imsize_pre), transforms.CenterCrop(SETTING.imsize), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # data load if args.dataset == 'coco': train_dset = CocoDset(root=SETTING.root_path,img_dir='train2017', ann_dir='annotations/captions_train2017.json', transform=train_transform) val_dset = CocoDset(root=SETTING.root_path, img_dir='val2017', ann_dir='annotations/captions_val2017.json', transform=val_transform) train_loader = DataLoader(train_dset, batch_size=SETTING.batch_size, shuffle=True, num_workers=SETTING.n_cpu, collate_fn=collater) val_loader = DataLoader(val_dset, batch_size=SETTING.batch_size, shuffle=False, num_workers=SETTING.n_cpu, collate_fn=collater) # setup vocab dict vocab = Vocabulary(max_len=SETTING.max_len) vocab.load_vocab(args.vocab_path) # setup encoder imenc = ImageEncoder(SETTING.out_size, SETTING.cnn_type) capenc = CaptionEncoder(len(vocab), SETTING.emb_size, SETTING.out_size, SETTING.rnn_type, vocab.padidx) device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") imenc = imenc.to(device) capenc = capenc.to(device) # learning rate cfgs = [{'params' : imenc.fc.parameters(), 'lr' : float(SETTING.lr_cnn)}, {'params' : capenc.parameters(), 'lr' : float(SETTING.lr_rnn)}] # optimizer if SETTING.optimizer == 'SGD': optimizer = optim.SGD(cfgs, momentum=SETTING.momentum, weight_decay=SETTING.weight_decay) elif SETTING.optimizer == 'Adam': optimizer = optim.Adam(cfgs, betas=(SETTING.beta1, SETTING.beta2), weight_decay=SETTING.weight_decay) elif SETTING.optimizer == 'RMSprop': optimizer = optim.RMSprop(cfgs, alpha=SETTING.alpha, weight_decay=SETTING.weight_decay) if SETTING.scheduler == 'Plateau': scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=SETTING.dampen_factor, patience=SETTING.patience, verbose=True) elif SETTING.scheduler == 'Step': scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=SETTING.patience, gamma=SETTING.dampen_factor) # loss lossfunc = PairwiseRankingLoss(margin=SETTING.margin, method=SETTING.method, improved=args.improved, intra=SETTING.intra, lamb=SETTING.imp_weight) # if start from checkpoint if args.checkpoint is not None: print("loading model and optimizer checkpoint from {} ...".format(args.checkpoint), flush=True) ckpt = torch.load(args.checkpoint) imenc.load_state_dict(ckpt["encoder_state"]) capenc.load_state_dict(ckpt["decoder_state"]) optimizer.load_state_dict(ckpt["optimizer_state"]) if SETTING.scheduler != 'None': scheduler.load_state_dict(ckpt["scheduler_state"]) offset = ckpt["epoch"] data = ckpt["stats"] bestscore = 0 for rank in [1, 5, 10, 20]: bestscore += data["i2c_recall@{}".format(rank)] + data["c2i_recall@{}".format(rank)] bestscore = int(bestscore) # start new training else: offset = 0 bestscore = -1 if args.dataparallel: print("Using Multiple GPU . . . ") imenc = nn.DataParallel(imenc) capenc = nn.DataParallel(capenc) metrics = {} es_cnt = 0 # training assert offset < SETTING.max_epochs for ep in range(offset, SETTING.max_epochs): epoch = ep+1 # unfreeze cnn parameters if epoch == SETTING.freeze_epoch: if args.dataparallel: optimizer.add_param_group({'params': imenc.module.cnn.parameters(), 'lr': float(SETTING.lr_cnn)}) else: optimizer.add_param_group({'params': imenc.cnn.parameters(), 'lr': float(SETTING.lr_cnn)}) #train(1epoch) train(epoch, train_loader, imenc, capenc, optimizer, lossfunc, vocab, args, SETTING) #validate data = validate(epoch, val_loader, imenc, capenc, vocab, args, SETTING) totalscore = 0 for rank in [1, 5, 10, 20]: totalscore += data["i2c_recall@{}".format(rank)] + data["c2i_recall@{}".format(rank)] totalscore = int(totalscore) #scheduler update if SETTING.scheduler == 'Plateau': scheduler.step(totalscore) if SETTING.scheduler == 'Step': scheduler.step() # update checkpoint if args.dataparallel: ckpt = { "stats": data, "epoch": epoch, "encoder_state": imenc.module.state_dict(), "decoder_state": capenc.module.state_dict(), "optimizer_state": optimizer.state_dict() } else: ckpt = { "stats": data, "epoch": epoch, "encoder_state": imenc.state_dict(), "decoder_state": capenc.state_dict(), "optimizer_state": optimizer.state_dict() } if SETTING.scheduler != 'None': ckpt['scheduler_state'] = scheduler.state_dict() # make savedir savedir = os.path.join("models", args.arg) if not os.path.exists(savedir): os.makedirs(savedir) # for k, v in data.items(): if k not in metrics.keys(): metrics[k] = [v] else: metrics[k].append(v) # save checkpoint savepath = os.path.join(savedir, "epoch_{:04d}_score_{:03d}.ckpt".format(epoch, totalscore)) if int(totalscore) > int(bestscore): print("score: {:03d}, saving model and optimizer checkpoint to {} ...".format(totalscore, savepath), flush=True) bestscore = totalscore torch.save(ckpt, savepath) es_cnt = 0 else: print("score: {:03d}, no improvement from best score of {:03d}, not saving".format(totalscore, bestscore), flush=True) es_cnt += 1 # early stopping if es_cnt == SETTING.es_cnt: print("early stopping at epoch {} because of no improvement for {} epochs".format(epoch, SETTING.es_cnt)) break print("done for epoch {:04d}".format(epoch), flush=True) visualize(metrics, args, SETTING) print("complete training")
def main(args): # Create model directory if not os.path.exists(args.model_path): os.makedirs(args.model_path) # Image preprocessing train_transform = transforms.Compose([ transforms.RandomCrop(args.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # val_transform = transforms.Compose([ # transforms.Resize(args.image_size, interpolation=Image.LANCZOS), # transforms.ToTensor(), # transforms.Normalize((0.485, 0.456, 0.406), # (0.229, 0.224, 0.225))]) # Load vocabulary wrapper. with open(args.vocab_path, 'rb') as f: vocab = pickle.load(f) # Build data loader train_data_loader = get_loader(args.train_image_dir, args.train_vqa_path, args.ix_to_ans_file, args.train_description_file, vocab, train_transform, args.batch_size, shuffle=True, num_workers=args.num_workers) #val_data_loader = get_loader(args.val_image_dir, args.val_vqa_path, args.ix_to_ans_file, vocab, val_transform, args.batch_size, shuffle=False, num_workers=args.num_workers) image_encoder = ImageEncoder(args.img_feature_size) question_emb_size = 1024 # description_emb_size = 512 no_ans = 1000 question_encoder = BertEncoder(question_emb_size) # ques_description_encoder = BertEncoder(description_emb_size) # vqa_decoder = VQA_Model(args.img_feature_size, question_emb_size, description_emb_size, no_ans) vqa_decoder = VQA_Model(args.img_feature_size, question_emb_size, no_ans) pretrained_epoch = 0 if args.pretrained_epoch > 0: pretrained_epoch = args.pretrained_epoch image_encoder.load_state_dict(torch.load('./models/image_encoder-' + str(pretrained_epoch) + '.pkl')) question_encoder.load_state_dict(torch.load('./models/question_encoder-' + str(pretrained_epoch) + '.pkl')) # ques_description_encoder.load_state_dict(torch.load('./models/ques_description_encoder-' + str(pretrained_epoch) + '.pkl')) vqa_decoder.load_state_dict(torch.load('./models/vqa_decoder-' + str(pretrained_epoch) + '.pkl')) if torch.cuda.is_available(): image_encoder.cuda() question_encoder.cuda() # ques_description_encoder.cuda() vqa_decoder.cuda() print("Cuda is enabled...") criterion = nn.CrossEntropyLoss() # params = image_encoder.get_params() + question_encoder.get_params() + ques_description_encoder.get_params() + vqa_decoder.get_params() params = list(image_encoder.parameters()) + list(question_encoder.parameters()) + list(vqa_decoder.parameters()) #print("params: ", params) optimizer = torch.optim.Adam(params, lr=args.learning_rate, weight_decay=args.weight_decay) total_train_step = len(train_data_loader) min_avg_loss = float("inf") overfit_warn = 0 for epoch in range(args.num_epochs): if epoch < pretrained_epoch: continue image_encoder.train() question_encoder.train() #ques_description_encoder.train() vqa_decoder.train() avg_loss = 0.0 avg_acc = 0.0 for bi, (question_arr, image_vqa, target_answer, answer_str) in enumerate(train_data_loader): loss = 0 image_encoder.zero_grad() question_encoder.zero_grad() #ques_description_encoder.zero_grad() vqa_decoder.zero_grad() images = to_var(torch.stack(image_vqa)) question_arr = to_var(torch.stack(question_arr)) #ques_desc_arr = to_var(torch.stack(ques_desc_arr)) target_answer = to_var(torch.tensor(target_answer)) image_emb = image_encoder(images) question_emb = question_encoder(question_arr) #ques_desc_emb = ques_description_encoder(ques_desc_arr) #output = vqa_decoder(image_emb, question_emb, ques_desc_emb) output = vqa_decoder(image_emb, question_emb) loss = criterion(output, target_answer) _, prediction = torch.max(output,1) no_correct_prediction = prediction.eq(target_answer).sum().item() accuracy = no_correct_prediction * 100/ args.batch_size #### target_answer_no = target_answer.tolist() prediction_no = prediction.tolist() #### loss_num = loss.item() avg_loss += loss.item() avg_acc += no_correct_prediction #loss /= (args.batch_size) loss.backward() optimizer.step() # Print log info if bi % args.log_step == 0: print('Epoch [%d/%d], Train Step [%d/%d], Loss: %.4f, Acc: %.4f' %(epoch + 1, args.num_epochs, bi, total_train_step, loss.item(), accuracy)) avg_loss /= (args.batch_size * total_train_step) avg_acc /= (args.batch_size * total_train_step) print('Epoch [%d/%d], Average Train Loss: %.4f, Average Train acc: %.4f' %(epoch + 1, args.num_epochs, avg_loss, avg_acc)) # Save the models torch.save(image_encoder.state_dict(), os.path.join(args.model_path, 'image_encoder-%d.pkl' %(epoch+1))) torch.save(question_encoder.state_dict(), os.path.join(args.model_path, 'question_encoder-%d.pkl' %(epoch+1))) #torch.save(ques_description_encoder.state_dict(), os.path.join(args.model_path, 'ques_description_encoder-%d.pkl' %(epoch+1))) torch.save(vqa_decoder.state_dict(), os.path.join(args.model_path, 'vqa_decoder-%d.pkl' %(epoch+1))) overfit_warn = overfit_warn + 1 if (min_avg_loss < avg_loss) else 0 min_avg_loss = min(min_avg_loss, avg_loss) lossFileName = "result/result_"+str(epoch)+".txt" test_fd = open(lossFileName, 'w') test_fd.write('Epoch: '+ str(epoch) + ' avg_loss: ' + str(avg_loss)+ " avg_acc: "+ str(avg_acc)+"\n") test_fd.close() if overfit_warn >= 5: print("terminated as overfitted") break
word_embedding = word_embedding.to(device) image_encoder = image_encoder.to(device) image_decoder = image_decoder.to(device) """ 11) text file logging """ log_filename = f"logs/training_log-BIGDATASET-BIGMODEL.log" logging.basicConfig(filename=log_filename, level=logging.DEBUG) EPOCHS = 100 START_EPOCH = 0 print("Beginning Training") for epoch in range(START_EPOCH, EPOCHS): # TRAIN results = run_epoch(epoch, train_loader, image_encoder, image_decoder, word_embedding, loss_fn, optim, device, train=True) print(results.to_string(-1)) logging.debug(results.to_string(-1)) # VAL results = run_epoch(epoch, val_loader, image_encoder, image_decoder, word_embedding, loss_fn, optim, device, train=False) print('Val ' + results.to_string(-1)) logging.debug('Val ' + results.to_string(-1)) # SAVE torch.save(word_embedding.state_dict(), f"checkpoints/BIGMODEL-BIGDATASET-weights-embedding-epoch-{epoch}.pt") torch.save(image_encoder.state_dict(), f"checkpoints/BIGMODEL-BIGDATASET-weights-encoder-epoch-{epoch}.pt") torch.save(image_decoder.state_dict(), f"checkpoints/BIGMODEL-BIGDATASET-weights-decoder-epoch-{epoch}.pt")
def intention_train(context_text_encoder: TextEncoder, context_image_encoder: ImageEncoder, context_encoder: ContextEncoder, train_dataset: Dataset, valid_dataset: Dataset, test_dataset: Dataset, model_file: str): """Intention train. Args: context_text_encoder (TextEncoder): Context text encoder. context_image_encoder (ImageEncoder): Context image encoder. context_encoder (ContextEncoder): Context encoder. train_dataset (Dataset): Train dataset. valid_dataset (Dataset): Valid dataset. test_dataset (Dataset): Test dataset. model_file (str): Saved model file. """ # Data loader. train_data_loader = DataLoader( dataset=train_dataset, batch_size=IntentionTrainConfig.batch_size, shuffle=True, num_workers=IntentionTrainConfig.num_data_loader_workers) # Model. intention_config = IntentionConfig() intention = Intention(intention_config).to(GlobalConfig.device) # Model parameters. params = list( chain.from_iterable([ list(model.parameters()) for model in [ context_text_encoder, context_image_encoder, context_encoder, intention ] ])) optimizer = Adam(params, lr=IntentionTrainConfig.learning_rate) epoch_id = 0 min_valid_loss = None # Load saved state. if isfile(model_file): state = torch.load(model_file) intention.load_state_dict(state['intention']) optimizer.load_state_dict(state['optimizer']) epoch_id = state['epoch_id'] min_valid_loss = state['min_valid_loss'] # Loss. sum_loss = 0 bad_loss_cnt = 0 # Switch to train mode. context_text_encoder.train() context_image_encoder.train() context_encoder.train() intention.train() finished = False for epoch_id in range(epoch_id, IntentionTrainConfig.num_iterations): for batch_id, train_data in enumerate(train_data_loader): # Sets gradients to 0. optimizer.zero_grad() texts, text_lengths, images, utter_types = train_data # Sizes: # texts: (batch_size, dialog_context_size + 1, dialog_text_max_len) # text_lengths: (batch_size, dialog_context_size + 1) # images: (batch_size, dialog_context_size + 1, # pos_images_max_num, 3, image_size, image_size) # utter_types: (batch_size, ) # To device. texts = texts.to(GlobalConfig.device) text_lengths = text_lengths.to(GlobalConfig.device) images = images.to(GlobalConfig.device) utter_types = utter_types.to(GlobalConfig.device) texts.transpose_(0, 1) # (dialog_context_size + 1, batch_size, dialog_text_max_len) text_lengths.transpose_(0, 1) # (dialog_context_size + 1, batch_size) images.transpose_(0, 1) images.transpose_(1, 2) # (dialog_context_size + 1, pos_images_max_num, batch_size, 3, # image_size, image_size) # Encode context. context, _ = encode_context(context_text_encoder, context_image_encoder, context_encoder, texts, text_lengths, images) # (batch_size, context_vector_size) intent_prob = intention(context) # (batch_size, utterance_type_size) loss = nll_loss(intent_prob, utter_types) sum_loss += loss loss.backward() optimizer.step() # Print loss every `TrainConfig.print_freq` batches. if (batch_id + 1) % IntentionTrainConfig.print_freq == 0: cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") sum_loss /= IntentionTrainConfig.print_freq print('epoch: {} \tbatch: {} \tloss: {} \ttime: {}'.format( epoch_id + 1, batch_id + 1, sum_loss, cur_time)) sum_loss = 0 # Valid every `TrainConfig.valid_freq` batches. if (batch_id + 1) % IntentionTrainConfig.valid_freq == 0: valid_loss, accuracy = intention_valid(context_text_encoder, context_image_encoder, context_encoder, intention, valid_dataset) cur_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print('valid_loss: {} \taccuracy: {} \ttime: {}'.format( valid_loss, accuracy, cur_time)) # Save current best model. if min_valid_loss is None or valid_loss < min_valid_loss: min_valid_loss = valid_loss bad_loss_cnt = 0 save_dict = { 'task': INTENTION_TASK, 'epoch_id': epoch_id, 'min_valid_loss': min_valid_loss, 'optimizer': optimizer.state_dict(), 'context_text_encoder': context_text_encoder.state_dict(), 'context_image_encoder': context_image_encoder.state_dict(), 'context_encoder': context_encoder.state_dict(), 'intention': intention.state_dict() } torch.save(save_dict, model_file) print('Best model saved.') else: bad_loss_cnt += 1 if bad_loss_cnt > IntentionTrainConfig.patience: intention_test(context_text_encoder, context_image_encoder, context_encoder, intention, test_dataset) finished = True break if finished: break