def __init__(self, args): """ Load the ontology. :param args: contain the domain ontology """ super(RandomPolicy, self).__init__() if 'ontology' in args: ontology = args['ontology'] else: raise ValueError('No ontology provided for RandomPolicy!') self.ontology = None if isinstance(ontology, Ontology): self.ontology = ontology elif isinstance(ontology, str): self.ontology = Ontology(ontology) else: raise ValueError('Unacceptable ontology type %s ' % ontology) self.intents = [ 'welcomemsg', 'inform', 'request', 'hello', 'bye', 'repeat', 'offer' ]
def __init__(self, args): """ Initializes the internal structures of the SlotFillingDST. Loads the DataBase and Ontology, retrieves the DataBase table name, and creates the dialogue State. :param args: """ super(SlotFillingDST, self).__init__() if 'ontology' not in args: raise AttributeError('SlotFillingDST: Please provide ontology!') if 'database' not in args: raise AttributeError('SlotFillingDST: Please provide database!') if 'domain' not in args: raise AttributeError('SlotFillingDST: Please provide domain!') ontology = args['ontology'] database = args['database'] domain = args['domain'] self.ontology = None if isinstance(ontology, Ontology): self.ontology = ontology elif isinstance(ontology, str): self.ontology = Ontology(ontology) else: raise ValueError('Unacceptable ontology type %s ' % ontology) self.database = None if isinstance(database, DataBase): self.database = database elif isinstance(database, str): if database[-3:] == '.db': self.database = SQLDataBase(database) elif database[-5:] == '.json': self.database = JSONDataBase(database) else: raise ValueError('Unacceptable database type %s ' % database) else: raise ValueError('Unacceptable datbase type %s ' % database) # Get Table name self.db_table_name = self.database.get_table_name() self.DB_ITEMS = 0 # This will raise an error! self.domain = domain if domain in ['CamRest', 'SlotFilling']: self.DState = \ SlotFillingDialogueState( {'slots': self.ontology.ontology['system_requestable']}) else: print('Warning! domain has not been defined. Using Slot-Filling ' 'dialogue State') self.DState = \ SlotFillingDialogueState( {'slots': self.ontology.ontology['system_requestable']})
def __init__(self, args): """ Load ontology and database, Ludwig nlu model, and create the static iob tag lists, punctuation, and patterns. :param args: """ super(CamRestNLU, self).__init__(args) self.ontology = None self.database = None if 'ontology' not in args: raise AttributeError('camrest_nlu: Please provide an ontology!') ontology = args['ontology'] if isinstance(ontology, Ontology): self.ontology = ontology elif isinstance(ontology, str): self.ontology = Ontology(ontology) else: raise ValueError('Unacceptable ontology type %s ' % ontology) self.iob_tag_list = [] self.dontcare_pattern = [] self.punctuation_remover = str.maketrans('', '', string.punctuation) self.TRAIN_ONLINE = False if 'train_online' in args: self.TRAIN_ONLINE = bool(args['train_online']) self.iob_tag_list = \ ['B-inform-' + slot for slot in self.ontology.ontology['requestable']] + \ ['I-inform-' + slot for slot in self.ontology.ontology['requestable']] self.dontcare_pattern = [ 'anything', 'any', 'i do not care', 'i dont care', 'dont care', 'dontcare', 'it does not matter', 'it doesnt matter', 'does not matter', 'doesnt matter' ]
def __init__(self, args): """ Load the ontology. :param args: contain the domain ontology """ super(HandcraftedPolicy, self).__init__() if 'ontology' in args: ontology = args['ontology'] else: raise ValueError('No ontology provided for HandcraftedPolicy!') self.ontology = None if isinstance(ontology, Ontology): self.ontology = ontology elif isinstance(ontology, str): self.ontology = Ontology(ontology) else: raise ValueError('Unacceptable ontology type %s ' % ontology)
def __init__(self, args): """ Initializes the internal structures of the Goal Generator and does some checks. :param args: the goal generator's arguments """ if 'ontology' not in args: raise ValueError('Goal Generator called without an ontology!') if 'database' not in args: raise ValueError('Goal Generator called without a database!') self.ontology = None if isinstance(args['ontology'], Ontology): self.ontology = args['ontology'] elif isinstance(args['ontology'], str): self.ontology = Ontology(args['ontology']) else: raise ValueError('Unacceptable ontology type %s ' % args['ontology']) self.database = None if isinstance(args['database'], DataBase): self.database = args['database'] elif isinstance(args['database'], str): if args['database'][-3:] == '.db': self.database = SQLDataBase(args['database']) elif args['database'][-5:] == '.json': self.database = JSONDataBase(args['database']) else: raise ValueError('Unacceptable database type %s ' % args['database']) else: raise ValueError('Unacceptable database type %s ' % args['database']) self.goals_file = None if 'goals_file' in args: self.goals_file = args['goals_file'] self.goals = None if self.goals_file: self.load_goals(self.goals_file) # Get the slot names from the database cursor = self.database.SQL_connection.cursor() # Get Table name result = \ cursor.execute( "select * from sqlite_master where type = 'table';").fetchall() if result and result[0] and result[0][1]: self.db_table_name = result[0][1] else: raise ValueError('Goal Generator cannot specify Table Name from ' 'database {0}'.format(self.database.db_file_name)) # Dummy SQL command sql_command = "SELECT * FROM " + self.db_table_name + " LIMIT 1;" cursor.execute(sql_command) self.slot_names = [i[0] for i in cursor.description] self.db_row_count = \ cursor.execute("SELECT COUNT(*) FROM " + self.db_table_name + ";").fetchall()[0][0]
def __init__(self, args): """ Parses the arguments in the dictionary and initializes the appropriate models for dialogue State Tracking and dialogue Policy. :param args: the configuration file parsed into a dictionary """ super(DialogueManager, self).__init__() if 'settings' not in args: raise AttributeError( 'DialogueManager: Please provide settings (config)!') if 'ontology' not in args: raise AttributeError('DialogueManager: Please provide ontology!') if 'database' not in args: raise AttributeError('DialogueManager: Please provide database!') if 'domain' not in args: raise AttributeError('DialogueManager: Please provide domain!') settings = args['settings'] ontology = args['ontology'] database = args['database'] domain = args['domain'] agent_id = 0 if 'agent_id' in args: agent_id = int(args['agent_id']) agent_role = 'system' if 'agent_role' in args: agent_role = args['agent_role'] self.settings = settings self.TRAIN_DST = False self.TRAIN_POLICY = False self.MAX_DB_RESULTS = 10 self.DSTracker = None self.policy = None self.policy_path = None self.ontology = None self.database = None self.domain = None self.agent_id = agent_id self.agent_role = agent_role self.dialogue_counter = 0 self.CALCULATE_SLOT_ENTROPIES = True if isinstance(ontology, Ontology): self.ontology = ontology elif isinstance(ontology, str): self.ontology = Ontology(ontology) else: raise ValueError('Unacceptable ontology type %s ' % ontology) if isinstance(database, DataBase): self.database = database elif isinstance(database, str): if database[-3:] == '.db': self.database = SQLDataBase(database) elif database[-5:] == '.json': self.database = JSONDataBase(database) else: raise ValueError('Unacceptable database type %s ' % database) else: raise ValueError('Unacceptable database type %s ' % database) if args and args['policy']: if 'domain' in self.settings['DIALOGUE']: self.domain = self.settings['DIALOGUE']['domain'] else: raise ValueError( 'domain is not specified in DIALOGUE at config.') if 'calculate_slot_entropies' in args: self.CALCULATE_SLOT_ENTROPIES = \ bool(args['calculate_slot_entropies']) if args['policy']['type'] == 'handcrafted': self.policy = HandcraftedPolicy({'ontology': self.ontology}) elif args['policy']['type'] == 'q_learning': alpha = None if 'learning_rate' in args['policy']: alpha = float(args['policy']['learning_rate']) gamma = None if 'discount_factor' in args['policy']: gamma = float(args['policy']['discount_factor']) epsilon = None if 'exploration_rate' in args['policy']: epsilon = float(args['policy']['exploration_rate']) alpha_decay = None if 'learning_decay_rate' in args['policy']: alpha_decay = float(args['policy']['learning_decay_rate']) epsilon_decay = None if 'exploration_decay_rate' in args['policy']: epsilon_decay = \ float(args['policy']['exploration_decay_rate']) self.policy = \ QPolicy({ 'ontology': self.ontology, 'database': self.database, 'agent_id': self.agent_id, 'agent_role': self.agent_role, 'domain': self.domain, 'alpha': alpha, 'epsilon': epsilon, 'gamma': gamma, 'alpha_decay': alpha_decay, 'epsilon_decay': epsilon_decay}) elif args['policy']['type'] == 'minimax_q': alpha = 0.25 gamma = 0.95 epsilon = 0.25 alpha_decay = 0.9995 epsilon_decay = 0.995 if 'learning_rate' in args['policy']: alpha = float(args['policy']['learning_rate']) if 'discount_factor' in args['policy']: gamma = float(args['policy']['discount_factor']) if 'exploration_rate' in args['policy']: epsilon = float(args['policy']['exploration_rate']) if 'learning_decay_rate' in args['policy']: alpha_decay = float(args['policy']['learning_decay_rate']) if 'exploration_decay_rate' in args['policy']: epsilon_decay = \ float(args['policy']['exploration_decay_rate']) self.policy = \ MinimaxQPolicy({ 'ontology': self.ontology, 'database': self.database, 'agent_id': self.agent_id, 'agent_role': self.agent_role, 'domain': self.domain, 'alpha': alpha, 'epsilon': epsilon, 'gamma': gamma, 'alpha_decay': alpha_decay, 'epsilon_decay': epsilon_decay}) elif args['policy']['type'] == 'wolf_phc': alpha = 0.25 gamma = 0.95 epsilon = 0.25 alpha_decay = 0.9995 epsilon_decay = 0.995 if 'learning_rate' in args['policy']: alpha = float(args['policy']['learning_rate']) if 'discount_factor' in args['policy']: gamma = float(args['policy']['discount_factor']) if 'exploration_rate' in args['policy']: epsilon = float(args['policy']['exploration_rate']) if 'learning_decay_rate' in args['policy']: alpha_decay = float(args['policy']['learning_decay_rate']) if 'exploration_decay_rate' in args['policy']: epsilon_decay = \ float(args['policy']['exploration_decay_rate']) self.policy = \ WoLFPHCPolicy({ 'ontology': self.ontology, 'database': self.database, 'agent_id': self.agent_id, 'agent_role': self.agent_role, 'domain': self.domain, 'alpha': alpha, 'epsilon': epsilon, 'gamma': gamma, 'alpha_decay': alpha_decay, 'epsilon_decay': epsilon_decay}) elif args['policy']['type'] == 'reinforce': alpha = None if 'learning_rate' in args['policy']: alpha = float(args['policy']['learning_rate']) gamma = None if 'discount_factor' in args['policy']: gamma = float(args['policy']['discount_factor']) epsilon = None if 'exploration_rate' in args['policy']: epsilon = float(args['policy']['exploration_rate']) alpha_decay = None if 'learning_decay_rate' in args['policy']: alpha_decay = float(args['policy']['learning_decay_rate']) epsilon_decay = None if 'exploration_decay_rate' in args['policy']: epsilon_decay = \ float(args['policy']['exploration_decay_rate']) self.policy = \ ReinforcePolicy({ 'ontology': self.ontology, 'database': self.database, 'agent_id': self.agent_id, 'agent_role': self.agent_role, 'domain': self.domain, 'alpha': alpha, 'epsilon': epsilon, 'gamma': gamma, 'alpha_decay': alpha_decay, 'epsilon_decay': epsilon_decay}) elif args['policy']['type'] == 'calculated': self.policy = \ CalculatedPolicy({ 'ontology': self.ontology, 'database': self.database, 'agent_id': self.agent_id, 'agent_role': self.agent_role, 'domain': self.domain}) elif args['policy']['type'] == 'supervised': self.policy = \ SupervisedPolicy({ 'ontology': self.ontology, 'database': self.database, 'agent_id': self.agent_id, 'agent_role': self.agent_role, 'domain': self.domain}) elif args['policy']['type'] == 'ludwig': if args['policy']['policy_path']: print('DialogueManager: Instantiate your ludwig-based' 'policy here') else: raise ValueError( 'Cannot find policy_path in the config for dialogue ' 'policy.') else: raise ValueError('DialogueManager: Unsupported policy type!' .format(args['policy']['type'])) if 'train' in args['policy']: self.TRAIN_POLICY = bool(args['policy']['train']) if 'policy_path' in args['policy']: self.policy_path = args['policy']['policy_path'] # DST Settings if 'DST' in args and args['DST']['dst']: if args['DST']['dst'] == 'CamRest': if args['DST']['policy']['model_path'] and \ args['DST']['policy']['metadata_path']: self.DSTracker = \ CamRestDST( {'model_path': args[ 'DST']['policy']['model_path']}) else: raise ValueError( 'Cannot find model_path or metadata_path in the ' 'config for dialogue state tracker.') # Default to dummy DST if not self.DSTracker: dst_args = dict( zip( ['ontology', 'database', 'domain'], [self.ontology, self.database, domain])) self.DSTracker = SlotFillingDST(dst_args) self.training = self.TRAIN_DST or self.TRAIN_POLICY self.load('')
def __init__(self, args): """ Initialize parameters and internal structures :param args: dictionary containing the dialogue_policy's settings """ super(WoLFPHCPolicy, self).__init__() self.ontology = None if 'ontology' in args: ontology = args['ontology'] if isinstance(ontology, Ontology): self.ontology = ontology elif isinstance(ontology, str): self.ontology = Ontology(ontology) else: raise ValueError('WoLFPHCPolicy Unacceptable ' 'ontology type %s ' % ontology) else: raise ValueError('WoLFPHCPolicy: No ontology provided') self.database = None if 'database' in args: database = args['database'] if isinstance(database, DataBase): self.database = database elif isinstance(database, str): self.database = DataBase(database) else: raise ValueError('WoLFPHCPolicy: Unacceptable ' 'database type %s ' % database) else: raise ValueError('WoLFPHCPolicy: No database provided') self.agent_role = \ args['agent_role'] if 'agent_role' in args else 'system' self.alpha = args['alpha'] if 'alpha' in args else 0.2 self.gamma = args['gamma'] if 'gamma' in args else 0.95 self.epsilon = args['epsilon'] if 'epsilon' in args else 0.95 self.alpha_decay_rate = \ args['alpha_decay'] if 'alpha_decay' in args else 0.995 self.exploration_decay_rate = \ args['epsilon_decay'] if 'epsilon_decay' in args else 0.9995 self.IS_GREEDY_POLICY = False # TODO: Put these as arguments in the config self.d_win = 0.0025 self.d_lose = 0.01 self.is_training = False self.Q = {} self.pi = {} self.mean_pi = {} self.state_counter = {} self.pp = pprint.PrettyPrinter(width=160) # For debug! # System and user expert policies (optional) self.warmup_policy = None self.warmup_simulator = None if self.agent_role == 'system': # Put your system expert dialogue_policy here self.warmup_policy = \ slot_filling_policy.HandcraftedPolicy({ 'ontology': self.ontology}) elif self.agent_role == 'user': usim_args = dict( zip(['ontology', 'database'], [self.ontology, self.database])) # Put your user expert dialogue_policy here self.warmup_simulator = AgendaBasedUS(usim_args) # Sub-case for CamRest self.dstc2_acts_sys = self.dstc2_acts_usr = None # Plato does not use action masks (rules to define which # actions are valid from each state) and so training can # be harder. This becomes easier if we have a smaller # action set. # Does not include inform and request that are modelled together with # their arguments self.dstc2_acts_sys = [ 'offer', 'canthelp', 'affirm', 'deny', 'ack', 'bye', 'reqmore', 'welcomemsg', 'expl-conf', 'select', 'repeat', 'confirm-domain', 'confirm' ] # Does not include inform and request that are modelled together with # their arguments self.dstc2_acts_usr = [ 'affirm', 'negate', 'deny', 'ack', 'thankyou', 'bye', 'reqmore', 'hello', 'expl-conf', 'repeat', 'reqalts', 'restart', 'confirm' ] # Extract lists of slots that are frequently used self.informable_slots = \ deepcopy(list(self.ontology.ontology['informable'].keys())) self.requestable_slots = \ deepcopy(self.ontology.ontology['requestable']) self.system_requestable_slots = \ deepcopy(self.ontology.ontology['system_requestable']) if self.dstc2_acts_sys: if self.agent_role == 'system': # self.NActions = 5 # self.NOtherActions = 4 self.NActions = \ len(self.dstc2_acts_sys) + \ len(self.requestable_slots) + \ len(self.system_requestable_slots) self.NOtherActions = \ len(self.dstc2_acts_usr) + \ len(self.requestable_slots) + \ len(self.system_requestable_slots) elif self.agent_role == 'user': # self.NActions = 4 # self.NOtherActions = 5 self.NActions = \ len(self.dstc2_acts_usr) + \ len(self.requestable_slots) +\ len(self.system_requestable_slots) self.NOtherActions = len(self.dstc2_acts_sys) + \ len(self.requestable_slots) + \ len(self.system_requestable_slots) else: if self.agent_role == 'system': self.NActions = \ 5 + len(self.ontology.ontology['system_requestable']) + \ len(self.ontology.ontology['requestable']) self.NOtherActions = \ 4 + 2 * len(self.ontology.ontology['requestable']) elif self.agent_role == 'user': self.NActions = \ 4 + 2 * len(self.ontology.ontology['requestable']) self.NOtherActions = \ 5 + len(self.ontology.ontology['system_requestable']) + \ len(self.ontology.ontology['requestable']) self.statistics = {'supervised_turns': 0, 'total_turns': 0}
def __init__(self, args): """ Initializes the internal structures of the Agenda-Based usr Simulator :param args: a dictionary containing an ontology, a database, and other necessary arguments """ super(AgendaBasedUS, self).__init__() if 'ontology' not in args: raise AttributeError('AgendaBasedUS: Please provide ontology!') if 'database' not in args: raise AttributeError('AgendaBasedUS: Please provide database!') ontology = args['ontology'] database = args['database'] um = None if 'um' in args: um = args['um'] self.nlu = None self.nlg = None self.dialogue_turn = 0 self.us_has_initiative = False self.policy = None self.goals_path = None if um is not None: self.user_model = um self.ontology = None if isinstance(ontology, Ontology): self.ontology = ontology elif isinstance(ontology, str): self.ontology = Ontology(ontology) else: raise ValueError('Unacceptable ontology type %s ' % ontology) self.database = None if isinstance(database, DataBase): self.database = database elif isinstance(database, str): if database[-3:] == '.db': self.database = SQLDataBase(database) elif database[-5:] == '.json': self.database = JSONDataBase(database) else: raise ValueError('Unacceptable database type %s ' % database) else: raise ValueError('Unacceptable database type %s ' % database) self.patience = 3 # Initialize probabilities if 'patience' in args: self.patience = args['patience'] if 'pop_distribution' in args: self.pop_distribution = args['pop_distribution'] if 'slot_confuse_prob' in args: self.slot_confuse_prob = args['slot_confuse_prob'] if 'op_confuse_prob' in args: self.op_confuse_prob = args['op_confuse_prob'] if 'value_confuse_prob' in args: self.value_confuse_prob = args['value_confuse_prob'] self.goal_slot_selection_weights = None if 'goal_slot_selection_weights' in args: self.goal_slot_selection_weights = \ args['goal_slot_selection_weights'] if 'nlu' in args: nlu_args = \ dict(zip(['ontology', 'database'], [self.ontology, self.database])) if args['nlu'] == 'CamRest': self.nlu = CamRestNLU(nlu_args) elif args['nlu'] == 'slot_filling': self.nlu = SlotFillingNLU(nlu_args) if 'nlg' in args: if args['nlg'] == 'CamRest': if args['nlg_model_path'] and args['nlg_metadata_path']: self.nlg = \ CamRestNLG({'model_path': args['nlg_model_path']}) else: raise ValueError('ABUS: Cannot initialize CamRest nlg ' 'without a model path AND a metadata ' 'path.') elif args['nlg'] == 'slot_filling': self.nlg = SlotFillingNLG() if 'goals_path' in args: self.goals_path = args['goals_path'] if 'policy_file' in args: self.load(args['policy_file']) if 'us_has_initiative' in args: self.us_has_initiative = args['us_has_initiative'] self.curr_patience = self.patience # Default values for probabilities self.pop_distribution = [1.0] self.slot_confuse_prob = 0.0 self.op_confuse_prob = 0.0 self.value_confuse_prob = 0.0 self.agenda = agenda.Agenda() self.error_model = error_model.ErrorModel(self.ontology, self.database, self.slot_confuse_prob, self.op_confuse_prob, self.value_confuse_prob) self.goal_generator = goal.GoalGenerator({ 'ontology': self.ontology, 'database': self.database, 'goals_file': self.goals_path }) self.goal = None self.offer_made = False self.prev_offer_name = None # Store previous system actions to keep track of patience self.prev_system_acts = None
def __init__(self, args): """ Parses the arguments in the dictionary and initializes the appropriate models for dialogue State Tracking and dialogue Policy. :param args: the configuration file parsed into a dictionary """ super(DialogueManagerGeneric, self).__init__() if 'settings' not in args: raise AttributeError( 'DialogueManagerGeneric: Please provide settings (config)!') if 'ontology' not in args: raise AttributeError( 'DialogueManagerGeneric: Please provide ontology!') if 'database' not in args: raise AttributeError( 'DialogueManagerGeneric: Please provide database!') if 'domain' not in args: raise AttributeError( 'DialogueManagerGeneric: Please provide domain!') settings = args['settings'] ontology = args['ontology'] database = args['database'] domain = args['domain'] agent_id = 0 if 'agent_id' in args: agent_id = int(args['agent_id']) agent_role = 'system' if 'agent_role' in args: agent_role = args['agent_role'] self.settings = settings self.TRAIN_DST = False self.TRAIN_POLICY = False self.MAX_DB_RESULTS = 10 self.DSTracker = None self.DSTracker_info = {} self.policy = None self.policy_info = {} self.policy_path = None self.ontology = None self.database = None self.domain = None self.agent_id = agent_id self.agent_role = agent_role self.dialogue_counter = 0 self.CALCULATE_SLOT_ENTROPIES = True if isinstance(ontology, Ontology): self.ontology = ontology elif isinstance(ontology, str): self.ontology = Ontology(ontology) else: raise ValueError('Unacceptable ontology type %s ' % ontology) if isinstance(database, DataBase): self.database = database elif isinstance(database, str): if database[-3:] == '.db': self.database = SQLDataBase(database) elif database[-5:] == '.json': self.database = JSONDataBase(database) else: raise ValueError('Unacceptable database type %s ' % database) else: raise ValueError('Unacceptable database type %s ' % database) if args and args['policy']: if 'domain' in self.settings['DIALOGUE']: self.domain = self.settings['DIALOGUE']['domain'] else: raise ValueError( 'domain is not specified in DIALOGUE at config.') if 'calculate_slot_entropies' in args: self.CALCULATE_SLOT_ENTROPIES = \ bool(args['calculate_slot_entropies']) if 'package' in args['policy'] and 'class' in args['policy']: self.policy_info = args['policy'] if 'global_arguments' in args['settings']['GENERAL']: if 'arguments' not in self.policy_info: self.policy_info['arguments'] = {} self.policy_info['arguments'].update( args['settings']['GENERAL']['global_arguments'] ) if 'train' in self.policy_info['arguments']: self.TRAIN_POLICY = \ bool(self.policy_info['arguments']['train']) if 'policy_path' in self.policy_info['arguments']: self.policy_path = \ self.policy_info['arguments']['policy_path'] self.policy_info['arguments']['agent_role'] = self.agent_role # Replace ontology and database strings with the actual # objects to avoid repetitions (these won't change). if 'ontology' in self.policy_info['arguments']: self.policy_info['arguments']['ontology'] = self.ontology if 'database' in self.policy_info['arguments']: self.policy_info['arguments']['database'] = self.database self.policy = ConversationalGenericAgent.load_module( self.policy_info['package'], self.policy_info['class'], self.policy_info['arguments'] ) else: raise ValueError('DialogueManagerGeneric: Cannot instantiate' 'dialogue policy!') # DST Settings if 'DST' in args and args: if 'package' in args['DST'] and 'class' in args['DST']: self.DSTracker_info['package'] = args['DST']['package'] self.DSTracker_info['class'] = args['DST']['class'] self.DSTracker_info['args'] = {} if 'global_arguments' in args['settings']['GENERAL']: self.DSTracker_info['args'] = \ args['settings']['GENERAL']['global_arguments'] if 'arguments' in args['DST']: self.DSTracker_info['args']. \ update(args['DST']['arguments']) self.DSTracker = ConversationalGenericAgent.load_module( self.DSTracker_info['package'], self.DSTracker_info['class'], self.DSTracker_info['args'] ) else: raise ValueError('DialogueManagerGeneric: Cannot instantiate' 'dialogue state tracker!') # Default to dummy DST, if no information is provided else: dst_args = dict( zip( ['ontology', 'database', 'domain'], [self.ontology, self.database, domain])) self.DSTracker = SlotFillingDST(dst_args) self.training = self.TRAIN_DST or self.TRAIN_POLICY self.load('')
def __init__(self, args): """ Load the ontology and database, create some patterns, and preprocess the database so that we avoid some computations at runtime. :param args: """ super(SlotFillingNLU, self).__init__() self.ontology = None self.database = None self.requestable_only_slots = None self.slot_values = None if 'ontology' not in args: raise AttributeError('SlotFillingNLU: Please provide ontology!') if 'database' not in args: raise AttributeError('SlotFillingNLU: Please provide database!') ontology = args['ontology'] database = args['database'] if isinstance(ontology, Ontology): self.ontology = ontology elif isinstance(ontology, str): self.ontology = Ontology(ontology) else: raise ValueError('Unacceptable ontology type %s ' % ontology) if database: if isinstance(database, DataBase): self.database = database elif isinstance(database, str): if database[-3:] == '.db': self.database = SQLDataBase(database) elif database[-5:] == '.json': self.database = JSONDataBase(database) else: raise ValueError('Unacceptable database type %s ' % database) else: raise ValueError('Unacceptable database type %s ' % database) # In order to work for simulated users, we need access to possible # values of requestable slots cursor = self.database.SQL_connection.cursor() print('SlotFillingNLU: Preprocessing Database... ' '(do not use SlotFillingNLU with large databases!)') # Get table name db_result = cursor.execute("select * from sqlite_master " "where type = 'table';").fetchall() if db_result and db_result[0] and db_result[0][1]: db_table_name = db_result[0][1] self.slot_values = {} # Get all entries in the database all_items = cursor.execute("select * from " + db_table_name + ";").fetchall() i = 0 for item in all_items: # Get column names slot_names = [i[0] for i in cursor.description] result = dict(zip(slot_names, item)) for slot in result: if slot in ['id', 'signature', 'description']: continue if slot not in self.slot_values: self.slot_values[slot] = [] if result[slot] not in self.slot_values[slot]: self.slot_values[slot].append(result[slot]) i += 1 if i % 2000 == 0: print(f'{float(i/len(all_items))*100}% done') print('SlotFillingNLU: Done!') else: raise ValueError( 'dialogue Manager cannot specify Table Name from database ' '{0}'.format(self.database.db_file_name)) # For this SlotFillingNLU create a list of requestable-only to reduce # computational load self.requestable_only_slots = \ [slot for slot in self.ontology.ontology['requestable'] if slot not in self.ontology.ontology['informable']] + ['name'] self.bye_pattern = ['bye', 'goodbye', 'exit', 'quit', 'stop'] self.hi_pattern = ['hi', 'hello'] self.welcome_pattern = ['welcome', 'how may i help'] self.deny_pattern = ['no'] self.negate_pattern = ['is not'] self.confirm_pattern = ['so is'] self.repeat_pattern = ['repeat'] self.ack_pattern = ['ok'] self.restart_pattern = ['start over'] self.affirm_pattern = ['yes'] self.thankyou_pattern = ['thank you'] self.reqmore_pattern = ['tell me more'] self.expl_conf_pattern = ['alright'] self.reqalts_pattern = ['anything else'] self.select_pattern = ['you prefer'] self.dontcare_pattern = ['anything', 'any', 'i do not care', 'i dont care', 'dont care', 'dontcare', 'it does not matter', 'it doesnt matter', 'does not matter', 'doesnt matter'] self.request_pattern = ['what', 'which', 'where', 'how', 'would'] self.cant_help_pattern = ['can not help', 'cannot help', 'cant help'] punctuation = string.punctuation.replace('$', '') punctuation = punctuation.replace('_', '') punctuation = punctuation.replace('.', '') punctuation = punctuation.replace('&', '') punctuation = punctuation.replace('-', '') punctuation += '.' self.punctuation_remover = str.maketrans('', '', punctuation)
def __init__(self, args): """ Initialise the user Simulator. Here we initialize structures that we need throughout the life of the DTL user Simulator. :param args: dictionary containing ontology, database, and policy file """ super(DTLUserSimulator, self).__init__() if 'ontology' not in args: raise AttributeError('DTLUserSimulator: Please provide ontology!') if 'database' not in args: raise AttributeError('DTLUserSimulator: Please provide database!') if 'policy_file' not in args: raise AttributeError('DTLUserSimulator: Please provide policy ' 'file!') ontology = args['ontology'] database = args['database'] policy_file = args['policy_file'] self.policy = None self.load(policy_file) self.ontology = None if isinstance(ontology, Ontology): self.ontology = ontology elif isinstance(ontology, str): self.ontology = Ontology(ontology) else: raise ValueError('Unacceptable ontology type %s ' % ontology) self.database = None if isinstance(database, DataBase): self.database = database elif isinstance(database, str): if database[-3:] == '.db': self.database = SQLDataBase(database) elif database[-5:] == '.json': self.database = JSONDataBase(database) else: raise ValueError('Unacceptable database type %s ' % database) else: raise ValueError('Unacceptable database type %s ' % database) self.input_system_acts = None self.goal = None self.goal_generator = GoalGenerator({ 'ontology': self.ontology, 'database': self.database }) self.patience = 3 if 'patience' in args: self.patience = args['patience'] self.curr_patience = self.patience self.prev_sys_acts = None self.goal_met = False self.offer_made = False