コード例 #1
0
ファイル: kgsf.py プロジェクト: linksboy/CRSLab
 def conv_evaluate(self, prediction, response):
     prediction = prediction.tolist()
     response = response.tolist()
     for p, r in zip(prediction, response):
         p_str = ind2txt(p, self.ind2tok, self.end_token_idx)
         r_str = ind2txt(r, self.ind2tok, self.end_token_idx)
         self.evaluator.gen_evaluate(p_str, [r_str])
コード例 #2
0
    def conv_evaluate(self, prediction, response):
        """
        Args:
            prediction: torch.LongTensor, shape=(bs, response_truncate-1)
            response: torch.LongTensor, shape=(bs, response_truncate)

            the first token in response is <|endoftext|>,  it is not in prediction
        """
        prediction = prediction.tolist()
        response = response.tolist()
        for p, r in zip(prediction, response):
            p_str = ind2txt(p, self.ind2tok, self.end_token_idx)
            r_str = ind2txt(r[1:], self.ind2tok, self.end_token_idx)
            self.evaluator.gen_evaluate(p_str, [r_str])
コード例 #3
0
    def interact(self):
        self.init_interact()
        input_text = self.get_input(self.language)
        while not self.finished:
            # rec
            if hasattr(self, 'rec_model'):
                rec_input = self.process_input(input_text, 'rec')
                scores = self.rec_model.forward(rec_input, 'infer')

                scores = scores.cpu()[0]
                scores = scores[self.item_ids]
                _, rank = torch.topk(scores, 10, dim=-1)
                item_ids = []
                for r in rank.tolist():
                    item_ids.append(self.item_ids[r])
                first_item_id = item_ids[:1]
                self.update_context('rec',
                                    entity_ids=first_item_id,
                                    item_ids=first_item_id)

                print(f"[Recommend]:")
                for item_id in item_ids:
                    if item_id in self.id2entity:
                        print(self.id2entity[item_id])
            # conv
            if hasattr(self, 'conv_model'):
                conv_input = self.process_input(input_text, 'conv')
                preds = self.conv_model.forward(conv_input,
                                                'infer').tolist()[0]
                p_str = ind2txt(preds, self.ind2tok, self.end_token_idx)

                token_ids, entity_ids, movie_ids, word_ids = self.convert_to_id(
                    p_str, 'conv')
                self.update_context('conv', token_ids, entity_ids, movie_ids,
                                    word_ids)

                print(f"[Response]:\n{p_str}")
            # input
            input_text = self.get_input(self.language)