def train(self, dataset, dataset_eval=None): super_time_start = time.time() # Obtain needed information data_size = dataset.data_size token_size = dataset.token_size ans_size = dataset.ans_size pretrained_emb = dataset.pretrained_emb # Define the MCAN model net = Net(self.__C, pretrained_emb, token_size, ans_size) net.cuda() net.train() # Define the binary cross entropy loss loss_fn = torch.nn.BCELoss(reduction='sum').cuda() # Load checkpoint if resume training if self.__C.RESUME: print('========== Resume training') if self.__C.CKPT_PATH is not None: print( 'Warning: you are now using CKPT_PATH args, CKPT_VERSION and CKPT_EPOCH will not work' ) path = self.__C.CKPT_PATH else: path = self.__C.CKPTS_PATH + 'ckpt_' + self.__C.CKPT_VERSION \ + '/epoch' + str(self.__C.CKPT_EPOCH) + '.pkl' # Load the network parameters print('========== Loading ckpt {}'.format(path)) ckpt = torch.load(path) print('========== Finished!') net.load_state_dict(ckpt['state_dict']) # Load the optimizer parameters optim = get_optim(self.__C, net, data_size, ckpt['lr_base']) optim._step = int(data_size / self.__C.BATCH_SIZE * self.__C.CKPT_EPOCH) optim.optimizer.load_state_dict(ckpt['optimizer']) start_epoch = self.__C.CKPT_EPOCH else: if ('ckpt_' + self.__C.VERSION) in os.listdir(self.__C.CKPTS_PATH): shutil.rmtree(self.__C.CKPTS_PATH + 'ckpt_' + self.__C.VERSION) os.mkdir(self.__C.CKPTS_PATH + 'ckpt_' + self.__C.VERSION) optim = get_optim(self.__C, net, data_size) start_epoch = 0 loss_sum = 0 named_params = list(net.named_parameters()) grad_norm = np.zeros(len(named_params)) # Define multi-thread dataloader dataloader = Data.DataLoader(dataset, batch_size=self.__C.BATCH_SIZE, shuffle=False, num_workers=self.__C.NUM_WORKERS, pin_memory=self.__C.PIN_MEM, drop_last=True) # Training script for epoch in range(start_epoch, self.__C.MAX_EPOCH): epoch_finish = epoch + 1 # Save log information logfile = open( self.__C.LOG_PATH + 'log_run_' + self.__C.VERSION + '.txt', 'a+') logfile.write( 'nowTime: ' + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + '\n') logfile.close() # Learning Rate Decay if epoch in self.__C.LR_DECAY_LIST: adjust_lr(optim, self.__C.LR_DECAY_R) # Externally shuffle shuffle_list(dataset.ans_list) time_start = time.time() # Iteration for step, (img_feat_iter, ques_ix_iter, ans_iter) in enumerate(dataloader): optim.zero_grad() img_feat_iter = img_feat_iter.cuda() ques_ix_iter = ques_ix_iter.cuda() ans_iter = ans_iter.cuda() pred = net(img_feat_iter, ques_ix_iter) loss = loss_fn(pred, ans_iter) loss.backward() loss_sum += loss.cpu().data.numpy() if self.__C.VERBOSE: # print loss every step if dataset_eval is not None: mode_str = self.__C.SPLIT[ 'train'] + '->' + self.__C.SPLIT['val'] else: mode_str = self.__C.SPLIT[ 'train'] + '->' + self.__C.SPLIT['test'] print( "\r[version %s][epoch %2d][step %4d/%4d][%s] loss: %.4f, lr: %.2e" % (self.__C.VERSION, epoch_finish, step, int(data_size / self.__C.BATCH_SIZE), mode_str, loss.cpu().data.numpy() / self.__C.BATCH_SIZE, optim._rate), end=' ') # Save the gradient information for name in range(len(named_params)): if named_params[name][1].grad is not None: norm_v = torch.norm( named_params[name][1].grad).cpu().data.numpy() else: norm_v = 0 grad_norm[name] += norm_v optim.step() time_end = time.time() print('========== Finished in {}s'.format( int(time_end - time_start))) # Save checkpoint state = { 'state_dict': net.state_dict(), 'optimizer': optim.optimizer.state_dict(), 'lr_base': optim.lr_base } torch.save( state, self.__C.CKPTS_PATH + 'ckpt_' + self.__C.VERSION + '/epoch' + str(epoch_finish) + '.pkl') # Logging logfile = open( self.__C.LOG_PATH + 'log_run_' + self.__C.VERSION + '.txt', 'a+') logfile.write('epoch = ' + str(epoch_finish) + ' loss = ' + str(loss_sum / data_size) + '\n' + 'lr = ' + str(optim._rate) + '\n\n') logfile.close() # Eval after every epoch if dataset_eval is not None: self.eval(dataset_eval, state_dict=net.state_dict(), valid=True) loss_sum = 0 grad_norm = np.zeros(len(named_params)) print('========== Total Training time is {}s'.format( int(time.time() - super_time_start)))
def eval(self, dataset, state_dict=None, valid=False): # Load parameters if self.__C.CKPT_PATH is not None: print('Warning: you are now using CKPT_PATH args, ' 'CKPT_VERSION and CKPT_EPOCH will not work') path = self.__C.CKPT_PATH else: path = self.__C.CKPTS_PATH + 'ckpt_' + self.__C.CKPT_VERSION + \ '/epoch' + str(self.__C.CKPT_EPOCH) + '.pkl' val_ckpt_flag = False if state_dict is None: val_ckpt_flag = True print('========== Loading ckpt {}'.format(path)) state_dict = torch.load(path)['state_dict'] print('========== Finished!') # Store the prediction list qid_list = [ques['question_id'] for ques in dataset.ques_list] ans_ix_list = [] pred_list = [] data_size = dataset.data_size # 214354 token_size = dataset.token_size # 20572 ans_size = dataset.ans_size # 3129 pretrained_emb = dataset.pretrained_emb net = Net(self.__C, pretrained_emb, token_size, ans_size) net.cuda() net.eval() net.load_state_dict(state_dict) dataloader = Data.DataLoader(dataset, batch_size=self.__C.EVAL_BATCH_SIZE, shuffle=False, num_workers=self.__C.NUM_WORKERS, pin_memory=True) for step, (img_feat_iter, ques_ix_iter, ans_iter) in enumerate(dataloader): print("\rEvaluation: [step %4d/%4d]" % ( step, int(data_size / self.__C.EVAL_BATCH_SIZE), ), end=' ') img_feat_iter = img_feat_iter.cuda() ques_ix_iter = ques_ix_iter.cuda() pred = net(img_feat_iter, ques_ix_iter) pred_np = pred.cpu().data.numpy() pred_argmax = np.argmax(pred_np, axis=1) # Save the answer index if pred_argmax.shape[0] != self.__C.EVAL_BATCH_SIZE: pred_argmax = np.pad( pred_argmax, (0, self.__C.EVAL_BATCH_SIZE - pred_argmax.shape[0]), mode='constant', constant_values=-1) ans_ix_list.append(pred_argmax) # Save the whole prediction vector if self.__C.TEST_SAVE_PRED: if pred_np.shape[0] != self.__C.EVAL_BATCH_SIZE: pred_np = np.pad( pred_np, ((0, self.__C.EVAL_BATCH_SIZE - pred_np.shape[0]), (0, 0)), mode='constant', constant_values=-1) pred_list.append(pred_np) print('') ans_ix_list = np.array(ans_ix_list).reshape(-1) result = [ { 'answer': dataset.ix_to_ans[str( ans_ix_list[qix] )], # ix_to_ans(load with json) keys are type of string 'question_id': int(qid_list[qix]) } for qix in range(qid_list.__len__()) ] # Write the results to result file if valid: if val_ckpt_flag: result_eval_file = \ self.__C.CACHE_PATH + \ 'result_run_' + self.__C.CKPT_VERSION + \ '.json' else: result_eval_file = \ self.__C.CACHE_PATH + \ 'result_run_' + self.__C.VERSION + \ '.json' else: if self.__C.CKPT_PATH is not None: result_eval_file = \ self.__C.RESULT_PATH + \ 'result_run_' + self.__C.CKPT_VERSION + \ '.json' else: result_eval_file = \ self.__C.RESULT_PATH + \ 'result_run_' + self.__C.CKPT_VERSION + \ '_epoch' + str(self.__C.CKPT_EPOCH) + \ '.json' print('Save the result to file: {}'.format(result_eval_file)) json.dump(result, open(result_eval_file, 'w')) # Save the whole prediction vector if self.__C.TEST_SAVE_PRED: if self.__C.CKPT_PATH is not None: ensemble_file = \ self.__C.PRED_PATH + \ 'result_run_' + self.__C.CKPT_VERSION + \ '.json' else: ensemble_file = \ self.__C.PRED_PATH + \ 'result_run_' + self.__C.CKPT_VERSION + \ '_epoch' + str(self.__C.CKPT_EPOCH) + \ '.json' print( 'Save the prediction vector to file: {}'.format(ensemble_file)) pred_list = np.array(pred_list).reshape(-1, ans_size) result_pred = [{ 'pred': pred_list[qix], 'question_id': int(qid_list[qix]) } for qix in range(qid_list.__len__())] pickle.dump(result_pred, open(ensemble_file, 'wb+'), protocol=-1) # Run validation script if valid: # create vqa object and vqaRes object ques_file_path = self.__C.QUESTION_PATH['val'] ans_file_path = self.__C.ANSWER_PATH['val'] vqa = VQA(ans_file_path, ques_file_path) vqaRes = vqa.loadRes(result_eval_file, ques_file_path) # create vqaEval object by taking vqa and vqaRes vqaEval = VQAEval( vqa, vqaRes, n=2 ) # n is precision of accuracy (number of places after decimal), default is 2 # evaluate results """ If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function By default it uses all the question ids in annotation file """ vqaEval.evaluate() # print accuracies print("\n") print("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall'])) # print("Per Question Type Accuracy is the following:") # for quesType in vqaEval.accuracy['perQuestionType']: # print("%s : %.02f" % (quesType, vqaEval.accuracy['perQuestionType'][quesType])) # print("\n") print("Per Answer Type Accuracy is the following:") for ansType in vqaEval.accuracy['perAnswerType']: print("%s : %.02f" % (ansType, vqaEval.accuracy['perAnswerType'][ansType])) print("\n") if val_ckpt_flag: print('Write to log file: {}'.format( self.__C.LOG_PATH + 'log_run_' + self.__C.CKPT_VERSION + '.txt', 'a+')) logfile = open( self.__C.LOG_PATH + 'log_run_' + self.__C.CKPT_VERSION + '.txt', 'a+') else: print('Write to log file: {}'.format( self.__C.LOG_PATH + 'log_run_' + self.__C.VERSION + '.txt', 'a+')) logfile = open( self.__C.LOG_PATH + 'log_run_' + self.__C.VERSION + '.txt', 'a+') logfile.write("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall'])) for ansType in vqaEval.accuracy['perAnswerType']: logfile.write( "%s : %.02f " % (ansType, vqaEval.accuracy['perAnswerType'][ansType])) logfile.write("\n\n") logfile.close()