Example #1
0
def test(config):
    Best_Model = torch.load(config.test_model)
    Tokenizer = BertTokenizer.from_pretrained(config.model)

    f_in = open(config.inputFile, 'r')

    net = BertForMaskedLM.from_pretrained(config.model)

    # When loading from a model not trained from DataParallel
    #net.load_state_dict(Best_Model['state_dict'])
    #net.eval()

    if torch.cuda.is_available():
        net = net.cuda(0)
        if config.dataParallel:
            net = DataParallelModel(net)

    # When loading from a model trained from DataParallel
    net.load_state_dict(Best_Model['state_dict'])
    net.eval()

    mySearcher = Searcher(net, config)

    f_top1 = open('summary' + config.suffix + '.txt', 'w', encoding='utf-8')
    f_topK = open('summary' + config.suffix + '.txt.' +
                  str(config.answer_size),
                  'w',
                  encoding='utf-8')

    ed = '\n------------------------\n'

    for idx, line in enumerate(f_in):
        source_ = line.strip().split()
        source = Tokenizer.tokenize(line.strip())
        mapping = mapping_tokenize(source_, source)

        source = Tokenizer.convert_tokens_to_ids(source)

        print(idx)
        print(detokenize(translate(source, Tokenizer), mapping), end=ed)

        l_pred = mySearcher.length_Predict(source)
        Answers = mySearcher.search(source)
        baseline = sum(Answers[0][0])

        if config.reranking_method == 'none':
            Answers = sorted(Answers, key=lambda x: sum(x[0]))
        elif config.reranking_method == 'length_norm':
            Answers = sorted(Answers, key=lambda x: length_norm(x[0]))
        elif config.reranking_method == 'bounded_word_reward':
            Answers = sorted(
                Answers,
                key=lambda x: bounded_word_reward(x[0], config.reward, l_pred))
        elif config.reranking_method == 'bounded_adaptive_reward':
            Answers = sorted(
                Answers,
                key=lambda x: bounded_adaptive_reward(x[0], x[2], l_pred))

        texts = [
            detokenize(translate(Answers[k][1], Tokenizer), mapping)
            for k in range(len(Answers))
        ]

        if baseline != sum(Answers[0][0]):
            print('Reranked!')

        print(texts[0], end=ed)
        print(texts[0], file=f_top1)
        print(len(texts), file=f_topK)
        for i in range(len(texts)):
            print(Answers[i][0], file=f_topK)
            print(texts[i], file=f_topK)

    f_top1.close()
    f_topK.close()