コード例 #1
0
ファイル: profiles.py プロジェクト: fuminorihomma/consistency
    def check_conflict(self, sys_texts, sys_labels):
        # check system candidates' confict with the system profiles
        # e.g. user: i want to donate a dollar
        #      sys: how much will you donate?
        def check_conflict_for_one_utt(sys_text, sys_label):
            is_repetition, repetition_ratio = is_repetition_with_context(sys_text, 
                                                                        itertools.chain(*self.profile.values()), 
                                                                        threshold=cfg.repetition_threshold)

            if "inquiry" in sys_label:
                # e.g. sys: have you heard of save the children?
                #      usr: i have heard of it
                #      sys: do you know about save the children?
                if self.is_inquiry_answered(sys_text, sys_label):
                    # 1.1 real repetition, 
                    # this is repetition inquiry
                    if cfg.debug:
                        print("{} inquiry encountered in user_profile check! {}: {}".format(cfg.REPETITION, sys_label, sys_text))
                    return cfg.REPETITION, repetition_ratio
                else:
                    # 1.2 fake repetition,
                    # where the user never replies the inquiry
                    return cfg.PASS, repetition_ratio
            elif is_repetition:
                if cfg.debug:
                    print("exact repetition with user utterance encountered in user_profile check! {}: {}".format(sys_label, sys_text))
                
                return cfg.PASS, repetition_ratio
            else:
                return cfg.PASS, repetition_ratio

        conflict_statuses, conflict_amounts = [], []
        for sys_text, sys_label in zip(sys_texts, sys_labels):
            conflict_status, conflict_amount = check_conflict_for_one_utt(sys_text, sys_label)
            conflict_statuses.append(conflict_status)
            conflict_amounts.append(conflict_amount)

        if len(sys_texts) == 1:
            return conflict_statuses[0], conflict_amounts[0], sys_texts, sys_labels
        else:
            edited_sents = []
            edited_sent_acts = []
            for status, sys_text, sys_label in zip(conflict_statuses, sys_texts, sys_labels):
                if status not in [cfg.PASS]:
                    pass
                else:
                    edited_sents.append(sys_text)
                    edited_sent_acts.append(sys_label)
            if len(edited_sents) == 0:
                is_repetition, repetition_ratio = is_repetition_with_context(" ".join(sys_texts), 
                                                                        itertools.chain(*self.profile.values()), 
                                                                        threshold=cfg.repetition_threshold)

                return cfg.NOT_PASS, repetition_ratio, sys_texts, sys_labels
            else:
                is_repetition, repetition_ratio = is_repetition_with_context(" ".join(edited_sents), 
                                                                        itertools.chain(*self.profile.values()), 
                                                                        threshold=cfg.repetition_threshold)

                return cfg.PASS, repetition_ratio, edited_sents, edited_sent_acts
コード例 #2
0
ファイル: profiles.py プロジェクト: fuminorihomma/consistency
        def check_conflict_for_one_utt(sys_text, sys_label):
            is_repetition, repetition_ratio = is_repetition_with_context(sys_text, 
                                                                        itertools.chain(*self.profile.values()), 
                                                                        threshold=cfg.repetition_threshold)

            if "inquiry" in sys_label:
                # e.g. sys: have you heard of save the children?
                #      usr: i have heard of it
                #      sys: do you know about save the children?
                if self.is_inquiry_answered(sys_text, sys_label):
                    # 1.1 real repetition, 
                    # this is repetition inquiry
                    if cfg.debug:
                        print("{} inquiry encountered in user_profile check! {}: {}".format(cfg.REPETITION, sys_label, sys_text))
                    return cfg.REPETITION, repetition_ratio
                else:
                    # 1.2 fake repetition,
                    # where the user never replies the inquiry
                    return cfg.PASS, repetition_ratio
            elif is_repetition:
                if cfg.debug:
                    print("exact repetition with user utterance encountered in user_profile check! {}: {}".format(sys_label, sys_text))
                
                return cfg.PASS, repetition_ratio
            else:
                return cfg.PASS, repetition_ratio
コード例 #3
0
        def check_conflict_for_one_utt(sys_text, sys_label):
            is_repetition, repetition_ratio = is_repetition_with_context(sys_text, 
                                                                        itertools.chain(*self.sent_profile.values()), 
                                                                        threshold=cfg.repetition_threshold)

            
            if True: #sys_label in self.sent_profile:
                if "inquiry" in sys_label:
                    # 1. inquiry
                    if False:
                        # this is repetition inquiry
                        print("{} encountered! {}: {}".format(cfg.REPETITION, sys_label, sys_text))
                        return cfg.REPETITION, repetition_ratio
                    else:
                        return cfg.PASS, repetition_ratio
                else:
                    # 2. statement                    
                    if is_repetition:
                        # 2.1 fake repetition
                        if self.is_qa_pair():
                            # elif (usr_label, sys_label) in cfg.QA_PAIR_WITH_UNIQUE_ANSWER_DB: 
                            #     # case 2: potentially be fake repetitions 
                            #     # (when user asks, system repeats similar answers)
                            #     # therefore, temporarily solve by 
                            #     # 1) external database support, or 
                            #     # 2) if sample 20 candidates, all the same, then that probably means there is only one best answer to the question
                            return cfg.PASS, repetition_ratio
                        # 2.2 real repetition
                        else:
                            if cfg.verbose:
                                print("{} encountered! {}: {}".format(cfg.REPETITION, sys_label, sys_text))
                            return cfg.REPETITION, repetition_ratio

                    else:
                        return cfg.PASS, repetition_ratio
コード例 #4
0
ファイル: profiles.py プロジェクト: fuminorihomma/consistency
    def check_conflict(self, sys_texts, sys_labels):
        # check system candidates' confict with the user profiles
        # label = self.regex_label(sys_text, context, turn_i)
        def check_conflict_for_one_utt(sys_text, sys_label):
            is_repetition, repetition_ratio = is_repetition_with_context(sys_text, 
                                                                        itertools.chain(*self.profile.values()), 
                                                                        threshold=cfg.repetition_threshold)

            
            if sys_label in self.profile:
                if "inquiry" in sys_label:
                    # 1. inquiry
                    if False:
                        # this is repetition inquiry
                        print("{} encountered! {}: {}".format(cfg.REPETITION, sys_label, sys_text))
                        return cfg.REPETITION, repetition_ratio
                    else:
                        return cfg.PASS, repetition_ratio
                else:
                    # 2. statement                    
                    if is_repetition:
                        # 2.1 fake repetition
                        if self.is_qa_pair():
                            # elif (usr_label, sys_label) in cfg.QA_PAIR_WITH_UNIQUE_ANSWER_DB: 
                            #     # case 2: potentially be fake repetitions 
                            #     # (when user asks, system repeats similar answers)
                            #     # therefore, temporarily solve by 
                            #     # 1) external database support, or 
                            #     # 2) if sample 20 candidates, all the same, then that probably means there is only one best answer to the question
                            return cfg.PASS, repetition_ratio
                        # 2.2 real repetition
                        else:
                            print("{} encountered! {}: {}".format(cfg.REPETITION, sys_label, sys_text))
                            return cfg.REPETITION, repetition_ratio

                    else:
                        return cfg.PASS, repetition_ratio

            else:
                if is_repetition:
                    return cfg.REPETITION, repetition_ratio
                else:
                    return cfg.PASS, repetition_ratio

        conflict_statuses, conflict_amounts = [], []
        for sys_text, sys_label in zip(sys_texts, sys_labels):
            conflict_status, conflict_amount = check_conflict_for_one_utt(sys_text, sys_label)
            conflict_statuses.append(conflict_status)
            conflict_amounts.append(conflict_amount)
        
        if len(sys_texts) == 1:
            return conflict_statuses[0], conflict_amounts[0], sys_texts, sys_labels
        else:
            edited_sents = []
            edited_sent_acts = []
            for status, sys_text, sys_label in zip(conflict_statuses, sys_texts, sys_labels):
                if status not in [cfg.PASS]:
                    pass
                else:
                    edited_sents.append(sys_text)
                    edited_sent_acts.append(sys_label)
            if len(edited_sents) == 0:
                is_repetition, repetition_ratio = is_repetition_with_context(" ".join(sys_texts), 
                                                                        itertools.chain(*self.profile.values()), 
                                                                        threshold=cfg.repetition_threshold)

                return cfg.NOT_PASS, repetition_ratio, sys_texts, sys_labels
            else:
                is_repetition, repetition_ratio = is_repetition_with_context(" ".join(edited_sents), 
                                                                        itertools.chain(*self.profile.values()), 
                                                                        threshold=cfg.repetition_threshold)

                return cfg.PASS, repetition_ratio, edited_sents, edited_sent_acts
コード例 #5
0
    def enforce(self, sents, sent_acts, past):
        """
        return: 
               None: no rule needed,
               int: one candidate selected
               str: no candidate selected, should append the returned sentence to the end
        """
        # if cfg.rl_finetune:
        #     return None
        if cfg.verbose:
            print("\n\n\n--------- rule enforce --------------")
        if self.chatbot.turn_i >= cfg.HAVE_TO_PROPOSE:
            # have to propose donation at this turn if it hasn't proposed yet
            enforced_acts = [
                SystemAct.propose_donation_inquiry,
                SystemAct.PROVIDE_DONATION_PROCEDURE
            ]
            enforced_templates = self.sys_template.get_template(enforced_acts)
            if (self.chatbot.global_profile.usr_world.usr_profile[self.chatbot.domain.WANT_TO_DONATE] != self.chatbot.domain.INIT)\
                or SystemAct.propose_donation_inquiry not in self.chatbot.global_profile.sys_world.sent_profile.keys():
                # if SystemAct.propose_donation_inquiry not in self.chatbot.global_profile.sys_world.sent_profile.keys():
                # we should enforce rule
                # we should check the enforced templates are not repetition
                is_repetition, repetition_score = is_repetition_with_context(
                    enforced_templates[0],
                    itertools.chain(*self.chatbot.global_profile.sys_world.
                                    sent_profile.values()),
                    threshold=cfg.repetition_threshold)
                if is_repetition:
                    if cfg.verbose:
                        print("case 1")
                        print(enforced_templates[0])
                    return None
                else:
                    # for i, acts in enumerate(sent_act_candidates):
                    for act in sent_acts:
                        if act == SystemAct.propose_donation_inquiry:
                            if cfg.verbose:
                                print("case 2")
                            return True
                    if cfg.verbose:
                        print("case 3")
                    return enforced_templates, enforced_acts  # didn't find appropriate candidates, so we append this sentence

                # edited_enforced_templates = []
                # edited_enforced_acts = []
                # for template, act in zip(enforced_templates, enforced_acts):
                #     if act == SystemAct.propose_donation_inquiry and \
                #         is_repetition_with_context(template,
                #                                   itertools.chain(*self.chatbot.sys_profile.values()),
                #                                   threshold=cfg.repetition_threshold):
                #         pass
                #     else:
                #         edited_enforced_templates.append(template)
                #         edited_enforced_acts.append(act)

            else:
                if cfg.verbose:
                    print("case 4")
                return None

        if cfg.verbose:
            print("case 5")
        return None