def __init__(self, args):
        """
        Initializes the internal structures of the SlotFillingDST. Loads the
        DataBase and Ontology, retrieves the DataBase table name, and creates
        the dialogue State.
        :param args:
        """

        super(SlotFillingDST, self).__init__()

        if 'ontology' not in args:
            raise AttributeError('SlotFillingDST: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('SlotFillingDST: Please provide database!')
        if 'domain' not in args:
            raise AttributeError('SlotFillingDST: Please provide domain!')

        ontology = args['ontology']
        database = args['database']
        domain = args['domain']

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[-3:] == '.db':
                self.database = SQLDataBase(database)
            elif database[-5:] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)

        else:
            raise ValueError('Unacceptable datbase type %s ' % database)

        # Get Table name
        self.db_table_name = self.database.get_table_name()

        self.DB_ITEMS = 0  # This will raise an error!

        self.domain = domain
        if domain in ['CamRest', 'SlotFilling']:
            self.DState = \
                SlotFillingDialogueState(
                    {'slots': self.ontology.ontology['system_requestable']})
        else:
            print('Warning! domain has not been defined. Using Slot-Filling '
                  'dialogue State')
            self.DState = \
                SlotFillingDialogueState(
                    {'slots': self.ontology.ontology['system_requestable']})
示例#2
0
    def __init__(self, args):
        """
        Initializes the internal structures of the Goal Generator and does
        some checks.

        :param args: the goal generator's arguments
        """

        if 'ontology' not in args:
            raise ValueError('Goal Generator called without an ontology!')

        if 'database' not in args:
            raise ValueError('Goal Generator called without a database!')

        self.ontology = None
        if isinstance(args['ontology'], Ontology):
            self.ontology = args['ontology']

        elif isinstance(args['ontology'], str):
            self.ontology = Ontology(args['ontology'])

        else:
            raise ValueError('Unacceptable ontology type %s ' %
                             args['ontology'])

        self.database = None
        if isinstance(args['database'], DataBase):
            self.database = args['database']

        elif isinstance(args['database'], str):
            if args['database'][-3:] == '.db':
                self.database = SQLDataBase(args['database'])
            elif args['database'][-5:] == '.json':
                self.database = JSONDataBase(args['database'])
            else:
                raise ValueError('Unacceptable database type %s ' %
                                 args['database'])
        else:
            raise ValueError('Unacceptable database type %s ' %
                             args['database'])

        self.goals_file = None
        if 'goals_file' in args:
            self.goals_file = args['goals_file']

        self.goals = None

        if self.goals_file:
            self.load_goals(self.goals_file)

        # Get the slot names from the database
        cursor = self.database.SQL_connection.cursor()

        # Get Table name
        result = \
            cursor.execute(
                "select * from sqlite_master where type = 'table';").fetchall()
        if result and result[0] and result[0][1]:
            self.db_table_name = result[0][1]
        else:
            raise ValueError('Goal Generator cannot specify Table Name from '
                             'database {0}'.format(self.database.db_file_name))

        # Dummy SQL command
        sql_command = "SELECT * FROM " + self.db_table_name + " LIMIT 1;"

        cursor.execute(sql_command)
        self.slot_names = [i[0] for i in cursor.description]

        self.db_row_count = \
            cursor.execute("SELECT COUNT(*) FROM " +
                           self.db_table_name + ";").fetchall()[0][0]
示例#3
0
    def __init__(self, args):
        """
        Parses the arguments in the dictionary and initializes the appropriate
        models for dialogue State Tracking and dialogue Policy.

        :param args: the configuration file parsed into a dictionary
        """
        
        super(DialogueManager, self).__init__()

        if 'settings' not in args:
            raise AttributeError(
                'DialogueManager: Please provide settings (config)!')
        if 'ontology' not in args:
            raise AttributeError('DialogueManager: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('DialogueManager: Please provide database!')
        if 'domain' not in args:
            raise AttributeError('DialogueManager: Please provide domain!')

        settings = args['settings']
        ontology = args['ontology']
        database = args['database']
        domain = args['domain']

        agent_id = 0
        if 'agent_id' in args:
            agent_id = int(args['agent_id'])

        agent_role = 'system'
        if 'agent_role' in args:
            agent_role = args['agent_role']

        self.settings = settings

        self.TRAIN_DST = False
        self.TRAIN_POLICY = False

        self.MAX_DB_RESULTS = 10

        self.DSTracker = None
        self.policy = None
        self.policy_path = None
        self.ontology = None
        self.database = None
        self.domain = None

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.dialogue_counter = 0
        self.CALCULATE_SLOT_ENTROPIES = True

        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[-3:] == '.db':
                self.database = SQLDataBase(database)
            elif database[-5:] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)

        else:
            raise ValueError('Unacceptable database type %s ' % database)
                
        if args and args['policy']:
            if 'domain' in self.settings['DIALOGUE']:
                self.domain = self.settings['DIALOGUE']['domain']
            else:
                raise ValueError(
                    'domain is not specified in DIALOGUE at config.')

            if 'calculate_slot_entropies' in args:
                self.CALCULATE_SLOT_ENTROPIES = \
                    bool(args['calculate_slot_entropies'])

            if args['policy']['type'] == 'handcrafted':
                self.policy = HandcraftedPolicy({'ontology': self.ontology})

            elif args['policy']['type'] == 'q_learning':
                alpha = None
                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                gamma = None
                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                epsilon = None
                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                alpha_decay = None
                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                epsilon_decay = None
                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    QPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain,
                        'alpha': alpha,
                        'epsilon': epsilon,
                        'gamma': gamma,
                        'alpha_decay': alpha_decay,
                        'epsilon_decay': epsilon_decay})

            elif args['policy']['type'] == 'minimax_q':
                alpha = 0.25
                gamma = 0.95
                epsilon = 0.25
                alpha_decay = 0.9995
                epsilon_decay = 0.995

                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    MinimaxQPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain,
                        'alpha': alpha,
                        'epsilon': epsilon,
                        'gamma': gamma,
                        'alpha_decay': alpha_decay,
                        'epsilon_decay': epsilon_decay})

            elif args['policy']['type'] == 'wolf_phc':
                alpha = 0.25
                gamma = 0.95
                epsilon = 0.25
                alpha_decay = 0.9995
                epsilon_decay = 0.995

                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    WoLFPHCPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain,
                        'alpha': alpha,
                        'epsilon': epsilon,
                        'gamma': gamma,
                        'alpha_decay': alpha_decay,
                        'epsilon_decay': epsilon_decay})

            elif args['policy']['type'] == 'reinforce':
                alpha = None
                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                gamma = None
                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                epsilon = None
                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                alpha_decay = None
                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                epsilon_decay = None
                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    ReinforcePolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain,
                        'alpha': alpha,
                        'epsilon': epsilon,
                        'gamma': gamma,
                        'alpha_decay': alpha_decay,
                        'epsilon_decay': epsilon_decay})

            elif args['policy']['type'] == 'calculated':
                self.policy = \
                    CalculatedPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain})

            elif args['policy']['type'] == 'supervised':
                self.policy = \
                    SupervisedPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain})

            elif args['policy']['type'] == 'ludwig':
                if args['policy']['policy_path']:
                    print('DialogueManager: Instantiate your ludwig-based'
                          'policy here')
                else:
                    raise ValueError(
                        'Cannot find policy_path in the config for dialogue '
                        'policy.')
            else:
                raise ValueError('DialogueManager: Unsupported policy type!'
                                 .format(args['policy']['type']))

            if 'train' in args['policy']:
                self.TRAIN_POLICY = bool(args['policy']['train'])

            if 'policy_path' in args['policy']:
                self.policy_path = args['policy']['policy_path']

        # DST Settings
        if 'DST' in args and args['DST']['dst']:
                if args['DST']['dst'] == 'CamRest':
                    if args['DST']['policy']['model_path'] and \
                            args['DST']['policy']['metadata_path']:
                        self.DSTracker = \
                            CamRestDST(
                                {'model_path': args[
                                    'DST']['policy']['model_path']})
                    else:
                        raise ValueError(
                            'Cannot find model_path or metadata_path in the '
                            'config for dialogue state tracker.')

        # Default to dummy DST
        if not self.DSTracker:
            dst_args = dict(
                zip(
                    ['ontology', 'database', 'domain'],
                    [self.ontology, self.database, domain]))
            self.DSTracker = SlotFillingDST(dst_args)

        self.training = self.TRAIN_DST or self.TRAIN_POLICY

        self.load('')
示例#4
0
class DialogueManager(ConversationalModule):
    def __init__(self, args):
        """
        Parses the arguments in the dictionary and initializes the appropriate
        models for dialogue State Tracking and dialogue Policy.

        :param args: the configuration file parsed into a dictionary
        """
        
        super(DialogueManager, self).__init__()

        if 'settings' not in args:
            raise AttributeError(
                'DialogueManager: Please provide settings (config)!')
        if 'ontology' not in args:
            raise AttributeError('DialogueManager: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('DialogueManager: Please provide database!')
        if 'domain' not in args:
            raise AttributeError('DialogueManager: Please provide domain!')

        settings = args['settings']
        ontology = args['ontology']
        database = args['database']
        domain = args['domain']

        agent_id = 0
        if 'agent_id' in args:
            agent_id = int(args['agent_id'])

        agent_role = 'system'
        if 'agent_role' in args:
            agent_role = args['agent_role']

        self.settings = settings

        self.TRAIN_DST = False
        self.TRAIN_POLICY = False

        self.MAX_DB_RESULTS = 10

        self.DSTracker = None
        self.policy = None
        self.policy_path = None
        self.ontology = None
        self.database = None
        self.domain = None

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.dialogue_counter = 0
        self.CALCULATE_SLOT_ENTROPIES = True

        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[-3:] == '.db':
                self.database = SQLDataBase(database)
            elif database[-5:] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)

        else:
            raise ValueError('Unacceptable database type %s ' % database)
                
        if args and args['policy']:
            if 'domain' in self.settings['DIALOGUE']:
                self.domain = self.settings['DIALOGUE']['domain']
            else:
                raise ValueError(
                    'domain is not specified in DIALOGUE at config.')

            if 'calculate_slot_entropies' in args:
                self.CALCULATE_SLOT_ENTROPIES = \
                    bool(args['calculate_slot_entropies'])

            if args['policy']['type'] == 'handcrafted':
                self.policy = HandcraftedPolicy({'ontology': self.ontology})

            elif args['policy']['type'] == 'q_learning':
                alpha = None
                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                gamma = None
                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                epsilon = None
                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                alpha_decay = None
                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                epsilon_decay = None
                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    QPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain,
                        'alpha': alpha,
                        'epsilon': epsilon,
                        'gamma': gamma,
                        'alpha_decay': alpha_decay,
                        'epsilon_decay': epsilon_decay})

            elif args['policy']['type'] == 'minimax_q':
                alpha = 0.25
                gamma = 0.95
                epsilon = 0.25
                alpha_decay = 0.9995
                epsilon_decay = 0.995

                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    MinimaxQPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain,
                        'alpha': alpha,
                        'epsilon': epsilon,
                        'gamma': gamma,
                        'alpha_decay': alpha_decay,
                        'epsilon_decay': epsilon_decay})

            elif args['policy']['type'] == 'wolf_phc':
                alpha = 0.25
                gamma = 0.95
                epsilon = 0.25
                alpha_decay = 0.9995
                epsilon_decay = 0.995

                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    WoLFPHCPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain,
                        'alpha': alpha,
                        'epsilon': epsilon,
                        'gamma': gamma,
                        'alpha_decay': alpha_decay,
                        'epsilon_decay': epsilon_decay})

            elif args['policy']['type'] == 'reinforce':
                alpha = None
                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                gamma = None
                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                epsilon = None
                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                alpha_decay = None
                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                epsilon_decay = None
                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    ReinforcePolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain,
                        'alpha': alpha,
                        'epsilon': epsilon,
                        'gamma': gamma,
                        'alpha_decay': alpha_decay,
                        'epsilon_decay': epsilon_decay})

            elif args['policy']['type'] == 'calculated':
                self.policy = \
                    CalculatedPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain})

            elif args['policy']['type'] == 'supervised':
                self.policy = \
                    SupervisedPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain})

            elif args['policy']['type'] == 'ludwig':
                if args['policy']['policy_path']:
                    print('DialogueManager: Instantiate your ludwig-based'
                          'policy here')
                else:
                    raise ValueError(
                        'Cannot find policy_path in the config for dialogue '
                        'policy.')
            else:
                raise ValueError('DialogueManager: Unsupported policy type!'
                                 .format(args['policy']['type']))

            if 'train' in args['policy']:
                self.TRAIN_POLICY = bool(args['policy']['train'])

            if 'policy_path' in args['policy']:
                self.policy_path = args['policy']['policy_path']

        # DST Settings
        if 'DST' in args and args['DST']['dst']:
                if args['DST']['dst'] == 'CamRest':
                    if args['DST']['policy']['model_path'] and \
                            args['DST']['policy']['metadata_path']:
                        self.DSTracker = \
                            CamRestDST(
                                {'model_path': args[
                                    'DST']['policy']['model_path']})
                    else:
                        raise ValueError(
                            'Cannot find model_path or metadata_path in the '
                            'config for dialogue state tracker.')

        # Default to dummy DST
        if not self.DSTracker:
            dst_args = dict(
                zip(
                    ['ontology', 'database', 'domain'],
                    [self.ontology, self.database, domain]))
            self.DSTracker = SlotFillingDST(dst_args)

        self.training = self.TRAIN_DST or self.TRAIN_POLICY

        self.load('')

    def initialize(self, args):
        """
        Initialize the relevant structures and variables of the dialogue
        Manager.

        :return: Nothing
        """

        self.DSTracker.initialize()
        if 'goal' not in args:
            self.policy.initialize(
                {'is_training': self.TRAIN_POLICY,
                 'policy_path': self.policy_path,
                 'ontology': self.ontology})
        else:
            self.policy.initialize(
                {'is_training': self.TRAIN_POLICY,
                 'policy_path': self.policy_path,
                 'ontology': self.ontology,
                 'goal': args['goal']})

        self.dialogue_counter = 0

    def receive_input(self, inpt):
        """
        Receive input and update the dialogue state.

        :return: Nothing
        """

        # Update dialogue state given the new input
        self.DSTracker.update_state(inpt)

        if self.domain and self.domain in ['CamRest', 'SFH', 'SlotFilling']:
            if self.agent_role == 'system':
                # Perform a database lookup
                db_result, sys_req_slot_entropies = self.db_lookup()

                # Update the dialogue state again to include the database
                # results
                self.DSTracker.update_state_db(
                    db_result=db_result,
                    sys_req_slot_entropies=sys_req_slot_entropies)

            else:
                # Update the dialogue state again to include the system actions
                self.DSTracker.update_state_db(db_result=None, sys_acts=inpt)

        return inpt

    def generate_output(self, args=None):
        """
        Consult the current policy to generate a response.

        :return: List of DialogueAct representing the system's output.
        """
        
        d_state = self.DSTracker.get_state()

        sys_acts = self.policy.next_action(d_state)
        # Copy the sys_acts to be able to iterate over all sys_acts while also
        # replacing some acts
        sys_acts_copy = deepcopy(sys_acts)
        new_sys_acts = []

        # Safeguards to support policies that make decisions on intents only
        # (i.e. do not output slots or values)
        for sys_act in sys_acts:
            if sys_act.intent == 'canthelp' and not sys_act.params:
                slots = \
                    [
                        s for s in d_state.slots_filled if
                        d_state.slots_filled[s]
                    ]
                if slots:
                    slot = random.choice(slots)

                    # Remove the empty canthelp
                    sys_acts_copy.remove(sys_act)

                    new_sys_acts.append(
                        DialogueAct(
                            'canthelp',
                            [DialogueActItem(
                                slot,
                                Operator.EQ,
                                d_state.slots_filled[slot])]))

                else:
                    print('DialogueManager Warning! No slot provided by '
                          'policy for canthelp and cannot find a reasonable '
                          'one!')

            if sys_act.intent == 'offer' and not sys_act.params:
                # Remove the empty offer
                sys_acts_copy.remove(sys_act)

                if d_state.item_in_focus:
                    new_sys_acts.append(
                        DialogueAct(
                            'offer',
                            [DialogueActItem(
                                'name',
                                Operator.EQ,
                                d_state.item_in_focus['name'])]))

                    # Only add these slots if no other acts were output
                    # by the DM
                    if len(sys_acts) == 1:
                        for slot in d_state.slots_filled:
                            if slot in d_state.item_in_focus:
                                if slot not in ['id', 'name'] and \
                                        slot != d_state.requested_slot:
                                    new_sys_acts.append(
                                        DialogueAct(
                                            'inform',
                                            [DialogueActItem(
                                                slot,
                                                Operator.EQ,
                                                d_state.item_in_focus[slot])]))
                            else:
                                new_sys_acts.append(
                                    DialogueAct(
                                        'inform',
                                        [DialogueActItem(
                                            slot,
                                            Operator.EQ,
                                            'no info')]))

            elif sys_act.intent == 'inform':
                if self.agent_role == 'system':
                    if sys_act.params and sys_act.params[0].value:
                        continue

                    if sys_act.params:
                        slot = sys_act.params[0].slot
                    else:
                        slot = d_state.requested_slot

                    if not slot:
                        slot = random.choice(list(d_state.slots_filled.keys()))

                    if d_state.item_in_focus:
                        if slot not in d_state.item_in_focus or \
                                not d_state.item_in_focus[slot]:
                            new_sys_acts.append(
                                DialogueAct(
                                    'inform',
                                    [DialogueActItem(
                                        slot,
                                        Operator.EQ,
                                        'no info')]))
                        else:
                            if slot == 'name':
                                new_sys_acts.append(
                                    DialogueAct(
                                        'offer',
                                        [DialogueActItem(
                                            slot,
                                            Operator.EQ,
                                            d_state.item_in_focus[slot])]))
                            else:
                                new_sys_acts.append(
                                    DialogueAct(
                                        'inform',
                                        [DialogueActItem(
                                            slot,
                                            Operator.EQ,
                                            d_state.item_in_focus[slot])]))

                    else:
                        new_sys_acts.append(
                            DialogueAct(
                                'inform',
                                [DialogueActItem(
                                    slot,
                                    Operator.EQ,
                                    'no info')]))

                elif self.agent_role == 'user':
                    if sys_act.params:
                        slot = sys_act.params[0].slot

                        # Do nothing if the slot is already filled
                        if sys_act.params[0].value:
                            continue

                    elif d_state.last_sys_acts and d_state.user_acts and \
                            d_state.user_acts[0].intent == 'request':
                        slot = d_state.user_acts[0].params[0].slot

                    else:
                        slot = \
                            random.choice(
                                list(d_state.user_goal.constraints.keys()))

                    # Populate the inform with a slot from the user goal
                    if d_state.user_goal:
                        # Look for the slot in the user goal
                        if slot in d_state.user_goal.constraints:
                            value = d_state.user_goal.constraints[slot].value
                        else:
                            value = 'dontcare'

                        new_sys_acts.append(
                            DialogueAct(
                                'inform',
                                [DialogueActItem(
                                    slot,
                                    Operator.EQ,
                                    value)]))

                # Remove the empty inform
                sys_acts_copy.remove(sys_act)

            elif sys_act.intent == 'request':
                # If the policy did not select a slot
                if not sys_act.params:
                    found = False

                    if self.agent_role == 'system':
                        # Select unfilled slot
                        for slot in d_state.slots_filled:
                            if not d_state.slots_filled[slot]:
                                found = True
                                new_sys_acts.append(
                                    DialogueAct(
                                        'request',
                                        [DialogueActItem(
                                            slot,
                                            Operator.EQ,
                                            '')]))
                                break

                    elif self.agent_role == 'user':
                        # Select request from goal
                        if d_state.user_goal:
                            for req in d_state.user_goal.requests:
                                if not d_state.user_goal.requests[req].value:
                                    found = True
                                    new_sys_acts.append(
                                        DialogueAct(
                                            'request',
                                            [DialogueActItem(
                                                req,
                                                Operator.EQ,
                                                '')]))
                                    break

                    if not found:
                        # All slots are filled
                        new_sys_acts.append(
                            DialogueAct(
                                'request',
                                [DialogueActItem(
                                    random.choice(
                                        list(
                                            d_state.slots_filled.keys())[:-1]),
                                    Operator.EQ, '')]))

                    # Remove the empty request
                    sys_acts_copy.remove(sys_act)

        # Append unique new sys acts
        for sa in new_sys_acts:
            if sa not in sys_acts_copy:
                sys_acts_copy.append(sa)

        self.DSTracker.update_state_sysact(sys_acts_copy)

        return sys_acts_copy

    def db_lookup(self):
        """
        Perform an SQLite query given the current dialogue state (i.e. given
        which slots have values).

        :return: a dictionary containing the current database results
        """

        # TODO: Add check to assert if each slot in d_state.slots_filled
        # actually exists in the schema.

        d_state = self.DSTracker.get_state()

        # Query the database
        db_result = self.database.db_lookup(d_state)

        if db_result:
            # Calculate entropy of requestable slot values in results -
            # if the flag is off this will be empty
            entropies = \
                dict.fromkeys(self.ontology.ontology['system_requestable'])

            if self.CALCULATE_SLOT_ENTROPIES:
                value_probabilities = {}

                # Count the values
                for req_slot in self.ontology.ontology['system_requestable']:
                    value_probabilities[req_slot] = {}

                    for db_item in db_result:
                        if db_item[req_slot] not in \
                                value_probabilities[req_slot]:
                            value_probabilities[req_slot][
                                db_item[req_slot]] = 1
                        else:
                            value_probabilities[req_slot][
                                db_item[req_slot]] += 1

                # Calculate probabilities
                for slot in value_probabilities:
                    for value in value_probabilities[slot]:
                        value_probabilities[slot][value] /= len(db_result)

                # Calculate entropies
                for slot in entropies:
                    entropies[slot] = 0

                    if slot in value_probabilities:
                        for value in value_probabilities[slot]:
                            entropies[slot] += \
                                value_probabilities[slot][value] * \
                                math.log(value_probabilities[slot][value])

                    entropies[slot] = -entropies[slot]

            return db_result[:self.MAX_DB_RESULTS], entropies

        # Failed to retrieve anything
        return ['empty'], {}

    def restart(self, args):
        """
        Restart the relevant structures or variables, e.g. at the beginning of
        a new dialogue.

        :return: Nothing
        """

        self.DSTracker.initialize(args)
        self.policy.restart(args)
        self.dialogue_counter += 1

    def update_goal(self, goal):
        """
        Update this agent's goal. This is mainly used to propagate the update
        down to the dialogue State Tracker.

        :param goal: a Goal
        :return: nothing
        """

        if self.DSTracker:
            self.DSTracker.update_goal(goal)
        else:
            print('WARNING: dialogue Manager goal update failed: No dialogue '
                  'State Tracker!')

    def get_state(self):
        """
        Get the current dialogue state

        :return: the dialogue state
        """

        return self.DSTracker.get_state()

    def at_terminal_state(self):
        """
        Assess whether the agent is at a terminal state.

        :return: True or False
        """

        return self.DSTracker.get_state().is_terminal()

    def train(self, dialogues):
        """
        Train the policy and dialogue state tracker, if applicable.

        :param dialogues: dialogue experience
        :return: nothing
        """

        if self.TRAIN_POLICY:
            self.policy.train(dialogues)

        if self.TRAIN_DST:
            self.DSTracker.train(dialogues)

    def is_training(self):
        """
        Assess whether there are any trainable components in this dialogue
        Manager.

        :return: True or False
        """

        return self.TRAIN_DST or self.TRAIN_POLICY

    def load(self, path):
        """
        Load models for the dialogue State Tracker and Policy.

        :param path: path to the policy model
        :return: nothing
        """

        # TODO: Handle path and loading properly
        self.DSTracker.load('')
        self.policy.load(self.policy_path)

    def save(self):
        """
        Save the models.

        :return: nothing
        """

        if self.DSTracker:
            self.DSTracker.save()

        if self.policy:
            self.policy.save(self.policy_path)
class SlotFillingDST(DialogueStateTracker):
    def __init__(self, args):
        """
        Initializes the internal structures of the SlotFillingDST. Loads the
        DataBase and Ontology, retrieves the DataBase table name, and creates
        the dialogue State.
        :param args:
        """

        super(SlotFillingDST, self).__init__()

        if 'ontology' not in args:
            raise AttributeError('SlotFillingDST: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('SlotFillingDST: Please provide database!')
        if 'domain' not in args:
            raise AttributeError('SlotFillingDST: Please provide domain!')

        ontology = args['ontology']
        database = args['database']
        domain = args['domain']

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[-3:] == '.db':
                self.database = SQLDataBase(database)
            elif database[-5:] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)

        else:
            raise ValueError('Unacceptable datbase type %s ' % database)

        # Get Table name
        self.db_table_name = self.database.get_table_name()

        self.DB_ITEMS = 0  # This will raise an error!

        self.domain = domain
        if domain in ['CamRest', 'SlotFilling']:
            self.DState = \
                SlotFillingDialogueState(
                    {'slots': self.ontology.ontology['system_requestable']})
        else:
            print('Warning! domain has not been defined. Using Slot-Filling '
                  'dialogue State')
            self.DState = \
                SlotFillingDialogueState(
                    {'slots': self.ontology.ontology['system_requestable']})

    def initialize(self, args=None):
        """
        Initializes the database results and dialogue state.

        :param args:
        :return:
        """

        if isinstance(self.database, SQLDataBase):
            cursor = self.database.SQL_connection.cursor()
            cursor.execute("SELECT * FROM " + self.db_table_name)
            tmp = cursor.fetchall()
            self.DB_ITEMS = len(tmp)

            if self.DB_ITEMS <= 0:
                print('Warning! DST could not get number of DB items.')
                self.DB_ITEMS = 110  # Default for CamRestaurants

            self.DState.initialize(args)
            # No constraints have been expressed yet
            self.DState.db_matches_ratio = 1.0

            self.DState.turn = 0

        else:
            raise NotImplementedError('JSONDataBase not supported yet.')

    def update_state(self, dacts):
        """
        Update the dialogue state given the input dialogue acts. This function
        basically tracks which intents, slots, and values have been mentioned
        and updates the dialogue state accordingly.

        :param dacts: a list of dialogue acts (usually the output of nlu)
        :return: the updated dialogue state
        """

        # TODO: These rules will create a field in the dialogue state slots
        # filled dictionary if one doesn't exist.
        self.DState.user_acts = deepcopy(dacts)

        # Reset past request
        self.DState.requested_slot = ''

        for dact in dacts:
            if dact.intent in ['inform', 'offer']:
                # The user provided new information so the system hasn't made
                # any offers taking that into account yet.
                # self.DState.system_made_offer = False

                if dact.intent == 'offer':
                    self.DState.system_made_offer = True

                for dact_item in dact.params:
                    if dact_item.slot in self.DState.slots_filled:
                        self.DState.slots_filled[dact_item.slot] = \
                            dact_item.value

                    elif self.DState.user_goal:
                        if dact_item.slot in \
                                self.DState.user_goal.actual_requests:
                            self.DState.user_goal.actual_requests[
                                dact_item.slot].value = dact_item.value

                        # Only update requests that have been asked for
                        if dact_item.slot in self.DState.user_goal.requests:
                            self.DState.user_goal.requests[
                                dact_item.slot].value = dact_item.value

            elif dact.intent == 'request':
                for dact_item in dact.params:
                    # TODO: THIS WILL ONLY SAVE THE LAST DACT ITEM! --
                    # THIS APPLIES TO THE FOLLOWING RULES AS WELL

                    if dact_item.slot == 'slot' and dact_item.value:
                        # Case where we have request(slot = slot_name)
                        self.DState.requested_slot = dact_item.value
                    else:
                        # Case where we have: request(slot_name)
                        self.DState.requested_slot = dact_item.slot

            elif dact.intent == 'bye':
                self.DState.is_terminal_state = True

        # Increment turn
        self.DState.turn += 1

        return self.DState

    def update_state_db(self,
                        db_result=None,
                        sys_req_slot_entropies=None,
                        sys_acts=None):
        """
        This is a special function that is mostly designed for the multi-agent
        setup. If the state belongs to a 'system' agent, then this function
        will update the current database results. If the state belongs to a
        'user' agent, then this function will update the 'item in focus' fields
        of the dialogue state, given the last system action.

        :param db_result: a dictionary containing the database query results
        :param sys_req_slot_entropies: calculated entropies for requestable
                                       slots
        :param sys_acts: the system's acts
        :return:
        """

        if db_result and sys_acts:
            raise ValueError('dialogue State Tracker: Cannot update state as '
                             'both system and user (i.e. please use only one '
                             'argument as appropriate).')

        # This should be called if the agent is a system
        if db_result:
            self.DState.db_matches_ratio = \
                float(len(db_result) / self.DB_ITEMS)

            if db_result[0] == 'empty':
                self.DState.item_in_focus = []

            else:
                self.DState.item_in_focus = db_result[0]

            if sys_req_slot_entropies:
                self.DState.system_requestable_slot_entropies = \
                    deepcopy(sys_req_slot_entropies)

            self.DState.db_result = db_result

        # This should be called if the agent is a user
        elif sys_acts:
            # Create dictionary if it doesn't exist or reset it if a new offer
            # has been made
            if not self.DState.item_in_focus or \
                    'offer' in [a.intent for a in sys_acts]:
                self.DState.item_in_focus = \
                    dict.fromkeys(self.ontology.ontology['requestable'])

            for sys_act in sys_acts:
                if sys_act.intent in ['inform', 'offer']:
                    for item in sys_act.params:
                        self.DState.item_in_focus[item.slot] = item.value

                        if self.DState.user_goal:
                            if item.slot in \
                                    self.DState.user_goal.actual_requests:
                                self.DState.user_goal.actual_requests[
                                    item.slot].value = item.value

                            # Only update requests that have been asked for
                            if item.slot in self.DState.user_goal.requests:
                                self.DState.user_goal.requests[
                                    item.slot].value = item.value

        return self.DState

    def update_state_sysact(self, sys_acts):
        """
        Updates the last system act and the goal, given that act. This is
        useful as we may want to update parts of the state given nlu output
        and then update again once the system produces a response.

        :param sys_acts: the last system acts
        :return:
        """

        if sys_acts:
            self.DState.last_sys_acts = sys_acts

            for sys_act in sys_acts:
                if sys_act.intent == 'offer':
                    self.DState.system_made_offer = True

                # Keep track of actual requests made. These are used in reward
                # and success calculation for systems. The
                # reasoning is that it does not make sense to penalise a system
                # for an unanswered request that was
                # never actually made by the user.
                # If the current agent is a system then these will be
                # disregarded.
                if sys_act.intent == 'request' and sys_act.params and \
                        self.DState.user_goal:
                    self.DState.user_goal.actual_requests[
                        sys_act.params[0].slot] = sys_act.params[0]

                # Similarly, keep track of actual constraints made.
                if sys_act.intent == 'inform' and sys_act.params and \
                        self.DState.user_goal:
                    self.DState.user_goal.actual_constraints[
                        sys_act.params[0].slot] = sys_act.params[0]

                # Reset the request if the system asks for more information,
                # assuming that any previously offered item
                # is now invalid.
                # elif sys_act.intent == 'request':
                #     self.DState.system_made_offer = False

    def update_goal(self, goal):
        """
        Updates the agent's goal

        :param goal: a Goal
        :return:
        """

        # TODO: Do a deep copy?
        self.DState.user_goal = goal

    def train(self, data):
        """
        Nothing to do here.

        :param data:
        :return:
        """
        pass

    def get_state(self):
        """
        Returns the current dialogue state.

        :return: the current dialogue state
        """
        return self.DState

    def save(self, path=None):
        """
        Nothing to do here.

        :param path:
        :return:
        """
        pass

    def load(self, path):
        """
        Nothing to do here.

        :param path:
        :return:
        """
        pass
    def __init__(self, args):
        """
        Initializes the internal structures of the Agenda-Based usr Simulator

        :param args: a dictionary containing an ontology, a database, and
                     other necessary arguments
        """

        super(AgendaBasedUS, self).__init__()

        if 'ontology' not in args:
            raise AttributeError('AgendaBasedUS: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('AgendaBasedUS: Please provide database!')

        ontology = args['ontology']
        database = args['database']

        um = None
        if 'um' in args:
            um = args['um']

        self.nlu = None
        self.nlg = None
        self.dialogue_turn = 0
        self.us_has_initiative = False
        self.policy = None
        self.goals_path = None

        if um is not None:
            self.user_model = um

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[-3:] == '.db':
                self.database = SQLDataBase(database)
            elif database[-5:] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)
        else:
            raise ValueError('Unacceptable database type %s ' % database)

        self.patience = 3

        # Initialize probabilities
        if 'patience' in args:
            self.patience = args['patience']
        if 'pop_distribution' in args:
            self.pop_distribution = args['pop_distribution']
        if 'slot_confuse_prob' in args:
            self.slot_confuse_prob = args['slot_confuse_prob']
        if 'op_confuse_prob' in args:
            self.op_confuse_prob = args['op_confuse_prob']
        if 'value_confuse_prob' in args:
            self.value_confuse_prob = args['value_confuse_prob']

        self.goal_slot_selection_weights = None
        if 'goal_slot_selection_weights' in args:
            self.goal_slot_selection_weights = \
                args['goal_slot_selection_weights']

        if 'nlu' in args:
            nlu_args = \
                dict(zip(['ontology', 'database'],
                         [self.ontology, self.database]))

            if args['nlu'] == 'CamRest':
                self.nlu = CamRestNLU(nlu_args)

            elif args['nlu'] == 'slot_filling':
                self.nlu = SlotFillingNLU(nlu_args)

        if 'nlg' in args:
            if args['nlg'] == 'CamRest':
                if args['nlg_model_path'] and args['nlg_metadata_path']:
                    self.nlg = \
                        CamRestNLG({'model_path': args['nlg_model_path']})
                else:
                    raise ValueError('ABUS: Cannot initialize CamRest nlg '
                                     'without a model path AND a metadata '
                                     'path.')

            elif args['nlg'] == 'slot_filling':
                self.nlg = SlotFillingNLG()

        if 'goals_path' in args:
            self.goals_path = args['goals_path']

        if 'policy_file' in args:
            self.load(args['policy_file'])

        if 'us_has_initiative' in args:
            self.us_has_initiative = args['us_has_initiative']

        self.curr_patience = self.patience

        # Default values for probabilities
        self.pop_distribution = [1.0]
        self.slot_confuse_prob = 0.0
        self.op_confuse_prob = 0.0
        self.value_confuse_prob = 0.0

        self.agenda = agenda.Agenda()
        self.error_model = error_model.ErrorModel(self.ontology, self.database,
                                                  self.slot_confuse_prob,
                                                  self.op_confuse_prob,
                                                  self.value_confuse_prob)

        self.goal_generator = goal.GoalGenerator({
            'ontology': self.ontology,
            'database': self.database,
            'goals_file': self.goals_path
        })
        self.goal = None
        self.offer_made = False
        self.prev_offer_name = None

        # Store previous system actions to keep track of patience
        self.prev_system_acts = None
    def __init__(self, args):
        """
        Parses the arguments in the dictionary and initializes the appropriate
        models for dialogue State Tracking and dialogue Policy.

        :param args: the configuration file parsed into a dictionary
        """

        super(DialogueManagerGeneric, self).__init__()

        if 'settings' not in args:
            raise AttributeError(
                'DialogueManagerGeneric: Please provide settings (config)!')
        if 'ontology' not in args:
            raise AttributeError(
                'DialogueManagerGeneric: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError(
                'DialogueManagerGeneric: Please provide database!')
        if 'domain' not in args:
            raise AttributeError(
                'DialogueManagerGeneric: Please provide domain!')

        settings = args['settings']
        ontology = args['ontology']
        database = args['database']
        domain = args['domain']

        agent_id = 0
        if 'agent_id' in args:
            agent_id = int(args['agent_id'])

        agent_role = 'system'
        if 'agent_role' in args:
            agent_role = args['agent_role']

        self.settings = settings

        self.TRAIN_DST = False
        self.TRAIN_POLICY = False

        self.MAX_DB_RESULTS = 10

        self.DSTracker = None
        self.DSTracker_info = {}

        self.policy = None
        self.policy_info = {}

        self.policy_path = None
        self.ontology = None
        self.database = None
        self.domain = None

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.dialogue_counter = 0
        self.CALCULATE_SLOT_ENTROPIES = True

        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[-3:] == '.db':
                self.database = SQLDataBase(database)
            elif database[-5:] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)

        else:
            raise ValueError('Unacceptable database type %s ' % database)
                
        if args and args['policy']:
            if 'domain' in self.settings['DIALOGUE']:
                self.domain = self.settings['DIALOGUE']['domain']
            else:
                raise ValueError(
                    'domain is not specified in DIALOGUE at config.')

            if 'calculate_slot_entropies' in args:
                self.CALCULATE_SLOT_ENTROPIES = \
                    bool(args['calculate_slot_entropies'])

            if 'package' in args['policy'] and 'class' in args['policy']:
                self.policy_info = args['policy']

                if 'global_arguments' in args['settings']['GENERAL']:
                    if 'arguments' not in self.policy_info:
                        self.policy_info['arguments'] = {}

                    self.policy_info['arguments'].update(
                        args['settings']['GENERAL']['global_arguments']
                    )

                if 'train' in self.policy_info['arguments']:
                    self.TRAIN_POLICY = \
                        bool(self.policy_info['arguments']['train'])

                if 'policy_path' in self.policy_info['arguments']:
                    self.policy_path = \
                        self.policy_info['arguments']['policy_path']

                self.policy_info['arguments']['agent_role'] = self.agent_role

                # Replace ontology and database strings with the actual
                # objects to avoid repetitions (these won't change).
                if 'ontology' in self.policy_info['arguments']:
                    self.policy_info['arguments']['ontology'] = self.ontology

                if 'database' in self.policy_info['arguments']:
                    self.policy_info['arguments']['database'] = self.database

                self.policy = ConversationalGenericAgent.load_module(
                    self.policy_info['package'],
                    self.policy_info['class'],
                    self.policy_info['arguments']
                )

            else:
                raise ValueError('DialogueManagerGeneric: Cannot instantiate'
                                 'dialogue policy!')

        # DST Settings
        if 'DST' in args and args:
            if 'package' in args['DST'] and 'class' in args['DST']:
                self.DSTracker_info['package'] = args['DST']['package']
                self.DSTracker_info['class'] = args['DST']['class']

                self.DSTracker_info['args'] = {}

                if 'global_arguments' in args['settings']['GENERAL']:
                    self.DSTracker_info['args'] = \
                        args['settings']['GENERAL']['global_arguments']

                if 'arguments' in args['DST']:
                    self.DSTracker_info['args']. \
                        update(args['DST']['arguments'])

                self.DSTracker = ConversationalGenericAgent.load_module(
                    self.DSTracker_info['package'],
                    self.DSTracker_info['class'],
                    self.DSTracker_info['args']
                )

            else:
                raise ValueError('DialogueManagerGeneric: Cannot instantiate'
                                 'dialogue state tracker!')

        # Default to dummy DST, if no information is provided
        else:
            dst_args = dict(
                zip(
                    ['ontology', 'database', 'domain'],
                    [self.ontology, self.database, domain]))
            self.DSTracker = SlotFillingDST(dst_args)

        self.training = self.TRAIN_DST or self.TRAIN_POLICY

        self.load('')
示例#8
0
    def __init__(self, args):
        """
        Load the ontology and database, create some patterns, and preprocess
        the database so that we avoid some computations at runtime.

        :param args:
        """
        super(SlotFillingNLU, self).__init__()

        self.ontology = None
        self.database = None
        self.requestable_only_slots = None
        self.slot_values = None

        if 'ontology' not in args:
            raise AttributeError('SlotFillingNLU: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('SlotFillingNLU: Please provide database!')

        ontology = args['ontology']
        database = args['database']

        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        if database:
            if isinstance(database, DataBase):
                self.database = database

            elif isinstance(database, str):
                if database[-3:] == '.db':
                    self.database = SQLDataBase(database)
                elif database[-5:] == '.json':
                    self.database = JSONDataBase(database)
                else:
                    raise ValueError('Unacceptable database type %s '
                                     % database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)

        # In order to work for simulated users, we need access to possible
        # values of requestable slots
        cursor = self.database.SQL_connection.cursor()

        print('SlotFillingNLU: Preprocessing Database... '
              '(do not use SlotFillingNLU with large databases!)')

        # Get table name
        db_result = cursor.execute("select * from sqlite_master "
                                   "where type = 'table';").fetchall()
        if db_result and db_result[0] and db_result[0][1]:
            db_table_name = db_result[0][1]

            self.slot_values = {}

            # Get all entries in the database
            all_items = cursor.execute("select * from " +
                                       db_table_name + ";").fetchall()

            i = 0

            for item in all_items:
                # Get column names
                slot_names = [i[0] for i in cursor.description]

                result = dict(zip(slot_names, item))

                for slot in result:
                    if slot in ['id', 'signature', 'description']:
                        continue

                    if slot not in self.slot_values:
                        self.slot_values[slot] = []

                    if result[slot] not in self.slot_values[slot]:
                        self.slot_values[slot].append(result[slot])

                i += 1
                if i % 2000 == 0:
                    print(f'{float(i/len(all_items))*100}% done')

            print('SlotFillingNLU: Done!')
        else:
            raise ValueError(
                'dialogue Manager cannot specify Table Name from database '
                '{0}'.format(self.database.db_file_name))

        # For this SlotFillingNLU create a list of requestable-only to reduce
        # computational load
        self.requestable_only_slots = \
            [slot for slot in self.ontology.ontology['requestable']
             if slot not in self.ontology.ontology['informable']] + ['name']

        self.bye_pattern = ['bye', 'goodbye', 'exit', 'quit', 'stop']

        self.hi_pattern = ['hi', 'hello']

        self.welcome_pattern = ['welcome', 'how may i help']

        self.deny_pattern = ['no']

        self.negate_pattern = ['is not']

        self.confirm_pattern = ['so is']

        self.repeat_pattern = ['repeat']

        self.ack_pattern = ['ok']

        self.restart_pattern = ['start over']

        self.affirm_pattern = ['yes']

        self.thankyou_pattern = ['thank you']

        self.reqmore_pattern = ['tell me more']

        self.expl_conf_pattern = ['alright']

        self.reqalts_pattern = ['anything else']

        self.select_pattern = ['you prefer']

        self.dontcare_pattern = ['anything', 'any', 'i do not care',
                                 'i dont care', 'dont care', 'dontcare',
                                 'it does not matter', 'it doesnt matter',
                                 'does not matter', 'doesnt matter']

        self.request_pattern = ['what', 'which', 'where', 'how', 'would']

        self.cant_help_pattern = ['can not help', 'cannot help', 'cant help']

        punctuation = string.punctuation.replace('$', '')
        punctuation = punctuation.replace('_', '')
        punctuation = punctuation.replace('.', '')
        punctuation = punctuation.replace('&', '')
        punctuation = punctuation.replace('-', '')
        punctuation += '.'
        self.punctuation_remover = str.maketrans('', '', punctuation)
示例#9
0
    def __init__(self, args):
        """
        Initialise the user Simulator. Here we initialize structures that
        we need throughout the life of the DTL user Simulator.

        :param args: dictionary containing ontology, database, and policy file
        """
        super(DTLUserSimulator, self).__init__()

        if 'ontology' not in args:
            raise AttributeError('DTLUserSimulator: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('DTLUserSimulator: Please provide database!')
        if 'policy_file' not in args:
            raise AttributeError('DTLUserSimulator: Please provide policy '
                                 'file!')

        ontology = args['ontology']
        database = args['database']
        policy_file = args['policy_file']

        self.policy = None
        self.load(policy_file)

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[-3:] == '.db':
                self.database = SQLDataBase(database)
            elif database[-5:] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)
        else:
            raise ValueError('Unacceptable database type %s ' % database)

        self.input_system_acts = None
        self.goal = None

        self.goal_generator = GoalGenerator({
            'ontology': self.ontology,
            'database': self.database
        })

        self.patience = 3

        if 'patience' in args:
            self.patience = args['patience']

        self.curr_patience = self.patience
        self.prev_sys_acts = None

        self.goal_met = False
        self.offer_made = False