Ejemplo n.º 1
0
    def __init__(self, args):
        """
        Load the ontology.

        :param args: contain the domain ontology
        """
        super(RandomPolicy, self).__init__()

        if 'ontology' in args:
            ontology = args['ontology']
        else:
            raise ValueError('No ontology provided for RandomPolicy!')

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.intents = [
            'welcomemsg', 'inform', 'request', 'hello', 'bye', 'repeat',
            'offer'
        ]
    def __init__(self, args):
        """
        Initializes the internal structures of the SlotFillingDST. Loads the
        DataBase and Ontology, retrieves the DataBase table name, and creates
        the dialogue State.
        :param args:
        """

        super(SlotFillingDST, self).__init__()

        if 'ontology' not in args:
            raise AttributeError('SlotFillingDST: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('SlotFillingDST: Please provide database!')
        if 'domain' not in args:
            raise AttributeError('SlotFillingDST: Please provide domain!')

        ontology = args['ontology']
        database = args['database']
        domain = args['domain']

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[-3:] == '.db':
                self.database = SQLDataBase(database)
            elif database[-5:] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)

        else:
            raise ValueError('Unacceptable datbase type %s ' % database)

        # Get Table name
        self.db_table_name = self.database.get_table_name()

        self.DB_ITEMS = 0  # This will raise an error!

        self.domain = domain
        if domain in ['CamRest', 'SlotFilling']:
            self.DState = \
                SlotFillingDialogueState(
                    {'slots': self.ontology.ontology['system_requestable']})
        else:
            print('Warning! domain has not been defined. Using Slot-Filling '
                  'dialogue State')
            self.DState = \
                SlotFillingDialogueState(
                    {'slots': self.ontology.ontology['system_requestable']})
Ejemplo n.º 3
0
    def __init__(self, args):
        """
        Load ontology and database, Ludwig nlu model, and create the static
        iob tag lists, punctuation, and patterns.

        :param args:
        """
        super(CamRestNLU, self).__init__(args)

        self.ontology = None
        self.database = None

        if 'ontology' not in args:
            raise AttributeError('camrest_nlu: Please provide an ontology!')

        ontology = args['ontology']

        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.iob_tag_list = []
        self.dontcare_pattern = []

        self.punctuation_remover = str.maketrans('', '', string.punctuation)

        self.TRAIN_ONLINE = False
        if 'train_online' in args:
            self.TRAIN_ONLINE = bool(args['train_online'])

        self.iob_tag_list = \
            ['B-inform-' +
             slot for slot in self.ontology.ontology['requestable']] + \
            ['I-inform-' +
             slot for slot in self.ontology.ontology['requestable']]

        self.dontcare_pattern = [
            'anything', 'any', 'i do not care', 'i dont care', 'dont care',
            'dontcare', 'it does not matter', 'it doesnt matter',
            'does not matter', 'doesnt matter'
        ]
Ejemplo n.º 4
0
    def __init__(self, args):
        """
        Load the ontology.

        :param args: contain the domain ontology
        """
        super(HandcraftedPolicy, self).__init__()

        if 'ontology' in args:
            ontology = args['ontology']
        else:
            raise ValueError('No ontology provided for HandcraftedPolicy!')

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)
Ejemplo n.º 5
0
    def __init__(self, args):
        """
        Initializes the internal structures of the Goal Generator and does
        some checks.

        :param args: the goal generator's arguments
        """

        if 'ontology' not in args:
            raise ValueError('Goal Generator called without an ontology!')

        if 'database' not in args:
            raise ValueError('Goal Generator called without a database!')

        self.ontology = None
        if isinstance(args['ontology'], Ontology):
            self.ontology = args['ontology']

        elif isinstance(args['ontology'], str):
            self.ontology = Ontology(args['ontology'])

        else:
            raise ValueError('Unacceptable ontology type %s ' %
                             args['ontology'])

        self.database = None
        if isinstance(args['database'], DataBase):
            self.database = args['database']

        elif isinstance(args['database'], str):
            if args['database'][-3:] == '.db':
                self.database = SQLDataBase(args['database'])
            elif args['database'][-5:] == '.json':
                self.database = JSONDataBase(args['database'])
            else:
                raise ValueError('Unacceptable database type %s ' %
                                 args['database'])
        else:
            raise ValueError('Unacceptable database type %s ' %
                             args['database'])

        self.goals_file = None
        if 'goals_file' in args:
            self.goals_file = args['goals_file']

        self.goals = None

        if self.goals_file:
            self.load_goals(self.goals_file)

        # Get the slot names from the database
        cursor = self.database.SQL_connection.cursor()

        # Get Table name
        result = \
            cursor.execute(
                "select * from sqlite_master where type = 'table';").fetchall()
        if result and result[0] and result[0][1]:
            self.db_table_name = result[0][1]
        else:
            raise ValueError('Goal Generator cannot specify Table Name from '
                             'database {0}'.format(self.database.db_file_name))

        # Dummy SQL command
        sql_command = "SELECT * FROM " + self.db_table_name + " LIMIT 1;"

        cursor.execute(sql_command)
        self.slot_names = [i[0] for i in cursor.description]

        self.db_row_count = \
            cursor.execute("SELECT COUNT(*) FROM " +
                           self.db_table_name + ";").fetchall()[0][0]
Ejemplo n.º 6
0
    def __init__(self, args):
        """
        Parses the arguments in the dictionary and initializes the appropriate
        models for dialogue State Tracking and dialogue Policy.

        :param args: the configuration file parsed into a dictionary
        """
        
        super(DialogueManager, self).__init__()

        if 'settings' not in args:
            raise AttributeError(
                'DialogueManager: Please provide settings (config)!')
        if 'ontology' not in args:
            raise AttributeError('DialogueManager: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('DialogueManager: Please provide database!')
        if 'domain' not in args:
            raise AttributeError('DialogueManager: Please provide domain!')

        settings = args['settings']
        ontology = args['ontology']
        database = args['database']
        domain = args['domain']

        agent_id = 0
        if 'agent_id' in args:
            agent_id = int(args['agent_id'])

        agent_role = 'system'
        if 'agent_role' in args:
            agent_role = args['agent_role']

        self.settings = settings

        self.TRAIN_DST = False
        self.TRAIN_POLICY = False

        self.MAX_DB_RESULTS = 10

        self.DSTracker = None
        self.policy = None
        self.policy_path = None
        self.ontology = None
        self.database = None
        self.domain = None

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.dialogue_counter = 0
        self.CALCULATE_SLOT_ENTROPIES = True

        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[-3:] == '.db':
                self.database = SQLDataBase(database)
            elif database[-5:] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)

        else:
            raise ValueError('Unacceptable database type %s ' % database)
                
        if args and args['policy']:
            if 'domain' in self.settings['DIALOGUE']:
                self.domain = self.settings['DIALOGUE']['domain']
            else:
                raise ValueError(
                    'domain is not specified in DIALOGUE at config.')

            if 'calculate_slot_entropies' in args:
                self.CALCULATE_SLOT_ENTROPIES = \
                    bool(args['calculate_slot_entropies'])

            if args['policy']['type'] == 'handcrafted':
                self.policy = HandcraftedPolicy({'ontology': self.ontology})

            elif args['policy']['type'] == 'q_learning':
                alpha = None
                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                gamma = None
                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                epsilon = None
                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                alpha_decay = None
                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                epsilon_decay = None
                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    QPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain,
                        'alpha': alpha,
                        'epsilon': epsilon,
                        'gamma': gamma,
                        'alpha_decay': alpha_decay,
                        'epsilon_decay': epsilon_decay})

            elif args['policy']['type'] == 'minimax_q':
                alpha = 0.25
                gamma = 0.95
                epsilon = 0.25
                alpha_decay = 0.9995
                epsilon_decay = 0.995

                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    MinimaxQPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain,
                        'alpha': alpha,
                        'epsilon': epsilon,
                        'gamma': gamma,
                        'alpha_decay': alpha_decay,
                        'epsilon_decay': epsilon_decay})

            elif args['policy']['type'] == 'wolf_phc':
                alpha = 0.25
                gamma = 0.95
                epsilon = 0.25
                alpha_decay = 0.9995
                epsilon_decay = 0.995

                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    WoLFPHCPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain,
                        'alpha': alpha,
                        'epsilon': epsilon,
                        'gamma': gamma,
                        'alpha_decay': alpha_decay,
                        'epsilon_decay': epsilon_decay})

            elif args['policy']['type'] == 'reinforce':
                alpha = None
                if 'learning_rate' in args['policy']:
                    alpha = float(args['policy']['learning_rate'])

                gamma = None
                if 'discount_factor' in args['policy']:
                    gamma = float(args['policy']['discount_factor'])

                epsilon = None
                if 'exploration_rate' in args['policy']:
                    epsilon = float(args['policy']['exploration_rate'])

                alpha_decay = None
                if 'learning_decay_rate' in args['policy']:
                    alpha_decay = float(args['policy']['learning_decay_rate'])

                epsilon_decay = None
                if 'exploration_decay_rate' in args['policy']:
                    epsilon_decay = \
                        float(args['policy']['exploration_decay_rate'])

                self.policy = \
                    ReinforcePolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain,
                        'alpha': alpha,
                        'epsilon': epsilon,
                        'gamma': gamma,
                        'alpha_decay': alpha_decay,
                        'epsilon_decay': epsilon_decay})

            elif args['policy']['type'] == 'calculated':
                self.policy = \
                    CalculatedPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain})

            elif args['policy']['type'] == 'supervised':
                self.policy = \
                    SupervisedPolicy({
                        'ontology': self.ontology,
                        'database': self.database,
                        'agent_id': self.agent_id,
                        'agent_role': self.agent_role,
                        'domain': self.domain})

            elif args['policy']['type'] == 'ludwig':
                if args['policy']['policy_path']:
                    print('DialogueManager: Instantiate your ludwig-based'
                          'policy here')
                else:
                    raise ValueError(
                        'Cannot find policy_path in the config for dialogue '
                        'policy.')
            else:
                raise ValueError('DialogueManager: Unsupported policy type!'
                                 .format(args['policy']['type']))

            if 'train' in args['policy']:
                self.TRAIN_POLICY = bool(args['policy']['train'])

            if 'policy_path' in args['policy']:
                self.policy_path = args['policy']['policy_path']

        # DST Settings
        if 'DST' in args and args['DST']['dst']:
                if args['DST']['dst'] == 'CamRest':
                    if args['DST']['policy']['model_path'] and \
                            args['DST']['policy']['metadata_path']:
                        self.DSTracker = \
                            CamRestDST(
                                {'model_path': args[
                                    'DST']['policy']['model_path']})
                    else:
                        raise ValueError(
                            'Cannot find model_path or metadata_path in the '
                            'config for dialogue state tracker.')

        # Default to dummy DST
        if not self.DSTracker:
            dst_args = dict(
                zip(
                    ['ontology', 'database', 'domain'],
                    [self.ontology, self.database, domain]))
            self.DSTracker = SlotFillingDST(dst_args)

        self.training = self.TRAIN_DST or self.TRAIN_POLICY

        self.load('')
Ejemplo n.º 7
0
    def __init__(self, args):
        """
        Initialize parameters and internal structures

        :param args: dictionary containing the dialogue_policy's settings
        """

        super(WoLFPHCPolicy, self).__init__()

        self.ontology = None
        if 'ontology' in args:
            ontology = args['ontology']

            if isinstance(ontology, Ontology):
                self.ontology = ontology
            elif isinstance(ontology, str):
                self.ontology = Ontology(ontology)
            else:
                raise ValueError('WoLFPHCPolicy Unacceptable '
                                 'ontology type %s ' % ontology)
        else:
            raise ValueError('WoLFPHCPolicy: No ontology provided')

        self.database = None
        if 'database' in args:
            database = args['database']

            if isinstance(database, DataBase):
                self.database = database
            elif isinstance(database, str):
                self.database = DataBase(database)
            else:
                raise ValueError('WoLFPHCPolicy: Unacceptable '
                                 'database type %s ' % database)
        else:
            raise ValueError('WoLFPHCPolicy: No database provided')

        self.agent_role = \
            args['agent_role'] if 'agent_role' in args else 'system'

        self.alpha = args['alpha'] if 'alpha' in args else 0.2
        self.gamma = args['gamma'] if 'gamma' in args else 0.95
        self.epsilon = args['epsilon'] if 'epsilon' in args else 0.95
        self.alpha_decay_rate = \
            args['alpha_decay'] if 'alpha_decay' in args else 0.995
        self.exploration_decay_rate = \
            args['epsilon_decay'] if 'epsilon_decay' in args else 0.9995

        self.IS_GREEDY_POLICY = False

        # TODO: Put these as arguments in the config
        self.d_win = 0.0025
        self.d_lose = 0.01

        self.is_training = False

        self.Q = {}
        self.pi = {}
        self.mean_pi = {}
        self.state_counter = {}

        self.pp = pprint.PrettyPrinter(width=160)  # For debug!

        # System and user expert policies (optional)
        self.warmup_policy = None
        self.warmup_simulator = None

        if self.agent_role == 'system':
            # Put your system expert dialogue_policy here
            self.warmup_policy = \
                slot_filling_policy.HandcraftedPolicy({
                    'ontology': self.ontology})

        elif self.agent_role == 'user':
            usim_args = dict(
                zip(['ontology', 'database'], [self.ontology, self.database]))
            # Put your user expert dialogue_policy here
            self.warmup_simulator = AgendaBasedUS(usim_args)

        # Sub-case for CamRest
        self.dstc2_acts_sys = self.dstc2_acts_usr = None

        # Plato does not use action masks (rules to define which
        # actions are valid from each state) and so training can
        # be harder. This becomes easier if we have a smaller
        # action set.

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_sys = [
            'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye', 'reqmore',
            'welcomemsg', 'expl-conf', 'select', 'repeat', 'confirm-domain',
            'confirm'
        ]

        # Does not include inform and request that are modelled together with
        # their arguments
        self.dstc2_acts_usr = [
            'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye', 'reqmore',
            'hello', 'expl-conf', 'repeat', 'reqalts', 'restart', 'confirm'
        ]

        # Extract lists of slots that are frequently used
        self.informable_slots = \
            deepcopy(list(self.ontology.ontology['informable'].keys()))
        self.requestable_slots = \
            deepcopy(self.ontology.ontology['requestable'])
        self.system_requestable_slots = \
            deepcopy(self.ontology.ontology['system_requestable'])

        if self.dstc2_acts_sys:
            if self.agent_role == 'system':
                # self.NActions = 5
                # self.NOtherActions = 4
                self.NActions = \
                    len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

                self.NOtherActions = \
                    len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)

            elif self.agent_role == 'user':
                # self.NActions = 4
                # self.NOtherActions = 5
                self.NActions = \
                    len(self.dstc2_acts_usr) + \
                    len(self.requestable_slots) +\
                    len(self.system_requestable_slots)

                self.NOtherActions = len(self.dstc2_acts_sys) + \
                    len(self.requestable_slots) + \
                    len(self.system_requestable_slots)
        else:
            if self.agent_role == 'system':
                self.NActions = \
                    5 + len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])
                self.NOtherActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])

            elif self.agent_role == 'user':
                self.NActions = \
                    4 + 2 * len(self.ontology.ontology['requestable'])
                self.NOtherActions = \
                    5 + len(self.ontology.ontology['system_requestable']) + \
                    len(self.ontology.ontology['requestable'])

        self.statistics = {'supervised_turns': 0, 'total_turns': 0}
    def __init__(self, args):
        """
        Initializes the internal structures of the Agenda-Based usr Simulator

        :param args: a dictionary containing an ontology, a database, and
                     other necessary arguments
        """

        super(AgendaBasedUS, self).__init__()

        if 'ontology' not in args:
            raise AttributeError('AgendaBasedUS: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('AgendaBasedUS: Please provide database!')

        ontology = args['ontology']
        database = args['database']

        um = None
        if 'um' in args:
            um = args['um']

        self.nlu = None
        self.nlg = None
        self.dialogue_turn = 0
        self.us_has_initiative = False
        self.policy = None
        self.goals_path = None

        if um is not None:
            self.user_model = um

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[-3:] == '.db':
                self.database = SQLDataBase(database)
            elif database[-5:] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)
        else:
            raise ValueError('Unacceptable database type %s ' % database)

        self.patience = 3

        # Initialize probabilities
        if 'patience' in args:
            self.patience = args['patience']
        if 'pop_distribution' in args:
            self.pop_distribution = args['pop_distribution']
        if 'slot_confuse_prob' in args:
            self.slot_confuse_prob = args['slot_confuse_prob']
        if 'op_confuse_prob' in args:
            self.op_confuse_prob = args['op_confuse_prob']
        if 'value_confuse_prob' in args:
            self.value_confuse_prob = args['value_confuse_prob']

        self.goal_slot_selection_weights = None
        if 'goal_slot_selection_weights' in args:
            self.goal_slot_selection_weights = \
                args['goal_slot_selection_weights']

        if 'nlu' in args:
            nlu_args = \
                dict(zip(['ontology', 'database'],
                         [self.ontology, self.database]))

            if args['nlu'] == 'CamRest':
                self.nlu = CamRestNLU(nlu_args)

            elif args['nlu'] == 'slot_filling':
                self.nlu = SlotFillingNLU(nlu_args)

        if 'nlg' in args:
            if args['nlg'] == 'CamRest':
                if args['nlg_model_path'] and args['nlg_metadata_path']:
                    self.nlg = \
                        CamRestNLG({'model_path': args['nlg_model_path']})
                else:
                    raise ValueError('ABUS: Cannot initialize CamRest nlg '
                                     'without a model path AND a metadata '
                                     'path.')

            elif args['nlg'] == 'slot_filling':
                self.nlg = SlotFillingNLG()

        if 'goals_path' in args:
            self.goals_path = args['goals_path']

        if 'policy_file' in args:
            self.load(args['policy_file'])

        if 'us_has_initiative' in args:
            self.us_has_initiative = args['us_has_initiative']

        self.curr_patience = self.patience

        # Default values for probabilities
        self.pop_distribution = [1.0]
        self.slot_confuse_prob = 0.0
        self.op_confuse_prob = 0.0
        self.value_confuse_prob = 0.0

        self.agenda = agenda.Agenda()
        self.error_model = error_model.ErrorModel(self.ontology, self.database,
                                                  self.slot_confuse_prob,
                                                  self.op_confuse_prob,
                                                  self.value_confuse_prob)

        self.goal_generator = goal.GoalGenerator({
            'ontology': self.ontology,
            'database': self.database,
            'goals_file': self.goals_path
        })
        self.goal = None
        self.offer_made = False
        self.prev_offer_name = None

        # Store previous system actions to keep track of patience
        self.prev_system_acts = None
    def __init__(self, args):
        """
        Parses the arguments in the dictionary and initializes the appropriate
        models for dialogue State Tracking and dialogue Policy.

        :param args: the configuration file parsed into a dictionary
        """

        super(DialogueManagerGeneric, self).__init__()

        if 'settings' not in args:
            raise AttributeError(
                'DialogueManagerGeneric: Please provide settings (config)!')
        if 'ontology' not in args:
            raise AttributeError(
                'DialogueManagerGeneric: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError(
                'DialogueManagerGeneric: Please provide database!')
        if 'domain' not in args:
            raise AttributeError(
                'DialogueManagerGeneric: Please provide domain!')

        settings = args['settings']
        ontology = args['ontology']
        database = args['database']
        domain = args['domain']

        agent_id = 0
        if 'agent_id' in args:
            agent_id = int(args['agent_id'])

        agent_role = 'system'
        if 'agent_role' in args:
            agent_role = args['agent_role']

        self.settings = settings

        self.TRAIN_DST = False
        self.TRAIN_POLICY = False

        self.MAX_DB_RESULTS = 10

        self.DSTracker = None
        self.DSTracker_info = {}

        self.policy = None
        self.policy_info = {}

        self.policy_path = None
        self.ontology = None
        self.database = None
        self.domain = None

        self.agent_id = agent_id
        self.agent_role = agent_role

        self.dialogue_counter = 0
        self.CALCULATE_SLOT_ENTROPIES = True

        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[-3:] == '.db':
                self.database = SQLDataBase(database)
            elif database[-5:] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)

        else:
            raise ValueError('Unacceptable database type %s ' % database)
                
        if args and args['policy']:
            if 'domain' in self.settings['DIALOGUE']:
                self.domain = self.settings['DIALOGUE']['domain']
            else:
                raise ValueError(
                    'domain is not specified in DIALOGUE at config.')

            if 'calculate_slot_entropies' in args:
                self.CALCULATE_SLOT_ENTROPIES = \
                    bool(args['calculate_slot_entropies'])

            if 'package' in args['policy'] and 'class' in args['policy']:
                self.policy_info = args['policy']

                if 'global_arguments' in args['settings']['GENERAL']:
                    if 'arguments' not in self.policy_info:
                        self.policy_info['arguments'] = {}

                    self.policy_info['arguments'].update(
                        args['settings']['GENERAL']['global_arguments']
                    )

                if 'train' in self.policy_info['arguments']:
                    self.TRAIN_POLICY = \
                        bool(self.policy_info['arguments']['train'])

                if 'policy_path' in self.policy_info['arguments']:
                    self.policy_path = \
                        self.policy_info['arguments']['policy_path']

                self.policy_info['arguments']['agent_role'] = self.agent_role

                # Replace ontology and database strings with the actual
                # objects to avoid repetitions (these won't change).
                if 'ontology' in self.policy_info['arguments']:
                    self.policy_info['arguments']['ontology'] = self.ontology

                if 'database' in self.policy_info['arguments']:
                    self.policy_info['arguments']['database'] = self.database

                self.policy = ConversationalGenericAgent.load_module(
                    self.policy_info['package'],
                    self.policy_info['class'],
                    self.policy_info['arguments']
                )

            else:
                raise ValueError('DialogueManagerGeneric: Cannot instantiate'
                                 'dialogue policy!')

        # DST Settings
        if 'DST' in args and args:
            if 'package' in args['DST'] and 'class' in args['DST']:
                self.DSTracker_info['package'] = args['DST']['package']
                self.DSTracker_info['class'] = args['DST']['class']

                self.DSTracker_info['args'] = {}

                if 'global_arguments' in args['settings']['GENERAL']:
                    self.DSTracker_info['args'] = \
                        args['settings']['GENERAL']['global_arguments']

                if 'arguments' in args['DST']:
                    self.DSTracker_info['args']. \
                        update(args['DST']['arguments'])

                self.DSTracker = ConversationalGenericAgent.load_module(
                    self.DSTracker_info['package'],
                    self.DSTracker_info['class'],
                    self.DSTracker_info['args']
                )

            else:
                raise ValueError('DialogueManagerGeneric: Cannot instantiate'
                                 'dialogue state tracker!')

        # Default to dummy DST, if no information is provided
        else:
            dst_args = dict(
                zip(
                    ['ontology', 'database', 'domain'],
                    [self.ontology, self.database, domain]))
            self.DSTracker = SlotFillingDST(dst_args)

        self.training = self.TRAIN_DST or self.TRAIN_POLICY

        self.load('')
Ejemplo n.º 10
0
    def __init__(self, args):
        """
        Load the ontology and database, create some patterns, and preprocess
        the database so that we avoid some computations at runtime.

        :param args:
        """
        super(SlotFillingNLU, self).__init__()

        self.ontology = None
        self.database = None
        self.requestable_only_slots = None
        self.slot_values = None

        if 'ontology' not in args:
            raise AttributeError('SlotFillingNLU: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('SlotFillingNLU: Please provide database!')

        ontology = args['ontology']
        database = args['database']

        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        if database:
            if isinstance(database, DataBase):
                self.database = database

            elif isinstance(database, str):
                if database[-3:] == '.db':
                    self.database = SQLDataBase(database)
                elif database[-5:] == '.json':
                    self.database = JSONDataBase(database)
                else:
                    raise ValueError('Unacceptable database type %s '
                                     % database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)

        # In order to work for simulated users, we need access to possible
        # values of requestable slots
        cursor = self.database.SQL_connection.cursor()

        print('SlotFillingNLU: Preprocessing Database... '
              '(do not use SlotFillingNLU with large databases!)')

        # Get table name
        db_result = cursor.execute("select * from sqlite_master "
                                   "where type = 'table';").fetchall()
        if db_result and db_result[0] and db_result[0][1]:
            db_table_name = db_result[0][1]

            self.slot_values = {}

            # Get all entries in the database
            all_items = cursor.execute("select * from " +
                                       db_table_name + ";").fetchall()

            i = 0

            for item in all_items:
                # Get column names
                slot_names = [i[0] for i in cursor.description]

                result = dict(zip(slot_names, item))

                for slot in result:
                    if slot in ['id', 'signature', 'description']:
                        continue

                    if slot not in self.slot_values:
                        self.slot_values[slot] = []

                    if result[slot] not in self.slot_values[slot]:
                        self.slot_values[slot].append(result[slot])

                i += 1
                if i % 2000 == 0:
                    print(f'{float(i/len(all_items))*100}% done')

            print('SlotFillingNLU: Done!')
        else:
            raise ValueError(
                'dialogue Manager cannot specify Table Name from database '
                '{0}'.format(self.database.db_file_name))

        # For this SlotFillingNLU create a list of requestable-only to reduce
        # computational load
        self.requestable_only_slots = \
            [slot for slot in self.ontology.ontology['requestable']
             if slot not in self.ontology.ontology['informable']] + ['name']

        self.bye_pattern = ['bye', 'goodbye', 'exit', 'quit', 'stop']

        self.hi_pattern = ['hi', 'hello']

        self.welcome_pattern = ['welcome', 'how may i help']

        self.deny_pattern = ['no']

        self.negate_pattern = ['is not']

        self.confirm_pattern = ['so is']

        self.repeat_pattern = ['repeat']

        self.ack_pattern = ['ok']

        self.restart_pattern = ['start over']

        self.affirm_pattern = ['yes']

        self.thankyou_pattern = ['thank you']

        self.reqmore_pattern = ['tell me more']

        self.expl_conf_pattern = ['alright']

        self.reqalts_pattern = ['anything else']

        self.select_pattern = ['you prefer']

        self.dontcare_pattern = ['anything', 'any', 'i do not care',
                                 'i dont care', 'dont care', 'dontcare',
                                 'it does not matter', 'it doesnt matter',
                                 'does not matter', 'doesnt matter']

        self.request_pattern = ['what', 'which', 'where', 'how', 'would']

        self.cant_help_pattern = ['can not help', 'cannot help', 'cant help']

        punctuation = string.punctuation.replace('$', '')
        punctuation = punctuation.replace('_', '')
        punctuation = punctuation.replace('.', '')
        punctuation = punctuation.replace('&', '')
        punctuation = punctuation.replace('-', '')
        punctuation += '.'
        self.punctuation_remover = str.maketrans('', '', punctuation)
Ejemplo n.º 11
0
    def __init__(self, args):
        """
        Initialise the user Simulator. Here we initialize structures that
        we need throughout the life of the DTL user Simulator.

        :param args: dictionary containing ontology, database, and policy file
        """
        super(DTLUserSimulator, self).__init__()

        if 'ontology' not in args:
            raise AttributeError('DTLUserSimulator: Please provide ontology!')
        if 'database' not in args:
            raise AttributeError('DTLUserSimulator: Please provide database!')
        if 'policy_file' not in args:
            raise AttributeError('DTLUserSimulator: Please provide policy '
                                 'file!')

        ontology = args['ontology']
        database = args['database']
        policy_file = args['policy_file']

        self.policy = None
        self.load(policy_file)

        self.ontology = None
        if isinstance(ontology, Ontology):
            self.ontology = ontology
        elif isinstance(ontology, str):
            self.ontology = Ontology(ontology)
        else:
            raise ValueError('Unacceptable ontology type %s ' % ontology)

        self.database = None
        if isinstance(database, DataBase):
            self.database = database

        elif isinstance(database, str):
            if database[-3:] == '.db':
                self.database = SQLDataBase(database)
            elif database[-5:] == '.json':
                self.database = JSONDataBase(database)
            else:
                raise ValueError('Unacceptable database type %s ' % database)
        else:
            raise ValueError('Unacceptable database type %s ' % database)

        self.input_system_acts = None
        self.goal = None

        self.goal_generator = GoalGenerator({
            'ontology': self.ontology,
            'database': self.database
        })

        self.patience = 3

        if 'patience' in args:
            self.patience = args['patience']

        self.curr_patience = self.patience
        self.prev_sys_acts = None

        self.goal_met = False
        self.offer_made = False