def eval(self, dataset, state_dict=None, valid=False): # Load parameters if self.__C.CKPT_PATH is not None: print('Warning: you are now using CKPT_PATH args, ' 'CKPT_VERSION and CKPT_EPOCH will not work') path = self.__C.CKPT_PATH else: path = self.__C.CKPTS_PATH + \ 'ckpt_' + self.__C.CKPT_VERSION + \ '/epoch' + str(self.__C.CKPT_EPOCH) + '.pkl' val_ckpt_flag = False if state_dict is None: val_ckpt_flag = True print('Loading ckpt {}'.format(path)) state_dict = torch.load(path)['state_dict'] print('Finish!') # Store the prediction list qid_list = [ques['question_id'] for ques in dataset.ques_list] ans_ix_list = [] pred_list = [] data_size = dataset.data_size token_size = dataset.token_size ans_size = dataset.ans_size pretrained_emb = dataset.pretrained_emb net = Net(self.__C, pretrained_emb, token_size, ans_size) net.cuda() net.eval() if self.__C.N_GPU > 1: net = nn.DataParallel(net, device_ids=self.__C.DEVICES) net.load_state_dict(state_dict) dataloader = Data.DataLoader(dataset, batch_size=self.__C.EVAL_BATCH_SIZE, shuffle=False, num_workers=self.__C.NUM_WORKERS, pin_memory=True) for step, (img_feat_iter, ques_ix_iter, ans_iter) in enumerate(dataloader): print("\rEvaluation: [step %4d/%4d]" % ( step, int(data_size / self.__C.EVAL_BATCH_SIZE), ), end=' ') img_feat_iter = img_feat_iter.cuda() ques_ix_iter = ques_ix_iter.cuda() pred = net(img_feat_iter, ques_ix_iter) #print(pred) pred_np = pred[0].cpu().data.numpy() pred_argmax = np.argmax(pred_np, axis=1) # Save the answer index if pred_argmax.shape[0] != self.__C.EVAL_BATCH_SIZE: pred_argmax = np.pad( pred_argmax, (0, self.__C.EVAL_BATCH_SIZE - pred_argmax.shape[0]), mode='constant', constant_values=-1) ans_ix_list.append(pred_argmax) # Save the whole prediction vector if self.__C.TEST_SAVE_PRED: if pred_np.shape[0] != self.__C.EVAL_BATCH_SIZE: pred_np = np.pad( pred_np, ((0, self.__C.EVAL_BATCH_SIZE - pred_np.shape[0]), (0, 0)), mode='constant', constant_values=-1) pred_list.append(pred_np) print('') ans_ix_list = np.array(ans_ix_list).reshape(-1) result = [ { 'answer': dataset.ix_to_ans[str( ans_ix_list[qix] )], # ix_to_ans(load with json) keys are type of string 'question_id': int(qid_list[qix]) } for qix in range(qid_list.__len__()) ] # Write the results to result file if valid: if val_ckpt_flag: result_eval_file = \ self.__C.CACHE_PATH + \ 'result_run_' + self.__C.CKPT_VERSION + \ '.json' else: result_eval_file = \ self.__C.CACHE_PATH + \ 'result_run_' + self.__C.VERSION + \ '.json' else: if self.__C.CKPT_PATH is not None: result_eval_file = \ self.__C.RESULT_PATH + \ 'result_run_' + self.__C.CKPT_VERSION + \ '.json' else: result_eval_file = \ self.__C.RESULT_PATH + \ 'result_run_' + self.__C.CKPT_VERSION + \ '_epoch' + str(self.__C.CKPT_EPOCH) + \ '.json' print('Save the result to file: {}'.format(result_eval_file)) json.dump(result, open(result_eval_file, 'w')) # Save the whole prediction vector if self.__C.TEST_SAVE_PRED: if self.__C.CKPT_PATH is not None: ensemble_file = \ self.__C.PRED_PATH + \ 'result_run_' + self.__C.CKPT_VERSION + \ '.json' else: ensemble_file = \ self.__C.PRED_PATH + \ 'result_run_' + self.__C.CKPT_VERSION + \ '_epoch' + str(self.__C.CKPT_EPOCH) + \ '.json' print( 'Save the prediction vector to file: {}'.format(ensemble_file)) pred_list = np.array(pred_list).reshape(-1, ans_size) result_pred = [{ 'pred': pred_list[qix], 'question_id': int(qid_list[qix]) } for qix in range(qid_list.__len__())] pickle.dump(result_pred, open(ensemble_file, 'wb+'), protocol=-1) # Run validation script if valid: # create vqa object and vqaRes object ques_file_path = self.__C.QUESTION_PATH['test'] ans_file_path = self.__C.ANSWER_PATH['test'] vqa = VQA(ans_file_path, ques_file_path) vqaRes = vqa.loadRes(result_eval_file, ques_file_path) # create vqaEval object by taking vqa and vqaRes vqaEval = VQAEval( vqa, vqaRes, n=2 ) # n is precision of accuracy (number of places after decimal), default is 2 # evaluate results """ If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function By default it uses all the question ids in annotation file """ vqaEval.evaluate() # print accuracies print("\n") print("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall'])) # print("Per Question Type Accuracy is the following:") # for quesType in vqaEval.accuracy['perQuestionType']: # print("%s : %.02f" % (quesType, vqaEval.accuracy['perQuestionType'][quesType])) # print("\n") print("Per Answer Type Accuracy is the following:") for ansType in vqaEval.accuracy['perAnswerType']: print("%s : %.02f" % (ansType, vqaEval.accuracy['perAnswerType'][ansType])) print("\n") if val_ckpt_flag: print('Write to log file: {}'.format( self.__C.LOG_PATH + 'log_run_' + self.__C.CKPT_VERSION + '.txt', 'a+')) logfile = open( self.__C.LOG_PATH + 'log_run_' + self.__C.CKPT_VERSION + '.txt', 'a+') else: print('Write to log file: {}'.format( self.__C.LOG_PATH + 'log_run_' + self.__C.VERSION + '.txt', 'a+')) logfile = open( self.__C.LOG_PATH + 'log_run_' + self.__C.VERSION + '.txt', 'a+') logfile.write("Overall Accuracy is: %.02f\n" % (vqaEval.accuracy['overall'])) for ansType in vqaEval.accuracy['perAnswerType']: logfile.write( "%s : %.02f " % (ansType, vqaEval.accuracy['perAnswerType'][ansType])) logfile.write("\n\n") logfile.close()
# set up file names and paths versionType = 'v2_' # this should be '' when using VQA v2.0 dataset taskType = 'MultipleChoice' # 'OpenEnded' only for v2.0. 'OpenEnded' or 'MultipleChoice' for v1.0 dataSubType = 'test' annFile = '{}/annotations/{}.json'.format(dataDir, dataSubType) quesFile = '{}/questions/{}.json'.format(dataDir, dataSubType) imgDir = '{}/images/{}/' .format(dataDir, dataSubType) resFile = './results.json' # create vqa object and vqaRes object vqa = VQA(annFile, quesFile) vqaRes = vqa.loadRes(resFile, quesFile) # create vqaEval object by taking vqa and vqaRes vqaEval = VQAEval(vqa, vqaRes, n=2) #n is precision of accuracy (number of places after decimal), default is 2 # evaluate results """ If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function By default it uses all the question ids in annotation file """ vqaEval.evaluate() # print accuracies print("\n") print("Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall'])) print("Per Question Type Accuracy is the following:") for quesType in vqaEval.accuracy['perQuestionType']: print("%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])) print("\n")