Example #1
0
def evaluate_and_dump_predictions(pred, qids, qfile, afile, ix_ans_dict, filename):
    """
    dumps predictions to some default file
    :param pred: list of predictions, like [1, 2, 3, 2, ...]. one number for each example
    :param qids: question ids in the same order of predictions, they need to align and match
    :param qfile:
    :param afile:
    :param ix_ans_dict:
    :return:
    """
    assert len(pred) == len(qids), "Number of predictions need to match number of question IDs"
    answers = []
    for i, val in enumerate(pred):
        qa_pair = {}
        qa_pair['question_id'] = int(qids[i])
        qa_pair['answer'] = ix_ans_dict[str(val + 1)]  # note indexing diff between python and torch
        answers.append(qa_pair)
    vqa = VQA(afile, qfile)
    fod = open(filename, 'wb')
    json.dump(answers, fod)
    fod.close()
    # VQA evaluation
    vqaRes = vqa.loadRes(filename, qfile)
    vqaEval = VQAEval(vqa, vqaRes, n=2)
    vqaEval.evaluate()
    acc = vqaEval.accuracy['overall']
    print("Overall Accuracy is: %.02f\n" % acc)
    return acc
Example #2
0
def vqaEval(config=Config(), epoch_list=range(10)):
    accuracy_dic = {}
    best_accuracy, best_epoch = 0.0, -1

    # set up file names and paths
    annFile = config.selected_val_annotations_path
    quesFile = config.selected_val_questions_path

    for epoch in epoch_list:

        resFile = config.result_path % (epoch)

        vqa = VQA(annFile, quesFile)
        vqaRes = vqa.loadRes(resFile, quesFile)
        vqaEval = VQAEval(
            vqa, vqaRes, n=2
        )  #n is precision of accuracy (number of places after decimal), default is 2

        # evaluate results
        """
        If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
        By default it uses all the question ids in annotation file
        """
        vqaEval.evaluate()

        # print accuracies
        accuracy = vqaEval.accuracy['overall']
        print "Overall Accuracy is: %.02f\n" % (accuracy)
        """
        print "Per Question Type Accuracy is the following:"
        for quesType in vqaEval.accuracy['perQuestionType']:
    	    print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
        print "\n"
        """
        accuracy_dic[epoch] = {'overall': accuracy}
        print "Per Answer Type Accuracy is the following:"
        for ansType in vqaEval.accuracy['perAnswerType']:
            accuracy_dic[epoch][ansType] = vqaEval.accuracy['perAnswerType'][
                ansType]

#print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_epoch = epoch

    #print "** Done for every epoch! **"
    #print "Accuracy Dictionry"
    #print accuracy_dic
    print "Best Epoch is %d with Accuracy %.02f" % (best_epoch, best_accuracy)
    return accuracy_dic
Example #3
0
def vqaEval(config = Config(), epoch_list = range(10)):
    accuracy_dic = {}
    best_accuracy, best_epoch = 0.0, -1

    # set up file names and paths
    annFile = config.selected_val_annotations_path
    quesFile = config.selected_val_questions_path

    for epoch in epoch_list:

        resFile = config.result_path%(epoch)

        vqa = VQA(annFile, quesFile)
        vqaRes = vqa.loadRes(resFile, quesFile)
        vqaEval = VQAEval(vqa, vqaRes, n=2)   #n is precision of accuracy (number of places after decimal), default is 2

        # evaluate results
        """
        If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
        By default it uses all the question ids in annotation file
        """
        vqaEval.evaluate()

        # print accuracies
        accuracy = vqaEval.accuracy['overall']
        print "Overall Accuracy is: %.02f\n" %(accuracy)
        """
        print "Per Question Type Accuracy is the following:"
        for quesType in vqaEval.accuracy['perQuestionType']:
    	    print "%s : %.02f" %(quesType, vqaEval.accuracy['perQuestionType'][quesType])
        print "\n"
        """
        accuracy_dic[epoch] = {'overall' : accuracy}
        print "Per Answer Type Accuracy is the following:"
        for ansType in vqaEval.accuracy['perAnswerType']:
            accuracy_dic[epoch][ansType] = vqaEval.accuracy['perAnswerType'][ansType]
	    #print "%s : %.02f" %(ansType, vqaEval.accuracy['perAnswerType'][ansType])

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_epoch = epoch


    #print "** Done for every epoch! **"
    #print "Accuracy Dictionry"
    #print accuracy_dic
    print "Best Epoch is %d with Accuracy %.02f"%(best_epoch, best_accuracy)
    return accuracy_dic
def exec_validation(device_id, mode, it='', visualize=False):

    caffe.set_device(device_id)
    caffe.set_mode_gpu()
    net = caffe.Net('./result/proto_test.prototxt',\
              './result/tmp.caffemodel',\
              caffe.TEST)

    dp = VQADataProvider(mode=mode,batchsize=64)
    total_questions = len(dp.getQuesIds())
    epoch = 0

    pred_list = []
    testloss_list = []
    stat_list = []

    while epoch == 0:
        t_word, t_cont, t_img_feature, t_answer, t_glove_matrix, t_qid_list, t_iid_list, epoch = dp.get_batch_vec()
        net.blobs['data'].data[...] = np.transpose(t_word,(1,0))
        net.blobs['cont'].data[...] = np.transpose(t_cont,(1,0))
        net.blobs['img_feature'].data[...] = t_img_feature
        net.blobs['label'].data[...] = t_answer
        net.blobs['glove'].data[...] = np.transpose(t_glove_matrix, (1,0,2))
        net.forward()
        t_pred_list = net.blobs['prediction'].data.argmax(axis=1)
        t_pred_str = [dp.vec_to_answer(pred_symbol) for pred_symbol in t_pred_list]
        testloss_list.append(net.blobs['loss'].data)
        for qid, iid, ans, pred in zip(t_qid_list, t_iid_list, t_answer.tolist(), t_pred_str):
            pred_list.append({u'answer':pred, u'question_id': int(dp.getStrippedQuesId(qid))})
            if visualize:
                q_list = dp.seq_to_list(dp.getQuesStr(qid))
                if mode == 'test-dev' or 'test':
                    ans_str = ''
                    ans_list = ['']*10
                else:
                    ans_str = dp.vec_to_answer(ans)
                    ans_list = [ dp.getAnsObj(qid)[i]['answer'] for i in xrange(10)]
                stat_list.append({\
                                    'qid'   : qid,
                                    'q_list' : q_list,
                                    'iid'   : iid,
                                    'answer': ans_str,
                                    'ans_list': ans_list,
                                    'pred'  : pred })
        percent = 100 * float(len(pred_list)) / total_questions
        sys.stdout.write('\r' + ('%.2f' % percent) + '%')
        sys.stdout.flush()



    mean_testloss = np.array(testloss_list).mean()

    if mode == 'val':
        valFile = './result/val2015_resfile'
        with open(valFile, 'w') as f:
            json.dump(pred_list, f)
        if visualize:
            visualize_failures(stat_list,mode)
        annFile = config.DATA_PATHS['val']['ans_file']
        quesFile = config.DATA_PATHS['val']['ques_file']
        vqa = VQA(annFile, quesFile)
        vqaRes = vqa.loadRes(valFile, quesFile)
        vqaEval = VQAEval(vqa, vqaRes, n=2)
        vqaEval.evaluate()
        acc_overall = vqaEval.accuracy['overall']
        acc_perQuestionType = vqaEval.accuracy['perQuestionType']
        acc_perAnswerType = vqaEval.accuracy['perAnswerType']
        return mean_testloss, acc_overall, acc_perQuestionType, acc_perAnswerType
    elif mode == 'test-dev':
        filename = './result/vqa_OpenEnded_mscoco_test-dev2015_v3t'+str(it).zfill(8)+'_results'
        with open(filename+'.json', 'w') as f:
            json.dump(pred_list, f)
        if visualize:
            visualize_failures(stat_list,mode)
    elif mode == 'test':
        filename = './result/vqa_OpenEnded_mscoco_test2015_v3c'+str(it).zfill(8)+'_results'
        with open(filename+'.json', 'w') as f:
            json.dump(pred_list, f)
        if visualize:
            visualize_failures(stat_list,mode)
Example #5
0
dataType    ='mscoco'  # 'mscoco' only for v1.0. 'mscoco' for real and 'abstract_v002' for abstract for v1.0.
dataSubType ='train2014'
annFile     ='%s/Annotations/%s%s_%s_annotations.json'%(dataDir, versionType, dataType, dataSubType)
quesFile    ='%s/Questions/%s%s_%s_%s_questions.json'%(dataDir, versionType, taskType, dataType, dataSubType)
imgDir      ='%s/Images/%s/%s/' %(dataDir, dataType, dataSubType)
resultType  ='fake'
fileTypes   = ['results', 'accuracy', 'evalQA', 'evalQuesType', 'evalAnsType']

# An example result json file has been provided in './Results' folder.

[resFile, accuracyFile, evalQAFile, evalQuesTypeFile, evalAnsTypeFile] = ['%s/Results/%s%s_%s_%s_%s_%s.json'%(dataDir, versionType, taskType, dataType, dataSubType, \
resultType, fileType) for fileType in fileTypes]

# create vqa object and vqaRes object
vqa = VQA(annFile, quesFile)
vqaRes = vqa.loadRes(resFile, quesFile)

# create vqaEval object by taking vqa and vqaRes
vqaEval = VQAEval(vqa, vqaRes, n=2)   #n is precision of accuracy (number of places after decimal), default is 2

# evaluate results
"""
If you have a list of question ids on which you would like to evaluate your results, pass it as a list to below function
By default it uses all the question ids in annotation file
"""
vqaEval.evaluate()

# print accuracies
print("\n")
print("Overall Accuracy is: %.02f\n" %(vqaEval.accuracy['overall']))
print("Per Question Type Accuracy is the following:")
Example #6
0
def exec_validation(device_id, mode, it='', visualize=False):

    caffe.set_device(device_id)
    caffe.set_mode_gpu()
    net = caffe.Net('./result/proto_test.prototxt',\
              './result/tmp.caffemodel',\
              caffe.TEST)

    dp = VQADataProvider(mode=mode, batchsize=64)
    total_questions = len(dp.getQuesIds())
    epoch = 0

    pred_list = []
    testloss_list = []
    stat_list = []

    while epoch == 0:
        t_word, t_cont, t_img_feature, t_answer, t_qid_list, t_iid_list, epoch = dp.get_batch_vec(
        )
        net.blobs['data'].data[...] = np.transpose(t_word, (1, 0))
        net.blobs['cont'].data[...] = np.transpose(t_cont, (1, 0))
        net.blobs['img_feature'].data[...] = t_img_feature
        net.blobs['label'].data[...] = t_answer
        net.forward()
        t_pred_list = net.blobs['prediction'].data.argmax(axis=1)
        t_pred_str = [
            dp.vec_to_answer(pred_symbol) for pred_symbol in t_pred_list
        ]
        testloss_list.append(net.blobs['loss'].data)
        for qid, iid, ans, pred in zip(t_qid_list, t_iid_list,
                                       t_answer.tolist(), t_pred_str):
            pred_list.append({
                'answer': pred,
                'question_id': int(dp.getStrippedQuesId(qid))
            })
            if visualize:
                q_list = dp.seq_to_list(dp.getQuesStr(qid))
                if mode == 'test-dev' or 'test':
                    ans_str = ''
                    ans_list = [''] * 10
                else:
                    ans_str = dp.vec_to_answer(ans)
                    ans_list = [
                        dp.getAnsObj(qid)[i]['answer'] for i in range(10)
                    ]
                stat_list.append({\
                                    'qid'   : qid,
                                    'q_list' : q_list,
                                    'iid'   : iid,
                                    'answer': ans_str,
                                    'ans_list': ans_list,
                                    'pred'  : pred })
        percent = 100 * float(len(pred_list)) / total_questions
        sys.stdout.write('\r' + ('%.2f' % percent) + '%')
        sys.stdout.flush()

    mean_testloss = np.array(testloss_list).mean()

    if mode == 'val':
        valFile = './result/val2015_resfile'
        with open(valFile, 'w') as f:
            json.dump(pred_list, f)
        if visualize:
            visualize_failures(stat_list, mode)
        annFile = config.DATA_PATHS['val']['ans_file']
        quesFile = config.DATA_PATHS['val']['ques_file']
        vqa = VQA(annFile, quesFile)
        vqaRes = vqa.loadRes(valFile, quesFile)
        vqaEval = VQAEval(vqa, vqaRes, n=2)
        vqaEval.evaluate()
        acc_overall = vqaEval.accuracy['overall']
        acc_perQuestionType = vqaEval.accuracy['perQuestionType']
        acc_perAnswerType = vqaEval.accuracy['perAnswerType']
        return mean_testloss, acc_overall, acc_perQuestionType, acc_perAnswerType
    elif mode == 'test-dev':
        filename = './result/vqa_OpenEnded_mscoco_test-dev2015_v3t' + str(
            it).zfill(8) + '_results'
        with open(filename + '.json', 'w') as f:
            json.dump(pred_list, f)
        if visualize:
            visualize_failures(stat_list, mode)
    elif mode == 'test':
        filename = './result/vqa_OpenEnded_mscoco_test2015_v3c' + str(
            it).zfill(8) + '_results'
        with open(filename + '.json', 'w') as f:
            json.dump(pred_list, f)
        if visualize:
            visualize_failures(stat_list, mode)
def exec_validation(model, opt, mode, folder, it, visualize=False):
    model.eval()
    criterion = nn.NLLLoss()
    dp = VQADataProvider(opt,
                         batchsize=opt.VAL_BATCH_SIZE,
                         mode='val',
                         folder=folder)
    epoch = 0
    pred_list = []
    testloss_list = []
    stat_list = []
    total_questions = len(dp.getQuesIds())

    print('Validating...')
    while epoch == 0:
        t_word, word_length, t_img_feature, t_answer, t_qid_list, t_iid_list, epoch = dp.get_batch_vec(
        )
        word_length = np.sum(word_length, axis=1)

        data = Variable(torch.from_numpy(t_word)).cuda()
        word_length = torch.from_numpy(word_length).cuda()
        img_feature = Variable(torch.from_numpy(t_img_feature)).cuda()
        label = Variable(torch.from_numpy(t_answer)).cuda()
        pred = model(data, word_length, img_feature, 'val')
        pred = (pred.data).cpu().numpy()
        if mode == 'test-dev' or 'test':
            pass
        else:
            loss = criterion(pred, label.long())
            loss = (loss.data).cpu().numpy()
            testloss_list.append(loss)
        t_pred_list = np.argmax(pred, axis=1)
        t_pred_str = [
            dp.vec_to_answer(pred_symbol) for pred_symbol in t_pred_list
        ]

        for qid, iid, ans, pred in zip(t_qid_list, t_iid_list,
                                       t_answer.tolist(), t_pred_str):
            pred_list.append((pred, int(dp.getStrippedQuesId(qid))))
            if visualize:
                q_list = dp.seq_to_list(dp.getQuesStr(qid))
                if mode == 'test-dev' or 'test':
                    ans_str = ''
                    ans_list = [''] * 10
                else:
                    ans_str = dp.vec_to_answer(ans)
                    ans_list = [
                        dp.getAnsObj(qid)[i]['answer'] for i in range(10)
                    ]
                stat_list.append({\
                                    'qid'   : qid,
                                    'q_list' : q_list,
                                    'iid'   : iid,
                                    'answer': ans_str,
                                    'ans_list': ans_list,
                                    'pred'  : pred })
        percent = 100 * float(len(pred_list)) / total_questions
        sys.stdout.write('\r' + ('%.2f' % percent) + '%')
        sys.stdout.flush()

    print('Deduping arr of len', len(pred_list))
    deduped = []
    seen = set()
    for ans, qid in pred_list:
        if qid not in seen:
            seen.add(qid)
            deduped.append((ans, qid))
    print('New len', len(deduped))
    final_list = []
    for ans, qid in deduped:
        final_list.append({u'answer': ans, u'question_id': qid})

    if mode == 'val':
        mean_testloss = np.array(testloss_list).mean()
        valFile = './%s/val2015_resfile' % folder
        with open(valFile, 'w') as f:
            json.dump(final_list, f)
        if visualize:
            visualize_failures(stat_list, mode)
        annFile = config.DATA_PATHS['val']['ans_file']
        quesFile = config.DATA_PATHS['val']['ques_file']
        vqa = VQA(annFile, quesFile)
        vqaRes = vqa.loadRes(valFile, quesFile)
        vqaEval = VQAEval(vqa, vqaRes, n=2)
        vqaEval.evaluate()
        acc_overall = vqaEval.accuracy['overall']
        acc_perQuestionType = vqaEval.accuracy['perQuestionType']
        acc_perAnswerType = vqaEval.accuracy['perAnswerType']
        return mean_testloss, acc_overall, acc_perQuestionType, acc_perAnswerType
    elif mode == 'test-dev':
        filename = './%s/vqa_OpenEnded_mscoco_test-dev2015_%s-' % (
            folder, folder) + str(it).zfill(8) + '_results'
        with open(filename + '.json', 'w') as f:
            json.dump(final_list, f)
        if visualize:
            visualize_failures(stat_list, mode)
    elif mode == 'test':
        filename = './%s/vqa_OpenEnded_mscoco_test2015_%s-' % (
            folder, folder) + str(it).zfill(8) + '_results'
        with open(filename + '.json', 'w') as f:
            json.dump(final_list, f)
        if visualize:
            visualize_failures(stat_list, mode)
Example #8
0
    def exec_validation(self, sess, mode, folder, it=0, visualize=False):

        dp = VQADataLoader(mode=mode,
                           batchsize=config.VAL_BATCH_SIZE,
                           folder=folder)
        total_questions = len(dp.getQuesIds())
        epoch = 0
        pred_list = []
        testloss_list = []
        stat_list = []
        while epoch == 0:
            q_strs, q_word_vec_list, q_len_list, ans_vectors, img_features, a_word_vec, ans_score, ans_space_score, t_qid_list, img_ids, epoch = dp.next_batch(
                config.BATCH_SIZE)
            feed_dict = {
                self.model.q_input: q_word_vec_list,
                self.model.ans1: ans_vectors,
                self.model.seqlen: q_len_list,
                self.model.img_vec: img_features,
                self.lr: config.VQA_LR,
                self.model.keep_prob: 1.0,
                self.model.is_training: False
            }

            t_predict_list, predict_loss = sess.run(
                [self.model.predict1, self.model.softmax_cross_entrophy1],
                feed_dict=feed_dict)
            t_pred_str = [
                dp.vec_to_answer(pred_symbol) for pred_symbol in t_predict_list
            ]
            testloss_list.append(predict_loss)
            ans_vectors = np.asarray(ans_vectors).argmax(1)
            for qid, iid, ans, pred in zip(t_qid_list, img_ids, ans_vectors,
                                           t_pred_str):
                # pred_list.append({u'answer':pred, u'question_id': int(dp.getStrippedQuesId(qid))})
                pred_list.append((pred, int(dp.getStrippedQuesId(qid))))
                if visualize:
                    q_list = dp.seq_to_list(dp.getQuesStr(qid))
                    if mode == 'test-dev' or 'test':
                        ans_str = ''
                        ans_list = [''] * 10
                    else:
                        ans_str = dp.vec_to_answer(ans)
                        ans_list = [
                            dp.getAnsObj(qid)[i]['answer'] for i in xrange(10)
                        ]
                    stat_list.append({ \
                        'qid': qid,
                        'q_list': q_list,
                        'iid': iid,
                        'answer': ans_str,
                        'ans_list': ans_list,
                        'pred': pred})
            percent = 100 * float(len(pred_list)) / total_questions
            sys.stdout.write('\r' + ('%.2f' % percent) + '%')
            sys.stdout.flush()

        print 'Deduping arr of len', len(pred_list)
        deduped = []
        seen = set()
        for ans, qid in pred_list:
            if qid not in seen:
                seen.add(qid)
                deduped.append((ans, qid))
        print 'New len', len(deduped)
        final_list = []
        for ans, qid in deduped:
            final_list.append({u'answer': ans, u'question_id': qid})

        mean_testloss = np.array(testloss_list).mean()

        if mode == 'val':
            valFile = './%s/val2015_resfile_%d' % (folder, it)
            with open(valFile, 'w') as f:
                json.dump(final_list, f)
            if visualize:
                visualize_failures(stat_list, mode)
            annFile = config.DATA_PATHS['val']['ans_file']
            quesFile = config.DATA_PATHS['val']['ques_file']
            vqa = VQA(annFile, quesFile)
            vqaRes = vqa.loadRes(valFile, quesFile)
            vqaEval = VQAEval(vqa, vqaRes, n=2)
            vqaEval.evaluate()
            acc_overall = vqaEval.accuracy['overall']
            acc_perQuestionType = vqaEval.accuracy['perQuestionType']
            acc_perAnswerType = vqaEval.accuracy['perAnswerType']
            return mean_testloss, acc_overall, acc_perQuestionType, acc_perAnswerType
        elif mode == 'test-dev':
            filename = './%s/vqa_OpenEnded_mscoco_test-dev2015_%s-%d-' % (
                folder, folder, it) + str(it).zfill(8) + '_results'
            with open(filename + '.json', 'w') as f:
                json.dump(final_list, f)
            if visualize:
                visualize_failures(stat_list, mode)
        elif mode == 'test':
            filename = './%s/vqa_OpenEnded_mscoco_test2015_%s-%d-' % (
                folder, folder, it) + str(it).zfill(8) + '_results'
            with open(filename + '.json', 'w') as f:
                json.dump(final_list, f)
            if visualize:
                visualize_failures(stat_list, mode)
Example #9
0
def exec_validation(model, opt, mode, folder, it, logger, visualize=False, dp=None):
    """
    execute validation and save predictions as json file for visualization
    avg_loss:       average loss on given validation dataset split
    acc_overall:    overall accuracy
    """
    if opt.LATE_FUSION:
        criterion = nn.BCELoss()
        model_prob = model[1]
        model = model[0]
    else:
        criterion = nn.NLLLoss()

    check_mkdir(folder)
    model.eval()
    # criterion = nn.KLDivLoss(reduction='batchmean')
    if opt.BINARY:
        criterion2 = nn.BCELoss()
        acc_counter = 0
        all_counter = 0
    if not dp:
        dp = VQADataProvider(opt, batchsize=opt.VAL_BATCH_SIZE, mode=mode, logger=logger)
    epoch = 0
    pred_list = []
    loss_list = []
    stat_list = []
    total_questions = len(dp.getQuesIds())

    percent_counter = 0

    logger.info('Validating...')
    while epoch == 0:
        data, word_length, img_feature, answer, embed_matrix, ocr_length, ocr_embedding, ocr_tokens, ocr_answer_flags, qid_list, iid_list, epoch = dp.get_batch_vec()
        data = cuda_wrapper(Variable(torch.from_numpy(data))).long()
        word_length = cuda_wrapper(torch.from_numpy(word_length))
        img_feature = cuda_wrapper(Variable(torch.from_numpy(img_feature))).float()
        label = cuda_wrapper(Variable(torch.from_numpy(answer)))
        ocr_answer_flags = cuda_wrapper(torch.from_numpy(ocr_answer_flags))

        if opt.OCR:
            embed_matrix = cuda_wrapper(Variable(torch.from_numpy(embed_matrix))).float()
            ocr_length = cuda_wrapper(torch.from_numpy(ocr_length))
            ocr_embedding= cuda_wrapper(Variable(torch.from_numpy(ocr_embedding))).float()
            if opt.BINARY:
                ocr_answer_flags = cuda_wrapper(ocr_answer_flags)
                if opt.LATE_FUSION:
                    binary = model(data, img_feature, embed_matrix, ocr_length, ocr_embedding, mode)
                    pred = model_prob(data, img_feature, embed_matrix, ocr_length, ocr_embedding, mode)
                    pred1 = pred[:, 0:opt.MAX_ANSWER_VOCAB_SIZE]
                    pred2 = pred[:, opt.MAX_ANSWER_VOCAB_SIZE:]
                else:
                    binary, pred1, pred2 = model(data, img_feature, embed_matrix, ocr_length, ocr_embedding, mode)
            else:
                pred = model(data, img_feature, embed_matrix, ocr_length, ocr_embedding, mode)
        elif opt.EMBED:
            embed_matrix = cuda_wrapper(Variable(torch.from_numpy(embed_matrix))).float()
            pred = model(data, img_feature, embed_matrix, mode)
        else:
            pred = model(data, word_length, img_feature, mode)

        if mode == 'test-dev' or mode == 'test':
            pass
        else:
            if opt.BINARY:
                if opt.LATE_FUSION:
                    loss = criterion(binary, ocr_answer_flags.float())
                else:
                    loss = criterion2(binary, ocr_answer_flags.float()) * opt.BIN_LOSS_RATE
                    loss += criterion(pred1[label < opt.MAX_ANSWER_VOCAB_SIZE], label[label < opt.MAX_ANSWER_VOCAB_SIZE].long())
                    loss += criterion(pred2[label >= opt.MAX_ANSWER_VOCAB_SIZE], label[label >= opt.MAX_ANSWER_VOCAB_SIZE].long() - opt.MAX_ANSWER_VOCAB_SIZE)
                all_counter += binary.size()[0]
                acc_counter += torch.sum((binary <= 0.5) * (ocr_answer_flags == 0) + (binary > 0.5) * (ocr_answer_flags == 1))
                #print(all_counter, acc_counter)
            else:
                loss = criterion(pred, label.long())
            loss = (loss.data).cpu().numpy()
            loss_list.append(loss)

        if opt.BINARY:
            binary = (binary.data).cpu().numpy()
            pred1 = (pred1.data).cpu().numpy()
            pred2 = (pred2.data).cpu().numpy()
            pred = np.hstack([pred1, pred2])
        else:
            pred = (pred.data).cpu().numpy()
        if opt.OCR:
            # select the largest index within the ocr length boundary
            ocr_mask = np.fromfunction(lambda i, j: j >= (ocr_length[i].cpu().numpy() + opt.MAX_ANSWER_VOCAB_SIZE), pred.shape, dtype=int)
            if opt.BINARY:
                #ocr_mask += np.fromfunction(lambda i, j: np.logical_or(np.logical_and(binary[i] <= 0.5, j >= opt.MAX_ANSWER_VOCAB_SIZE), np.logical_and(binary[i] > 0.5, j < opt.MAX_ANSWER_VOCAB_SIZE)), pred.shape, dtype=int)
                #ocr_mask += np.fromfunction(lambda i, j: np.logical_or(np.logical_and(ocr_answer_flags[i] == 0, j >= opt.MAX_ANSWER_VOCAB_SIZE), np.logical_and(ocr_answer_flags[i] == 1, j < opt.MAX_ANSWER_VOCAB_SIZE)), pred.shape, dtype=int)
                ocr_mask += np.fromfunction(lambda i, j: np.logical_or(np.logical_and(ocr_answer_flags[i].cpu().numpy() == 0, j >= opt.MAX_ANSWER_VOCAB_SIZE), np.logical_and(ocr_answer_flags[i].cpu().numpy() == 1, j < opt.MAX_ANSWER_VOCAB_SIZE)), pred.shape, dtype=int)

            masked_pred = np.ma.array(pred, mask=ocr_mask)
            pred_max = np.ma.argmax(masked_pred, axis=1)
            pred_str = [dp.vec_to_answer_ocr(pred_symbol, ocr) for pred_symbol, ocr in zip(pred_max, ocr_tokens)]
        else:
            pred_max = np.argmax(pred, axis=1)
            pred_str = [dp.vec_to_answer(pred_symbol) for pred_symbol in pred_max]

        for qid, iid, ans, pred, ocr in zip(qid_list, iid_list, answer.tolist(), pred_str, ocr_tokens):
            pred_list.append((pred, int(dp.getStrippedQuesId(qid))))
            # prepare pred json file
            if visualize:
                q_list = dp.seq_to_list(dp.getQuesStr(qid), opt.MAX_QUESTION_LENGTH)
                if mode == 'test-dev' or mode == 'test':
                    ans_str = ''
                    ans_list = ['']*10
                else:
                    if opt.OCR:
                        ans_str = dp.vec_to_answer_ocr(int(ans), ocr)
                    else:
                        ans_str = dp.vec_to_answer(int(ans))
                    ans_list = [ dp.getAnsObj(qid)[i]['answer'] for i in range(10)]
                stat_list.append({
                    'qid': qid,
                    'q_list': q_list,
                    'iid': iid,
                    'answer': ans_str,
                    'ans_list': ans_list,
                    'pred': pred,
                    'ocr_tokens': ocr
                })
        percent = 100 * float(len(pred_list)) / total_questions
        if percent <= 100 and percent - percent_counter >= 5:
            percent_counter = percent
            sys.stdout.write('\r' + ('%.2f' % percent) + '%')
            sys.stdout.flush()

    if visualize:
        with open(os.path.join(folder, 'visualize.json'), 'w') as f:
            json.dump(stat_list, f, indent=4, sort_keys=True)

    if opt.BINARY:
        logger.info('Binary Acc: {},({}/{})'.format(acc_counter.item()/all_counter, acc_counter, all_counter))

    logger.info('Deduping arr of len {}'.format(len(pred_list)))
    deduped = []
    seen = set()
    for ans, qid in pred_list:
        if qid not in seen:
            seen.add(qid)
            deduped.append((ans, qid))
    logger.info('New len {}'.format(len(deduped)))
    final_list=[]
    for ans,qid in deduped:
        final_list.append({u'answer': ans, u'question_id': qid})

    if mode == 'val':
        avg_loss = np.array(loss_list).mean()
        valFile = os.path.join(folder, 'val2015_resfile')
        with open(valFile, 'w') as f:
            json.dump(final_list, f)
        # if visualize:
        #     visualize_pred(stat_list,mode)

        exp_type = opt.EXP_TYPE

        annFile = config.DATA_PATHS[exp_type]['val']['ans_file']
        quesFile = config.DATA_PATHS[exp_type]['val']['ques_file']
        vqa = VQA(annFile, quesFile)
        vqaRes = vqa.loadRes(valFile, quesFile)
        vqaEval = VQAEval(vqa, vqaRes, n=2)
        vqaEval.evaluate()
        acc_overall = vqaEval.accuracy['overall']
        acc_perQuestionType = vqaEval.accuracy['perQuestionType']
        acc_perAnswerType = vqaEval.accuracy['perAnswerType']
    elif mode == 'test-dev':
        filename = os.path.join(folder, 'test-dev_results_' + str(it).zfill(8))
        with open(filename+'.json', 'w') as f:
            json.dump(final_list, f)
        # if visualize:
        #     visualize_pred(stat_list,mode)
    elif mode == 'test':
        filename = os.path.join(folder, 'test_results_' + str(it).zfill(8))
        with open(filename+'.json', 'w') as f:
            json.dump(final_list, f)
        # if visualize:
        #     visualize_pred(stat_list,mode)
    return avg_loss, acc_overall, acc_perQuestionType, acc_perAnswerType