Exemple #1
0
def make_question_vocab(qdic):
    """
    Returns a dictionary that maps words to indices.
    """
    vdict = {'': 0}
    vid = 1
    for qid in qdic.keys():
        # sequence to list
        q_str = qdic[qid]['qstr']
        q_list = VQADataProvider.seq_to_list(q_str)

        # create dict
        for w in q_list:
            if not vdict.has_key(w):
                vdict[w] = vid
                vid += 1

    return vdict
Exemple #2
0
def exec_validation(model, opt, mode, folder, it, logger, visualize=False, dp=None):
    """
    execute validation and save predictions as json file for visualization
    avg_loss:       average loss on given validation dataset split
    acc_overall:    overall accuracy
    """
    if opt.LATE_FUSION:
        criterion = nn.BCELoss()
        model_prob = model[1]
        model = model[0]
    else:
        criterion = nn.NLLLoss()

    check_mkdir(folder)
    model.eval()
    # criterion = nn.KLDivLoss(reduction='batchmean')
    if opt.BINARY:
        criterion2 = nn.BCELoss()
        acc_counter = 0
        all_counter = 0
    if not dp:
        dp = VQADataProvider(opt, batchsize=opt.VAL_BATCH_SIZE, mode=mode, logger=logger)
    epoch = 0
    pred_list = []
    loss_list = []
    stat_list = []
    total_questions = len(dp.getQuesIds())

    percent_counter = 0

    logger.info('Validating...')
    while epoch == 0:
        data, word_length, img_feature, answer, embed_matrix, ocr_length, ocr_embedding, ocr_tokens, ocr_answer_flags, qid_list, iid_list, epoch = dp.get_batch_vec()
        data = cuda_wrapper(Variable(torch.from_numpy(data))).long()
        word_length = cuda_wrapper(torch.from_numpy(word_length))
        img_feature = cuda_wrapper(Variable(torch.from_numpy(img_feature))).float()
        label = cuda_wrapper(Variable(torch.from_numpy(answer)))
        ocr_answer_flags = cuda_wrapper(torch.from_numpy(ocr_answer_flags))

        if opt.OCR:
            embed_matrix = cuda_wrapper(Variable(torch.from_numpy(embed_matrix))).float()
            ocr_length = cuda_wrapper(torch.from_numpy(ocr_length))
            ocr_embedding= cuda_wrapper(Variable(torch.from_numpy(ocr_embedding))).float()
            if opt.BINARY:
                ocr_answer_flags = cuda_wrapper(ocr_answer_flags)
                if opt.LATE_FUSION:
                    binary = model(data, img_feature, embed_matrix, ocr_length, ocr_embedding, mode)
                    pred = model_prob(data, img_feature, embed_matrix, ocr_length, ocr_embedding, mode)
                    pred1 = pred[:, 0:opt.MAX_ANSWER_VOCAB_SIZE]
                    pred2 = pred[:, opt.MAX_ANSWER_VOCAB_SIZE:]
                else:
                    binary, pred1, pred2 = model(data, img_feature, embed_matrix, ocr_length, ocr_embedding, mode)
            else:
                pred = model(data, img_feature, embed_matrix, ocr_length, ocr_embedding, mode)
        elif opt.EMBED:
            embed_matrix = cuda_wrapper(Variable(torch.from_numpy(embed_matrix))).float()
            pred = model(data, img_feature, embed_matrix, mode)
        else:
            pred = model(data, word_length, img_feature, mode)

        if mode == 'test-dev' or mode == 'test':
            pass
        else:
            if opt.BINARY:
                if opt.LATE_FUSION:
                    loss = criterion(binary, ocr_answer_flags.float())
                else:
                    loss = criterion2(binary, ocr_answer_flags.float()) * opt.BIN_LOSS_RATE
                    loss += criterion(pred1[label < opt.MAX_ANSWER_VOCAB_SIZE], label[label < opt.MAX_ANSWER_VOCAB_SIZE].long())
                    loss += criterion(pred2[label >= opt.MAX_ANSWER_VOCAB_SIZE], label[label >= opt.MAX_ANSWER_VOCAB_SIZE].long() - opt.MAX_ANSWER_VOCAB_SIZE)
                all_counter += binary.size()[0]
                acc_counter += torch.sum((binary <= 0.5) * (ocr_answer_flags == 0) + (binary > 0.5) * (ocr_answer_flags == 1))
                #print(all_counter, acc_counter)
            else:
                loss = criterion(pred, label.long())
            loss = (loss.data).cpu().numpy()
            loss_list.append(loss)

        if opt.BINARY:
            binary = (binary.data).cpu().numpy()
            pred1 = (pred1.data).cpu().numpy()
            pred2 = (pred2.data).cpu().numpy()
            pred = np.hstack([pred1, pred2])
        else:
            pred = (pred.data).cpu().numpy()
        if opt.OCR:
            # select the largest index within the ocr length boundary
            ocr_mask = np.fromfunction(lambda i, j: j >= (ocr_length[i].cpu().numpy() + opt.MAX_ANSWER_VOCAB_SIZE), pred.shape, dtype=int)
            if opt.BINARY:
                #ocr_mask += np.fromfunction(lambda i, j: np.logical_or(np.logical_and(binary[i] <= 0.5, j >= opt.MAX_ANSWER_VOCAB_SIZE), np.logical_and(binary[i] > 0.5, j < opt.MAX_ANSWER_VOCAB_SIZE)), pred.shape, dtype=int)
                #ocr_mask += np.fromfunction(lambda i, j: np.logical_or(np.logical_and(ocr_answer_flags[i] == 0, j >= opt.MAX_ANSWER_VOCAB_SIZE), np.logical_and(ocr_answer_flags[i] == 1, j < opt.MAX_ANSWER_VOCAB_SIZE)), pred.shape, dtype=int)
                ocr_mask += np.fromfunction(lambda i, j: np.logical_or(np.logical_and(ocr_answer_flags[i].cpu().numpy() == 0, j >= opt.MAX_ANSWER_VOCAB_SIZE), np.logical_and(ocr_answer_flags[i].cpu().numpy() == 1, j < opt.MAX_ANSWER_VOCAB_SIZE)), pred.shape, dtype=int)

            masked_pred = np.ma.array(pred, mask=ocr_mask)
            pred_max = np.ma.argmax(masked_pred, axis=1)
            pred_str = [dp.vec_to_answer_ocr(pred_symbol, ocr) for pred_symbol, ocr in zip(pred_max, ocr_tokens)]
        else:
            pred_max = np.argmax(pred, axis=1)
            pred_str = [dp.vec_to_answer(pred_symbol) for pred_symbol in pred_max]

        for qid, iid, ans, pred, ocr in zip(qid_list, iid_list, answer.tolist(), pred_str, ocr_tokens):
            pred_list.append((pred, int(dp.getStrippedQuesId(qid))))
            # prepare pred json file
            if visualize:
                q_list = dp.seq_to_list(dp.getQuesStr(qid), opt.MAX_QUESTION_LENGTH)
                if mode == 'test-dev' or mode == 'test':
                    ans_str = ''
                    ans_list = ['']*10
                else:
                    if opt.OCR:
                        ans_str = dp.vec_to_answer_ocr(int(ans), ocr)
                    else:
                        ans_str = dp.vec_to_answer(int(ans))
                    ans_list = [ dp.getAnsObj(qid)[i]['answer'] for i in range(10)]
                stat_list.append({
                    'qid': qid,
                    'q_list': q_list,
                    'iid': iid,
                    'answer': ans_str,
                    'ans_list': ans_list,
                    'pred': pred,
                    'ocr_tokens': ocr
                })
        percent = 100 * float(len(pred_list)) / total_questions
        if percent <= 100 and percent - percent_counter >= 5:
            percent_counter = percent
            sys.stdout.write('\r' + ('%.2f' % percent) + '%')
            sys.stdout.flush()

    if visualize:
        with open(os.path.join(folder, 'visualize.json'), 'w') as f:
            json.dump(stat_list, f, indent=4, sort_keys=True)

    if opt.BINARY:
        logger.info('Binary Acc: {},({}/{})'.format(acc_counter.item()/all_counter, acc_counter, all_counter))

    logger.info('Deduping arr of len {}'.format(len(pred_list)))
    deduped = []
    seen = set()
    for ans, qid in pred_list:
        if qid not in seen:
            seen.add(qid)
            deduped.append((ans, qid))
    logger.info('New len {}'.format(len(deduped)))
    final_list=[]
    for ans,qid in deduped:
        final_list.append({u'answer': ans, u'question_id': qid})

    if mode == 'val':
        avg_loss = np.array(loss_list).mean()
        valFile = os.path.join(folder, 'val2015_resfile')
        with open(valFile, 'w') as f:
            json.dump(final_list, f)
        # if visualize:
        #     visualize_pred(stat_list,mode)

        exp_type = opt.EXP_TYPE

        annFile = config.DATA_PATHS[exp_type]['val']['ans_file']
        quesFile = config.DATA_PATHS[exp_type]['val']['ques_file']
        vqa = VQA(annFile, quesFile)
        vqaRes = vqa.loadRes(valFile, quesFile)
        vqaEval = VQAEval(vqa, vqaRes, n=2)
        vqaEval.evaluate()
        acc_overall = vqaEval.accuracy['overall']
        acc_perQuestionType = vqaEval.accuracy['perQuestionType']
        acc_perAnswerType = vqaEval.accuracy['perAnswerType']
    elif mode == 'test-dev':
        filename = os.path.join(folder, 'test-dev_results_' + str(it).zfill(8))
        with open(filename+'.json', 'w') as f:
            json.dump(final_list, f)
        # if visualize:
        #     visualize_pred(stat_list,mode)
    elif mode == 'test':
        filename = os.path.join(folder, 'test_results_' + str(it).zfill(8))
        with open(filename+'.json', 'w') as f:
            json.dump(final_list, f)
        # if visualize:
        #     visualize_pred(stat_list,mode)
    return avg_loss, acc_overall, acc_perQuestionType, acc_perAnswerType