示例#1
0
文件: main.py 项目: Caiit/IR2
def run(data, word2vec):
    """
    Retrieve, rerank, rewrite.
    """

    global device

    emb_size = len(data_utils.embeddings[0])
    SOS_token = torch.Tensor([i for i in range(emb_size)
                              ]).unsqueeze(0).to(device)
    EOS_token = torch.Tensor([i + 1 for i in range(emb_size)
                              ]).unsqueeze(0).to(device)
    w2emb = data_utils.load_w2emb(args.w2emb)
    w2emb["SOS_token"] = SOS_token.cpu()
    w2emb["EOS_token"] = EOS_token.cpu()

    templates = data_utils.load_templates(args.templates)
    templates = [[temp[-args.max_length:] for temp in part_templ]
                 for part_templ in templates]
    templates = [[
        np.pad(temp2, ((0, args.max_length - len(temp2)), (0, 0)),
               "constant",
               constant_values=(len(data_utils.w2i))) for temp2 in temp1
    ] for temp1 in templates]
    templates = [torch.Tensor(class_tm) for class_tm in templates]
    rewrite = Rewrite(args.saliency_model, args.rewrite_model,
                      data_utils.embeddings, data_utils.w2i, SOS_token,
                      EOS_token, templates, w2emb, device)
    prediction = ResourcePrediction(args.prediction_model_folder)

    rouge = Rouge()
    total = 0
    avg_rouge1 = 0
    avg_rouge2 = 0
    avg_rougeL = 0
    avg_bleu = 0

    smooth = SmoothingFunction()

    for example in tqdm(data):
        resources = []
        embedded_resources = []
        class_indices = []

        data_utils.get_resources(example["documents"]["comments"], resources,
                                 embedded_resources)
        num_comments = len(resources)
        data_utils.get_resources(example["documents"]["fact_table"], resources,
                                 embedded_resources)
        num_facts = len(resources) - num_comments
        data_utils.get_resources(example["documents"]["plot"], resources,
                                 embedded_resources)
        num_plots = len(resources) - num_comments - num_facts
        data_utils.get_resources(example["documents"]["review"], resources,
                                 embedded_resources)
        num_reviews = len(resources) - num_comments - num_facts - num_plots

        # Keep track of where each resource originated from.
        class_indices += [2] * num_comments
        class_indices += [3] * num_facts
        class_indices += [0] * num_plots
        class_indices += [1] * num_reviews

        chat = example["chat"]

        # Loop over each of the last three utterances in the chat (the context).
        for i in range(3, len(chat) - 1):
            last_utterances = chat[i - 3:i]
            response = chat[i + 1]

            if len(response) > 0:
                embedded_utterances = [
                    data_utils.embed_sentence(utterance)
                    for utterance in last_utterances
                ]
                context, embedded_context = data_utils.get_context(
                    last_utterances)

                # Retrieve: Takes context and resources. Uses Word Mover's
                # Distance to obtain relevant resource candidates.
                similarities = retrieve(context, resources, word2vec)

                # Predict: Takes context and predicts the category of the
                # resource. Take the maximum length as max and pad the context
                # to maximum length if it is too short.
                if args.use_gensim:
                    constant_values = len(data_utils.embeddings.index2word)
                else:
                    constant_values = len(data_utils.w2i)

                last_utterance = embedded_utterances[-2]
                padded_utterance = last_utterance[-args.max_length:]
                padded_utterance = np.pad(
                    padded_utterance,
                    ((0, args.max_length - len(padded_utterance)), (0, 0)),
                    "constant",
                    constant_values=(constant_values))
                if args.prediction:
                    predicted = prediction.predict(
                        np.expand_dims(padded_utterance, 0))
                else:
                    predicted = np.array([[0.25, 0.25, 0.25, 0.25]])

                # Rerank Resources: Takes ranked resource candidates and class
                # prediction and reranks them.
                ranked_resources, ranked_classes = rerank(
                    embedded_resources, class_indices, similarities, predicted)

                # Rerank Templates: Takes best resource and ranks the templates
                # accordingly. Returns the best template.
                best_resource, best_template = rewrite.rerank(
                    ranked_resources[0], ranked_classes[0])

                # Rewrite: Takes the best resource and best template and
                # rewrites them into a single response.
                best_response = rewrite.rewrite(best_resource, best_template)
                total += 1
                rouge_scores = rouge.get_scores(best_response, response)[0]
                avg_rouge1 += rouge_scores["rouge-1"]["f"]
                avg_rouge2 += rouge_scores["rouge-2"]["f"]
                avg_rougeL += rouge_scores["rouge-l"]["f"]
                avg_bleu += sentence_bleu([response],
                                          best_response,
                                          smoothing_function=smooth.method1)

    print("Average rouge1: " + str(avg_rouge1 / total))
    print("Average rouge2: " + str(avg_rouge2 / total))
    print("Average rougel: " + str(avg_rougeL / total))
    print("Average bleu: " + str(avg_bleu / total))
示例#2
0
 def setUp(self):
     self.handler = MockRequestHandler()
     self.rewrite = Rewrite(self.handler)
     self.rewrite.root = os.path.join(os.path.dirname(__file__), 'testdata')
示例#3
0
 def test_init_define(self):
     self.rewrite = Rewrite(self.handler, mime={'.js': 'something else'})
     self.assertEqual(self.rewrite.get_mime('.js'), 'something else')
示例#4
0
class test_rewrite_function(unittest.TestCase):

    def setUp(self):
        self.handler = MockRequestHandler()
        self.rewrite = Rewrite(self.handler)
        self.rewrite.root = os.path.join(os.path.dirname(__file__), 'testdata')

    def test_inner_predefine(self):
        self.assertEqual(self.rewrite.get_mime('.js'), 'text/javascript')

    def test_init_define(self):
        self.rewrite = Rewrite(self.handler, mime={'.js': 'something else'})
        self.assertEqual(self.rewrite.get_mime('.js'), 'something else')

    def test_predefine(self):
        self.assertEqual(self.rewrite.get_mime('.bmp'), 'image/x-ms-bmp')

    def test_none_define(self):
        self.assertEqual(self.rewrite.get_mime('.kkk'), 'application/x-kkk')

    def test_rewrite_callback(self):
        callbacked = [False]
        def callback(url):
            callbacked[0] = True
            self.assertEqual(url, 'test?param=1')
        self.rewrite.add_rewrite_callback('test', callback)
        self.assertEqual(self.rewrite.match('test?param=1'), True)
        self.assertEqual(callbacked[0], True)

    def test_rewrite_callback_fail(self):
        callbacked = [False]
        def callback(url):
            callbacked[0] = True
            self.assertEqual(url, 'test?param=1')
        self.rewrite.add_rewrite_callback('test1', callback)
        self.assertEqual(self.rewrite.match('test?param=1'), False)
        #never call back
        self.assertEqual(callbacked[0], False)

    def test_get_confs(self):
        expect = [os.path.normpath(os.path.join(self.rewrite.root, 'server-conf/test.conf'))]
        conf_path = os.path.normpath(os.path.join(self.rewrite.root, 'server-conf/'))
        self.assertEqual(self.rewrite.get_confs(conf_path), expect)

    def test_get_rulers(self):
        expect = [
            {
                "type" : "rewrite",
                "rewrite" : "/test/data/test.py",
                "rule" : "^runpytest"
            },{
                "type" : "redirect",
                "rewrite" : "/runpytest?from=redirect",
                "rule" : "^redirec.*est"
            }
        ]
        self.assertEqual(self.rewrite.get_rulers(), expect)

    def test_redirect(self):
        self.assertEqual(self.rewrite.match('redirecttest'), True)
        self.assertEqual(self.handler.direct, '/runpytest?from=redirect')

    def test_rewrite_py(self):
        self.assertEqual(self.rewrite.match('runpytest'), True)
        actual = json.loads(self.handler.content)
        self.assertEqual(self.handler.headers['Content-Type'], 'application/json')
        self.assertEqual(actual['user'], 'hefangshi')