def run(data, word2vec): """ Retrieve, rerank, rewrite. """ global device emb_size = len(data_utils.embeddings[0]) SOS_token = torch.Tensor([i for i in range(emb_size) ]).unsqueeze(0).to(device) EOS_token = torch.Tensor([i + 1 for i in range(emb_size) ]).unsqueeze(0).to(device) w2emb = data_utils.load_w2emb(args.w2emb) w2emb["SOS_token"] = SOS_token.cpu() w2emb["EOS_token"] = EOS_token.cpu() templates = data_utils.load_templates(args.templates) templates = [[temp[-args.max_length:] for temp in part_templ] for part_templ in templates] templates = [[ np.pad(temp2, ((0, args.max_length - len(temp2)), (0, 0)), "constant", constant_values=(len(data_utils.w2i))) for temp2 in temp1 ] for temp1 in templates] templates = [torch.Tensor(class_tm) for class_tm in templates] rewrite = Rewrite(args.saliency_model, args.rewrite_model, data_utils.embeddings, data_utils.w2i, SOS_token, EOS_token, templates, w2emb, device) prediction = ResourcePrediction(args.prediction_model_folder) rouge = Rouge() total = 0 avg_rouge1 = 0 avg_rouge2 = 0 avg_rougeL = 0 avg_bleu = 0 smooth = SmoothingFunction() for example in tqdm(data): resources = [] embedded_resources = [] class_indices = [] data_utils.get_resources(example["documents"]["comments"], resources, embedded_resources) num_comments = len(resources) data_utils.get_resources(example["documents"]["fact_table"], resources, embedded_resources) num_facts = len(resources) - num_comments data_utils.get_resources(example["documents"]["plot"], resources, embedded_resources) num_plots = len(resources) - num_comments - num_facts data_utils.get_resources(example["documents"]["review"], resources, embedded_resources) num_reviews = len(resources) - num_comments - num_facts - num_plots # Keep track of where each resource originated from. class_indices += [2] * num_comments class_indices += [3] * num_facts class_indices += [0] * num_plots class_indices += [1] * num_reviews chat = example["chat"] # Loop over each of the last three utterances in the chat (the context). for i in range(3, len(chat) - 1): last_utterances = chat[i - 3:i] response = chat[i + 1] if len(response) > 0: embedded_utterances = [ data_utils.embed_sentence(utterance) for utterance in last_utterances ] context, embedded_context = data_utils.get_context( last_utterances) # Retrieve: Takes context and resources. Uses Word Mover's # Distance to obtain relevant resource candidates. similarities = retrieve(context, resources, word2vec) # Predict: Takes context and predicts the category of the # resource. Take the maximum length as max and pad the context # to maximum length if it is too short. if args.use_gensim: constant_values = len(data_utils.embeddings.index2word) else: constant_values = len(data_utils.w2i) last_utterance = embedded_utterances[-2] padded_utterance = last_utterance[-args.max_length:] padded_utterance = np.pad( padded_utterance, ((0, args.max_length - len(padded_utterance)), (0, 0)), "constant", constant_values=(constant_values)) if args.prediction: predicted = prediction.predict( np.expand_dims(padded_utterance, 0)) else: predicted = np.array([[0.25, 0.25, 0.25, 0.25]]) # Rerank Resources: Takes ranked resource candidates and class # prediction and reranks them. ranked_resources, ranked_classes = rerank( embedded_resources, class_indices, similarities, predicted) # Rerank Templates: Takes best resource and ranks the templates # accordingly. Returns the best template. best_resource, best_template = rewrite.rerank( ranked_resources[0], ranked_classes[0]) # Rewrite: Takes the best resource and best template and # rewrites them into a single response. best_response = rewrite.rewrite(best_resource, best_template) total += 1 rouge_scores = rouge.get_scores(best_response, response)[0] avg_rouge1 += rouge_scores["rouge-1"]["f"] avg_rouge2 += rouge_scores["rouge-2"]["f"] avg_rougeL += rouge_scores["rouge-l"]["f"] avg_bleu += sentence_bleu([response], best_response, smoothing_function=smooth.method1) print("Average rouge1: " + str(avg_rouge1 / total)) print("Average rouge2: " + str(avg_rouge2 / total)) print("Average rougel: " + str(avg_rougeL / total)) print("Average bleu: " + str(avg_bleu / total))
def setUp(self): self.handler = MockRequestHandler() self.rewrite = Rewrite(self.handler) self.rewrite.root = os.path.join(os.path.dirname(__file__), 'testdata')
def test_init_define(self): self.rewrite = Rewrite(self.handler, mime={'.js': 'something else'}) self.assertEqual(self.rewrite.get_mime('.js'), 'something else')
class test_rewrite_function(unittest.TestCase): def setUp(self): self.handler = MockRequestHandler() self.rewrite = Rewrite(self.handler) self.rewrite.root = os.path.join(os.path.dirname(__file__), 'testdata') def test_inner_predefine(self): self.assertEqual(self.rewrite.get_mime('.js'), 'text/javascript') def test_init_define(self): self.rewrite = Rewrite(self.handler, mime={'.js': 'something else'}) self.assertEqual(self.rewrite.get_mime('.js'), 'something else') def test_predefine(self): self.assertEqual(self.rewrite.get_mime('.bmp'), 'image/x-ms-bmp') def test_none_define(self): self.assertEqual(self.rewrite.get_mime('.kkk'), 'application/x-kkk') def test_rewrite_callback(self): callbacked = [False] def callback(url): callbacked[0] = True self.assertEqual(url, 'test?param=1') self.rewrite.add_rewrite_callback('test', callback) self.assertEqual(self.rewrite.match('test?param=1'), True) self.assertEqual(callbacked[0], True) def test_rewrite_callback_fail(self): callbacked = [False] def callback(url): callbacked[0] = True self.assertEqual(url, 'test?param=1') self.rewrite.add_rewrite_callback('test1', callback) self.assertEqual(self.rewrite.match('test?param=1'), False) #never call back self.assertEqual(callbacked[0], False) def test_get_confs(self): expect = [os.path.normpath(os.path.join(self.rewrite.root, 'server-conf/test.conf'))] conf_path = os.path.normpath(os.path.join(self.rewrite.root, 'server-conf/')) self.assertEqual(self.rewrite.get_confs(conf_path), expect) def test_get_rulers(self): expect = [ { "type" : "rewrite", "rewrite" : "/test/data/test.py", "rule" : "^runpytest" },{ "type" : "redirect", "rewrite" : "/runpytest?from=redirect", "rule" : "^redirec.*est" } ] self.assertEqual(self.rewrite.get_rulers(), expect) def test_redirect(self): self.assertEqual(self.rewrite.match('redirecttest'), True) self.assertEqual(self.handler.direct, '/runpytest?from=redirect') def test_rewrite_py(self): self.assertEqual(self.rewrite.match('runpytest'), True) actual = json.loads(self.handler.content) self.assertEqual(self.handler.headers['Content-Type'], 'application/json') self.assertEqual(actual['user'], 'hefangshi')