示例#1
0
def main():
    p = get_argparser()
    args = p.parse_args()

    lm = LanguageModel()
    lm.configure_logger(level=logging.DEBUG if args.DEBUG else logging.INFO,
                        write_file=True)

    if args.train and args.data_path:
        lm.train(args.data_path,
                 output_path=args.train,
                 learning_rate=args.learning_rate,
                 hidden_size=args.hidden_size,
                 batch_size=args.batch_size,
                 max_epoch=args.max_epoch)

    elif args.test and args.data_path:
        lm.predict(args.test, args.data_path)

    else:
        # Well, this is silly.
        p.print_help()
        exit(2)
示例#2
0
class DialogBackendLocal(DialogBackend):
    def __init__(self):
        super().__init__()

        self.model_lm = LanguageModel()
        self.model_ct = ContentTransfer()
        self.kb = KnowledgeBase()
        self.ranker = Ranker(self.model_lm)
        self.local = True

    def predict(self, context, max_n=1):
        print('backend running, context = %s' % context)
        query = self.get_query(context)

        # get results from different models
        results = self.model_lm.predict(context)

        passages = []
        url_snippet = []
        for line in open('args/kb_sites.txt', encoding='utf-8'):
            cust = line.strip('\n')
            kb_args = {'domain': 'cust', 'cust': cust, 'must_include': []}
            url_snippet.append(self.kb.predict(query, args=kb_args)[0])
            passage = ' ... '.join([snippet for _, snippet in url_snippet])
            passages.append((passage, query))

        for passage, kb_query in passages:
            results += self.model_ct.predict(kb_query, passage)

        # rank hyps from different models

        hyps = [hyp for _, _, hyp in results]
        scored = self.ranker.predict(context, hyps)
        ret = []
        for i, d in enumerate(scored):
            d['way'], _, d['hyp'] = results[i]
            ret.append((d['score'], d))
        ranked = [d for _, d in sorted(ret, reverse=True)]
        if max_n > 0:
            ranked = ranked[:min(len(ranked), max_n)]
        return ranked, url_snippet